1use std::collections::HashMap;
30use std::net::SocketAddr;
31use std::sync::Arc;
32
33use futures_util::stream::SplitSink;
34use futures_util::{SinkExt, StreamExt};
35use tokio::net::{TcpListener, TcpStream};
36use tokio_tungstenite::tungstenite::Message as WsMessage;
37use tokio_tungstenite::WebSocketStream;
38
39use a2a_protocol_types::jsonrpc::{
40 JsonRpcError, JsonRpcErrorResponse, JsonRpcId, JsonRpcRequest, JsonRpcSuccessResponse,
41 JsonRpcVersion,
42};
43
44use crate::error::ServerError;
45use crate::handler::{RequestHandler, SendMessageResult};
46use crate::streaming::EventQueueReader;
47
48pub struct WebSocketDispatcher {
53 handler: Arc<RequestHandler>,
54}
55
56impl WebSocketDispatcher {
57 #[must_use]
59 pub const fn new(handler: Arc<RequestHandler>) -> Self {
60 Self { handler }
61 }
62
63 pub async fn serve(
69 self: Arc<Self>,
70 addr: impl tokio::net::ToSocketAddrs,
71 ) -> std::io::Result<()> {
72 let listener = TcpListener::bind(addr).await?;
73
74 trace_info!(
75 addr = %listener.local_addr().unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0))),
76 "A2A WebSocket server listening"
77 );
78
79 loop {
80 let (stream, _peer) = listener.accept().await?;
81 let dispatcher = Arc::clone(&self);
82 tokio::spawn(async move {
83 trace_debug!("WebSocket connection accepted");
84 if let Err(_e) = dispatcher.handle_connection(stream).await {
85 trace_warn!("WebSocket connection error");
86 }
87 });
88 }
89 }
90
91 pub async fn serve_with_addr(
99 self: Arc<Self>,
100 addr: impl tokio::net::ToSocketAddrs,
101 ) -> std::io::Result<SocketAddr> {
102 let listener = TcpListener::bind(addr).await?;
103 let local_addr = listener.local_addr()?;
104
105 trace_info!(%local_addr, "A2A WebSocket server listening");
106
107 tokio::spawn(async move {
108 loop {
109 let Ok((stream, _peer)) = listener.accept().await else {
110 break;
111 };
112 let dispatcher = Arc::clone(&self);
113 tokio::spawn(async move {
114 let _ = dispatcher.handle_connection(stream).await;
115 });
116 }
117 });
118
119 Ok(local_addr)
120 }
121
122 async fn handle_connection(&self, stream: TcpStream) -> Result<(), WsError> {
124 let ws_stream = tokio_tungstenite::accept_async(stream)
125 .await
126 .map_err(WsError::Handshake)?;
127
128 let (writer, mut reader) = ws_stream.split();
129 let writer = Arc::new(tokio::sync::Mutex::new(writer));
130
131 let semaphore = Arc::new(tokio::sync::Semaphore::new(64));
133
134 while let Some(msg) = reader.next().await {
135 match msg {
136 Ok(WsMessage::Text(text)) => {
137 if text.len() > 4 * 1024 * 1024 {
139 let err_resp = JsonRpcErrorResponse::new(
140 None,
141 JsonRpcError::new(-32000, "message too large".to_string()),
142 );
143 send_json(&writer, &err_resp).await;
144 continue;
145 }
146
147 let Ok(permit) = semaphore.clone().try_acquire_owned() else {
149 let err_resp = JsonRpcErrorResponse::new(
150 None,
151 JsonRpcError::new(
152 -32000,
153 "server busy: too many concurrent requests".to_string(),
154 ),
155 );
156 send_json(&writer, &err_resp).await;
157 continue;
158 };
159
160 let writer = Arc::clone(&writer);
161 let handler = Arc::clone(&self.handler);
162 tokio::spawn(async move {
163 process_ws_message(&handler, &text, writer).await;
164 drop(permit); });
166 }
167 Ok(WsMessage::Ping(data)) => {
168 let mut w = writer.lock().await;
169 let _ = w.send(WsMessage::Pong(data)).await;
170 drop(w);
171 }
172 Ok(WsMessage::Close(_)) | Err(_) => break,
173 Ok(_) => {} }
175 }
176
177 Ok(())
178 }
179}
180
181#[derive(Debug)]
183enum WsError {
184 Handshake(tokio_tungstenite::tungstenite::Error),
185}
186
187impl std::fmt::Display for WsError {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 match self {
190 Self::Handshake(e) => write!(f, "WebSocket handshake failed: {e}"),
191 }
192 }
193}
194
195type WsSink = Arc<tokio::sync::Mutex<SplitSink<WebSocketStream<TcpStream>, WsMessage>>>;
196
197#[allow(clippy::too_many_lines)]
199async fn process_ws_message(handler: &RequestHandler, text: &str, writer: WsSink) {
200 let rpc_req: JsonRpcRequest = match serde_json::from_str(text) {
201 Ok(req) => req,
202 Err(e) => {
203 let err_resp = JsonRpcErrorResponse::new(
204 None,
205 JsonRpcError::new(-32700, format!("parse error: {e}")),
206 );
207 send_json(&writer, &err_resp).await;
208 return;
209 }
210 };
211
212 let id = rpc_req.id.clone();
213 let headers = HashMap::new();
214
215 match rpc_req.method.as_str() {
216 "SendMessage" => {
217 dispatch_send_message(handler, &rpc_req, false, &headers, id, &writer).await;
218 }
219 "SendStreamingMessage" => {
220 dispatch_send_message(handler, &rpc_req, true, &headers, id, &writer).await;
221 }
222 "GetTask" => {
223 dispatch_simple(handler, &rpc_req, id, &headers, &writer, |h, p, hdr| {
224 Box::pin(async move {
225 let params: a2a_protocol_types::params::TaskQueryParams =
226 serde_json::from_value(p).map_err(|e| {
227 a2a_protocol_types::error::A2aError::invalid_params(e.to_string())
228 })?;
229 h.on_get_task(params, Some(hdr))
230 .await
231 .map(|r| serde_json::to_value(&r).unwrap_or_default())
232 .map_err(|e| e.to_a2a_error())
233 })
234 })
235 .await;
236 }
237 "ListTasks" => {
238 dispatch_simple(handler, &rpc_req, id, &headers, &writer, |h, p, hdr| {
239 Box::pin(async move {
240 let params: a2a_protocol_types::params::ListTasksParams =
241 serde_json::from_value(p).map_err(|e| {
242 a2a_protocol_types::error::A2aError::invalid_params(e.to_string())
243 })?;
244 h.on_list_tasks(params, Some(hdr))
245 .await
246 .map(|r| serde_json::to_value(&r).unwrap_or_default())
247 .map_err(|e| e.to_a2a_error())
248 })
249 })
250 .await;
251 }
252 "CancelTask" => {
253 dispatch_simple(handler, &rpc_req, id, &headers, &writer, |h, p, hdr| {
254 Box::pin(async move {
255 let params: a2a_protocol_types::params::CancelTaskParams =
256 serde_json::from_value(p).map_err(|e| {
257 a2a_protocol_types::error::A2aError::invalid_params(e.to_string())
258 })?;
259 h.on_cancel_task(params, Some(hdr))
260 .await
261 .map(|r| serde_json::to_value(&r).unwrap_or_default())
262 .map_err(|e| e.to_a2a_error())
263 })
264 })
265 .await;
266 }
267 "SubscribeToTask" => {
268 let params = match parse_params::<a2a_protocol_types::params::TaskIdParams>(
269 rpc_req.params.as_ref(),
270 ) {
271 Ok(p) => p,
272 Err(e) => {
273 send_error(&writer, id, &e).await;
274 return;
275 }
276 };
277 match handler.on_resubscribe(params, Some(&headers)).await {
278 Ok(reader) => {
279 stream_events(&writer, reader, id).await;
280 }
281 Err(e) => {
282 send_error(&writer, id, &e).await;
283 }
284 }
285 }
286 other => {
287 let err = ServerError::MethodNotFound(other.to_owned());
288 send_error(&writer, id, &err).await;
289 }
290 }
291}
292
293async fn dispatch_send_message(
295 handler: &RequestHandler,
296 rpc_req: &JsonRpcRequest,
297 streaming: bool,
298 headers: &HashMap<String, String>,
299 id: JsonRpcId,
300 writer: &WsSink,
301) {
302 let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(
303 rpc_req.params.as_ref(),
304 ) {
305 Ok(p) => p,
306 Err(e) => {
307 send_error(writer, id, &e).await;
308 return;
309 }
310 };
311
312 match handler
313 .on_send_message(params, streaming, Some(headers))
314 .await
315 {
316 Ok(SendMessageResult::Response(resp)) => {
317 let result = serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null);
318 let success = JsonRpcSuccessResponse {
319 jsonrpc: JsonRpcVersion,
320 id,
321 result,
322 };
323 send_json(writer, &success).await;
324 }
325 Ok(SendMessageResult::Stream(reader)) => {
326 stream_events(writer, reader, id).await;
327 }
328 Err(e) => {
329 send_error(writer, id, &e).await;
330 }
331 }
332}
333
334async fn stream_events(
336 writer: &WsSink,
337 mut reader: crate::streaming::InMemoryQueueReader,
338 id: JsonRpcId,
339) {
340 while let Some(event) = reader.read().await {
341 match event {
342 Ok(stream_resp) => {
343 let envelope = JsonRpcSuccessResponse {
346 jsonrpc: JsonRpcVersion,
347 id: id.clone(),
348 result: stream_resp,
349 };
350 let json = serde_json::to_string(&envelope).unwrap_or_default();
351 let mut w = writer.lock().await;
352 if w.send(WsMessage::Text(json)).await.is_err() {
353 return; }
355 drop(w);
356 }
357 Err(e) => {
358 let err_resp =
359 JsonRpcErrorResponse::new(id.clone(), JsonRpcError::new(-32000, e.to_string()));
360 send_json(writer, &err_resp).await;
361 return;
362 }
363 }
364 }
365
366 let success = JsonRpcSuccessResponse {
368 jsonrpc: JsonRpcVersion,
369 id,
370 result: serde_json::json!({"status": "stream_complete"}),
371 };
372 send_json(writer, &success).await;
373}
374
375async fn dispatch_simple<'a, F>(
377 handler: &'a RequestHandler,
378 rpc_req: &JsonRpcRequest,
379 id: JsonRpcId,
380 headers: &'a HashMap<String, String>,
381 writer: &WsSink,
382 f: F,
383) where
384 F: FnOnce(
385 &'a RequestHandler,
386 serde_json::Value,
387 &'a HashMap<String, String>,
388 ) -> std::pin::Pin<
389 Box<
390 dyn std::future::Future<
391 Output = Result<serde_json::Value, a2a_protocol_types::error::A2aError>,
392 > + Send
393 + 'a,
394 >,
395 >,
396{
397 let params = rpc_req.params.clone().unwrap_or(serde_json::Value::Null);
398 match f(handler, params, headers).await {
399 Ok(result) => {
400 let success = JsonRpcSuccessResponse {
401 jsonrpc: JsonRpcVersion,
402 id,
403 result,
404 };
405 send_json(writer, &success).await;
406 }
407 Err(e) => {
408 let err_resp =
409 JsonRpcErrorResponse::new(id, JsonRpcError::new(e.code.as_i32(), e.message));
410 send_json(writer, &err_resp).await;
411 }
412 }
413}
414
415async fn send_json<T: serde::Serialize + Sync>(writer: &WsSink, value: &T) {
417 let json = serde_json::to_string(value).unwrap_or_default();
418 let mut w = writer.lock().await;
419 let _ = w.send(WsMessage::Text(json)).await;
420 drop(w);
421}
422
423async fn send_error(writer: &WsSink, id: JsonRpcId, err: &ServerError) {
425 let a2a_err = err.to_a2a_error();
426 let resp = JsonRpcErrorResponse::new(
427 id,
428 JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
429 );
430 send_json(writer, &resp).await;
431}
432
433fn parse_params<T: serde::de::DeserializeOwned>(
435 params: Option<&serde_json::Value>,
436) -> Result<T, ServerError> {
437 let value = params.cloned().unwrap_or(serde_json::Value::Null);
438 serde_json::from_value(value)
439 .map_err(|e| ServerError::InvalidParams(format!("invalid params: {e}")))
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn parse_params_with_valid_json() {
448 let value = Some(serde_json::json!({"id": "task-1"}));
449 let result: Result<a2a_protocol_types::params::TaskQueryParams, _> =
450 parse_params(value.as_ref());
451 assert!(result.is_ok());
452 assert_eq!(result.unwrap().id, "task-1");
453 }
454
455 #[test]
456 fn parse_params_with_none_returns_error() {
457 let result: Result<a2a_protocol_types::params::TaskQueryParams, _> = parse_params(None);
458 assert!(result.is_err());
459 }
460
461 #[test]
462 fn parse_params_with_wrong_type_returns_error() {
463 let value = Some(serde_json::json!("not an object"));
464 let result: Result<a2a_protocol_types::params::TaskQueryParams, _> =
465 parse_params(value.as_ref());
466 assert!(result.is_err());
467 }
468
469 #[test]
471 fn ws_error_display_contains_message() {
472 let err = WsError::Handshake(tokio_tungstenite::tungstenite::Error::ConnectionClosed);
473 let s = err.to_string();
474 assert!(s.contains("WebSocket handshake failed"));
475 }
476
477 #[test]
479 fn websocket_dispatcher_new() {
480 use crate::agent_executor;
481 use crate::RequestHandlerBuilder;
482 use std::sync::Arc;
483 struct DummyExec;
484 agent_executor!(DummyExec, |_ctx, _queue| async { Ok(()) });
485 let handler = Arc::new(RequestHandlerBuilder::new(DummyExec).build().unwrap());
486 let _dispatcher = WebSocketDispatcher::new(handler);
487 }
488
489 use crate::agent_executor;
492 use crate::RequestHandlerBuilder;
493 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
494 use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
495 use futures_util::{SinkExt, StreamExt};
496
497 struct EchoExec;
498 agent_executor!(EchoExec, |ctx, queue| async {
499 queue
500 .write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
501 task_id: ctx.task_id.clone(),
502 context_id: ContextId::new(ctx.context_id.clone()),
503 status: TaskStatus::new(TaskState::Working),
504 metadata: None,
505 }))
506 .await?;
507 queue
508 .write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
509 task_id: ctx.task_id.clone(),
510 context_id: ContextId::new(ctx.context_id.clone()),
511 status: TaskStatus::new(TaskState::Completed),
512 metadata: None,
513 }))
514 .await?;
515 Ok(())
516 });
517
518 async fn spawn_ws_server() -> std::net::SocketAddr {
519 let handler = Arc::new(RequestHandlerBuilder::new(EchoExec).build().unwrap());
520 let dispatcher = Arc::new(WebSocketDispatcher::new(handler));
521 dispatcher
522 .serve_with_addr("127.0.0.1:0")
523 .await
524 .expect("bind to port 0")
525 }
526
527 async fn ws_connect(
528 addr: std::net::SocketAddr,
529 ) -> tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>
530 {
531 let (ws, _) = tokio_tungstenite::connect_async(format!("ws://{addr}"))
532 .await
533 .expect("ws connect");
534 ws
535 }
536
537 async fn read_text(
539 ws: &mut tokio_tungstenite::WebSocketStream<
540 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
541 >,
542 ) -> String {
543 let msg = tokio::time::timeout(std::time::Duration::from_secs(5), ws.next())
544 .await
545 .expect("timeout waiting for WS frame")
546 .expect("stream ended")
547 .expect("ws error");
548 msg.into_text().expect("not a text frame")
549 }
550
551 fn send_message_json(id: &str) -> String {
552 serde_json::json!({
553 "jsonrpc": "2.0",
554 "method": "SendMessage",
555 "id": id,
556 "params": {
557 "message": {
558 "messageId": "msg-1",
559 "role": "user",
560 "parts": [{"type": "text", "text": "hello"}]
561 }
562 }
563 })
564 .to_string()
565 }
566
567 #[tokio::test]
569 async fn ws_send_message_success() {
570 let addr = spawn_ws_server().await;
571 let mut ws = ws_connect(addr).await;
572
573 ws.send(WsMessage::Text(send_message_json("sm-1")))
574 .await
575 .unwrap();
576
577 let text = read_text(&mut ws).await;
578 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
579 assert_eq!(v["id"], "sm-1");
580 assert!(v.get("result").is_some(), "expected result key: {text}");
582 }
583
584 #[tokio::test]
586 async fn ws_get_task_not_found() {
587 let addr = spawn_ws_server().await;
588 let mut ws = ws_connect(addr).await;
589
590 let req = serde_json::json!({
591 "jsonrpc": "2.0",
592 "method": "GetTask",
593 "id": "gt-1",
594 "params": {"id": "nonexistent"}
595 })
596 .to_string();
597 ws.send(WsMessage::Text(req)).await.unwrap();
598
599 let text = read_text(&mut ws).await;
600 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
601 assert!(v.get("error").is_some(), "expected error: {text}");
602 }
603
604 #[tokio::test]
606 async fn ws_list_tasks_success() {
607 let addr = spawn_ws_server().await;
608 let mut ws = ws_connect(addr).await;
609
610 let req = serde_json::json!({
611 "jsonrpc": "2.0",
612 "method": "ListTasks",
613 "id": "lt-1",
614 "params": {}
615 })
616 .to_string();
617 ws.send(WsMessage::Text(req)).await.unwrap();
618
619 let text = read_text(&mut ws).await;
620 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
621 assert_eq!(v["id"], "lt-1");
622 assert!(v.get("result").is_some(), "expected result: {text}");
623 }
624
625 #[tokio::test]
627 async fn ws_cancel_task_not_found() {
628 let addr = spawn_ws_server().await;
629 let mut ws = ws_connect(addr).await;
630
631 let req = serde_json::json!({
632 "jsonrpc": "2.0",
633 "method": "CancelTask",
634 "id": "ct-1",
635 "params": {"id": "nonexistent"}
636 })
637 .to_string();
638 ws.send(WsMessage::Text(req)).await.unwrap();
639
640 let text = read_text(&mut ws).await;
641 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
642 assert!(v.get("error").is_some(), "expected error: {text}");
643 }
644
645 #[tokio::test]
647 async fn ws_subscribe_task_not_found() {
648 let addr = spawn_ws_server().await;
649 let mut ws = ws_connect(addr).await;
650
651 let req = serde_json::json!({
652 "jsonrpc": "2.0",
653 "method": "SubscribeToTask",
654 "id": "sub-1",
655 "params": {"id": "nonexistent"}
656 })
657 .to_string();
658 ws.send(WsMessage::Text(req)).await.unwrap();
659
660 let text = read_text(&mut ws).await;
661 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
662 assert!(v.get("error").is_some(), "expected error: {text}");
663 }
664
665 #[tokio::test]
667 async fn ws_unknown_method_error() {
668 let addr = spawn_ws_server().await;
669 let mut ws = ws_connect(addr).await;
670
671 let req = serde_json::json!({
672 "jsonrpc": "2.0",
673 "method": "FooBar",
674 "id": "unk-1",
675 "params": {}
676 })
677 .to_string();
678 ws.send(WsMessage::Text(req)).await.unwrap();
679
680 let text = read_text(&mut ws).await;
681 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
682 assert!(v.get("error").is_some(), "expected error: {text}");
683 let msg = v["error"]["message"].as_str().unwrap_or("");
684 assert!(
685 msg.to_lowercase().contains("method")
686 || msg.to_lowercase().contains("not found")
687 || msg.to_lowercase().contains("unsupported"),
688 "error message should mention method not found: {msg}"
689 );
690 }
691
692 #[tokio::test]
694 async fn ws_invalid_json_parse_error() {
695 let addr = spawn_ws_server().await;
696 let mut ws = ws_connect(addr).await;
697
698 ws.send(WsMessage::Text("this is not json {{".into()))
699 .await
700 .unwrap();
701
702 let text = read_text(&mut ws).await;
703 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
704 assert_eq!(v["error"]["code"], -32700, "expected parse error code");
705 }
706
707 #[tokio::test]
709 async fn ws_oversized_message_rejected() {
710 let addr = spawn_ws_server().await;
711 let mut ws = ws_connect(addr).await;
712
713 let big = "x".repeat(4 * 1024 * 1024 + 1);
715 ws.send(WsMessage::Text(big)).await.unwrap();
716
717 let text = read_text(&mut ws).await;
718 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
719 assert!(v.get("error").is_some(), "expected error: {text}");
720 let msg = v["error"]["message"].as_str().unwrap_or("");
721 assert!(
722 msg.contains("too large"),
723 "error should mention 'too large': {msg}"
724 );
725 }
726
727 #[tokio::test]
729 async fn ws_ping_pong_response() {
730 let addr = spawn_ws_server().await;
731 let mut ws = ws_connect(addr).await;
732
733 ws.send(WsMessage::Ping(vec![42, 43])).await.unwrap();
734
735 let pong = tokio::time::timeout(std::time::Duration::from_secs(3), async {
736 loop {
737 let msg = ws.next().await.unwrap().unwrap();
738 if let WsMessage::Pong(data) = msg {
739 return data;
740 }
741 }
742 })
743 .await
744 .expect("should get pong within 3s");
745 assert_eq!(pong, vec![42, 43]);
746 }
747
748 #[tokio::test]
750 async fn ws_get_task_invalid_params() {
751 let addr = spawn_ws_server().await;
752 let mut ws = ws_connect(addr).await;
753
754 let req = serde_json::json!({
756 "jsonrpc": "2.0",
757 "method": "GetTask",
758 "id": "gti-1",
759 "params": {"wrong_field": 123}
760 })
761 .to_string();
762 ws.send(WsMessage::Text(req)).await.unwrap();
763
764 let text = read_text(&mut ws).await;
765 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
766 assert!(
767 v.get("error").is_some(),
768 "expected error for bad params: {text}"
769 );
770 }
771
772 #[tokio::test]
774 async fn ws_send_streaming_message_events() {
775 let addr = spawn_ws_server().await;
776 let mut ws = ws_connect(addr).await;
777
778 let req = serde_json::json!({
779 "jsonrpc": "2.0",
780 "method": "SendStreamingMessage",
781 "id": "ssm-1",
782 "params": {
783 "message": {
784 "messageId": "msg-stream-1",
785 "role": "user",
786 "parts": [{"type": "text", "text": "stream me"}]
787 }
788 }
789 })
790 .to_string();
791 ws.send(WsMessage::Text(req)).await.unwrap();
792
793 let mut frames = Vec::new();
795 let timeout = tokio::time::timeout(std::time::Duration::from_secs(5), async {
796 loop {
797 let msg = ws.next().await.unwrap().unwrap();
798 let text = msg.into_text().unwrap();
799 let done = text.contains("stream_complete");
800 frames.push(text);
801 if done {
802 break;
803 }
804 }
805 });
806 timeout.await.expect("streaming should complete within 5s");
807
808 assert!(
810 frames.len() >= 3,
811 "expected >= 3 frames, got {}: {:?}",
812 frames.len(),
813 frames
814 );
815 assert!(frames.last().unwrap().contains("stream_complete"));
817 }
818
819 #[tokio::test]
821 async fn ws_send_message_invalid_params() {
822 let addr = spawn_ws_server().await;
823 let mut ws = ws_connect(addr).await;
824
825 let req = serde_json::json!({
826 "jsonrpc": "2.0",
827 "method": "SendMessage",
828 "id": "smi-1",
829 "params": {"not_message": true}
830 })
831 .to_string();
832 ws.send(WsMessage::Text(req)).await.unwrap();
833
834 let text = read_text(&mut ws).await;
835 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
836 assert!(
837 v.get("error").is_some(),
838 "expected error for bad send params: {text}"
839 );
840 }
841
842 #[tokio::test]
844 async fn ws_subscribe_invalid_params() {
845 let addr = spawn_ws_server().await;
846 let mut ws = ws_connect(addr).await;
847
848 let req = serde_json::json!({
849 "jsonrpc": "2.0",
850 "method": "SubscribeToTask",
851 "id": "subi-1",
852 "params": {}
853 })
854 .to_string();
855 ws.send(WsMessage::Text(req)).await.unwrap();
856
857 let text = read_text(&mut ws).await;
858 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
859 assert!(
860 v.get("error").is_some(),
861 "expected error for bad subscribe params: {text}"
862 );
863 }
864
865 #[tokio::test]
867 async fn ws_cancel_task_invalid_params() {
868 let addr = spawn_ws_server().await;
869 let mut ws = ws_connect(addr).await;
870
871 let req = serde_json::json!({
872 "jsonrpc": "2.0",
873 "method": "CancelTask",
874 "id": "cti-1",
875 "params": {"wrong": 1}
876 })
877 .to_string();
878 ws.send(WsMessage::Text(req)).await.unwrap();
879
880 let text = read_text(&mut ws).await;
881 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
882 assert!(v.get("error").is_some(), "expected error: {text}");
883 }
884
885 #[tokio::test]
887 async fn ws_list_tasks_with_filters() {
888 let addr = spawn_ws_server().await;
889 let mut ws = ws_connect(addr).await;
890
891 let req = serde_json::json!({
892 "jsonrpc": "2.0",
893 "method": "ListTasks",
894 "id": "ltf-1",
895 "params": {
896 "contextId": "ctx-1",
897 "pageSize": 10
898 }
899 })
900 .to_string();
901 ws.send(WsMessage::Text(req)).await.unwrap();
902
903 let text = read_text(&mut ws).await;
904 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
905 assert_eq!(v["id"], "ltf-1");
906 assert!(v.get("result").is_some(), "expected result: {text}");
907 }
908}