a2a_protocol_client/transport/
websocket.rs1use std::collections::HashMap;
38use std::future::Future;
39use std::pin::Pin;
40use std::sync::Arc;
41use std::time::Duration;
42
43use futures_util::{SinkExt, StreamExt};
44use tokio::sync::{mpsc, oneshot, Mutex};
45use tokio_tungstenite::tungstenite::client::IntoClientRequest;
46use tokio_tungstenite::tungstenite::Message as WsMessage;
47use uuid::Uuid;
48
49use a2a_protocol_types::{JsonRpcRequest, JsonRpcResponse};
50
51use crate::error::{ClientError, ClientResult};
52use crate::streaming::EventStream;
53use crate::transport::Transport;
54
55enum PendingRequest {
59 Unary(oneshot::Sender<Result<String, ClientError>>),
61 Streaming(mpsc::Sender<crate::streaming::event_stream::BodyChunk>),
63}
64
65struct WriteCommand {
67 text: String,
68 request_id: String,
69 pending: PendingRequest,
70}
71
72pub struct WebSocketTransport {
83 inner: Arc<Inner>,
84}
85
86struct Inner {
87 write_tx: mpsc::Sender<WriteCommand>,
89 endpoint: String,
90 request_timeout: Duration,
91 _reader_handle: tokio::task::JoinHandle<()>,
93 _writer_handle: tokio::task::JoinHandle<()>,
95}
96
97impl WebSocketTransport {
98 pub async fn connect(endpoint: impl Into<String>) -> ClientResult<Self> {
106 Self::connect_with_options(endpoint, Duration::from_secs(30), &HashMap::new()).await
107 }
108
109 pub async fn connect_with_timeout(
115 endpoint: impl Into<String>,
116 request_timeout: Duration,
117 ) -> ClientResult<Self> {
118 Self::connect_with_options(endpoint, request_timeout, &HashMap::new()).await
119 }
120
121 #[allow(clippy::too_many_lines)]
132 pub async fn connect_with_options(
133 endpoint: impl Into<String>,
134 request_timeout: Duration,
135 extra_headers: &HashMap<String, String>,
136 ) -> ClientResult<Self> {
137 let endpoint = endpoint.into();
138 validate_ws_url(&endpoint)?;
139
140 let mut ws_request = endpoint
144 .as_str()
145 .into_client_request()
146 .map_err(|e| ClientError::Transport(format!("WebSocket request build failed: {e}")))?;
147 for (k, v) in extra_headers {
148 if let (Ok(name), Ok(val)) = (
149 k.parse::<tokio_tungstenite::tungstenite::http::HeaderName>(),
150 v.parse::<tokio_tungstenite::tungstenite::http::HeaderValue>(),
151 ) {
152 ws_request.headers_mut().insert(name, val);
153 }
154 }
155
156 let (ws_stream, _resp) = tokio_tungstenite::connect_async(ws_request)
157 .await
158 .map_err(|e| ClientError::Transport(format!("WebSocket connect failed: {e}")))?;
159
160 let (ws_writer, ws_reader) = ws_stream.split();
161
162 let pending: Arc<Mutex<HashMap<String, PendingRequest>>> =
164 Arc::new(Mutex::new(HashMap::new()));
165
166 let (write_tx, mut write_rx) = mpsc::channel::<WriteCommand>(64);
168
169 let pending_for_writer = Arc::clone(&pending);
172 let writer_handle = tokio::spawn(async move {
173 let mut ws_writer = ws_writer;
174 while let Some(cmd) = write_rx.recv().await {
175 {
177 let mut map = pending_for_writer.lock().await;
178 map.insert(cmd.request_id, cmd.pending);
179 }
180 if ws_writer
181 .send(WsMessage::Text(cmd.text.into()))
182 .await
183 .is_err()
184 {
185 break;
186 }
187 }
188 });
189
190 let pending_for_reader = Arc::clone(&pending);
193 let reader_handle = tokio::spawn(async move {
194 let mut ws_reader = ws_reader;
195 loop {
196 match ws_reader.next().await {
197 Some(Ok(WsMessage::Text(text))) => {
198 route_frame(&pending_for_reader, text.as_str()).await;
199 }
200 Some(Ok(WsMessage::Close(_))) | None => break,
201 Some(Ok(_)) => {}
203 Some(Err(_e)) => {
204 let entries: Vec<PendingRequest> = {
207 let mut map = pending_for_reader.lock().await;
208 map.drain().map(|(_, v)| v).collect()
209 };
210 for pending in entries {
211 match pending {
212 PendingRequest::Unary(tx) => {
213 let _ = tx.send(Err(ClientError::Transport(
214 "WebSocket connection error".into(),
215 )));
216 }
217 PendingRequest::Streaming(tx) => {
218 let _ = tx
219 .send(Err(ClientError::Transport(
220 "WebSocket connection error".into(),
221 )))
222 .await;
223 }
224 }
225 }
226 break;
227 }
228 }
229 }
230 });
231
232 let endpoint_stored = endpoint;
234
235 Ok(Self {
236 inner: Arc::new(Inner {
237 write_tx,
238 endpoint: endpoint_stored,
239 request_timeout,
240 _reader_handle: reader_handle,
241 _writer_handle: writer_handle,
242 }),
243 })
244 }
245
246 #[must_use]
248 pub fn endpoint(&self) -> &str {
249 &self.inner.endpoint
250 }
251
252 async fn execute_request(
254 &self,
255 method: &str,
256 params: serde_json::Value,
257 _extra_headers: &HashMap<String, String>,
258 ) -> ClientResult<serde_json::Value> {
259 trace_info!(method, endpoint = %self.inner.endpoint, "sending WebSocket JSON-RPC request");
260
261 let rpc_req = build_rpc_request(method, params);
262 let request_id = rpc_req
263 .id
264 .as_ref()
265 .and_then(|v| v.as_str())
266 .unwrap_or("")
267 .to_owned();
268 let body = serde_json::to_string(&rpc_req).map_err(ClientError::Serialization)?;
269
270 let (tx, rx) = oneshot::channel();
271
272 self.inner
273 .write_tx
274 .send(WriteCommand {
275 text: body,
276 request_id,
277 pending: PendingRequest::Unary(tx),
278 })
279 .await
280 .map_err(|_| ClientError::Transport("WebSocket writer task closed".into()))?;
281
282 let response_text = tokio::time::timeout(self.inner.request_timeout, rx)
283 .await
284 .map_err(|_| ClientError::Timeout("WebSocket response timed out".into()))?
285 .map_err(|_| ClientError::Transport("WebSocket reader task closed".into()))??;
286
287 let envelope: JsonRpcResponse<serde_json::Value> =
288 serde_json::from_str(&response_text).map_err(ClientError::Serialization)?;
289
290 match envelope {
291 JsonRpcResponse::Success(ok) => {
292 trace_info!(method, "WebSocket request succeeded");
293 Ok(ok.result)
294 }
295 JsonRpcResponse::Error(err) => {
296 trace_warn!(
297 method,
298 code = err.error.code,
299 "JSON-RPC error over WebSocket"
300 );
301 let a2a = a2a_protocol_types::A2aError::new(
302 a2a_protocol_types::ErrorCode::try_from(err.error.code)
303 .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
304 err.error.message,
305 );
306 Err(ClientError::Protocol(a2a))
307 }
308 }
309 }
310
311 async fn execute_streaming_request(
313 &self,
314 method: &str,
315 params: serde_json::Value,
316 _extra_headers: &HashMap<String, String>,
317 ) -> ClientResult<EventStream> {
318 trace_info!(method, endpoint = %self.inner.endpoint, "opening WebSocket stream");
319
320 let rpc_req = build_rpc_request(method, params);
321 let request_id = rpc_req
322 .id
323 .as_ref()
324 .and_then(|v| v.as_str())
325 .unwrap_or("")
326 .to_owned();
327 let body = serde_json::to_string(&rpc_req).map_err(ClientError::Serialization)?;
328
329 let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(64);
331
332 self.inner
333 .write_tx
334 .send(WriteCommand {
335 text: body,
336 request_id,
337 pending: PendingRequest::Streaming(tx),
338 })
339 .await
340 .map_err(|_| ClientError::Transport("WebSocket writer task closed".into()))?;
341
342 Ok(EventStream::new(rx))
343 }
344}
345
346impl Transport for WebSocketTransport {
347 fn send_request<'a>(
348 &'a self,
349 method: &'a str,
350 params: serde_json::Value,
351 extra_headers: &'a HashMap<String, String>,
352 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
353 Box::pin(self.execute_request(method, params, extra_headers))
354 }
355
356 fn send_streaming_request<'a>(
357 &'a self,
358 method: &'a str,
359 params: serde_json::Value,
360 extra_headers: &'a HashMap<String, String>,
361 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
362 Box::pin(self.execute_streaming_request(method, params, extra_headers))
363 }
364}
365
366impl std::fmt::Debug for WebSocketTransport {
367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 f.debug_struct("WebSocketTransport")
369 .field("endpoint", &self.inner.endpoint)
370 .finish()
371 }
372}
373
374async fn route_frame(pending: &Arc<Mutex<HashMap<String, PendingRequest>>>, text: &str) {
381 let frame_id = extract_jsonrpc_id(text);
383
384 let mut map = pending.lock().await;
385
386 let request_id = if let Some(ref id) = frame_id {
387 id.clone()
388 } else {
389 return;
392 };
393
394 if let Some(entry) = map.get(&request_id) {
395 match entry {
396 PendingRequest::Unary(_) => {
397 if let Some(PendingRequest::Unary(tx)) = map.remove(&request_id) {
399 let _ = tx.send(Ok(text.to_owned()));
400 }
401 }
402 PendingRequest::Streaming(tx) => {
403 let sse_line = format!("data: {text}\n\n");
405 if tx
406 .send(Ok(hyper::body::Bytes::from(sse_line)))
407 .await
408 .is_err()
409 {
410 map.remove(&request_id);
412 return;
413 }
414
415 if is_stream_terminal(text) {
417 map.remove(&request_id);
418 }
419 }
420 }
421 }
422}
423
424fn extract_jsonrpc_id(text: &str) -> Option<String> {
426 let v: serde_json::Value = serde_json::from_str(text).ok()?;
427 match v.get("id") {
428 Some(serde_json::Value::String(s)) => Some(s.clone()),
429 Some(serde_json::Value::Number(n)) => Some(n.to_string()),
430 _ => None,
431 }
432}
433
434fn is_stream_terminal(text: &str) -> bool {
445 let Ok(frame) = serde_json::from_str::<serde_json::Value>(text) else {
446 return false;
447 };
448
449 let has_terminal_state = |obj: &serde_json::Value| -> bool {
452 if let Some(status_update) = obj.get("statusUpdate") {
454 if let Some(status) = status_update.get("status") {
455 if let Some(state) = status.get("state").and_then(|s| s.as_str()) {
456 return matches!(state, "completed" | "failed" | "canceled" | "rejected");
457 }
458 }
459 }
460 if let Some(status) = obj.get("status") {
462 if let Some(state) = status.get("state").and_then(|s| s.as_str()) {
463 return matches!(state, "completed" | "failed" | "canceled" | "rejected");
464 }
465 }
466 false
467 };
468
469 if let Some(r) = frame.get("result") {
471 if r.get("stream_complete").is_some() {
475 return true;
476 }
477 if r.get("status").and_then(|s| s.as_str()) == Some("stream_complete") {
478 return true;
479 }
480 return has_terminal_state(r);
481 }
482
483 has_terminal_state(&frame)
486}
487
488fn build_rpc_request(method: &str, params: serde_json::Value) -> JsonRpcRequest {
489 let id = serde_json::Value::String(Uuid::new_v4().to_string());
490 JsonRpcRequest::with_params(id, method, params)
491}
492
493fn validate_ws_url(url: &str) -> ClientResult<()> {
494 if url.is_empty() {
495 return Err(ClientError::InvalidEndpoint("URL must not be empty".into()));
496 }
497 if !url.starts_with("ws://") && !url.starts_with("wss://") {
498 return Err(ClientError::InvalidEndpoint(format!(
499 "WebSocket URL must start with ws:// or wss://: {url}"
500 )));
501 }
502 Ok(())
503}
504
505#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[test]
512 fn validate_ws_url_rejects_empty() {
513 assert!(validate_ws_url("").is_err());
514 }
515
516 #[test]
517 fn validate_ws_url_rejects_http() {
518 assert!(validate_ws_url("http://localhost:8080").is_err());
519 }
520
521 #[test]
522 fn validate_ws_url_accepts_ws() {
523 assert!(validate_ws_url("ws://localhost:8080").is_ok());
524 }
525
526 #[test]
527 fn validate_ws_url_accepts_wss() {
528 assert!(validate_ws_url("wss://agent.example.com/a2a").is_ok());
529 }
530
531 #[test]
532 fn is_stream_terminal_completed_status() {
533 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"completed"}}}}"#;
534 assert!(is_stream_terminal(frame));
535 }
536
537 #[test]
538 fn is_stream_terminal_failed_status() {
539 let frame =
540 r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"failed"}}}}"#;
541 assert!(is_stream_terminal(frame));
542 }
543
544 #[test]
545 fn is_stream_terminal_working_is_not_terminal() {
546 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"working"}}}}"#;
547 assert!(!is_stream_terminal(frame));
548 }
549
550 #[test]
551 fn is_stream_terminal_stream_complete_sentinel() {
552 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"stream_complete":true}}"#;
553 assert!(is_stream_terminal(frame));
554 }
555
556 #[test]
557 fn is_stream_terminal_artifact_not_terminal() {
558 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"artifactUpdate":{"artifact":{"id":"a1","parts":[]}}}}"#;
559 assert!(!is_stream_terminal(frame));
560 }
561
562 #[test]
563 fn is_stream_terminal_payload_containing_word_not_terminal() {
564 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"artifactUpdate":{"artifact":{"id":"a1","parts":[{"text":"task completed successfully"}]}}}}"#;
566 assert!(!is_stream_terminal(frame));
567 }
568
569 #[test]
570 fn build_rpc_request_has_method() {
571 let req = build_rpc_request("TestMethod", serde_json::json!({"key": "val"}));
572 assert_eq!(req.method, "TestMethod");
573 let params = req.params.expect("params should be present");
574 assert_eq!(params["key"], "val");
575 let id = req.id.expect("id should be present");
577 assert!(id.is_string(), "id should be a string UUID");
578 assert!(!id.as_str().unwrap().is_empty(), "id should not be empty");
579 }
580
581 #[test]
582 fn is_stream_terminal_invalid_json() {
583 assert!(!is_stream_terminal("not json"));
584 }
585
586 #[test]
587 fn is_stream_terminal_no_result() {
588 assert!(!is_stream_terminal(r#"{"jsonrpc":"2.0","id":"1"}"#));
589 }
590
591 #[test]
592 fn is_stream_terminal_task_level_completed() {
593 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"status":{"state":"completed"}}}"#;
594 assert!(is_stream_terminal(frame));
595 }
596
597 #[test]
598 fn is_stream_terminal_canceled() {
599 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"canceled"}}}}"#;
600 assert!(is_stream_terminal(frame));
601 }
602
603 #[test]
604 fn is_stream_terminal_rejected() {
605 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"statusUpdate":{"status":{"state":"rejected"}}}}"#;
606 assert!(is_stream_terminal(frame));
607 }
608
609 #[test]
610 fn is_stream_terminal_task_level_failed() {
611 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"status":{"state":"failed"}}}"#;
612 assert!(is_stream_terminal(frame));
613 }
614
615 #[test]
616 fn is_stream_terminal_non_string_state() {
617 let frame = r#"{"jsonrpc":"2.0","id":"1","result":{"status":{"state":42}}}"#;
618 assert!(!is_stream_terminal(frame));
619 }
620
621 #[test]
622 fn validate_ws_url_rejects_https() {
623 assert!(validate_ws_url("https://example.com").is_err());
624 }
625
626 #[test]
627 fn validate_ws_url_error_message_contains_url() {
628 let err = validate_ws_url("http://bad").unwrap_err();
629 let msg = format!("{err}");
630 assert!(msg.contains("http://bad") || msg.contains("ws://"));
631 }
632
633 #[test]
634 fn extract_jsonrpc_id_string() {
635 let id = extract_jsonrpc_id(r#"{"jsonrpc":"2.0","id":"abc","result":{}}"#);
636 assert_eq!(id.as_deref(), Some("abc"));
637 }
638
639 #[test]
640 fn extract_jsonrpc_id_number() {
641 let id = extract_jsonrpc_id(r#"{"jsonrpc":"2.0","id":42,"result":{}}"#);
642 assert_eq!(id.as_deref(), Some("42"));
643 }
644
645 #[test]
646 fn extract_jsonrpc_id_null_returns_none() {
647 let id = extract_jsonrpc_id(r#"{"jsonrpc":"2.0","id":null,"result":{}}"#);
648 assert!(id.is_none());
649 }
650
651 #[test]
652 fn extract_jsonrpc_id_missing_returns_none() {
653 let id = extract_jsonrpc_id(r#"{"jsonrpc":"2.0","result":{}}"#);
654 assert!(id.is_none());
655 }
656}