use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FocusEvent {
pub app_id: String,
pub window_title: Option<String>,
}
impl FocusEvent {
pub fn new(app_id: impl Into<String>, window_title: Option<String>) -> Self {
Self {
app_id: app_id.into(),
window_title,
}
}
pub fn from_app_id(app_id: impl Into<String>) -> Self {
Self {
app_id: app_id.into(),
window_title: None,
}
}
pub fn matches(&self, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
if self.app_id == pattern {
return true;
}
if pattern.starts_with('.') {
return self.app_id.ends_with(pattern);
}
if pattern.ends_with('.') {
return self.app_id.starts_with(pattern);
}
false
}
}
pub struct FocusTracker {
portal: Option<Arc<FocusPortal>>,
running: Arc<AtomicBool>,
}
#[allow(dead_code)]
struct FocusPortal {
available: bool,
backend: String,
}
impl FocusPortal {
async fn try_new() -> Option<Self> {
if std::env::var("WAYLAND_DISPLAY").is_err() {
tracing::warn!("Not running on Wayland, focus tracking unavailable");
return None;
}
match ashpd::desktop::global_shortcuts::GlobalShortcuts::new().await {
Ok(_) => {
tracing::info!("Successfully connected to xdg-desktop-portal");
Some(Self {
available: true,
backend: "xdg-desktop-portal".to_string(),
})
}
Err(e) => {
tracing::warn!("Failed to connect to xdg-desktop-portal: {}", e);
tracing::warn!("Focus tracking will be unavailable");
None
}
}
}
}
impl FocusTracker {
pub async fn new() -> Self {
let portal = tokio::task::spawn(async {
FocusPortal::try_new().await
})
.await
.ok()
.and_then(|r| r);
Self {
portal: portal.map(Arc::new),
running: Arc::new(AtomicBool::new(false)),
}
}
pub fn is_available(&self) -> bool {
self.portal.is_some()
}
pub async fn start(&self, _tx: mpsc::Sender<FocusEvent>) -> Result<(), String> {
if self.running.swap(true, Ordering::SeqCst) {
return Err("Focus tracking is already running".to_string());
}
let running = self.running.clone();
let _portal = self.portal.clone();
tokio::spawn(async move {
tracing::info!("Focus tracking task started");
if _portal.is_some() {
while running.load(Ordering::SeqCst) {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
}
} else {
tracing::warn!("Focus tracking portal unavailable, task exiting");
}
tracing::info!("Focus tracking task stopped");
});
Ok(())
}
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
}
impl Default for FocusTracker {
fn default() -> Self {
let rt = tokio::runtime::Runtime::new()
.expect("Failed to create tokio runtime");
rt.block_on(Self::new())
}
}
pub async fn start_focus_tracking<F>(
callback: F,
) -> Result<(FocusTracker, tokio::task::JoinHandle<()>), String>
where
F: Fn(FocusEvent) + Send + 'static,
{
let tracker = FocusTracker::new().await;
if !tracker.is_available() {
return Err("Focus tracking portal unavailable".to_string());
}
let (tx, mut rx) = mpsc::channel(32);
tracker.start(tx).await?;
let handle = tokio::spawn(async move {
while let Some(event) = rx.recv().await {
callback(event);
}
});
Ok((tracker, handle))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_focus_event_creation() {
let event = FocusEvent::new("org.alacritty".to_string(), Some("Alacritty".to_string()));
assert_eq!(event.app_id, "org.alacritty");
assert_eq!(event.window_title, Some("Alacritty".to_string()));
}
#[test]
fn test_focus_event_from_app_id() {
let event = FocusEvent::from_app_id("firefox");
assert_eq!(event.app_id, "firefox");
assert_eq!(event.window_title, None);
}
#[test]
fn test_focus_event_matches_exact() {
let event = FocusEvent::from_app_id("org.alacritty");
assert!(event.matches("org.alacritty"));
assert!(!event.matches("org.mozilla.firefox"));
}
#[test]
fn test_focus_event_matches_wildcard() {
let event = FocusEvent::from_app_id("org.alacritty");
assert!(event.matches("*"));
}
#[test]
fn test_focus_event_matches_suffix() {
let event = FocusEvent::from_app_id("org.mozilla.firefox");
assert!(event.matches(".firefox"));
assert!(event.matches(".mozilla.firefox"));
assert!(!event.matches(".alacritty"));
}
#[test]
fn test_focus_event_matches_prefix() {
let event = FocusEvent::from_app_id("org.mozilla.firefox");
assert!(event.matches("org.mozilla."));
assert!(event.matches("org."));
assert!(!event.matches("com."));
}
#[tokio::test]
async fn test_focus_tracker_creation() {
let tracker = FocusTracker::new().await;
assert!(!tracker.is_running());
}
#[test]
fn test_focus_tracker_default() {
let tracker = FocusTracker::default();
assert!(!tracker.is_running());
}
}