use std::collections::HashMap;
use std::convert::TryInto;
use std::io::ErrorKind;
use std::num::NonZeroU32;
use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use async_channel::{unbounded, Receiver as CReceiver, Sender as CSender};
use futures::future::{Either, TryFutureExt};
#[doc(hidden)]
pub use utils::poll_once;
use std::path::Path;
use tokio::io::unix::AsyncFd;
use tokio::net::ToSocketAddrs;
use tokio::sync::oneshot::{
channel as oneshot_channel, Receiver as OneReceiver, Sender as OneSender,
};
use futures::pin_mut;
use futures::prelude::*;
use futures::task::{Context, Poll};
use tokio::sync::watch::{channel as watch_channel, Sender as WatchSender};
use tokio::sync::Mutex;
pub mod rustbus_core;
use rustbus_core::message_builder::{MarshalledMessage, MessageType};
use rustbus_core::path::ObjectPath;
use rustbus_core::standard_messages::{hello, release_name, request_name};
use rustbus_core::standard_messages::{
DBUS_NAME_FLAG_DO_NOT_QUEUE, DBUS_REQUEST_NAME_REPLY_ALREADY_OWNER,
DBUS_REQUEST_NAME_REPLY_PRIMARY_OWNER,
};
pub mod conn;
use conn::{Conn, GenStream, RecvState, SendState};
mod utils;
mod routing;
use routing::{queue_sig, CallHierarchy};
pub use routing::{CallAction, MatchRule, EMPTY_MATCH};
pub use conn::{get_session_bus_addr, get_system_bus_addr, DBusAddr};
const NO_REPLY_EXPECTED: u8 = 0x01;
struct MsgQueue {
sender: CSender<MarshalledMessage>,
recv: CReceiver<MarshalledMessage>,
}
impl MsgQueue {
fn new() -> Self {
let (sender, recv) = unbounded::<MarshalledMessage>();
Self { sender, recv }
}
fn get_receiver(&self) -> CReceiver<MarshalledMessage> {
self.recv.clone()
}
fn send(&self, msg: MarshalledMessage) {
self.sender.try_send(msg).unwrap()
}
}
struct RecvData {
state: RecvState,
reply_map: HashMap<NonZeroU32, OneSender<MarshalledMessage>>,
hierarchy: CallHierarchy,
sig_matches: Vec<MatchRule>,
}
pub struct RpcConn {
conn: AsyncFd<GenStream>,
recv_watch: WatchSender<()>,
recv_data: Arc<Mutex<RecvData>>,
send_data: Mutex<(SendState, Option<NonZeroU32>)>,
serial: AtomicU32,
auto_name: String,
}
impl RpcConn {
async fn new(conn: Conn) -> std::io::Result<Self> {
unsafe {
let recvd = libc::fcntl(conn.as_raw_fd(), libc::F_GETFL);
if recvd == -1 {
return Err(std::io::Error::last_os_error());
}
if libc::O_NONBLOCK & recvd == 0
&& libc::fcntl(conn.as_raw_fd(), libc::F_SETFL, recvd | libc::O_NONBLOCK) == -1
{
return Err(std::io::Error::last_os_error());
}
}
let recv_data = RecvData {
state: conn.recv_state,
reply_map: HashMap::new(),
hierarchy: CallHierarchy::new(),
sig_matches: Vec::new(),
};
let (recv_watch, _) = watch_channel(());
let mut ret = Self {
conn: AsyncFd::new(conn.stream)?,
send_data: Mutex::new((conn.send_state, None)),
recv_data: Arc::new(Mutex::new(recv_data)),
recv_watch,
serial: AtomicU32::new(1),
auto_name: String::new(),
};
let hello_res = ret.send_msg(&hello()).await?.unwrap().await?;
match hello_res.typ {
MessageType::Reply => {
ret.auto_name = hello_res.body.parser().get().map_err(|_| {
std::io::Error::new(ErrorKind::ConnectionRefused, "Unable to parser name")
})?;
Ok(ret)
}
MessageType::Error => {
let (err, details): (&str, &str) = hello_res
.body
.parser()
.get()
.unwrap_or(("Unable to parse message", ""));
Err(std::io::Error::new(
ErrorKind::ConnectionRefused,
format!("Hello message failed with: {}: {}", err, details),
))
}
_ => Err(std::io::Error::new(
ErrorKind::ConnectionAborted,
"Unexpected reply to hello message!",
)),
}
}
pub fn get_name(&self) -> &str {
&self.auto_name
}
pub async fn session_conn(with_fd: bool) -> std::io::Result<Self> {
let addr = get_session_bus_addr().await?;
Self::connect_to_addr(&addr, with_fd).await
}
pub async fn system_conn(with_fd: bool) -> std::io::Result<Self> {
let addr = get_system_bus_addr().await?;
Self::connect_to_addr(&addr, with_fd).await
}
pub async fn connect_to_addr<P: AsRef<Path>, S: ToSocketAddrs, B: AsRef<[u8]>>(
addr: &DBusAddr<P, S, B>,
with_fd: bool,
) -> std::io::Result<Self> {
let conn = Conn::connect_to_addr(addr, with_fd).await?;
Self::new(conn).await
}
pub async fn connect_to_path<P: AsRef<Path>>(path: P, with_fd: bool) -> std::io::Result<Self> {
let conn = Conn::connect_to_path(path, with_fd).await?;
Self::new(conn).await
}
fn allocate_idx(&self) -> NonZeroU32 {
let mut idx = 0;
while idx == 0 {
idx = self.serial.fetch_add(1, Ordering::Relaxed);
}
NonZeroU32::new(idx).unwrap()
}
pub async fn send_msg(
&self,
msg: &MarshalledMessage,
) -> std::io::Result<Option<impl Future<Output = std::io::Result<MarshalledMessage>> + '_>>
{
Ok(if expects_reply(msg) {
Some(self.send_msg_w_rsp(msg).await?)
} else {
self.send_msg_wo_rsp(msg).await?;
None
})
}
async fn send_msg_loop(&self, msg: &MarshalledMessage, idx: NonZeroU32) -> std::io::Result<()> {
let mut send_idx = None;
loop {
let mut write_guard = self.conn.writable().await?;
let mut send_lock = self.send_data.lock().await;
let stream = self.conn.get_ref();
match send_idx {
Some(send_idx) => {
if send_lock.0.current_idx() > send_idx {
return Ok(());
}
let new_idx = match send_lock.0.finish_sending_next(stream) {
Ok(i) => i,
Err(e) if e.kind() == ErrorKind::WouldBlock => {
write_guard.clear_ready();
continue;
}
Err(e) => return Err(e),
};
if new_idx > send_idx {
return Ok(());
}
}
None => {
send_idx = match send_lock.0.write_next_message(stream, msg, idx) {
Ok(si) => si,
Err(e) if e.kind() == ErrorKind::WouldBlock => {
write_guard.clear_ready();
continue;
}
Err(e) => return Err(e),
};
if send_idx.is_none() {
return Ok(());
}
}
}
drop(send_lock);
}
}
pub async fn send_msg_wo_rsp(&self, msg: &MarshalledMessage) -> std::io::Result<()> {
assert!(!expects_reply(msg));
let idx = self.allocate_idx();
self.send_msg_loop(msg, idx).await
}
pub async fn send_msg_w_rsp(
&self,
msg: &MarshalledMessage,
) -> std::io::Result<impl Future<Output = std::io::Result<MarshalledMessage>> + '_> {
assert!(expects_reply(msg));
let idx = self.allocate_idx();
let recv = self.get_recv_and_insert_sender(idx).await;
let msg_fut = recv.map_err(|_| panic!("Message reply channel should never be closed"));
self.send_msg_loop(msg, idx).await?;
let res_pred = move |msg: &MarshalledMessage, _: &mut RecvData| match &msg.typ {
MessageType::Reply | MessageType::Error => {
let res_idx = match msg.dynheader.response_serial {
Some(res_idx) => res_idx,
None => {
unreachable!("Should never reply/err without res serial.")
}
};
res_idx == idx
}
_ => false,
};
Ok(ResponseFuture {
idx,
rpc_conn: self,
fut: self.get_msg(msg_fut, res_pred).boxed(),
})
}
async fn get_recv_and_insert_sender(&self, idx: NonZeroU32) -> OneReceiver<MarshalledMessage> {
let (sender, recv) = oneshot_channel();
let mut recv_lock = self.recv_data.lock().await;
recv_lock.reply_map.insert(idx, sender);
recv
}
pub async fn insert_sig_match(&self, sig_match: &MatchRule) -> std::io::Result<()> {
assert!(!(sig_match.path.is_some() && sig_match.path_namespace.is_some()));
let mut recv_data = self.recv_data.lock().await;
let insert_idx = match recv_data.sig_matches.binary_search(sig_match) {
Ok(_) => {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"Already exists",
))
}
Err(i) => i,
};
let mut to_insert = sig_match.clone();
to_insert.queue = Some(MsgQueue::new());
recv_data.sig_matches.insert(insert_idx, to_insert);
drop(recv_data);
let match_str = sig_match.match_string();
let call = rustbus_core::standard_messages::add_match(&match_str);
let res = self.send_msg_w_rsp(&call).await?.await?;
match res.typ {
MessageType::Reply => Ok(()),
MessageType::Error => {
let mut recv_data = self.recv_data.lock().await;
if let Ok(idx) = recv_data.sig_matches.binary_search(sig_match) {
recv_data.sig_matches.remove(idx);
}
let err_str: &str = res
.body
.parser()
.get()
.unwrap_or("Unknown DBus Error Type!");
Err(std::io::Error::new(ErrorKind::Other, err_str))
}
_ => unreachable!(),
}
}
pub async fn remove_sig_match(&self, sig_match: &MatchRule) -> std::io::Result<()> {
let mut recv_data = self.recv_data.lock().await;
let idx = match recv_data.sig_matches.binary_search(sig_match) {
Err(_) => {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"MatchRule doesn't exist!",
))
}
Ok(i) => i,
};
recv_data.sig_matches.remove(idx);
drop(recv_data);
let match_str = sig_match.match_string();
let call = rustbus_core::standard_messages::remove_match(&match_str);
let res = self.send_msg_w_rsp(&call).await?.await?;
match res.typ {
MessageType::Reply => Ok(()),
MessageType::Error => {
let err_str: &str = res
.body
.parser()
.get()
.unwrap_or("Unknown DBus Error Type!");
Err(std::io::Error::new(ErrorKind::Other, err_str))
}
_ => unreachable!(),
}
}
fn queue_msg<F>(
&self,
recv_data: &mut RecvData,
pred: F,
) -> std::io::Result<(MarshalledMessage, bool)>
where
F: Fn(&MarshalledMessage, &mut RecvData) -> bool,
{
let stream = self.conn.get_ref();
loop {
let msg = recv_data.state.get_next_message(stream)?;
if pred(&msg, recv_data) {
return Ok((msg, false));
} else {
match &msg.typ {
MessageType::Signal => queue_sig(&recv_data.sig_matches, msg),
MessageType::Reply | MessageType::Error => {
let idx = msg
.dynheader
.response_serial
.expect("Reply should always have a response serial!");
if let Some(sender) = recv_data.reply_map.remove(&idx) {
sender.send(msg).ok();
}
}
MessageType::Call => {
if let Err(msg) = recv_data.hierarchy.send(msg) {
return Ok((msg, true));
}
}
MessageType::Invalid => unreachable!(),
}
}
}
}
async fn get_msg<F, P>(&self, msg_fut: F, pred: P) -> std::io::Result<MarshalledMessage>
where
F: Future<Output = std::io::Result<MarshalledMessage>>,
P: Fn(&MarshalledMessage, &mut RecvData) -> bool,
{
pin_mut!(msg_fut);
let mut msg_fut = match poll_once(msg_fut) {
Either::Left(res) => return res,
Either::Right(f) => f,
};
loop {
let mut read_guard = self.conn.readable().await?;
let recv_guard_fut = self.recv_data.lock();
tokio::select! {
biased;
res = &mut msg_fut => { return res }
mut recv_guard = recv_guard_fut => match self.queue_msg(&mut recv_guard, &pred) {
Err(e) if e.kind() == ErrorKind::WouldBlock => {
read_guard.clear_ready();
continue;
}
Err(e) => { return Err(e) }
Ok((msg, should_reply)) => {
self.recv_watch.send_replace(());
if !should_reply {
return Ok(msg);
}
drop(recv_guard);
self.send_msg_wo_rsp(&msg).await?;
}
}
}
}
}
pub async fn get_signal(&self, sig_match: &MatchRule) -> std::io::Result<MarshalledMessage> {
let recv_data = self.recv_data.lock().await;
let idx = recv_data
.sig_matches
.binary_search(sig_match)
.map_err(|_| {
std::io::Error::new(ErrorKind::InvalidInput, "Unknown match rule given!")
})?;
let recv = recv_data.sig_matches[idx]
.queue
.as_ref()
.unwrap()
.get_receiver();
drop(recv_data);
let msg_fut = recv.recv().map_err(|_| {
std::io::Error::new(
ErrorKind::Interrupted,
"Signal match was deleted while waiting!",
)
});
let sig_pred = |msg: &MarshalledMessage, _: &mut RecvData| sig_match.matches(msg);
self.get_msg(msg_fut, sig_pred).await
}
pub async fn get_call<'a, S, D>(&self, path: S) -> std::io::Result<MarshalledMessage>
where
S: TryInto<&'a ObjectPath, Error = D>,
D: std::fmt::Debug,
{
let path = path.try_into().map_err(|e| {
std::io::Error::new(ErrorKind::InvalidInput, format!("Invalid path: {:?}", e))
})?;
let recv_guard = self.recv_data.lock().await;
let msg_queue = recv_guard.hierarchy.get_queue(path).ok_or_else(|| {
std::io::Error::new(ErrorKind::InvalidInput, "Unknown message path given!")
})?;
let recv = msg_queue.get_receiver();
drop(recv_guard);
let call_fut = recv.recv().map_err(|_| {
std::io::Error::new(
ErrorKind::Interrupted,
"Call Queue was deleted while waiting!",
)
});
let call_pred = |msg: &MarshalledMessage, recv_data: &mut RecvData| match &msg.typ {
MessageType::Call => {
let msg_path =
ObjectPath::from_str(msg.dynheader.object.as_ref().unwrap()).unwrap();
recv_data.hierarchy.is_match(path, msg_path)
}
_ => false,
};
self.get_msg(call_fut, call_pred).await
}
pub async fn insert_call_path<'a, S, D>(&self, path: S, action: CallAction) -> Result<(), D>
where
S: TryInto<&'a ObjectPath, Error = D>,
{
let path = path.try_into()?;
let mut recv_data = self.recv_data.lock().await;
recv_data.hierarchy.insert_path(path, action);
Ok(())
}
pub async fn get_call_path_action<'a, S: TryInto<&'a ObjectPath>>(
&self,
path: S,
) -> Option<CallAction> {
let path = path.try_into().ok()?;
let recv_data = self.recv_data.lock().await;
recv_data.hierarchy.get_action(path)
}
pub async fn get_call_recv<'a, S: TryInto<&'a ObjectPath>>(
&self,
path: S,
) -> Option<CReceiver<MarshalledMessage>> {
let path = path.try_into().ok()?;
let recv_data = self.recv_data.lock().await;
Some(recv_data.hierarchy.get_queue(path)?.get_receiver())
}
pub async fn request_name(&self, name: &str) -> std::io::Result<bool> {
let req = request_name(name, DBUS_NAME_FLAG_DO_NOT_QUEUE);
let res = self.send_msg_w_rsp(&req).await?.await?;
if MessageType::Error == res.typ {
return Ok(false);
}
Ok(match res.body.parser().get::<u32>() {
Ok(ret) => matches!(
ret,
DBUS_REQUEST_NAME_REPLY_ALREADY_OWNER | DBUS_REQUEST_NAME_REPLY_PRIMARY_OWNER
),
Err(_) => false,
})
}
pub async fn release_name(&self, name: &str) -> std::io::Result<()> {
let rel_name = release_name(name);
self.send_msg_w_rsp(&rel_name).await?.await?;
Ok(())
}
}
impl AsRawFd for RpcConn {
fn as_raw_fd(&self) -> RawFd {
self.conn.as_raw_fd()
}
}
struct ResponseFuture<'a, T>
where
T: Future<Output = std::io::Result<MarshalledMessage>>,
{
rpc_conn: &'a RpcConn,
idx: NonZeroU32,
fut: T,
}
impl<T> Future for ResponseFuture<'_, T>
where
T: Future<Output = std::io::Result<MarshalledMessage>>,
{
type Output = T::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe { self.map_unchecked_mut(|s| &mut s.fut).poll(cx) }
}
}
impl<T> Drop for ResponseFuture<'_, T>
where
T: Future<Output = std::io::Result<MarshalledMessage>>,
{
fn drop(&mut self) {
if let Ok(mut recv_lock) = self.rpc_conn.recv_data.try_lock() {
recv_lock.reply_map.remove(&self.idx);
return;
}
let reply_arc = Arc::clone(&self.rpc_conn.recv_data);
let idx = self.idx;
tokio::spawn(async move {
let mut recv_lock = reply_arc.lock().await;
recv_lock.reply_map.remove(&idx);
});
}
}
fn expects_reply(msg: &MarshalledMessage) -> bool {
msg.typ == MessageType::Call && (msg.flags & NO_REPLY_EXPECTED) == 0
}