#![cfg(test)]
use crate::{
Call, Result,
connection::{
Connection,
socket::{ReadHalf, Socket, WriteHalf},
},
test_utils::mock_socket::{MockSocket, MockWriteHalf},
};
use alloc::vec::Vec;
use futures_util::{pin_mut, stream::StreamExt};
use rustix::{fd::AsFd, io::write};
use serde::{Deserialize, Serialize};
use std::os::unix::net::UnixStream;
#[tokio::test]
async fn chain_replies_with_fds() {
let (r1, _w1) = UnixStream::pair().unwrap();
let (r2, _w2) = UnixStream::pair().unwrap();
let reply1 = r#"{"parameters":{"id":1,"name":"Alice"}}"#;
let reply2 = r#"{"parameters":{"id":2,"name":"Bob"}}"#;
let fds = vec![vec![r1.into(), r2.into()]];
let socket = MockSocket::new(&[reply1, reply2], fds);
let mut conn = Connection::new(socket);
let call1 = Call::new(GetUser { id: 1 });
let call2 = Call::new(GetUser { id: 2 });
let replies = conn
.chain_call::<GetUser>(&call1, vec![])
.unwrap()
.append(&call2, vec![])
.unwrap()
.send::<User, ApiError>()
.await
.unwrap();
pin_mut!(replies);
let (reply1, fds1) = replies.next().await.unwrap().unwrap();
let reply1 = reply1.unwrap();
assert_eq!(reply1.parameters().unwrap().id, 1);
assert_eq!(fds1.len(), 2);
let (reply2, fds2) = replies.next().await.unwrap().unwrap();
let reply2 = reply2.unwrap();
assert_eq!(reply2.parameters().unwrap().id, 2);
assert_eq!(fds2.len(), 0); }
#[tokio::test]
async fn chain_replies_with_no_fds() {
let reply1 = r#"{"parameters":{"id":1,"name":"Alice"}}"#;
let reply2 = r#"{"parameters":{"id":2,"name":"Bob"}}"#;
let fds = vec![vec![]];
let socket = MockSocket::new(&[reply1, reply2], fds);
let mut conn = Connection::new(socket);
let call1 = Call::new(GetUser { id: 1 });
let call2 = Call::new(GetUser { id: 2 });
let replies = conn
.chain_call::<GetUser>(&call1, vec![])
.unwrap()
.append(&call2, vec![])
.unwrap()
.send::<User, ApiError>()
.await
.unwrap();
pin_mut!(replies);
let (reply1, fds1) = replies.next().await.unwrap().unwrap();
let reply1 = reply1.unwrap();
assert_eq!(reply1.parameters().unwrap().id, 1);
assert!(fds1.is_empty());
let (reply2, fds2) = replies.next().await.unwrap().unwrap();
let reply2 = reply2.unwrap();
assert_eq!(reply2.parameters().unwrap().id, 2);
assert!(fds2.is_empty());
}
#[tokio::test]
async fn chain_send_with_fds() {
let (r1, w1) = UnixStream::pair().unwrap();
let (r2, w2) = UnixStream::pair().unwrap();
write(w1.as_fd(), b"data1").unwrap();
write(w2.as_fd(), b"data2").unwrap();
let reply1 = r#"{"parameters":{"id":1,"name":"Alice"}}"#;
let reply2 = r#"{"parameters":{"id":2,"name":"Bob"}}"#;
let socket = MockSocket::new(&[reply1, reply2], vec![vec![]]);
let (read_half, write_half) = socket.split();
let tracking_write = TrackingWriteHalf { mock: write_half };
let mut conn = Connection::new(TrackingSocket {
read: read_half,
write: tracking_write,
});
let call1 = Call::new(GetUser { id: 1 });
let call2 = Call::new(GetUser { id: 2 });
let chain = conn
.chain_call::<GetUser>(&call1, vec![r1.into()])
.unwrap()
.append(&call2, vec![r2.into()])
.unwrap();
let replies = chain.send::<User, ApiError>().await.unwrap();
let reply_results: Vec<_> = {
pin_mut!(replies);
replies.collect().await
};
let fds_written = conn.write_mut().socket.mock.fds_written();
assert_eq!(fds_written.len(), 2, "Should have written FDs twice");
assert_eq!(fds_written[0].len(), 1, "First call should send 1 FD");
assert_eq!(fds_written[1].len(), 1, "Second call should send 1 FD");
let mut buf1 = [0u8; 5];
rustix::io::read(fds_written[0][0].as_fd(), &mut buf1).unwrap();
assert_eq!(&buf1, b"data1");
let mut buf2 = [0u8; 5];
rustix::io::read(fds_written[1][0].as_fd(), &mut buf2).unwrap();
assert_eq!(&buf2, b"data2");
assert_eq!(reply_results.len(), 2);
let (reply1, _fds) = reply_results[0].as_ref().unwrap();
let reply1 = reply1.as_ref().unwrap();
assert_eq!(reply1.parameters().unwrap().id, 1);
let (reply2, _fds) = reply_results[1].as_ref().unwrap();
let reply2 = reply2.as_ref().unwrap();
assert_eq!(reply2.parameters().unwrap().id, 2);
}
#[tokio::test]
async fn chain_oneway_call_with_fds() {
let reply1 = r#"{"parameters":{"id":1,"name":"Alice"}}"#;
let socket = MockSocket::new(&[reply1], vec![vec![]]);
let mut conn = Connection::new(socket);
let call1 = Call::new(GetUser { id: 1 });
let oneway_call = Call::new(GetUser { id: 2 }).set_oneway(true);
let replies = conn
.chain_call::<GetUser>(&call1, vec![])
.unwrap()
.append(&oneway_call, vec![])
.unwrap()
.send::<User, ApiError>()
.await
.unwrap();
pin_mut!(replies);
let (reply1, _fds) = replies.next().await.unwrap().unwrap();
let reply1 = reply1.unwrap();
assert_eq!(reply1.parameters().unwrap().id, 1);
assert!(replies.next().await.is_none());
}
#[tokio::test]
async fn chain_error_reply_with_fds() {
use crate::ReplyError;
#[derive(Debug, ReplyError)]
#[zlink(interface = "org.example")]
enum TestError {
NotFound { code: i32 },
}
let (r_fd, _w_fd) = UnixStream::pair().unwrap();
let error_reply = r#"{"error":"org.example.NotFound","parameters":{"code":404}}"#;
let fds = vec![vec![r_fd.into()]];
let socket = MockSocket::new(&[error_reply], fds);
let mut conn = Connection::new(socket);
let call1 = Call::new(GetUser { id: 1 });
let replies = conn
.chain_call::<GetUser>(&call1, vec![])
.unwrap()
.send::<User, TestError>()
.await
.unwrap();
pin_mut!(replies);
let (reply, fds) = replies.next().await.unwrap().unwrap();
assert!(reply.is_err());
assert_eq!(fds.len(), 1);
}
#[tokio::test]
async fn chain_receive_fds_from_server() {
let (r1, _w1) = UnixStream::pair().unwrap();
let (r2, _w2) = UnixStream::pair().unwrap();
let (r3, _w3) = UnixStream::pair().unwrap();
let reply1 = r#"{"parameters":{"id":1,"name":"Alice"}}"#;
let reply2 = r#"{"parameters":{"id":2,"name":"Bob"}}"#;
let reply3 = r#"{"parameters":{"id":3,"name":"Charlie"}}"#;
let fds = vec![vec![r1.into(), r2.into(), r3.into()]];
let socket = MockSocket::new(&[reply1, reply2, reply3], fds);
let mut conn = Connection::new(socket);
let call1 = Call::new(GetUser { id: 1 });
let call2 = Call::new(GetUser { id: 2 });
let call3 = Call::new(GetUser { id: 3 });
let replies = conn
.chain_call::<GetUser>(&call1, vec![])
.unwrap()
.append(&call2, vec![])
.unwrap()
.append(&call3, vec![])
.unwrap()
.send::<User, ApiError>()
.await
.unwrap();
pin_mut!(replies);
let results: Vec<_> = replies.collect().await;
assert_eq!(results.len(), 3);
let (reply, fds) = results[0].as_ref().unwrap();
let user = reply.as_ref().unwrap();
assert_eq!(user.parameters().unwrap().id, 1);
assert_eq!(fds.len(), 3);
for i in 1..3 {
let (reply, fds) = results[i].as_ref().unwrap();
let user = reply.as_ref().unwrap();
assert_eq!(user.parameters().unwrap().id, (i + 1) as u32);
assert_eq!(fds.len(), 0);
}
}
#[tokio::test]
async fn chain_fds_sent_only_with_their_message() {
let (r_fd, w_fd) = UnixStream::pair().unwrap();
write(w_fd.as_fd(), b"test").unwrap();
let reply1 = r#"{"parameters":{"id":1,"name":"Alice"}}"#;
let reply2 = r#"{"parameters":{"id":2,"name":"Bob"}}"#;
let reply3 = r#"{"parameters":{"id":3,"name":"Charlie"}}"#;
let socket = MockSocket::new(&[reply1, reply2, reply3], vec![vec![]]);
let (read_half, write_half) = socket.split();
let tracking_write = WriteOperationTracker {
mock: write_half,
operations: Vec::new(),
};
let mut conn = Connection::new(TrackingSocket {
read: read_half,
write: tracking_write,
});
let call1 = Call::new(GetUser { id: 1 });
let call2 = Call::new(GetUser { id: 2 });
let call3 = Call::new(GetUser { id: 3 });
let chain = conn
.chain_call::<GetUser>(&call1, vec![])
.unwrap()
.append(&call2, vec![r_fd.into()])
.unwrap()
.append(&call3, vec![])
.unwrap();
chain.connection.write_mut().flush().await.unwrap();
let ops = &mut conn.write_mut().socket.operations;
assert_eq!(ops.len(), 3, "Expected 3 write operations");
assert_eq!(ops[0].fd_count, 0, "call1 should have no FDs");
let call1_str = core::str::from_utf8(&ops[0].data[..ops[0].data.len() - 1]).unwrap();
assert!(
call1_str.contains("\"id\":1"),
"call1 write should contain id:1"
);
assert!(
ops[0].data.ends_with(b"\0"),
"call1 should be null-terminated"
);
assert_eq!(ops[1].fd_count, 1, "call2 should have 1 FD");
let call2_str = core::str::from_utf8(&ops[1].data[..ops[1].data.len() - 1]).unwrap();
assert!(
call2_str.contains("\"id\":2"),
"call2 write should contain id:2"
);
assert!(
ops[1].data.ends_with(b"\0"),
"call2 should be null-terminated"
);
assert!(
!call2_str.contains("\"id\":1"),
"call2 write should NOT contain id:1 from call1"
);
assert!(
!call2_str.contains("\"id\":3"),
"call2 write should NOT contain id:3 from call3"
);
assert_eq!(ops[2].fd_count, 0, "call3 should have no FDs");
let call3_str = core::str::from_utf8(&ops[2].data[..ops[2].data.len() - 1]).unwrap();
assert!(
call3_str.contains("\"id\":3"),
"call3 write should contain id:3"
);
assert!(
ops[2].data.ends_with(b"\0"),
"call3 should be null-terminated"
);
}
#[derive(Debug, Serialize, Deserialize)]
struct GetUser {
id: u32,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct User {
id: u32,
name: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct ApiError {
code: i32,
}
#[derive(Debug)]
struct TrackingSocket<R, W> {
read: R,
write: W,
}
impl<R: ReadHalf, W: WriteHalf> Socket for TrackingSocket<R, W> {
type ReadHalf = R;
type WriteHalf = W;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
(self.read, self.write)
}
}
#[derive(Debug)]
struct TrackingWriteHalf {
mock: MockWriteHalf,
}
impl WriteHalf for TrackingWriteHalf {
async fn write(&mut self, buf: &[u8], fds: &[impl AsFd]) -> Result<()> {
self.mock.write(buf, fds).await
}
}
#[derive(Debug)]
struct WriteOperation {
data: Vec<u8>,
fd_count: usize,
}
#[derive(Debug)]
struct WriteOperationTracker {
mock: MockWriteHalf,
operations: Vec<WriteOperation>,
}
impl WriteHalf for WriteOperationTracker {
async fn write(&mut self, buf: &[u8], fds: &[impl AsFd]) -> Result<()> {
self.operations.push(WriteOperation {
data: buf.to_vec(),
fd_count: fds.len(),
});
self.mock.write(buf, fds).await
}
}