use event_listener::{Event, EventListener};
use serde::Serialize;
use std::{
collections::{hash_map::Entry, HashMap},
convert::TryInto,
fmt::Write,
marker::PhantomData,
ops::{Deref, DerefMut},
sync::Arc,
};
use tracing::{debug, instrument, trace};
use static_assertions::assert_impl_all;
use zbus_names::InterfaceName;
use zvariant::{ObjectPath, OwnedObjectPath, OwnedValue, Signature, Type, Value};
use crate::{
async_lock::{RwLock, RwLockReadGuard, RwLockWriteGuard},
fdo,
fdo::{Introspectable, ManagedObjects, ObjectManager, Peer, Properties},
Connection, DispatchResult, Error, Interface, Message, Result, SignalContext, WeakConnection,
};
pub struct InterfaceDeref<'d, I> {
iface: RwLockReadGuard<'d, dyn Interface>,
phantom: PhantomData<I>,
}
impl<I> Deref for InterfaceDeref<'_, I>
where
I: Interface,
{
type Target = I;
fn deref(&self) -> &I {
self.iface.downcast_ref::<I>().unwrap()
}
}
pub struct InterfaceDerefMut<'d, I> {
iface: RwLockWriteGuard<'d, dyn Interface>,
phantom: PhantomData<I>,
}
impl<I> Deref for InterfaceDerefMut<'_, I>
where
I: Interface,
{
type Target = I;
fn deref(&self) -> &I {
self.iface.downcast_ref::<I>().unwrap()
}
}
impl<I> DerefMut for InterfaceDerefMut<'_, I>
where
I: Interface,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.iface.downcast_mut::<I>().unwrap()
}
}
pub struct InterfaceRef<I> {
ctxt: SignalContext<'static>,
lock: Arc<RwLock<dyn Interface>>,
phantom: PhantomData<I>,
}
impl<I> InterfaceRef<I>
where
I: 'static,
{
pub async fn get(&self) -> InterfaceDeref<'_, I> {
let iface = self.lock.read().await;
iface
.downcast_ref::<I>()
.expect("Unexpected interface type");
InterfaceDeref {
iface,
phantom: PhantomData,
}
}
pub async fn get_mut(&self) -> InterfaceDerefMut<'_, I> {
let mut iface = self.lock.write().await;
iface
.downcast_ref::<I>()
.expect("Unexpected interface type");
iface
.downcast_mut::<I>()
.expect("Unexpected interface type");
InterfaceDerefMut {
iface,
phantom: PhantomData,
}
}
pub fn signal_context(&self) -> &SignalContext<'static> {
&self.ctxt
}
}
impl<I> Clone for InterfaceRef<I> {
fn clone(&self) -> Self {
Self {
ctxt: self.ctxt.clone(),
lock: self.lock.clone(),
phantom: PhantomData,
}
}
}
#[derive(Default, derivative::Derivative)]
#[derivative(Debug)]
pub(crate) struct Node {
path: OwnedObjectPath,
children: HashMap<String, Node>,
#[derivative(Debug = "ignore")]
interfaces: HashMap<InterfaceName<'static>, Arc<RwLock<dyn Interface>>>,
}
impl Node {
pub(crate) fn new(path: OwnedObjectPath) -> Self {
let mut node = Self {
path,
..Default::default()
};
node.at(Peer::name(), || Arc::new(RwLock::new(Peer)));
node.at(Introspectable::name(), || {
Arc::new(RwLock::new(Introspectable))
});
node.at(Properties::name(), || Arc::new(RwLock::new(Properties)));
node
}
pub(crate) fn get_child(&self, path: &ObjectPath<'_>) -> Option<&Node> {
let mut node = self;
for i in path.split('/').skip(1) {
if i.is_empty() {
continue;
}
match node.children.get(i) {
Some(n) => node = n,
None => return None,
}
}
Some(node)
}
fn get_child_mut(
&mut self,
path: &ObjectPath<'_>,
create: bool,
) -> (Option<&mut Node>, Option<ObjectPath<'_>>) {
let mut node = self;
let mut node_path = String::new();
let mut obj_manager_path = None;
for i in path.split('/').skip(1) {
if i.is_empty() {
continue;
}
if node.interfaces.contains_key(&ObjectManager::name()) {
obj_manager_path = Some((*node.path).clone());
}
write!(&mut node_path, "/{i}").unwrap();
match node.children.entry(i.into()) {
Entry::Vacant(e) => {
if create {
let path = node_path.as_str().try_into().expect("Invalid Object Path");
node = e.insert(Node::new(path));
} else {
return (None, obj_manager_path);
}
}
Entry::Occupied(e) => node = e.into_mut(),
}
}
(Some(node), obj_manager_path)
}
pub(crate) fn interface_lock(
&self,
interface_name: InterfaceName<'_>,
) -> Option<Arc<RwLock<dyn Interface>>> {
self.interfaces.get(&interface_name).cloned()
}
fn remove_interface(&mut self, interface_name: InterfaceName<'static>) -> bool {
self.interfaces.remove(&interface_name).is_some()
}
fn is_empty(&self) -> bool {
!self.interfaces.keys().any(|k| {
*k != Peer::name()
&& *k != Introspectable::name()
&& *k != Properties::name()
&& *k != ObjectManager::name()
})
}
fn remove_node(&mut self, node: &str) -> bool {
self.children.remove(node).is_some()
}
fn at<F>(&mut self, name: InterfaceName<'static>, iface_creator: F) -> bool
where
F: FnOnce() -> Arc<RwLock<dyn Interface>>,
{
match self.interfaces.entry(name) {
Entry::Vacant(e) => e.insert(iface_creator()),
Entry::Occupied(_) => return false,
};
true
}
#[async_recursion::async_recursion]
async fn introspect_to_writer<W: Write + Send>(&self, writer: &mut W, level: usize) {
if level == 0 {
writeln!(
writer,
r#"
<!DOCTYPE node PUBLIC "-//freedesktop//DTD D-BUS Object Introspection 1.0//EN"
"http://www.freedesktop.org/standards/dbus/1.0/introspect.dtd">
<node>"#
)
.unwrap();
}
for iface in self.interfaces.values() {
iface.read().await.introspect_to_writer(writer, level + 2);
}
for (path, node) in &self.children {
let level = level + 2;
writeln!(
writer,
"{:indent$}<node name=\"{}\">",
"",
path,
indent = level
)
.unwrap();
node.introspect_to_writer(writer, level).await;
writeln!(writer, "{:indent$}</node>", "", indent = level).unwrap();
}
if level == 0 {
writeln!(writer, "</node>").unwrap();
}
}
pub(crate) async fn introspect(&self) -> String {
let mut xml = String::with_capacity(1024);
self.introspect_to_writer(&mut xml, 0).await;
xml
}
#[async_recursion::async_recursion]
pub(crate) async fn get_managed_objects(&self) -> ManagedObjects {
let mut managed_objects = ManagedObjects::new();
for node in self.children.values() {
let mut interfaces = HashMap::new();
for iface_name in node.interfaces.keys().filter(|n| {
*n != &Peer::name()
&& *n != &Introspectable::name()
&& *n != &Properties::name()
&& *n != &ObjectManager::name()
}) {
let props = node.get_properties(iface_name.clone()).await;
interfaces.insert(iface_name.clone().into(), props);
}
managed_objects.insert(node.path.clone(), interfaces);
managed_objects.extend(node.get_managed_objects().await);
}
managed_objects
}
async fn get_properties(
&self,
interface_name: InterfaceName<'_>,
) -> HashMap<String, OwnedValue> {
self.interface_lock(interface_name)
.expect("Interface was added but not found")
.read()
.await
.get_all()
.await
}
}
#[derive(Debug)]
pub struct ObjectServer {
conn: WeakConnection,
root: RwLock<Node>,
}
assert_impl_all!(ObjectServer: Send, Sync, Unpin);
impl ObjectServer {
pub(crate) fn new(conn: &Connection) -> Self {
Self {
conn: conn.into(),
root: RwLock::new(Node::new("/".try_into().expect("zvariant bug"))),
}
}
pub(crate) fn root(&self) -> &RwLock<Node> {
&self.root
}
pub async fn at<'p, P, I>(&self, path: P, iface: I) -> Result<bool>
where
I: Interface,
P: TryInto<ObjectPath<'p>>,
P::Error: Into<Error>,
{
self.at_ready(path, I::name(), move || Arc::new(RwLock::new(iface)))
.await
}
pub(crate) async fn at_ready<'node, 'p, P, F>(
&'node self,
path: P,
name: InterfaceName<'static>,
iface_creator: F,
) -> Result<bool>
where
P: TryInto<ObjectPath<'p>>,
P::Error: Into<Error>,
F: FnOnce() -> Arc<RwLock<dyn Interface + 'static>>,
{
let path = path.try_into().map_err(Into::into)?;
let mut root = self.root().write().await;
let (node, manager_path) = root.get_child_mut(&path, true);
let node = node.unwrap();
let added = node.at(name.clone(), iface_creator);
if added {
if name == ObjectManager::name() {
let ctxt = SignalContext::new(&self.connection(), path)?;
let objects = node.get_managed_objects().await;
for (path, owned_interfaces) in objects {
let interfaces = owned_interfaces
.iter()
.map(|(i, props)| {
let props = props
.iter()
.map(|(k, v)| (k.as_str(), Value::from(v)))
.collect();
(i.into(), props)
})
.collect();
ObjectManager::interfaces_added(&ctxt, &path, &interfaces).await?;
}
} else if let Some(manager_path) = manager_path {
let ctxt = SignalContext::new(&self.connection(), manager_path.clone())?;
let mut interfaces = HashMap::new();
let owned_props = node.get_properties(name.clone()).await;
let props = owned_props
.iter()
.map(|(k, v)| (k.as_str(), Value::from(v)))
.collect();
interfaces.insert(name, props);
ObjectManager::interfaces_added(&ctxt, &path, &interfaces).await?;
}
}
Ok(added)
}
pub async fn remove<'p, I, P>(&self, path: P) -> Result<bool>
where
I: Interface,
P: TryInto<ObjectPath<'p>>,
P::Error: Into<Error>,
{
let path = path.try_into().map_err(Into::into)?;
let mut root = self.root.write().await;
let (node, manager_path) = root.get_child_mut(&path, false);
let node = node.ok_or(Error::InterfaceNotFound)?;
if !node.remove_interface(I::name()) {
return Err(Error::InterfaceNotFound);
}
if let Some(manager_path) = manager_path {
let ctxt = SignalContext::new(&self.connection(), manager_path.clone())?;
ObjectManager::interfaces_removed(&ctxt, &path, &[I::name()]).await?;
}
if node.is_empty() {
let mut path_parts = path.rsplit('/').filter(|i| !i.is_empty());
let last_part = path_parts.next().unwrap();
let ppath = ObjectPath::from_string_unchecked(
path_parts.fold(String::new(), |a, p| format!("/{p}{a}")),
);
root.get_child_mut(&ppath, false)
.0
.unwrap()
.remove_node(last_part);
return Ok(true);
}
Ok(false)
}
pub async fn interface<'p, P, I>(&self, path: P) -> Result<InterfaceRef<I>>
where
I: Interface,
P: TryInto<ObjectPath<'p>>,
P::Error: Into<Error>,
{
let path = path.try_into().map_err(Into::into)?;
let root = self.root().read().await;
let node = root.get_child(&path).ok_or(Error::InterfaceNotFound)?;
let lock = node
.interface_lock(I::name())
.ok_or(Error::InterfaceNotFound)?
.clone();
lock.read()
.await
.downcast_ref::<I>()
.ok_or(Error::InterfaceNotFound)?;
let conn = self.connection();
let ctxt = SignalContext::new(&conn, path).unwrap().into_owned();
Ok(InterfaceRef {
ctxt,
lock,
phantom: PhantomData,
})
}
#[instrument(skip(self, connection))]
async fn dispatch_method_call_try(
&self,
connection: &Connection,
msg: &Message,
) -> fdo::Result<Result<()>> {
let path = msg
.path()
.ok_or_else(|| fdo::Error::Failed("Missing object path".into()))?;
let iface_name = msg
.interface()
.ok_or_else(|| fdo::Error::Failed("Missing interface".into()))?;
let member = msg
.member()
.ok_or_else(|| fdo::Error::Failed("Missing member".into()))?;
let iface = {
let root = self.root.read().await;
let node = root
.get_child(&path)
.ok_or_else(|| fdo::Error::UnknownObject(format!("Unknown object '{path}'")))?;
node.interface_lock(iface_name.as_ref()).ok_or_else(|| {
fdo::Error::UnknownInterface(format!("Unknown interface '{iface_name}'"))
})?
};
trace!("acquiring read lock on interface `{}`", iface_name);
let read_lock = iface.read().await;
trace!("acquired read lock on interface `{}`", iface_name);
match read_lock.call(self, connection, msg, member.as_ref()) {
DispatchResult::NotFound => {
return Err(fdo::Error::UnknownMethod(format!(
"Unknown method '{member}'"
)));
}
DispatchResult::Async(f) => {
return Ok(f.await);
}
DispatchResult::RequiresMut => {}
}
drop(read_lock);
trace!("acquiring write lock on interface `{}`", iface_name);
let mut write_lock = iface.write().await;
trace!("acquired write lock on interface `{}`", iface_name);
match write_lock.call_mut(self, connection, msg, member.as_ref()) {
DispatchResult::NotFound => {}
DispatchResult::RequiresMut => {}
DispatchResult::Async(f) => {
return Ok(f.await);
}
}
drop(write_lock);
Err(fdo::Error::UnknownMethod(format!(
"Unknown method '{member}'"
)))
}
#[instrument(skip(self, connection))]
async fn dispatch_method_call(&self, connection: &Connection, msg: &Message) -> Result<()> {
match self.dispatch_method_call_try(connection, msg).await {
Err(e) => {
let hdr = msg.header()?;
debug!("Returning error: {}", e);
connection.reply_dbus_error(&hdr, e).await?;
Ok(())
}
Ok(r) => r,
}
}
#[instrument(skip(self))]
pub(crate) async fn dispatch_message(&self, msg: &Message) -> Result<bool> {
let conn = self.connection();
self.dispatch_method_call(&conn, msg).await?;
trace!("Handled: {}", msg);
Ok(true)
}
pub(crate) fn connection(&self) -> Connection {
self.conn
.upgrade()
.expect("ObjectServer can't exist w/o an associated Connection")
}
}
impl From<crate::blocking::ObjectServer> for ObjectServer {
fn from(server: crate::blocking::ObjectServer) -> Self {
server.into_inner()
}
}
#[derive(Debug)]
pub struct ResponseDispatchNotifier<R> {
response: R,
event: Option<Event>,
}
impl<R> ResponseDispatchNotifier<R> {
pub fn new(response: R) -> (Self, EventListener) {
let event = Event::new();
let listener = event.listen();
(
Self {
response,
event: Some(event),
},
listener,
)
}
}
impl<R> Serialize for ResponseDispatchNotifier<R>
where
R: Serialize,
{
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.response.serialize(serializer)
}
}
impl<R> Type for ResponseDispatchNotifier<R>
where
R: Type,
{
fn signature() -> Signature<'static> {
R::signature()
}
}
impl<T> Drop for ResponseDispatchNotifier<T> {
fn drop(&mut self) {
if let Some(event) = self.event.take() {
event.notify(usize::MAX);
}
}
}