use std::{
collections::HashMap,
fmt::Debug,
time::{Duration, SystemTime},
};
use async_channel::{Receiver, Sender};
use futures::{FutureExt, StreamExt, TryStreamExt, channel::oneshot};
use waybar_cffi::gtk::glib;
use zbus::{
Connection, MatchRule, MessageStream,
fdo::{DBusProxy, MonitoringProxy, NameOwnerChanged},
message::Type,
names::UniqueName,
};
#[derive(Debug, Clone)]
pub struct ConnectionCache {
tx: Sender<Request>,
}
impl ConnectionCache {
pub fn new(expiry: Duration) -> Self {
let (tx, rx) = async_channel::unbounded();
glib::spawn_future_local(async move {
if let Err(e) = worker(rx, expiry).await {
tracing::error!(%e, "connection cache worker error");
}
});
Self { tx }
}
#[tracing::instrument(level = "TRACE", skip(self))]
pub async fn get(&self, connection: impl ToString + Debug) -> Option<u32> {
let (tx, rx) = oneshot::channel();
if let Err(e) = self
.tx
.send(Request::Get {
connection: connection.to_string(),
result: tx,
})
.await
{
tracing::error!(%e, "error sending request to connection cache");
return None;
}
rx.await.unwrap_or(None)
}
}
#[derive(Debug)]
enum Request {
Get {
connection: String,
result: oneshot::Sender<Option<u32>>,
},
}
#[derive(Debug)]
struct Entry {
pid: Option<u32>,
expiry: SystemTime,
}
static DBUS_INTERFACE: &str = "org.freedesktop.DBus";
async fn worker(rx: Receiver<Request>, expiry: Duration) -> anyhow::Result<()> {
let mut cache = Cache::new(expiry);
let dbus_conn = Connection::session().await?;
let dbus_proxy = DBusProxy::new(&dbus_conn).await?;
let monitor_conn = Connection::session().await?;
let monitor_proxy = MonitoringProxy::new(&monitor_conn).await?;
monitor_proxy
.become_monitor(
&[MatchRule::builder()
.msg_type(Type::Signal)
.interface(DBUS_INTERFACE)?
.member("NameOwnerChanged")?
.build()],
0,
)
.await?;
let mut cleanup = glib::interval_stream(Duration::from_secs(60)).fuse();
let mut stream = MessageStream::from(monitor_conn);
loop {
futures::select! {
result = stream.try_next() => {
match result {
Ok(Some(msg)) => {
handle_zbus_message(&mut cache, &dbus_proxy, msg).await;
}
Ok(None) => {
tracing::error!("D-Bus monitor stream closed unexpectedly");
break;
}
Err(e) => {
tracing::error!(%e, "D-Bus monitor stream error");
anyhow::bail!(e);
}
}
}
result = rx.recv().fuse() => {
match result {
Ok(msg) => {
handle_message(&mut cache, &dbus_proxy, msg).await;
}
Err(_) => {
break;
}
}
}
_ = cleanup.next() => {
cache.expire(SystemTime::now());
}
}
}
Ok(())
}
async fn handle_zbus_message(
cache: &mut Cache,
dbus_proxy: &DBusProxy<'_>,
message: zbus::Message,
) {
if let Some(message) = NameOwnerChanged::from_message(message) {
if let Ok(args) = message.args() {
if let Some(new_owner) = args.new_owner().as_ref() {
if let Ok(pid) = dbus_proxy
.get_connection_unix_process_id(new_owner.clone().into())
.await
{
cache.insert(new_owner, Some(pid));
}
} else if let Some(old_owner) = args.old_owner.as_ref() {
cache.remove(old_owner);
}
}
}
}
async fn handle_message(cache: &mut Cache, dbus_proxy: &DBusProxy<'_>, message: Request) {
match message {
Request::Get { connection, result } => {
if let Some(maybe_pid) = cache.get(&connection) {
let _ = result.send(maybe_pid);
} else if let Ok(name) = UniqueName::try_from(connection.as_str()) {
if let Ok(pid) = dbus_proxy.get_connection_unix_process_id(name.into()).await {
cache.insert(connection, Some(pid));
let _ = result.send(Some(pid));
}
}
}
}
}
#[derive(Debug)]
struct Cache {
cache: HashMap<String, Entry>,
expiry: Duration,
}
impl Cache {
pub fn new(expiry: Duration) -> Self {
Self {
cache: Default::default(),
expiry,
}
}
pub fn expire(&mut self, now: SystemTime) {
self.cache.retain(|_, entry| entry.expiry > now);
}
pub fn get(&mut self, connection: &str) -> Option<Option<u32>> {
self.cache.get_mut(connection).map(|entry| {
entry.expiry = SystemTime::now() + self.expiry;
entry.pid
})
}
pub fn insert(&mut self, connection: impl ToString, pid: Option<u32>) {
self.cache.insert(
connection.to_string(),
Entry {
pid,
expiry: SystemTime::now() + self.expiry,
},
);
}
pub fn remove(&mut self, connection: &str) {
self.cache.remove(connection);
}
}