Skip to main content

roder_api/
notifications.rs

1use serde::{Deserialize, Serialize};
2use time::OffsetDateTime;
3
4use crate::events::{ThreadId, TurnId};
5use crate::extension::NotificationSinkId;
6use crate::tasks::TaskId;
7
8pub type NotificationId = String;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(rename_all = "snake_case")]
12pub enum NotificationKind {
13    NeedsInput,
14    TurnIdle,
15    TaskCompleted,
16    TaskFailed,
17    Custom(String),
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub struct Notification {
22    pub id: NotificationId,
23    pub kind: NotificationKind,
24    pub title: String,
25    #[serde(default, skip_serializing_if = "Option::is_none")]
26    pub body: Option<String>,
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub task_id: Option<TaskId>,
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub thread_id: Option<ThreadId>,
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub turn_id: Option<TurnId>,
33    #[serde(with = "time::serde::rfc3339")]
34    pub timestamp: OffsetDateTime,
35    #[serde(default)]
36    pub metadata: serde_json::Value,
37}
38
39#[async_trait::async_trait]
40pub trait NotificationSink: Send + Sync + 'static {
41    fn id(&self) -> NotificationSinkId;
42
43    async fn deliver(&self, notification: Notification) -> anyhow::Result<()>;
44}
45
46#[cfg(test)]
47mod tests {
48    use std::sync::{Arc, Mutex};
49
50    use super::*;
51
52    struct CapturingSink {
53        delivered: Arc<Mutex<Vec<Notification>>>,
54    }
55
56    #[async_trait::async_trait]
57    impl NotificationSink for CapturingSink {
58        fn id(&self) -> NotificationSinkId {
59            "capture".to_string()
60        }
61
62        async fn deliver(&self, notification: Notification) -> anyhow::Result<()> {
63            self.delivered.lock().unwrap().push(notification);
64            Ok(())
65        }
66    }
67
68    #[test]
69    fn notification_round_trips_json() {
70        let notification = Notification {
71            id: "notice-1".to_string(),
72            kind: NotificationKind::TaskCompleted,
73            title: "Task completed".to_string(),
74            body: Some("process finished".to_string()),
75            task_id: Some("task-1".to_string()),
76            thread_id: Some("thread-a".to_string()),
77            turn_id: Some("turn-a".to_string()),
78            timestamp: OffsetDateTime::UNIX_EPOCH,
79            metadata: serde_json::json!({ "sink": "test" }),
80        };
81
82        let encoded = serde_json::to_value(&notification).expect("serialize notification");
83        assert_eq!(encoded["kind"], "task_completed");
84
85        let decoded: Notification =
86            serde_json::from_value(encoded).expect("deserialize notification");
87        assert_eq!(decoded, notification);
88    }
89
90    #[tokio::test]
91    async fn notification_sink_trait_is_object_safe() {
92        let delivered = Arc::new(Mutex::new(Vec::new()));
93        let sink: Arc<dyn NotificationSink> = Arc::new(CapturingSink {
94            delivered: Arc::clone(&delivered),
95        });
96
97        sink.deliver(Notification {
98            id: "notice-1".to_string(),
99            kind: NotificationKind::NeedsInput,
100            title: "Approval needed".to_string(),
101            body: None,
102            task_id: None,
103            thread_id: Some("thread-a".to_string()),
104            turn_id: Some("turn-a".to_string()),
105            timestamp: OffsetDateTime::UNIX_EPOCH,
106            metadata: serde_json::json!({}),
107        })
108        .await
109        .unwrap();
110
111        assert_eq!(sink.id(), "capture");
112        assert_eq!(delivered.lock().unwrap().len(), 1);
113    }
114}