Skip to main content

lean_ctx/
mcp_stdio.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    sync::{Arc, Mutex},
5};
6
7use futures::{SinkExt, StreamExt};
8use rmcp::{
9    service::{RoleServer, RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage},
10    transport::Transport,
11};
12use serde::{de::DeserializeOwned, Serialize};
13use thiserror::Error;
14use tokio::{
15    io::{AsyncRead, AsyncWrite},
16    sync::Mutex as AsyncMutex,
17};
18use tokio_util::{
19    bytes::{Buf, BufMut, BytesMut},
20    codec::{Decoder, Encoder, FramedRead, FramedWrite},
21};
22
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24enum WireProtocol {
25    JsonLine,
26    ContentLength,
27}
28
29#[derive(Debug, Clone)]
30struct SharedProtocol(Arc<Mutex<Option<WireProtocol>>>);
31
32impl SharedProtocol {
33    fn new() -> Self {
34        Self(Arc::new(Mutex::new(None)))
35    }
36
37    fn get(&self) -> Option<WireProtocol> {
38        *self.0.lock().expect("protocol mutex poisoned")
39    }
40
41    fn set_if_unset(&self, protocol: WireProtocol) {
42        let mut guard = self.0.lock().expect("protocol mutex poisoned");
43        if guard.is_none() {
44            *guard = Some(protocol);
45        }
46    }
47}
48
49pub type TransportWriter<Role, W> =
50    FramedWrite<W, HybridJsonRpcMessageCodec<TxJsonRpcMessage<Role>>>;
51
52pub struct HybridStdioTransport<Role: ServiceRole, R: AsyncRead, W: AsyncWrite> {
53    read: FramedRead<R, HybridJsonRpcMessageCodec<RxJsonRpcMessage<Role>>>,
54    write: Arc<AsyncMutex<Option<TransportWriter<Role, W>>>>,
55}
56
57impl<Role: ServiceRole, R, W> HybridStdioTransport<Role, R, W>
58where
59    R: Send + AsyncRead + Unpin,
60    W: Send + AsyncWrite + Unpin + 'static,
61{
62    pub fn new(read: R, write: W) -> Self {
63        let protocol = SharedProtocol::new();
64        let read = FramedRead::new(
65            read,
66            HybridJsonRpcMessageCodec::<RxJsonRpcMessage<Role>>::new(protocol.clone()),
67        );
68        let write = Arc::new(AsyncMutex::new(Some(FramedWrite::new(
69            write,
70            HybridJsonRpcMessageCodec::<TxJsonRpcMessage<Role>>::new(protocol),
71        ))));
72        Self { read, write }
73    }
74}
75
76impl<R, W> HybridStdioTransport<RoleServer, R, W>
77where
78    R: Send + AsyncRead + Unpin,
79    W: Send + AsyncWrite + Unpin + 'static,
80{
81    pub fn new_server(read: R, write: W) -> Self {
82        Self::new(read, write)
83    }
84}
85
86impl<Role: ServiceRole, R, W> Transport<Role> for HybridStdioTransport<Role, R, W>
87where
88    R: Send + AsyncRead + Unpin,
89    W: Send + AsyncWrite + Unpin + 'static,
90{
91    type Error = std::io::Error;
92
93    fn send(
94        &mut self,
95        item: TxJsonRpcMessage<Role>,
96    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
97        let lock = self.write.clone();
98        async move {
99            let mut write = lock.lock().await;
100            if let Some(ref mut write) = *write {
101                write.send(item).await.map_err(Into::into)
102            } else {
103                Err(std::io::Error::new(
104                    std::io::ErrorKind::NotConnected,
105                    "Transport is closed",
106                ))
107            }
108        }
109    }
110
111    fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<Role>>> + Send {
112        let next = self.read.next();
113        async {
114            next.await.and_then(|result| {
115                result
116                    .inspect_err(|error| {
117                        tracing::error!("Error reading from stream: {}", error);
118                    })
119                    .ok()
120            })
121        }
122    }
123
124    async fn close(&mut self) -> Result<(), Self::Error> {
125        let mut write = self.write.lock().await;
126        drop(write.take());
127        Ok(())
128    }
129}
130
131#[derive(Debug, Clone)]
132pub struct HybridJsonRpcMessageCodec<T> {
133    _marker: PhantomData<fn() -> T>,
134    next_index: usize,
135    max_length: usize,
136    is_discarding: bool,
137    protocol: SharedProtocol,
138}
139
140impl<T> HybridJsonRpcMessageCodec<T> {
141    fn new(protocol: SharedProtocol) -> Self {
142        Self {
143            _marker: PhantomData,
144            next_index: 0,
145            max_length: usize::MAX,
146            is_discarding: false,
147            protocol,
148        }
149    }
150}
151
152fn without_carriage_return(s: &[u8]) -> &[u8] {
153    if let Some(&b'\r') = s.last() {
154        &s[..s.len() - 1]
155    } else {
156        s
157    }
158}
159
160fn is_standard_method(method: &str) -> bool {
161    matches!(
162        method,
163        "initialize"
164            | "ping"
165            | "prompts/get"
166            | "prompts/list"
167            | "resources/list"
168            | "resources/read"
169            | "resources/subscribe"
170            | "resources/unsubscribe"
171            | "resources/templates/list"
172            | "tools/call"
173            | "tools/list"
174            | "completion/complete"
175            | "logging/setLevel"
176            | "roots/list"
177            | "sampling/createMessage"
178    ) || is_standard_notification(method)
179}
180
181fn is_standard_notification(method: &str) -> bool {
182    matches!(
183        method,
184        "notifications/cancelled"
185            | "notifications/initialized"
186            | "notifications/message"
187            | "notifications/progress"
188            | "notifications/prompts/list_changed"
189            | "notifications/resources/list_changed"
190            | "notifications/resources/updated"
191            | "notifications/roots/list_changed"
192            | "notifications/tools/list_changed"
193    )
194}
195
196fn should_ignore_notification(json_value: &serde_json::Value, method: &str) -> bool {
197    let is_notification = json_value.get("id").is_none();
198    if is_notification && !is_standard_method(method) {
199        tracing::trace!(
200            "Ignoring non-MCP notification '{}' for compatibility",
201            method
202        );
203        return true;
204    }
205
206    matches!(
207        (
208            method.starts_with("notifications/"),
209            is_standard_notification(method)
210        ),
211        (true, false)
212    )
213}
214
215fn try_parse_with_compatibility<T: DeserializeOwned>(
216    payload: &[u8],
217    context: &str,
218) -> Result<Option<T>, HybridCodecError> {
219    if let Ok(line_str) = std::str::from_utf8(payload) {
220        match serde_json::from_slice(payload) {
221            Ok(item) => Ok(Some(item)),
222            Err(error) => {
223                if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(line_str) {
224                    if let Some(method) =
225                        json_value.get("method").and_then(serde_json::Value::as_str)
226                    {
227                        if should_ignore_notification(&json_value, method) {
228                            return Ok(None);
229                        }
230                    }
231                }
232
233                tracing::debug!(
234                    "Failed to parse message {}: {} | Error: {}",
235                    context,
236                    line_str,
237                    error
238                );
239                Err(HybridCodecError::Serde(error))
240            }
241        }
242    } else {
243        serde_json::from_slice(payload)
244            .map(Some)
245            .map_err(HybridCodecError::Serde)
246    }
247}
248
249#[derive(Debug, Error)]
250pub enum HybridCodecError {
251    #[error("max line length exceeded")]
252    MaxLineLengthExceeded,
253    #[error("missing Content-Length header")]
254    MissingContentLength,
255    #[error("invalid Content-Length value: {0}")]
256    InvalidContentLength(String),
257    #[error("invalid header frame: {0}")]
258    InvalidHeaderFrame(String),
259    #[error("serde error {0}")]
260    Serde(#[from] serde_json::Error),
261    #[error("io error {0}")]
262    Io(#[from] std::io::Error),
263}
264
265impl From<HybridCodecError> for std::io::Error {
266    fn from(value: HybridCodecError) -> Self {
267        match value {
268            HybridCodecError::MaxLineLengthExceeded
269            | HybridCodecError::MissingContentLength
270            | HybridCodecError::InvalidContentLength(_)
271            | HybridCodecError::InvalidHeaderFrame(_) => {
272                std::io::Error::new(std::io::ErrorKind::InvalidData, value)
273            }
274            HybridCodecError::Serde(error) => error.into(),
275            HybridCodecError::Io(error) => error,
276        }
277    }
278}
279
280fn looks_like_content_length_frame(buf: &BytesMut) -> bool {
281    let prefix = &buf[..buf.len().min(32)];
282    prefix
283        .windows(b"content-length".len())
284        .next()
285        .map(|candidate| candidate.eq_ignore_ascii_case(b"content-length"))
286        .unwrap_or(false)
287}
288
289fn find_header_terminator(buf: &BytesMut) -> Option<(usize, usize)> {
290    if let Some(index) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
291        return Some((index, 4));
292    }
293    buf.windows(2)
294        .position(|window| window == b"\n\n")
295        .map(|index| (index, 2))
296}
297
298fn parse_content_length(header: &str) -> Result<usize, HybridCodecError> {
299    for raw_line in header.lines() {
300        let line = raw_line.trim_end_matches('\r');
301        let (name, value) = match line.split_once(':') {
302            Some(parts) => parts,
303            None => continue,
304        };
305        if name.trim().eq_ignore_ascii_case("content-length") {
306            return value
307                .trim()
308                .parse::<usize>()
309                .map_err(|_| HybridCodecError::InvalidContentLength(value.trim().to_string()));
310        }
311    }
312
313    Err(HybridCodecError::MissingContentLength)
314}
315
316impl<T: DeserializeOwned> HybridJsonRpcMessageCodec<T> {
317    fn decode_content_length(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
318        let Some((header_end, delimiter_len)) = find_header_terminator(buf) else {
319            return Ok(None);
320        };
321
322        let header = std::str::from_utf8(&buf[..header_end])
323            .map_err(|error| HybridCodecError::InvalidHeaderFrame(error.to_string()))?;
324        let content_length = parse_content_length(header)?;
325        let body_start = header_end + delimiter_len;
326        let frame_len = body_start + content_length;
327        if buf.len() < frame_len {
328            return Ok(None);
329        }
330
331        let frame = buf.split_to(frame_len);
332        let payload = &frame[body_start..];
333        self.protocol.set_if_unset(WireProtocol::ContentLength);
334
335        try_parse_with_compatibility(payload, "decode_content_length")
336    }
337
338    fn decode_json_line(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
339        loop {
340            let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len());
341            let newline_offset = buf[self.next_index..read_to]
342                .iter()
343                .position(|byte| *byte == b'\n');
344
345            match (self.is_discarding, newline_offset) {
346                (true, Some(offset)) => {
347                    buf.advance(offset + self.next_index + 1);
348                    self.is_discarding = false;
349                    self.next_index = 0;
350                }
351                (true, None) => {
352                    buf.advance(read_to);
353                    self.next_index = 0;
354                    if buf.is_empty() {
355                        return Ok(None);
356                    }
357                }
358                (false, Some(offset)) => {
359                    let newline_index = offset + self.next_index;
360                    self.next_index = 0;
361                    let line = buf.split_to(newline_index + 1);
362                    let line = &line[..line.len() - 1];
363                    let payload = without_carriage_return(line);
364                    self.protocol.set_if_unset(WireProtocol::JsonLine);
365
366                    if let Some(item) = try_parse_with_compatibility(payload, "decode_json_line")? {
367                        return Ok(Some(item));
368                    }
369                }
370                (false, None) if buf.len() > self.max_length => {
371                    self.is_discarding = true;
372                    return Err(HybridCodecError::MaxLineLengthExceeded);
373                }
374                (false, None) => {
375                    self.next_index = read_to;
376                    return Ok(None);
377                }
378            }
379        }
380    }
381}
382
383impl<T: DeserializeOwned> Decoder for HybridJsonRpcMessageCodec<T> {
384    type Item = T;
385    type Error = HybridCodecError;
386
387    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
388        match self.protocol.get() {
389            Some(WireProtocol::ContentLength) => self.decode_content_length(buf),
390            Some(WireProtocol::JsonLine) => self.decode_json_line(buf),
391            None => {
392                if looks_like_content_length_frame(buf) {
393                    self.decode_content_length(buf)
394                } else {
395                    self.decode_json_line(buf)
396                }
397            }
398        }
399    }
400
401    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
402        match self.protocol.get() {
403            Some(WireProtocol::ContentLength) if !buf.is_empty() => self.decode_content_length(buf),
404            _ => Ok(match self.decode(buf)? {
405                Some(frame) => Some(frame),
406                None => {
407                    self.next_index = 0;
408                    if buf.is_empty() || buf == &b"\r"[..] {
409                        None
410                    } else {
411                        let line = buf.split_to(buf.len());
412                        let payload = without_carriage_return(&line);
413                        try_parse_with_compatibility(payload, "decode_eof")?
414                    }
415                }
416            }),
417        }
418    }
419}
420
421impl<T: Serialize> Encoder<T> for HybridJsonRpcMessageCodec<T> {
422    type Error = HybridCodecError;
423
424    fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), HybridCodecError> {
425        let payload = serde_json::to_vec(&item)?;
426
427        match self.protocol.get().unwrap_or(WireProtocol::ContentLength) {
428            WireProtocol::ContentLength => {
429                buf.extend_from_slice(
430                    format!("Content-Length: {}\r\n\r\n", payload.len()).as_bytes(),
431                );
432                buf.extend_from_slice(&payload);
433            }
434            WireProtocol::JsonLine => {
435                buf.extend_from_slice(&payload);
436                buf.put_u8(b'\n');
437            }
438        }
439
440        Ok(())
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    use tokio_util::bytes::BytesMut;
449
450    fn sample_message() -> serde_json::Value {
451        serde_json::json!({
452            "jsonrpc": "2.0",
453            "id": 1,
454            "method": "initialize",
455            "params": {
456                "protocolVersion": "2024-11-05",
457                "capabilities": {},
458                "clientInfo": {
459                    "name": "probe",
460                    "version": "0.0.0"
461                }
462            }
463        })
464    }
465
466    #[test]
467    fn decodes_json_line_and_marks_protocol() {
468        let protocol = SharedProtocol::new();
469        let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol.clone());
470        let payload = serde_json::to_vec(&sample_message()).unwrap();
471        let mut buf = BytesMut::from(&payload[..]);
472        buf.put_u8(b'\n');
473
474        let item = codec.decode(&mut buf).unwrap();
475        assert!(item.is_some());
476        assert_eq!(protocol.get(), Some(WireProtocol::JsonLine));
477    }
478
479    #[test]
480    fn decodes_content_length_and_marks_protocol() {
481        let protocol = SharedProtocol::new();
482        let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol.clone());
483        let payload = serde_json::to_vec(&sample_message()).unwrap();
484        let mut frame = BytesMut::new();
485        frame.extend_from_slice(format!("Content-Length: {}\r\n\r\n", payload.len()).as_bytes());
486        frame.extend_from_slice(&payload);
487
488        let item = codec.decode(&mut frame).unwrap();
489        assert!(item.is_some());
490        assert_eq!(protocol.get(), Some(WireProtocol::ContentLength));
491    }
492
493    #[test]
494    fn encodes_using_content_length_when_protocol_is_detected() {
495        let protocol = SharedProtocol::new();
496        protocol.set_if_unset(WireProtocol::ContentLength);
497        let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol);
498        let mut buf = BytesMut::new();
499        codec
500            .encode(
501                serde_json::json!({"jsonrpc":"2.0","id":1,"result":{"ok":true}}),
502                &mut buf,
503            )
504            .unwrap();
505
506        assert!(std::str::from_utf8(&buf)
507            .unwrap()
508            .starts_with("Content-Length: "));
509    }
510}