Skip to main content

lean_ctx/
mcp_stdio.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    sync::{Arc, Mutex},
5};
6
7use futures::{SinkExt, StreamExt};
8use rmcp::{
9    service::{RoleServer, RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage},
10    transport::Transport,
11};
12use serde::{de::DeserializeOwned, Serialize};
13use thiserror::Error;
14use tokio::{
15    io::{AsyncRead, AsyncWrite},
16    sync::Mutex as AsyncMutex,
17};
18use tokio_util::{
19    bytes::{Buf, BufMut, BytesMut},
20    codec::{Decoder, Encoder, FramedRead, FramedWrite},
21};
22
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24enum WireProtocol {
25    JsonLine,
26    ContentLength,
27}
28
29#[derive(Debug, Clone)]
30struct SharedProtocol(Arc<Mutex<Option<WireProtocol>>>);
31
32impl SharedProtocol {
33    fn new() -> Self {
34        Self(Arc::new(Mutex::new(None)))
35    }
36
37    fn get(&self) -> Option<WireProtocol> {
38        *self
39            .0
40            .lock()
41            .unwrap_or_else(std::sync::PoisonError::into_inner)
42    }
43
44    fn set_if_unset(&self, protocol: WireProtocol) {
45        let mut guard = self
46            .0
47            .lock()
48            .unwrap_or_else(std::sync::PoisonError::into_inner);
49        if guard.is_none() {
50            *guard = Some(protocol);
51        }
52    }
53}
54
55pub type TransportWriter<Role, W> =
56    FramedWrite<W, HybridJsonRpcMessageCodec<TxJsonRpcMessage<Role>>>;
57
58pub struct HybridStdioTransport<Role: ServiceRole, R: AsyncRead, W: AsyncWrite> {
59    read: FramedRead<R, HybridJsonRpcMessageCodec<RxJsonRpcMessage<Role>>>,
60    write: Arc<AsyncMutex<Option<TransportWriter<Role, W>>>>,
61}
62
63impl<Role: ServiceRole, R, W> HybridStdioTransport<Role, R, W>
64where
65    R: Send + AsyncRead + Unpin,
66    W: Send + AsyncWrite + Unpin + 'static,
67{
68    pub fn new(read: R, write: W) -> Self {
69        let protocol = SharedProtocol::new();
70        let read = FramedRead::new(
71            read,
72            HybridJsonRpcMessageCodec::<RxJsonRpcMessage<Role>>::new(protocol.clone()),
73        );
74        let write = Arc::new(AsyncMutex::new(Some(FramedWrite::new(
75            write,
76            HybridJsonRpcMessageCodec::<TxJsonRpcMessage<Role>>::new(protocol),
77        ))));
78        Self { read, write }
79    }
80}
81
82impl<R, W> HybridStdioTransport<RoleServer, R, W>
83where
84    R: Send + AsyncRead + Unpin,
85    W: Send + AsyncWrite + Unpin + 'static,
86{
87    pub fn new_server(read: R, write: W) -> Self {
88        Self::new(read, write)
89    }
90}
91
92impl<Role: ServiceRole, R, W> Transport<Role> for HybridStdioTransport<Role, R, W>
93where
94    R: Send + AsyncRead + Unpin,
95    W: Send + AsyncWrite + Unpin + 'static,
96{
97    type Error = std::io::Error;
98
99    fn send(
100        &mut self,
101        item: TxJsonRpcMessage<Role>,
102    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
103        let lock = self.write.clone();
104        async move {
105            let mut write = lock.lock().await;
106            if let Some(ref mut write) = *write {
107                write.send(item).await.map_err(Into::into)
108            } else {
109                Err(std::io::Error::new(
110                    std::io::ErrorKind::NotConnected,
111                    "Transport is closed",
112                ))
113            }
114        }
115    }
116
117    fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<Role>>> + Send {
118        let next = self.read.next();
119        async {
120            next.await.and_then(|result| {
121                result
122                    .inspect_err(|error| {
123                        tracing::error!("Error reading from stream: {}", error);
124                    })
125                    .ok()
126            })
127        }
128    }
129
130    async fn close(&mut self) -> Result<(), Self::Error> {
131        let mut write = self.write.lock().await;
132        drop(write.take());
133        Ok(())
134    }
135}
136
137#[derive(Debug, Clone)]
138pub struct HybridJsonRpcMessageCodec<T> {
139    _marker: PhantomData<fn() -> T>,
140    next_index: usize,
141    max_length: usize,
142    is_discarding: bool,
143    protocol: SharedProtocol,
144}
145
146impl<T> HybridJsonRpcMessageCodec<T> {
147    fn new(protocol: SharedProtocol) -> Self {
148        Self {
149            _marker: PhantomData,
150            next_index: 0,
151            max_length: 32 * 1024 * 1024, // 32 MiB — prevents OOM from oversized messages
152            is_discarding: false,
153            protocol,
154        }
155    }
156}
157
158fn without_carriage_return(s: &[u8]) -> &[u8] {
159    if let Some(&b'\r') = s.last() {
160        &s[..s.len() - 1]
161    } else {
162        s
163    }
164}
165
166fn is_standard_method(method: &str) -> bool {
167    matches!(
168        method,
169        "initialize"
170            | "ping"
171            | "prompts/get"
172            | "prompts/list"
173            | "resources/list"
174            | "resources/read"
175            | "resources/subscribe"
176            | "resources/unsubscribe"
177            | "resources/templates/list"
178            | "tools/call"
179            | "tools/list"
180            | "completion/complete"
181            | "logging/setLevel"
182            | "roots/list"
183            | "sampling/createMessage"
184    ) || is_standard_notification(method)
185}
186
187fn is_standard_notification(method: &str) -> bool {
188    matches!(
189        method,
190        "notifications/cancelled"
191            | "notifications/initialized"
192            | "notifications/message"
193            | "notifications/progress"
194            | "notifications/prompts/list_changed"
195            | "notifications/resources/list_changed"
196            | "notifications/resources/updated"
197            | "notifications/roots/list_changed"
198            | "notifications/tools/list_changed"
199    )
200}
201
202fn should_ignore_notification(json_value: &serde_json::Value, method: &str) -> bool {
203    let is_notification = json_value.get("id").is_none();
204    if is_notification && !is_standard_method(method) {
205        tracing::trace!(
206            "Ignoring non-MCP notification '{}' for compatibility",
207            method
208        );
209        return true;
210    }
211
212    matches!(
213        (
214            method.starts_with("notifications/"),
215            is_standard_notification(method)
216        ),
217        (true, false)
218    )
219}
220
221fn try_parse_with_compatibility<T: DeserializeOwned>(
222    payload: &[u8],
223    context: &str,
224) -> Result<Option<T>, HybridCodecError> {
225    if let Ok(line_str) = std::str::from_utf8(payload) {
226        match serde_json::from_slice(payload) {
227            Ok(item) => Ok(Some(item)),
228            Err(error) => {
229                if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(line_str) {
230                    if let Some(method) =
231                        json_value.get("method").and_then(serde_json::Value::as_str)
232                    {
233                        if should_ignore_notification(&json_value, method) {
234                            return Ok(None);
235                        }
236                    }
237                }
238
239                tracing::debug!(
240                    "Failed to parse message {}: {} | Error: {}",
241                    context,
242                    line_str,
243                    error
244                );
245                Err(HybridCodecError::Serde(error))
246            }
247        }
248    } else {
249        serde_json::from_slice(payload)
250            .map(Some)
251            .map_err(HybridCodecError::Serde)
252    }
253}
254
255#[derive(Debug, Error)]
256pub enum HybridCodecError {
257    #[error("max line length exceeded")]
258    MaxLineLengthExceeded,
259    #[error("missing Content-Length header")]
260    MissingContentLength,
261    #[error("invalid Content-Length value: {0}")]
262    InvalidContentLength(String),
263    #[error("invalid header frame: {0}")]
264    InvalidHeaderFrame(String),
265    #[error("serde error {0}")]
266    Serde(#[from] serde_json::Error),
267    #[error("io error {0}")]
268    Io(#[from] std::io::Error),
269}
270
271impl From<HybridCodecError> for std::io::Error {
272    fn from(value: HybridCodecError) -> Self {
273        match value {
274            HybridCodecError::MaxLineLengthExceeded
275            | HybridCodecError::MissingContentLength
276            | HybridCodecError::InvalidContentLength(_)
277            | HybridCodecError::InvalidHeaderFrame(_) => {
278                std::io::Error::new(std::io::ErrorKind::InvalidData, value)
279            }
280            HybridCodecError::Serde(error) => error.into(),
281            HybridCodecError::Io(error) => error,
282        }
283    }
284}
285
286fn looks_like_content_length_frame(buf: &BytesMut) -> bool {
287    let prefix = &buf[..buf.len().min(32)];
288    prefix
289        .windows(b"content-length".len())
290        .next()
291        .is_some_and(|candidate| candidate.eq_ignore_ascii_case(b"content-length"))
292}
293
294fn find_header_terminator(buf: &BytesMut) -> Option<(usize, usize)> {
295    if let Some(index) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
296        return Some((index, 4));
297    }
298    buf.windows(2)
299        .position(|window| window == b"\n\n")
300        .map(|index| (index, 2))
301}
302
303fn parse_content_length(header: &str) -> Result<usize, HybridCodecError> {
304    for raw_line in header.lines() {
305        let line = raw_line.trim_end_matches('\r');
306        let Some((name, value)) = line.split_once(':') else {
307            continue;
308        };
309        if name.trim().eq_ignore_ascii_case("content-length") {
310            return value
311                .trim()
312                .parse::<usize>()
313                .map_err(|_| HybridCodecError::InvalidContentLength(value.trim().to_string()));
314        }
315    }
316
317    Err(HybridCodecError::MissingContentLength)
318}
319
320impl<T: DeserializeOwned> HybridJsonRpcMessageCodec<T> {
321    fn decode_content_length(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
322        let Some((header_end, delimiter_len)) = find_header_terminator(buf) else {
323            return Ok(None);
324        };
325
326        let header = std::str::from_utf8(&buf[..header_end])
327            .map_err(|error| HybridCodecError::InvalidHeaderFrame(error.to_string()))?;
328        let content_length = parse_content_length(header)?;
329        if content_length > self.max_length {
330            return Err(HybridCodecError::MaxLineLengthExceeded);
331        }
332        let body_start = header_end + delimiter_len;
333        let frame_len = body_start
334            .checked_add(content_length)
335            .ok_or(HybridCodecError::MaxLineLengthExceeded)?;
336        if buf.len() < frame_len {
337            return Ok(None);
338        }
339
340        let frame = buf.split_to(frame_len);
341        let payload = &frame[body_start..];
342        self.protocol.set_if_unset(WireProtocol::ContentLength);
343
344        try_parse_with_compatibility(payload, "decode_content_length")
345    }
346
347    fn decode_json_line(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
348        loop {
349            let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len());
350            let newline_offset = buf[self.next_index..read_to]
351                .iter()
352                .position(|byte| *byte == b'\n');
353
354            match (self.is_discarding, newline_offset) {
355                (true, Some(offset)) => {
356                    buf.advance(offset + self.next_index + 1);
357                    self.is_discarding = false;
358                    self.next_index = 0;
359                }
360                (true, None) => {
361                    buf.advance(read_to);
362                    self.next_index = 0;
363                    if buf.is_empty() {
364                        return Ok(None);
365                    }
366                }
367                (false, Some(offset)) => {
368                    let newline_index = offset + self.next_index;
369                    self.next_index = 0;
370                    let line = buf.split_to(newline_index + 1);
371                    let line = &line[..line.len() - 1];
372                    let payload = without_carriage_return(line);
373                    self.protocol.set_if_unset(WireProtocol::JsonLine);
374
375                    if let Some(item) = try_parse_with_compatibility(payload, "decode_json_line")? {
376                        return Ok(Some(item));
377                    }
378                }
379                (false, None) if buf.len() > self.max_length => {
380                    self.is_discarding = true;
381                    return Err(HybridCodecError::MaxLineLengthExceeded);
382                }
383                (false, None) => {
384                    self.next_index = read_to;
385                    return Ok(None);
386                }
387            }
388        }
389    }
390}
391
392impl<T: DeserializeOwned> Decoder for HybridJsonRpcMessageCodec<T> {
393    type Item = T;
394    type Error = HybridCodecError;
395
396    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
397        match self.protocol.get() {
398            Some(WireProtocol::ContentLength) => self.decode_content_length(buf),
399            Some(WireProtocol::JsonLine) => self.decode_json_line(buf),
400            None => {
401                if looks_like_content_length_frame(buf) {
402                    self.decode_content_length(buf)
403                } else {
404                    self.decode_json_line(buf)
405                }
406            }
407        }
408    }
409
410    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
411        match self.protocol.get() {
412            Some(WireProtocol::ContentLength) if !buf.is_empty() => self.decode_content_length(buf),
413            _ => Ok(if let Some(frame) = self.decode(buf)? {
414                Some(frame)
415            } else {
416                self.next_index = 0;
417                if buf.is_empty() || buf == &b"\r"[..] {
418                    None
419                } else {
420                    let line = buf.split_to(buf.len());
421                    let payload = without_carriage_return(&line);
422                    try_parse_with_compatibility(payload, "decode_eof")?
423                }
424            }),
425        }
426    }
427}
428
429impl<T: Serialize> Encoder<T> for HybridJsonRpcMessageCodec<T> {
430    type Error = HybridCodecError;
431
432    fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), HybridCodecError> {
433        let payload = serde_json::to_vec(&item)?;
434
435        match self.protocol.get().unwrap_or(WireProtocol::ContentLength) {
436            WireProtocol::ContentLength => {
437                buf.extend_from_slice(
438                    format!("Content-Length: {}\r\n\r\n", payload.len()).as_bytes(),
439                );
440                buf.extend_from_slice(&payload);
441            }
442            WireProtocol::JsonLine => {
443                buf.extend_from_slice(&payload);
444                buf.put_u8(b'\n');
445            }
446        }
447
448        Ok(())
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    use tokio_util::bytes::BytesMut;
457
458    fn sample_message() -> serde_json::Value {
459        serde_json::json!({
460            "jsonrpc": "2.0",
461            "id": 1,
462            "method": "initialize",
463            "params": {
464                "protocolVersion": "2024-11-05",
465                "capabilities": {},
466                "clientInfo": {
467                    "name": "probe",
468                    "version": "0.0.0"
469                }
470            }
471        })
472    }
473
474    #[test]
475    fn decodes_json_line_and_marks_protocol() {
476        let protocol = SharedProtocol::new();
477        let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol.clone());
478        let payload = serde_json::to_vec(&sample_message()).unwrap();
479        let mut buf = BytesMut::from(&payload[..]);
480        buf.put_u8(b'\n');
481
482        let item = codec.decode(&mut buf).unwrap();
483        assert!(item.is_some());
484        assert_eq!(protocol.get(), Some(WireProtocol::JsonLine));
485    }
486
487    #[test]
488    fn decodes_content_length_and_marks_protocol() {
489        let protocol = SharedProtocol::new();
490        let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol.clone());
491        let payload = serde_json::to_vec(&sample_message()).unwrap();
492        let mut frame = BytesMut::new();
493        frame.extend_from_slice(format!("Content-Length: {}\r\n\r\n", payload.len()).as_bytes());
494        frame.extend_from_slice(&payload);
495
496        let item = codec.decode(&mut frame).unwrap();
497        assert!(item.is_some());
498        assert_eq!(protocol.get(), Some(WireProtocol::ContentLength));
499    }
500
501    #[test]
502    fn encodes_using_content_length_when_protocol_is_detected() {
503        let protocol = SharedProtocol::new();
504        protocol.set_if_unset(WireProtocol::ContentLength);
505        let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol);
506        let mut buf = BytesMut::new();
507        codec
508            .encode(
509                serde_json::json!({"jsonrpc":"2.0","id":1,"result":{"ok":true}}),
510                &mut buf,
511            )
512            .unwrap();
513
514        assert!(std::str::from_utf8(&buf)
515            .unwrap()
516            .starts_with("Content-Length: "));
517    }
518}