Skip to main content

solo_api/llm/
sampling.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! [`SamplingLlmClient`] — `LlmClient` impl backed by an MCP client's
4//! `sampling/createMessage` capability.
5//!
6//! Per the v0.9.0 design (`docs/dev-log/0098-v0.9.0-implementation-plan.md`
7//! §6 "Sampling-backed LLM client" / MAJOR 1 + MAJOR 3 resolutions):
8//!
9//!   * Steward holds an `Arc<dyn LlmClient>`. When `LlmConfig::McpSampling`
10//!     is configured, the Steward's `LlmClient` is a `SamplingLlmClient`
11//!     constructed at MCP `initialize` time (when the live peer becomes
12//!     available — the `TenantHandle::steward_slot` LATE-population path).
13//!
14//!   * `SamplingLlmClient::complete()` translates the workspace's
15//!     `Message` → `rmcp::SamplingMessage`, calls
16//!     `peer.create_message(params).await`, extracts the assistant's
17//!     text from the returned `CreateMessageResult`, and emits a
18//!     per-call `AuditOperation::LlmSamplingCall` row through the
19//!     tenant's `WriteHandle` (lesson #30: sync in writer-actor tx
20//!     for ACID).
21//!
22//!   * **Privacy invariant**: the audit `details_json` carries metadata
23//!     only — model hint, message count, max_tokens, duration_ms,
24//!     total prompt character count, output character count. **The raw
25//!     prompt content MUST NOT appear in the audit row**. Pinned by
26//!     [`tests::audit_row_omits_raw_prompt_text`].
27//!
28//!   * Error paths land structured audit rows:
29//!     - Client refusal → `result = "forbidden"`,
30//!       `details_json.reason = "client_refused"`.
31//!     - Timeout → `result = "error"`,
32//!       `details_json.reason = "timeout"`.
33//!     - Other transport / malformed-response → `result = "error"`,
34//!       `details_json.reason = <category>`.
35//!
36//!   * Per-call rate-limit / coalescing is **deferred to v0.9.0 P4**
37//!     (`SamplingCoordinator`). P2 wires the per-call path only.
38
39use std::sync::Arc;
40use std::time::{Duration, Instant};
41
42use async_trait::async_trait;
43use rmcp::model::{
44    CreateMessageRequestParams, CreateMessageResult, ModelHint, ModelPreferences, Role as RmcpRole,
45    SamplingMessage, SamplingMessageContent,
46};
47use rmcp::service::{Peer, RoleServer, ServiceError};
48use solo_core::{Error as CoreError, LlmClient, Message, Result as CoreResult, Role};
49use solo_storage::{AuditEvent, AuditOperation, AuditResult, WriteHandle};
50
51/// Default per-call timeout. Drives the bounded wait around
52/// `peer.create_message`; if the client refuses or stalls, the caller
53/// sees a structured timeout error instead of an indefinite hang.
54///
55/// 30 seconds matches the consolidate-timer's cadence margins: an
56/// LLM call slower than this would already starve the Steward batch
57/// in P4's coordinator. Configurable per-construct via
58/// [`SamplingLlmClient::with_timeout`].
59pub const DEFAULT_SAMPLING_TIMEOUT: Duration = Duration::from_secs(30);
60
61/// Default max_tokens for sampling completions. Matches
62/// `solo-steward::StewardConfig::default().abstraction_max_tokens`
63/// so the wire shape is identical to what the Steward would have
64/// requested from any other backend.
65const DEFAULT_SAMPLING_MAX_TOKENS: u32 = 512;
66
67/// Error surface for [`SamplingClient::create_message`]. Combines the
68/// real rmcp `ServiceError` (when wrapping a live `Peer<RoleServer>`)
69/// with [`super::super::test_support::fake_mcp_client::FakeSamplingError`]
70/// (when driving the fixture from tests).
71#[derive(Debug)]
72pub enum SamplingError {
73    /// Forwarded from `rmcp::Peer::create_message`.
74    Service(ServiceError),
75    /// Routed from [`super::super::test_support::fake_mcp_client::
76    /// FakeSamplingError`] in test paths.
77    #[cfg(any(test, feature = "test-support"))]
78    Fake(crate::test_support::FakeSamplingError),
79}
80
81impl std::fmt::Display for SamplingError {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            Self::Service(e) => write!(f, "{e}"),
85            #[cfg(any(test, feature = "test-support"))]
86            Self::Fake(e) => write!(f, "{e}"),
87        }
88    }
89}
90
91impl std::error::Error for SamplingError {}
92
93impl SamplingError {
94    /// Classifier used by [`SamplingLlmClient::complete`] to map the
95    /// transport-level error to an audit-row category + a Solo
96    /// [`CoreError`] variant.
97    ///
98    /// `(category_for_audit, treat_as_forbidden)` — `forbidden` becomes
99    /// `AuditResult::Forbidden` + `CoreError::Forbidden`; everything
100    /// else is `AuditResult::Error` + `CoreError::Llm`.
101    pub fn classify(&self) -> (&'static str, bool) {
102        match self {
103            Self::Service(_) => ("transport_error", false),
104            #[cfg(any(test, feature = "test-support"))]
105            Self::Fake(e) => match e {
106                crate::test_support::FakeSamplingError::Refused { .. } => ("client_refused", true),
107                crate::test_support::FakeSamplingError::Transport { .. } => {
108                    ("transport_error", false)
109                }
110                crate::test_support::FakeSamplingError::MalformedResponse { .. } => {
111                    ("malformed_response", false)
112                }
113            },
114        }
115    }
116}
117
118/// Trait abstracting the `sampling/createMessage` RPC. The production
119/// impl wraps `Arc<Peer<RoleServer>>`; the test impl is
120/// [`super::super::test_support::fake_mcp_client::FakeMcpClient`].
121///
122/// Separating the trait from the concrete `Peer<RoleServer>` is the
123/// way around rmcp's `Peer` having private constructors — we can't
124/// build a fake `Peer` for tests, so we inject behind a trait.
125#[async_trait]
126pub trait SamplingClient: Send + Sync {
127    async fn create_message(
128        &self,
129        params: CreateMessageRequestParams,
130    ) -> Result<CreateMessageResult, SamplingError>;
131}
132
133/// Production wrapper around `rmcp::Peer<RoleServer>`. The Peer is
134/// cheap to clone (internally `Arc`-backed) and stays valid for the
135/// lifetime of the MCP session.
136pub struct PeerSamplingClient {
137    peer: Peer<RoleServer>,
138}
139
140impl PeerSamplingClient {
141    pub fn new(peer: Peer<RoleServer>) -> Self {
142        Self { peer }
143    }
144}
145
146#[async_trait]
147impl SamplingClient for PeerSamplingClient {
148    async fn create_message(
149        &self,
150        params: CreateMessageRequestParams,
151    ) -> Result<CreateMessageResult, SamplingError> {
152        self.peer
153            .create_message(params)
154            .await
155            .map_err(SamplingError::Service)
156    }
157}
158
159/// `LlmClient` impl whose `complete()` calls back via the connected
160/// MCP client's sampling capability.
161///
162/// Construct via [`SamplingLlmClient::new`] (production path: wraps a
163/// real `Peer<RoleServer>`) or [`SamplingLlmClient::with_sampling_client`]
164/// (test path: takes the abstracted [`SamplingClient`] trait object
165/// directly so [`super::super::test_support::fake_mcp_client::
166/// FakeMcpClient`] can drive it).
167///
168/// Cheap to clone — every field is `Arc`-shared.
169#[derive(Clone)]
170pub struct SamplingLlmClient {
171    /// The RPC channel back to the MCP client.
172    sampling_client: Arc<dyn SamplingClient>,
173    /// Per-tenant `WriteHandle` for the synchronous audit emit. Routes
174    /// through the writer-actor's mpsc so the
175    /// `AuditOperation::LlmSamplingCall` INSERT lands in a dedicated
176    /// `BEGIN IMMEDIATE` transaction on the writer's connection.
177    write_handle: WriteHandle,
178    /// Cached audit `principal_subject` for the MCP session. Resolved
179    /// at session init time (see `mcp::resolve_mcp_principal`); `None`
180    /// for unauthenticated stdio sessions.
181    audit_principal: Option<String>,
182    /// `max_tokens` value sent on every `sampling/createMessage`.
183    /// Defaults to [`DEFAULT_SAMPLING_MAX_TOKENS`]; configurable via
184    /// [`Self::with_max_tokens`].
185    max_tokens: u32,
186    /// Bounded wait on `create_message`. See
187    /// [`DEFAULT_SAMPLING_TIMEOUT`].
188    timeout: Duration,
189}
190
191impl SamplingLlmClient {
192    /// Build a client wrapping a real `Peer<RoleServer>`. Production
193    /// path — called from
194    /// [`crate::mcp::SoloMcpServer::populate_sampling_steward`] when an MCP
195    /// session reaches `initialize` with a sampling-capable peer.
196    pub fn new(
197        peer: Peer<RoleServer>,
198        write_handle: WriteHandle,
199        audit_principal: Option<String>,
200    ) -> Self {
201        Self::with_sampling_client(
202            Arc::new(PeerSamplingClient::new(peer)),
203            write_handle,
204            audit_principal,
205        )
206    }
207
208    /// Test-friendly constructor accepting any [`SamplingClient`]
209    /// implementation. Pair with
210    /// [`super::super::test_support::fake_mcp_client::FakeMcpClient`]
211    /// in tests.
212    pub fn with_sampling_client(
213        sampling_client: Arc<dyn SamplingClient>,
214        write_handle: WriteHandle,
215        audit_principal: Option<String>,
216    ) -> Self {
217        Self {
218            sampling_client,
219            write_handle,
220            audit_principal,
221            max_tokens: DEFAULT_SAMPLING_MAX_TOKENS,
222            timeout: DEFAULT_SAMPLING_TIMEOUT,
223        }
224    }
225
226    /// Override the per-call `max_tokens` cap.
227    pub fn with_max_tokens(mut self, n: u32) -> Self {
228        self.max_tokens = n.max(1);
229        self
230    }
231
232    /// Override the per-call timeout.
233    pub fn with_timeout(mut self, t: Duration) -> Self {
234        self.timeout = t;
235        self
236    }
237
238    /// Build the `CreateMessageRequestParams` from Solo's `Message`
239    /// vec. Splits out `Role::System` into the `system_prompt` field
240    /// (rmcp's `SamplingMessage::role` is only User / Assistant) and
241    /// hints the user's MCP client toward a Claude-class model.
242    fn build_request(&self, messages: &[Message]) -> CreateMessageRequestParams {
243        // Split system messages out of the conversation history; the
244        // sampling protocol carries the system prompt as a top-level
245        // field rather than inline.
246        let mut system_parts: Vec<String> = Vec::new();
247        let mut samp_messages: Vec<SamplingMessage> = Vec::new();
248        for m in messages {
249            match m.role {
250                Role::System => system_parts.push(m.content.clone()),
251                Role::User => {
252                    samp_messages.push(SamplingMessage::user_text(&m.content));
253                }
254                Role::Assistant => {
255                    samp_messages.push(SamplingMessage::assistant_text(&m.content));
256                }
257            }
258        }
259        // rmcp 1.7's struct literals are non-exhaustive across crate
260        // boundaries; build via the typed constructors + builders.
261        let preferences = ModelPreferences::new()
262            .with_hints(vec![ModelHint::new("claude")])
263            .with_intelligence_priority(0.7)
264            .with_speed_priority(0.3)
265            .with_cost_priority(0.4);
266        let mut params = CreateMessageRequestParams::new(samp_messages, self.max_tokens)
267            .with_model_preferences(preferences);
268        if !system_parts.is_empty() {
269            params = params.with_system_prompt(system_parts.join("\n\n"));
270        }
271        params
272    }
273
274    /// Build the audit `AuditEvent` carrying ONLY metadata. No raw
275    /// prompt content lands in `details_json`.
276    ///
277    /// Pinned by [`tests::audit_row_omits_raw_prompt_text`].
278    ///
279    /// v0.9.1 P1 Fix 4 (F6 privacy bucketing): the raw character count
280    /// of the prompt is itself a side-channel — a 6-char prompt
281    /// uniquely identifies very-short refusal paths (e.g. a leaked
282    /// password length). `prompt_chars` and `input_tokens_est` are
283    /// rounded up to the next power of two before persistence. This
284    /// preserves operator capacity-planning (the bucket is within ~2x
285    /// of the real size for any sufficiently large prompt) while
286    /// removing the per-character precision.
287    ///
288    /// Buckets: `0, 1, 2, 4, 8, 16, 32, 64, ..., 1024, 2048, ...`
289    /// (next power of two `>= n`). 0 stays 0.
290    fn audit_event(
291        &self,
292        params: &CreateMessageRequestParams,
293        outcome: SamplingOutcome,
294    ) -> AuditEvent {
295        let raw_prompt_chars: usize = params
296            .messages
297            .iter()
298            .flat_map(|m| m.content.iter())
299            .filter_map(|c| c.as_text().map(|t| t.text.len()))
300            .sum::<usize>()
301            + params.system_prompt.as_ref().map(|s| s.len()).unwrap_or(0);
302        // v0.9.1 P1 Fix 4: bucket the raw count to the next power of
303        // two. Pinned by `tests::audit_row_bucket_prompt_chars_to_pow2`.
304        let prompt_chars = next_pow2_bucket(raw_prompt_chars);
305        // ~4 chars per token for the rough English-text estimate used
306        // by `solo doctor --check-llm` and Anthropic's docs. Recorded
307        // for operator capacity-planning. Bucketed for the same
308        // privacy reason — and to stay consistent with `prompt_chars`.
309        let input_tokens_est = next_pow2_bucket(raw_prompt_chars / 4) as u64;
310        let model_hint = params
311            .model_preferences
312            .as_ref()
313            .and_then(|p| p.hints.as_ref())
314            .and_then(|h| h.first())
315            .and_then(|h| h.name.clone())
316            .unwrap_or_else(|| "(none)".to_string());
317
318        let mut details = serde_json::Map::new();
319        details.insert(
320            "model_hint".to_string(),
321            serde_json::Value::String(model_hint),
322        );
323        details.insert(
324            "messages_count".to_string(),
325            serde_json::Value::Number(params.messages.len().into()),
326        );
327        details.insert(
328            "max_tokens".to_string(),
329            serde_json::Value::Number(params.max_tokens.into()),
330        );
331        details.insert(
332            "prompt_chars".to_string(),
333            serde_json::Value::Number(prompt_chars.into()),
334        );
335        details.insert(
336            "input_tokens_est".to_string(),
337            serde_json::Value::Number(input_tokens_est.into()),
338        );
339
340        let result = match &outcome {
341            SamplingOutcome::Ok {
342                duration_ms,
343                model,
344                output_chars,
345            } => {
346                // v0.9.1 P1 Fix 4: same power-of-2 bucketing as
347                // `prompt_chars` for the output side. A model that
348                // always replies with a one-token refusal (e.g. an
349                // assistant trained to say "no.") would otherwise leak
350                // the response-length distribution; bucketing
351                // collapses 1-2-3-4 chars all into bucket 4.
352                let bucketed_output_chars = next_pow2_bucket(*output_chars);
353                let output_tokens_est = next_pow2_bucket(*output_chars / 4) as u64;
354                details.insert(
355                    "duration_ms".to_string(),
356                    serde_json::Value::Number((*duration_ms).into()),
357                );
358                details.insert(
359                    "model".to_string(),
360                    serde_json::Value::String(model.clone()),
361                );
362                details.insert(
363                    "output_chars".to_string(),
364                    serde_json::Value::Number(bucketed_output_chars.into()),
365                );
366                details.insert(
367                    "output_tokens_est".to_string(),
368                    serde_json::Value::Number(output_tokens_est.into()),
369                );
370                AuditResult::Ok
371            }
372            SamplingOutcome::Forbidden {
373                reason,
374                duration_ms,
375            } => {
376                details.insert(
377                    "duration_ms".to_string(),
378                    serde_json::Value::Number((*duration_ms).into()),
379                );
380                details.insert(
381                    "reason".to_string(),
382                    serde_json::Value::String(reason.to_string()),
383                );
384                AuditResult::Forbidden
385            }
386            SamplingOutcome::Error {
387                reason,
388                duration_ms,
389            } => {
390                details.insert(
391                    "duration_ms".to_string(),
392                    serde_json::Value::Number((*duration_ms).into()),
393                );
394                details.insert(
395                    "reason".to_string(),
396                    serde_json::Value::String(reason.to_string()),
397                );
398                AuditResult::Error
399            }
400        };
401
402        AuditEvent {
403            ts_ms: chrono::Utc::now().timestamp_millis(),
404            principal_subject: self.audit_principal.clone(),
405            operation: AuditOperation::LlmSamplingCall,
406            target_id: None,
407            result,
408            details: Some(serde_json::Value::Object(details)),
409        }
410    }
411}
412
413/// Internal outcome category for the audit-row builder.
414enum SamplingOutcome {
415    Ok {
416        duration_ms: u64,
417        model: String,
418        output_chars: usize,
419    },
420    Forbidden {
421        reason: &'static str,
422        duration_ms: u64,
423    },
424    Error {
425        reason: &'static str,
426        duration_ms: u64,
427    },
428}
429
430#[async_trait]
431impl LlmClient for SamplingLlmClient {
432    fn name(&self) -> &str {
433        "mcp-sampling"
434    }
435
436    async fn complete(&self, messages: &[Message]) -> CoreResult<Message> {
437        let params = self.build_request(messages);
438        let start = Instant::now();
439
440        // Bounded wait on `peer.create_message`. The fold of (rmcp
441        // ServiceError | FakeError | tokio timeout) into the
442        // `Outcome` enum keeps the audit path single-sourced.
443        let rpc = tokio::time::timeout(
444            self.timeout,
445            self.sampling_client.create_message(params.clone()),
446        )
447        .await;
448        let duration_ms = start.elapsed().as_millis().min(u128::from(u64::MAX)) as u64;
449
450        let (core_result, outcome): (CoreResult<Message>, SamplingOutcome) = match rpc {
451            Ok(Ok(result)) => match extract_text(&result) {
452                Ok(text) => {
453                    let output_chars = text.len();
454                    let outcome = SamplingOutcome::Ok {
455                        duration_ms,
456                        model: result.model.clone(),
457                        output_chars,
458                    };
459                    (Ok(Message::assistant(text)), outcome)
460                }
461                Err(reason) => (
462                    Err(CoreError::llm(format!(
463                        "mcp sampling: malformed response: {reason}"
464                    ))),
465                    SamplingOutcome::Error {
466                        reason: "malformed_response",
467                        duration_ms,
468                    },
469                ),
470            },
471            Ok(Err(e)) => {
472                let (category, is_forbidden) = e.classify();
473                let outcome = if is_forbidden {
474                    SamplingOutcome::Forbidden {
475                        reason: category,
476                        duration_ms,
477                    }
478                } else {
479                    SamplingOutcome::Error {
480                        reason: category,
481                        duration_ms,
482                    }
483                };
484                let err = if is_forbidden {
485                    CoreError::forbidden(format!("mcp sampling: {e}"))
486                } else {
487                    CoreError::llm(format!("mcp sampling: {e}"))
488                };
489                (Err(err), outcome)
490            }
491            Err(_elapsed) => (
492                Err(CoreError::llm(format!(
493                    "mcp sampling: timeout after {}ms",
494                    duration_ms
495                ))),
496                SamplingOutcome::Error {
497                    reason: "timeout",
498                    duration_ms,
499                },
500            ),
501        };
502
503        // Synchronous audit emit through the writer-actor (lesson #30).
504        // Failure to land the audit row is operator-visible: the
505        // sampling call's caller sees the storage error and can decide
506        // whether to abort (we DO abort here — without the audit row
507        // we have no record of the call).
508        //
509        // v0.9.1 P1 Fix 3 (F4 Result-shadowing): when BOTH `core_result`
510        // is `Err(..)` AND the audit emit also fails, return the
511        // ORIGINAL LLM-side error (more actionable for callers — they
512        // can retry the LLM call, or decide whether the upstream
513        // refusal/timeout is recoverable). Surface the audit failure
514        // via `tracing::error!` for operator visibility — operators
515        // alarming on storage errors see it; callers see the actionable
516        // error.
517        //
518        // Policy summary:
519        //   * RPC Ok  + audit Ok  → return Ok(text)
520        //   * RPC Ok  + audit Err → return Err(storage) [audit failure
521        //                            wins — no undocumented sampling
522        //                            calls per lesson #30]
523        //   * RPC Err + audit Ok  → return Err(llm/forbidden) [unchanged]
524        //   * RPC Err + audit Err → return Err(llm/forbidden) AND log
525        //                            audit failure at error level
526        //                            [v0.9.1 P1 Fix 3]
527        let event = self.audit_event(&params, outcome);
528        match (
529            core_result,
530            self.write_handle.emit_llm_sampling_audit(event).await,
531        ) {
532            (Ok(text), Ok(())) => Ok(text),
533            (Ok(_text), Err(audit_err)) => {
534                // RPC succeeded but the audit row didn't land. Drop
535                // the success — without a durable audit row we can't
536                // honor the "every sampling call leaves a trace"
537                // invariant.
538                Err(CoreError::storage(format!(
539                    "mcp sampling: audit emit failed: {audit_err}"
540                )))
541            }
542            (Err(core_err), Ok(())) => Err(core_err),
543            (Err(core_err), Err(audit_err)) => {
544                // Both failed. Return the LLM-side error (the caller's
545                // most actionable signal); log the audit failure so an
546                // operator who alarms on storage errors still sees it.
547                tracing::error!(
548                    audit_error = %audit_err,
549                    core_error = %core_err,
550                    "mcp sampling: audit emit failed alongside core \
551                     error; surfacing core error to caller"
552                );
553                Err(core_err)
554            }
555        }
556    }
557}
558
559/// Round `n` up to the next power of two. Used to bucket
560/// `prompt_chars` / `output_chars` / `*_tokens_est` in the
561/// `LlmSamplingCall` audit row's `details_json` (v0.9.1 P1 Fix 4
562/// "F6" — `prompt_chars` was a privacy side-channel for short
563/// prompts).
564///
565/// Buckets: `0 → 0`, `1 → 1`, `2 → 2`, `3 → 4`, `4 → 4`, `5..=8 → 8`,
566/// `9..=16 → 16`, `17..=32 → 32`, ... — within a bucket all values
567/// collapse to the same persisted number. The worst-case fidelity
568/// loss is just under 2x (e.g. 9 chars persists as 16) which is well
569/// within the precision capacity-planning needs.
570///
571/// Pinned by [`tests::next_pow2_bucket_*`] and
572/// [`tests::audit_row_bucket_prompt_chars_to_pow2`].
573fn next_pow2_bucket(n: usize) -> usize {
574    if n == 0 {
575        return 0;
576    }
577    // `next_power_of_two` saturates at `usize::MAX` if `n` is past the
578    // last representable power. For our use (char counts on a Solo
579    // prompt) the absolute upper bound is the LLM model's context
580    // window — well below `usize::MAX` on every Solo-supported target.
581    n.next_power_of_two()
582}
583
584/// Pull the assistant's text out of the rmcp result. Walks every text
585/// content block in the message (the spec allows either a single
586/// `SamplingContent::Single` or a `SamplingContent::Multiple`) and
587/// concatenates them with newlines. Returns `Err(reason)` if no text
588/// blocks were present — the malformed-response path.
589fn extract_text(result: &CreateMessageResult) -> Result<String, &'static str> {
590    if result.message.role != RmcpRole::Assistant {
591        return Err("response role was not Assistant");
592    }
593    let mut out = String::new();
594    for content in result.message.content.iter() {
595        if let SamplingMessageContent::Text(text) = content {
596            if !out.is_empty() {
597                out.push('\n');
598            }
599            out.push_str(&text.text);
600        }
601    }
602    if out.is_empty() {
603        Err("no text content blocks")
604    } else {
605        Ok(out)
606    }
607}
608
609/// v0.9.0 P2: build a sampling-backed `Arc<Steward>` for a tenant that
610/// has resolved `LlmConfig::McpSampling` and just attached an MCP
611/// session.
612///
613/// Called from [`crate::mcp::SoloMcpServer::populate_sampling_steward`] at
614/// MCP `initialize` time once the peer's sampling capability is
615/// confirmed. The returned `Arc<Steward>` is written into
616/// `tenant.steward_slot()` so the writer-actor + consolidate timer
617/// can read a populated slot on their next tick.
618///
619/// v0.9.0 P5 (M3 wiring): the live `PeerSamplingClient` is now wrapped
620/// in a [`super::SamplingCoordinator`] before being handed to
621/// `SamplingLlmClient`. Concurrent `complete()` calls within the
622/// coalesce window collapse into one `peer.create_message` RPC and the
623/// response is demultiplexed back per-task — matching the
624/// `[sampling] coalesce_window_ms` / `coalesce_max_requests` config the
625/// operator wrote in `solo.config.toml`. Per-call audit emit semantics
626/// are unchanged: every logical request still lands one
627/// `AuditOperation::LlmSamplingCall` row, no raw prompt content escapes
628/// to the audit row.
629///
630/// Edge case (clamping): the `[sampling]` block accepts values that
631/// effectively disable batching — `coalesce_max_requests = 1` and / or
632/// `coalesce_window_ms = 0` reduce the coordinator to pass-through (one
633/// inner call per submission). The coordinator's
634/// [`super::SamplingCoordinator::with_settings`] clamps `max_batch` to
635/// `max(1)` so a zero value still produces a single-element flush
636/// immediately rather than panicking or deadlocking.
637pub fn build_sampling_steward(
638    peer: Peer<RoleServer>,
639    write_handle: WriteHandle,
640    audit_principal: Option<String>,
641    steward_config: solo_steward::StewardConfig,
642    sampling_config: solo_storage::SamplingConfig,
643) -> Arc<solo_steward::Steward> {
644    let inner: Arc<dyn SamplingClient> = Arc::new(PeerSamplingClient::new(peer));
645    let coordinator: Arc<dyn SamplingClient> = super::SamplingCoordinator::with_settings(
646        inner,
647        std::time::Duration::from_millis(sampling_config.coalesce_window_ms),
648        sampling_config.coalesce_max_requests as usize,
649    );
650    let client =
651        SamplingLlmClient::with_sampling_client(coordinator, write_handle, audit_principal)
652            .with_max_tokens(steward_config.abstraction_max_tokens.min(65_536) as u32);
653    Arc::new(solo_steward::Steward::new(Arc::new(client), steward_config))
654}
655
656#[cfg(test)]
657mod tests {
658    use super::*;
659    use crate::test_support::{FakeMcpClient, FakeResponse, FakeSamplingError};
660    use rmcp::model::CreateMessageResult;
661    use solo_core::TenantId;
662    use solo_storage::{
663        EmbedderConfig, HnswParams, InitParams, KeyMaterial, StubEmbedder, TenantHandle,
664        TenantRegistry, TenantRegistryParams, init, open_sqlcipher,
665    };
666    use std::path::PathBuf;
667    use std::sync::Arc;
668    use tempfile::TempDir;
669    use zeroize::Zeroizing;
670
671    const TEST_PASSPHRASE: &str = "v0.9.0-p2-sampling-tests";
672
673    /// Bootstrap a per-tenant `TenantHandle` whose writer-actor accepts
674    /// the new `WriteCommand::EmitLlmSamplingAudit` variant.
675    ///
676    /// Mirrors the v0.8.x test discipline (see
677    /// `crates/solo-storage/src/tenants/handle_registry_tests.rs`'s
678    /// `fresh_init_dir`): build a real tenant DB on disk via the same
679    /// `init()` helper users invoke, wrap in a `TenantRegistry`, and
680    /// surface the `WriteHandle` for direct `SamplingLlmClient`
681    /// wiring.
682    struct Harness {
683        _tmp: TempDir,
684        _registry: Arc<TenantRegistry>,
685        _tenant: Arc<TenantHandle>,
686        write_handle: solo_storage::WriteHandle,
687        db_path: PathBuf,
688        key: KeyMaterial,
689    }
690
691    async fn harness() -> Harness {
692        let tmp = TempDir::new().expect("tempdir");
693        let data_dir = tmp.path().to_path_buf();
694        let _ = init(InitParams {
695            data_dir: data_dir.clone(),
696            passphrase: Zeroizing::new(TEST_PASSPHRASE.into()),
697            force: false,
698            embedder: EmbedderConfig {
699                name: "stub".into(),
700                version: "v1".into(),
701                dim: 32,
702                dtype: "f32".into(),
703            },
704        })
705        .expect("init");
706
707        let cfg =
708            solo_storage::SoloConfig::read(&data_dir.join("solo.config.toml")).expect("read cfg");
709        let key = KeyMaterial::derive(TEST_PASSPHRASE, &cfg.salt_bytes().expect("salt"))
710            .expect("derive key");
711
712        let embedder: Arc<dyn solo_core::Embedder> = Arc::new(StubEmbedder::new("stub", "v1", 32));
713        let registry = Arc::new(
714            TenantRegistry::open(TenantRegistryParams {
715                data_dir: data_dir.clone(),
716                key: key.clone(),
717                embedder: embedder.clone(),
718                hnsw_params: HnswParams::default(),
719                steward: None,
720                runtime_handle: Some(tokio::runtime::Handle::current()),
721                steward_factory: None,
722                triples_batch_signal: None,
723            })
724            .expect("open registry"),
725        );
726
727        let tenant_id = TenantId::default_tenant();
728        let tenant = registry
729            .get_or_open(&tenant_id)
730            .await
731            .expect("get_or_open default tenant");
732        let write_handle = tenant.write().clone();
733        let db_path = tenant.db_path().to_path_buf();
734
735        Harness {
736            _tmp: tmp,
737            _registry: registry,
738            _tenant: tenant,
739            write_handle,
740            db_path,
741            key,
742        }
743    }
744
745    /// Helper: count the `audit_events` rows whose `operation` is the
746    /// given string. Opens a fresh connection to the tenant DB so we
747    /// avoid contention with the writer-actor's own connection.
748    fn count_audit_rows(db_path: &std::path::Path, key: &KeyMaterial, op: &str) -> i64 {
749        let conn = open_sqlcipher(db_path, key).expect("open db");
750        conn.query_row(
751            "SELECT COUNT(*) FROM audit_events WHERE operation = ?",
752            rusqlite::params![op],
753            |r| r.get(0),
754        )
755        .expect("count")
756    }
757
758    /// Helper: load the most-recent `llm.sampling_call` audit row and
759    /// return `(result, principal_subject, details_json)`.
760    fn latest_sampling_audit_details(
761        db_path: &std::path::Path,
762        key: &KeyMaterial,
763    ) -> (String, Option<String>, serde_json::Value) {
764        let conn = open_sqlcipher(db_path, key).expect("open db");
765        let (result, principal, details_str): (String, Option<String>, Option<String>) = conn
766            .query_row(
767                "SELECT result, principal_subject, details_json
768                 FROM audit_events
769                 WHERE operation = 'llm.sampling_call'
770                 ORDER BY ts_ms DESC, rowid DESC
771                 LIMIT 1",
772                [],
773                |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
774            )
775            .expect("query");
776        let details: serde_json::Value =
777            serde_json::from_str(&details_str.expect("details_json present"))
778                .expect("parse details");
779        (result, principal, details)
780    }
781
782    /// Happy path: a successful `create_message` round-trip returns
783    /// the assistant text wrapped in a `Message::assistant`, and lands
784    /// exactly one `llm.sampling_call` audit row with `result = 'ok'`.
785    #[tokio::test]
786    async fn sampling_complete_happy_path_returns_text() {
787        let h = harness().await;
788        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("derived theme")));
789        let client = SamplingLlmClient::with_sampling_client(
790            fake.clone(),
791            h.write_handle.clone(),
792            Some("alice".into()),
793        );
794        let messages = vec![Message::user("summarise these episodes")];
795        let result = client.complete(&messages).await.expect("ok");
796        assert_eq!(result.role, Role::Assistant);
797        assert_eq!(result.content, "derived theme");
798
799        // Exactly one audit row landed.
800        assert_eq!(count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"), 1);
801        let (result_str, principal, details) = latest_sampling_audit_details(&h.db_path, &h.key);
802        assert_eq!(result_str, "ok");
803        assert_eq!(principal.as_deref(), Some("alice"));
804        assert_eq!(details["model_hint"], "claude");
805        assert_eq!(details["model"], "fake-claude");
806        assert_eq!(details["messages_count"], 1);
807        assert_eq!(details["max_tokens"], 512);
808    }
809
810    /// Privacy invariant: the audit row's `details_json` MUST NOT
811    /// contain the raw prompt content. Pinned by string inspection of
812    /// the persisted JSON.
813    #[tokio::test]
814    async fn audit_row_omits_raw_prompt_text() {
815        let h = harness().await;
816        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
817        let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
818        let secret = "THE-USER-ID-IS-bobby-1234";
819        let messages = vec![
820            Message::system("you are a friendly assistant"),
821            Message::user(secret),
822        ];
823        client.complete(&messages).await.expect("ok");
824
825        let (_, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
826        let serialised = serde_json::to_string(&details).expect("serialise details");
827        assert!(
828            !serialised.contains(secret),
829            "audit details must not carry raw prompt content; was: {serialised}"
830        );
831        assert!(
832            !serialised.contains("you are a friendly assistant"),
833            "audit details must not carry system prompt; was: {serialised}"
834        );
835        // Metadata IS present, even though the prompt is not.
836        assert_eq!(details["messages_count"], 1);
837        assert!(details["prompt_chars"].as_u64().unwrap() > 0);
838    }
839
840    /// v0.9.1 P1 Fix 4 (F6 privacy bucketing): the audit row's
841    /// `prompt_chars` MUST be the power-of-2 bucket, never the raw
842    /// character count. Pins the bucketing behavior end-to-end (raw
843    /// `audit_event` → SQLite → re-read).
844    ///
845    /// Test recipe: drive a prompt with a known raw length (6 chars
846    /// total, `"hello "` system + `"x"` user → 6+1 = 7) and assert the
847    /// audit row carries `8` (next pow2 ≥ 7), not 7.
848    #[tokio::test]
849    async fn audit_row_bucket_prompt_chars_to_pow2() {
850        let h = harness().await;
851        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
852        let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
853        // System: 6 chars + user: 1 char = 7 chars raw → bucket 8.
854        client
855            .complete(&[Message::system("hello "), Message::user("x")])
856            .await
857            .expect("ok");
858        let (_, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
859        assert_eq!(
860            details["prompt_chars"].as_u64().unwrap(),
861            8,
862            "prompt_chars must be bucketed to next pow2 (7 → 8). \
863             raw count is a privacy side-channel; see Fix 4 F6 in \
864             v0.9.1 P1 dev log. got details={details}"
865        );
866    }
867
868    /// Stability invariant: two prompts that fall in the SAME bucket
869    /// must persist identical `prompt_chars`. Distinguishes "the
870    /// implementation buckets" from "the implementation hashes/leaks
871    /// raw values".
872    ///
873    /// 5 chars and 7 chars both round to 8 → must persist identically.
874    /// (Mirrors the brief's "test that bucketed values are stable
875    /// across exact-character variations within the same bucket".)
876    #[tokio::test]
877    async fn audit_row_bucket_prompt_chars_is_stable_within_bucket() {
878        let h = harness().await;
879        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
880        let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
881        // 5 chars raw → bucket 8.
882        client
883            .complete(&[Message::user("hello")])
884            .await
885            .expect("ok");
886        let (_, _, details_5) = latest_sampling_audit_details(&h.db_path, &h.key);
887        // 7 chars raw → bucket 8.
888        client
889            .complete(&[Message::user("hellooo")])
890            .await
891            .expect("ok");
892        let (_, _, details_7) = latest_sampling_audit_details(&h.db_path, &h.key);
893        assert_eq!(
894            details_5["prompt_chars"], details_7["prompt_chars"],
895            "5 chars and 7 chars must hash to the same bucket (8) — \
896             otherwise the bucketing is leaking raw fidelity. \
897             5-char details: {details_5}, 7-char details: {details_7}"
898        );
899        assert_eq!(details_5["prompt_chars"].as_u64().unwrap(), 8);
900    }
901
902    /// Unit-level pins for the bucketing helper. Catches a regression
903    /// where someone "simplifies" `next_pow2_bucket` into a no-op.
904    #[test]
905    fn next_pow2_bucket_table() {
906        assert_eq!(next_pow2_bucket(0), 0, "0 stays 0");
907        assert_eq!(next_pow2_bucket(1), 1, "1 stays 1");
908        assert_eq!(next_pow2_bucket(2), 2, "2 stays 2");
909        assert_eq!(next_pow2_bucket(3), 4, "3 rounds up to 4");
910        assert_eq!(next_pow2_bucket(4), 4, "4 stays 4");
911        assert_eq!(next_pow2_bucket(5), 8);
912        assert_eq!(next_pow2_bucket(6), 8, "6-char prompt (brief case) → 8");
913        assert_eq!(next_pow2_bucket(7), 8);
914        assert_eq!(next_pow2_bucket(8), 8);
915        assert_eq!(next_pow2_bucket(9), 16);
916        assert_eq!(next_pow2_bucket(1023), 1024);
917        assert_eq!(next_pow2_bucket(1024), 1024);
918        assert_eq!(next_pow2_bucket(1025), 2048);
919    }
920
921    /// Client refusal: maps to `CoreError::Forbidden` + audit row
922    /// `result = 'forbidden'` + `details_json.reason = 'client_refused'`.
923    #[tokio::test]
924    async fn client_refusal_returns_forbidden_and_audits() {
925        let h = harness().await;
926        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ignored")));
927        fake.reject_with("user dismissed approval");
928        let client = SamplingLlmClient::with_sampling_client(
929            fake,
930            h.write_handle.clone(),
931            Some("alice".into()),
932        );
933        let err = client
934            .complete(&[Message::user("anything")])
935            .await
936            .unwrap_err();
937        match err {
938            CoreError::Forbidden(_) => {}
939            other => panic!("expected Forbidden, got {other:?}"),
940        }
941        let (result_str, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
942        assert_eq!(result_str, "forbidden");
943        assert_eq!(details["reason"], "client_refused");
944    }
945
946    /// Timeout: tokio::time::timeout fires before the fake's `Slow`
947    /// response resolves; client returns `CoreError::Llm` + audit row
948    /// `result = 'error'` + `details_json.reason = 'timeout'`.
949    ///
950    /// Real wall-clock: 80ms slow response vs 30ms client timeout.
951    /// Margin is loose enough for slow CI without making the test
952    /// drag.
953    #[tokio::test]
954    async fn timeout_returns_error_with_timeout_reason() {
955        let h = harness().await;
956        let fake = Arc::new(FakeMcpClient::new(FakeResponse::slow(
957            "late",
958            Duration::from_millis(800),
959        )));
960        let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None)
961            .with_timeout(Duration::from_millis(30));
962        let err = client
963            .complete(&[Message::user("hello")])
964            .await
965            .unwrap_err();
966        match err {
967            CoreError::Llm(msg) => assert!(msg.contains("timeout")),
968            other => panic!("expected Llm, got {other:?}"),
969        }
970        let (result_str, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
971        assert_eq!(result_str, "error");
972        assert_eq!(details["reason"], "timeout");
973    }
974
975    /// Malformed response: the fake returns a result with zero text
976    /// content blocks; client surfaces `CoreError::Llm` + audit row
977    /// `result = 'error'` + `details_json.reason = 'malformed_response'`.
978    #[tokio::test]
979    async fn malformed_response_returns_error_with_reason() {
980        let h = harness().await;
981        let fake = Arc::new(FakeMcpClient::new(FakeResponse::EmptyContent));
982        let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
983        let err = client.complete(&[Message::user("hi")]).await.unwrap_err();
984        assert!(matches!(err, CoreError::Llm(_)));
985        let (result_str, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
986        assert_eq!(result_str, "error");
987        assert_eq!(details["reason"], "malformed_response");
988    }
989
990    /// `principal_subject = None` works — audit row still emits with
991    /// NULL.
992    #[tokio::test]
993    async fn no_principal_emits_audit_with_null_principal() {
994        let h = harness().await;
995        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
996        let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
997        client.complete(&[Message::user("hi")]).await.expect("ok");
998        let (_, principal, _) = latest_sampling_audit_details(&h.db_path, &h.key);
999        assert_eq!(principal, None);
1000    }
1001
1002    /// Concurrency: 8 parallel `complete()` calls land 8 audit rows.
1003    /// Audit IDs (autoincrement rowid) must be distinct — verifies the
1004    /// writer-actor serialises the per-call audit emit (no
1005    /// interleaving / dropped rows).
1006    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1007    async fn parallel_completes_serialise_audit_rows() {
1008        let h = harness().await;
1009        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1010        let client = SamplingLlmClient::with_sampling_client(
1011            fake.clone(),
1012            h.write_handle.clone(),
1013            Some("alice".into()),
1014        );
1015        let mut futs = Vec::new();
1016        for _ in 0..8 {
1017            let c = client.clone();
1018            futs.push(tokio::spawn(async move {
1019                c.complete(&[Message::user("hi")]).await
1020            }));
1021        }
1022        for f in futs {
1023            f.await.expect("join").expect("ok");
1024        }
1025        assert_eq!(
1026            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1027            8,
1028            "8 parallel calls must land 8 audit rows"
1029        );
1030
1031        // Each was a separate request to the fake.
1032        assert_eq!(fake.record_requests().len(), 8);
1033    }
1034
1035    /// `complete` translates the workspace's `Message::system` into the
1036    /// `system_prompt` top-level field; user/assistant roles map to
1037    /// rmcp's `SamplingMessage::user_text` / `assistant_text`.
1038    #[tokio::test]
1039    async fn build_request_splits_system_from_messages() {
1040        let h = harness().await;
1041        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1042        let client =
1043            SamplingLlmClient::with_sampling_client(fake.clone(), h.write_handle.clone(), None);
1044        client
1045            .complete(&[
1046                Message::system("be terse"),
1047                Message::user("question"),
1048                Message::assistant("answer"),
1049            ])
1050            .await
1051            .expect("ok");
1052        let recorded = fake.record_requests();
1053        assert_eq!(recorded.len(), 1);
1054        let req = &recorded[0];
1055        assert_eq!(
1056            req.system_prompt.as_deref(),
1057            Some("be terse"),
1058            "Role::System must map to system_prompt"
1059        );
1060        assert_eq!(req.messages.len(), 2);
1061        // The remaining two messages are the user + assistant turns.
1062        assert_eq!(req.messages[0].role, RmcpRole::User);
1063        assert_eq!(req.messages[1].role, RmcpRole::Assistant);
1064    }
1065
1066    /// `model_preferences` carries the `claude` hint per plan §6.
1067    /// Pins the wire shape so a future change is a conscious decision.
1068    #[tokio::test]
1069    async fn build_request_includes_claude_model_hint() {
1070        let h = harness().await;
1071        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1072        let client =
1073            SamplingLlmClient::with_sampling_client(fake.clone(), h.write_handle.clone(), None);
1074        client.complete(&[Message::user("hi")]).await.expect("ok");
1075        let recorded = fake.record_requests();
1076        let prefs = recorded[0].model_preferences.as_ref().expect("prefs");
1077        let hint = prefs
1078            .hints
1079            .as_ref()
1080            .and_then(|h| h.first())
1081            .and_then(|h| h.name.clone())
1082            .expect("hint name");
1083        assert_eq!(hint, "claude");
1084    }
1085
1086    /// `with_max_tokens(n)` propagates to the request's
1087    /// `max_tokens` field.
1088    #[tokio::test]
1089    async fn with_max_tokens_overrides_default() {
1090        let h = harness().await;
1091        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1092        let client =
1093            SamplingLlmClient::with_sampling_client(fake.clone(), h.write_handle.clone(), None)
1094                .with_max_tokens(2048);
1095        client.complete(&[Message::user("hi")]).await.expect("ok");
1096        let recorded = fake.record_requests();
1097        assert_eq!(recorded[0].max_tokens, 2048);
1098    }
1099
1100    /// Reconfiguring the fake mid-test produces distinct audit rows
1101    /// for each call (positive then negative).
1102    #[tokio::test]
1103    async fn reconfigurable_fake_distinguishes_audit_rows() {
1104        let h = harness().await;
1105        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1106        let client = SamplingLlmClient::with_sampling_client(
1107            fake.clone(),
1108            h.write_handle.clone(),
1109            Some("alice".into()),
1110        );
1111
1112        client.complete(&[Message::user("a")]).await.expect("ok");
1113        fake.reject_with("user said no");
1114        let _ = client.complete(&[Message::user("b")]).await;
1115
1116        let conn = open_sqlcipher(&h.db_path, &h.key).expect("open");
1117        let mut stmt = conn
1118            .prepare(
1119                "SELECT result FROM audit_events WHERE operation = 'llm.sampling_call' ORDER BY ts_ms ASC, rowid ASC",
1120            )
1121            .expect("prepare");
1122        let rows: Vec<String> = stmt
1123            .query_map([], |r| r.get::<_, String>(0))
1124            .expect("query")
1125            .map(|r| r.expect("row"))
1126            .collect();
1127        assert_eq!(rows, vec!["ok".to_string(), "forbidden".to_string()]);
1128    }
1129
1130    /// `extract_text` walks single-block content.
1131    #[test]
1132    fn extract_text_pulls_text_from_single_block() {
1133        let result =
1134            CreateMessageResult::new(SamplingMessage::assistant_text("hello"), "fake".into());
1135        assert_eq!(extract_text(&result).unwrap(), "hello");
1136    }
1137
1138    /// `extract_text` rejects an empty-content response.
1139    #[test]
1140    fn extract_text_rejects_empty_content() {
1141        let result = CreateMessageResult::new(
1142            SamplingMessage::new_multiple(RmcpRole::Assistant, Vec::new()),
1143            "fake".into(),
1144        );
1145        assert!(extract_text(&result).is_err());
1146    }
1147
1148    /// `extract_text` rejects a User-role response (impossible per
1149    /// spec but pinning the defensive check).
1150    #[test]
1151    fn extract_text_rejects_non_assistant_role() {
1152        let result = CreateMessageResult::new(SamplingMessage::user_text("hello"), "fake".into());
1153        assert!(extract_text(&result).is_err());
1154    }
1155
1156    // ---- v0.10.1 F5 audit-minor closure: pin multi-block concat
1157    //      semantics (deferred from v0.9.0 P2). ----
1158    //
1159    // `extract_text` walks `result.message.content` and pushes a `\n`
1160    // between successive `SamplingMessageContent::Text` blocks. The
1161    // exact semantics are wire-shape decisions that downstream Steward
1162    // parsers depend on; we pin them here so a future refactor can't
1163    // drift silently. The audit minor (v0.9.0 P2 §F5) flagged that the
1164    // existing tests only exercised single-block + empty + non-
1165    // assistant cases — not the multi-block join behavior. Closure
1166    // checklist:
1167    //
1168    //   1. join inserts ONE `\n` between non-trailing-newline blocks
1169    //   2. trailing `\n` in a block is preserved (so a `"abc\n"` block
1170    //      + a `"def"` block produces `"abc\n\ndef"`)
1171    //   3. single block returns verbatim (no leading/trailing newline
1172    //      added)
1173    //
1174    // These tests are intentionally identical-output to what the code
1175    // produces today. They're a regression net, not a behavior change.
1176    //
1177    // Grep terms: F5, extract_text_joins_multi_block, extract_text_preserves_trailing_newlines.
1178
1179    /// F5 pin: a two-block message joins with exactly ONE `\n` between
1180    /// blocks when neither block ends in a newline.
1181    #[test]
1182    fn extract_text_joins_multi_block_with_newline_separator() {
1183        let blocks = vec![
1184            SamplingMessageContent::text("abc"),
1185            SamplingMessageContent::text("def"),
1186        ];
1187        let result = CreateMessageResult::new(
1188            SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1189            "fake".into(),
1190        );
1191        // Exact value pinned: "abc" + "\n" + "def".
1192        assert_eq!(
1193            extract_text(&result).unwrap(),
1194            "abc\ndef",
1195            "two non-newline-terminated blocks must join with a single newline"
1196        );
1197    }
1198
1199    /// F5 pin: a trailing newline in a content block is preserved
1200    /// verbatim AND a join newline is still inserted, so the result
1201    /// contains `\n\n` between such a block and the next. This is the
1202    /// "honest current behavior" pin — a future refactor that strips
1203    /// trailing newlines must explicitly update this test.
1204    #[test]
1205    fn extract_text_preserves_trailing_newlines_in_blocks() {
1206        let blocks = vec![
1207            SamplingMessageContent::text("abc\n"),
1208            SamplingMessageContent::text("def"),
1209        ];
1210        let result = CreateMessageResult::new(
1211            SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1212            "fake".into(),
1213        );
1214        assert_eq!(
1215            extract_text(&result).unwrap(),
1216            "abc\n\ndef",
1217            "trailing newline in block 1 + join newline => '\\n\\n' between blocks"
1218        );
1219    }
1220
1221    /// F5 pin: a single block returns verbatim — no leading or
1222    /// trailing newline added by the helper. Already covered by
1223    /// `extract_text_pulls_text_from_single_block` for the "hello"
1224    /// case; this variant pins the negative space for inputs with
1225    /// internal newlines (the helper does NOT mutate inner whitespace).
1226    #[test]
1227    fn extract_text_single_block_returns_verbatim_including_inner_newlines() {
1228        let blocks = vec![SamplingMessageContent::text("line1\nline2")];
1229        let result = CreateMessageResult::new(
1230            SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1231            "fake".into(),
1232        );
1233        assert_eq!(
1234            extract_text(&result).unwrap(),
1235            "line1\nline2",
1236            "single block must return verbatim, no extra newlines added"
1237        );
1238    }
1239
1240    /// F5 pin: three blocks, second is empty-string. The first block
1241    /// emits its text + `\n`; the empty middle pushes nothing but
1242    /// `out.is_empty()` is false so the NEXT iteration's pre-newline
1243    /// fires again. Result: `"a\n\nb"`. This pins the "empty middle
1244    /// block adds a blank line" semantic — surprising at first read
1245    /// but consistent with the join-between-non-empty rule applied
1246    /// uniformly to every iteration.
1247    #[test]
1248    fn extract_text_empty_middle_block_inserts_blank_line() {
1249        let blocks = vec![
1250            SamplingMessageContent::text("a"),
1251            SamplingMessageContent::text(""),
1252            SamplingMessageContent::text("b"),
1253        ];
1254        let result = CreateMessageResult::new(
1255            SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1256            "fake".into(),
1257        );
1258        // The implementation's exact behavior: after "a" is pushed,
1259        // out = "a". For block 2 (empty): out.is_empty() is false, so
1260        // push '\n' → out = "a\n"; then push_str("") → out = "a\n".
1261        // For block 3 ("b"): out.is_empty() is false, so push '\n' →
1262        // out = "a\n\n"; then push_str("b") → out = "a\n\nb".
1263        assert_eq!(
1264            extract_text(&result).unwrap(),
1265            "a\n\nb",
1266            "empty middle block leaves a blank line between non-empty blocks"
1267        );
1268    }
1269
1270    /// `SamplingError::classify` maps each fake variant to the right
1271    /// audit category.
1272    #[test]
1273    fn sampling_error_classify_maps_fake_variants() {
1274        let refused = SamplingError::Fake(FakeSamplingError::Refused { reason: "x".into() });
1275        let (cat, forb) = refused.classify();
1276        assert_eq!(cat, "client_refused");
1277        assert!(forb);
1278
1279        let transport = SamplingError::Fake(FakeSamplingError::Transport {
1280            message: "x".into(),
1281        });
1282        let (cat, forb) = transport.classify();
1283        assert_eq!(cat, "transport_error");
1284        assert!(!forb);
1285
1286        let malformed = SamplingError::Fake(FakeSamplingError::MalformedResponse {
1287            message: "x".into(),
1288        });
1289        let (cat, forb) = malformed.classify();
1290        assert_eq!(cat, "malformed_response");
1291        assert!(!forb);
1292    }
1293
1294    // -------- v0.9.0 P5a (M3 wiring) — SamplingCoordinator integration --------
1295    //
1296    // These tests pin the contract that `build_sampling_steward` wraps the
1297    // live peer in a `SamplingCoordinator` before handing it to
1298    // `SamplingLlmClient`. They cannot call `build_sampling_steward`
1299    // directly (it takes a real `Peer<RoleServer>` whose constructors are
1300    // private inside rmcp), but they exercise the **exact same wiring
1301    // shape** by substituting `FakeMcpClient` for `PeerSamplingClient`.
1302    // The production code path is:
1303    //
1304    //     PeerSamplingClient -> SamplingCoordinator -> SamplingLlmClient
1305    //
1306    // The tested shape is:
1307    //
1308    //     FakeMcpClient      -> SamplingCoordinator -> SamplingLlmClient
1309    //
1310    // Only the leaf `SamplingClient` impl differs; the
1311    // `SamplingClient` trait is the same Arc-of-dyn in both paths.
1312
1313    /// SamplingCoordinator wrapping a `FakeMcpClient` and feeding
1314    /// `SamplingLlmClient::with_sampling_client` is the same Arc-of-dyn
1315    /// shape `build_sampling_steward` constructs at MCP-initialize
1316    /// time. Single-element flushes pass through unwrapped, so a lone
1317    /// `complete()` call still emits one audit row and produces the
1318    /// expected text.
1319    #[tokio::test]
1320    async fn sampling_llm_client_uses_coordinator_in_production_path() {
1321        let h = harness().await;
1322        let fake: Arc<dyn SamplingClient> = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1323        let coord: Arc<dyn SamplingClient> = super::super::SamplingCoordinator::with_settings(
1324            fake.clone(),
1325            Duration::from_millis(50),
1326            10,
1327        );
1328        let client = SamplingLlmClient::with_sampling_client(
1329            coord,
1330            h.write_handle.clone(),
1331            Some("alice".into()),
1332        );
1333        let result = client.complete(&[Message::user("test")]).await.expect("ok");
1334        assert_eq!(result.role, Role::Assistant);
1335        assert_eq!(result.content, "ok");
1336        // Single audit row landed — per-call audit semantics
1337        // unchanged by the coordinator wrap.
1338        assert_eq!(
1339            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1340            1,
1341            "one logical call → one audit row, even through coordinator"
1342        );
1343    }
1344
1345    /// End-to-end batching pin: N concurrent `complete()` calls within
1346    /// the coalesce window resolve as ONE inner `create_message` RPC
1347    /// on the underlying `FakeMcpClient`, but N audit rows still land
1348    /// (one per logical call — the privacy + audit invariants from P2
1349    /// hold).
1350    ///
1351    /// This is the v0.9.0 release notes' "⌈N/M⌉ peer.create_message
1352    /// calls per coalesce window" claim, exercised through the same
1353    /// trait-object chain that `build_sampling_steward` constructs in
1354    /// production.
1355    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1356    async fn coordinator_coalesces_concurrent_calls_into_one_inner_rpc() {
1357        // Coalesced JSON response for 5 tasks — matches the
1358        // `[{task_index, response}]` shape `flush_batch` demuxes
1359        // multi-element batches into.
1360        let response = serde_json::to_string(
1361            &(0..5)
1362                .map(|i| {
1363                    serde_json::json!({
1364                        "task_index": i,
1365                        "response": format!("response-{i}"),
1366                    })
1367                })
1368                .collect::<Vec<_>>(),
1369        )
1370        .unwrap();
1371
1372        let h = harness().await;
1373        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
1374        let coord: Arc<dyn SamplingClient> = super::super::SamplingCoordinator::with_settings(
1375            fake.clone(),
1376            // Wide window so all 5 submissions land in one batch.
1377            Duration::from_secs(5),
1378            10,
1379        );
1380        let client = SamplingLlmClient::with_sampling_client(
1381            coord,
1382            h.write_handle.clone(),
1383            Some("alice".into()),
1384        );
1385
1386        // Fire 5 concurrent `complete()` calls; the coordinator should
1387        // coalesce them into ONE `FakeMcpClient::create_message` call.
1388        let mut futs = Vec::new();
1389        for i in 0..5 {
1390            let c = client.clone();
1391            futs.push(tokio::spawn(async move {
1392                c.complete(&[Message::user(format!("task-{i}"))]).await
1393            }));
1394        }
1395        for f in futs {
1396            f.await.expect("join").expect("ok");
1397        }
1398
1399        // EXACTLY one inner RPC.
1400        assert_eq!(
1401            fake.record_requests().len(),
1402            1,
1403            "5 logical calls within window must coalesce to 1 inner RPC"
1404        );
1405        // BUT 5 audit rows — per-logical-call audit invariant preserved.
1406        assert_eq!(
1407            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1408            5,
1409            "5 logical calls → 5 audit rows (coordinator doesn't merge audits)"
1410        );
1411    }
1412
1413    /// Edge case: `coalesce_max_requests = 1` reduces the coordinator
1414    /// to pass-through (each submit flushes a 1-element batch
1415    /// immediately). With max_batch=1 and a wide window, 3 concurrent
1416    /// calls land 3 inner RPCs — coordinator is operating as if no
1417    /// batching were configured.
1418    ///
1419    /// Pins the brief's documented edge-case: zero / one-valued config
1420    /// reduces to pass-through, never panics or deadlocks. Mirrors
1421    /// `SamplingCoordinator::with_settings`'s `max_batch.max(1)`
1422    /// clamping for the `coalesce_max_requests = 0` case.
1423    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1424    async fn coordinator_max_batch_one_acts_as_passthrough() {
1425        let h = harness().await;
1426        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1427        let coord: Arc<dyn SamplingClient> = super::super::SamplingCoordinator::with_settings(
1428            fake.clone(),
1429            Duration::from_secs(5),
1430            // max_batch=1 → every submission flushes immediately as
1431            // a 1-element batch; pass-through behaviour.
1432            1,
1433        );
1434        let client = SamplingLlmClient::with_sampling_client(coord, h.write_handle.clone(), None);
1435        let mut futs = Vec::new();
1436        for _ in 0..3 {
1437            let c = client.clone();
1438            futs.push(tokio::spawn(async move {
1439                c.complete(&[Message::user("hi")]).await
1440            }));
1441        }
1442        for f in futs {
1443            f.await.expect("join").expect("ok");
1444        }
1445        // 3 logical calls → 3 inner RPCs (no coalescing).
1446        assert_eq!(
1447            fake.record_requests().len(),
1448            3,
1449            "max_batch=1 must pass through every submission as its own RPC"
1450        );
1451        assert_eq!(count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"), 3);
1452    }
1453}