Skip to main content

solo_api/llm/
sampling.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! [`SamplingLlmClient`] — `LlmClient` impl backed by an MCP client's
4//! `sampling/createMessage` capability.
5//!
6//! Per the v0.9.0 design (`docs/dev-log/0098-v0.9.0-implementation-plan.md`
7//! §6 "Sampling-backed LLM client" / MAJOR 1 + MAJOR 3 resolutions):
8//!
9//!   * Steward holds an `Arc<dyn LlmClient>`. When `LlmConfig::McpSampling`
10//!     is configured, the Steward's `LlmClient` is a `SamplingLlmClient`
11//!     constructed at MCP `initialize` time (when the live peer becomes
12//!     available — the `TenantHandle::steward_slot` LATE-population path).
13//!
14//!   * `SamplingLlmClient::complete()` translates the workspace's
15//!     `Message` → `rmcp::SamplingMessage`, calls
16//!     `peer.create_message(params).await`, extracts the assistant's
17//!     text from the returned `CreateMessageResult`, and emits a
18//!     per-call `AuditOperation::LlmSamplingCall` row through the
19//!     tenant's `WriteHandle` (lesson #30: sync in writer-actor tx
20//!     for ACID).
21//!
22//!   * **Privacy invariant**: the audit `details_json` carries metadata
23//!     only — model hint, message count, max_tokens, duration_ms,
24//!     total prompt character count, output character count. **The raw
25//!     prompt content MUST NOT appear in the audit row**. Pinned by
26//!     [`tests::audit_row_omits_raw_prompt_text`].
27//!
28//!   * Error paths land structured audit rows:
29//!     - Client refusal → `result = "forbidden"`,
30//!       `details_json.reason = "client_refused"`.
31//!     - Timeout → `result = "error"`,
32//!       `details_json.reason = "timeout"`.
33//!     - Other transport / malformed-response → `result = "error"`,
34//!       `details_json.reason = <category>`.
35//!
36//!   * Per-call rate-limit / coalescing is **deferred to v0.9.0 P4**
37//!     (`SamplingCoordinator`). P2 wires the per-call path only.
38
39use std::sync::Arc;
40use std::time::{Duration, Instant};
41
42use async_trait::async_trait;
43use rmcp::model::{
44    CreateMessageRequestParams, CreateMessageResult, ModelHint,
45    ModelPreferences, Role as RmcpRole, SamplingMessage,
46    SamplingMessageContent,
47};
48use rmcp::service::{Peer, RoleServer, ServiceError};
49use solo_core::{Error as CoreError, LlmClient, Message, Result as CoreResult, Role};
50use solo_storage::{AuditEvent, AuditOperation, AuditResult, WriteHandle};
51
52/// Default per-call timeout. Drives the bounded wait around
53/// `peer.create_message`; if the client refuses or stalls, the caller
54/// sees a structured timeout error instead of an indefinite hang.
55///
56/// 30 seconds matches the consolidate-timer's cadence margins: an
57/// LLM call slower than this would already starve the Steward batch
58/// in P4's coordinator. Configurable per-construct via
59/// [`SamplingLlmClient::with_timeout`].
60pub const DEFAULT_SAMPLING_TIMEOUT: Duration = Duration::from_secs(30);
61
62/// Default max_tokens for sampling completions. Matches
63/// `solo-steward::StewardConfig::default().abstraction_max_tokens`
64/// so the wire shape is identical to what the Steward would have
65/// requested from any other backend.
66const DEFAULT_SAMPLING_MAX_TOKENS: u32 = 512;
67
68/// Error surface for [`SamplingClient::create_message`]. Combines the
69/// real rmcp `ServiceError` (when wrapping a live `Peer<RoleServer>`)
70/// with [`super::super::test_support::fake_mcp_client::FakeSamplingError`]
71/// (when driving the fixture from tests).
72#[derive(Debug)]
73pub enum SamplingError {
74    /// Forwarded from `rmcp::Peer::create_message`.
75    Service(ServiceError),
76    /// Routed from [`super::super::test_support::fake_mcp_client::
77    /// FakeSamplingError`] in test paths.
78    #[cfg(any(test, feature = "test-support"))]
79    Fake(crate::test_support::FakeSamplingError),
80}
81
82impl std::fmt::Display for SamplingError {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        match self {
85            Self::Service(e) => write!(f, "{e}"),
86            #[cfg(any(test, feature = "test-support"))]
87            Self::Fake(e) => write!(f, "{e}"),
88        }
89    }
90}
91
92impl std::error::Error for SamplingError {}
93
94impl SamplingError {
95    /// Classifier used by [`SamplingLlmClient::complete`] to map the
96    /// transport-level error to an audit-row category + a Solo
97    /// [`CoreError`] variant.
98    ///
99    /// `(category_for_audit, treat_as_forbidden)` — `forbidden` becomes
100    /// `AuditResult::Forbidden` + `CoreError::Forbidden`; everything
101    /// else is `AuditResult::Error` + `CoreError::Llm`.
102    pub fn classify(&self) -> (&'static str, bool) {
103        match self {
104            Self::Service(_) => ("transport_error", false),
105            #[cfg(any(test, feature = "test-support"))]
106            Self::Fake(e) => match e {
107                crate::test_support::FakeSamplingError::Refused { .. } => {
108                    ("client_refused", true)
109                }
110                crate::test_support::FakeSamplingError::Transport { .. } => {
111                    ("transport_error", false)
112                }
113                crate::test_support::FakeSamplingError::MalformedResponse {
114                    ..
115                } => ("malformed_response", false),
116            },
117        }
118    }
119}
120
121/// Trait abstracting the `sampling/createMessage` RPC. The production
122/// impl wraps `Arc<Peer<RoleServer>>`; the test impl is
123/// [`super::super::test_support::fake_mcp_client::FakeMcpClient`].
124///
125/// Separating the trait from the concrete `Peer<RoleServer>` is the
126/// way around rmcp's `Peer` having private constructors — we can't
127/// build a fake `Peer` for tests, so we inject behind a trait.
128#[async_trait]
129pub trait SamplingClient: Send + Sync {
130    async fn create_message(
131        &self,
132        params: CreateMessageRequestParams,
133    ) -> Result<CreateMessageResult, SamplingError>;
134}
135
136/// Production wrapper around `rmcp::Peer<RoleServer>`. The Peer is
137/// cheap to clone (internally `Arc`-backed) and stays valid for the
138/// lifetime of the MCP session.
139pub struct PeerSamplingClient {
140    peer: Peer<RoleServer>,
141}
142
143impl PeerSamplingClient {
144    pub fn new(peer: Peer<RoleServer>) -> Self {
145        Self { peer }
146    }
147}
148
149#[async_trait]
150impl SamplingClient for PeerSamplingClient {
151    async fn create_message(
152        &self,
153        params: CreateMessageRequestParams,
154    ) -> Result<CreateMessageResult, SamplingError> {
155        self.peer
156            .create_message(params)
157            .await
158            .map_err(SamplingError::Service)
159    }
160}
161
162/// `LlmClient` impl whose `complete()` calls back via the connected
163/// MCP client's sampling capability.
164///
165/// Construct via [`SamplingLlmClient::new`] (production path: wraps a
166/// real `Peer<RoleServer>`) or [`SamplingLlmClient::with_sampling_client`]
167/// (test path: takes the abstracted [`SamplingClient`] trait object
168/// directly so [`super::super::test_support::fake_mcp_client::
169/// FakeMcpClient`] can drive it).
170///
171/// Cheap to clone — every field is `Arc`-shared.
172#[derive(Clone)]
173pub struct SamplingLlmClient {
174    /// The RPC channel back to the MCP client.
175    sampling_client: Arc<dyn SamplingClient>,
176    /// Per-tenant `WriteHandle` for the synchronous audit emit. Routes
177    /// through the writer-actor's mpsc so the
178    /// `AuditOperation::LlmSamplingCall` INSERT lands in a dedicated
179    /// `BEGIN IMMEDIATE` transaction on the writer's connection.
180    write_handle: WriteHandle,
181    /// Cached audit `principal_subject` for the MCP session. Resolved
182    /// at session init time (see `mcp::resolve_mcp_principal`); `None`
183    /// for unauthenticated stdio sessions.
184    audit_principal: Option<String>,
185    /// `max_tokens` value sent on every `sampling/createMessage`.
186    /// Defaults to [`DEFAULT_SAMPLING_MAX_TOKENS`]; configurable via
187    /// [`Self::with_max_tokens`].
188    max_tokens: u32,
189    /// Bounded wait on `create_message`. See
190    /// [`DEFAULT_SAMPLING_TIMEOUT`].
191    timeout: Duration,
192}
193
194impl SamplingLlmClient {
195    /// Build a client wrapping a real `Peer<RoleServer>`. Production
196    /// path — called from
197    /// [`crate::mcp::SoloMcpServer::populate_sampling_steward`] when an MCP
198    /// session reaches `initialize` with a sampling-capable peer.
199    pub fn new(
200        peer: Peer<RoleServer>,
201        write_handle: WriteHandle,
202        audit_principal: Option<String>,
203    ) -> Self {
204        Self::with_sampling_client(
205            Arc::new(PeerSamplingClient::new(peer)),
206            write_handle,
207            audit_principal,
208        )
209    }
210
211    /// Test-friendly constructor accepting any [`SamplingClient`]
212    /// implementation. Pair with
213    /// [`super::super::test_support::fake_mcp_client::FakeMcpClient`]
214    /// in tests.
215    pub fn with_sampling_client(
216        sampling_client: Arc<dyn SamplingClient>,
217        write_handle: WriteHandle,
218        audit_principal: Option<String>,
219    ) -> Self {
220        Self {
221            sampling_client,
222            write_handle,
223            audit_principal,
224            max_tokens: DEFAULT_SAMPLING_MAX_TOKENS,
225            timeout: DEFAULT_SAMPLING_TIMEOUT,
226        }
227    }
228
229    /// Override the per-call `max_tokens` cap.
230    pub fn with_max_tokens(mut self, n: u32) -> Self {
231        self.max_tokens = n.max(1);
232        self
233    }
234
235    /// Override the per-call timeout.
236    pub fn with_timeout(mut self, t: Duration) -> Self {
237        self.timeout = t;
238        self
239    }
240
241    /// Build the `CreateMessageRequestParams` from Solo's `Message`
242    /// vec. Splits out `Role::System` into the `system_prompt` field
243    /// (rmcp's `SamplingMessage::role` is only User / Assistant) and
244    /// hints the user's MCP client toward a Claude-class model.
245    fn build_request(&self, messages: &[Message]) -> CreateMessageRequestParams {
246        // Split system messages out of the conversation history; the
247        // sampling protocol carries the system prompt as a top-level
248        // field rather than inline.
249        let mut system_parts: Vec<String> = Vec::new();
250        let mut samp_messages: Vec<SamplingMessage> = Vec::new();
251        for m in messages {
252            match m.role {
253                Role::System => system_parts.push(m.content.clone()),
254                Role::User => {
255                    samp_messages.push(SamplingMessage::user_text(&m.content));
256                }
257                Role::Assistant => {
258                    samp_messages
259                        .push(SamplingMessage::assistant_text(&m.content));
260                }
261            }
262        }
263        // rmcp 1.7's struct literals are non-exhaustive across crate
264        // boundaries; build via the typed constructors + builders.
265        let preferences = ModelPreferences::new()
266            .with_hints(vec![ModelHint::new("claude")])
267            .with_intelligence_priority(0.7)
268            .with_speed_priority(0.3)
269            .with_cost_priority(0.4);
270        let mut params =
271            CreateMessageRequestParams::new(samp_messages, self.max_tokens)
272                .with_model_preferences(preferences);
273        if !system_parts.is_empty() {
274            params = params.with_system_prompt(system_parts.join("\n\n"));
275        }
276        params
277    }
278
279    /// Build the audit `AuditEvent` carrying ONLY metadata. No raw
280    /// prompt content lands in `details_json`.
281    ///
282    /// Pinned by [`tests::audit_row_omits_raw_prompt_text`].
283    fn audit_event(
284        &self,
285        params: &CreateMessageRequestParams,
286        outcome: SamplingOutcome,
287    ) -> AuditEvent {
288        let prompt_chars: usize = params
289            .messages
290            .iter()
291            .flat_map(|m| m.content.iter())
292            .filter_map(|c| c.as_text().map(|t| t.text.len()))
293            .sum::<usize>()
294            + params
295                .system_prompt
296                .as_ref()
297                .map(|s| s.len())
298                .unwrap_or(0);
299        // ~4 chars per token for the rough English-text estimate used
300        // by `solo doctor --check-llm` and Anthropic's docs. Recorded
301        // for operator capacity-planning.
302        let input_tokens_est = (prompt_chars / 4) as u64;
303        let model_hint = params
304            .model_preferences
305            .as_ref()
306            .and_then(|p| p.hints.as_ref())
307            .and_then(|h| h.first())
308            .and_then(|h| h.name.clone())
309            .unwrap_or_else(|| "(none)".to_string());
310
311        let mut details = serde_json::Map::new();
312        details.insert(
313            "model_hint".to_string(),
314            serde_json::Value::String(model_hint),
315        );
316        details.insert(
317            "messages_count".to_string(),
318            serde_json::Value::Number(params.messages.len().into()),
319        );
320        details.insert(
321            "max_tokens".to_string(),
322            serde_json::Value::Number(params.max_tokens.into()),
323        );
324        details.insert(
325            "prompt_chars".to_string(),
326            serde_json::Value::Number(prompt_chars.into()),
327        );
328        details.insert(
329            "input_tokens_est".to_string(),
330            serde_json::Value::Number(input_tokens_est.into()),
331        );
332
333        let result = match &outcome {
334            SamplingOutcome::Ok {
335                duration_ms,
336                model,
337                output_chars,
338            } => {
339                let output_tokens_est = (*output_chars / 4) as u64;
340                details.insert(
341                    "duration_ms".to_string(),
342                    serde_json::Value::Number((*duration_ms).into()),
343                );
344                details.insert(
345                    "model".to_string(),
346                    serde_json::Value::String(model.clone()),
347                );
348                details.insert(
349                    "output_chars".to_string(),
350                    serde_json::Value::Number((*output_chars).into()),
351                );
352                details.insert(
353                    "output_tokens_est".to_string(),
354                    serde_json::Value::Number(output_tokens_est.into()),
355                );
356                AuditResult::Ok
357            }
358            SamplingOutcome::Forbidden {
359                reason,
360                duration_ms,
361            } => {
362                details.insert(
363                    "duration_ms".to_string(),
364                    serde_json::Value::Number((*duration_ms).into()),
365                );
366                details.insert(
367                    "reason".to_string(),
368                    serde_json::Value::String(reason.to_string()),
369                );
370                AuditResult::Forbidden
371            }
372            SamplingOutcome::Error {
373                reason,
374                duration_ms,
375            } => {
376                details.insert(
377                    "duration_ms".to_string(),
378                    serde_json::Value::Number((*duration_ms).into()),
379                );
380                details.insert(
381                    "reason".to_string(),
382                    serde_json::Value::String(reason.to_string()),
383                );
384                AuditResult::Error
385            }
386        };
387
388        AuditEvent {
389            ts_ms: chrono::Utc::now().timestamp_millis(),
390            principal_subject: self.audit_principal.clone(),
391            operation: AuditOperation::LlmSamplingCall,
392            target_id: None,
393            result,
394            details: Some(serde_json::Value::Object(details)),
395        }
396    }
397}
398
399/// Internal outcome category for the audit-row builder.
400enum SamplingOutcome {
401    Ok {
402        duration_ms: u64,
403        model: String,
404        output_chars: usize,
405    },
406    Forbidden {
407        reason: &'static str,
408        duration_ms: u64,
409    },
410    Error {
411        reason: &'static str,
412        duration_ms: u64,
413    },
414}
415
416#[async_trait]
417impl LlmClient for SamplingLlmClient {
418    fn name(&self) -> &str {
419        "mcp-sampling"
420    }
421
422    async fn complete(&self, messages: &[Message]) -> CoreResult<Message> {
423        let params = self.build_request(messages);
424        let start = Instant::now();
425
426        // Bounded wait on `peer.create_message`. The fold of (rmcp
427        // ServiceError | FakeError | tokio timeout) into the
428        // `Outcome` enum keeps the audit path single-sourced.
429        let rpc = tokio::time::timeout(
430            self.timeout,
431            self.sampling_client.create_message(params.clone()),
432        )
433        .await;
434        let duration_ms = start.elapsed().as_millis().min(u128::from(u64::MAX))
435            as u64;
436
437        let (core_result, outcome): (CoreResult<Message>, SamplingOutcome) =
438            match rpc {
439                Ok(Ok(result)) => {
440                    match extract_text(&result) {
441                        Ok(text) => {
442                            let output_chars = text.len();
443                            let outcome = SamplingOutcome::Ok {
444                                duration_ms,
445                                model: result.model.clone(),
446                                output_chars,
447                            };
448                            (Ok(Message::assistant(text)), outcome)
449                        }
450                        Err(reason) => (
451                            Err(CoreError::llm(format!(
452                                "mcp sampling: malformed response: {reason}"
453                            ))),
454                            SamplingOutcome::Error {
455                                reason: "malformed_response",
456                                duration_ms,
457                            },
458                        ),
459                    }
460                }
461                Ok(Err(e)) => {
462                    let (category, is_forbidden) = e.classify();
463                    let outcome = if is_forbidden {
464                        SamplingOutcome::Forbidden {
465                            reason: category,
466                            duration_ms,
467                        }
468                    } else {
469                        SamplingOutcome::Error {
470                            reason: category,
471                            duration_ms,
472                        }
473                    };
474                    let err = if is_forbidden {
475                        CoreError::forbidden(format!("mcp sampling: {e}"))
476                    } else {
477                        CoreError::llm(format!("mcp sampling: {e}"))
478                    };
479                    (Err(err), outcome)
480                }
481                Err(_elapsed) => (
482                    Err(CoreError::llm(format!(
483                        "mcp sampling: timeout after {}ms",
484                        duration_ms
485                    ))),
486                    SamplingOutcome::Error {
487                        reason: "timeout",
488                        duration_ms,
489                    },
490                ),
491            };
492
493        // Synchronous audit emit through the writer-actor (lesson #30).
494        // Failure to land the audit row is operator-visible: the
495        // sampling call's caller sees the storage error and can decide
496        // whether to abort (we DO abort here — without the audit row
497        // we have no record of the call).
498        let event = self.audit_event(&params, outcome);
499        if let Err(audit_err) = self.write_handle.emit_llm_sampling_audit(event).await
500        {
501            // If the RPC itself succeeded but the audit failed, treat
502            // the call as failed. The audit row IS the persisted
503            // record of the call; an undocumented sampling call is
504            // not acceptable.
505            return Err(CoreError::storage(format!(
506                "mcp sampling: audit emit failed: {audit_err}"
507            )));
508        }
509
510        core_result
511    }
512}
513
514/// Pull the assistant's text out of the rmcp result. Walks every text
515/// content block in the message (the spec allows either a single
516/// `SamplingContent::Single` or a `SamplingContent::Multiple`) and
517/// concatenates them with newlines. Returns `Err(reason)` if no text
518/// blocks were present — the malformed-response path.
519fn extract_text(result: &CreateMessageResult) -> Result<String, &'static str> {
520    if result.message.role != RmcpRole::Assistant {
521        return Err("response role was not Assistant");
522    }
523    let mut out = String::new();
524    for content in result.message.content.iter() {
525        if let SamplingMessageContent::Text(text) = content {
526            if !out.is_empty() {
527                out.push('\n');
528            }
529            out.push_str(&text.text);
530        }
531    }
532    if out.is_empty() {
533        Err("no text content blocks")
534    } else {
535        Ok(out)
536    }
537}
538
539/// v0.9.0 P2: build a sampling-backed `Arc<Steward>` for a tenant that
540/// has resolved `LlmConfig::McpSampling` and just attached an MCP
541/// session.
542///
543/// Called from [`crate::mcp::SoloMcpServer::populate_sampling_steward`] at
544/// MCP `initialize` time once the peer's sampling capability is
545/// confirmed. The returned `Arc<Steward>` is written into
546/// `tenant.steward_slot()` so the writer-actor + consolidate timer
547/// can read a populated slot on their next tick.
548///
549/// v0.9.0 P5 (M3 wiring): the live `PeerSamplingClient` is now wrapped
550/// in a [`super::SamplingCoordinator`] before being handed to
551/// `SamplingLlmClient`. Concurrent `complete()` calls within the
552/// coalesce window collapse into one `peer.create_message` RPC and the
553/// response is demultiplexed back per-task — matching the
554/// `[sampling] coalesce_window_ms` / `coalesce_max_requests` config the
555/// operator wrote in `solo.config.toml`. Per-call audit emit semantics
556/// are unchanged: every logical request still lands one
557/// `AuditOperation::LlmSamplingCall` row, no raw prompt content escapes
558/// to the audit row.
559///
560/// Edge case (clamping): the `[sampling]` block accepts values that
561/// effectively disable batching — `coalesce_max_requests = 1` and / or
562/// `coalesce_window_ms = 0` reduce the coordinator to pass-through (one
563/// inner call per submission). The coordinator's
564/// [`super::SamplingCoordinator::with_settings`] clamps `max_batch` to
565/// `max(1)` so a zero value still produces a single-element flush
566/// immediately rather than panicking or deadlocking.
567pub fn build_sampling_steward(
568    peer: Peer<RoleServer>,
569    write_handle: WriteHandle,
570    audit_principal: Option<String>,
571    steward_config: solo_steward::StewardConfig,
572    sampling_config: solo_storage::SamplingConfig,
573) -> Arc<solo_steward::Steward> {
574    let inner: Arc<dyn SamplingClient> = Arc::new(PeerSamplingClient::new(peer));
575    let coordinator: Arc<dyn SamplingClient> = super::SamplingCoordinator::with_settings(
576        inner,
577        std::time::Duration::from_millis(sampling_config.coalesce_window_ms),
578        sampling_config.coalesce_max_requests as usize,
579    );
580    let client = SamplingLlmClient::with_sampling_client(
581        coordinator,
582        write_handle,
583        audit_principal,
584    )
585    .with_max_tokens(steward_config.abstraction_max_tokens.min(65_536) as u32);
586    Arc::new(solo_steward::Steward::new(Arc::new(client), steward_config))
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592    use crate::test_support::{FakeMcpClient, FakeResponse, FakeSamplingError};
593    use rmcp::model::CreateMessageResult;
594    use solo_core::TenantId;
595    use solo_storage::{
596        EmbedderConfig, HnswParams, InitParams, KeyMaterial, StubEmbedder,
597        TenantHandle, TenantRegistry, TenantRegistryParams, init,
598        open_sqlcipher,
599    };
600    use std::path::PathBuf;
601    use std::sync::Arc;
602    use tempfile::TempDir;
603    use zeroize::Zeroizing;
604
605    const TEST_PASSPHRASE: &str = "v0.9.0-p2-sampling-tests";
606
607    /// Bootstrap a per-tenant `TenantHandle` whose writer-actor accepts
608    /// the new `WriteCommand::EmitLlmSamplingAudit` variant.
609    ///
610    /// Mirrors the v0.8.x test discipline (see
611    /// `crates/solo-storage/src/tenants/handle_registry_tests.rs`'s
612    /// `fresh_init_dir`): build a real tenant DB on disk via the same
613    /// `init()` helper users invoke, wrap in a `TenantRegistry`, and
614    /// surface the `WriteHandle` for direct `SamplingLlmClient`
615    /// wiring.
616    struct Harness {
617        _tmp: TempDir,
618        _registry: Arc<TenantRegistry>,
619        _tenant: Arc<TenantHandle>,
620        write_handle: solo_storage::WriteHandle,
621        db_path: PathBuf,
622        key: KeyMaterial,
623    }
624
625    async fn harness() -> Harness {
626        let tmp = TempDir::new().expect("tempdir");
627        let data_dir = tmp.path().to_path_buf();
628        let _ = init(InitParams {
629            data_dir: data_dir.clone(),
630            passphrase: Zeroizing::new(TEST_PASSPHRASE.into()),
631            force: false,
632            embedder: EmbedderConfig {
633                name: "stub".into(),
634                version: "v1".into(),
635                dim: 32,
636                dtype: "f32".into(),
637            },
638        })
639        .expect("init");
640
641        let cfg = solo_storage::SoloConfig::read(
642            &data_dir.join("solo.config.toml"),
643        )
644        .expect("read cfg");
645        let key = KeyMaterial::derive(
646            TEST_PASSPHRASE,
647            &cfg.salt_bytes().expect("salt"),
648        )
649        .expect("derive key");
650
651        let embedder: Arc<dyn solo_core::Embedder> =
652            Arc::new(StubEmbedder::new("stub", "v1", 32));
653        let registry = Arc::new(
654            TenantRegistry::open(TenantRegistryParams {
655                data_dir: data_dir.clone(),
656                key: key.clone(),
657                embedder: embedder.clone(),
658                hnsw_params: HnswParams::default(),
659                steward: None,
660                runtime_handle: Some(tokio::runtime::Handle::current()),
661                steward_factory: None,
662                triples_batch_signal: None,
663            })
664            .expect("open registry"),
665        );
666
667        let tenant_id = TenantId::default_tenant();
668        let tenant = registry
669            .get_or_open(&tenant_id)
670            .await
671            .expect("get_or_open default tenant");
672        let write_handle = tenant.write().clone();
673        let db_path = tenant.db_path().to_path_buf();
674
675        Harness {
676            _tmp: tmp,
677            _registry: registry,
678            _tenant: tenant,
679            write_handle,
680            db_path,
681            key,
682        }
683    }
684
685    /// Helper: count the `audit_events` rows whose `operation` is the
686    /// given string. Opens a fresh connection to the tenant DB so we
687    /// avoid contention with the writer-actor's own connection.
688    fn count_audit_rows(db_path: &std::path::Path, key: &KeyMaterial, op: &str) -> i64 {
689        let conn = open_sqlcipher(db_path, key).expect("open db");
690        conn.query_row(
691            "SELECT COUNT(*) FROM audit_events WHERE operation = ?",
692            rusqlite::params![op],
693            |r| r.get(0),
694        )
695        .expect("count")
696    }
697
698    /// Helper: load the most-recent `llm.sampling_call` audit row and
699    /// return `(result, principal_subject, details_json)`.
700    fn latest_sampling_audit_details(
701        db_path: &std::path::Path,
702        key: &KeyMaterial,
703    ) -> (String, Option<String>, serde_json::Value) {
704        let conn = open_sqlcipher(db_path, key).expect("open db");
705        let (result, principal, details_str): (String, Option<String>, Option<String>) = conn
706            .query_row(
707                "SELECT result, principal_subject, details_json
708                 FROM audit_events
709                 WHERE operation = 'llm.sampling_call'
710                 ORDER BY ts_ms DESC, rowid DESC
711                 LIMIT 1",
712                [],
713                |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
714            )
715            .expect("query");
716        let details: serde_json::Value =
717            serde_json::from_str(&details_str.expect("details_json present"))
718                .expect("parse details");
719        (result, principal, details)
720    }
721
722    /// Happy path: a successful `create_message` round-trip returns
723    /// the assistant text wrapped in a `Message::assistant`, and lands
724    /// exactly one `llm.sampling_call` audit row with `result = 'ok'`.
725    #[tokio::test]
726    async fn sampling_complete_happy_path_returns_text() {
727        let h = harness().await;
728        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("derived theme")));
729        let client = SamplingLlmClient::with_sampling_client(
730            fake.clone(),
731            h.write_handle.clone(),
732            Some("alice".into()),
733        );
734        let messages = vec![Message::user("summarise these episodes")];
735        let result = client.complete(&messages).await.expect("ok");
736        assert_eq!(result.role, Role::Assistant);
737        assert_eq!(result.content, "derived theme");
738
739        // Exactly one audit row landed.
740        assert_eq!(
741            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
742            1
743        );
744        let (result_str, principal, details) =
745            latest_sampling_audit_details(&h.db_path, &h.key);
746        assert_eq!(result_str, "ok");
747        assert_eq!(principal.as_deref(), Some("alice"));
748        assert_eq!(details["model_hint"], "claude");
749        assert_eq!(details["model"], "fake-claude");
750        assert_eq!(details["messages_count"], 1);
751        assert_eq!(details["max_tokens"], 512);
752    }
753
754    /// Privacy invariant: the audit row's `details_json` MUST NOT
755    /// contain the raw prompt content. Pinned by string inspection of
756    /// the persisted JSON.
757    #[tokio::test]
758    async fn audit_row_omits_raw_prompt_text() {
759        let h = harness().await;
760        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
761        let client = SamplingLlmClient::with_sampling_client(
762            fake,
763            h.write_handle.clone(),
764            None,
765        );
766        let secret = "THE-USER-ID-IS-bobby-1234";
767        let messages = vec![
768            Message::system("you are a friendly assistant"),
769            Message::user(secret),
770        ];
771        client.complete(&messages).await.expect("ok");
772
773        let (_, _, details) =
774            latest_sampling_audit_details(&h.db_path, &h.key);
775        let serialised =
776            serde_json::to_string(&details).expect("serialise details");
777        assert!(
778            !serialised.contains(secret),
779            "audit details must not carry raw prompt content; was: {serialised}"
780        );
781        assert!(
782            !serialised.contains("you are a friendly assistant"),
783            "audit details must not carry system prompt; was: {serialised}"
784        );
785        // Metadata IS present, even though the prompt is not.
786        assert_eq!(details["messages_count"], 1);
787        assert!(details["prompt_chars"].as_u64().unwrap() > 0);
788    }
789
790    /// Client refusal: maps to `CoreError::Forbidden` + audit row
791    /// `result = 'forbidden'` + `details_json.reason = 'client_refused'`.
792    #[tokio::test]
793    async fn client_refusal_returns_forbidden_and_audits() {
794        let h = harness().await;
795        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ignored")));
796        fake.reject_with("user dismissed approval");
797        let client = SamplingLlmClient::with_sampling_client(
798            fake,
799            h.write_handle.clone(),
800            Some("alice".into()),
801        );
802        let err = client
803            .complete(&[Message::user("anything")])
804            .await
805            .unwrap_err();
806        match err {
807            CoreError::Forbidden(_) => {}
808            other => panic!("expected Forbidden, got {other:?}"),
809        }
810        let (result_str, _, details) =
811            latest_sampling_audit_details(&h.db_path, &h.key);
812        assert_eq!(result_str, "forbidden");
813        assert_eq!(details["reason"], "client_refused");
814    }
815
816    /// Timeout: tokio::time::timeout fires before the fake's `Slow`
817    /// response resolves; client returns `CoreError::Llm` + audit row
818    /// `result = 'error'` + `details_json.reason = 'timeout'`.
819    ///
820    /// Real wall-clock: 80ms slow response vs 30ms client timeout.
821    /// Margin is loose enough for slow CI without making the test
822    /// drag.
823    #[tokio::test]
824    async fn timeout_returns_error_with_timeout_reason() {
825        let h = harness().await;
826        let fake = Arc::new(FakeMcpClient::new(FakeResponse::slow(
827            "late",
828            Duration::from_millis(800),
829        )));
830        let client = SamplingLlmClient::with_sampling_client(
831            fake,
832            h.write_handle.clone(),
833            None,
834        )
835        .with_timeout(Duration::from_millis(30));
836        let err = client
837            .complete(&[Message::user("hello")])
838            .await
839            .unwrap_err();
840        match err {
841            CoreError::Llm(msg) => assert!(msg.contains("timeout")),
842            other => panic!("expected Llm, got {other:?}"),
843        }
844        let (result_str, _, details) =
845            latest_sampling_audit_details(&h.db_path, &h.key);
846        assert_eq!(result_str, "error");
847        assert_eq!(details["reason"], "timeout");
848    }
849
850    /// Malformed response: the fake returns a result with zero text
851    /// content blocks; client surfaces `CoreError::Llm` + audit row
852    /// `result = 'error'` + `details_json.reason = 'malformed_response'`.
853    #[tokio::test]
854    async fn malformed_response_returns_error_with_reason() {
855        let h = harness().await;
856        let fake = Arc::new(FakeMcpClient::new(FakeResponse::EmptyContent));
857        let client = SamplingLlmClient::with_sampling_client(
858            fake,
859            h.write_handle.clone(),
860            None,
861        );
862        let err = client
863            .complete(&[Message::user("hi")])
864            .await
865            .unwrap_err();
866        assert!(matches!(err, CoreError::Llm(_)));
867        let (result_str, _, details) =
868            latest_sampling_audit_details(&h.db_path, &h.key);
869        assert_eq!(result_str, "error");
870        assert_eq!(details["reason"], "malformed_response");
871    }
872
873    /// `principal_subject = None` works — audit row still emits with
874    /// NULL.
875    #[tokio::test]
876    async fn no_principal_emits_audit_with_null_principal() {
877        let h = harness().await;
878        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
879        let client = SamplingLlmClient::with_sampling_client(
880            fake,
881            h.write_handle.clone(),
882            None,
883        );
884        client.complete(&[Message::user("hi")]).await.expect("ok");
885        let (_, principal, _) =
886            latest_sampling_audit_details(&h.db_path, &h.key);
887        assert_eq!(principal, None);
888    }
889
890    /// Concurrency: 8 parallel `complete()` calls land 8 audit rows.
891    /// Audit IDs (autoincrement rowid) must be distinct — verifies the
892    /// writer-actor serialises the per-call audit emit (no
893    /// interleaving / dropped rows).
894    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
895    async fn parallel_completes_serialise_audit_rows() {
896        let h = harness().await;
897        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
898        let client = SamplingLlmClient::with_sampling_client(
899            fake.clone(),
900            h.write_handle.clone(),
901            Some("alice".into()),
902        );
903        let mut futs = Vec::new();
904        for _ in 0..8 {
905            let c = client.clone();
906            futs.push(tokio::spawn(async move {
907                c.complete(&[Message::user("hi")]).await
908            }));
909        }
910        for f in futs {
911            f.await.expect("join").expect("ok");
912        }
913        assert_eq!(
914            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
915            8,
916            "8 parallel calls must land 8 audit rows"
917        );
918
919        // Each was a separate request to the fake.
920        assert_eq!(fake.record_requests().len(), 8);
921    }
922
923    /// `complete` translates the workspace's `Message::system` into the
924    /// `system_prompt` top-level field; user/assistant roles map to
925    /// rmcp's `SamplingMessage::user_text` / `assistant_text`.
926    #[tokio::test]
927    async fn build_request_splits_system_from_messages() {
928        let h = harness().await;
929        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
930        let client = SamplingLlmClient::with_sampling_client(
931            fake.clone(),
932            h.write_handle.clone(),
933            None,
934        );
935        client
936            .complete(&[
937                Message::system("be terse"),
938                Message::user("question"),
939                Message::assistant("answer"),
940            ])
941            .await
942            .expect("ok");
943        let recorded = fake.record_requests();
944        assert_eq!(recorded.len(), 1);
945        let req = &recorded[0];
946        assert_eq!(
947            req.system_prompt.as_deref(),
948            Some("be terse"),
949            "Role::System must map to system_prompt"
950        );
951        assert_eq!(req.messages.len(), 2);
952        // The remaining two messages are the user + assistant turns.
953        assert_eq!(req.messages[0].role, RmcpRole::User);
954        assert_eq!(req.messages[1].role, RmcpRole::Assistant);
955    }
956
957    /// `model_preferences` carries the `claude` hint per plan §6.
958    /// Pins the wire shape so a future change is a conscious decision.
959    #[tokio::test]
960    async fn build_request_includes_claude_model_hint() {
961        let h = harness().await;
962        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
963        let client = SamplingLlmClient::with_sampling_client(
964            fake.clone(),
965            h.write_handle.clone(),
966            None,
967        );
968        client
969            .complete(&[Message::user("hi")])
970            .await
971            .expect("ok");
972        let recorded = fake.record_requests();
973        let prefs = recorded[0].model_preferences.as_ref().expect("prefs");
974        let hint = prefs
975            .hints
976            .as_ref()
977            .and_then(|h| h.first())
978            .and_then(|h| h.name.clone())
979            .expect("hint name");
980        assert_eq!(hint, "claude");
981    }
982
983    /// `with_max_tokens(n)` propagates to the request's
984    /// `max_tokens` field.
985    #[tokio::test]
986    async fn with_max_tokens_overrides_default() {
987        let h = harness().await;
988        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
989        let client = SamplingLlmClient::with_sampling_client(
990            fake.clone(),
991            h.write_handle.clone(),
992            None,
993        )
994        .with_max_tokens(2048);
995        client
996            .complete(&[Message::user("hi")])
997            .await
998            .expect("ok");
999        let recorded = fake.record_requests();
1000        assert_eq!(recorded[0].max_tokens, 2048);
1001    }
1002
1003    /// Reconfiguring the fake mid-test produces distinct audit rows
1004    /// for each call (positive then negative).
1005    #[tokio::test]
1006    async fn reconfigurable_fake_distinguishes_audit_rows() {
1007        let h = harness().await;
1008        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1009        let client = SamplingLlmClient::with_sampling_client(
1010            fake.clone(),
1011            h.write_handle.clone(),
1012            Some("alice".into()),
1013        );
1014
1015        client.complete(&[Message::user("a")]).await.expect("ok");
1016        fake.reject_with("user said no");
1017        let _ = client.complete(&[Message::user("b")]).await;
1018
1019        let conn = open_sqlcipher(&h.db_path, &h.key).expect("open");
1020        let mut stmt = conn
1021            .prepare(
1022                "SELECT result FROM audit_events WHERE operation = 'llm.sampling_call' ORDER BY ts_ms ASC, rowid ASC",
1023            )
1024            .expect("prepare");
1025        let rows: Vec<String> = stmt
1026            .query_map([], |r| r.get::<_, String>(0))
1027            .expect("query")
1028            .map(|r| r.expect("row"))
1029            .collect();
1030        assert_eq!(rows, vec!["ok".to_string(), "forbidden".to_string()]);
1031    }
1032
1033    /// `extract_text` walks single-block content.
1034    #[test]
1035    fn extract_text_pulls_text_from_single_block() {
1036        let result = CreateMessageResult::new(
1037            SamplingMessage::assistant_text("hello"),
1038            "fake".into(),
1039        );
1040        assert_eq!(extract_text(&result).unwrap(), "hello");
1041    }
1042
1043    /// `extract_text` rejects an empty-content response.
1044    #[test]
1045    fn extract_text_rejects_empty_content() {
1046        let result = CreateMessageResult::new(
1047            SamplingMessage::new_multiple(RmcpRole::Assistant, Vec::new()),
1048            "fake".into(),
1049        );
1050        assert!(extract_text(&result).is_err());
1051    }
1052
1053    /// `extract_text` rejects a User-role response (impossible per
1054    /// spec but pinning the defensive check).
1055    #[test]
1056    fn extract_text_rejects_non_assistant_role() {
1057        let result = CreateMessageResult::new(
1058            SamplingMessage::user_text("hello"),
1059            "fake".into(),
1060        );
1061        assert!(extract_text(&result).is_err());
1062    }
1063
1064    /// `SamplingError::classify` maps each fake variant to the right
1065    /// audit category.
1066    #[test]
1067    fn sampling_error_classify_maps_fake_variants() {
1068        let refused = SamplingError::Fake(FakeSamplingError::Refused {
1069            reason: "x".into(),
1070        });
1071        let (cat, forb) = refused.classify();
1072        assert_eq!(cat, "client_refused");
1073        assert!(forb);
1074
1075        let transport = SamplingError::Fake(FakeSamplingError::Transport {
1076            message: "x".into(),
1077        });
1078        let (cat, forb) = transport.classify();
1079        assert_eq!(cat, "transport_error");
1080        assert!(!forb);
1081
1082        let malformed =
1083            SamplingError::Fake(FakeSamplingError::MalformedResponse {
1084                message: "x".into(),
1085            });
1086        let (cat, forb) = malformed.classify();
1087        assert_eq!(cat, "malformed_response");
1088        assert!(!forb);
1089    }
1090
1091    // -------- v0.9.0 P5a (M3 wiring) — SamplingCoordinator integration --------
1092    //
1093    // These tests pin the contract that `build_sampling_steward` wraps the
1094    // live peer in a `SamplingCoordinator` before handing it to
1095    // `SamplingLlmClient`. They cannot call `build_sampling_steward`
1096    // directly (it takes a real `Peer<RoleServer>` whose constructors are
1097    // private inside rmcp), but they exercise the **exact same wiring
1098    // shape** by substituting `FakeMcpClient` for `PeerSamplingClient`.
1099    // The production code path is:
1100    //
1101    //     PeerSamplingClient -> SamplingCoordinator -> SamplingLlmClient
1102    //
1103    // The tested shape is:
1104    //
1105    //     FakeMcpClient      -> SamplingCoordinator -> SamplingLlmClient
1106    //
1107    // Only the leaf `SamplingClient` impl differs; the
1108    // `SamplingClient` trait is the same Arc-of-dyn in both paths.
1109
1110    /// SamplingCoordinator wrapping a `FakeMcpClient` and feeding
1111    /// `SamplingLlmClient::with_sampling_client` is the same Arc-of-dyn
1112    /// shape `build_sampling_steward` constructs at MCP-initialize
1113    /// time. Single-element flushes pass through unwrapped, so a lone
1114    /// `complete()` call still emits one audit row and produces the
1115    /// expected text.
1116    #[tokio::test]
1117    async fn sampling_llm_client_uses_coordinator_in_production_path() {
1118        let h = harness().await;
1119        let fake: Arc<dyn SamplingClient> =
1120            Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1121        let coord: Arc<dyn SamplingClient> =
1122            super::super::SamplingCoordinator::with_settings(
1123                fake.clone(),
1124                Duration::from_millis(50),
1125                10,
1126            );
1127        let client = SamplingLlmClient::with_sampling_client(
1128            coord,
1129            h.write_handle.clone(),
1130            Some("alice".into()),
1131        );
1132        let result = client
1133            .complete(&[Message::user("test")])
1134            .await
1135            .expect("ok");
1136        assert_eq!(result.role, Role::Assistant);
1137        assert_eq!(result.content, "ok");
1138        // Single audit row landed — per-call audit semantics
1139        // unchanged by the coordinator wrap.
1140        assert_eq!(
1141            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1142            1,
1143            "one logical call → one audit row, even through coordinator"
1144        );
1145    }
1146
1147    /// End-to-end batching pin: N concurrent `complete()` calls within
1148    /// the coalesce window resolve as ONE inner `create_message` RPC
1149    /// on the underlying `FakeMcpClient`, but N audit rows still land
1150    /// (one per logical call — the privacy + audit invariants from P2
1151    /// hold).
1152    ///
1153    /// This is the v0.9.0 release notes' "⌈N/M⌉ peer.create_message
1154    /// calls per coalesce window" claim, exercised through the same
1155    /// trait-object chain that `build_sampling_steward` constructs in
1156    /// production.
1157    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1158    async fn coordinator_coalesces_concurrent_calls_into_one_inner_rpc() {
1159        // Coalesced JSON response for 5 tasks — matches the
1160        // `[{task_index, response}]` shape `flush_batch` demuxes
1161        // multi-element batches into.
1162        let response = serde_json::to_string(&(0..5)
1163            .map(|i| serde_json::json!({
1164                "task_index": i,
1165                "response": format!("response-{i}"),
1166            }))
1167            .collect::<Vec<_>>())
1168            .unwrap();
1169
1170        let h = harness().await;
1171        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
1172        let coord: Arc<dyn SamplingClient> =
1173            super::super::SamplingCoordinator::with_settings(
1174                fake.clone(),
1175                // Wide window so all 5 submissions land in one batch.
1176                Duration::from_secs(5),
1177                10,
1178            );
1179        let client = SamplingLlmClient::with_sampling_client(
1180            coord,
1181            h.write_handle.clone(),
1182            Some("alice".into()),
1183        );
1184
1185        // Fire 5 concurrent `complete()` calls; the coordinator should
1186        // coalesce them into ONE `FakeMcpClient::create_message` call.
1187        let mut futs = Vec::new();
1188        for i in 0..5 {
1189            let c = client.clone();
1190            futs.push(tokio::spawn(async move {
1191                c.complete(&[Message::user(format!("task-{i}"))]).await
1192            }));
1193        }
1194        for f in futs {
1195            f.await.expect("join").expect("ok");
1196        }
1197
1198        // EXACTLY one inner RPC.
1199        assert_eq!(
1200            fake.record_requests().len(),
1201            1,
1202            "5 logical calls within window must coalesce to 1 inner RPC"
1203        );
1204        // BUT 5 audit rows — per-logical-call audit invariant preserved.
1205        assert_eq!(
1206            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1207            5,
1208            "5 logical calls → 5 audit rows (coordinator doesn't merge audits)"
1209        );
1210    }
1211
1212    /// Edge case: `coalesce_max_requests = 1` reduces the coordinator
1213    /// to pass-through (each submit flushes a 1-element batch
1214    /// immediately). With max_batch=1 and a wide window, 3 concurrent
1215    /// calls land 3 inner RPCs — coordinator is operating as if no
1216    /// batching were configured.
1217    ///
1218    /// Pins the brief's documented edge-case: zero / one-valued config
1219    /// reduces to pass-through, never panics or deadlocks. Mirrors
1220    /// `SamplingCoordinator::with_settings`'s `max_batch.max(1)`
1221    /// clamping for the `coalesce_max_requests = 0` case.
1222    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1223    async fn coordinator_max_batch_one_acts_as_passthrough() {
1224        let h = harness().await;
1225        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1226        let coord: Arc<dyn SamplingClient> =
1227            super::super::SamplingCoordinator::with_settings(
1228                fake.clone(),
1229                Duration::from_secs(5),
1230                // max_batch=1 → every submission flushes immediately as
1231                // a 1-element batch; pass-through behaviour.
1232                1,
1233            );
1234        let client = SamplingLlmClient::with_sampling_client(
1235            coord,
1236            h.write_handle.clone(),
1237            None,
1238        );
1239        let mut futs = Vec::new();
1240        for _ in 0..3 {
1241            let c = client.clone();
1242            futs.push(tokio::spawn(async move {
1243                c.complete(&[Message::user("hi")]).await
1244            }));
1245        }
1246        for f in futs {
1247            f.await.expect("join").expect("ok");
1248        }
1249        // 3 logical calls → 3 inner RPCs (no coalescing).
1250        assert_eq!(
1251            fake.record_requests().len(),
1252            3,
1253            "max_batch=1 must pass through every submission as its own RPC"
1254        );
1255        assert_eq!(
1256            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1257            3
1258        );
1259    }
1260}