use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use rmcp::model::{
CreateMessageRequestParams, CreateMessageResult, ModelHint,
ModelPreferences, Role as RmcpRole, SamplingMessage,
SamplingMessageContent,
};
use rmcp::service::{Peer, RoleServer, ServiceError};
use solo_core::{Error as CoreError, LlmClient, Message, Result as CoreResult, Role};
use solo_storage::{AuditEvent, AuditOperation, AuditResult, WriteHandle};
pub const DEFAULT_SAMPLING_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_SAMPLING_MAX_TOKENS: u32 = 512;
#[derive(Debug)]
pub enum SamplingError {
Service(ServiceError),
#[cfg(any(test, feature = "test-support"))]
Fake(crate::test_support::FakeSamplingError),
}
impl std::fmt::Display for SamplingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Service(e) => write!(f, "{e}"),
#[cfg(any(test, feature = "test-support"))]
Self::Fake(e) => write!(f, "{e}"),
}
}
}
impl std::error::Error for SamplingError {}
impl SamplingError {
pub fn classify(&self) -> (&'static str, bool) {
match self {
Self::Service(_) => ("transport_error", false),
#[cfg(any(test, feature = "test-support"))]
Self::Fake(e) => match e {
crate::test_support::FakeSamplingError::Refused { .. } => {
("client_refused", true)
}
crate::test_support::FakeSamplingError::Transport { .. } => {
("transport_error", false)
}
crate::test_support::FakeSamplingError::MalformedResponse {
..
} => ("malformed_response", false),
},
}
}
}
#[async_trait]
pub trait SamplingClient: Send + Sync {
async fn create_message(
&self,
params: CreateMessageRequestParams,
) -> Result<CreateMessageResult, SamplingError>;
}
pub struct PeerSamplingClient {
peer: Peer<RoleServer>,
}
impl PeerSamplingClient {
pub fn new(peer: Peer<RoleServer>) -> Self {
Self { peer }
}
}
#[async_trait]
impl SamplingClient for PeerSamplingClient {
async fn create_message(
&self,
params: CreateMessageRequestParams,
) -> Result<CreateMessageResult, SamplingError> {
self.peer
.create_message(params)
.await
.map_err(SamplingError::Service)
}
}
#[derive(Clone)]
pub struct SamplingLlmClient {
sampling_client: Arc<dyn SamplingClient>,
write_handle: WriteHandle,
audit_principal: Option<String>,
max_tokens: u32,
timeout: Duration,
}
impl SamplingLlmClient {
pub fn new(
peer: Peer<RoleServer>,
write_handle: WriteHandle,
audit_principal: Option<String>,
) -> Self {
Self::with_sampling_client(
Arc::new(PeerSamplingClient::new(peer)),
write_handle,
audit_principal,
)
}
pub fn with_sampling_client(
sampling_client: Arc<dyn SamplingClient>,
write_handle: WriteHandle,
audit_principal: Option<String>,
) -> Self {
Self {
sampling_client,
write_handle,
audit_principal,
max_tokens: DEFAULT_SAMPLING_MAX_TOKENS,
timeout: DEFAULT_SAMPLING_TIMEOUT,
}
}
pub fn with_max_tokens(mut self, n: u32) -> Self {
self.max_tokens = n.max(1);
self
}
pub fn with_timeout(mut self, t: Duration) -> Self {
self.timeout = t;
self
}
fn build_request(&self, messages: &[Message]) -> CreateMessageRequestParams {
let mut system_parts: Vec<String> = Vec::new();
let mut samp_messages: Vec<SamplingMessage> = Vec::new();
for m in messages {
match m.role {
Role::System => system_parts.push(m.content.clone()),
Role::User => {
samp_messages.push(SamplingMessage::user_text(&m.content));
}
Role::Assistant => {
samp_messages
.push(SamplingMessage::assistant_text(&m.content));
}
}
}
let preferences = ModelPreferences::new()
.with_hints(vec![ModelHint::new("claude")])
.with_intelligence_priority(0.7)
.with_speed_priority(0.3)
.with_cost_priority(0.4);
let mut params =
CreateMessageRequestParams::new(samp_messages, self.max_tokens)
.with_model_preferences(preferences);
if !system_parts.is_empty() {
params = params.with_system_prompt(system_parts.join("\n\n"));
}
params
}
fn audit_event(
&self,
params: &CreateMessageRequestParams,
outcome: SamplingOutcome,
) -> AuditEvent {
let raw_prompt_chars: usize = params
.messages
.iter()
.flat_map(|m| m.content.iter())
.filter_map(|c| c.as_text().map(|t| t.text.len()))
.sum::<usize>()
+ params
.system_prompt
.as_ref()
.map(|s| s.len())
.unwrap_or(0);
let prompt_chars = next_pow2_bucket(raw_prompt_chars);
let input_tokens_est = next_pow2_bucket(raw_prompt_chars / 4) as u64;
let model_hint = params
.model_preferences
.as_ref()
.and_then(|p| p.hints.as_ref())
.and_then(|h| h.first())
.and_then(|h| h.name.clone())
.unwrap_or_else(|| "(none)".to_string());
let mut details = serde_json::Map::new();
details.insert(
"model_hint".to_string(),
serde_json::Value::String(model_hint),
);
details.insert(
"messages_count".to_string(),
serde_json::Value::Number(params.messages.len().into()),
);
details.insert(
"max_tokens".to_string(),
serde_json::Value::Number(params.max_tokens.into()),
);
details.insert(
"prompt_chars".to_string(),
serde_json::Value::Number(prompt_chars.into()),
);
details.insert(
"input_tokens_est".to_string(),
serde_json::Value::Number(input_tokens_est.into()),
);
let result = match &outcome {
SamplingOutcome::Ok {
duration_ms,
model,
output_chars,
} => {
let bucketed_output_chars = next_pow2_bucket(*output_chars);
let output_tokens_est = next_pow2_bucket(*output_chars / 4) as u64;
details.insert(
"duration_ms".to_string(),
serde_json::Value::Number((*duration_ms).into()),
);
details.insert(
"model".to_string(),
serde_json::Value::String(model.clone()),
);
details.insert(
"output_chars".to_string(),
serde_json::Value::Number(bucketed_output_chars.into()),
);
details.insert(
"output_tokens_est".to_string(),
serde_json::Value::Number(output_tokens_est.into()),
);
AuditResult::Ok
}
SamplingOutcome::Forbidden {
reason,
duration_ms,
} => {
details.insert(
"duration_ms".to_string(),
serde_json::Value::Number((*duration_ms).into()),
);
details.insert(
"reason".to_string(),
serde_json::Value::String(reason.to_string()),
);
AuditResult::Forbidden
}
SamplingOutcome::Error {
reason,
duration_ms,
} => {
details.insert(
"duration_ms".to_string(),
serde_json::Value::Number((*duration_ms).into()),
);
details.insert(
"reason".to_string(),
serde_json::Value::String(reason.to_string()),
);
AuditResult::Error
}
};
AuditEvent {
ts_ms: chrono::Utc::now().timestamp_millis(),
principal_subject: self.audit_principal.clone(),
operation: AuditOperation::LlmSamplingCall,
target_id: None,
result,
details: Some(serde_json::Value::Object(details)),
}
}
}
enum SamplingOutcome {
Ok {
duration_ms: u64,
model: String,
output_chars: usize,
},
Forbidden {
reason: &'static str,
duration_ms: u64,
},
Error {
reason: &'static str,
duration_ms: u64,
},
}
#[async_trait]
impl LlmClient for SamplingLlmClient {
fn name(&self) -> &str {
"mcp-sampling"
}
async fn complete(&self, messages: &[Message]) -> CoreResult<Message> {
let params = self.build_request(messages);
let start = Instant::now();
let rpc = tokio::time::timeout(
self.timeout,
self.sampling_client.create_message(params.clone()),
)
.await;
let duration_ms = start.elapsed().as_millis().min(u128::from(u64::MAX))
as u64;
let (core_result, outcome): (CoreResult<Message>, SamplingOutcome) =
match rpc {
Ok(Ok(result)) => {
match extract_text(&result) {
Ok(text) => {
let output_chars = text.len();
let outcome = SamplingOutcome::Ok {
duration_ms,
model: result.model.clone(),
output_chars,
};
(Ok(Message::assistant(text)), outcome)
}
Err(reason) => (
Err(CoreError::llm(format!(
"mcp sampling: malformed response: {reason}"
))),
SamplingOutcome::Error {
reason: "malformed_response",
duration_ms,
},
),
}
}
Ok(Err(e)) => {
let (category, is_forbidden) = e.classify();
let outcome = if is_forbidden {
SamplingOutcome::Forbidden {
reason: category,
duration_ms,
}
} else {
SamplingOutcome::Error {
reason: category,
duration_ms,
}
};
let err = if is_forbidden {
CoreError::forbidden(format!("mcp sampling: {e}"))
} else {
CoreError::llm(format!("mcp sampling: {e}"))
};
(Err(err), outcome)
}
Err(_elapsed) => (
Err(CoreError::llm(format!(
"mcp sampling: timeout after {}ms",
duration_ms
))),
SamplingOutcome::Error {
reason: "timeout",
duration_ms,
},
),
};
let event = self.audit_event(¶ms, outcome);
match (
core_result,
self.write_handle.emit_llm_sampling_audit(event).await,
) {
(Ok(text), Ok(())) => Ok(text),
(Ok(_text), Err(audit_err)) => {
Err(CoreError::storage(format!(
"mcp sampling: audit emit failed: {audit_err}"
)))
}
(Err(core_err), Ok(())) => Err(core_err),
(Err(core_err), Err(audit_err)) => {
tracing::error!(
audit_error = %audit_err,
core_error = %core_err,
"mcp sampling: audit emit failed alongside core \
error; surfacing core error to caller"
);
Err(core_err)
}
}
}
}
fn next_pow2_bucket(n: usize) -> usize {
if n == 0 {
return 0;
}
n.next_power_of_two()
}
fn extract_text(result: &CreateMessageResult) -> Result<String, &'static str> {
if result.message.role != RmcpRole::Assistant {
return Err("response role was not Assistant");
}
let mut out = String::new();
for content in result.message.content.iter() {
if let SamplingMessageContent::Text(text) = content {
if !out.is_empty() {
out.push('\n');
}
out.push_str(&text.text);
}
}
if out.is_empty() {
Err("no text content blocks")
} else {
Ok(out)
}
}
pub fn build_sampling_steward(
peer: Peer<RoleServer>,
write_handle: WriteHandle,
audit_principal: Option<String>,
steward_config: solo_steward::StewardConfig,
sampling_config: solo_storage::SamplingConfig,
) -> Arc<solo_steward::Steward> {
let inner: Arc<dyn SamplingClient> = Arc::new(PeerSamplingClient::new(peer));
let coordinator: Arc<dyn SamplingClient> = super::SamplingCoordinator::with_settings(
inner,
std::time::Duration::from_millis(sampling_config.coalesce_window_ms),
sampling_config.coalesce_max_requests as usize,
);
let client = SamplingLlmClient::with_sampling_client(
coordinator,
write_handle,
audit_principal,
)
.with_max_tokens(steward_config.abstraction_max_tokens.min(65_536) as u32);
Arc::new(solo_steward::Steward::new(Arc::new(client), steward_config))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::{FakeMcpClient, FakeResponse, FakeSamplingError};
use rmcp::model::CreateMessageResult;
use solo_core::TenantId;
use solo_storage::{
EmbedderConfig, HnswParams, InitParams, KeyMaterial, StubEmbedder,
TenantHandle, TenantRegistry, TenantRegistryParams, init,
open_sqlcipher,
};
use std::path::PathBuf;
use std::sync::Arc;
use tempfile::TempDir;
use zeroize::Zeroizing;
const TEST_PASSPHRASE: &str = "v0.9.0-p2-sampling-tests";
struct Harness {
_tmp: TempDir,
_registry: Arc<TenantRegistry>,
_tenant: Arc<TenantHandle>,
write_handle: solo_storage::WriteHandle,
db_path: PathBuf,
key: KeyMaterial,
}
async fn harness() -> Harness {
let tmp = TempDir::new().expect("tempdir");
let data_dir = tmp.path().to_path_buf();
let _ = init(InitParams {
data_dir: data_dir.clone(),
passphrase: Zeroizing::new(TEST_PASSPHRASE.into()),
force: false,
embedder: EmbedderConfig {
name: "stub".into(),
version: "v1".into(),
dim: 32,
dtype: "f32".into(),
},
})
.expect("init");
let cfg = solo_storage::SoloConfig::read(
&data_dir.join("solo.config.toml"),
)
.expect("read cfg");
let key = KeyMaterial::derive(
TEST_PASSPHRASE,
&cfg.salt_bytes().expect("salt"),
)
.expect("derive key");
let embedder: Arc<dyn solo_core::Embedder> =
Arc::new(StubEmbedder::new("stub", "v1", 32));
let registry = Arc::new(
TenantRegistry::open(TenantRegistryParams {
data_dir: data_dir.clone(),
key: key.clone(),
embedder: embedder.clone(),
hnsw_params: HnswParams::default(),
steward: None,
runtime_handle: Some(tokio::runtime::Handle::current()),
steward_factory: None,
triples_batch_signal: None,
})
.expect("open registry"),
);
let tenant_id = TenantId::default_tenant();
let tenant = registry
.get_or_open(&tenant_id)
.await
.expect("get_or_open default tenant");
let write_handle = tenant.write().clone();
let db_path = tenant.db_path().to_path_buf();
Harness {
_tmp: tmp,
_registry: registry,
_tenant: tenant,
write_handle,
db_path,
key,
}
}
fn count_audit_rows(db_path: &std::path::Path, key: &KeyMaterial, op: &str) -> i64 {
let conn = open_sqlcipher(db_path, key).expect("open db");
conn.query_row(
"SELECT COUNT(*) FROM audit_events WHERE operation = ?",
rusqlite::params![op],
|r| r.get(0),
)
.expect("count")
}
fn latest_sampling_audit_details(
db_path: &std::path::Path,
key: &KeyMaterial,
) -> (String, Option<String>, serde_json::Value) {
let conn = open_sqlcipher(db_path, key).expect("open db");
let (result, principal, details_str): (String, Option<String>, Option<String>) = conn
.query_row(
"SELECT result, principal_subject, details_json
FROM audit_events
WHERE operation = 'llm.sampling_call'
ORDER BY ts_ms DESC, rowid DESC
LIMIT 1",
[],
|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
)
.expect("query");
let details: serde_json::Value =
serde_json::from_str(&details_str.expect("details_json present"))
.expect("parse details");
(result, principal, details)
}
#[tokio::test]
async fn sampling_complete_happy_path_returns_text() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("derived theme")));
let client = SamplingLlmClient::with_sampling_client(
fake.clone(),
h.write_handle.clone(),
Some("alice".into()),
);
let messages = vec![Message::user("summarise these episodes")];
let result = client.complete(&messages).await.expect("ok");
assert_eq!(result.role, Role::Assistant);
assert_eq!(result.content, "derived theme");
assert_eq!(
count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1
);
let (result_str, principal, details) =
latest_sampling_audit_details(&h.db_path, &h.key);
assert_eq!(result_str, "ok");
assert_eq!(principal.as_deref(), Some("alice"));
assert_eq!(details["model_hint"], "claude");
assert_eq!(details["model"], "fake-claude");
assert_eq!(details["messages_count"], 1);
assert_eq!(details["max_tokens"], 512);
}
#[tokio::test]
async fn audit_row_omits_raw_prompt_text() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let client = SamplingLlmClient::with_sampling_client(
fake,
h.write_handle.clone(),
None,
);
let secret = "THE-USER-ID-IS-bobby-1234";
let messages = vec![
Message::system("you are a friendly assistant"),
Message::user(secret),
];
client.complete(&messages).await.expect("ok");
let (_, _, details) =
latest_sampling_audit_details(&h.db_path, &h.key);
let serialised =
serde_json::to_string(&details).expect("serialise details");
assert!(
!serialised.contains(secret),
"audit details must not carry raw prompt content; was: {serialised}"
);
assert!(
!serialised.contains("you are a friendly assistant"),
"audit details must not carry system prompt; was: {serialised}"
);
assert_eq!(details["messages_count"], 1);
assert!(details["prompt_chars"].as_u64().unwrap() > 0);
}
#[tokio::test]
async fn audit_row_bucket_prompt_chars_to_pow2() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let client = SamplingLlmClient::with_sampling_client(
fake,
h.write_handle.clone(),
None,
);
client
.complete(&[Message::system("hello "), Message::user("x")])
.await
.expect("ok");
let (_, _, details) =
latest_sampling_audit_details(&h.db_path, &h.key);
assert_eq!(
details["prompt_chars"].as_u64().unwrap(),
8,
"prompt_chars must be bucketed to next pow2 (7 → 8). \
raw count is a privacy side-channel; see Fix 4 F6 in \
v0.9.1 P1 dev log. got details={details}"
);
}
#[tokio::test]
async fn audit_row_bucket_prompt_chars_is_stable_within_bucket() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let client = SamplingLlmClient::with_sampling_client(
fake,
h.write_handle.clone(),
None,
);
client
.complete(&[Message::user("hello")])
.await
.expect("ok");
let (_, _, details_5) =
latest_sampling_audit_details(&h.db_path, &h.key);
client
.complete(&[Message::user("hellooo")])
.await
.expect("ok");
let (_, _, details_7) =
latest_sampling_audit_details(&h.db_path, &h.key);
assert_eq!(
details_5["prompt_chars"], details_7["prompt_chars"],
"5 chars and 7 chars must hash to the same bucket (8) — \
otherwise the bucketing is leaking raw fidelity. \
5-char details: {details_5}, 7-char details: {details_7}"
);
assert_eq!(details_5["prompt_chars"].as_u64().unwrap(), 8);
}
#[test]
fn next_pow2_bucket_table() {
assert_eq!(next_pow2_bucket(0), 0, "0 stays 0");
assert_eq!(next_pow2_bucket(1), 1, "1 stays 1");
assert_eq!(next_pow2_bucket(2), 2, "2 stays 2");
assert_eq!(next_pow2_bucket(3), 4, "3 rounds up to 4");
assert_eq!(next_pow2_bucket(4), 4, "4 stays 4");
assert_eq!(next_pow2_bucket(5), 8);
assert_eq!(next_pow2_bucket(6), 8, "6-char prompt (brief case) → 8");
assert_eq!(next_pow2_bucket(7), 8);
assert_eq!(next_pow2_bucket(8), 8);
assert_eq!(next_pow2_bucket(9), 16);
assert_eq!(next_pow2_bucket(1023), 1024);
assert_eq!(next_pow2_bucket(1024), 1024);
assert_eq!(next_pow2_bucket(1025), 2048);
}
#[tokio::test]
async fn client_refusal_returns_forbidden_and_audits() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ignored")));
fake.reject_with("user dismissed approval");
let client = SamplingLlmClient::with_sampling_client(
fake,
h.write_handle.clone(),
Some("alice".into()),
);
let err = client
.complete(&[Message::user("anything")])
.await
.unwrap_err();
match err {
CoreError::Forbidden(_) => {}
other => panic!("expected Forbidden, got {other:?}"),
}
let (result_str, _, details) =
latest_sampling_audit_details(&h.db_path, &h.key);
assert_eq!(result_str, "forbidden");
assert_eq!(details["reason"], "client_refused");
}
#[tokio::test]
async fn timeout_returns_error_with_timeout_reason() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::slow(
"late",
Duration::from_millis(800),
)));
let client = SamplingLlmClient::with_sampling_client(
fake,
h.write_handle.clone(),
None,
)
.with_timeout(Duration::from_millis(30));
let err = client
.complete(&[Message::user("hello")])
.await
.unwrap_err();
match err {
CoreError::Llm(msg) => assert!(msg.contains("timeout")),
other => panic!("expected Llm, got {other:?}"),
}
let (result_str, _, details) =
latest_sampling_audit_details(&h.db_path, &h.key);
assert_eq!(result_str, "error");
assert_eq!(details["reason"], "timeout");
}
#[tokio::test]
async fn malformed_response_returns_error_with_reason() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::EmptyContent));
let client = SamplingLlmClient::with_sampling_client(
fake,
h.write_handle.clone(),
None,
);
let err = client
.complete(&[Message::user("hi")])
.await
.unwrap_err();
assert!(matches!(err, CoreError::Llm(_)));
let (result_str, _, details) =
latest_sampling_audit_details(&h.db_path, &h.key);
assert_eq!(result_str, "error");
assert_eq!(details["reason"], "malformed_response");
}
#[tokio::test]
async fn no_principal_emits_audit_with_null_principal() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let client = SamplingLlmClient::with_sampling_client(
fake,
h.write_handle.clone(),
None,
);
client.complete(&[Message::user("hi")]).await.expect("ok");
let (_, principal, _) =
latest_sampling_audit_details(&h.db_path, &h.key);
assert_eq!(principal, None);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn parallel_completes_serialise_audit_rows() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let client = SamplingLlmClient::with_sampling_client(
fake.clone(),
h.write_handle.clone(),
Some("alice".into()),
);
let mut futs = Vec::new();
for _ in 0..8 {
let c = client.clone();
futs.push(tokio::spawn(async move {
c.complete(&[Message::user("hi")]).await
}));
}
for f in futs {
f.await.expect("join").expect("ok");
}
assert_eq!(
count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
8,
"8 parallel calls must land 8 audit rows"
);
assert_eq!(fake.record_requests().len(), 8);
}
#[tokio::test]
async fn build_request_splits_system_from_messages() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let client = SamplingLlmClient::with_sampling_client(
fake.clone(),
h.write_handle.clone(),
None,
);
client
.complete(&[
Message::system("be terse"),
Message::user("question"),
Message::assistant("answer"),
])
.await
.expect("ok");
let recorded = fake.record_requests();
assert_eq!(recorded.len(), 1);
let req = &recorded[0];
assert_eq!(
req.system_prompt.as_deref(),
Some("be terse"),
"Role::System must map to system_prompt"
);
assert_eq!(req.messages.len(), 2);
assert_eq!(req.messages[0].role, RmcpRole::User);
assert_eq!(req.messages[1].role, RmcpRole::Assistant);
}
#[tokio::test]
async fn build_request_includes_claude_model_hint() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let client = SamplingLlmClient::with_sampling_client(
fake.clone(),
h.write_handle.clone(),
None,
);
client
.complete(&[Message::user("hi")])
.await
.expect("ok");
let recorded = fake.record_requests();
let prefs = recorded[0].model_preferences.as_ref().expect("prefs");
let hint = prefs
.hints
.as_ref()
.and_then(|h| h.first())
.and_then(|h| h.name.clone())
.expect("hint name");
assert_eq!(hint, "claude");
}
#[tokio::test]
async fn with_max_tokens_overrides_default() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let client = SamplingLlmClient::with_sampling_client(
fake.clone(),
h.write_handle.clone(),
None,
)
.with_max_tokens(2048);
client
.complete(&[Message::user("hi")])
.await
.expect("ok");
let recorded = fake.record_requests();
assert_eq!(recorded[0].max_tokens, 2048);
}
#[tokio::test]
async fn reconfigurable_fake_distinguishes_audit_rows() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let client = SamplingLlmClient::with_sampling_client(
fake.clone(),
h.write_handle.clone(),
Some("alice".into()),
);
client.complete(&[Message::user("a")]).await.expect("ok");
fake.reject_with("user said no");
let _ = client.complete(&[Message::user("b")]).await;
let conn = open_sqlcipher(&h.db_path, &h.key).expect("open");
let mut stmt = conn
.prepare(
"SELECT result FROM audit_events WHERE operation = 'llm.sampling_call' ORDER BY ts_ms ASC, rowid ASC",
)
.expect("prepare");
let rows: Vec<String> = stmt
.query_map([], |r| r.get::<_, String>(0))
.expect("query")
.map(|r| r.expect("row"))
.collect();
assert_eq!(rows, vec!["ok".to_string(), "forbidden".to_string()]);
}
#[test]
fn extract_text_pulls_text_from_single_block() {
let result = CreateMessageResult::new(
SamplingMessage::assistant_text("hello"),
"fake".into(),
);
assert_eq!(extract_text(&result).unwrap(), "hello");
}
#[test]
fn extract_text_rejects_empty_content() {
let result = CreateMessageResult::new(
SamplingMessage::new_multiple(RmcpRole::Assistant, Vec::new()),
"fake".into(),
);
assert!(extract_text(&result).is_err());
}
#[test]
fn extract_text_rejects_non_assistant_role() {
let result = CreateMessageResult::new(
SamplingMessage::user_text("hello"),
"fake".into(),
);
assert!(extract_text(&result).is_err());
}
#[test]
fn extract_text_joins_multi_block_with_newline_separator() {
let blocks = vec![
SamplingMessageContent::text("abc"),
SamplingMessageContent::text("def"),
];
let result = CreateMessageResult::new(
SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
"fake".into(),
);
assert_eq!(
extract_text(&result).unwrap(),
"abc\ndef",
"two non-newline-terminated blocks must join with a single newline"
);
}
#[test]
fn extract_text_preserves_trailing_newlines_in_blocks() {
let blocks = vec![
SamplingMessageContent::text("abc\n"),
SamplingMessageContent::text("def"),
];
let result = CreateMessageResult::new(
SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
"fake".into(),
);
assert_eq!(
extract_text(&result).unwrap(),
"abc\n\ndef",
"trailing newline in block 1 + join newline => '\\n\\n' between blocks"
);
}
#[test]
fn extract_text_single_block_returns_verbatim_including_inner_newlines() {
let blocks = vec![SamplingMessageContent::text("line1\nline2")];
let result = CreateMessageResult::new(
SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
"fake".into(),
);
assert_eq!(
extract_text(&result).unwrap(),
"line1\nline2",
"single block must return verbatim, no extra newlines added"
);
}
#[test]
fn extract_text_empty_middle_block_inserts_blank_line() {
let blocks = vec![
SamplingMessageContent::text("a"),
SamplingMessageContent::text(""),
SamplingMessageContent::text("b"),
];
let result = CreateMessageResult::new(
SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
"fake".into(),
);
assert_eq!(
extract_text(&result).unwrap(),
"a\n\nb",
"empty middle block leaves a blank line between non-empty blocks"
);
}
#[test]
fn sampling_error_classify_maps_fake_variants() {
let refused = SamplingError::Fake(FakeSamplingError::Refused {
reason: "x".into(),
});
let (cat, forb) = refused.classify();
assert_eq!(cat, "client_refused");
assert!(forb);
let transport = SamplingError::Fake(FakeSamplingError::Transport {
message: "x".into(),
});
let (cat, forb) = transport.classify();
assert_eq!(cat, "transport_error");
assert!(!forb);
let malformed =
SamplingError::Fake(FakeSamplingError::MalformedResponse {
message: "x".into(),
});
let (cat, forb) = malformed.classify();
assert_eq!(cat, "malformed_response");
assert!(!forb);
}
#[tokio::test]
async fn sampling_llm_client_uses_coordinator_in_production_path() {
let h = harness().await;
let fake: Arc<dyn SamplingClient> =
Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let coord: Arc<dyn SamplingClient> =
super::super::SamplingCoordinator::with_settings(
fake.clone(),
Duration::from_millis(50),
10,
);
let client = SamplingLlmClient::with_sampling_client(
coord,
h.write_handle.clone(),
Some("alice".into()),
);
let result = client
.complete(&[Message::user("test")])
.await
.expect("ok");
assert_eq!(result.role, Role::Assistant);
assert_eq!(result.content, "ok");
assert_eq!(
count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1,
"one logical call → one audit row, even through coordinator"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn coordinator_coalesces_concurrent_calls_into_one_inner_rpc() {
let response = serde_json::to_string(&(0..5)
.map(|i| serde_json::json!({
"task_index": i,
"response": format!("response-{i}"),
}))
.collect::<Vec<_>>())
.unwrap();
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
let coord: Arc<dyn SamplingClient> =
super::super::SamplingCoordinator::with_settings(
fake.clone(),
Duration::from_secs(5),
10,
);
let client = SamplingLlmClient::with_sampling_client(
coord,
h.write_handle.clone(),
Some("alice".into()),
);
let mut futs = Vec::new();
for i in 0..5 {
let c = client.clone();
futs.push(tokio::spawn(async move {
c.complete(&[Message::user(format!("task-{i}"))]).await
}));
}
for f in futs {
f.await.expect("join").expect("ok");
}
assert_eq!(
fake.record_requests().len(),
1,
"5 logical calls within window must coalesce to 1 inner RPC"
);
assert_eq!(
count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
5,
"5 logical calls → 5 audit rows (coordinator doesn't merge audits)"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn coordinator_max_batch_one_acts_as_passthrough() {
let h = harness().await;
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
let coord: Arc<dyn SamplingClient> =
super::super::SamplingCoordinator::with_settings(
fake.clone(),
Duration::from_secs(5),
1,
);
let client = SamplingLlmClient::with_sampling_client(
coord,
h.write_handle.clone(),
None,
);
let mut futs = Vec::new();
for _ in 0..3 {
let c = client.clone();
futs.push(tokio::spawn(async move {
c.complete(&[Message::user("hi")]).await
}));
}
for f in futs {
f.await.expect("join").expect("ok");
}
assert_eq!(
fake.record_requests().len(),
3,
"max_batch=1 must pass through every submission as its own RPC"
);
assert_eq!(
count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
3
);
}
}