use std::io::ErrorKind;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
use std::pin::Pin;
use std::process::id;
use std::time::Duration;
use async_std::future::{timeout, TimeoutError};
use async_std::os::unix::net::UnixStream;
use futures::future::Future;
use futures::future::{ready, select, try_join, try_join_all, Either};
use futures::pin_mut;
use futures::prelude::*;
use futures::task::{noop_waker_ref, Context, Poll};
use async_rustbus::conn::DBusAddr;
use async_rustbus::prime_future;
use async_rustbus::rustbus_core;
use async_rustbus::rustbus_core::wire::UnixFd;
use async_rustbus::CallAction;
use async_rustbus::{MatchRule, RpcConn, EMPTY_MATCH};
use rustbus_core::message_builder::{MarshalledMessage, MessageBuilder, MessageType};
use rustbus_core::wire::unmarshal::traits::Unmarshal;
use rand::prelude::*;
use rand::rngs::{OsRng, SmallRng};
const DBUS_NAME: &str = "io.test.dbus";
const DEFAULT_TO: Duration = Duration::from_secs(1);
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum TestingError {
Io(std::io::Error),
Bad(MarshalledMessage),
Timeout(TimeoutError),
}
impl From<std::io::Error> for TestingError {
fn from(err: std::io::Error) -> Self {
TestingError::Io(err)
}
}
impl From<TimeoutError> for TestingError {
fn from(err: TimeoutError) -> Self {
TestingError::Timeout(err)
}
}
#[async_std::test]
async fn session_wo_fd() -> std::io::Result<()> {
RpcConn::session_conn(false).await?;
Ok(())
}
#[async_std::test]
async fn session_w_fd() -> std::io::Result<()> {
RpcConn::session_conn(true).await?;
Ok(())
}
#[async_std::test]
async fn system_wo_fd() -> std::io::Result<()> {
RpcConn::system_conn(false).await?;
Ok(())
}
#[async_std::test]
async fn system_w_fd() -> std::io::Result<()> {
RpcConn::system_conn(true).await?;
Ok(())
}
#[async_std::test]
#[ignore]
async fn tcp_wo_fd() -> std::io::Result<()> {
let addr = DBusAddr::tcp_addr("localhost:29011");
RpcConn::connect_to_addr(&addr, false).await?;
Ok(())
}
#[async_std::test]
#[ignore]
async fn tcp_w_fd() -> std::io::Result<()> {
let addr = DBusAddr::tcp_addr("localhost:29011");
assert!(RpcConn::connect_to_addr(&addr, true).await.is_err());
Ok(())
}
#[async_std::test]
async fn get_name() -> Result<(), TestingError> {
use rustbus_core::standard_messages::request_name;
let conn = RpcConn::session_conn(false).await?;
let msg_fut = timeout(DEFAULT_TO, conn.send_msg(&request_name(DBUS_NAME, 0)))
.await??
.unwrap();
println!("Name Request sent");
let msg = timeout(DEFAULT_TO, msg_fut).await??;
match &msg.typ {
MessageType::Reply => {}
_ => return Err(TestingError::Bad(msg)),
}
let mut call = MessageBuilder::new()
.call(String::from("GetConnectionUnixProcessID"))
.with_interface(String::from("org.freedesktop.DBus"))
.on(String::from("/org/freedesktop/DBus"))
.at(String::from("org.freedesktop.DBus"))
.build();
call.body.push_param(DBUS_NAME).unwrap();
let res_fut = timeout(DEFAULT_TO, conn.send_msg(&call)).await??.unwrap();
let res = timeout(DEFAULT_TO, res_fut).await??;
match &res.typ {
MessageType::Reply => {
let i: u32 = res.body.parser().get().unwrap();
assert_eq!(i, id());
}
_ => return Err(TestingError::Bad(msg)),
}
call.dynheader.member = Some(String::from("GetNameOwner"));
let res_fut = timeout(DEFAULT_TO, conn.send_msg(&call)).await??.unwrap();
let res = timeout(DEFAULT_TO, res_fut).await??;
match &res.typ {
MessageType::Reply => {
let name: &str = res.body.parser().get().unwrap();
assert_eq!(name, conn.get_name());
Ok(())
}
_ => Err(TestingError::Bad(msg)),
}
}
#[async_std::test]
async fn get_mach_id() -> Result<(), TestingError> {
let conn = RpcConn::session_conn(false).await?;
let call = MessageBuilder::new()
.call(String::from("GetMachineId"))
.with_interface(String::from("org.freedesktop.DBus.Peer"))
.on(String::from("/org/freedesktop/DBus"))
.at(String::from("org.freedesktop.DBus"))
.build();
let msg = conn.send_msg(&call).await?.unwrap().await?;
match &msg.typ {
MessageType::Reply => {
let _s: &str = msg.body.parser().get().unwrap();
Ok(())
}
_ => Err(TestingError::Bad(msg)),
}
}
#[async_std::test]
#[ignore]
async fn no_recv_deadlock_overcut() -> Result<(), TestingError> {
let conn = RpcConn::session_conn(false).await?;
let mut call = MessageBuilder::new()
.call(String::from("Echo"))
.with_interface(String::from("org.freedesktop.DBus.Testing"))
.on(String::from("/"))
.build();
for i in 0..5u8 {
println!("no_recv_deadlock_overcut() iteration {}", i);
call.dynheader.destination = Some(String::from("io.test.LongWait"));
let long_fut = conn.send_msg(&call).await?.unwrap();
call.dynheader.destination = Some(String::from("io.test.ShortWait"));
let mut i = 0u8;
let short_fut = loop {
let short_fut = conn.send_msg_w_rsp(&call).await?.boxed();
match prime_future(short_fut) {
Either::Left(_) => {
i += 1;
if i >= 25 {
panic!("Failed to prime short fut!");
}
}
Either::Right(f) => break f,
}
};
let long_res = timeout(Duration::from_secs(2), long_fut).await??;
is_msg_reply(long_res)?;
is_msg_reply(short_fut.await?)?;
}
Ok(())
}
#[async_std::test]
#[ignore]
async fn no_recv_deadlock_shuffle() -> Result<(), TestingError> {
let rng: <SmallRng as SeedableRng>::Seed = OsRng.gen();
println!("no_recv_deadlock_shuffle(rng: {:X?})", rng);
let mut rng = SmallRng::from_seed(rng);
let conn = RpcConn::session_conn(false).await?;
let call = MessageBuilder::new()
.call(String::from("Echo"))
.with_interface(String::from("org.freedesktop.DBus.Testing"))
.on(String::from("/"))
.at("io.test.ShortWait")
.build();
for i in 0..5u8 {
println!("no_recv_deadlock_shuffle() iteration {}", i);
let mut futs = Vec::with_capacity(50);
for _ in 0..50u8 {
futs.push(conn.send_msg_w_rsp(&call).await?.boxed());
}
futs.shuffle(&mut rng);
let mut futs: Vec<_> = futs
.into_iter()
.filter_map(|mut f| {
let mut cx = Context::from_waker(noop_waker_ref());
match f.poll_unpin(&mut cx) {
Poll::Pending => Some(f),
Poll::Ready(_) => None,
}
})
.collect();
futs.shuffle(&mut rng);
for fut in futs {
let res = timeout(Duration::from_secs(2), fut).await??;
is_msg_reply(res)?;
}
}
Ok(())
}
#[async_std::test]
#[ignore]
async fn no_recv_deadlock_undercut() -> Result<(), TestingError> {
let conn = RpcConn::session_conn(false).await?;
let mut call = MessageBuilder::new()
.call(String::from("Echo"))
.with_interface(String::from("org.freedesktop.DBus.Testing"))
.on(String::from("/"))
.build();
for i in 0u8..5 {
println!(
"no_recv_deadlock_under(): iteration {}, sending messages",
i
);
call.dynheader.destination = Some(String::from("io.test.LongWait"));
let long_fut = conn.send_msg(&call).await?.unwrap();
call.dynheader.destination = Some(String::from("io.test.NoWait"));
let short_fut = conn.send_msg(&call).await?.unwrap();
pin_mut!(long_fut);
pin_mut!(short_fut);
println!(
"no_recv_deadlock_under(): iteration {}, awaiting responses",
i
);
let long_fut = match prime_future(long_fut) {
Either::Right(f) => f,
Either::Left(o) => panic!(
"no_recv_deadlock_overcut(): long_fut finished to early: {:?}",
o
),
};
let short_res = short_fut.await?;
is_msg_reply(short_res)?;
println!("no_recv_deadlock_undercut(): iteration {}: first recvd", i);
is_msg_reply(long_fut.await?)?;
}
Ok(())
}
#[async_std::test]
async fn fd_send_recv() -> Result<(), TestingError> {
let (conn, recv_conn) =
try_join(RpcConn::session_conn(true), RpcConn::session_conn(true)).await?;
recv_conn
.insert_call_path("/", CallAction::Queue)
.await
.unwrap();
let mut call = MessageBuilder::new()
.call(String::from("Echo"))
.with_interface(String::from("org.freedesktop.DBus.Testing"))
.on(String::from("/"))
.at(String::from(recv_conn.get_name()))
.build();
let (mut ours, theirs) = UnixStream::pair()?;
call.body
.push_param(UnixFd::new(theirs.into_raw_fd()))
.unwrap();
println!("fd_send_recv(): sending first msg");
let res_fut = conn.send_msg(&call).await?.unwrap();
println!("fd_send_recv(): get first msg");
let mut call_msg = recv_conn.get_call("/").await?;
call_msg.typ = MessageType::Reply;
call_msg.dynheader.response_serial = call_msg.dynheader.serial;
call_msg.dynheader.serial = None;
call_msg.dynheader.destination = Some(String::from(conn.get_name()));
println!("fd_send_recv(): sending first response");
assert!(recv_conn.send_msg(&call_msg).await?.is_none());
println!("fd_send_recv(): getting first response");
let res = res_fut.await?;
println!("Response sig: {:?}", res.dynheader.signature);
let unix_fd: UnixFd = res.body.parser().get().unwrap();
let mut theirs = unsafe { UnixStream::from_raw_fd(unix_fd.take_raw_fd().unwrap()) };
theirs.write_all(b"Hello World!").await?;
let mut buf = [0; 12];
timeout(Duration::from_millis(50), ours.read_exact(&mut buf[..])).await??;
Ok(())
}
#[async_std::test]
async fn send_fd_wo_fd_conn() -> Result<(), TestingError> {
let conn = RpcConn::session_conn(false).await?;
let mut call = MessageBuilder::new()
.call(String::from("Echo"))
.with_interface(String::from("org.freedesktop.DBus.Testing"))
.on(String::from("/"))
.at(String::from("io.test.NoWait"))
.build();
let (ours, theirs) = UnixStream::pair()?;
call.body
.push_param(UnixFd::new(theirs.into_raw_fd()))
.unwrap();
match conn.send_msg(&call).await {
Err(e) if e.kind() == ErrorKind::InvalidInput => { }
Err(e) => panic!("Received wrong error type: {:?}", e),
Ok(_) => panic!("Message was send successfully when it should fail."),
}
drop(ours);
Ok(())
}
#[async_std::test]
async fn no_send_deadlock_long() -> Result<(), TestingError> {
let (conn, recv_conn) =
try_join(RpcConn::session_conn(false), RpcConn::session_conn(false)).await?;
recv_conn
.insert_call_path("/", CallAction::Queue)
.await
.unwrap();
let mut call = MessageBuilder::new()
.call(String::from("Echo"))
.with_interface(String::from("org.freedesktop.DBus.Testing"))
.on(String::from("/"))
.at(recv_conn.get_name().to_string())
.build();
call.body.push_param(vec![1u8; 16 * 1024 * 1024]).unwrap();
let conn_name = conn.get_name().to_string();
let handle = async_std::task::spawn(async move {
for i in 0..8 {
println!("no_send_deadlock_long(): other thread waiting for {}", i);
let mut call_msg = recv_conn.get_call("/").await.unwrap();
println!("no_send_deadlock_long(): other thread recvd {}", i);
call_msg.typ = MessageType::Reply;
call_msg.dynheader.response_serial = call_msg.dynheader.serial;
call_msg.dynheader.serial = None;
call_msg.dynheader.destination = Some(conn_name.clone());
assert!(recv_conn.send_msg(&call_msg).await.unwrap().is_none());
}
});
let mut res_futs = Vec::new();
for i in 0..8 {
println!("no_send_deadlock_long(): send iteration {}", i);
let res_fut = timeout(Duration::from_millis(500), conn.send_msg(&call))
.await??
.unwrap();
res_futs.push(res_fut);
}
handle.await;
for response in try_join_all(res_futs).await? {
is_msg_reply(response)?;
}
Ok(())
}
fn is_msg_reply(msg: MarshalledMessage) -> Result<(), TestingError> {
match msg.typ {
MessageType::Reply => Ok(()),
_ => Err(TestingError::Bad(msg)),
}
}
#[allow(dead_code)]
fn is_msg_bad(msg: MarshalledMessage) -> Result<(), TestingError> {
match msg.typ {
MessageType::Error => Ok(()),
_ => Err(TestingError::Bad(msg)),
}
}
fn is_msg_good<T>(msg: MarshalledMessage) -> Result<T, TestingError>
where
for<'r> T: Unmarshal<'r, 'r>,
{
let res: Result<T, _> = match msg.typ {
MessageType::Reply => msg.body.parser().get(),
_ => return Err(TestingError::Bad(msg)),
};
match res {
Ok(v) => Ok(v),
Err(_) => Err(TestingError::Bad(msg)),
}
}
#[async_std::test]
async fn introspect() -> Result<(), TestingError> {
let (conn, recv_conn) =
try_join(RpcConn::session_conn(false), RpcConn::session_conn(false)).await?;
recv_conn
.insert_call_path("/", CallAction::Intro)
.await
.unwrap();
recv_conn
.insert_call_path("/usr/local/lib/libdbus", CallAction::Intro)
.await
.unwrap();
recv_conn
.insert_call_path("/usr/local/lib/libssl", CallAction::Intro)
.await
.unwrap();
recv_conn
.insert_call_path("/usr/local/bin/ls", CallAction::Intro)
.await
.unwrap();
recv_conn
.insert_call_path("/tmp", CallAction::Drop)
.await
.unwrap();
recv_conn
.insert_call_path("/tmp/log", CallAction::Queue)
.await
.unwrap();
let mut intro = MessageBuilder::new()
.call("Introspect".to_string())
.at(recv_conn.get_name().to_string())
.on("/".to_string())
.with_interface("org.freedesktop.DBus.Introspectable".to_string())
.build();
let other = async_std::task::spawn(async move {
println!("introspect(): spawned: Being await");
let res = recv_conn.get_call("/tmp/log").await;
unreachable!(
"introspect(): spawned: unfinishable task finished: {:?}",
res
);
});
let intro_str: String = is_msg_good(conn.send_msg_w_rsp(&intro).await?.await?)?;
assert!(intro_str.contains("<node name=\"usr\"/>"));
assert!(!intro_str.contains("<node name=\"tmp\"/>"));
intro.dynheader.object = Some("/usr".to_string());
let intro_str: String = is_msg_good(conn.send_msg_w_rsp(&intro).await?.await?)?;
assert!(intro_str.contains("<node name=\"local\"/>"));
intro.dynheader.object = Some("/usr/local".to_string());
let intro_str: String = is_msg_good(conn.send_msg_w_rsp(&intro).await?.await?)?;
assert!(intro_str.contains("<node name=\"lib\"/>"));
assert!(intro_str.contains("<node name=\"bin\"/>"));
intro.dynheader.object = Some("/usr/local/lib".to_string());
let intro_str: String = is_msg_good(conn.send_msg_w_rsp(&intro).await?.await?)?;
assert!(intro_str.contains("<node name=\"libdbus\"/>"));
assert!(intro_str.contains("<node name=\"libssl\"/>"));
intro.dynheader.object = Some("/usr/local/bin".to_string());
let intro_str: String = is_msg_good(conn.send_msg_w_rsp(&intro).await?.await?)?;
assert!(intro_str.contains("<node name=\"ls\"/>"));
other.cancel().await;
Ok(())
}
#[async_std::test]
async fn detect_hangup() -> Result<(), TestingError> {
let conn = RpcConn::session_conn(false).await?;
conn.insert_call_path("/", CallAction::Exact).await.unwrap();
let fd = conn.as_raw_fd();
println!("Writing bad buffer");
let bad_buf = [0xFFu8; 32];
let bad_ptr = bad_buf.as_ptr() as *const libc::c_void;
let res = unsafe { libc::write(fd, bad_ptr, bad_buf.len()) };
assert_eq!(res, 32);
println!("Checking if hung up");
let res = timeout(Duration::from_secs(1), conn.get_call("/")).await?;
assert!(matches!(res, Err(_)));
Ok(())
}
#[async_std::test]
async fn signal_send_and_receive() -> Result<(), TestingError> {
let (conn, recv_conn) =
try_join(RpcConn::session_conn(false), RpcConn::session_conn(false)).await?;
let recv_dest = recv_conn.get_name();
println!("Inserting signal matches");
recv_conn.insert_sig_match(EMPTY_MATCH).await?;
let mut m1 = MatchRule::new();
m1.path("/io/test/specific");
recv_conn.insert_sig_match(&m1).await?;
let mut m2 = MatchRule::new();
m2.path_namespace("/io/test");
recv_conn.insert_sig_match(&m2).await?;
let mut m3 = MatchRule::new();
m3.interface("io.test.Test1");
recv_conn.insert_sig_match(&m3).await?;
let mut m4 = m3.clone();
m4.member("TestSignal");
recv_conn.insert_sig_match(&m4).await?;
let mut m5 = m4.clone();
m5.interface = None;
recv_conn.insert_sig_match(&m5).await?;
let s_default = MessageBuilder::new()
.signal("io.test.Test3", "TestSignal2", "/")
.to(recv_dest)
.build();
println!("Sending signals");
conn.send_msg_wo_rsp(&s_default).await?;
let mut s1 = MessageBuilder::new()
.signal("io.test.Test2", "TestSignal", "/io/test/specific")
.build(); conn.send_msg_wo_rsp(&s1).await?;
s1.dynheader.object = Some("/io/test/other".into()); conn.send_msg_wo_rsp(&s1).await?;
s1.dynheader.object = Some("/io/test".into()); conn.send_msg_wo_rsp(&s1).await?;
s1.dynheader.object = Some("/io".into()); conn.send_msg_wo_rsp(&s1).await?;
s1.dynheader.interface = Some("io.test.Test1".into()); conn.send_msg_wo_rsp(&s1).await?;
s1.dynheader.member = Some("TestSignal2".into()); conn.send_msg_wo_rsp(&s1).await?;
println!("Receiving signals");
let rs1 = recv_conn.get_signal(&m1).await?.dynheader;
let rs2 = recv_conn.get_signal(&m2).await?.dynheader;
let rs3 = recv_conn.get_signal(&m2).await?.dynheader;
let rs4 = recv_conn.get_signal(&m5).await?.dynheader;
let rs5 = recv_conn.get_signal(&m4).await?.dynheader;
let rs6 = recv_conn.get_signal(&m3).await?.dynheader;
let mut found = false;
while let Some(res) = recv_conn.get_signal(EMPTY_MATCH).now_or_never() {
let rs_default = res?.dynheader;
if rs_default.interface.as_deref() == Some("io.test.Test3")
&& rs_default.member.as_deref() == Some("TestSignal2")
&& rs_default.object.as_deref() == Some("/")
{
found = true;
break;
}
}
assert!(found);
assert_eq!(rs1.interface.as_deref(), Some("io.test.Test2"));
assert_eq!(rs1.member.as_deref(), Some("TestSignal"));
assert_eq!(rs1.object.as_deref(), Some("/io/test/specific"));
assert_eq!(rs2.interface.as_deref(), Some("io.test.Test2"));
assert_eq!(rs2.member.as_deref(), Some("TestSignal"));
assert_eq!(rs2.object.as_deref(), Some("/io/test/other"));
assert_eq!(rs3.interface.as_deref(), Some("io.test.Test2"));
assert_eq!(rs3.member.as_deref(), Some("TestSignal"));
assert_eq!(rs3.object.as_deref(), Some("/io/test"));
assert_eq!(rs4.interface.as_deref(), Some("io.test.Test2"));
assert_eq!(rs4.member.as_deref(), Some("TestSignal"));
assert_eq!(rs4.object.as_deref(), Some("/io"));
assert_eq!(rs5.interface.as_deref(), Some("io.test.Test1"));
assert_eq!(rs5.member.as_deref(), Some("TestSignal"));
assert_eq!(rs5.object.as_deref(), Some("/io"));
assert_eq!(rs6.interface.as_deref(), Some("io.test.Test1"));
assert_eq!(rs6.member.as_deref(), Some("TestSignal2"));
assert_eq!(rs6.object.as_deref(), Some("/io"));
Ok(())
}
#[async_std::test]
#[should_panic]
async fn panic_on_bad_match() {
let conn = RpcConn::session_conn(false).await.unwrap();
let mut m = MatchRule::new();
m.path = Some("/io/test/specific".into());
m.path_namespace = Some("/io".into());
conn.insert_sig_match(&m).await.unwrap();
}
#[async_std::test]
async fn select_left_priority() {
struct PollCounter {
cnt: usize,
}
impl PollCounter {
fn new() -> Self {
Self { cnt: 0 }
}
fn get_count(&self) -> usize {
self.cnt
}
}
impl Future for PollCounter {
type Output = ();
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
self.cnt += 1;
Poll::Pending
}
}
let mut counter = PollCounter::new();
for i in 1..=16 {
counter = match select(counter, ready(())).await {
Either::Right((_, c)) => c,
_ => unreachable!(),
};
assert_eq!(counter.get_count(), i);
}
}