use crate::sampling::{SamplingClient, SamplingError, SamplingFuture};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::oneshot;
pub(super) const SAMPLING_REVERSE_TIMEOUT: Duration = Duration::from_secs(180);
#[derive(Clone)]
pub(super) struct StdSamplingChannel {
outgoing: std::sync::mpsc::Sender<Value>,
pending: Arc<StdMutex<HashMap<String, oneshot::Sender<Value>>>>,
next_id: Arc<AtomicU64>,
}
impl StdSamplingChannel {
pub fn new(outgoing: std::sync::mpsc::Sender<Value>) -> Self {
Self {
outgoing,
pending: Arc::new(StdMutex::new(HashMap::new())),
next_id: Arc::new(AtomicU64::new(0)),
}
}
pub fn route_response(&self, id_key: &str, message: Value) -> bool {
let mut map = self.pending.lock().expect("pending map poisoned");
if let Some(sender) = map.remove(id_key) {
let _ = sender.send(message);
true
} else {
false
}
}
fn next_request_id(&self) -> String {
let n = self.next_id.fetch_add(1, Ordering::Relaxed);
format!("spool-sampling-{n}")
}
}
pub(super) struct McpSamplingClient {
channel: StdSamplingChannel,
supports: bool,
}
impl McpSamplingClient {
pub fn new(channel: StdSamplingChannel, supports: bool) -> Self {
Self { channel, supports }
}
}
impl SamplingClient for McpSamplingClient {
fn is_available(&self) -> bool {
self.supports
}
fn create_message<'a>(&'a self, prompt: &'a str) -> SamplingFuture<'a> {
let channel = self.channel.clone();
let prompt = prompt.to_string();
Box::pin(async move {
let id = channel.next_request_id();
let request = json!({
"jsonrpc": "2.0",
"id": id,
"method": "sampling/createMessage",
"params": {
"messages": [{
"role": "user",
"content": {"type": "text", "text": prompt}
}],
"maxTokens": 800,
"systemPrompt": "You are a memory-extraction assistant. Reply with strict JSON only.",
"includeContext": "none"
}
});
let (tx, rx) = oneshot::channel();
channel
.pending
.lock()
.expect("pending map poisoned")
.insert(id.clone(), tx);
channel
.outgoing
.send(request)
.map_err(|err| SamplingError::Other(format!("stdio writer closed: {err}")))?;
let response = match tokio::time::timeout(SAMPLING_REVERSE_TIMEOUT, rx).await {
Ok(Ok(value)) => value,
Ok(Err(_canceled)) => {
channel
.pending
.lock()
.expect("pending map poisoned")
.remove(&id);
return Err(SamplingError::Other(
"sampling waiter dropped before response".into(),
));
}
Err(_elapsed) => {
channel
.pending
.lock()
.expect("pending map poisoned")
.remove(&id);
return Err(SamplingError::Timeout);
}
};
extract_sampling_text(response)
})
}
}
pub(super) fn extract_sampling_text(response: Value) -> Result<String, SamplingError> {
if let Some(err_obj) = response.get("error") {
let code = err_obj.get("code").and_then(Value::as_i64).unwrap_or(0);
let message = err_obj
.get("message")
.and_then(Value::as_str)
.unwrap_or("unknown sampling error");
if code == -32601 {
return Err(SamplingError::Unavailable);
}
if message.to_lowercase().contains("reject")
|| message.to_lowercase().contains("declin")
|| message.to_lowercase().contains("denied")
{
return Err(SamplingError::Rejected(message.to_string()));
}
return Err(SamplingError::Other(format!("error {code}: {message}")));
}
let result = response
.get("result")
.ok_or_else(|| SamplingError::Other("missing result envelope".into()))?;
if let Some(text) = result.get("text").and_then(Value::as_str) {
return Ok(text.to_string());
}
if let Some(content) = result.get("content") {
if let Some(text) = content.get("text").and_then(Value::as_str) {
return Ok(text.to_string());
}
if let Some(arr) = content.as_array() {
let mut buf = String::new();
for piece in arr {
if let Some(text) = piece.get("text").and_then(Value::as_str) {
buf.push_str(text);
}
}
if !buf.is_empty() {
return Ok(buf);
}
}
}
Err(SamplingError::Other(format!(
"could not extract text from sampling result: {result}"
)))
}