spool-memory 0.2.3

Local-first developer memory system — persistent, structured knowledge for AI coding tools
Documentation
use crate::sampling::{SamplingClient, SamplingError, SamplingFuture};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::oneshot;

pub(super) const SAMPLING_REVERSE_TIMEOUT: Duration = Duration::from_secs(180);

/// Channel handle used by [`McpSamplingClient`] to drive a real
/// `sampling/createMessage` reverse JSON-RPC round-trip on top of the
/// thread-backed stdio loop.
#[derive(Clone)]
pub(super) struct StdSamplingChannel {
    outgoing: std::sync::mpsc::Sender<Value>,
    pending: Arc<StdMutex<HashMap<String, oneshot::Sender<Value>>>>,
    next_id: Arc<AtomicU64>,
}

impl StdSamplingChannel {
    pub fn new(outgoing: std::sync::mpsc::Sender<Value>) -> Self {
        Self {
            outgoing,
            pending: Arc::new(StdMutex::new(HashMap::new())),
            next_id: Arc::new(AtomicU64::new(0)),
        }
    }

    pub fn route_response(&self, id_key: &str, message: Value) -> bool {
        let mut map = self.pending.lock().expect("pending map poisoned");
        if let Some(sender) = map.remove(id_key) {
            let _ = sender.send(message);
            true
        } else {
            false
        }
    }

    fn next_request_id(&self) -> String {
        let n = self.next_id.fetch_add(1, Ordering::Relaxed);
        format!("spool-sampling-{n}")
    }
}

/// Real implementation of [`SamplingClient`] that drives the MCP
/// `sampling/createMessage` reverse-call over the stdio loop.
pub(super) struct McpSamplingClient {
    channel: StdSamplingChannel,
    supports: bool,
}

impl McpSamplingClient {
    pub fn new(channel: StdSamplingChannel, supports: bool) -> Self {
        Self { channel, supports }
    }
}

impl SamplingClient for McpSamplingClient {
    fn is_available(&self) -> bool {
        self.supports
    }

    fn create_message<'a>(&'a self, prompt: &'a str) -> SamplingFuture<'a> {
        let channel = self.channel.clone();
        let prompt = prompt.to_string();
        Box::pin(async move {
            let id = channel.next_request_id();
            let request = json!({
                "jsonrpc": "2.0",
                "id": id,
                "method": "sampling/createMessage",
                "params": {
                    "messages": [{
                        "role": "user",
                        "content": {"type": "text", "text": prompt}
                    }],
                    "maxTokens": 800,
                    "systemPrompt": "You are a memory-extraction assistant. Reply with strict JSON only.",
                    "includeContext": "none"
                }
            });

            let (tx, rx) = oneshot::channel();
            channel
                .pending
                .lock()
                .expect("pending map poisoned")
                .insert(id.clone(), tx);

            channel
                .outgoing
                .send(request)
                .map_err(|err| SamplingError::Other(format!("stdio writer closed: {err}")))?;

            let response = match tokio::time::timeout(SAMPLING_REVERSE_TIMEOUT, rx).await {
                Ok(Ok(value)) => value,
                Ok(Err(_canceled)) => {
                    channel
                        .pending
                        .lock()
                        .expect("pending map poisoned")
                        .remove(&id);
                    return Err(SamplingError::Other(
                        "sampling waiter dropped before response".into(),
                    ));
                }
                Err(_elapsed) => {
                    channel
                        .pending
                        .lock()
                        .expect("pending map poisoned")
                        .remove(&id);
                    return Err(SamplingError::Timeout);
                }
            };

            extract_sampling_text(response)
        })
    }
}

pub(super) fn extract_sampling_text(response: Value) -> Result<String, SamplingError> {
    if let Some(err_obj) = response.get("error") {
        let code = err_obj.get("code").and_then(Value::as_i64).unwrap_or(0);
        let message = err_obj
            .get("message")
            .and_then(Value::as_str)
            .unwrap_or("unknown sampling error");
        if code == -32601 {
            return Err(SamplingError::Unavailable);
        }
        if message.to_lowercase().contains("reject")
            || message.to_lowercase().contains("declin")
            || message.to_lowercase().contains("denied")
        {
            return Err(SamplingError::Rejected(message.to_string()));
        }
        return Err(SamplingError::Other(format!("error {code}: {message}")));
    }

    let result = response
        .get("result")
        .ok_or_else(|| SamplingError::Other("missing result envelope".into()))?;

    if let Some(text) = result.get("text").and_then(Value::as_str) {
        return Ok(text.to_string());
    }
    if let Some(content) = result.get("content") {
        if let Some(text) = content.get("text").and_then(Value::as_str) {
            return Ok(text.to_string());
        }
        if let Some(arr) = content.as_array() {
            let mut buf = String::new();
            for piece in arr {
                if let Some(text) = piece.get("text").and_then(Value::as_str) {
                    buf.push_str(text);
                }
            }
            if !buf.is_empty() {
                return Ok(buf);
            }
        }
    }
    Err(SamplingError::Other(format!(
        "could not extract text from sampling result: {result}"
    )))
}