a2a_protocol_client/transport/
websocket.rs1use std::collections::HashMap;
38use std::future::Future;
39use std::pin::Pin;
40use std::sync::Arc;
41use std::time::Duration;
42
43use futures_util::{SinkExt, StreamExt};
44use tokio::sync::{mpsc, oneshot, Mutex};
45use tokio_tungstenite::tungstenite::client::IntoClientRequest;
46use tokio_tungstenite::tungstenite::Message as WsMessage;
47use uuid::Uuid;
48
49use a2a_protocol_types::{JsonRpcRequest, JsonRpcResponse};
50
51use crate::error::{ClientError, ClientResult};
52use crate::streaming::EventStream;
53use crate::transport::Transport;
54
55enum PendingRequest {
59 Unary(oneshot::Sender<Result<String, ClientError>>),
61 Streaming(mpsc::Sender<crate::streaming::event_stream::BodyChunk>),
63}
64
65struct WriteCommand {
67 text: String,
68 request_id: String,
69 pending: PendingRequest,
70}
71
72pub struct WebSocketTransport {
83 inner: Arc<Inner>,
84}
85
86struct Inner {
87 write_tx: mpsc::Sender<WriteCommand>,
89 endpoint: String,
90 request_timeout: Duration,
91 _reader_handle: tokio::task::JoinHandle<()>,
93 _writer_handle: tokio::task::JoinHandle<()>,
95}
96
97impl WebSocketTransport {
98 pub async fn connect(endpoint: impl Into<String>) -> ClientResult<Self> {
106 Self::connect_with_options(endpoint, Duration::from_secs(30), &HashMap::new()).await
107 }
108
109 pub async fn connect_with_timeout(
115 endpoint: impl Into<String>,
116 request_timeout: Duration,
117 ) -> ClientResult<Self> {
118 Self::connect_with_options(endpoint, request_timeout, &HashMap::new()).await
119 }
120
121 #[allow(clippy::too_many_lines)]
132 pub async fn connect_with_options(
133 endpoint: impl Into<String>,
134 request_timeout: Duration,
135 extra_headers: &HashMap<String, String>,
136 ) -> ClientResult<Self> {
137 let endpoint = endpoint.into();
138 validate_ws_url(&endpoint)?;
139
140 let mut ws_request = endpoint
144 .as_str()
145 .into_client_request()
146 .map_err(|e| ClientError::Transport(format!("WebSocket request build failed: {e}")))?;
147 for (k, v) in extra_headers {
148 if let (Ok(name), Ok(val)) = (
149 k.parse::<tokio_tungstenite::tungstenite::http::HeaderName>(),
150 v.parse::<tokio_tungstenite::tungstenite::http::HeaderValue>(),
151 ) {
152 ws_request.headers_mut().insert(name, val);
153 }
154 }
155
156 let (ws_stream, _resp) = tokio_tungstenite::connect_async(ws_request)
157 .await
158 .map_err(|e| ClientError::Transport(format!("WebSocket connect failed: {e}")))?;
159
160 let (ws_writer, ws_reader) = ws_stream.split();
161
162 let pending: Arc<Mutex<HashMap<String, PendingRequest>>> =
164 Arc::new(Mutex::new(HashMap::new()));
165
166 let (write_tx, mut write_rx) = mpsc::channel::<WriteCommand>(64);
168
169 let pending_for_writer = Arc::clone(&pending);
172 let writer_handle = tokio::spawn(async move {
173 let mut ws_writer = ws_writer;
174 while let Some(cmd) = write_rx.recv().await {
175 {
177 let mut map = pending_for_writer.lock().await;
178 map.insert(cmd.request_id, cmd.pending);
179 }
180 if ws_writer.send(WsMessage::Text(cmd.text)).await.is_err() {
181 break;
182 }
183 }
184 });
185
186 let pending_for_reader = Arc::clone(&pending);
189 let reader_handle = tokio::spawn(async move {
190 let mut ws_reader = ws_reader;
191 loop {
192 match ws_reader.next().await {
193 Some(Ok(WsMessage::Text(text))) => {
194 route_frame(&pending_for_reader, &text).await;
195 }
196 Some(Ok(WsMessage::Close(_))) | None => break,
197 Some(Ok(_)) => {}
199 Some(Err(_e)) => {
200 let entries: Vec<PendingRequest> = {
203 let mut map = pending_for_reader.lock().await;
204 map.drain().map(|(_, v)| v).collect()
205 };
206 for pending in entries {
207 match pending {
208 PendingRequest::Unary(tx) => {
209 let _ = tx.send(Err(ClientError::Transport(
210 "WebSocket connection error".into(),
211 )));
212 }
213 PendingRequest::Streaming(tx) => {
214 let _ = tx
215 .send(Err(ClientError::Transport(
216 "WebSocket connection error".into(),
217 )))
218 .await;
219 }
220 }
221 }
222 break;
223 }
224 }
225 }
226 });
227
228 let endpoint_stored = endpoint;
230
231 Ok(Self {
232 inner: Arc::new(Inner {
233 write_tx,
234 endpoint: endpoint_stored,
235 request_timeout,
236 _reader_handle: reader_handle,
237 _writer_handle: writer_handle,
238 }),
239 })
240 }
241
242 #[must_use]
244 pub fn endpoint(&self) -> &str {
245 &self.inner.endpoint
246 }
247
248 async fn execute_request(
250 &self,
251 method: &str,
252 params: serde_json::Value,
253 _extra_headers: &HashMap<String, String>,
254 ) -> ClientResult<serde_json::Value> {
255 trace_info!(method, endpoint = %self.inner.endpoint, "sending WebSocket JSON-RPC request");
256
257 let rpc_req = build_rpc_request(method, params);
258 let request_id = rpc_req
259 .id
260 .as_ref()
261 .and_then(|v| v.as_str())
262 .unwrap_or("")
263 .to_owned();
264 let body = serde_json::to_string(&rpc_req).map_err(ClientError::Serialization)?;
265
266 let (tx, rx) = oneshot::channel();
267
268 self.inner
269 .write_tx
270 .send(WriteCommand {
271 text: body,
272 request_id,
273 pending: PendingRequest::Unary(tx),
274 })
275 .await
276 .map_err(|_| ClientError::Transport("WebSocket writer task closed".into()))?;
277
278 let response_text = tokio::time::timeout(self.inner.request_timeout, rx)
279 .await
280 .map_err(|_| ClientError::Timeout("WebSocket response timed out".into()))?
281 .map_err(|_| ClientError::Transport("WebSocket reader task closed".into()))??;
282
283 let envelope: JsonRpcResponse<serde_json::Value> =
284 serde_json::from_str(&response_text).map_err(ClientError::Serialization)?;
285
286 match envelope {
287 JsonRpcResponse::Success(ok) => {
288 trace_info!(method, "WebSocket request succeeded");
289 Ok(ok.result)
290 }
291 JsonRpcResponse::Error(err) => {
292 trace_warn!(
293 method,
294 code = err.error.code,
295 "JSON-RPC error over WebSocket"
296 );
297 let a2a = a2a_protocol_types::A2aError::new(
298 a2a_protocol_types::ErrorCode::try_from(err.error.code)
299 .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
300 err.error.message,
301 );
302 Err(ClientError::Protocol(a2a))
303 }
304 }
305 }
306
307 async fn execute_streaming_request(
309 &self,
310 method: &str,
311 params: serde_json::Value,
312 _extra_headers: &HashMap<String, String>,
313 ) -> ClientResult<EventStream> {
314 trace_info!(method, endpoint = %self.inner.endpoint, "opening WebSocket stream");
315
316 let rpc_req = build_rpc_request(method, params);
317 let request_id = rpc_req
318 .id
319 .as_ref()
320 .and_then(|v| v.as_str())
321 .unwrap_or("")
322 .to_owned();
323 let body = serde_json::to_string(&rpc_req).map_err(ClientError::Serialization)?;
324
325 let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(64);
327
328 self.inner
329 .write_tx
330 .send(WriteCommand {
331 text: body,
332 request_id,
333 pending: PendingRequest::Streaming(tx),
334 })
335 .await
336 .map_err(|_| ClientError::Transport("WebSocket writer task closed".into()))?;
337
338 Ok(EventStream::new(rx))
339 }
340}
341
342impl Transport for WebSocketTransport {
343 fn send_request<'a>(
344 &'a self,
345 method: &'a str,
346 params: serde_json::Value,
347 extra_headers: &'a HashMap<String, String>,
348 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
349 Box::pin(self.execute_request(method, params, extra_headers))
350 }
351
352 fn send_streaming_request<'a>(
353 &'a self,
354 method: &'a str,
355 params: serde_json::Value,
356 extra_headers: &'a HashMap<String, String>,
357 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
358 Box::pin(self.execute_streaming_request(method, params, extra_headers))
359 }
360}
361
362impl std::fmt::Debug for WebSocketTransport {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.debug_struct("WebSocketTransport")
365 .field("endpoint", &self.inner.endpoint)
366 .finish()
367 }
368}
369
370async fn route_frame(pending: &Arc<Mutex<HashMap<String, PendingRequest>>>, text: &str) {
377 let frame_id = extract_jsonrpc_id(text);
379
380 let mut map = pending.lock().await;
381
382 let request_id = if let Some(ref id) = frame_id {
383 id.clone()
384 } else {
385 return;
388 };
389
390 if let Some(entry) = map.get(&request_id) {
391 match entry {
392 PendingRequest::Unary(_) => {
393 if let Some(PendingRequest::Unary(tx)) = map.remove(&request_id) {
395 let _ = tx.send(Ok(text.to_owned()));
396 }
397 }
398 PendingRequest::Streaming(tx) => {
399 let sse_line = format!("data: {text}\n\n");
401 if tx
402 .send(Ok(hyper::body::Bytes::from(sse_line)))
403 .await
404 .is_err()
405 {
406 map.remove(&request_id);
408 return;
409 }
410
411 if is_stream_terminal(text) {
413 map.remove(&request_id);
414 }
415 }
416 }
417 }
418}
419
420fn extract_jsonrpc_id(text: &str) -> Option<String> {
422 let v: serde_json::Value = serde_json::from_str(text).ok()?;
423 match v.get("id") {
424 Some(serde_json::Value::String(s)) => Some(s.clone()),
425 Some(serde_json::Value::Number(n)) => Some(n.to_string()),
426 _ => None,
427 }
428}
429
430fn is_stream_terminal(text: &str) -> bool {
441 let Ok(frame) = serde_json::from_str::<serde_json::Value>(text) else {
442 return false;
443 };
444
445 let has_terminal_state = |obj: &serde_json::Value| -> bool {
448 if let Some(status_update) = obj.get("statusUpdate") {
450 if let Some(status) = status_update.get("status") {
451 if let Some(state) = status.get("state").and_then(|s| s.as_str()) {
452 return matches!(state, "completed" | "failed" | "canceled" | "rejected");
453 }
454 }
455 }
456 if let Some(status) = obj.get("status") {
458 if let Some(state) = status.get("state").and_then(|s| s.as_str()) {
459 return matches!(state, "completed" | "failed" | "canceled" | "rejected");
460 }
461 }
462 false
463 };
464
465 if let Some(r) = frame.get("result") {
467 if r.get("stream_complete").is_some() {
471 return true;
472 }
473 if r.get("status").and_then(|s| s.as_str()) == Some("stream_complete") {
474 return true;
475 }
476 return has_terminal_state(r);
477 }
478
479 has_terminal_state(&frame)
482}
483
484fn build_rpc_request(method: &str, params: serde_json::Value) -> JsonRpcRequest {
485 let id = serde_json::Value::String(Uuid::new_v4().to_string());
486 JsonRpcRequest::with_params(id, method, params)
487}
488
489fn validate_ws_url(url: &str) -> ClientResult<()> {
490 if url.is_empty() {
491 return Err(ClientError::InvalidEndpoint("URL must not be empty".into()));
492 }
493 if !url.starts_with("ws://") && !url.starts_with("wss://") {
494 return Err(ClientError::InvalidEndpoint(format!(
495 "WebSocket URL must start with ws:// or wss://: {url}"
496 )));
497 }
498 Ok(())
499}
500
501#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn validate_ws_url_rejects_empty() {
509 assert!(validate_ws_url("").is_err());
510 }
511
512 #[test]
513 fn validate_ws_url_rejects_http() {
514 assert!(validate_ws_url("http://localhost:8080").is_err());
515 }
516
517 #[test]
518 fn validate_ws_url_accepts_ws() {
519 assert!(validate_ws_url("ws://localhost:8080").is_ok());
520 }
521
522 #[test]
523 fn validate_ws_url_accepts_wss() {
524 assert!(validate_ws_url("wss://agent.example.com/a2a").is_ok());
525 }
526
527 #[test]
528 fn is_stream_terminal_completed_status() {
529 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"completed"}}}}"#;
530 assert!(is_stream_terminal(frame));
531 }
532
533 #[test]
534 fn is_stream_terminal_failed_status() {
535 let frame =
536 r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"failed"}}}}"#;
537 assert!(is_stream_terminal(frame));
538 }
539
540 #[test]
541 fn is_stream_terminal_working_is_not_terminal() {
542 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"working"}}}}"#;
543 assert!(!is_stream_terminal(frame));
544 }
545
546 #[test]
547 fn is_stream_terminal_stream_complete_sentinel() {
548 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"stream_complete":true}}"#;
549 assert!(is_stream_terminal(frame));
550 }
551
552 #[test]
553 fn is_stream_terminal_artifact_not_terminal() {
554 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"artifactUpdate":{"artifact":{"id":"a1","parts":[]}}}}"#;
555 assert!(!is_stream_terminal(frame));
556 }
557
558 #[test]
559 fn is_stream_terminal_payload_containing_word_not_terminal() {
560 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"artifactUpdate":{"artifact":{"id":"a1","parts":[{"text":"task completed successfully"}]}}}}"#;
562 assert!(!is_stream_terminal(frame));
563 }
564
565 #[test]
566 fn build_rpc_request_has_method() {
567 let req = build_rpc_request("TestMethod", serde_json::json!({"key": "val"}));
568 assert_eq!(req.method, "TestMethod");
569 let params = req.params.expect("params should be present");
570 assert_eq!(params["key"], "val");
571 let id = req.id.expect("id should be present");
573 assert!(id.is_string(), "id should be a string UUID");
574 assert!(!id.as_str().unwrap().is_empty(), "id should not be empty");
575 }
576
577 #[test]
578 fn is_stream_terminal_invalid_json() {
579 assert!(!is_stream_terminal("not json"));
580 }
581
582 #[test]
583 fn is_stream_terminal_no_result() {
584 assert!(!is_stream_terminal(r#"{"jsonrpc":"2.0","id":"1"}"#));
585 }
586
587 #[test]
588 fn is_stream_terminal_task_level_completed() {
589 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"status":{"state":"completed"}}}"#;
590 assert!(is_stream_terminal(frame));
591 }
592
593 #[test]
594 fn is_stream_terminal_canceled() {
595 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"canceled"}}}}"#;
596 assert!(is_stream_terminal(frame));
597 }
598
599 #[test]
600 fn is_stream_terminal_rejected() {
601 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"rejected"}}}}"#;
602 assert!(is_stream_terminal(frame));
603 }
604
605 #[test]
606 fn is_stream_terminal_task_level_failed() {
607 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"status":{"state":"failed"}}}"#;
608 assert!(is_stream_terminal(frame));
609 }
610
611 #[test]
612 fn is_stream_terminal_non_string_state() {
613 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"status":{"state":42}}}"#;
614 assert!(!is_stream_terminal(frame));
615 }
616
617 #[test]
618 fn validate_ws_url_rejects_https() {
619 assert!(validate_ws_url("https://example.com").is_err());
620 }
621
622 #[test]
623 fn validate_ws_url_error_message_contains_url() {
624 let err = validate_ws_url("http://bad").unwrap_err();
625 let msg = format!("{err}");
626 assert!(msg.contains("http://bad") || msg.contains("ws://"));
627 }
628
629 #[test]
630 fn extract_jsonrpc_id_string() {
631 let id = extract_jsonrpc_id(r#"{"jsonrpc":"2.0","id":"abc","result":{}}"#);
632 assert_eq!(id.as_deref(), Some("abc"));
633 }
634
635 #[test]
636 fn extract_jsonrpc_id_number() {
637 let id = extract_jsonrpc_id(r#"{"jsonrpc":"2.0","id":42,"result":{}}"#);
638 assert_eq!(id.as_deref(), Some("42"));
639 }
640
641 #[test]
642 fn extract_jsonrpc_id_null_returns_none() {
643 let id = extract_jsonrpc_id(r#"{"jsonrpc":"2.0","id":null,"result":{}}"#);
644 assert!(id.is_none());
645 }
646
647 #[test]
648 fn extract_jsonrpc_id_missing_returns_none() {
649 let id = extract_jsonrpc_id(r#"{"jsonrpc":"2.0","result":{}}"#);
650 assert!(id.is_none());
651 }
652}