use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use zbus::message::Header;
use zbus::zvariant::{OwnedObjectPath, OwnedValue};
use zbus::{interface, Connection, ObjectServer};
use crate::error::{Error, Result};
const NOTIFICATIONS_NAME: &str = "org.freedesktop.Notifications";
const NOTIFICATIONS_PATH: &str = "/org/freedesktop/Notifications";
const PORTAL_NAME: &str = "org.freedesktop.portal.Desktop";
const PORTAL_PATH: &str = "/org/freedesktop/portal/desktop";
#[derive(Debug, Clone)]
pub struct CapturedNotification {
pub seq: u64,
pub app_name: String,
pub replaces_id: u32,
pub app_icon: String,
pub summary: String,
pub body: String,
pub actions: Vec<String>,
pub hints: Vec<String>,
pub expire_timeout: i32,
pub id: u32,
}
#[derive(Debug, Clone)]
pub struct CapturedOpenUri {
pub seq: u64,
pub parent_window: String,
pub uri: String,
pub options: Vec<String>,
}
struct Records<T> {
items: Mutex<Vec<T>>,
notify: Notify,
seq: AtomicU64,
}
impl<T> Default for Records<T> {
fn default() -> Self {
Self {
items: Mutex::new(Vec::new()),
notify: Notify::new(),
seq: AtomicU64::new(0),
}
}
}
impl<T: Clone> Records<T> {
fn push(&self, make: impl FnOnce(u64) -> T) {
let seq = self.seq.fetch_add(1, Ordering::Relaxed);
self.items.lock().unwrap().push(make(seq));
self.notify.notify_waiters();
}
fn snapshot(&self) -> Vec<T> {
self.items.lock().unwrap().clone()
}
fn len(&self) -> usize {
self.items.lock().unwrap().len()
}
async fn wait_for<F>(
&self,
after: usize,
pred: F,
timeout: Duration,
cancel: &CancellationToken,
) -> Result<T>
where
F: Fn(&T) -> bool,
{
let deadline = tokio::time::Instant::now() + timeout;
loop {
let notified = self.notify.notified();
tokio::pin!(notified);
{
let guard = self.items.lock().unwrap();
if let Some(found) = guard.iter().skip(after).find(|i| pred(i)) {
return Ok(found.clone());
}
}
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(Error::Timeout(format!(
"no captured entry matched within {timeout:?}"
)));
}
tokio::select! {
_ = &mut notified => {}
_ = tokio::time::sleep(remaining) => {
return Err(Error::Timeout(format!(
"no captured entry matched within {timeout:?}"
)));
}
_ = cancel.cancelled() => return Err(Error::Cancelled),
}
}
}
}
fn render_dict(dict: &HashMap<String, OwnedValue>) -> Vec<String> {
let mut rendered: Vec<String> = dict.iter().map(|(k, v)| format!("{k}={v:?}")).collect();
rendered.sort();
rendered
}
fn request_handle_path(sender: &str, token: &str) -> String {
let mut sender_id = sender.trim_start_matches(':').replace('.', "_");
if sender_id.is_empty() {
sender_id = "wd".to_string();
}
format!("/org/freedesktop/portal/desktop/request/{sender_id}/{token}")
}
struct NotificationsIface {
records: Arc<Records<CapturedNotification>>,
next_id: AtomicU32,
}
#[interface(name = "org.freedesktop.Notifications")]
impl NotificationsIface {
#[allow(clippy::too_many_arguments)]
fn notify(
&self,
app_name: String,
replaces_id: u32,
app_icon: String,
summary: String,
body: String,
actions: Vec<String>,
hints: HashMap<String, OwnedValue>,
expire_timeout: i32,
) -> u32 {
let id = if replaces_id != 0 {
replaces_id
} else {
self.next_id.fetch_add(1, Ordering::Relaxed)
};
let hints = render_dict(&hints);
self.records.push(|seq| CapturedNotification {
seq,
app_name,
replaces_id,
app_icon,
summary,
body,
actions,
hints,
expire_timeout,
id,
});
id
}
fn close_notification(&self, _id: u32) {}
fn get_capabilities(&self) -> Vec<String> {
vec![
"body".to_string(),
"body-markup".to_string(),
"actions".to_string(),
]
}
fn get_server_information(&self) -> (String, String, String, String) {
(
"waydriver".to_string(),
"waydriver".to_string(),
env!("CARGO_PKG_VERSION").to_string(),
"1.2".to_string(),
)
}
}
struct PortalRequest;
#[interface(name = "org.freedesktop.portal.Request")]
impl PortalRequest {
fn close(&self) {}
}
struct OpenUriIface {
records: Arc<Records<CapturedOpenUri>>,
token_counter: AtomicU64,
}
#[interface(name = "org.freedesktop.portal.OpenURI")]
impl OpenUriIface {
#[zbus(property, name = "version")]
fn version(&self) -> u32 {
4
}
#[zbus(name = "OpenURI")]
async fn open_uri(
&self,
#[zbus(header)] header: Header<'_>,
#[zbus(connection)] conn: &Connection,
#[zbus(object_server)] server: &ObjectServer,
parent_window: String,
uri: String,
options: HashMap<String, OwnedValue>,
) -> OwnedObjectPath {
let sender = header
.sender()
.map(|s| s.as_str().to_string())
.unwrap_or_default();
let token = options
.get("handle_token")
.and_then(|v| String::try_from(v.clone()).ok())
.unwrap_or_else(|| format!("wd{}", self.token_counter.fetch_add(1, Ordering::Relaxed)));
let handle = request_handle_path(&sender, &token);
let rendered = render_dict(&options);
self.records.push(|seq| CapturedOpenUri {
seq,
parent_window,
uri,
options: rendered,
});
let path = OwnedObjectPath::try_from(handle.clone()).unwrap_or_else(|e| {
tracing::warn!(error = %e, handle, "portal: invalid request handle path; using fallback");
OwnedObjectPath::try_from("/org/freedesktop/portal/desktop/request/wd/wd")
.expect("static fallback path is valid")
});
let _ = server.at(&path, PortalRequest).await;
let results: HashMap<String, OwnedValue> = HashMap::new();
let dest = (!sender.is_empty()).then_some(sender.as_str());
if let Err(e) = conn
.emit_signal(
dest,
&path,
"org.freedesktop.portal.Request",
"Response",
&(0u32, results),
)
.await
{
tracing::warn!(error = %e, "portal: failed to emit Response signal");
}
path
}
}
pub struct ExternalSinks {
_conn: Connection,
notifications: Arc<Records<CapturedNotification>>,
open_uris: Arc<Records<CapturedOpenUri>>,
}
impl ExternalSinks {
pub async fn start(dbus_address: &str) -> Result<Self> {
let address: zbus::address::Address =
dbus_address.try_into().map_err(|e: zbus::Error| {
Error::process_with("external sinks: invalid dbus address", e)
})?;
let conn = zbus::connection::Builder::address(address)?.build().await?;
let notifications: Arc<Records<CapturedNotification>> = Arc::new(Records::default());
let open_uris: Arc<Records<CapturedOpenUri>> = Arc::new(Records::default());
conn.object_server()
.at(
NOTIFICATIONS_PATH,
NotificationsIface {
records: notifications.clone(),
next_id: AtomicU32::new(1),
},
)
.await?;
conn.object_server()
.at(
PORTAL_PATH,
OpenUriIface {
records: open_uris.clone(),
token_counter: AtomicU64::new(1),
},
)
.await?;
request_name_best_effort(&conn, NOTIFICATIONS_NAME).await;
request_name_best_effort(&conn, PORTAL_NAME).await;
Ok(Self {
_conn: conn,
notifications,
open_uris,
})
}
pub fn notifications(&self) -> Vec<CapturedNotification> {
self.notifications.snapshot()
}
pub fn open_uri_requests(&self) -> Vec<CapturedOpenUri> {
self.open_uris.snapshot()
}
pub fn notification_count(&self) -> usize {
self.notifications.len()
}
pub fn open_uri_count(&self) -> usize {
self.open_uris.len()
}
pub async fn wait_for_notification<F>(
&self,
after: usize,
pred: F,
timeout: Duration,
cancel: &CancellationToken,
) -> Result<CapturedNotification>
where
F: Fn(&CapturedNotification) -> bool,
{
self.notifications
.wait_for(after, pred, timeout, cancel)
.await
}
pub async fn wait_for_open_uri<F>(
&self,
after: usize,
pred: F,
timeout: Duration,
cancel: &CancellationToken,
) -> Result<CapturedOpenUri>
where
F: Fn(&CapturedOpenUri) -> bool,
{
self.open_uris.wait_for(after, pred, timeout, cancel).await
}
}
async fn request_name_best_effort(conn: &Connection, name: &str) {
use zbus::fdo::{RequestNameFlags, RequestNameReply};
match conn
.request_name_with_flags(name, RequestNameFlags::DoNotQueue.into())
.await
{
Ok(RequestNameReply::PrimaryOwner) => {
tracing::info!(name, "external sink claimed bus name");
}
Ok(other) => {
tracing::warn!(
name,
reply = %other,
"external sink could not claim bus name (already owned?); \
capture for this interface will be inactive"
);
}
Err(e) => {
tracing::warn!(
name,
error = %e,
"external sink RequestName failed; capture for this interface will be inactive"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use zbus::zvariant::Value;
#[test]
fn records_push_assigns_monotonic_seq_and_snapshots() {
let r: Records<u64> = Records::default();
assert_eq!(r.len(), 0);
r.push(|seq| seq);
r.push(|seq| seq * 10);
assert_eq!(r.snapshot(), vec![0, 10]);
assert_eq!(r.len(), 2);
}
#[test]
fn request_handle_path_strips_unique_name() {
assert_eq!(
request_handle_path(":1.42", "tok"),
"/org/freedesktop/portal/desktop/request/1_42/tok"
);
assert_eq!(
request_handle_path("", "tok"),
"/org/freedesktop/portal/desktop/request/wd/tok"
);
}
#[test]
fn render_dict_is_sorted_key_value() {
let mut d: HashMap<String, OwnedValue> = HashMap::new();
d.insert(
"zeta".to_string(),
OwnedValue::try_from(Value::from(2u32)).unwrap(),
);
d.insert(
"alpha".to_string(),
OwnedValue::try_from(Value::from(1u32)).unwrap(),
);
let r = render_dict(&d);
assert_eq!(r.len(), 2);
assert!(r[0].starts_with("alpha="), "got {r:?}");
assert!(r[1].starts_with("zeta="), "got {r:?}");
}
#[tokio::test]
async fn wait_for_returns_existing_match() {
let r: Records<u64> = Records::default();
r.push(|_| 7);
let token = CancellationToken::new();
let got = r
.wait_for(0, |v| *v == 7, Duration::from_secs(1), &token)
.await
.unwrap();
assert_eq!(got, 7);
}
#[tokio::test]
async fn wait_for_times_out_without_match() {
let r: Records<u64> = Records::default();
let token = CancellationToken::new();
let res = r
.wait_for(0, |_| false, Duration::from_millis(50), &token)
.await;
assert!(matches!(res, Err(Error::Timeout(_))));
}
#[tokio::test]
async fn wait_for_wakes_on_push() {
let r: Arc<Records<u64>> = Arc::new(Records::default());
let r2 = r.clone();
let token = CancellationToken::new();
let waiter = tokio::spawn(async move {
r2.wait_for(
0,
|v| *v == 99,
Duration::from_secs(5),
&CancellationToken::new(),
)
.await
});
tokio::time::sleep(Duration::from_millis(20)).await;
r.push(|_| 99);
let got = waiter.await.unwrap().unwrap();
assert_eq!(got, 99);
let _ = token; }
#[tokio::test]
async fn wait_for_cancels() {
let r: Records<u64> = Records::default();
let token = CancellationToken::new();
token.cancel();
let res = r
.wait_for(0, |_| false, Duration::from_secs(5), &token)
.await;
assert!(matches!(res, Err(Error::Cancelled)));
}
}