a2a_client/components/streaming.rs
1//! Server-Sent Events (SSE) streaming components
2
3use a2a_rs::services::{AsyncA2AClient, StreamItem};
4use axum::response::sse::{Event, KeepAlive, Sse};
5use futures::StreamExt;
6use std::{convert::Infallible, sync::Arc, time::Duration};
7use tracing::{error, info, warn};
8
9use crate::WebA2AClient;
10
11/// Create an SSE stream for task updates
12///
13/// This function handles:
14/// - WebSocket streaming if available
15/// - Fallback to HTTP polling
16/// - Automatic retry logic
17/// - Serialization to JSON events
18pub fn create_sse_stream(
19 client: Arc<WebA2AClient>,
20 task_id: String,
21) -> Sse<impl futures::Stream<Item = Result<Event, Infallible>>> {
22 let stream = async_stream::stream! {
23 // Check if we have a WebSocket client
24 if let Some(ws_client) = client.websocket() {
25 info!("Attempting to subscribe to task {} via WebSocket", task_id);
26
27 let mut retry_count = 0;
28 let max_retries = 60; // 60 retries with 1 second delay = 1 minute
29
30 loop {
31 match ws_client.subscribe_to_task(&task_id, Some(50)).await {
32 Ok(mut event_stream) => {
33 info!("Successfully subscribed to task {} via WebSocket", task_id);
34
35 while let Some(result) = event_stream.next().await {
36 match result {
37 Ok(stream_item) => {
38 let (event_type, event_data) = match &stream_item {
39 StreamItem::Task(task) => {
40 match serde_json::to_string(task) {
41 Ok(json) => ("task-update", json),
42 Err(e) => {
43 error!("Failed to serialize task: {}", e);
44 continue;
45 }
46 }
47 }
48 StreamItem::StatusUpdate(status) => {
49 match serde_json::to_string(status) {
50 Ok(json) => ("task-status", json),
51 Err(e) => {
52 error!("Failed to serialize status: {}", e);
53 continue;
54 }
55 }
56 }
57 StreamItem::ArtifactUpdate(artifact) => {
58 match serde_json::to_string(artifact) {
59 Ok(json) => ("artifact", json),
60 Err(e) => {
61 error!("Failed to serialize artifact: {}", e);
62 continue;
63 }
64 }
65 }
66 };
67
68 yield Ok(Event::default()
69 .event(event_type)
70 .data(event_data));
71 }
72 Err(e) => {
73 warn!("Stream error (continuing): {}", e);
74 continue;
75 }
76 }
77 }
78 break;
79 }
80 Err(e) => {
81 retry_count += 1;
82
83 if retry_count <= max_retries {
84 if retry_count == 1 {
85 info!("Task {} not ready yet, will retry", task_id);
86 }
87 tokio::time::sleep(Duration::from_secs(1)).await;
88 continue;
89 } else {
90 warn!("Failed to subscribe after {} retries: {}, falling back to polling", max_retries, e);
91 loop {
92 match client.http.get_task(&task_id, Some(50)).await {
93 Ok(task) => {
94 let task_json = match serde_json::to_string(&task) {
95 Ok(json) => json,
96 Err(e) => {
97 error!("Failed to serialize task: {}", e);
98 tokio::time::sleep(Duration::from_secs(2)).await;
99 continue;
100 }
101 };
102
103 yield Ok(Event::default()
104 .event("task-update")
105 .data(task_json));
106 }
107 Err(_) => {
108 // Task doesn't exist yet, keep polling silently
109 }
110 }
111
112 tokio::time::sleep(Duration::from_secs(2)).await;
113 }
114 }
115 }
116 }
117 }
118 } else {
119 // Fallback: Poll for updates every 2 seconds
120 warn!("WebSocket not available, using polling fallback for task {}", task_id);
121 loop {
122 match client.http.get_task(&task_id, Some(50)).await {
123 Ok(task) => {
124 let task_json = match serde_json::to_string(&task) {
125 Ok(json) => json,
126 Err(e) => {
127 error!("Failed to serialize task: {}", e);
128 continue;
129 }
130 };
131
132 yield Ok(Event::default()
133 .event("task-update")
134 .data(task_json));
135 }
136 Err(_) => {
137 // Task doesn't exist yet, keep polling silently
138 }
139 }
140
141 tokio::time::sleep(Duration::from_secs(2)).await;
142 }
143 }
144 };
145
146 Sse::new(stream).keep_alive(KeepAlive::default())
147}