use {
super::handshake::Handshake,
crate::protocol::jsonrpc::Message,
futures_lite::io::{
AsyncRead,
AsyncReadExt,
AsyncWrite,
AsyncWriteExt,
},
std::io,
};
const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const MAX_HANDSHAKE_SIZE: usize = 4096;
pub async fn send_message<W>(writer: &mut W, msg: &Message) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let bytes = serde_json::to_vec(msg)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let len = (bytes.len() as u32).to_le_bytes();
writer.write_all(&len).await?;
writer.write_all(&bytes).await?;
writer.flush().await?;
Ok(())
}
pub async fn recv_message<R>(reader: &mut R) -> io::Result<Option<Message>>
where
R: AsyncRead + Unpin,
{
let mut len_buf = [0u8; 4];
match reader.read_exact(&mut len_buf).await {
| Ok(()) => {},
| Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
| Err(e) => return Err(e),
}
let len = u32::from_le_bytes(len_buf) as usize;
if len > MAX_MESSAGE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"message too large: {} bytes (max {})",
len, MAX_MESSAGE_SIZE
),
));
}
let mut buf = vec![0u8; len];
reader.read_exact(&mut buf).await?;
let msg = serde_json::from_slice(&buf)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
Ok(Some(msg))
}
pub async fn send_handshake<W>(
writer: &mut W,
handshake: &Handshake,
) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let bytes = serde_json::to_vec(handshake)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let len = (bytes.len() as u32).to_le_bytes();
writer.write_all(&len).await?;
writer.write_all(&bytes).await?;
writer.flush().await?;
Ok(())
}
pub async fn recv_handshake<R>(reader: &mut R) -> io::Result<Handshake>
where
R: AsyncRead + Unpin,
{
let mut len_buf = [0u8; 4];
reader.read_exact(&mut len_buf).await?;
let len = u32::from_le_bytes(len_buf) as usize;
if len > MAX_HANDSHAKE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"handshake too large: {} bytes (max {})",
len, MAX_HANDSHAKE_SIZE
),
));
}
let mut buf = vec![0u8; len];
reader.read_exact(&mut buf).await?;
serde_json::from_slice(&buf)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
#[cfg(test)]
mod tests {
use {
super::*,
crate::protocol::{
jsonrpc::{
Notification,
Request,
Response,
},
lsp::LSPAny,
},
futures_lite::io::Cursor,
serde_json::json,
};
fn make_test_request() -> Message {
Message::Request(
Request::build("test/method", 1)
.params(json!({"key": "value"}))
.finish(),
)
}
fn make_test_notification() -> Message {
Message::Notification(
Notification::build("test/notif")
.params(json!({"data": 123}))
.finish(),
)
}
fn make_test_response() -> Message {
let result: LSPAny =
serde_json::from_value(json!({"result": "ok"})).unwrap();
Message::Response(Response::from_ok(1.into(), result))
}
#[test]
fn test_send_message_format() {
smol::block_on(async {
let msg = make_test_request();
let mut buf = Vec::new();
send_message(&mut buf, &msg).await.unwrap();
let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
assert_eq!(len, buf.len() - 4);
let json_bytes = &buf[4..];
let parsed: Message = serde_json::from_slice(json_bytes).unwrap();
assert_eq!(parsed, msg);
});
}
#[test]
fn test_recv_message_success() {
smol::block_on(async {
let msg = make_test_request();
let json_bytes = serde_json::to_vec(&msg).unwrap();
let len = (json_bytes.len() as u32).to_le_bytes();
let mut data = Vec::new();
data.extend_from_slice(&len);
data.extend_from_slice(&json_bytes);
let mut cursor = Cursor::new(data);
let received = recv_message(&mut cursor).await.unwrap().unwrap();
assert_eq!(received, msg);
});
}
#[test]
fn test_recv_message_eof() {
smol::block_on(async {
let mut cursor = Cursor::new(Vec::<u8>::new());
let result = recv_message(&mut cursor).await.unwrap();
assert!(result.is_none());
});
}
#[test]
fn test_recv_message_too_large() {
smol::block_on(async {
let len = (MAX_MESSAGE_SIZE + 1) as u32;
let mut data = Vec::new();
data.extend_from_slice(&len.to_le_bytes());
let mut cursor = Cursor::new(data);
let result = recv_message(&mut cursor).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("too large"));
});
}
#[test]
fn test_send_recv_roundtrip_request() {
smol::block_on(async {
let msg = make_test_request();
let mut buf = Vec::new();
send_message(&mut buf, &msg).await.unwrap();
let mut cursor = Cursor::new(buf);
let received = recv_message(&mut cursor).await.unwrap().unwrap();
assert_eq!(received, msg);
});
}
#[test]
fn test_send_recv_roundtrip_notification() {
smol::block_on(async {
let msg = make_test_notification();
let mut buf = Vec::new();
send_message(&mut buf, &msg).await.unwrap();
let mut cursor = Cursor::new(buf);
let received = recv_message(&mut cursor).await.unwrap().unwrap();
assert_eq!(received, msg);
});
}
#[test]
fn test_send_recv_roundtrip_response() {
smol::block_on(async {
let msg = make_test_response();
let mut buf = Vec::new();
send_message(&mut buf, &msg).await.unwrap();
let mut cursor = Cursor::new(buf);
let received = recv_message(&mut cursor).await.unwrap().unwrap();
assert_eq!(received, msg);
});
}
#[test]
fn test_send_handshake_format() {
smol::block_on(async {
let handshake = Handshake::new("v1.0.0-abc123");
let mut buf = Vec::new();
send_handshake(&mut buf, &handshake).await.unwrap();
let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
assert_eq!(len, buf.len() - 4);
let json_bytes = &buf[4..];
let parsed: Handshake = serde_json::from_slice(json_bytes).unwrap();
assert_eq!(parsed, handshake);
});
}
#[test]
fn test_recv_handshake_success() {
smol::block_on(async {
let handshake = Handshake::new("v1.0.0-abc123");
let json_bytes = serde_json::to_vec(&handshake).unwrap();
let len = (json_bytes.len() as u32).to_le_bytes();
let mut data = Vec::new();
data.extend_from_slice(&len);
data.extend_from_slice(&json_bytes);
let mut cursor = Cursor::new(data);
let received = recv_handshake(&mut cursor).await.unwrap();
assert_eq!(received, handshake);
});
}
#[test]
fn test_recv_handshake_too_large() {
smol::block_on(async {
let len = (MAX_HANDSHAKE_SIZE + 1) as u32;
let mut data = Vec::new();
data.extend_from_slice(&len.to_le_bytes());
let mut cursor = Cursor::new(data);
let result = recv_handshake(&mut cursor).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("too large"));
});
}
#[test]
fn test_handshake_roundtrip() {
smol::block_on(async {
let handshake = Handshake::new("commit-deadbeef");
let mut buf = Vec::new();
send_handshake(&mut buf, &handshake).await.unwrap();
let mut cursor = Cursor::new(buf);
let received = recv_handshake(&mut cursor).await.unwrap();
assert_eq!(received, handshake);
});
}
#[test]
fn test_frame_partial_length_eof() {
smol::block_on(async {
let mut cursor = Cursor::new(vec![0x10, 0x00]);
let result = recv_message(&mut cursor).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
});
}
#[test]
fn test_frame_partial_payload() {
smol::block_on(async {
let len: u32 = 100;
let mut data = Vec::new();
data.extend_from_slice(&len.to_le_bytes());
data.extend_from_slice(&[0u8; 10]);
let mut cursor = Cursor::new(data);
let result = recv_message(&mut cursor).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
});
}
#[test]
fn test_frame_sizes() {
smol::block_on(async {
let sizes = [0usize, 1, 127, 128, 255, 256, 65535];
for &size in &sizes {
let padding = "x".repeat(size);
let notif = Notification::build("test")
.params(json!({"pad": padding}))
.finish();
let msg = Message::Notification(notif);
let mut buf = Vec::new();
send_message(&mut buf, &msg).await.unwrap();
let mut cursor = Cursor::new(buf);
let received = recv_message(&mut cursor).await.unwrap().unwrap();
assert_eq!(received, msg);
}
});
}
#[test]
fn test_multiple_messages_sequential() {
smol::block_on(async {
let msg1 = make_test_request();
let msg2 = make_test_notification();
let msg3 = make_test_response();
let mut buf = Vec::new();
send_message(&mut buf, &msg1).await.unwrap();
send_message(&mut buf, &msg2).await.unwrap();
send_message(&mut buf, &msg3).await.unwrap();
let mut cursor = Cursor::new(buf);
let recv1 = recv_message(&mut cursor).await.unwrap().unwrap();
let recv2 = recv_message(&mut cursor).await.unwrap().unwrap();
let recv3 = recv_message(&mut cursor).await.unwrap().unwrap();
let recv4 = recv_message(&mut cursor).await.unwrap();
assert_eq!(recv1, msg1);
assert_eq!(recv2, msg2);
assert_eq!(recv3, msg3);
assert!(recv4.is_none()); });
}
#[test]
fn test_empty_json_object_message() {
smol::block_on(async {
let json_bytes = b"{}";
let len = (json_bytes.len() as u32).to_le_bytes();
let mut data = Vec::new();
data.extend_from_slice(&len);
data.extend_from_slice(json_bytes);
let mut cursor = Cursor::new(data);
let result = recv_message(&mut cursor).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
});
}
#[test]
fn test_zero_length_message() {
smol::block_on(async {
let len: u32 = 0;
let mut data = Vec::new();
data.extend_from_slice(&len.to_le_bytes());
let mut cursor = Cursor::new(data);
let result = recv_message(&mut cursor).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
});
}
#[test]
fn test_memory_stream_integration() {
smol::block_on(async {
use super::super::memory::MemoryTransport;
let transport = MemoryTransport::new();
let mut listener = transport.bind("framing-test").await.unwrap();
let t = transport.clone();
let client_handle = smol::spawn(async move {
let mut stream = t.connect("framing-test").await.unwrap();
let msg = make_test_request();
send_message(&mut stream, &msg).await.unwrap();
let response = recv_message(&mut stream).await.unwrap().unwrap();
assert!(matches!(response, Message::Response(_)));
});
let mut server = listener.accept().await.unwrap();
let request = recv_message(&mut server).await.unwrap().unwrap();
assert!(matches!(request, Message::Request(_)));
let response = make_test_response();
send_message(&mut server, &response).await.unwrap();
client_handle.await;
});
}
#[test]
fn test_handshake_then_messages() {
smol::block_on(async {
use super::super::memory::MemoryTransport;
let transport = MemoryTransport::new();
let mut listener = transport.bind("handshake-test").await.unwrap();
let t = transport.clone();
let client_handle = smol::spawn(async move {
let mut stream = t.connect("handshake-test").await.unwrap();
let handshake = Handshake::new("client-v1");
send_handshake(&mut stream, &handshake).await.unwrap();
let server_hs = recv_handshake(&mut stream).await.unwrap();
assert!(handshake.is_compatible(&server_hs));
let msg = make_test_request();
send_message(&mut stream, &msg).await.unwrap();
let response = recv_message(&mut stream).await.unwrap().unwrap();
assert!(matches!(response, Message::Response(_)));
});
let mut server = listener.accept().await.unwrap();
let client_hs = recv_handshake(&mut server).await.unwrap();
assert_eq!(client_hs.version, "client-v1");
let handshake = Handshake::new("client-v1"); send_handshake(&mut server, &handshake).await.unwrap();
let request = recv_message(&mut server).await.unwrap().unwrap();
assert!(matches!(request, Message::Request(_)));
let response = make_test_response();
send_message(&mut server, &response).await.unwrap();
client_handle.await;
});
}
}