Skip to main content

a2a_rs/port/
streaming_handler.rs

1//! Streaming and real-time update handling port definitions
2
3#[cfg(feature = "server")]
4use async_trait::async_trait;
5use futures::Stream;
6use std::pin::Pin;
7
8use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
9
10/// A trait for subscribing to real-time updates
11#[cfg(feature = "server")]
12#[async_trait]
13pub trait Subscriber<T>: Send + Sync {
14    /// Handle an update
15    async fn on_update(&self, update: T) -> Result<(), A2AError>;
16
17    /// Handle subscription errors
18    async fn on_error(&self, error: A2AError) -> Result<(), A2AError> {
19        // Default implementation - log error but don't propagate
20        eprintln!("Subscription error: {}", error);
21        Ok(())
22    }
23
24    /// Handle subscription completion
25    async fn on_complete(&self) -> Result<(), A2AError> {
26        // Default implementation - no-op
27        Ok(())
28    }
29}
30
31/// A trait for managing streaming connections and real-time updates
32pub trait StreamingHandler {
33    /// Add a status update subscriber for a task
34    fn add_status_subscriber(
35        &self,
36        task_id: &str,
37        subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
38    ) -> Result<String, A2AError>; // Returns subscription ID
39
40    /// Add an artifact update subscriber for a task
41    fn add_artifact_subscriber(
42        &self,
43        task_id: &str,
44        subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
45    ) -> Result<String, A2AError>; // Returns subscription ID
46
47    /// Remove a specific subscription
48    fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError>;
49
50    /// Remove all subscribers for a task
51    fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError>;
52
53    /// Get the number of active subscribers for a task
54    fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError>;
55
56    /// Check if a task has any active subscribers
57    fn has_subscribers(&self, task_id: &str) -> Result<bool, A2AError> {
58        let count = self.get_subscriber_count(task_id)?;
59        Ok(count > 0)
60    }
61}
62
63#[cfg(feature = "server")]
64#[async_trait]
65/// An async trait for managing streaming connections and real-time updates
66pub trait AsyncStreamingHandler: Send + Sync {
67    /// Add a status update subscriber for a task
68    async fn add_status_subscriber(
69        &self,
70        task_id: &str,
71        subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
72    ) -> Result<String, A2AError>; // Returns subscription ID
73
74    /// Add an artifact update subscriber for a task
75    async fn add_artifact_subscriber(
76        &self,
77        task_id: &str,
78        subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
79    ) -> Result<String, A2AError>; // Returns subscription ID
80
81    /// Remove a specific subscription
82    async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError>;
83
84    /// Remove all subscribers for a task
85    async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError>;
86
87    /// Get the number of active subscribers for a task
88    async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError>;
89
90    /// Check if a task has any active subscribers
91    async fn has_subscribers(&self, task_id: &str) -> Result<bool, A2AError> {
92        let count = self.get_subscriber_count(task_id).await?;
93        Ok(count > 0)
94    }
95
96    /// Broadcast a status update to all subscribers of a task
97    async fn broadcast_status_update(
98        &self,
99        task_id: &str,
100        update: TaskStatusUpdateEvent,
101    ) -> Result<(), A2AError>;
102
103    /// Broadcast an artifact update to all subscribers of a task
104    async fn broadcast_artifact_update(
105        &self,
106        task_id: &str,
107        update: TaskArtifactUpdateEvent,
108    ) -> Result<(), A2AError>;
109
110    /// Create a stream of status updates for a task
111    async fn status_update_stream(
112        &self,
113        task_id: &str,
114    ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>;
115
116    /// Create a stream of artifact updates for a task
117    async fn artifact_update_stream(
118        &self,
119        task_id: &str,
120    ) -> Result<
121        Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
122        A2AError,
123    >;
124
125    /// Create a combined stream of all updates for a task
126    async fn combined_update_stream(
127        &self,
128        task_id: &str,
129    ) -> Result<Pin<Box<dyn Stream<Item = Result<UpdateEvent, A2AError>> + Send>>, A2AError>;
130
131    /// Validate streaming parameters
132    async fn validate_streaming_params(&self, task_id: &str) -> Result<(), A2AError> {
133        if task_id.trim().is_empty() {
134            return Err(A2AError::ValidationError {
135                field: "task_id".to_string(),
136                message: "Task ID cannot be empty for streaming".to_string(),
137            });
138        }
139        Ok(())
140    }
141
142    /// Start streaming for a task with automatic cleanup
143    async fn start_task_streaming(
144        &self,
145        task_id: &str,
146    ) -> Result<Pin<Box<dyn Stream<Item = Result<UpdateEvent, A2AError>> + Send>>, A2AError> {
147        self.validate_streaming_params(task_id).await?;
148        self.combined_update_stream(task_id).await
149    }
150
151    /// Stop all streaming for a task
152    async fn stop_task_streaming(&self, task_id: &str) -> Result<(), A2AError> {
153        self.remove_task_subscribers(task_id).await
154    }
155}
156
157/// Union type for different kinds of updates that can be streamed
158#[derive(Debug, Clone)]
159pub enum UpdateEvent {
160    StatusUpdate(TaskStatusUpdateEvent),
161    ArtifactUpdate(TaskArtifactUpdateEvent),
162}
163
164impl UpdateEvent {
165    /// Get the task ID from the update event
166    #[inline]
167    pub fn task_id(&self) -> &str {
168        match self {
169            UpdateEvent::StatusUpdate(event) => &event.task_id,
170            UpdateEvent::ArtifactUpdate(event) => &event.task_id,
171        }
172    }
173
174    /// Get the context ID from the update event
175    #[inline]
176    pub fn context_id(&self) -> &str {
177        match self {
178            UpdateEvent::StatusUpdate(event) => &event.context_id,
179            UpdateEvent::ArtifactUpdate(event) => &event.context_id,
180        }
181    }
182
183    /// Check if this is a final update
184    #[inline]
185    pub fn is_final(&self) -> bool {
186        match self {
187            UpdateEvent::StatusUpdate(event) => event.final_,
188            UpdateEvent::ArtifactUpdate(event) => event.last_chunk.unwrap_or(false),
189        }
190    }
191}