roder_api/
notifications.rs1use serde::{Deserialize, Serialize};
2use time::OffsetDateTime;
3
4use crate::events::{ThreadId, TurnId};
5use crate::extension::NotificationSinkId;
6use crate::tasks::TaskId;
7
8pub type NotificationId = String;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(rename_all = "snake_case")]
12pub enum NotificationKind {
13 NeedsInput,
14 TurnIdle,
15 TaskCompleted,
16 TaskFailed,
17 Custom(String),
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub struct Notification {
22 pub id: NotificationId,
23 pub kind: NotificationKind,
24 pub title: String,
25 #[serde(default, skip_serializing_if = "Option::is_none")]
26 pub body: Option<String>,
27 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub task_id: Option<TaskId>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub thread_id: Option<ThreadId>,
31 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub turn_id: Option<TurnId>,
33 #[serde(with = "time::serde::rfc3339")]
34 pub timestamp: OffsetDateTime,
35 #[serde(default)]
36 pub metadata: serde_json::Value,
37}
38
39#[async_trait::async_trait]
40pub trait NotificationSink: Send + Sync + 'static {
41 fn id(&self) -> NotificationSinkId;
42
43 async fn deliver(&self, notification: Notification) -> anyhow::Result<()>;
44}
45
46#[cfg(test)]
47mod tests {
48 use std::sync::{Arc, Mutex};
49
50 use super::*;
51
52 struct CapturingSink {
53 delivered: Arc<Mutex<Vec<Notification>>>,
54 }
55
56 #[async_trait::async_trait]
57 impl NotificationSink for CapturingSink {
58 fn id(&self) -> NotificationSinkId {
59 "capture".to_string()
60 }
61
62 async fn deliver(&self, notification: Notification) -> anyhow::Result<()> {
63 self.delivered.lock().unwrap().push(notification);
64 Ok(())
65 }
66 }
67
68 #[test]
69 fn notification_round_trips_json() {
70 let notification = Notification {
71 id: "notice-1".to_string(),
72 kind: NotificationKind::TaskCompleted,
73 title: "Task completed".to_string(),
74 body: Some("process finished".to_string()),
75 task_id: Some("task-1".to_string()),
76 thread_id: Some("thread-a".to_string()),
77 turn_id: Some("turn-a".to_string()),
78 timestamp: OffsetDateTime::UNIX_EPOCH,
79 metadata: serde_json::json!({ "sink": "test" }),
80 };
81
82 let encoded = serde_json::to_value(¬ification).expect("serialize notification");
83 assert_eq!(encoded["kind"], "task_completed");
84
85 let decoded: Notification =
86 serde_json::from_value(encoded).expect("deserialize notification");
87 assert_eq!(decoded, notification);
88 }
89
90 #[tokio::test]
91 async fn notification_sink_trait_is_object_safe() {
92 let delivered = Arc::new(Mutex::new(Vec::new()));
93 let sink: Arc<dyn NotificationSink> = Arc::new(CapturingSink {
94 delivered: Arc::clone(&delivered),
95 });
96
97 sink.deliver(Notification {
98 id: "notice-1".to_string(),
99 kind: NotificationKind::NeedsInput,
100 title: "Approval needed".to_string(),
101 body: None,
102 task_id: None,
103 thread_id: Some("thread-a".to_string()),
104 turn_id: Some("turn-a".to_string()),
105 timestamp: OffsetDateTime::UNIX_EPOCH,
106 metadata: serde_json::json!({}),
107 })
108 .await
109 .unwrap();
110
111 assert_eq!(sink.id(), "capture");
112 assert_eq!(delivered.lock().unwrap().len(), 1);
113 }
114}