use {
crate::{
connect::lsp::ClientKind,
protocol::{
ipc::{
Endpoint,
Handshake,
IpcReadHalf,
IpcStream,
IpcWriteHalf,
framing,
},
jsonrpc::Message,
},
},
async_channel::{Receiver, Sender},
crossbeam_channel::bounded,
std::{
collections::HashMap,
io::{self, BufReader, Write},
net::{TcpListener, TcpStream, ToSocketAddrs},
thread,
},
};
type WriterHandle = (
Sender<Message>,
thread::JoinHandle<io::Result<()>>,
crossbeam_channel::Receiver<Message>,
);
#[derive(Clone, Debug)]
pub struct Connection {
pub sender: Sender<Message>,
pub receiver: Receiver<Message>,
}
impl Connection {
pub fn stdio() -> io::Result<(Connection, IoThreads)> {
let (sender, receiver, io_threads) = stdio_transport()?;
Ok((Connection { sender, receiver }, io_threads))
}
pub fn connect<A: ToSocketAddrs>(
addr: A,
) -> io::Result<(Connection, IoThreads)> {
let stream = TcpStream::connect(addr)?;
let (sender, receiver, io_threads) = socket_transport(stream)?;
Ok((Connection { sender, receiver }, io_threads))
}
pub fn listen<A: ToSocketAddrs>(
addr: A,
) -> io::Result<(Connection, IoThreads)> {
let listener = TcpListener::bind(addr)?;
let (stream, _) = listener.accept()?;
let (sender, receiver, io_threads) = socket_transport(stream)?;
Ok((Connection { sender, receiver }, io_threads))
}
pub fn memory() -> (Connection, Connection) {
let (s1, r1) = async_channel::unbounded();
let (s2, r2) = async_channel::unbounded();
(
Connection {
sender: s1,
receiver: r2,
},
Connection {
sender: s2,
receiver: r1,
},
)
}
pub async fn ipc(
endpoint: &Endpoint,
version: impl Into<String>,
) -> io::Result<(Connection, IpcHandle)> {
Self::ipc_as(endpoint, version, ClientKind::Cli, HashMap::new()).await
}
pub async fn ipc_as(
endpoint: &Endpoint,
version: impl Into<String>,
client_kind: ClientKind,
metadata: HashMap<String, String>,
) -> io::Result<(Connection, IpcHandle)> {
let stream = IpcStream::connect(endpoint).await?;
let handshake = Handshake::with_metadata(version, client_kind, metadata);
let (mut reader, mut writer) = stream.into_split();
framing::send_handshake(&mut writer, &handshake).await?;
let peer = framing::recv_handshake(&mut reader).await?;
if !handshake.is_compatible(&peer) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"version mismatch: local={}, remote={}",
handshake.version, peer.version
),
));
}
let (client_sender, to_ipc) = async_channel::unbounded::<Message>();
let (from_ipc, client_receiver) = async_channel::unbounded::<Message>();
let handle = IpcHandle::spawn(reader, writer, to_ipc, from_ipc);
Ok((
Connection {
sender: client_sender,
receiver: client_receiver,
},
handle,
))
}
}
#[derive(Debug)]
pub struct IoThreads {
reader: thread::JoinHandle<io::Result<()>>,
writer: thread::JoinHandle<io::Result<()>>,
dropper: thread::JoinHandle<()>,
}
impl IoThreads {
pub fn join(self) -> io::Result<()> {
match self.reader.join() {
| Ok(r) => r?,
| Err(err) => std::panic::panic_any(err),
}
match self.dropper.join() {
| Ok(_) => (),
| Err(err) => {
std::panic::panic_any(err);
},
}
match self.writer.join() {
| Ok(r) => r,
| Err(err) => {
std::panic::panic_any(err);
},
}
}
}
pub struct IpcHandle {
reader_task: smol::Task<io::Result<()>>,
writer_task: smol::Task<io::Result<()>>,
}
impl std::fmt::Debug for IpcHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IpcHandle").finish_non_exhaustive()
}
}
impl IpcHandle {
fn spawn(
mut reader: IpcReadHalf,
mut writer: IpcWriteHalf,
to_ipc: Receiver<Message>,
from_ipc: Sender<Message>,
) -> Self {
let writer_task = smol::spawn(async move {
while let Ok(msg) = to_ipc.recv().await {
framing::send_message(&mut writer, &msg).await?;
}
Ok(())
});
let reader_task = smol::spawn(async move {
while let Some(msg) = framing::recv_message(&mut reader).await? {
if from_ipc.send(msg).await.is_err() {
break;
}
}
Ok(())
});
Self {
reader_task,
writer_task,
}
}
pub async fn join(self) -> io::Result<()> {
self.reader_task.await?;
self.writer_task.await
}
}
fn stdio_transport()
-> io::Result<(Sender<Message>, Receiver<Message>, IoThreads)> {
let (drop_sender, drop_receiver) = bounded::<Message>(0);
let (writer_sender, writer_receiver) = async_channel::unbounded::<Message>();
let writer = thread::Builder::new()
.name("LaburnetWriter".to_owned())
.spawn(move || {
let stdout = io::stdout();
let mut stdout = stdout.lock();
while let Ok(msg) = writer_receiver.recv_blocking() {
let result = msg.write(&mut stdout);
let _ = drop_sender.send(msg);
result?;
}
Ok(())
})
.map_err(io::Error::other)?;
let dropper = thread::Builder::new()
.name("MessageDropper".to_owned())
.spawn(move || drop_receiver.into_iter().for_each(drop))
.map_err(io::Error::other)?;
let (reader_sender, reader_receiver) = async_channel::unbounded::<Message>();
let reader = thread::Builder::new()
.name("LaburnumReader".to_owned())
.spawn(move || {
let stdin = io::stdin();
let mut stdin = stdin.lock();
while let Some(msg) = Message::read(&mut stdin)? {
if let Err(e) = reader_sender.send_blocking(msg) {
otel::error!(
"reader_channel_send_failed",
format!("Failed to send message to reader channel: {:?}", e)
);
return Err(io::Error::other(e));
}
}
Ok(())
})
.map_err(io::Error::other)?;
let threads = IoThreads {
reader,
writer,
dropper,
};
Ok((writer_sender, reader_receiver, threads))
}
fn socket_transport(
stream: TcpStream,
) -> io::Result<(Sender<Message>, Receiver<Message>, IoThreads)> {
let cloned_stream = stream
.try_clone()
.map_err(|e| io::Error::other(format!("Failed to clone stream: {}", e)))?;
let (reader_receiver, reader) = make_reader(cloned_stream)?;
let (writer_sender, writer, messages_to_drop) = make_writer(stream)?;
let dropper = thread::Builder::new()
.name("MessageDropper".to_owned())
.spawn(move || {
messages_to_drop.into_iter().for_each(drop);
})
.map_err(io::Error::other)?;
let io_threads = IoThreads {
reader,
writer,
dropper,
};
Ok((writer_sender, reader_receiver, io_threads))
}
fn make_reader(
stream: TcpStream,
) -> io::Result<(Receiver<Message>, thread::JoinHandle<io::Result<()>>)> {
let (reader_sender, reader_receiver) = async_channel::unbounded::<Message>();
let reader = thread::Builder::new()
.name("SocketReader".to_owned())
.spawn(move || {
let mut buf_read = BufReader::new(stream);
while let Some(msg) = Message::read(&mut buf_read)? {
if reader_sender.send_blocking(msg).is_err() {
break;
}
}
Ok(())
})
.map_err(io::Error::other)?;
Ok((reader_receiver, reader))
}
fn make_writer(mut stream: TcpStream) -> io::Result<WriterHandle> {
let (writer_sender, writer_receiver) = async_channel::unbounded::<Message>();
let (drop_sender, drop_receiver) = bounded::<Message>(0);
let writer = thread::Builder::new()
.name("SocketWriter".to_owned())
.spawn(move || {
while let Ok(msg) = writer_receiver.recv_blocking() {
let result = msg.write(&mut stream);
let _ = drop_sender.send(msg);
result?;
}
Ok(())
})
.map_err(io::Error::other)?;
Ok((writer_sender, writer, drop_receiver))
}
pub async fn run_bridge(
endpoint: &Endpoint,
version: impl Into<String>,
metadata: HashMap<String, String>,
) -> io::Result<()> {
use std::time::Duration;
const RECONNECT_TIMEOUT: Duration = Duration::from_secs(15);
let version = version.into();
let (stdio_tx, stdio_rx) = async_channel::unbounded::<Message>();
let (to_stdout_tx, to_stdout_rx) = async_channel::unbounded::<Message>();
let _stdio_reader = thread::Builder::new()
.name("BridgeStdioReader".to_owned())
.spawn({
let tx = stdio_tx;
move || {
let stdin = io::stdin();
let mut stdin = stdin.lock();
while let Ok(Some(msg)) = Message::read(&mut stdin) {
if tx.send_blocking(msg).is_err() {
break;
}
}
}
})
.map_err(io::Error::other)?;
let _stdio_writer = thread::Builder::new()
.name("BridgeStdioWriter".to_owned())
.spawn({
let rx = to_stdout_rx;
move || {
let stdout = io::stdout();
let mut stdout = stdout.lock();
while let Ok(msg) = rx.recv_blocking() {
if msg.write(&mut stdout).is_err() {
break;
}
let _ = stdout.flush();
}
}
})
.map_err(io::Error::other)?;
loop {
let connect_result =
connect_with_retry(endpoint, &version, &metadata, RECONNECT_TIMEOUT).await;
let (mut ipc_reader, mut ipc_writer) = match connect_result {
| Ok(streams) => streams,
| Err(_) => {
return Ok(());
},
};
let (disconnect_tx, disconnect_rx) = async_channel::bounded::<()>(1);
let ipc_writer_task = {
let rx = stdio_rx.clone();
let disconnect_tx = disconnect_tx.clone();
smol::spawn(async move {
while let Ok(msg) = rx.recv().await {
if framing::send_message(&mut ipc_writer, &msg).await.is_err() {
let _ = disconnect_tx.try_send(());
break;
}
}
})
};
let ipc_reader_task = {
let tx = to_stdout_tx.clone();
let disconnect_tx = disconnect_tx;
smol::spawn(async move {
while let Ok(Some(msg)) = framing::recv_message(&mut ipc_reader).await {
if tx.send(msg).await.is_err() {
break;
}
}
let _ = disconnect_tx.try_send(());
})
};
let _ = disconnect_rx.recv().await;
ipc_writer_task.cancel().await;
ipc_reader_task.cancel().await;
if stdio_rx.is_closed() {
return Ok(());
}
}
}
async fn connect_with_retry(
endpoint: &Endpoint,
version: &str,
metadata: &HashMap<String, String>,
timeout: std::time::Duration,
) -> io::Result<(IpcReadHalf, IpcWriteHalf)> {
use std::time::Instant;
const RECONNECT_INTERVAL: std::time::Duration =
std::time::Duration::from_millis(100);
let start = Instant::now();
loop {
match IpcStream::connect(endpoint).await {
| Ok(stream) => {
let (mut ipc_reader, mut ipc_writer) = stream.into_split();
let handshake =
Handshake::with_metadata(version, ClientKind::Ide, metadata.clone());
if framing::send_handshake(&mut ipc_writer, &handshake)
.await
.is_err()
{
if start.elapsed() >= timeout {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"reconnect timeout during handshake send",
));
}
smol::Timer::after(RECONNECT_INTERVAL).await;
continue;
}
match framing::recv_handshake(&mut ipc_reader).await {
| Ok(peer) => {
if !handshake.is_compatible(&peer) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"version mismatch: local={}, remote={}",
handshake.version, peer.version
),
));
}
return Ok((ipc_reader, ipc_writer));
},
| Err(_) => {
if start.elapsed() >= timeout {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"reconnect timeout during handshake recv",
));
}
smol::Timer::after(RECONNECT_INTERVAL).await;
continue;
},
}
},
| Err(_) => {
if start.elapsed() >= timeout {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"reconnect timeout",
));
}
smol::Timer::after(RECONNECT_INTERVAL).await;
},
}
}
}
pub struct BridgeHandle {
ipc_writer_task: smol::Task<()>,
ipc_reader_task: smol::Task<()>,
_stdio_reader: thread::JoinHandle<()>,
_stdio_writer: thread::JoinHandle<()>,
}
impl BridgeHandle {
pub async fn start(
endpoint: &Endpoint,
version: impl Into<String>,
metadata: HashMap<String, String>,
) -> io::Result<Self> {
let stream = IpcStream::connect(endpoint).await?;
let (mut ipc_reader, mut ipc_writer) = stream.into_split();
let handshake = Handshake::with_metadata(version, ClientKind::Ide, metadata);
framing::send_handshake(&mut ipc_writer, &handshake).await?;
let peer = framing::recv_handshake(&mut ipc_reader).await?;
if !handshake.is_compatible(&peer) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"version mismatch: local={}, remote={}",
handshake.version, peer.version
),
));
}
let (stdio_tx, stdio_rx) = async_channel::unbounded::<Message>();
let (to_stdout_tx, to_stdout_rx) = async_channel::unbounded::<Message>();
let stdio_reader = thread::Builder::new()
.name("BridgeStdioReader".to_owned())
.spawn({
let tx = stdio_tx;
move || {
let stdin = io::stdin();
let mut stdin = stdin.lock();
while let Ok(Some(msg)) = Message::read(&mut stdin) {
if tx.send_blocking(msg).is_err() {
break;
}
}
}
})
.map_err(io::Error::other)?;
let stdio_writer = thread::Builder::new()
.name("BridgeStdioWriter".to_owned())
.spawn({
let rx = to_stdout_rx;
move || {
let stdout = io::stdout();
let mut stdout = stdout.lock();
while let Ok(msg) = rx.recv_blocking() {
if msg.write(&mut stdout).is_err() {
break;
}
let _ = stdout.flush();
}
}
})
.map_err(io::Error::other)?;
let ipc_writer_task = {
let rx = stdio_rx;
smol::spawn(async move {
while let Ok(msg) = rx.recv().await {
if framing::send_message(&mut ipc_writer, &msg).await.is_err() {
break;
}
}
})
};
let ipc_reader_task = {
let tx = to_stdout_tx;
smol::spawn(async move {
while let Ok(Some(msg)) = framing::recv_message(&mut ipc_reader).await {
if tx.send(msg).await.is_err() {
break;
}
}
})
};
Ok(Self {
ipc_writer_task,
ipc_reader_task,
_stdio_reader: stdio_reader,
_stdio_writer: stdio_writer,
})
}
pub async fn wait(self) {
smol::future::race(self.ipc_writer_task, self.ipc_reader_task).await;
}
}
impl std::fmt::Debug for BridgeHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BridgeHandle").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use {
super::*,
crate::protocol::{
ipc::{IpcListener, MemoryTransport},
jsonrpc::{Notification, Request, Response},
lsp::LSPAny,
},
serde_json::json,
};
fn make_test_request(id: i64) -> Message {
Message::Request(
Request::build("test/method", id)
.params(json!({"key": "value"}))
.finish(),
)
}
fn make_test_notification() -> Message {
Message::Notification(
Notification::build("test/notif")
.params(json!({"data": 123}))
.finish(),
)
}
fn make_test_response(id: i64) -> Message {
let result: LSPAny =
serde_json::from_value(json!({"result": "ok"})).unwrap();
Message::Response(Response::from_ok(id.into(), result))
}
#[test]
fn test_connection_ipc_success() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "ipc-test");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
let server_handle = smol::spawn(async move {
let stream = listener.accept().await.unwrap();
let (mut reader, mut writer) = stream.into_split();
let _client_hs = framing::recv_handshake(&mut reader).await.unwrap();
let server_hs = Handshake::new("v1.0.0");
framing::send_handshake(&mut writer, &server_hs)
.await
.unwrap();
let msg = framing::recv_message(&mut reader).await.unwrap().unwrap();
assert!(matches!(msg, Message::Request(_)));
let response = make_test_response(1);
framing::send_message(&mut writer, &response).await.unwrap();
});
let ep = Endpoint::memory(t, "ipc-test");
let (conn, _handle) = Connection::ipc(&ep, "v1.0.0").await.unwrap();
let request = make_test_request(1);
conn.sender.send(request).await.unwrap();
let response = conn.receiver.recv().await.unwrap();
assert!(matches!(response, Message::Response(_)));
drop(conn);
server_handle.await;
});
}
#[test]
fn test_connection_ipc_handshake_fail() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "ipc-mismatch");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
smol::spawn(async move {
let stream = listener.accept().await.unwrap();
let (mut reader, mut writer) = stream.into_split();
let _client_hs = framing::recv_handshake(&mut reader).await.unwrap();
let server_hs = Handshake::new("v2.0.0");
framing::send_handshake(&mut writer, &server_hs)
.await
.unwrap();
})
.detach();
let ep = Endpoint::memory(t, "ipc-mismatch");
let result = Connection::ipc(&ep, "v1.0.0").await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("version mismatch"));
});
}
#[test]
fn test_connection_ipc_connect_fail() {
smol::block_on(async {
let transport = MemoryTransport::new();
let ep = Endpoint::memory(transport, "nonexistent");
let result = Connection::ipc(&ep, "v1.0.0").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::ConnectionRefused);
});
}
#[test]
fn test_connection_ipc_bidirectional() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "ipc-bidir");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
let server_handle = smol::spawn(async move {
let stream = listener.accept().await.unwrap();
let (mut reader, mut writer) = stream.into_split();
let _client_hs = framing::recv_handshake(&mut reader).await.unwrap();
let server_hs = Handshake::new("test-version");
framing::send_handshake(&mut writer, &server_hs)
.await
.unwrap();
for i in 0..5 {
let msg = framing::recv_message(&mut reader).await.unwrap().unwrap();
assert!(matches!(msg, Message::Request(_)));
let response = make_test_response(i as i64);
framing::send_message(&mut writer, &response).await.unwrap();
}
});
let ep = Endpoint::memory(t, "ipc-bidir");
let (conn, _handle) = Connection::ipc(&ep, "test-version").await.unwrap();
for i in 0..5 {
let request = make_test_request(i);
conn.sender.send(request).await.unwrap();
let response = conn.receiver.recv().await.unwrap();
assert!(matches!(response, Message::Response(_)));
}
drop(conn);
server_handle.await;
});
}
#[test]
fn test_ipc_handle_join() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "ipc-join");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
let server_handle = smol::spawn(async move {
let stream = listener.accept().await.unwrap();
let (mut reader, mut writer) = stream.into_split();
let _client_hs = framing::recv_handshake(&mut reader).await.unwrap();
let server_hs = Handshake::new("test-v1");
framing::send_handshake(&mut writer, &server_hs)
.await
.unwrap();
let _msg = framing::recv_message(&mut reader).await.unwrap().unwrap();
let response = make_test_response(1);
framing::send_message(&mut writer, &response).await.unwrap();
drop(writer);
});
let ep = Endpoint::memory(t, "ipc-join");
let (conn, handle) = Connection::ipc(&ep, "test-v1").await.unwrap();
conn.sender.send(make_test_request(1)).await.unwrap();
let _ = conn.receiver.recv().await.unwrap();
drop(conn.sender);
server_handle.await;
let result = handle.join().await;
assert!(result.is_ok());
});
}
#[test]
fn test_connection_tcp_listen_and_connect() {
use std::{io::BufReader, net::TcpListener, thread};
let listener =
TcpListener::bind("127.0.0.1:0").expect("bind to random port");
let addr = listener.local_addr().expect("get local addr");
let server_thread = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept connection");
let msg = make_test_request(1);
msg.write(&mut stream).expect("write request");
let mut reader = BufReader::new(&stream);
let response = Message::read(&mut reader)
.expect("read response")
.expect("response not none");
assert!(matches!(response, Message::Response(_)));
});
let (conn, io_threads) =
Connection::connect(addr).expect("connect to server");
let received = conn.receiver.recv_blocking().expect("receive request");
assert!(matches!(received, Message::Request(_)));
let response = make_test_response(1);
conn.sender.send_blocking(response).expect("send response");
drop(conn);
server_thread.join().expect("server thread join");
io_threads.join().expect("io threads join");
}
#[test]
fn test_connection_tcp_connect_refused() {
let result = Connection::connect("127.0.0.1:1");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.kind() == io::ErrorKind::ConnectionRefused
|| err.kind() == io::ErrorKind::AddrNotAvailable
|| err.kind() == io::ErrorKind::PermissionDenied,
"expected connection error, got: {:?}",
err.kind()
);
}
#[test]
fn test_connection_tcp_multiple_messages() {
use std::{io::BufReader, net::TcpListener, thread};
let listener =
TcpListener::bind("127.0.0.1:0").expect("bind to random port");
let addr = listener.local_addr().expect("get local addr");
let server_thread = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept connection");
for i in 1..=5 {
let msg = make_test_request(i);
msg.write(&mut stream).expect("write request");
}
let mut reader = BufReader::new(&stream);
for i in 1..=5i64 {
let response = Message::read(&mut reader)
.expect("read response")
.expect("response not none");
if let Message::Response(r) = response {
assert_eq!(r.id(), &crate::protocol::jsonrpc::Id::from(i));
} else {
panic!("expected response, got: {:?}", response);
}
}
});
let (conn, io_threads) =
Connection::connect(addr).expect("connect to server");
for _ in 0..5 {
let received = conn.receiver.recv_blocking().expect("receive request");
if let Message::Request(r) = received {
if let crate::protocol::jsonrpc::Id::Number(n) = r.id() {
let response = make_test_response(*n);
conn.sender.send_blocking(response).expect("send response");
} else {
panic!("expected numeric id");
}
} else {
panic!("expected request");
}
}
drop(conn);
server_thread.join().expect("server thread join");
io_threads.join().expect("io threads join");
}
#[test]
fn test_connection_memory_pair() {
let (server_conn, client_conn) = Connection::memory();
let request = make_test_request(1);
client_conn
.sender
.send_blocking(request.clone())
.expect("send request");
let received = server_conn
.receiver
.recv_blocking()
.expect("receive request");
assert_eq!(received, request);
let response = make_test_response(1);
server_conn
.sender
.send_blocking(response.clone())
.expect("send response");
let received = client_conn
.receiver
.recv_blocking()
.expect("receive response");
assert_eq!(received, response);
}
#[test]
fn test_connection_memory_bidirectional() {
let (server_conn, client_conn) = Connection::memory();
for i in 1..=10 {
let request = make_test_request(i);
client_conn
.sender
.send_blocking(request)
.expect("send request");
let received = server_conn
.receiver
.recv_blocking()
.expect("receive request");
assert!(matches!(received, Message::Request(_)));
let response = make_test_response(i);
server_conn
.sender
.send_blocking(response)
.expect("send response");
let received = client_conn
.receiver
.recv_blocking()
.expect("receive response");
assert!(matches!(received, Message::Response(_)));
}
}
#[test]
fn test_bridge_handshake_success() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "bridge-success");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
let server_handle = smol::spawn(async move {
let stream = listener.accept().await.unwrap();
let (mut reader, mut writer) = stream.into_split();
let client_hs = framing::recv_handshake(&mut reader).await.unwrap();
assert_eq!(client_hs.version, "v1.0.0");
let server_hs = Handshake::new("v1.0.0");
framing::send_handshake(&mut writer, &server_hs)
.await
.unwrap();
drop(writer);
});
let ep = Endpoint::memory(t, "bridge-success");
let result = BridgeHandle::start(&ep, "v1.0.0", HashMap::new()).await;
assert!(result.is_ok(), "bridge should connect successfully");
let bridge = result.unwrap();
server_handle.await;
bridge.wait().await;
});
}
#[test]
fn test_bridge_ipc_message_forwarding() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "bridge-forward");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let (messages_sent_tx, messages_sent_rx) =
async_channel::bounded::<()>(1);
let t = transport.clone();
let server_handle = smol::spawn(async move {
let stream = listener.accept().await.unwrap();
let (mut reader, mut writer) = stream.into_split();
let client_hs = framing::recv_handshake(&mut reader).await.unwrap();
let server_hs = Handshake::new(client_hs.version);
framing::send_handshake(&mut writer, &server_hs)
.await
.unwrap();
let notification = make_test_notification();
framing::send_message(&mut writer, ¬ification)
.await
.unwrap();
let request = make_test_request(42);
framing::send_message(&mut writer, &request).await.unwrap();
let _ = messages_sent_tx.send(()).await;
drop(writer);
});
let ep = Endpoint::memory(t, "bridge-forward");
let bridge = BridgeHandle::start(&ep, "test-version", HashMap::new())
.await
.expect("bridge should connect");
messages_sent_rx.recv().await.unwrap();
server_handle.await;
bridge.wait().await;
});
}
#[test]
fn test_bridge_handshake_version_mismatch() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "bridge-mismatch");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
smol::spawn(async move {
let stream = listener.accept().await.unwrap();
let (mut reader, mut writer) = stream.into_split();
let _client_hs = framing::recv_handshake(&mut reader).await.unwrap();
let server_hs = Handshake::new("different-version");
framing::send_handshake(&mut writer, &server_hs)
.await
.unwrap();
})
.detach();
let ep = Endpoint::memory(transport, "bridge-mismatch");
let result = BridgeHandle::start(&ep, "my-version", HashMap::new()).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("version mismatch"));
});
}
#[test]
fn test_bridge_connect_fail() {
smol::block_on(async {
let transport = MemoryTransport::new();
let ep = Endpoint::memory(transport, "nonexistent");
let result = BridgeHandle::start(&ep, "v1.0.0", HashMap::new()).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::ConnectionRefused);
});
}
#[test]
fn test_bridge_server_disconnect() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "bridge-disconnect");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
let server_handle = smol::spawn(async move {
let stream = listener.accept().await.unwrap();
let (mut reader, mut writer) = stream.into_split();
let client_hs = framing::recv_handshake(&mut reader).await.unwrap();
let server_hs = Handshake::new(client_hs.version);
framing::send_handshake(&mut writer, &server_hs)
.await
.unwrap();
drop(writer);
drop(reader);
});
let ep = Endpoint::memory(t, "bridge-disconnect");
let bridge = BridgeHandle::start(&ep, "v1.0.0", HashMap::new())
.await
.expect("bridge should connect");
server_handle.await;
bridge.wait().await;
});
}
}