use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use rmcp::model::{
CreateMessageRequestParams, CreateMessageResult, Role as RmcpRole, SamplingMessage,
SamplingMessageContent,
};
use tokio::sync::{Mutex, mpsc, oneshot};
use crate::llm::sampling::{SamplingClient, SamplingError};
pub const DEFAULT_COALESCE_WINDOW: Duration = Duration::from_millis(5000);
pub const DEFAULT_COALESCE_MAX_BATCH: usize = 10;
pub struct SamplingCoordinator {
tx: mpsc::Sender<Submission>,
worker: Mutex<Option<tokio::task::JoinHandle<()>>>,
}
impl SamplingCoordinator {
pub fn new(inner: Arc<dyn SamplingClient>) -> Arc<Self> {
Self::with_settings(inner, DEFAULT_COALESCE_WINDOW, DEFAULT_COALESCE_MAX_BATCH)
}
pub fn with_settings(
inner: Arc<dyn SamplingClient>,
window: Duration,
max_batch: usize,
) -> Arc<Self> {
let (tx, rx) = mpsc::channel::<Submission>(max_batch.max(1) * 2 + 16);
let worker = tokio::spawn(coordinator_worker(rx, inner, window, max_batch.max(1)));
Arc::new(Self {
tx,
worker: Mutex::new(Some(worker)),
})
}
pub async fn submit(
&self,
params: CreateMessageRequestParams,
) -> Result<CreateMessageResult, SamplingError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Submission {
params,
reply: reply_tx,
})
.await
.map_err(|_| {
SamplingError::Service(rmcp::service::ServiceError::McpError(
rmcp::model::ErrorData::internal_error(
"sampling coordinator worker is gone (channel closed)",
None,
),
))
})?;
reply_rx.await.map_err(|_| {
SamplingError::Service(rmcp::service::ServiceError::McpError(
rmcp::model::ErrorData::internal_error(
"sampling coordinator worker dropped reply channel",
None,
),
))
})?
}
pub async fn shutdown(self: Arc<Self>) {
let mut guard = self.worker.lock().await;
if let Some(join) = guard.take() {
join.abort();
let _ = join.await;
}
}
}
struct Submission {
params: CreateMessageRequestParams,
reply: oneshot::Sender<Result<CreateMessageResult, SamplingError>>,
}
async fn coordinator_worker(
mut rx: mpsc::Receiver<Submission>,
inner: Arc<dyn SamplingClient>,
window: Duration,
max_batch: usize,
) {
loop {
let first = match rx.recv().await {
Some(s) => s,
None => return,
};
let mut buffer: Vec<Submission> = vec![first];
let deadline = tokio::time::Instant::now() + window;
while buffer.len() < max_batch {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
break;
}
match tokio::time::timeout(remaining, rx.recv()).await {
Ok(Some(s)) => buffer.push(s),
Ok(None) => {
flush_batch(&inner, buffer).await;
return;
}
Err(_) => break,
}
}
flush_batch(&inner, buffer).await;
}
}
async fn flush_batch(inner: &Arc<dyn SamplingClient>, batch: Vec<Submission>) {
if batch.is_empty() {
return;
}
if batch.len() == 1 {
let mut iter = batch.into_iter();
let s = iter.next().unwrap();
let result = inner.create_message(s.params).await;
let _ = s.reply.send(result);
return;
}
let coalesced = build_coalesced_request(&batch);
let result = inner.create_message(coalesced).await;
match result {
Ok(rendered) => {
match demux_coalesced(&rendered, &batch) {
Ok(per_task) => {
for (sub, task_result) in batch.into_iter().zip(per_task) {
let _ = sub.reply.send(task_result);
}
}
Err(parse_err) => {
let err_msg = format!(
"sampling coordinator: failed to parse coalesced response: {parse_err}"
);
for sub in batch {
let _ = sub.reply.send(Err(SamplingError::Service(
rmcp::service::ServiceError::McpError(
rmcp::model::ErrorData::internal_error(err_msg.clone(), None),
),
)));
}
}
}
}
Err(e) => {
let err_msg = format!("{e}");
for sub in batch {
let _ = sub.reply.send(Err(SamplingError::Service(
rmcp::service::ServiceError::McpError(
rmcp::model::ErrorData::internal_error(
format!("sampling coordinator: coalesced RPC failed: {err_msg}"),
None,
),
),
)));
}
}
}
}
fn build_coalesced_request(batch: &[Submission]) -> CreateMessageRequestParams {
let mut tasks: Vec<serde_json::Value> = Vec::with_capacity(batch.len());
let mut system_parts: Vec<String> = vec![
"You are a batch task processor. Process EVERY task listed in the \
user message and reply with a JSON array of objects where each \
object has shape: { \"task_index\": <int starting from 0>, \
\"response\": \"<string>\" }. The array MUST have exactly N entries \
(one per task) in the SAME ORDER. Do NOT include any prose outside \
the JSON."
.to_string(),
];
for (idx, sub) in batch.iter().enumerate() {
let mut task_messages: Vec<serde_json::Value> = Vec::new();
if let Some(sys) = sub.params.system_prompt.as_ref() {
system_parts.push(format!("Task-{idx} sub-system: {sys}"));
}
for sm in &sub.params.messages {
let role_str = match sm.role {
RmcpRole::User => "user",
RmcpRole::Assistant => "assistant",
};
let mut text_parts: Vec<String> = Vec::new();
for content in sm.content.iter() {
if let SamplingMessageContent::Text(t) = content {
text_parts.push(t.text.clone());
}
}
task_messages.push(serde_json::json!({
"role": role_str,
"content": text_parts.join("\n"),
}));
}
tasks.push(serde_json::json!({
"task_index": idx,
"messages": task_messages,
}));
}
let user_payload =
serde_json::json!({ "tasks": tasks }).to_string();
let max_tokens = batch
.iter()
.map(|s| s.params.max_tokens)
.fold(0u32, |acc, n| acc.saturating_add(n));
let mut params = CreateMessageRequestParams::new(
vec![SamplingMessage::user_text(&user_payload)],
max_tokens.max(1),
);
params = params.with_system_prompt(system_parts.join("\n\n"));
if let Some(prefs) = batch[0].params.model_preferences.as_ref() {
params = params.with_model_preferences(prefs.clone());
}
params
}
fn demux_coalesced(
rendered: &CreateMessageResult,
batch: &[Submission],
) -> Result<Vec<Result<CreateMessageResult, SamplingError>>, String> {
let text = extract_text_from_result(rendered).map_err(|e| e.to_string())?;
let parsed: serde_json::Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(e) => {
match extract_fenced_json(&text) {
Some(inner) => serde_json::from_str(inner)
.map_err(|fe| format!("fenced parse: {fe}"))?,
None => return Err(format!("top-level JSON parse: {e}")),
}
}
};
let arr = parsed
.as_array()
.ok_or_else(|| "response root is not a JSON array".to_string())?;
let mut out: Vec<Result<CreateMessageResult, SamplingError>> =
Vec::with_capacity(batch.len());
for (idx, _sub) in batch.iter().enumerate() {
let entry = arr.iter().find(|e| {
e.get("task_index")
.and_then(|v| v.as_i64())
.map(|i| i as usize == idx)
.unwrap_or(false)
});
match entry {
Some(e) => {
let response_text = e
.get("response")
.and_then(|v| v.as_str())
.unwrap_or("");
out.push(Ok(make_assistant_result(response_text, &rendered.model)));
}
None => out.push(Err(SamplingError::Service(
rmcp::service::ServiceError::McpError(
rmcp::model::ErrorData::internal_error(
format!(
"sampling coordinator: response missing task_index {idx}"
),
None,
),
),
))),
}
}
Ok(out)
}
fn extract_fenced_json(text: &str) -> Option<&str> {
let needle = "```json";
let start = text.find(needle)?;
let after = &text[start + needle.len()..];
let end = after.find("```")?;
Some(after[..end].trim())
}
fn extract_text_from_result(result: &CreateMessageResult) -> Result<String, &'static str> {
if result.message.role != RmcpRole::Assistant {
return Err("response role was not Assistant");
}
let mut out = String::new();
for content in result.message.content.iter() {
if let SamplingMessageContent::Text(text) = content {
if !out.is_empty() {
out.push('\n');
}
out.push_str(&text.text);
}
}
if out.is_empty() {
Err("no text content blocks")
} else {
Ok(out)
}
}
fn make_assistant_result(text: &str, model: &str) -> CreateMessageResult {
CreateMessageResult::new(
SamplingMessage::assistant_text(text.to_string()),
model.to_string(),
)
}
#[async_trait]
impl SamplingClient for SamplingCoordinator {
async fn create_message(
&self,
params: CreateMessageRequestParams,
) -> Result<CreateMessageResult, SamplingError> {
self.submit(params).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::{FakeMcpClient, FakeResponse};
fn mk_params(prompt: &str) -> CreateMessageRequestParams {
CreateMessageRequestParams::new(vec![SamplingMessage::user_text(prompt)], 128)
}
fn coalesced_response_for(n_tasks: usize) -> String {
let mut arr = Vec::with_capacity(n_tasks);
for i in 0..n_tasks {
arr.push(serde_json::json!({
"task_index": i,
"response": format!("response-{i}"),
}));
}
serde_json::to_string(&arr).unwrap()
}
#[tokio::test]
async fn coalesces_n_concurrent_submissions_into_one_create_message_call() {
let response = coalesced_response_for(3);
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
let coord = SamplingCoordinator::with_settings(
fake.clone(),
Duration::from_millis(100),
10,
);
let c1 = coord.clone();
let c2 = coord.clone();
let c3 = coord.clone();
let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
let h3 = tokio::spawn(async move { c3.submit(mk_params("task-C")).await });
let r1 = h1.await.unwrap().expect("submission 1 ok");
let r2 = h2.await.unwrap().expect("submission 2 ok");
let r3 = h3.await.unwrap().expect("submission 3 ok");
let recorded = fake.record_requests();
assert_eq!(
recorded.len(),
1,
"coordinator must coalesce 3 submissions into 1 inner call"
);
assert_eq!(
extract_text_from_result(&r1).unwrap(),
"response-0",
"task-0 response routed to first submission"
);
assert_eq!(extract_text_from_result(&r2).unwrap(), "response-1");
assert_eq!(extract_text_from_result(&r3).unwrap(), "response-2");
}
#[tokio::test]
async fn flushes_at_max_batch_before_window_expires() {
let response = coalesced_response_for(2);
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
let coord = SamplingCoordinator::with_settings(
fake.clone(),
Duration::from_secs(5),
2,
);
let started = tokio::time::Instant::now();
let c1 = coord.clone();
let c2 = coord.clone();
let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
let _ = h1.await.unwrap();
let _ = h2.await.unwrap();
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_secs(2),
"max_batch must flush before window expires; took {elapsed:?}"
);
assert_eq!(fake.record_requests().len(), 1);
}
#[tokio::test]
async fn single_submission_passes_through_unwrapped() {
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(
"direct-response",
)));
let coord = SamplingCoordinator::with_settings(
fake.clone(),
Duration::from_millis(50),
10,
);
let result = coord
.submit(mk_params("lonely-task"))
.await
.expect("submission ok");
let recorded = fake.record_requests();
assert_eq!(recorded.len(), 1);
let inner_text = extract_first_user_text(&recorded[0]);
assert_eq!(
inner_text, "lonely-task",
"single-batch path must NOT wrap the prompt"
);
assert_eq!(extract_text_from_result(&result).unwrap(), "direct-response");
}
#[tokio::test]
async fn window_expiry_flushes_each_submission_individually() {
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("r-first")));
fake.respond_each(vec![
FakeResponse::text("r-first"),
FakeResponse::text("r-second"),
]);
let coord = SamplingCoordinator::with_settings(
fake.clone(),
Duration::from_millis(20),
10,
);
let r1 = coord
.submit(mk_params("first"))
.await
.expect("submission 1");
tokio::time::sleep(Duration::from_millis(50)).await;
let r2 = coord
.submit(mk_params("second"))
.await
.expect("submission 2");
assert_eq!(fake.record_requests().len(), 2);
assert_eq!(extract_text_from_result(&r1).unwrap(), "r-first");
assert_eq!(extract_text_from_result(&r2).unwrap(), "r-second");
}
#[tokio::test]
async fn demux_propagates_per_request_failures() {
let response = serde_json::json!([
{ "task_index": 0, "response": "ok-0" },
{ "task_index": 2, "response": "ok-2" },
])
.to_string();
let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
let coord = SamplingCoordinator::with_settings(
fake.clone(),
Duration::from_millis(100),
10,
);
let c1 = coord.clone();
let c2 = coord.clone();
let c3 = coord.clone();
let h1 = tokio::spawn(async move { c1.submit(mk_params("t0")).await });
let h2 = tokio::spawn(async move { c2.submit(mk_params("t1")).await });
let h3 = tokio::spawn(async move { c3.submit(mk_params("t2")).await });
let r1 = h1.await.unwrap();
let r2 = h2.await.unwrap();
let r3 = h3.await.unwrap();
assert!(r1.is_ok());
assert!(r2.is_err(), "missing task_index must surface as error");
assert!(r3.is_ok());
}
#[tokio::test]
async fn coalesced_rpc_failure_surfaces_to_every_submission() {
let fake = Arc::new(FakeMcpClient::new(FakeResponse::Error(
crate::test_support::FakeSamplingError::Transport {
message: "simulated transport failure".into(),
},
)));
let coord = SamplingCoordinator::with_settings(
fake.clone(),
Duration::from_millis(100),
10,
);
let c1 = coord.clone();
let c2 = coord.clone();
let h1 = tokio::spawn(async move { c1.submit(mk_params("a")).await });
let h2 = tokio::spawn(async move { c2.submit(mk_params("b")).await });
assert!(h1.await.unwrap().is_err());
assert!(h2.await.unwrap().is_err());
}
fn extract_first_user_text(params: &CreateMessageRequestParams) -> String {
for m in ¶ms.messages {
if m.role == RmcpRole::User {
for c in m.content.iter() {
if let SamplingMessageContent::Text(t) = c {
return t.text.clone();
}
}
}
}
String::new()
}
}