uira-core 0.1.1

Shared types, events, protocol definitions, and configuration loading for Uira
Documentation
use crate::events::{Event, EventCategory};
use async_trait::async_trait;
use std::collections::HashSet;
use tokio::sync::broadcast;

#[derive(Debug, Clone, Default)]
pub struct SubscriptionFilter {
    categories: Option<HashSet<EventCategory>>,
    event_names: Option<HashSet<String>>,
    session_ids: Option<HashSet<String>>,
}

impl SubscriptionFilter {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn categories(mut self, categories: impl IntoIterator<Item = EventCategory>) -> Self {
        self.categories = Some(categories.into_iter().collect());
        self
    }

    pub fn event_names(mut self, names: impl IntoIterator<Item = impl Into<String>>) -> Self {
        self.event_names = Some(names.into_iter().map(|s| s.into()).collect());
        self
    }

    pub fn session_ids(mut self, ids: impl IntoIterator<Item = impl Into<String>>) -> Self {
        self.session_ids = Some(ids.into_iter().map(|s| s.into()).collect());
        self
    }

    pub fn matches(&self, event: &Event) -> bool {
        if let Some(ref cats) = self.categories {
            if !cats.contains(&event.category()) {
                return false;
            }
        }

        if let Some(ref names) = self.event_names {
            if !names.contains(event.event_name()) {
                return false;
            }
        }

        if let Some(ref ids) = self.session_ids {
            if let Some(session_id) = event.session_id() {
                if !ids.contains(session_id) {
                    return false;
                }
            } else {
                return false;
            }
        }

        true
    }

    pub fn is_wildcard(&self) -> bool {
        self.categories.is_none() && self.event_names.is_none() && self.session_ids.is_none()
    }
}

#[async_trait]
pub trait EventHandler: Send + Sync {
    fn name(&self) -> &str;
    fn filter(&self) -> SubscriptionFilter {
        SubscriptionFilter::new()
    }
    async fn handle(&self, event: &Event) -> HandlerResult;
    fn priority(&self) -> i32 {
        0
    }
}

#[derive(Debug, Clone)]
pub struct HandlerResult {
    pub should_continue: bool,
    pub message: Option<String>,
    pub modified_event: Option<Event>,
}

impl HandlerResult {
    pub fn pass() -> Self {
        Self {
            should_continue: true,
            message: None,
            modified_event: None,
        }
    }

    pub fn with_message(message: impl Into<String>) -> Self {
        Self {
            should_continue: true,
            message: Some(message.into()),
            modified_event: None,
        }
    }

    pub fn block(reason: impl Into<String>) -> Self {
        Self {
            should_continue: false,
            message: Some(reason.into()),
            modified_event: None,
        }
    }
}

pub struct Subscriber {
    receiver: broadcast::Receiver<Event>,
    filter: SubscriptionFilter,
}

impl Subscriber {
    pub fn new(receiver: broadcast::Receiver<Event>) -> Self {
        Self {
            receiver,
            filter: SubscriptionFilter::new(),
        }
    }

    pub fn with_filter(mut self, filter: SubscriptionFilter) -> Self {
        self.filter = filter;
        self
    }

    pub async fn recv(&mut self) -> Option<Event> {
        loop {
            match self.receiver.recv().await {
                Ok(event) => {
                    if self.filter.matches(&event) {
                        return Some(event);
                    }
                }
                Err(broadcast::error::RecvError::Closed) => return None,
                Err(broadcast::error::RecvError::Lagged(n)) => {
                    tracing::warn!("Subscriber lagged, missed {} events", n);
                    continue;
                }
            }
        }
    }

    pub fn try_recv(&mut self) -> Option<Event> {
        loop {
            match self.receiver.try_recv() {
                Ok(event) => {
                    if self.filter.matches(&event) {
                        return Some(event);
                    }
                }
                Err(broadcast::error::TryRecvError::Empty) => return None,
                Err(broadcast::error::TryRecvError::Closed) => return None,
                Err(broadcast::error::TryRecvError::Lagged(n)) => {
                    tracing::warn!("Subscriber lagged, missed {} events", n);
                    continue;
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_filter_categories() {
        let filter = SubscriptionFilter::new().categories([EventCategory::Tool]);

        let tool_event = Event::ToolExecutionStarted {
            session_id: "test".to_string(),
            tool_call_id: "tc_1".to_string(),
            tool_name: "bash".to_string(),
            input: serde_json::json!({}),
        };
        assert!(filter.matches(&tool_event));

        let session_event = Event::SessionStarted {
            session_id: "test".to_string(),
            parent_id: None,
        };
        assert!(!filter.matches(&session_event));
    }

    #[test]
    fn test_filter_event_names() {
        let filter = SubscriptionFilter::new().event_names(["session_started", "session_ended"]);

        let start_event = Event::SessionStarted {
            session_id: "test".to_string(),
            parent_id: None,
        };
        assert!(filter.matches(&start_event));

        let turn_event = Event::TurnStarted {
            session_id: "test".to_string(),
            turn_number: 1,
        };
        assert!(!filter.matches(&turn_event));
    }

    #[test]
    fn test_filter_session_ids() {
        let filter = SubscriptionFilter::new().session_ids(["ses_123"]);

        let matching = Event::TurnStarted {
            session_id: "ses_123".to_string(),
            turn_number: 1,
        };
        assert!(filter.matches(&matching));

        let non_matching = Event::TurnStarted {
            session_id: "ses_456".to_string(),
            turn_number: 1,
        };
        assert!(!filter.matches(&non_matching));
    }

    #[test]
    fn test_filter_wildcard() {
        let filter = SubscriptionFilter::new();
        assert!(filter.is_wildcard());

        let filter_with_cat = SubscriptionFilter::new().categories([EventCategory::Tool]);
        assert!(!filter_with_cat.is_wildcard());
    }

    #[test]
    fn test_handler_result() {
        let pass = HandlerResult::pass();
        assert!(pass.should_continue);
        assert!(pass.message.is_none());

        let with_msg = HandlerResult::with_message("hello");
        assert!(with_msg.should_continue);
        assert_eq!(with_msg.message, Some("hello".to_string()));

        let block = HandlerResult::block("blocked");
        assert!(!block.should_continue);
        assert_eq!(block.message, Some("blocked".to_string()));
    }
}