use std::{
any::{Any, TypeId},
cell::RefCell,
collections::{hash_map::Entry, HashMap},
convert::TryInto,
fmt::Write,
rc::Rc,
};
use scoped_tls::scoped_thread_local;
use zvariant::{ObjectPath, OwnedObjectPath, OwnedValue, Value};
use crate::{dbus_interface, fdo, Connection, Error, Message, MessageHeader, MessageType, Result};
scoped_thread_local!(static LOCAL_NODE: Node);
scoped_thread_local!(static LOCAL_CONNECTION: Connection);
pub trait Interface: Any {
fn name() -> &'static str
where
Self: Sized;
fn get(&self, property_name: &str) -> Option<fdo::Result<OwnedValue>>;
fn get_all(&self) -> HashMap<String, OwnedValue>;
fn set(&mut self, property_name: &str, value: &Value<'_>) -> Option<fdo::Result<()>>;
fn call(&self, connection: &Connection, msg: &Message, name: &str) -> Option<Result<u32>>;
fn call_mut(
&mut self,
connection: &Connection,
msg: &Message,
name: &str,
) -> Option<Result<u32>>;
fn introspect_to_writer(&self, writer: &mut dyn Write, level: usize);
}
impl dyn Interface {
fn downcast_ref<T: Any>(&self) -> Option<&T> {
if <dyn Interface as Any>::type_id(self) == TypeId::of::<T>() {
Some(unsafe { &*(self as *const dyn Interface as *const T) })
} else {
None
}
}
}
struct Introspectable;
#[dbus_interface(name = "org.freedesktop.DBus.Introspectable")]
impl Introspectable {
fn introspect(&self) -> String {
LOCAL_NODE.with(|node| node.introspect())
}
}
struct Peer;
#[dbus_interface(name = "org.freedesktop.DBus.Peer")]
impl Peer {
fn ping(&self) {}
fn get_machine_id(&self) -> fdo::Result<String> {
let mut id = match std::fs::read_to_string("/var/lib/dbus/machine-id") {
Ok(id) => id,
Err(e) => {
if let Ok(id) = std::fs::read_to_string("/etc/machine-id") {
id
} else {
return Err(fdo::Error::IOError(format!(
"Failed to read from /var/lib/dbus/machine-id or /etc/machine-id: {}",
e
)));
}
}
};
let len = id.trim_end().len();
id.truncate(len);
Ok(id)
}
}
struct Properties;
#[dbus_interface(name = "org.freedesktop.DBus.Properties")]
impl Properties {
fn get(&self, interface_name: &str, property_name: &str) -> fdo::Result<OwnedValue> {
LOCAL_NODE.with(|node| {
let iface = node.get_interface(interface_name).ok_or_else(|| {
fdo::Error::UnknownInterface(format!("Unknown interface '{}'", interface_name))
})?;
let res = iface.borrow().get(property_name);
res.ok_or_else(|| {
fdo::Error::UnknownProperty(format!("Unknown property '{}'", property_name))
})?
})
}
fn set(
&mut self,
interface_name: &str,
property_name: &str,
value: OwnedValue,
) -> fdo::Result<()> {
LOCAL_NODE.with(|node| {
let iface = node.get_interface(interface_name).ok_or_else(|| {
fdo::Error::UnknownInterface(format!("Unknown interface '{}'", interface_name))
})?;
let res = iface.borrow_mut().set(property_name, &value);
res.ok_or_else(|| {
fdo::Error::UnknownProperty(format!("Unknown property '{}'", property_name))
})?
})
}
fn get_all(&self, interface_name: &str) -> fdo::Result<HashMap<String, OwnedValue>> {
LOCAL_NODE.with(|node| {
let iface = node.get_interface(interface_name).ok_or_else(|| {
fdo::Error::UnknownInterface(format!("Unknown interface '{}'", interface_name))
})?;
let res = iface.borrow().get_all();
Ok(res)
})
}
#[dbus_interface(signal)]
fn properties_changed(
&self,
interface_name: &str,
changed_properties: &HashMap<&str, &Value<'_>>,
invalidated_properties: &[&str],
) -> Result<()>;
}
#[derive(Default, derivative::Derivative)]
#[derivative(Debug)]
struct Node {
path: OwnedObjectPath,
children: HashMap<String, Node>,
#[derivative(Debug = "ignore")]
interfaces: HashMap<&'static str, Rc<RefCell<dyn Interface>>>,
}
impl Node {
fn new(path: OwnedObjectPath) -> Self {
let mut node = Self {
path,
..Default::default()
};
node.at(Peer::name(), Peer);
node.at(Introspectable::name(), Introspectable);
node.at(Properties::name(), Properties);
node
}
fn get_interface(&self, iface: &str) -> Option<Rc<RefCell<dyn Interface>>> {
self.interfaces.get(iface).cloned()
}
fn remove_interface(&mut self, iface: &str) -> bool {
self.interfaces.remove(iface).is_some()
}
fn is_empty(&self) -> bool {
self.interfaces
.keys()
.find(|k| {
*k != &Peer::name() && *k != &Introspectable::name() && *k != &Properties::name()
})
.is_none()
}
fn remove_node(&mut self, node: &str) -> bool {
self.children.remove(node).is_some()
}
fn at<I>(&mut self, name: &'static str, iface: I) -> bool
where
I: Interface,
{
match self.interfaces.entry(name) {
Entry::Vacant(e) => e.insert(Rc::new(RefCell::new(iface))),
Entry::Occupied(_) => return false,
};
true
}
fn with_iface_func<F, I>(&self, func: F) -> Result<()>
where
F: Fn(&I) -> Result<()>,
I: Interface,
{
let iface = self
.interfaces
.get(I::name())
.ok_or(Error::InterfaceNotFound)?
.borrow();
let iface = iface.downcast_ref::<I>().ok_or(Error::InterfaceNotFound)?;
func(iface)
}
fn introspect_to_writer<W: Write>(&self, writer: &mut W, level: usize) {
if level == 0 {
writeln!(
writer,
r#"
<!DOCTYPE node PUBLIC "-//freedesktop//DTD D-BUS Object Introspection 1.0//EN"
"http://www.freedesktop.org/standards/dbus/1.0/introspect.dtd">
<node>"#
)
.unwrap();
}
for iface in self.interfaces.values() {
iface.borrow().introspect_to_writer(writer, level + 2);
}
for (path, node) in &self.children {
let level = level + 2;
writeln!(
writer,
"{:indent$}<node name=\"{}\">",
"",
path,
indent = level
)
.unwrap();
node.introspect_to_writer(writer, level);
writeln!(writer, "{:indent$}</node>", "", indent = level).unwrap();
}
if level == 0 {
writeln!(writer, "</node>").unwrap();
}
}
fn introspect(&self) -> String {
let mut xml = String::with_capacity(1024);
self.introspect_to_writer(&mut xml, 0);
xml
}
fn emit_signal<B>(
&self,
dest: Option<&str>,
iface: &str,
signal_name: &str,
body: &B,
) -> Result<()>
where
B: serde::ser::Serialize + zvariant::Type,
{
if !LOCAL_CONNECTION.is_set() {
panic!("emit_signal: Connection TLS not set");
}
LOCAL_CONNECTION
.with(|conn| conn.emit_signal(dest, self.path.as_str(), iface, signal_name, body))
}
}
#[derive(Debug)]
pub struct ObjectServer {
conn: Connection,
root: Node,
}
impl ObjectServer {
pub fn new(connection: &Connection) -> Self {
Self {
conn: connection.clone(),
root: Node::new("/".try_into().expect("zvariant bug")),
}
}
fn get_node(&self, path: &ObjectPath<'_>) -> Option<&Node> {
let mut node = &self.root;
let mut node_path = String::new();
for i in path.split('/').skip(1) {
if i.is_empty() {
continue;
}
write!(&mut node_path, "/{}", i).unwrap();
match node.children.get(i) {
Some(n) => node = n,
None => return None,
}
}
Some(node)
}
fn get_node_mut(&mut self, path: &ObjectPath<'_>, create: bool) -> Option<&mut Node> {
let mut node = &mut self.root;
let mut node_path = String::new();
for i in path.split('/').skip(1) {
if i.is_empty() {
continue;
}
write!(&mut node_path, "/{}", i).unwrap();
match node.children.entry(i.into()) {
Entry::Vacant(e) => {
if create {
let path = node_path.as_str().try_into().expect("Invalid Object Path");
node = e.insert(Node::new(path));
} else {
return None;
}
}
Entry::Occupied(e) => node = e.into_mut(),
}
}
Some(node)
}
pub fn at<'p, P, I>(&mut self, path: P, iface: I) -> Result<bool>
where
I: Interface,
P: TryInto<ObjectPath<'p>, Error = zvariant::Error>,
{
let path = path.try_into()?;
Ok(self.get_node_mut(&path, true).unwrap().at(I::name(), iface))
}
pub fn remove<'p, I, P>(&mut self, path: P) -> Result<bool>
where
I: Interface,
P: TryInto<ObjectPath<'p>, Error = zvariant::Error>,
{
let path = path.try_into()?;
let node = self
.get_node_mut(&path, false)
.ok_or(Error::InterfaceNotFound)?;
if !node.remove_interface(I::name()) {
return Err(Error::InterfaceNotFound);
}
if node.is_empty() {
let mut path_parts = path.rsplit('/').filter(|i| !i.is_empty());
let last_part = path_parts.next().unwrap();
let ppath = ObjectPath::from_string_unchecked(
path_parts.fold(String::new(), |a, p| format!("/{}{}", p, a)),
);
self.get_node_mut(&ppath, false)
.unwrap()
.remove_node(last_part);
return Ok(true);
}
Ok(false)
}
pub fn with<'p, P, F, I>(&self, path: P, func: F) -> Result<()>
where
F: Fn(&I) -> Result<()>,
I: Interface,
P: TryInto<ObjectPath<'p>, Error = zvariant::Error>,
{
let path = path.try_into()?;
let node = self.get_node(&path).ok_or(Error::InterfaceNotFound)?;
LOCAL_CONNECTION.set(&self.conn, || {
LOCAL_NODE.set(node, || node.with_iface_func(func))
})
}
pub fn local_node_emit_signal<B>(
destination: Option<&str>,
iface: &str,
signal_name: &str,
body: &B,
) -> Result<()>
where
B: serde::ser::Serialize + zvariant::Type,
{
if !LOCAL_NODE.is_set() {
panic!("emit_signal: Node TLS not set");
}
LOCAL_NODE.with(|n| n.emit_signal(destination, iface, signal_name, body))
}
fn dispatch_method_call_try(
&mut self,
msg_header: &MessageHeader<'_>,
msg: &Message,
) -> fdo::Result<Result<u32>> {
let conn = self.conn.clone();
let path = msg_header
.path()
.ok()
.flatten()
.ok_or_else(|| fdo::Error::Failed("Missing object path".into()))?;
let iface = msg_header
.interface()
.ok()
.flatten()
.ok_or_else(|| fdo::Error::Failed("Missing interface".into()))?;
let member = msg_header
.member()
.ok()
.flatten()
.ok_or_else(|| fdo::Error::Failed("Missing member".into()))?;
let node = self
.get_node_mut(&path, false)
.ok_or_else(|| fdo::Error::UnknownObject(format!("Unknown object '{}'", path)))?;
let iface = node.get_interface(iface).ok_or_else(|| {
fdo::Error::UnknownInterface(format!("Unknown interface '{}'", iface))
})?;
LOCAL_CONNECTION.set(&conn, || {
LOCAL_NODE.set(node, || {
let res = iface.borrow().call(&conn, &msg, member);
res.or_else(|| iface.borrow_mut().call_mut(&conn, &msg, member))
.ok_or_else(|| {
fdo::Error::UnknownMethod(format!("Unknown method '{}'", member))
})
})
})
}
fn dispatch_method_call(
&mut self,
msg_header: &MessageHeader<'_>,
msg: &Message,
) -> Result<u32> {
match self.dispatch_method_call_try(msg_header, msg) {
Err(e) => e.reply(&self.conn, msg),
Ok(r) => r,
}
}
pub fn dispatch_message(&mut self, msg: &Message) -> Result<bool> {
let msg_header = msg.header()?;
match msg_header.message_type()? {
MessageType::MethodCall => {
self.dispatch_method_call(&msg_header, &msg)?;
Ok(true)
}
_ => Ok(false),
}
}
pub fn try_handle_next(&mut self) -> Result<Option<Message>> {
let msg = self.conn.receive_message()?;
if !self.dispatch_message(&msg)? {
return Ok(Some(msg));
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use std::{cell::Cell, collections::HashMap, error::Error, rc::Rc, thread};
use ntest::timeout;
use serde::{Deserialize, Serialize};
use zvariant::derive::Type;
use crate::{
dbus_interface, dbus_proxy, fdo, Connection, MessageHeader, MessageType, ObjectServer,
};
#[derive(Deserialize, Serialize, Type)]
pub struct ArgStructTest {
foo: i32,
bar: String,
}
#[dbus_proxy]
trait MyIface {
fn ping(&self) -> zbus::Result<u32>;
fn quit(&self) -> zbus::Result<()>;
fn test_header(&self) -> zbus::Result<()>;
fn test_error(&self) -> zbus::Result<()>;
fn test_single_struct_arg(&self, arg: ArgStructTest) -> zbus::Result<()>;
fn test_single_struct_ret(&self) -> zbus::Result<ArgStructTest>;
fn test_multi_ret(&self) -> zbus::Result<(i32, String)>;
fn test_hashmap_return(&self) -> zbus::Result<HashMap<String, String>>;
fn create_obj(&self, key: &str) -> zbus::Result<()>;
fn destroy_obj(&self, key: &str) -> zbus::Result<()>;
#[dbus_proxy(property)]
fn count(&self) -> zbus::Result<u32>;
#[dbus_proxy(property)]
fn set_count(&self, count: u32) -> zbus::Result<()>;
#[dbus_proxy(property)]
fn hash_map(&self) -> zbus::Result<HashMap<String, String>>;
}
#[derive(Debug, Clone)]
enum NextAction {
Nothing,
Quit,
CreateObj(String),
DestroyObj(String),
}
struct MyIfaceImpl {
action: Rc<Cell<NextAction>>,
count: u32,
}
impl MyIfaceImpl {
fn new(action: Rc<Cell<NextAction>>) -> Self {
Self { action, count: 0 }
}
}
#[dbus_interface(interface = "org.freedesktop.MyIface")]
impl MyIfaceImpl {
fn ping(&mut self) -> u32 {
self.count += 1;
if self.count % 3 == 0 {
self.alert_count(self.count).expect("Failed to emit signal");
}
self.count
}
fn quit(&mut self) {
self.action.set(NextAction::Quit);
}
fn test_header(&self, #[zbus(header)] header: MessageHeader<'_>) {
assert_eq!(header.message_type().unwrap(), MessageType::MethodCall);
assert_eq!(header.member().unwrap(), Some("TestHeader"));
}
fn test_error(&self) -> zbus::fdo::Result<()> {
Err(zbus::fdo::Error::Failed("error raised".to_string()))
}
fn test_single_struct_arg(&self, arg: ArgStructTest) {
assert_eq!(arg.foo, 1);
assert_eq!(arg.bar, "TestString");
}
#[dbus_interface(struct_return)]
fn test_single_struct_ret(&self) -> zbus::Result<ArgStructTest> {
Ok(ArgStructTest {
foo: 42,
bar: String::from("Meaning of life"),
})
}
#[dbus_interface(out_args("foo", "bar"))]
fn test_multi_ret(&self) -> zbus::Result<(i32, String)> {
Ok((42, String::from("Meaning of life")))
}
fn test_hashmap_return(&self) -> zbus::Result<HashMap<String, String>> {
let mut map = HashMap::new();
map.insert("hi".into(), "hello".into());
map.insert("bye".into(), "now".into());
Ok(map)
}
fn create_obj(&self, key: String) {
self.action.set(NextAction::CreateObj(key));
}
fn destroy_obj(&self, key: String) {
self.action.set(NextAction::DestroyObj(key));
}
#[dbus_interface(property)]
fn set_count(&mut self, val: u32) -> zbus::fdo::Result<()> {
if val == 42 {
return Err(zbus::fdo::Error::InvalidArgs("Tsss tsss!".to_string()));
}
self.count = val;
Ok(())
}
#[dbus_interface(property)]
fn count(&self) -> u32 {
self.count
}
#[dbus_interface(property)]
fn hash_map(&self) -> HashMap<String, String> {
self.test_hashmap_return().unwrap()
}
#[dbus_interface(signal)]
fn alert_count(&self, val: u32) -> zbus::Result<()>;
}
fn check_hash_map(map: HashMap<String, String>) {
assert_eq!(map["hi"], "hello");
assert_eq!(map["bye"], "now");
}
fn my_iface_test() -> std::result::Result<u32, Box<dyn Error>> {
let conn = Connection::new_session()?;
let proxy = MyIfaceProxy::new_for(
&conn,
"org.freedesktop.MyService",
"/org/freedesktop/MyService",
)?;
proxy.ping()?;
assert_eq!(proxy.count()?, 1);
proxy.test_header()?;
proxy.test_single_struct_arg(ArgStructTest {
foo: 1,
bar: "TestString".into(),
})?;
check_hash_map(proxy.test_hashmap_return()?);
check_hash_map(proxy.hash_map()?);
#[cfg(feature = "xml")]
{
let xml = proxy.introspect()?;
let node = crate::xml::Node::from_reader(xml.as_bytes())?;
let ifaces = node.interfaces();
let iface = ifaces
.iter()
.find(|i| i.name() == "org.freedesktop.MyIface")
.unwrap();
let methods = iface.methods();
for method in methods {
if method.name() != "TestSingleStructRet" && method.name() != "TestMultiRet" {
continue;
}
let args = method.args();
let mut out_args = args.iter().filter(|a| a.direction().unwrap() == "out");
if method.name() == "TestSingleStructRet" {
assert_eq!(args.len(), 1);
assert_eq!(out_args.next().unwrap().ty(), "(is)");
assert!(out_args.next().is_none());
} else {
assert_eq!(args.len(), 2);
let foo = out_args.find(|a| a.name() == Some("foo")).unwrap();
assert_eq!(foo.ty(), "i");
let bar = out_args.find(|a| a.name() == Some("bar")).unwrap();
assert_eq!(bar.ty(), "s");
}
}
}
let _ = proxy.test_single_struct_ret()?.foo;
let _ = proxy.test_multi_ret()?.1;
let val = proxy.ping()?;
proxy.create_obj("MyObj")?;
let my_obj_proxy =
MyIfaceProxy::new_for(&conn, "org.freedesktop.MyService", "/zbus/test/MyObj")?;
my_obj_proxy.ping()?;
proxy.destroy_obj("MyObj")?;
assert!(my_obj_proxy.introspect().is_err());
assert!(my_obj_proxy.ping().is_err());
proxy.quit()?;
Ok(val)
}
#[test]
#[timeout(2000)]
fn basic_iface() {
let conn = Connection::new_session().unwrap();
let mut object_server = ObjectServer::new(&conn);
let action = Rc::new(Cell::new(NextAction::Nothing));
fdo::DBusProxy::new(&conn)
.unwrap()
.request_name(
"org.freedesktop.MyService",
fdo::RequestNameFlags::ReplaceExisting.into(),
)
.unwrap();
let iface = MyIfaceImpl::new(action.clone());
object_server
.at("/org/freedesktop/MyService", iface)
.unwrap();
let child = thread::spawn(|| my_iface_test().expect("child failed"));
loop {
let m = conn.receive_message().unwrap();
if let Err(e) = object_server.dispatch_message(&m) {
eprintln!("{}", e);
}
object_server
.with("/org/freedesktop/MyService", |iface: &MyIfaceImpl| {
iface.alert_count(51)
})
.unwrap();
match action.replace(NextAction::Nothing) {
NextAction::Nothing => (),
NextAction::Quit => break,
NextAction::CreateObj(key) => {
let path = format!("/zbus/test/{}", key);
object_server
.at(path, MyIfaceImpl::new(action.clone()))
.unwrap();
}
NextAction::DestroyObj(key) => {
let path = format!("/zbus/test/{}", key);
object_server.remove::<MyIfaceImpl, _>(path).unwrap();
}
}
}
let val = child.join().expect("failed to join");
assert_eq!(val, 2);
}
}