1use std::{
2 future::Future,
3 marker::PhantomData,
4 sync::{Arc, Mutex},
5};
6
7use futures::{SinkExt, StreamExt};
8use rmcp::{
9 service::{RoleServer, RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage},
10 transport::Transport,
11};
12use serde::{de::DeserializeOwned, Serialize};
13use thiserror::Error;
14use tokio::{
15 io::{AsyncRead, AsyncWrite},
16 sync::Mutex as AsyncMutex,
17};
18use tokio_util::{
19 bytes::{Buf, BufMut, BytesMut},
20 codec::{Decoder, Encoder, FramedRead, FramedWrite},
21};
22
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24enum WireProtocol {
25 JsonLine,
26 ContentLength,
27}
28
29#[derive(Debug, Clone)]
30struct SharedProtocol(Arc<Mutex<Option<WireProtocol>>>);
31
32impl SharedProtocol {
33 fn new() -> Self {
34 Self(Arc::new(Mutex::new(None)))
35 }
36
37 fn get(&self) -> Option<WireProtocol> {
38 *self
39 .0
40 .lock()
41 .unwrap_or_else(std::sync::PoisonError::into_inner)
42 }
43
44 fn set_if_unset(&self, protocol: WireProtocol) {
45 let mut guard = self
46 .0
47 .lock()
48 .unwrap_or_else(std::sync::PoisonError::into_inner);
49 if guard.is_none() {
50 *guard = Some(protocol);
51 }
52 }
53}
54
55pub type TransportWriter<Role, W> =
56 FramedWrite<W, HybridJsonRpcMessageCodec<TxJsonRpcMessage<Role>>>;
57
58pub struct HybridStdioTransport<Role: ServiceRole, R: AsyncRead, W: AsyncWrite> {
59 read: FramedRead<R, HybridJsonRpcMessageCodec<RxJsonRpcMessage<Role>>>,
60 write: Arc<AsyncMutex<Option<TransportWriter<Role, W>>>>,
61}
62
63impl<Role: ServiceRole, R, W> HybridStdioTransport<Role, R, W>
64where
65 R: Send + AsyncRead + Unpin,
66 W: Send + AsyncWrite + Unpin + 'static,
67{
68 pub fn new(read: R, write: W) -> Self {
69 let protocol = SharedProtocol::new();
70 let read = FramedRead::new(
71 read,
72 HybridJsonRpcMessageCodec::<RxJsonRpcMessage<Role>>::new(protocol.clone()),
73 );
74 let write = Arc::new(AsyncMutex::new(Some(FramedWrite::new(
75 write,
76 HybridJsonRpcMessageCodec::<TxJsonRpcMessage<Role>>::new(protocol),
77 ))));
78 Self { read, write }
79 }
80}
81
82impl<R, W> HybridStdioTransport<RoleServer, R, W>
83where
84 R: Send + AsyncRead + Unpin,
85 W: Send + AsyncWrite + Unpin + 'static,
86{
87 pub fn new_server(read: R, write: W) -> Self {
88 Self::new(read, write)
89 }
90}
91
92impl<Role: ServiceRole, R, W> Transport<Role> for HybridStdioTransport<Role, R, W>
93where
94 R: Send + AsyncRead + Unpin,
95 W: Send + AsyncWrite + Unpin + 'static,
96{
97 type Error = std::io::Error;
98
99 fn send(
100 &mut self,
101 item: TxJsonRpcMessage<Role>,
102 ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
103 let lock = self.write.clone();
104 async move {
105 let mut write = lock.lock().await;
106 if let Some(ref mut write) = *write {
107 write.send(item).await.map_err(Into::into)
108 } else {
109 Err(std::io::Error::new(
110 std::io::ErrorKind::NotConnected,
111 "Transport is closed",
112 ))
113 }
114 }
115 }
116
117 fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<Role>>> + Send {
118 let next = self.read.next();
119 async {
120 next.await.and_then(|result| {
121 result
122 .inspect_err(|error| {
123 tracing::error!("Error reading from stream: {}", error);
124 })
125 .ok()
126 })
127 }
128 }
129
130 async fn close(&mut self) -> Result<(), Self::Error> {
131 let mut write = self.write.lock().await;
132 drop(write.take());
133 Ok(())
134 }
135}
136
137#[derive(Debug, Clone)]
138pub struct HybridJsonRpcMessageCodec<T> {
139 _marker: PhantomData<fn() -> T>,
140 next_index: usize,
141 max_length: usize,
142 is_discarding: bool,
143 protocol: SharedProtocol,
144}
145
146impl<T> HybridJsonRpcMessageCodec<T> {
147 fn new(protocol: SharedProtocol) -> Self {
148 Self {
149 _marker: PhantomData,
150 next_index: 0,
151 max_length: 32 * 1024 * 1024, is_discarding: false,
153 protocol,
154 }
155 }
156}
157
158fn without_carriage_return(s: &[u8]) -> &[u8] {
159 if let Some(&b'\r') = s.last() {
160 &s[..s.len() - 1]
161 } else {
162 s
163 }
164}
165
166fn is_standard_method(method: &str) -> bool {
167 matches!(
168 method,
169 "initialize"
170 | "ping"
171 | "prompts/get"
172 | "prompts/list"
173 | "resources/list"
174 | "resources/read"
175 | "resources/subscribe"
176 | "resources/unsubscribe"
177 | "resources/templates/list"
178 | "tools/call"
179 | "tools/list"
180 | "completion/complete"
181 | "logging/setLevel"
182 | "roots/list"
183 | "sampling/createMessage"
184 ) || is_standard_notification(method)
185}
186
187fn is_standard_notification(method: &str) -> bool {
188 matches!(
189 method,
190 "notifications/cancelled"
191 | "notifications/initialized"
192 | "notifications/message"
193 | "notifications/progress"
194 | "notifications/prompts/list_changed"
195 | "notifications/resources/list_changed"
196 | "notifications/resources/updated"
197 | "notifications/roots/list_changed"
198 | "notifications/tools/list_changed"
199 )
200}
201
202fn should_ignore_notification(json_value: &serde_json::Value, method: &str) -> bool {
203 let is_notification = json_value.get("id").is_none();
204 if is_notification && !is_standard_method(method) {
205 tracing::trace!(
206 "Ignoring non-MCP notification '{}' for compatibility",
207 method
208 );
209 return true;
210 }
211
212 matches!(
213 (
214 method.starts_with("notifications/"),
215 is_standard_notification(method)
216 ),
217 (true, false)
218 )
219}
220
221fn try_parse_with_compatibility<T: DeserializeOwned>(
222 payload: &[u8],
223 context: &str,
224) -> Result<Option<T>, HybridCodecError> {
225 if let Ok(line_str) = std::str::from_utf8(payload) {
226 match serde_json::from_slice(payload) {
227 Ok(item) => Ok(Some(item)),
228 Err(error) => {
229 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(line_str) {
230 if let Some(method) =
231 json_value.get("method").and_then(serde_json::Value::as_str)
232 {
233 if should_ignore_notification(&json_value, method) {
234 return Ok(None);
235 }
236 }
237 }
238
239 tracing::debug!(
240 "Failed to parse message {}: {} | Error: {}",
241 context,
242 line_str,
243 error
244 );
245 Err(HybridCodecError::Serde(error))
246 }
247 }
248 } else {
249 serde_json::from_slice(payload)
250 .map(Some)
251 .map_err(HybridCodecError::Serde)
252 }
253}
254
255#[derive(Debug, Error)]
256pub enum HybridCodecError {
257 #[error("max line length exceeded")]
258 MaxLineLengthExceeded,
259 #[error("missing Content-Length header")]
260 MissingContentLength,
261 #[error("invalid Content-Length value: {0}")]
262 InvalidContentLength(String),
263 #[error("invalid header frame: {0}")]
264 InvalidHeaderFrame(String),
265 #[error("serde error {0}")]
266 Serde(#[from] serde_json::Error),
267 #[error("io error {0}")]
268 Io(#[from] std::io::Error),
269}
270
271impl From<HybridCodecError> for std::io::Error {
272 fn from(value: HybridCodecError) -> Self {
273 match value {
274 HybridCodecError::MaxLineLengthExceeded
275 | HybridCodecError::MissingContentLength
276 | HybridCodecError::InvalidContentLength(_)
277 | HybridCodecError::InvalidHeaderFrame(_) => {
278 std::io::Error::new(std::io::ErrorKind::InvalidData, value)
279 }
280 HybridCodecError::Serde(error) => error.into(),
281 HybridCodecError::Io(error) => error,
282 }
283 }
284}
285
286fn looks_like_content_length_frame(buf: &BytesMut) -> bool {
287 let prefix = &buf[..buf.len().min(32)];
288 prefix
289 .windows(b"content-length".len())
290 .next()
291 .is_some_and(|candidate| candidate.eq_ignore_ascii_case(b"content-length"))
292}
293
294fn find_header_terminator(buf: &BytesMut) -> Option<(usize, usize)> {
295 if let Some(index) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
296 return Some((index, 4));
297 }
298 buf.windows(2)
299 .position(|window| window == b"\n\n")
300 .map(|index| (index, 2))
301}
302
303fn parse_content_length(header: &str) -> Result<usize, HybridCodecError> {
304 for raw_line in header.lines() {
305 let line = raw_line.trim_end_matches('\r');
306 let Some((name, value)) = line.split_once(':') else {
307 continue;
308 };
309 if name.trim().eq_ignore_ascii_case("content-length") {
310 return value
311 .trim()
312 .parse::<usize>()
313 .map_err(|_| HybridCodecError::InvalidContentLength(value.trim().to_string()));
314 }
315 }
316
317 Err(HybridCodecError::MissingContentLength)
318}
319
320impl<T: DeserializeOwned> HybridJsonRpcMessageCodec<T> {
321 fn decode_content_length(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
322 let Some((header_end, delimiter_len)) = find_header_terminator(buf) else {
323 return Ok(None);
324 };
325
326 let header = std::str::from_utf8(&buf[..header_end])
327 .map_err(|error| HybridCodecError::InvalidHeaderFrame(error.to_string()))?;
328 let content_length = parse_content_length(header)?;
329 if content_length > self.max_length {
330 return Err(HybridCodecError::MaxLineLengthExceeded);
331 }
332 let body_start = header_end + delimiter_len;
333 let frame_len = body_start
334 .checked_add(content_length)
335 .ok_or(HybridCodecError::MaxLineLengthExceeded)?;
336 if buf.len() < frame_len {
337 return Ok(None);
338 }
339
340 let frame = buf.split_to(frame_len);
341 let payload = &frame[body_start..];
342 self.protocol.set_if_unset(WireProtocol::ContentLength);
343
344 try_parse_with_compatibility(payload, "decode_content_length")
345 }
346
347 fn decode_json_line(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
348 loop {
349 let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len());
350 let newline_offset = buf[self.next_index..read_to]
351 .iter()
352 .position(|byte| *byte == b'\n');
353
354 match (self.is_discarding, newline_offset) {
355 (true, Some(offset)) => {
356 buf.advance(offset + self.next_index + 1);
357 self.is_discarding = false;
358 self.next_index = 0;
359 }
360 (true, None) => {
361 buf.advance(read_to);
362 self.next_index = 0;
363 if buf.is_empty() {
364 return Ok(None);
365 }
366 }
367 (false, Some(offset)) => {
368 let newline_index = offset + self.next_index;
369 self.next_index = 0;
370 let line = buf.split_to(newline_index + 1);
371 let line = &line[..line.len() - 1];
372 let payload = without_carriage_return(line);
373 self.protocol.set_if_unset(WireProtocol::JsonLine);
374
375 if let Some(item) = try_parse_with_compatibility(payload, "decode_json_line")? {
376 return Ok(Some(item));
377 }
378 }
379 (false, None) if buf.len() > self.max_length => {
380 self.is_discarding = true;
381 return Err(HybridCodecError::MaxLineLengthExceeded);
382 }
383 (false, None) => {
384 self.next_index = read_to;
385 return Ok(None);
386 }
387 }
388 }
389 }
390}
391
392impl<T: DeserializeOwned> Decoder for HybridJsonRpcMessageCodec<T> {
393 type Item = T;
394 type Error = HybridCodecError;
395
396 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
397 match self.protocol.get() {
398 Some(WireProtocol::ContentLength) => self.decode_content_length(buf),
399 Some(WireProtocol::JsonLine) => self.decode_json_line(buf),
400 None => {
401 if looks_like_content_length_frame(buf) {
402 self.decode_content_length(buf)
403 } else {
404 self.decode_json_line(buf)
405 }
406 }
407 }
408 }
409
410 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
411 match self.protocol.get() {
412 Some(WireProtocol::ContentLength) if !buf.is_empty() => self.decode_content_length(buf),
413 _ => Ok(if let Some(frame) = self.decode(buf)? {
414 Some(frame)
415 } else {
416 self.next_index = 0;
417 if buf.is_empty() || buf == &b"\r"[..] {
418 None
419 } else {
420 let line = buf.split_to(buf.len());
421 let payload = without_carriage_return(&line);
422 try_parse_with_compatibility(payload, "decode_eof")?
423 }
424 }),
425 }
426 }
427}
428
429impl<T: Serialize> Encoder<T> for HybridJsonRpcMessageCodec<T> {
430 type Error = HybridCodecError;
431
432 fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), HybridCodecError> {
433 let payload = serde_json::to_vec(&item)?;
434
435 match self.protocol.get().unwrap_or(WireProtocol::ContentLength) {
436 WireProtocol::ContentLength => {
437 buf.extend_from_slice(
438 format!("Content-Length: {}\r\n\r\n", payload.len()).as_bytes(),
439 );
440 buf.extend_from_slice(&payload);
441 }
442 WireProtocol::JsonLine => {
443 buf.extend_from_slice(&payload);
444 buf.put_u8(b'\n');
445 }
446 }
447
448 Ok(())
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 use tokio_util::bytes::BytesMut;
457
458 fn sample_message() -> serde_json::Value {
459 serde_json::json!({
460 "jsonrpc": "2.0",
461 "id": 1,
462 "method": "initialize",
463 "params": {
464 "protocolVersion": "2024-11-05",
465 "capabilities": {},
466 "clientInfo": {
467 "name": "probe",
468 "version": "0.0.0"
469 }
470 }
471 })
472 }
473
474 #[test]
475 fn decodes_json_line_and_marks_protocol() {
476 let protocol = SharedProtocol::new();
477 let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol.clone());
478 let payload = serde_json::to_vec(&sample_message()).unwrap();
479 let mut buf = BytesMut::from(&payload[..]);
480 buf.put_u8(b'\n');
481
482 let item = codec.decode(&mut buf).unwrap();
483 assert!(item.is_some());
484 assert_eq!(protocol.get(), Some(WireProtocol::JsonLine));
485 }
486
487 #[test]
488 fn decodes_content_length_and_marks_protocol() {
489 let protocol = SharedProtocol::new();
490 let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol.clone());
491 let payload = serde_json::to_vec(&sample_message()).unwrap();
492 let mut frame = BytesMut::new();
493 frame.extend_from_slice(format!("Content-Length: {}\r\n\r\n", payload.len()).as_bytes());
494 frame.extend_from_slice(&payload);
495
496 let item = codec.decode(&mut frame).unwrap();
497 assert!(item.is_some());
498 assert_eq!(protocol.get(), Some(WireProtocol::ContentLength));
499 }
500
501 #[test]
502 fn encodes_using_content_length_when_protocol_is_detected() {
503 let protocol = SharedProtocol::new();
504 protocol.set_if_unset(WireProtocol::ContentLength);
505 let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol);
506 let mut buf = BytesMut::new();
507 codec
508 .encode(
509 serde_json::json!({"jsonrpc":"2.0","id":1,"result":{"ok":true}}),
510 &mut buf,
511 )
512 .unwrap();
513
514 assert!(std::str::from_utf8(&buf)
515 .unwrap()
516 .starts_with("Content-Length: "));
517 }
518}