use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::stream::SplitStream;
use futures::{SinkExt, StreamExt};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Stdin, Stdout};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::client_sender::ResponseMap;
use crate::lifecycle::{ExitCode, LifecycleState, ProtocolError};
use crate::ClientSender;
use crate::{transport, Message, RequestQueue, Transport};
fn spawn_drain_task<T>(
mut sender: futures::stream::SplitSink<Transport<T>, Message>,
) -> (mpsc::UnboundedSender<Message>, CancellationToken)
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (tx, mut rx) = mpsc::unbounded_channel();
let drain_alive = CancellationToken::new();
let drain_guard = drain_alive.clone().drop_guard();
tokio::spawn(async move {
let _guard = drain_guard;
while let Some(message) = rx.recv().await {
if sender.send(message).await.is_err() {
return;
}
}
});
(tx, drain_alive)
}
pub type StdioConnection<I = ()> = Connection<StdioTransport, I>;
pin_project! {
pub struct MessageStream<T> {
#[pin]
inner: SplitStream<Transport<T>>,
}
}
impl<T> MessageStream<T> {
fn new(inner: SplitStream<Transport<T>>) -> Self {
Self { inner }
}
}
impl<T: AsyncRead + AsyncWrite> futures::Stream for MessageStream<T> {
type Item = Result<Message, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
pub struct Connection<T, I = ()>
where
T: AsyncRead + AsyncWrite,
{
outbound_tx: mpsc::UnboundedSender<Message>,
receiver: MessageStream<T>,
pub request_queue: RequestQueue<I>,
lifecycle_state: LifecycleState,
shutdown_token: CancellationToken,
response_map: Option<ResponseMap>,
drain_alive: CancellationToken,
}
impl<T, I> Connection<T, I>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
pub fn new(io: T) -> Self {
let transport = transport(io);
Self::from_transport(transport)
}
pub fn from_transport(transport: Transport<T>) -> Self {
let (sender, receiver) = transport.split();
let (tx, drain_alive) = spawn_drain_task(sender);
Self {
outbound_tx: tx,
receiver: MessageStream::new(receiver),
request_queue: RequestQueue::new(),
lifecycle_state: LifecycleState::default(),
shutdown_token: CancellationToken::new(),
response_map: None,
drain_alive,
}
}
pub fn with_request_queue(io: T, request_queue: RequestQueue<I>) -> Self {
let transport = transport(io);
let (sender, receiver) = transport.split();
let (tx, drain_alive) = spawn_drain_task(sender);
Self {
outbound_tx: tx,
receiver: MessageStream::new(receiver),
request_queue,
lifecycle_state: LifecycleState::default(),
shutdown_token: CancellationToken::new(),
response_map: None,
drain_alive,
}
}
#[must_use]
pub fn lifecycle_state(&self) -> LifecycleState {
self.lifecycle_state
}
#[must_use]
pub fn shutdown_token(&self) -> CancellationToken {
self.shutdown_token.clone()
}
pub fn on_shutdown(&self) -> impl std::future::Future<Output = ()> + '_ {
self.shutdown_token.cancelled()
}
}
impl<T, I> Connection<T, I>
where
T: AsyncRead + AsyncWrite,
{
pub fn receiver_mut(&mut self) -> &mut MessageStream<T> {
&mut self.receiver
}
#[must_use]
pub fn into_receiver(self) -> MessageStream<T> {
self.receiver
}
pub fn send(&self, message: Message) -> Result<(), io::Error> {
self.outbound_tx
.send(message)
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "connection closed"))
}
}
impl<T, I> Connection<T, I>
where
T: AsyncRead + AsyncWrite,
I: Default,
{
pub fn route(&mut self, message: Message) -> crate::IncomingMessage {
match message {
Message::Request(req) => {
let token = self.shutdown_token.child_token();
self.request_queue
.incoming
.register(req.id.clone(), I::default(), token.clone());
crate::IncomingMessage::Request(req, token)
}
Message::Notification(notif) => {
if notif.method == crate::request_queue::CANCEL_REQUEST_METHOD {
if let Some(id) = crate::parse_cancel_params(¬if.params) {
let _ = self.request_queue.incoming.cancel(&id);
crate::IncomingMessage::CancelHandled
} else {
crate::IncomingMessage::Notification(notif)
}
} else {
crate::IncomingMessage::Notification(notif)
}
}
Message::Response(resp) => {
if let Some(id) = resp.id.clone() {
if let Some(response_map) = self.response_map.as_ref() {
if response_map.contains(&id) && response_map.deliver(&id, resp.clone()) {
return crate::IncomingMessage::ResponseRouted;
}
}
if self.request_queue.outgoing.is_pending(&id) {
self.request_queue.outgoing.complete(&id, resp);
crate::IncomingMessage::ResponseRouted
} else {
crate::IncomingMessage::ResponseUnknown(resp)
}
} else {
crate::IncomingMessage::ResponseUnknown(resp)
}
}
}
}
pub fn cancel_incoming(&mut self, id: impl Into<crate::RequestId>) -> bool {
self.request_queue.incoming.cancel(&id.into())
}
#[must_use]
pub fn client_sender(&mut self) -> ClientSender {
if let Some(ref response_map) = self.response_map {
return ClientSender::new(
self.outbound_tx.clone(),
response_map.clone(),
self.drain_alive.clone(),
);
}
let response_map = ResponseMap::new();
self.response_map = Some(response_map.clone());
ClientSender::new(
self.outbound_tx.clone(),
response_map,
self.drain_alive.clone(),
)
}
}
impl<T, I> Connection<T, I>
where
T: AsyncRead + AsyncWrite,
{
pub fn cancel(&mut self, id: impl Into<crate::RequestId>) -> Result<bool, std::io::Error> {
use crate::request_queue::CANCEL_REQUEST_METHOD;
let id = id.into();
let notification =
crate::Notification::new(CANCEL_REQUEST_METHOD, Some(serde_json::json!({"id": id})));
self.send(Message::Notification(notification))?;
Ok(self.request_queue.outgoing.cancel(&id))
}
}
impl<T, I> Connection<T, I>
where
T: AsyncRead + AsyncWrite,
{
pub async fn initialize_start(
&mut self,
) -> Result<(crate::RequestId, serde_json::Value), ProtocolError> {
loop {
match self.receiver_mut().next().await {
Some(Ok(Message::Request(req))) => {
if req.method == "initialize" {
self.lifecycle_state = LifecycleState::Initializing;
return Ok((req.id, req.params.unwrap_or(serde_json::Value::Null)));
}
let error = crate::ResponseError::new(
crate::ErrorCode::ServerNotInitialized,
"Server not yet initialized",
);
let response = Message::Response(crate::Response::err(req.id, error));
if let Err(e) = self.send(response) {
return Err(ProtocolError::Io(e));
}
}
Some(Ok(Message::Notification(notif))) => {
if notif.method == "exit" {
return Err(ProtocolError::Disconnected);
}
}
Some(Ok(Message::Response(_))) => {
}
Some(Err(e)) => {
return Err(ProtocolError::Io(e));
}
None => {
return Err(ProtocolError::Disconnected);
}
}
}
}
pub async fn initialize_finish(
&mut self,
id: crate::RequestId,
initialize_result: serde_json::Value,
) -> Result<(), ProtocolError> {
use std::time::Duration;
let response = Message::Response(crate::Response::ok(id, initialize_result));
if let Err(e) = self.send(response) {
return Err(ProtocolError::Io(e));
}
tokio::time::timeout(Duration::from_mins(1), async {
loop {
match self.receiver_mut().next().await {
Some(Ok(Message::Notification(notif))) => {
if notif.method == "initialized" {
self.lifecycle_state = LifecycleState::Running;
return Ok(());
}
}
Some(Ok(Message::Request(req))) => {
let error = crate::ResponseError::new(
crate::ErrorCode::ServerNotInitialized,
"Server not yet initialized",
);
let response = Message::Response(crate::Response::err(req.id, error));
if let Err(e) = self.send(response) {
return Err(ProtocolError::Io(e));
}
}
Some(Ok(Message::Response(_))) => {
}
Some(Err(e)) => {
return Err(ProtocolError::Io(e));
}
None => {
return Err(ProtocolError::Disconnected);
}
}
}
})
.await
.map_err(|_| ProtocolError::InitializeTimeout)?
}
pub async fn initialize(
&mut self,
server_capabilities: serde_json::Value,
) -> Result<serde_json::Value, ProtocolError> {
let (id, params) = self.initialize_start().await?;
let initialize_result = serde_json::json!({
"capabilities": server_capabilities
});
self.initialize_finish(id, initialize_result).await?;
Ok(params)
}
pub fn handle_shutdown(&mut self, id: crate::RequestId) -> Result<(), ProtocolError> {
self.shutdown_token.cancel();
self.lifecycle_state = LifecycleState::ShuttingDown;
let response = Message::Response(crate::Response::ok(id, serde_json::Value::Null));
if let Err(e) = self.send(response) {
return Err(ProtocolError::Io(e));
}
Ok(())
}
pub fn handle_exit(&mut self) -> ExitCode {
let was_shutting_down = self.lifecycle_state == LifecycleState::ShuttingDown;
self.lifecycle_state = LifecycleState::Exited;
if was_shutting_down {
ExitCode::Success
} else {
ExitCode::Error
}
}
#[must_use]
pub fn is_running(&self) -> bool {
self.lifecycle_state == LifecycleState::Running
}
#[must_use]
pub fn is_shutting_down(&self) -> bool {
self.lifecycle_state == LifecycleState::ShuttingDown
}
}
pin_project! {
pub struct StdioTransport {
#[pin]
stdin: Stdin,
#[pin]
stdout: Stdout,
}
}
impl StdioTransport {
#[must_use]
pub fn new() -> Self {
Self {
stdin: tokio::io::stdin(),
stdout: tokio::io::stdout(),
}
}
}
impl Default for StdioTransport {
fn default() -> Self {
Self::new()
}
}
impl AsyncRead for StdioTransport {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().stdin.poll_read(cx, buf)
}
}
impl AsyncWrite for StdioTransport {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().stdout.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().stdout.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().stdout.poll_shutdown(cx)
}
}
impl Connection<StdioTransport, ()> {
#[must_use]
pub fn stdio() -> Self {
Self::new(StdioTransport::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Request, Response, StdioConnection};
use serde_json::json;
use std::time::Duration;
#[tokio::test]
async fn connection_from_duplex_test() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let request = Message::Request(Request::new(1, "test", None));
client.send(request).unwrap();
let received = server.receiver_mut().next().await.unwrap().unwrap();
assert!(received.is_request());
if let Message::Request(req) = received {
assert_eq!(req.method, "test");
assert_eq!(req.id, 1.into());
} else {
panic!("Expected Request");
}
}
#[tokio::test]
async fn connection_bidirectional_test() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let request = Message::Request(Request::new(
1,
"textDocument/hover",
Some(json!({
"textDocument": {"uri": "file:///test.rs"},
"position": {"line": 10, "character": 5}
})),
));
client.send(request).unwrap();
let received = server.receiver_mut().next().await.unwrap().unwrap();
assert!(received.is_request());
let response = Message::Response(Response::ok(
1,
json!({
"contents": "fn main()"
}),
));
server.send(response).unwrap();
let received = client.receiver_mut().next().await.unwrap().unwrap();
assert!(received.is_response());
if let Message::Response(resp) = received {
assert_eq!(resp.id, Some(1.into()));
assert!(resp.result().is_some());
} else {
panic!("Expected Response");
}
}
#[tokio::test]
async fn connection_from_transport_test() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client_transport = transport(client_stream);
let server_transport = transport(server_stream);
let client: Connection<_, ()> = Connection::from_transport(client_transport);
let mut server: Connection<_, ()> = Connection::from_transport(server_transport);
let request = Message::Request(Request::new(42, "test", None));
client.send(request).unwrap();
let received = server.receiver_mut().next().await.unwrap().unwrap();
assert!(received.is_request());
if let Message::Request(req) = received {
assert_eq!(req.id, 42.into());
}
}
#[tokio::test]
async fn connection_multiple_messages_test() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let msg1 = Message::Request(Request::new(1, "first", None));
let msg2 = Message::Request(Request::new(2, "second", None));
let msg3 = Message::Request(Request::new(3, "third", None));
client.send(msg1).unwrap();
client.send(msg2).unwrap();
client.send(msg3).unwrap();
let recv1 = server.receiver_mut().next().await.unwrap().unwrap();
let recv2 = server.receiver_mut().next().await.unwrap().unwrap();
let recv3 = server.receiver_mut().next().await.unwrap().unwrap();
if let Message::Request(r) = recv1 {
assert_eq!(r.method, "first");
assert_eq!(r.id, 1.into());
} else {
panic!("Expected request");
}
if let Message::Request(r) = recv2 {
assert_eq!(r.method, "second");
assert_eq!(r.id, 2.into());
} else {
panic!("Expected request");
}
if let Message::Request(r) = recv3 {
assert_eq!(r.method, "third");
assert_eq!(r.id, 3.into());
} else {
panic!("Expected request");
}
}
#[tokio::test]
async fn sender_receiver_independent_test() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client: Connection<_, ()> = Connection::new(client_stream);
let server: Connection<_, ()> = Connection::new(server_stream);
let send_task =
tokio::spawn(
async move { client.send(Message::Request(Request::new(1, "test", None))) },
);
let mut server_receiver = server.receiver;
let recv_task = tokio::spawn(async move { server_receiver.next().await });
let (send_result, recv_result) = tokio::join!(send_task, recv_task);
send_result.unwrap().unwrap();
assert!(recv_result.unwrap().unwrap().unwrap().is_request());
}
#[tokio::test]
async fn connection_has_request_queue_test() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, String> = Connection::new(stream);
let token = CancellationToken::new();
conn.request_queue
.incoming
.register(1.into(), "handler_data".to_string(), token);
assert!(conn.request_queue.incoming.is_pending(&1.into()));
let data = conn.request_queue.incoming.complete(&1.into());
assert_eq!(data, Some("handler_data".to_string()));
}
#[tokio::test]
async fn connection_with_request_queue_test() {
let (stream, _) = tokio::io::duplex(4096);
let mut queue: RequestQueue<u32> = RequestQueue::new();
let token = CancellationToken::new();
queue.incoming.register(42.into(), 100, token);
let conn = Connection::with_request_queue(stream, queue);
assert!(conn.request_queue.incoming.is_pending(&42.into()));
}
#[tokio::test]
async fn connection_outgoing_request_queue_test() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let rx = conn.request_queue.outgoing.register(1.into());
assert!(conn.request_queue.outgoing.is_pending(&1.into()));
let completed = conn.request_queue.outgoing.complete(
&1.into(),
Response::ok(1, serde_json::json!("response data")),
);
assert!(completed);
let response = rx.await.unwrap();
assert_eq!(response.id, Some(1.into()));
assert_eq!(
response.result().cloned(),
Some(serde_json::json!("response data"))
);
}
#[test]
fn stdio_transport_constructible() {
let _transport = StdioTransport::new();
let _transport_default = StdioTransport::default();
}
use crate::{ExitCode, LifecycleState, Notification, ProtocolError};
#[tokio::test]
async fn test_initialize_handshake() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let init_params = json!({"processId": 1234, "capabilities": {}});
let init_request =
Message::Request(Request::new(1, "initialize", Some(init_params.clone())));
client.send(init_request).unwrap();
let (id, params) = server.initialize_start().await.unwrap();
assert_eq!(id, 1.into());
assert_eq!(params["processId"], 1234);
assert_eq!(server.lifecycle_state(), LifecycleState::Initializing);
let server_task = tokio::spawn(async move {
let result = json!({"capabilities": {"textDocumentSync": 1}});
server.initialize_finish(id, result).await.unwrap();
server
});
let response = client.receiver_mut().next().await.unwrap().unwrap();
assert!(response.is_response());
if let Message::Response(resp) = response {
assert_eq!(resp.id, Some(1.into()));
assert!(resp.result().is_some());
let result = resp.into_result().unwrap();
assert_eq!(result["capabilities"]["textDocumentSync"], 1);
}
let initialized = Message::Notification(Notification::new("initialized", None));
client.send(initialized).unwrap();
let server = server_task.await.unwrap();
assert!(server.is_running());
assert_eq!(server.lifecycle_state(), LifecycleState::Running);
}
#[tokio::test]
async fn stdio_connection_alias_constructs_with_custom_metadata() {
let conn: StdioConnection<String> = Connection::new(StdioTransport::new());
assert_eq!(conn.lifecycle_state(), LifecycleState::Uninitialized);
assert!(!conn.request_queue.incoming.is_pending(&1.into()));
assert!(!conn.request_queue.outgoing.is_pending(&1.into()));
}
#[tokio::test]
async fn test_initialize_rejects_non_init_requests() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let hover_request = Message::Request(Request::new(1, "textDocument/hover", None));
client.send(hover_request).unwrap();
let server_task = tokio::spawn(async move {
server.initialize_start().await.unwrap();
server
});
let response = client.receiver_mut().next().await.unwrap().unwrap();
assert!(response.is_response());
if let Message::Response(resp) = response {
assert_eq!(resp.id, Some(1.into()));
assert!(resp.error().is_some());
let error = resp.into_error().unwrap();
assert_eq!(error.code, crate::ErrorCode::ServerNotInitialized as i32);
}
let init_request = Message::Request(Request::new(2, "initialize", None));
client.send(init_request).unwrap();
let server = server_task.await.unwrap();
assert_eq!(server.lifecycle_state(), LifecycleState::Initializing);
}
#[tokio::test(start_paused = true)]
async fn test_initialize_finish_times_out_without_initialized() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
client
.send(Message::Request(Request::new(1, "initialize", None)))
.unwrap();
let (id, _params) = server.initialize_start().await.unwrap();
let server_task = tokio::spawn(async move {
server
.initialize_finish(id, json!({"capabilities": {}}))
.await
});
let response = client.receiver_mut().next().await.unwrap().unwrap();
assert!(response.is_response());
tokio::time::advance(Duration::from_secs(61)).await;
let result = server_task.await.unwrap();
assert!(matches!(result, Err(ProtocolError::InitializeTimeout)));
}
#[tokio::test]
async fn test_initialize_drops_notifications() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let random_notif = Message::Notification(Notification::new("textDocument/didOpen", None));
client.send(random_notif).unwrap();
let init_request = Message::Request(Request::new(1, "initialize", None));
client.send(init_request).unwrap();
let (id, _params) = server.initialize_start().await.unwrap();
assert_eq!(id, 1.into());
assert_eq!(server.lifecycle_state(), LifecycleState::Initializing);
}
#[tokio::test]
async fn test_exit_during_init_disconnects() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let exit_notif = Message::Notification(Notification::new("exit", None));
client.send(exit_notif).unwrap();
let result = server.initialize_start().await;
assert!(matches!(result, Err(ProtocolError::Disconnected)));
}
#[tokio::test]
async fn test_shutdown_then_exit() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let init_request = Message::Request(Request::new(1, "initialize", None));
client.send(init_request).unwrap();
let (id, _params) = server.initialize_start().await.unwrap();
let server_task = tokio::spawn(async move {
server
.initialize_finish(id, json!({"capabilities": {}}))
.await
.unwrap();
server
});
let _ = client.receiver_mut().next().await; let initialized = Message::Notification(Notification::new("initialized", None));
client.send(initialized).unwrap();
let mut server = server_task.await.unwrap();
assert!(server.is_running());
let shutdown_request = Message::Request(Request::new(2, "shutdown", None));
client.send(shutdown_request).unwrap();
let msg = server.receiver_mut().next().await.unwrap().unwrap();
if let Message::Request(req) = msg {
assert_eq!(req.method, "shutdown");
server.handle_shutdown(req.id).unwrap();
} else {
panic!("Expected shutdown request");
}
assert!(server.is_shutting_down());
assert!(server.shutdown_token().is_cancelled());
let response = client.receiver_mut().next().await.unwrap().unwrap();
if let Message::Response(resp) = response {
assert_eq!(resp.id, Some(2.into()));
assert_eq!(resp.result().cloned(), Some(serde_json::Value::Null));
}
let exit_code = server.handle_exit();
assert_eq!(exit_code, ExitCode::Success);
assert_eq!(server.lifecycle_state(), LifecycleState::Exited);
}
#[tokio::test]
async fn test_exit_without_shutdown() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let init_request = Message::Request(Request::new(1, "initialize", None));
client.send(init_request).unwrap();
let (id, _params) = server.initialize_start().await.unwrap();
let server_task = tokio::spawn(async move {
server
.initialize_finish(id, json!({"capabilities": {}}))
.await
.unwrap();
server
});
let _ = client.receiver_mut().next().await;
let initialized = Message::Notification(Notification::new("initialized", None));
client.send(initialized).unwrap();
let mut server = server_task.await.unwrap();
assert!(server.is_running());
let exit_code = server.handle_exit();
assert_eq!(exit_code, ExitCode::Error);
assert_eq!(server.lifecycle_state(), LifecycleState::Exited);
}
#[tokio::test]
async fn test_on_shutdown_future() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let init_request = Message::Request(Request::new(1, "initialize", None));
client.send(init_request).unwrap();
let (id, _params) = server.initialize_start().await.unwrap();
let server_task = tokio::spawn(async move {
server
.initialize_finish(id, json!({"capabilities": {}}))
.await
.unwrap();
server
});
let _ = client.receiver_mut().next().await;
let initialized = Message::Notification(Notification::new("initialized", None));
client.send(initialized).unwrap();
let mut server = server_task.await.unwrap();
let token = server.shutdown_token();
let wait_task = tokio::spawn(async move {
token.cancelled().await;
"shutdown received"
});
let shutdown_request = Message::Request(Request::new(2, "shutdown", None));
client.send(shutdown_request).unwrap();
let msg = server.receiver_mut().next().await.unwrap().unwrap();
if let Message::Request(req) = msg {
server.handle_shutdown(req.id).unwrap();
}
let result = tokio::time::timeout(std::time::Duration::from_millis(100), wait_task)
.await
.expect("wait task should complete quickly")
.unwrap();
assert_eq!(result, "shutdown received");
}
use crate::IncomingMessage;
#[tokio::test]
async fn route_request_returns_incoming_request() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let request = Request::new(42, "textDocument/hover", Some(json!({"line": 10})));
let message = Message::Request(request);
let result = conn.route(message);
match result {
IncomingMessage::Request(req, token) => {
assert_eq!(req.id, 42.into());
assert_eq!(req.method, "textDocument/hover");
assert!(!token.is_cancelled());
assert!(conn.request_queue.incoming.is_pending(&42.into()));
}
_ => panic!("Expected IncomingMessage::Request"),
}
}
#[tokio::test]
async fn route_notification_returns_incoming_notification() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let notification = Notification::new(
"textDocument/didOpen",
Some(json!({"uri": "file:///test.rs"})),
);
let message = Message::Notification(notification);
let result = conn.route(message);
match result {
IncomingMessage::Notification(notif) => {
assert_eq!(notif.method, "textDocument/didOpen");
}
_ => panic!("Expected IncomingMessage::Notification"),
}
}
#[tokio::test]
async fn route_response_to_pending_outgoing_request() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let rx = conn.request_queue.outgoing.register(42.into());
let response = Response::ok(42, json!({"result": "success"}));
let message = Message::Response(response);
let result = conn.route(message);
assert!(
matches!(result, IncomingMessage::ResponseRouted),
"Expected ResponseRouted, got {result:?}"
);
let received = rx.await.expect("Should receive response");
assert_eq!(received.id, Some(42.into()));
assert!(received.result().is_some());
assert_eq!(received.into_result().unwrap()["result"], "success");
}
#[tokio::test]
async fn route_response_for_unknown_id_returns_response_unknown() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let response = Response::ok(999, json!({"unexpected": true}));
let message = Message::Response(response);
let result = conn.route(message);
match result {
IncomingMessage::ResponseUnknown(resp) => {
assert_eq!(resp.id, Some(999.into()));
}
_ => panic!("Expected IncomingMessage::ResponseUnknown"),
}
}
#[tokio::test]
async fn route_response_with_null_id_returns_response_unknown() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let response = Response::parse_error(crate::ResponseError::new(
crate::ErrorCode::ParseError,
"Parse error",
));
let message = Message::Response(response);
let result = conn.route(message);
match result {
IncomingMessage::ResponseUnknown(resp) => {
assert!(
resp.id.is_none(),
"Expected null id for parse error response"
);
assert!(resp.error().is_some());
}
_ => panic!("Expected IncomingMessage::ResponseUnknown"),
}
}
#[tokio::test]
async fn route_response_with_string_id() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let rx = conn.request_queue.outgoing.register("request-abc".into());
let response = Response::ok("request-abc", json!(null));
let message = Message::Response(response);
let result = conn.route(message);
assert!(matches!(result, IncomingMessage::ResponseRouted));
let received = rx.await.expect("Should receive response");
assert_eq!(
received.id,
Some(crate::RequestId::String("request-abc".to_string()))
);
}
#[tokio::test]
async fn route_multiple_responses_to_different_requests() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let rx1 = conn.request_queue.outgoing.register(1.into());
let rx2 = conn.request_queue.outgoing.register(2.into());
let rx3 = conn.request_queue.outgoing.register(3.into());
let result2 = conn.route(Message::Response(Response::ok(2, json!("second"))));
assert!(matches!(result2, IncomingMessage::ResponseRouted));
let result1 = conn.route(Message::Response(Response::ok(1, json!("first"))));
assert!(matches!(result1, IncomingMessage::ResponseRouted));
let result3 = conn.route(Message::Response(Response::ok(3, json!("third"))));
assert!(matches!(result3, IncomingMessage::ResponseRouted));
let resp1 = rx1.await.unwrap();
assert_eq!(resp1.into_result().unwrap(), json!("first"));
let resp2 = rx2.await.unwrap();
assert_eq!(resp2.into_result().unwrap(), json!("second"));
let resp3 = rx3.await.unwrap();
assert_eq!(resp3.into_result().unwrap(), json!("third"));
}
#[tokio::test]
async fn route_error_response_to_pending_request() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let _rx = conn.request_queue.outgoing.register(42.into());
let response = Response::err(
42,
crate::ResponseError::new(crate::ErrorCode::MethodNotFound, "Not found"),
);
let message = Message::Response(response);
let result = conn.route(message);
assert!(matches!(result, IncomingMessage::ResponseRouted));
}
#[tokio::test]
async fn route_cancel_request_returns_cancel_handled_and_cancels_pending() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let request = Request::new(42, "test", None);
let routed = conn.route(Message::Request(request));
let IncomingMessage::Request(_, token) = routed else {
panic!("expected Request");
};
assert!(!token.is_cancelled());
let cancel = Notification::new("$/cancelRequest", Some(json!({"id": 42})));
let result = conn.route(Message::Notification(cancel));
assert!(matches!(result, IncomingMessage::CancelHandled));
assert!(token.is_cancelled());
}
#[tokio::test]
async fn route_cancel_request_unknown_id_still_returns_cancel_handled() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let cancel = Notification::new("$/cancelRequest", Some(json!({"id": 99})));
let result = conn.route(Message::Notification(cancel));
assert!(matches!(result, IncomingMessage::CancelHandled));
}
#[tokio::test]
async fn route_regular_notification_not_affected() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let notif = Notification::new("textDocument/didOpen", None);
let result = conn.route(Message::Notification(notif));
assert!(matches!(result, IncomingMessage::Notification(_)));
}
#[tokio::test]
async fn cancellation_propagates_to_spawned_handler() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let request = Request::new(1, "test", None);
let result = conn.route(Message::Request(request));
let IncomingMessage::Request(_, token) = result else {
panic!("Expected IncomingMessage::Request")
};
let handle = tokio::spawn(async move {
token.cancelled().await;
"cancelled"
});
let _ = conn.request_queue.incoming.cancel(&1.into());
let result = tokio::time::timeout(std::time::Duration::from_millis(100), handle)
.await
.expect("Handler should complete quickly")
.unwrap();
assert_eq!(result, "cancelled");
}
#[tokio::test]
async fn route_request_auto_registers_and_cancellation_works() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let request = Request::new(42, "test", None);
let result = conn.route(Message::Request(request));
let IncomingMessage::Request(_, token) = result else {
panic!("Expected IncomingMessage::Request")
};
assert!(!token.is_cancelled());
let was_cancelled = conn.cancel_incoming(42);
assert!(was_cancelled);
assert!(token.is_cancelled());
}
#[tokio::test]
async fn test_cancel_outgoing_request() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let rx = client.request_queue.outgoing.register(42.into());
let was_pending = client.cancel(42).unwrap();
assert!(was_pending);
assert!(!client.request_queue.outgoing.is_pending(&42.into()));
let msg = server.receiver_mut().next().await.unwrap().unwrap();
assert!(msg.is_notification());
if let Message::Notification(notif) = msg {
assert_eq!(notif.method, "$/cancelRequest");
assert_eq!(notif.params.unwrap()["id"], 42);
} else {
panic!("Expected notification");
}
assert!(rx.await.is_err());
}
#[tokio::test]
async fn test_cancel_unknown_outgoing_request() {
let (client_stream, _server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let was_pending = client.cancel(999).unwrap();
assert!(!was_pending);
}
#[tokio::test]
async fn test_cancel_with_string_id() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let rx = client.request_queue.outgoing.register("req-abc".into());
let was_pending = client.cancel("req-abc").unwrap();
assert!(was_pending);
let msg = server.receiver_mut().next().await.unwrap().unwrap();
if let Message::Notification(notif) = msg {
assert_eq!(notif.params.unwrap()["id"], "req-abc");
} else {
panic!("Expected notification");
}
assert!(rx.await.is_err());
}
#[tokio::test]
async fn client_sender_messages_arrive_on_receiver() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let sender = server.client_sender();
sender
.notify(
"window/logMessage",
Some(json!({"type": 3, "message": "hello"})),
)
.unwrap();
let msg = client.receiver_mut().next().await.unwrap().unwrap();
assert!(msg.is_notification());
if let Message::Notification(notif) = msg {
assert_eq!(notif.method, "window/logMessage");
assert_eq!(notif.params.unwrap()["message"], "hello");
}
}
#[tokio::test]
async fn client_sender_request_routed_through_response_map() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let sender = server.client_sender();
let sender_clone = sender.clone();
let req_task = tokio::spawn(async move {
sender_clone
.request(
"client/registerCapability",
Some(json!({"registrations": []})),
)
.await
});
let msg = client.receiver_mut().next().await.unwrap().unwrap();
let req_id = if let Message::Request(req) = msg {
assert_eq!(req.method, "client/registerCapability");
req.id.clone()
} else {
panic!("Expected Request");
};
client
.send(Message::Response(Response::ok(req_id.clone(), json!(null))))
.unwrap();
let resp_msg = server.receiver_mut().next().await.unwrap().unwrap();
let routed = server.route(resp_msg);
assert!(
matches!(routed, IncomingMessage::ResponseRouted),
"Expected ResponseRouted, got {routed:?}"
);
let result = tokio::time::timeout(Duration::from_secs(1), req_task)
.await
.expect("request should complete")
.unwrap()
.unwrap();
assert_eq!(result.id, Some(req_id));
}
#[tokio::test]
async fn route_prefers_response_map_over_outgoing_queue() {
let (client_stream, _server_stream) = tokio::io::duplex(4096);
let mut server: Connection<_, ()> = Connection::new(client_stream);
let sender = server.client_sender();
let mut outgoing_rx = server.request_queue.outgoing.register(1.into());
let sender_clone = sender.clone();
let req_task = tokio::spawn(async move { sender_clone.request("test", None).await });
tokio::task::yield_now().await;
tokio::task::yield_now().await;
let response = Response::ok(1, json!("via-response-map"));
let result = server.route(Message::Response(response));
assert!(matches!(result, IncomingMessage::ResponseRouted));
let resp = tokio::time::timeout(Duration::from_secs(1), req_task)
.await
.expect("should complete")
.unwrap()
.unwrap();
assert_eq!(resp.result().cloned(), Some(json!("via-response-map")));
assert!(outgoing_rx.try_recv().is_err());
}
#[tokio::test]
async fn send_works_before_and_after_client_sender() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let server: Connection<_, ()> = Connection::new(server_stream);
let mut client: Connection<_, ()> = Connection::new(client_stream);
server
.send(Message::Notification(Notification::new(
"test/before",
None,
)))
.unwrap();
let msg = client.receiver_mut().next().await.unwrap().unwrap();
if let Message::Notification(notif) = msg {
assert_eq!(notif.method, "test/before");
}
}
#[tokio::test]
async fn client_sender_can_be_called_multiple_times() {
let (stream, _) = tokio::io::duplex(4096);
let mut conn: Connection<_, ()> = Connection::new(stream);
let sender1 = conn.client_sender();
let sender2 = conn.client_sender(); drop(sender1);
drop(sender2);
}
#[tokio::test]
async fn send_works_after_client_sender() {
let (client_stream, server_stream) = tokio::io::duplex(4096);
let mut server: Connection<_, ()> = Connection::new(server_stream);
let mut client: Connection<_, ()> = Connection::new(client_stream);
let _sender = server.client_sender();
server
.send(Message::Notification(Notification::new("test/ping", None)))
.unwrap();
let msg = client.receiver_mut().next().await.unwrap().unwrap();
assert!(msg.is_notification());
if let Message::Notification(notif) = msg {
assert_eq!(notif.method, "test/ping");
}
}
}