Skip to main content

a2a_client/components/
streaming.rs

1//! Server-Sent Events (SSE) streaming components
2
3use a2a_rs::services::{AsyncA2AClient, StreamItem};
4use axum::response::sse::{Event, KeepAlive, Sse};
5use futures::StreamExt;
6use std::{convert::Infallible, sync::Arc, time::Duration};
7use tracing::{error, info, warn};
8
9use crate::WebA2AClient;
10
11/// Create an SSE stream for task updates
12///
13/// This function handles:
14/// - WebSocket streaming if available
15/// - Fallback to HTTP polling
16/// - Automatic retry logic
17/// - Serialization to JSON events
18pub fn create_sse_stream(
19    client: Arc<WebA2AClient>,
20    task_id: String,
21) -> Sse<impl futures::Stream<Item = Result<Event, Infallible>>> {
22    let stream = async_stream::stream! {
23        // Check if we have a WebSocket client
24        if let Some(ws_client) = client.websocket() {
25            info!("Attempting to subscribe to task {} via WebSocket", task_id);
26
27            let mut retry_count = 0;
28            let max_retries = 60; // 60 retries with 1 second delay = 1 minute
29
30            loop {
31                match ws_client.subscribe_to_task(&task_id, Some(50)).await {
32                    Ok(mut event_stream) => {
33                        info!("Successfully subscribed to task {} via WebSocket", task_id);
34
35                        while let Some(result) = event_stream.next().await {
36                            match result {
37                                Ok(stream_item) => {
38                                    let (event_type, event_data) = match &stream_item {
39                                        StreamItem::Task(task) => {
40                                            match serde_json::to_string(task) {
41                                                Ok(json) => ("task-update", json),
42                                                Err(e) => {
43                                                    error!("Failed to serialize task: {}", e);
44                                                    continue;
45                                                }
46                                            }
47                                        }
48                                        StreamItem::StatusUpdate(status) => {
49                                            match serde_json::to_string(status) {
50                                                Ok(json) => ("task-status", json),
51                                                Err(e) => {
52                                                    error!("Failed to serialize status: {}", e);
53                                                    continue;
54                                                }
55                                            }
56                                        }
57                                        StreamItem::ArtifactUpdate(artifact) => {
58                                            match serde_json::to_string(artifact) {
59                                                Ok(json) => ("artifact", json),
60                                                Err(e) => {
61                                                    error!("Failed to serialize artifact: {}", e);
62                                                    continue;
63                                                }
64                                            }
65                                        }
66                                    };
67
68                                    yield Ok(Event::default()
69                                        .event(event_type)
70                                        .data(event_data));
71                                }
72                                Err(e) => {
73                                    warn!("Stream error (continuing): {}", e);
74                                    continue;
75                                }
76                            }
77                        }
78                        break;
79                    }
80                    Err(e) => {
81                        retry_count += 1;
82
83                        if retry_count <= max_retries {
84                            if retry_count == 1 {
85                                info!("Task {} not ready yet, will retry", task_id);
86                            }
87                            tokio::time::sleep(Duration::from_secs(1)).await;
88                            continue;
89                        } else {
90                            warn!("Failed to subscribe after {} retries: {}, falling back to polling", max_retries, e);
91                            loop {
92                                match client.http.get_task(&task_id, Some(50)).await {
93                                    Ok(task) => {
94                                        let task_json = match serde_json::to_string(&task) {
95                                            Ok(json) => json,
96                                            Err(e) => {
97                                                error!("Failed to serialize task: {}", e);
98                                                tokio::time::sleep(Duration::from_secs(2)).await;
99                                                continue;
100                                            }
101                                        };
102
103                                        yield Ok(Event::default()
104                                            .event("task-update")
105                                            .data(task_json));
106                                    }
107                                    Err(_) => {
108                                        // Task doesn't exist yet, keep polling silently
109                                    }
110                                }
111
112                                tokio::time::sleep(Duration::from_secs(2)).await;
113                            }
114                        }
115                    }
116                }
117            }
118        } else {
119            // Fallback: Poll for updates every 2 seconds
120            warn!("WebSocket not available, using polling fallback for task {}", task_id);
121            loop {
122                match client.http.get_task(&task_id, Some(50)).await {
123                    Ok(task) => {
124                        let task_json = match serde_json::to_string(&task) {
125                            Ok(json) => json,
126                            Err(e) => {
127                                error!("Failed to serialize task: {}", e);
128                                continue;
129                            }
130                        };
131
132                        yield Ok(Event::default()
133                            .event("task-update")
134                            .data(task_json));
135                    }
136                    Err(_) => {
137                        // Task doesn't exist yet, keep polling silently
138                    }
139                }
140
141                tokio::time::sleep(Duration::from_secs(2)).await;
142            }
143        }
144    };
145
146    Sse::new(stream).keep_alive(KeepAlive::default())
147}