use dbus::channel::{Channel, BusType};
use dbus::nonblock::{LocalConnection, SyncConnection, Process, NonblockReply};
use std::{future, io, task, pin};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::os::unix::io::RawFd;
use tokio::io::unix::{AsyncFd, AsyncFdReadyGuard};
#[derive(Debug)]
enum WakeStatus {
Waiting { ready: bool },
Polled { waker: task::Waker },
}
enum IOResourceRegistration {
Unregistered(RawFd, tokio::io::Interest),
Registered(AsyncFd<RawFd>),
}
pub struct IOResource<C> {
connection: Arc<C>,
registration: IOResourceRegistration,
wake: Arc<Mutex<WakeStatus>>,
write_pending: bool,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum IOResourceError {
Dbus(dbus::Error),
Io(io::Error),
}
impl From<dbus::Error> for IOResourceError {
fn from(e: dbus::Error) -> Self {
IOResourceError::Dbus(e)
}
}
impl From<io::Error> for IOResourceError {
fn from(e: io::Error) -> Self {
IOResourceError::Io(e)
}
}
impl std::fmt::Display for IOResourceError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
IOResourceError::Dbus(e) => e.fmt(f),
IOResourceError::Io(e) => e.fmt(f),
}
}
}
impl std::error::Error for IOResourceError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(match self {
IOResourceError::Dbus(e) => e,
IOResourceError::Io(e) => e,
})
}
}
impl<C: AsRef<Channel> + Process> IOResource<C> {
fn poll_internal(&mut self, ctx: &mut task::Context<'_>) -> Result<(), IOResourceError> {
let c: &Channel = (*self.connection).as_ref();
let mut wake_status = self.wake.lock().unwrap();
if let IOResourceRegistration::Unregistered(watch_fd, interest) = self.registration {
let watch_reg = AsyncFd::with_interest(watch_fd, interest)?;
self.registration = IOResourceRegistration::Registered(watch_reg);
}
let watch_reg = match &self.registration {
IOResourceRegistration::Registered(res) => res,
IOResourceRegistration::Unregistered(..) => unreachable!(),
};
let mut read_guard = watch_reg.poll_read_ready(ctx)?;
let send_ready = match &*wake_status {
WakeStatus::Polled { waker } if ctx.waker().will_wake(waker) => false,
_ => {
let prev_status = std::mem::replace(
&mut *wake_status,
WakeStatus::Polled { waker: ctx.waker().clone() },
);
matches!(prev_status, WakeStatus::Waiting { ready: true })
}
};
let mut write_guard = watch_reg.poll_write_ready(ctx)?;
if read_guard.is_ready() || send_ready || (self.write_pending && write_guard.is_ready()) {
loop {
self.write_pending = false;
c.read_write(Some(Duration::default())).map_err(|_| dbus::Error::new_failed("Read/write failed"))?;
self.connection.process_all();
if c.has_messages_to_send() {
self.write_pending = true;
if check_ready_now(&mut write_guard, || watch_reg.poll_write_ready(ctx))? {
continue
}
}
let watch_fd = *watch_reg.get_ref();
let mut x = 0u8;
let r = unsafe {
libc::recv(watch_fd, &mut x as *mut _ as *mut libc::c_void, 1, libc::MSG_DONTWAIT | libc::MSG_PEEK)
};
if r != 1 {
if check_ready_now(&mut read_guard, || watch_reg.poll_read_ready(ctx))? {
continue
}
break;
}
}
}
Ok(())
}
}
fn check_ready_now<'a>(
guard: &mut task::Poll<AsyncFdReadyGuard<'a, RawFd>>,
poll_ready: impl FnOnce() -> task::Poll<std::io::Result<AsyncFdReadyGuard<'a, RawFd>>>,
) -> std::io::Result<bool> {
if let task::Poll::Ready(g) = guard {
g.clear_ready();
}
let ready_now = poll_ready()?;
let try_again = ready_now.is_ready();
*guard = ready_now;
Ok(try_again)
}
impl<C: AsRef<Channel> + Process> future::Future for IOResource<C> {
type Output = IOResourceError;
fn poll(mut self: pin::Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
match self.poll_internal(ctx) {
Ok(()) => task::Poll::Pending,
Err(e) => task::Poll::Ready(e),
}
}
}
fn make_timeout(timeout: Instant) -> pin::Pin<Box<dyn future::Future<Output=()> + Send + Sync + 'static>> {
let t = tokio::time::sleep_until(timeout.into());
Box::pin(t)
}
pub fn new<C: From<Channel> + NonblockReply>(b: BusType) -> Result<(IOResource<C>, Arc<C>), dbus::Error> {
let mut channel = Channel::get_private(b)?;
channel.set_watch_enabled(true);
let watch = channel.watch();
let watch_fd = watch.fd;
let mut interest = tokio::io::Interest::READABLE;
if watch.write {
interest |= tokio::io::Interest::WRITABLE;
}
let mut conn = C::from(channel);
conn.set_timeout_maker(Some(make_timeout));
let wake = Arc::new(Mutex::new(WakeStatus::Waiting { ready: false }));
conn.set_waker(Some(Box::new({
let wake = wake.clone();
move || {
let mut wake_status = wake.lock().unwrap();
let prev_status = std::mem::replace(
&mut *wake_status,
WakeStatus::Waiting { ready: true }
);
match prev_status {
WakeStatus::Polled { waker } => {
waker.wake();
Ok(())
}
WakeStatus::Waiting { .. } => {
Err(())
}
}
}
})));
let conn = Arc::new(conn);
let res = IOResource {
connection: conn.clone(),
registration: IOResourceRegistration::Unregistered(watch_fd, interest),
wake,
write_pending: false,
};
Ok((res, conn))
}
pub fn new_session_local() -> Result<(IOResource<LocalConnection>, Arc<LocalConnection>), dbus::Error> { new(BusType::Session) }
pub fn new_system_local() -> Result<(IOResource<LocalConnection>, Arc<LocalConnection>), dbus::Error> { new(BusType::System) }
pub fn new_session_sync() -> Result<(IOResource<SyncConnection>, Arc<SyncConnection>), dbus::Error> { new(BusType::Session) }
pub fn new_system_sync() -> Result<(IOResource<SyncConnection>, Arc<SyncConnection>), dbus::Error> { new(BusType::System) }
#[cfg(test)]
mod test {
use super::*;
#[test]
fn method_call_local() {
use tokio::task;
use std::time::Duration;
let mut rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap();
let local = task::LocalSet::new();
let (res, conn) = new_session_local().unwrap();
local.spawn_local(async move { panic!(res.await);});
let proxy = dbus::nonblock::Proxy::new("org.freedesktop.DBus", "/", Duration::from_secs(2), conn);
let fut = proxy.method_call("org.freedesktop.DBus", "NameHasOwner", ("dummy.name.without.owner",));
let (has_owner,): (bool,) = local.block_on(&mut rt, fut).unwrap();
assert_eq!(has_owner, false);
}
#[tokio::test]
async fn timeout() {
use std::time::Duration;
let (ress, conns) = new_session_sync().unwrap();
tokio::spawn(async move { panic!(ress.await);});
conns.request_name("com.example.dbusrs.tokiotest", true, true, true).await.unwrap();
use dbus::channel::MatchingReceiver;
conns.start_receive(dbus::message::MatchRule::new_method_call(), Box::new(|_,_| true));
let (res, conn) = new_session_sync().unwrap();
tokio::spawn(async move { panic!(res.await);});
let proxy = dbus::nonblock::Proxy::new("com.example.dbusrs.tokiotest", "/", Duration::from_millis(150), conn);
let e: Result<(), _> = proxy.method_call("com.example.dbusrs.tokiotest", "Whatever", ()).await;
let e = e.unwrap_err();
assert_eq!(e.name(), Some("org.freedesktop.DBus.Error.Timeout"));
}
#[tokio::test]
async fn large_message() -> Result<(), Box<dyn std::error::Error>> {
use dbus::arg::Variant;
use dbus_tree::Factory;
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
type BigProps<'a> = Vec<(dbus::Path<'a>, HashMap<String, Variant<Box<i32>>>)>;
fn make_big_reply<'a>() -> Result<BigProps<'a>, String> {
let prop_map: HashMap<String, Variant<Box<i32>>> = (0..500).map(|i| (format!("key {}", i), Variant(Box::new(i)))).collect();
(0..30u8).map(|i| Ok((dbus::strings::Path::new(format!("/{}", i))?, prop_map.clone()))).collect()
}
let server_conn = dbus::blocking::SyncConnection::new_session()?;
server_conn.request_name("com.example.dbusrs.tokiobigtest", false, true, false)?;
let f = Factory::new_sync::<()>();
let tree =
f.tree(()).add(f.object_path("/", ()).add(f.interface("com.example.dbusrs.tokiobigtest", ()).add_m(f.method("Ping", (), |m| {
Ok(vec![m.msg.method_return().append1(make_big_reply().map_err(|err| dbus::MethodErr::failed(&err))?)])
}))));
tree.start_receive_sync(&server_conn);
let done = Arc::new(AtomicBool::new(false));
let done2 = done.clone();
tokio::task::spawn_blocking(move || {
while !done2.load(Ordering::Acquire) {
server_conn.process(Duration::from_millis(100)).unwrap();
}
});
let (resource, client_conn) = new_session_sync()?;
tokio::spawn(async {
let err = resource.await;
panic!("Lost connection to D-Bus: {}", err);
});
let mut client_interval = tokio::time::interval(Duration::from_millis(10));
let proxy = dbus::nonblock::Proxy::new("com.example.dbusrs.tokiobigtest", "/", Duration::from_secs(1), client_conn);
for _ in 0..10 {
client_interval.tick().await;
println!("sending ping");
proxy.method_call::<(BigProps,), _, _, _>("com.example.dbusrs.tokiobigtest", "Ping", ()).await.unwrap();
println!("received prop list!");
}
done.store(true, Ordering::Release);
Ok(())
}
}