1use std::{
2 future::Future,
3 marker::PhantomData,
4 sync::{Arc, Mutex},
5};
6
7use futures::{SinkExt, StreamExt};
8use rmcp::{
9 service::{RoleServer, RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage},
10 transport::Transport,
11};
12use serde::{de::DeserializeOwned, Serialize};
13use thiserror::Error;
14use tokio::{
15 io::{AsyncRead, AsyncWrite},
16 sync::Mutex as AsyncMutex,
17};
18use tokio_util::{
19 bytes::{Buf, BufMut, BytesMut},
20 codec::{Decoder, Encoder, FramedRead, FramedWrite},
21};
22
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24enum WireProtocol {
25 JsonLine,
26 ContentLength,
27}
28
29#[derive(Debug, Clone)]
30struct SharedProtocol(Arc<Mutex<Option<WireProtocol>>>);
31
32impl SharedProtocol {
33 fn new() -> Self {
34 Self(Arc::new(Mutex::new(None)))
35 }
36
37 fn get(&self) -> Option<WireProtocol> {
38 *self.0.lock().expect("protocol mutex poisoned")
39 }
40
41 fn set_if_unset(&self, protocol: WireProtocol) {
42 let mut guard = self.0.lock().expect("protocol mutex poisoned");
43 if guard.is_none() {
44 *guard = Some(protocol);
45 }
46 }
47}
48
49pub type TransportWriter<Role, W> =
50 FramedWrite<W, HybridJsonRpcMessageCodec<TxJsonRpcMessage<Role>>>;
51
52pub struct HybridStdioTransport<Role: ServiceRole, R: AsyncRead, W: AsyncWrite> {
53 read: FramedRead<R, HybridJsonRpcMessageCodec<RxJsonRpcMessage<Role>>>,
54 write: Arc<AsyncMutex<Option<TransportWriter<Role, W>>>>,
55}
56
57impl<Role: ServiceRole, R, W> HybridStdioTransport<Role, R, W>
58where
59 R: Send + AsyncRead + Unpin,
60 W: Send + AsyncWrite + Unpin + 'static,
61{
62 pub fn new(read: R, write: W) -> Self {
63 let protocol = SharedProtocol::new();
64 let read = FramedRead::new(
65 read,
66 HybridJsonRpcMessageCodec::<RxJsonRpcMessage<Role>>::new(protocol.clone()),
67 );
68 let write = Arc::new(AsyncMutex::new(Some(FramedWrite::new(
69 write,
70 HybridJsonRpcMessageCodec::<TxJsonRpcMessage<Role>>::new(protocol),
71 ))));
72 Self { read, write }
73 }
74}
75
76impl<R, W> HybridStdioTransport<RoleServer, R, W>
77where
78 R: Send + AsyncRead + Unpin,
79 W: Send + AsyncWrite + Unpin + 'static,
80{
81 pub fn new_server(read: R, write: W) -> Self {
82 Self::new(read, write)
83 }
84}
85
86impl<Role: ServiceRole, R, W> Transport<Role> for HybridStdioTransport<Role, R, W>
87where
88 R: Send + AsyncRead + Unpin,
89 W: Send + AsyncWrite + Unpin + 'static,
90{
91 type Error = std::io::Error;
92
93 fn send(
94 &mut self,
95 item: TxJsonRpcMessage<Role>,
96 ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
97 let lock = self.write.clone();
98 async move {
99 let mut write = lock.lock().await;
100 if let Some(ref mut write) = *write {
101 write.send(item).await.map_err(Into::into)
102 } else {
103 Err(std::io::Error::new(
104 std::io::ErrorKind::NotConnected,
105 "Transport is closed",
106 ))
107 }
108 }
109 }
110
111 fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<Role>>> + Send {
112 let next = self.read.next();
113 async {
114 next.await.and_then(|result| {
115 result
116 .inspect_err(|error| {
117 tracing::error!("Error reading from stream: {}", error);
118 })
119 .ok()
120 })
121 }
122 }
123
124 async fn close(&mut self) -> Result<(), Self::Error> {
125 let mut write = self.write.lock().await;
126 drop(write.take());
127 Ok(())
128 }
129}
130
131#[derive(Debug, Clone)]
132pub struct HybridJsonRpcMessageCodec<T> {
133 _marker: PhantomData<fn() -> T>,
134 next_index: usize,
135 max_length: usize,
136 is_discarding: bool,
137 protocol: SharedProtocol,
138}
139
140impl<T> HybridJsonRpcMessageCodec<T> {
141 fn new(protocol: SharedProtocol) -> Self {
142 Self {
143 _marker: PhantomData,
144 next_index: 0,
145 max_length: 32 * 1024 * 1024, is_discarding: false,
147 protocol,
148 }
149 }
150}
151
152fn without_carriage_return(s: &[u8]) -> &[u8] {
153 if let Some(&b'\r') = s.last() {
154 &s[..s.len() - 1]
155 } else {
156 s
157 }
158}
159
160fn is_standard_method(method: &str) -> bool {
161 matches!(
162 method,
163 "initialize"
164 | "ping"
165 | "prompts/get"
166 | "prompts/list"
167 | "resources/list"
168 | "resources/read"
169 | "resources/subscribe"
170 | "resources/unsubscribe"
171 | "resources/templates/list"
172 | "tools/call"
173 | "tools/list"
174 | "completion/complete"
175 | "logging/setLevel"
176 | "roots/list"
177 | "sampling/createMessage"
178 ) || is_standard_notification(method)
179}
180
181fn is_standard_notification(method: &str) -> bool {
182 matches!(
183 method,
184 "notifications/cancelled"
185 | "notifications/initialized"
186 | "notifications/message"
187 | "notifications/progress"
188 | "notifications/prompts/list_changed"
189 | "notifications/resources/list_changed"
190 | "notifications/resources/updated"
191 | "notifications/roots/list_changed"
192 | "notifications/tools/list_changed"
193 )
194}
195
196fn should_ignore_notification(json_value: &serde_json::Value, method: &str) -> bool {
197 let is_notification = json_value.get("id").is_none();
198 if is_notification && !is_standard_method(method) {
199 tracing::trace!(
200 "Ignoring non-MCP notification '{}' for compatibility",
201 method
202 );
203 return true;
204 }
205
206 matches!(
207 (
208 method.starts_with("notifications/"),
209 is_standard_notification(method)
210 ),
211 (true, false)
212 )
213}
214
215fn try_parse_with_compatibility<T: DeserializeOwned>(
216 payload: &[u8],
217 context: &str,
218) -> Result<Option<T>, HybridCodecError> {
219 if let Ok(line_str) = std::str::from_utf8(payload) {
220 match serde_json::from_slice(payload) {
221 Ok(item) => Ok(Some(item)),
222 Err(error) => {
223 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(line_str) {
224 if let Some(method) =
225 json_value.get("method").and_then(serde_json::Value::as_str)
226 {
227 if should_ignore_notification(&json_value, method) {
228 return Ok(None);
229 }
230 }
231 }
232
233 tracing::debug!(
234 "Failed to parse message {}: {} | Error: {}",
235 context,
236 line_str,
237 error
238 );
239 Err(HybridCodecError::Serde(error))
240 }
241 }
242 } else {
243 serde_json::from_slice(payload)
244 .map(Some)
245 .map_err(HybridCodecError::Serde)
246 }
247}
248
249#[derive(Debug, Error)]
250pub enum HybridCodecError {
251 #[error("max line length exceeded")]
252 MaxLineLengthExceeded,
253 #[error("missing Content-Length header")]
254 MissingContentLength,
255 #[error("invalid Content-Length value: {0}")]
256 InvalidContentLength(String),
257 #[error("invalid header frame: {0}")]
258 InvalidHeaderFrame(String),
259 #[error("serde error {0}")]
260 Serde(#[from] serde_json::Error),
261 #[error("io error {0}")]
262 Io(#[from] std::io::Error),
263}
264
265impl From<HybridCodecError> for std::io::Error {
266 fn from(value: HybridCodecError) -> Self {
267 match value {
268 HybridCodecError::MaxLineLengthExceeded
269 | HybridCodecError::MissingContentLength
270 | HybridCodecError::InvalidContentLength(_)
271 | HybridCodecError::InvalidHeaderFrame(_) => {
272 std::io::Error::new(std::io::ErrorKind::InvalidData, value)
273 }
274 HybridCodecError::Serde(error) => error.into(),
275 HybridCodecError::Io(error) => error,
276 }
277 }
278}
279
280fn looks_like_content_length_frame(buf: &BytesMut) -> bool {
281 let prefix = &buf[..buf.len().min(32)];
282 prefix
283 .windows(b"content-length".len())
284 .next()
285 .is_some_and(|candidate| candidate.eq_ignore_ascii_case(b"content-length"))
286}
287
288fn find_header_terminator(buf: &BytesMut) -> Option<(usize, usize)> {
289 if let Some(index) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
290 return Some((index, 4));
291 }
292 buf.windows(2)
293 .position(|window| window == b"\n\n")
294 .map(|index| (index, 2))
295}
296
297fn parse_content_length(header: &str) -> Result<usize, HybridCodecError> {
298 for raw_line in header.lines() {
299 let line = raw_line.trim_end_matches('\r');
300 let Some((name, value)) = line.split_once(':') else {
301 continue;
302 };
303 if name.trim().eq_ignore_ascii_case("content-length") {
304 return value
305 .trim()
306 .parse::<usize>()
307 .map_err(|_| HybridCodecError::InvalidContentLength(value.trim().to_string()));
308 }
309 }
310
311 Err(HybridCodecError::MissingContentLength)
312}
313
314impl<T: DeserializeOwned> HybridJsonRpcMessageCodec<T> {
315 fn decode_content_length(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
316 let Some((header_end, delimiter_len)) = find_header_terminator(buf) else {
317 return Ok(None);
318 };
319
320 let header = std::str::from_utf8(&buf[..header_end])
321 .map_err(|error| HybridCodecError::InvalidHeaderFrame(error.to_string()))?;
322 let content_length = parse_content_length(header)?;
323 if content_length > self.max_length {
324 return Err(HybridCodecError::MaxLineLengthExceeded);
325 }
326 let body_start = header_end + delimiter_len;
327 let frame_len = body_start
328 .checked_add(content_length)
329 .ok_or(HybridCodecError::MaxLineLengthExceeded)?;
330 if buf.len() < frame_len {
331 return Ok(None);
332 }
333
334 let frame = buf.split_to(frame_len);
335 let payload = &frame[body_start..];
336 self.protocol.set_if_unset(WireProtocol::ContentLength);
337
338 try_parse_with_compatibility(payload, "decode_content_length")
339 }
340
341 fn decode_json_line(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
342 loop {
343 let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len());
344 let newline_offset = buf[self.next_index..read_to]
345 .iter()
346 .position(|byte| *byte == b'\n');
347
348 match (self.is_discarding, newline_offset) {
349 (true, Some(offset)) => {
350 buf.advance(offset + self.next_index + 1);
351 self.is_discarding = false;
352 self.next_index = 0;
353 }
354 (true, None) => {
355 buf.advance(read_to);
356 self.next_index = 0;
357 if buf.is_empty() {
358 return Ok(None);
359 }
360 }
361 (false, Some(offset)) => {
362 let newline_index = offset + self.next_index;
363 self.next_index = 0;
364 let line = buf.split_to(newline_index + 1);
365 let line = &line[..line.len() - 1];
366 let payload = without_carriage_return(line);
367 self.protocol.set_if_unset(WireProtocol::JsonLine);
368
369 if let Some(item) = try_parse_with_compatibility(payload, "decode_json_line")? {
370 return Ok(Some(item));
371 }
372 }
373 (false, None) if buf.len() > self.max_length => {
374 self.is_discarding = true;
375 return Err(HybridCodecError::MaxLineLengthExceeded);
376 }
377 (false, None) => {
378 self.next_index = read_to;
379 return Ok(None);
380 }
381 }
382 }
383 }
384}
385
386impl<T: DeserializeOwned> Decoder for HybridJsonRpcMessageCodec<T> {
387 type Item = T;
388 type Error = HybridCodecError;
389
390 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
391 match self.protocol.get() {
392 Some(WireProtocol::ContentLength) => self.decode_content_length(buf),
393 Some(WireProtocol::JsonLine) => self.decode_json_line(buf),
394 None => {
395 if looks_like_content_length_frame(buf) {
396 self.decode_content_length(buf)
397 } else {
398 self.decode_json_line(buf)
399 }
400 }
401 }
402 }
403
404 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<T>, HybridCodecError> {
405 match self.protocol.get() {
406 Some(WireProtocol::ContentLength) if !buf.is_empty() => self.decode_content_length(buf),
407 _ => Ok(if let Some(frame) = self.decode(buf)? {
408 Some(frame)
409 } else {
410 self.next_index = 0;
411 if buf.is_empty() || buf == &b"\r"[..] {
412 None
413 } else {
414 let line = buf.split_to(buf.len());
415 let payload = without_carriage_return(&line);
416 try_parse_with_compatibility(payload, "decode_eof")?
417 }
418 }),
419 }
420 }
421}
422
423impl<T: Serialize> Encoder<T> for HybridJsonRpcMessageCodec<T> {
424 type Error = HybridCodecError;
425
426 fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), HybridCodecError> {
427 let payload = serde_json::to_vec(&item)?;
428
429 match self.protocol.get().unwrap_or(WireProtocol::ContentLength) {
430 WireProtocol::ContentLength => {
431 buf.extend_from_slice(
432 format!("Content-Length: {}\r\n\r\n", payload.len()).as_bytes(),
433 );
434 buf.extend_from_slice(&payload);
435 }
436 WireProtocol::JsonLine => {
437 buf.extend_from_slice(&payload);
438 buf.put_u8(b'\n');
439 }
440 }
441
442 Ok(())
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 use tokio_util::bytes::BytesMut;
451
452 fn sample_message() -> serde_json::Value {
453 serde_json::json!({
454 "jsonrpc": "2.0",
455 "id": 1,
456 "method": "initialize",
457 "params": {
458 "protocolVersion": "2024-11-05",
459 "capabilities": {},
460 "clientInfo": {
461 "name": "probe",
462 "version": "0.0.0"
463 }
464 }
465 })
466 }
467
468 #[test]
469 fn decodes_json_line_and_marks_protocol() {
470 let protocol = SharedProtocol::new();
471 let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol.clone());
472 let payload = serde_json::to_vec(&sample_message()).unwrap();
473 let mut buf = BytesMut::from(&payload[..]);
474 buf.put_u8(b'\n');
475
476 let item = codec.decode(&mut buf).unwrap();
477 assert!(item.is_some());
478 assert_eq!(protocol.get(), Some(WireProtocol::JsonLine));
479 }
480
481 #[test]
482 fn decodes_content_length_and_marks_protocol() {
483 let protocol = SharedProtocol::new();
484 let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol.clone());
485 let payload = serde_json::to_vec(&sample_message()).unwrap();
486 let mut frame = BytesMut::new();
487 frame.extend_from_slice(format!("Content-Length: {}\r\n\r\n", payload.len()).as_bytes());
488 frame.extend_from_slice(&payload);
489
490 let item = codec.decode(&mut frame).unwrap();
491 assert!(item.is_some());
492 assert_eq!(protocol.get(), Some(WireProtocol::ContentLength));
493 }
494
495 #[test]
496 fn encodes_using_content_length_when_protocol_is_detected() {
497 let protocol = SharedProtocol::new();
498 protocol.set_if_unset(WireProtocol::ContentLength);
499 let mut codec = HybridJsonRpcMessageCodec::<serde_json::Value>::new(protocol);
500 let mut buf = BytesMut::new();
501 codec
502 .encode(
503 serde_json::json!({"jsonrpc":"2.0","id":1,"result":{"ok":true}}),
504 &mut buf,
505 )
506 .unwrap();
507
508 assert!(std::str::from_utf8(&buf)
509 .unwrap()
510 .starts_with("Content-Length: "));
511 }
512}