1use std::collections::HashMap;
30use std::net::SocketAddr;
31use std::sync::Arc;
32
33use futures_util::stream::SplitSink;
34use futures_util::{SinkExt, StreamExt};
35use tokio::net::{TcpListener, TcpStream};
36use tokio_tungstenite::tungstenite::Message as WsMessage;
37use tokio_tungstenite::WebSocketStream;
38
39use a2a_protocol_types::jsonrpc::{
40 JsonRpcError, JsonRpcErrorResponse, JsonRpcId, JsonRpcRequest, JsonRpcSuccessResponse,
41 JsonRpcVersion,
42};
43
44use crate::error::ServerError;
45use crate::handler::{RequestHandler, SendMessageResult};
46use crate::streaming::EventQueueReader;
47
48pub struct WebSocketDispatcher {
53 handler: Arc<RequestHandler>,
54}
55
56impl WebSocketDispatcher {
57 #[must_use]
59 pub const fn new(handler: Arc<RequestHandler>) -> Self {
60 Self { handler }
61 }
62
63 pub async fn serve(
69 self: Arc<Self>,
70 addr: impl tokio::net::ToSocketAddrs,
71 ) -> std::io::Result<()> {
72 let listener = TcpListener::bind(addr).await?;
73
74 trace_info!(
75 addr = %listener.local_addr().unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0))),
76 "A2A WebSocket server listening"
77 );
78
79 loop {
80 let (stream, _peer) = listener.accept().await?;
81 let dispatcher = Arc::clone(&self);
82 tokio::spawn(async move {
83 trace_debug!("WebSocket connection accepted");
84 if let Err(_e) = dispatcher.handle_connection(stream).await {
85 trace_warn!("WebSocket connection error");
86 }
87 });
88 }
89 }
90
91 pub async fn serve_with_addr(
99 self: Arc<Self>,
100 addr: impl tokio::net::ToSocketAddrs,
101 ) -> std::io::Result<SocketAddr> {
102 let listener = TcpListener::bind(addr).await?;
103 let local_addr = listener.local_addr()?;
104
105 trace_info!(%local_addr, "A2A WebSocket server listening");
106
107 tokio::spawn(async move {
108 loop {
109 let Ok((stream, _peer)) = listener.accept().await else {
110 break;
111 };
112 let dispatcher = Arc::clone(&self);
113 tokio::spawn(async move {
114 let _ = dispatcher.handle_connection(stream).await;
115 });
116 }
117 });
118
119 Ok(local_addr)
120 }
121
122 async fn handle_connection(&self, stream: TcpStream) -> Result<(), WsError> {
124 let ws_stream = tokio_tungstenite::accept_async(stream)
125 .await
126 .map_err(WsError::Handshake)?;
127
128 let (writer, mut reader) = ws_stream.split();
129 let writer = Arc::new(tokio::sync::Mutex::new(writer));
130
131 let semaphore = Arc::new(tokio::sync::Semaphore::new(64));
133
134 while let Some(msg) = reader.next().await {
135 match msg {
136 Ok(WsMessage::Text(text)) => {
137 if text.len() > 4 * 1024 * 1024 {
139 let err_resp = JsonRpcErrorResponse::new(
140 None,
141 JsonRpcError::new(-32000, "message too large".to_string()),
142 );
143 send_json(&writer, &err_resp).await;
144 continue;
145 }
146
147 let Ok(permit) = semaphore.clone().try_acquire_owned() else {
149 let err_resp = JsonRpcErrorResponse::new(
150 None,
151 JsonRpcError::new(
152 -32000,
153 "server busy: too many concurrent requests".to_string(),
154 ),
155 );
156 send_json(&writer, &err_resp).await;
157 continue;
158 };
159
160 let writer = Arc::clone(&writer);
161 let handler = Arc::clone(&self.handler);
162 tokio::spawn(async move {
163 process_ws_message(&handler, &text, writer).await;
164 drop(permit); });
166 }
167 Ok(WsMessage::Ping(data)) => {
168 let mut w = writer.lock().await;
169 let _ = w.send(WsMessage::Pong(data)).await;
170 drop(w);
171 }
172 Ok(WsMessage::Close(_)) | Err(_) => break,
173 Ok(_) => {} }
175 }
176
177 Ok(())
178 }
179}
180
181#[derive(Debug)]
183enum WsError {
184 Handshake(tokio_tungstenite::tungstenite::Error),
185}
186
187impl std::fmt::Display for WsError {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 match self {
190 Self::Handshake(e) => write!(f, "WebSocket handshake failed: {e}"),
191 }
192 }
193}
194
195type WsSink = Arc<tokio::sync::Mutex<SplitSink<WebSocketStream<TcpStream>, WsMessage>>>;
196
197#[allow(clippy::too_many_lines)]
199async fn process_ws_message(handler: &RequestHandler, text: &str, writer: WsSink) {
200 let rpc_req: JsonRpcRequest = match serde_json::from_str(text) {
201 Ok(req) => req,
202 Err(e) => {
203 let err_resp = JsonRpcErrorResponse::new(
204 None,
205 JsonRpcError::new(-32700, format!("parse error: {e}")),
206 );
207 send_json(&writer, &err_resp).await;
208 return;
209 }
210 };
211
212 let id = rpc_req.id.clone();
213 let headers = HashMap::new();
214
215 match rpc_req.method.as_str() {
216 "SendMessage" => {
217 dispatch_send_message(handler, &rpc_req, false, &headers, id, &writer).await;
218 }
219 "SendStreamingMessage" => {
220 dispatch_send_message(handler, &rpc_req, true, &headers, id, &writer).await;
221 }
222 "GetTask" => {
223 dispatch_simple(handler, &rpc_req, id, &headers, &writer, |h, p, hdr| {
224 Box::pin(async move {
225 let params: a2a_protocol_types::params::TaskQueryParams =
226 serde_json::from_value(p).map_err(|e| {
227 a2a_protocol_types::error::A2aError::invalid_params(e.to_string())
228 })?;
229 h.on_get_task(params, Some(hdr))
230 .await
231 .map(|r| serde_json::to_value(&r).unwrap_or_default())
232 .map_err(|e| e.to_a2a_error())
233 })
234 })
235 .await;
236 }
237 "ListTasks" => {
238 dispatch_simple(handler, &rpc_req, id, &headers, &writer, |h, p, hdr| {
239 Box::pin(async move {
240 let params: a2a_protocol_types::params::ListTasksParams =
241 serde_json::from_value(p).map_err(|e| {
242 a2a_protocol_types::error::A2aError::invalid_params(e.to_string())
243 })?;
244 h.on_list_tasks(params, Some(hdr))
245 .await
246 .map(|r| serde_json::to_value(&r).unwrap_or_default())
247 .map_err(|e| e.to_a2a_error())
248 })
249 })
250 .await;
251 }
252 "CancelTask" => {
253 dispatch_simple(handler, &rpc_req, id, &headers, &writer, |h, p, hdr| {
254 Box::pin(async move {
255 let params: a2a_protocol_types::params::CancelTaskParams =
256 serde_json::from_value(p).map_err(|e| {
257 a2a_protocol_types::error::A2aError::invalid_params(e.to_string())
258 })?;
259 h.on_cancel_task(params, Some(hdr))
260 .await
261 .map(|r| serde_json::to_value(&r).unwrap_or_default())
262 .map_err(|e| e.to_a2a_error())
263 })
264 })
265 .await;
266 }
267 "SubscribeToTask" => {
268 let params = match parse_params::<a2a_protocol_types::params::TaskIdParams>(
269 rpc_req.params.as_ref(),
270 ) {
271 Ok(p) => p,
272 Err(e) => {
273 send_error(&writer, id, &e).await;
274 return;
275 }
276 };
277 match handler.on_resubscribe(params, Some(&headers)).await {
278 Ok(reader) => {
279 stream_events(&writer, reader, id).await;
280 }
281 Err(e) => {
282 send_error(&writer, id, &e).await;
283 }
284 }
285 }
286 other => {
287 let err = ServerError::MethodNotFound(other.to_owned());
288 send_error(&writer, id, &err).await;
289 }
290 }
291}
292
293async fn dispatch_send_message(
295 handler: &RequestHandler,
296 rpc_req: &JsonRpcRequest,
297 streaming: bool,
298 headers: &HashMap<String, String>,
299 id: JsonRpcId,
300 writer: &WsSink,
301) {
302 let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(
303 rpc_req.params.as_ref(),
304 ) {
305 Ok(p) => p,
306 Err(e) => {
307 send_error(writer, id, &e).await;
308 return;
309 }
310 };
311
312 match handler
313 .on_send_message(params, streaming, Some(headers))
314 .await
315 {
316 Ok(SendMessageResult::Response(resp)) => {
317 let result = serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null);
318 let success = JsonRpcSuccessResponse {
319 jsonrpc: JsonRpcVersion,
320 id,
321 result,
322 };
323 send_json(writer, &success).await;
324 }
325 Ok(SendMessageResult::Stream(reader)) => {
326 stream_events(writer, reader, id).await;
327 }
328 Err(e) => {
329 send_error(writer, id, &e).await;
330 }
331 }
332}
333
334async fn stream_events(
336 writer: &WsSink,
337 mut reader: crate::streaming::InMemoryQueueReader,
338 id: JsonRpcId,
339) {
340 while let Some(event) = reader.read().await {
341 match event {
342 Ok(stream_resp) => {
343 let envelope = JsonRpcSuccessResponse {
346 jsonrpc: JsonRpcVersion,
347 id: id.clone(),
348 result: stream_resp,
349 };
350 let json = serde_json::to_string(&envelope).unwrap_or_default();
351 let mut w = writer.lock().await;
352 if w.send(WsMessage::Text(json.into())).await.is_err() {
353 return; }
355 drop(w);
356 }
357 Err(e) => {
358 let err_resp =
359 JsonRpcErrorResponse::new(id.clone(), JsonRpcError::new(-32000, e.to_string()));
360 send_json(writer, &err_resp).await;
361 return;
362 }
363 }
364 }
365
366 let success = JsonRpcSuccessResponse {
368 jsonrpc: JsonRpcVersion,
369 id,
370 result: serde_json::json!({"status": "stream_complete"}),
371 };
372 send_json(writer, &success).await;
373}
374
375async fn dispatch_simple<'a, F>(
377 handler: &'a RequestHandler,
378 rpc_req: &JsonRpcRequest,
379 id: JsonRpcId,
380 headers: &'a HashMap<String, String>,
381 writer: &WsSink,
382 f: F,
383) where
384 F: FnOnce(
385 &'a RequestHandler,
386 serde_json::Value,
387 &'a HashMap<String, String>,
388 ) -> std::pin::Pin<
389 Box<
390 dyn std::future::Future<
391 Output = Result<serde_json::Value, a2a_protocol_types::error::A2aError>,
392 > + Send
393 + 'a,
394 >,
395 >,
396{
397 let params = rpc_req.params.clone().unwrap_or(serde_json::Value::Null);
398 match f(handler, params, headers).await {
399 Ok(result) => {
400 let success = JsonRpcSuccessResponse {
401 jsonrpc: JsonRpcVersion,
402 id,
403 result,
404 };
405 send_json(writer, &success).await;
406 }
407 Err(e) => {
408 let err_resp =
409 JsonRpcErrorResponse::new(id, JsonRpcError::new(e.code.as_i32(), e.message));
410 send_json(writer, &err_resp).await;
411 }
412 }
413}
414
415async fn send_json<T: serde::Serialize + Sync>(writer: &WsSink, value: &T) {
417 let json = serde_json::to_string(value).unwrap_or_default();
418 let mut w = writer.lock().await;
419 let _ = w.send(WsMessage::Text(json.into())).await;
420 drop(w);
421}
422
423async fn send_error(writer: &WsSink, id: JsonRpcId, err: &ServerError) {
425 let a2a_err = err.to_a2a_error();
426 let resp = JsonRpcErrorResponse::new(
427 id,
428 JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
429 );
430 send_json(writer, &resp).await;
431}
432
433fn parse_params<T: serde::de::DeserializeOwned>(
435 params: Option<&serde_json::Value>,
436) -> Result<T, ServerError> {
437 let value = params.cloned().unwrap_or(serde_json::Value::Null);
438 serde_json::from_value(value)
439 .map_err(|e| ServerError::InvalidParams(format!("invalid params: {e}")))
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn parse_params_with_valid_json() {
448 let value = Some(serde_json::json!({"id": "task-1"}));
449 let result: Result<a2a_protocol_types::params::TaskQueryParams, _> =
450 parse_params(value.as_ref());
451 assert!(result.is_ok());
452 assert_eq!(result.unwrap().id, "task-1");
453 }
454
455 #[test]
456 fn parse_params_with_none_returns_error() {
457 let result: Result<a2a_protocol_types::params::TaskQueryParams, _> = parse_params(None);
458 assert!(result.is_err());
459 }
460
461 #[test]
462 fn parse_params_with_wrong_type_returns_error() {
463 let value = Some(serde_json::json!("not an object"));
464 let result: Result<a2a_protocol_types::params::TaskQueryParams, _> =
465 parse_params(value.as_ref());
466 assert!(result.is_err());
467 }
468
469 #[test]
471 fn ws_error_display_contains_message() {
472 let err = WsError::Handshake(tokio_tungstenite::tungstenite::Error::ConnectionClosed);
473 let s = err.to_string();
474 assert!(s.contains("WebSocket handshake failed"));
475 }
476
477 #[test]
479 fn websocket_dispatcher_new() {
480 use crate::agent_executor;
481 use crate::RequestHandlerBuilder;
482 use std::sync::Arc;
483 struct DummyExec;
484 agent_executor!(DummyExec, |_ctx, _queue| async { Ok(()) });
485 let handler = Arc::new(RequestHandlerBuilder::new(DummyExec).build().unwrap());
486 let _dispatcher = WebSocketDispatcher::new(handler);
487 }
488
489 use crate::agent_executor;
492 use crate::RequestHandlerBuilder;
493 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
494 use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
495 use futures_util::{SinkExt, StreamExt};
496
497 struct EchoExec;
498 agent_executor!(EchoExec, |ctx, queue| async {
499 queue
500 .write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
501 task_id: ctx.task_id.clone(),
502 context_id: ContextId::new(ctx.context_id.clone()),
503 status: TaskStatus::new(TaskState::Working),
504 metadata: None,
505 }))
506 .await?;
507 queue
508 .write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
509 task_id: ctx.task_id.clone(),
510 context_id: ContextId::new(ctx.context_id.clone()),
511 status: TaskStatus::new(TaskState::Completed),
512 metadata: None,
513 }))
514 .await?;
515 Ok(())
516 });
517
518 async fn spawn_ws_server() -> std::net::SocketAddr {
519 let handler = Arc::new(RequestHandlerBuilder::new(EchoExec).build().unwrap());
520 let dispatcher = Arc::new(WebSocketDispatcher::new(handler));
521 dispatcher
522 .serve_with_addr("127.0.0.1:0")
523 .await
524 .expect("bind to port 0")
525 }
526
527 async fn ws_connect(
528 addr: std::net::SocketAddr,
529 ) -> tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>
530 {
531 let (ws, _) = tokio_tungstenite::connect_async(format!("ws://{addr}"))
532 .await
533 .expect("ws connect");
534 ws
535 }
536
537 async fn read_text(
539 ws: &mut tokio_tungstenite::WebSocketStream<
540 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
541 >,
542 ) -> String {
543 let msg = tokio::time::timeout(std::time::Duration::from_secs(5), ws.next())
544 .await
545 .expect("timeout waiting for WS frame")
546 .expect("stream ended")
547 .expect("ws error");
548 msg.into_text()
549 .expect("not a text frame")
550 .as_str()
551 .to_owned()
552 }
553
554 fn send_message_json(id: &str) -> String {
555 serde_json::json!({
556 "jsonrpc": "2.0",
557 "method": "SendMessage",
558 "id": id,
559 "params": {
560 "message": {
561 "messageId": "msg-1",
562 "role": "ROLE_USER",
563 "parts": [{"text": "hello"}]
564 }
565 }
566 })
567 .to_string()
568 }
569
570 #[tokio::test]
572 async fn ws_send_message_success() {
573 let addr = spawn_ws_server().await;
574 let mut ws = ws_connect(addr).await;
575
576 ws.send(WsMessage::Text(send_message_json("sm-1").into()))
577 .await
578 .unwrap();
579
580 let text = read_text(&mut ws).await;
581 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
582 assert_eq!(v["id"], "sm-1");
583 assert!(v.get("result").is_some(), "expected result key: {text}");
585 }
586
587 #[tokio::test]
589 async fn ws_get_task_not_found() {
590 let addr = spawn_ws_server().await;
591 let mut ws = ws_connect(addr).await;
592
593 let req = serde_json::json!({
594 "jsonrpc": "2.0",
595 "method": "GetTask",
596 "id": "gt-1",
597 "params": {"id": "nonexistent"}
598 })
599 .to_string();
600 ws.send(WsMessage::Text(req.into())).await.unwrap();
601
602 let text = read_text(&mut ws).await;
603 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
604 assert!(v.get("error").is_some(), "expected error: {text}");
605 }
606
607 #[tokio::test]
609 async fn ws_list_tasks_success() {
610 let addr = spawn_ws_server().await;
611 let mut ws = ws_connect(addr).await;
612
613 let req = serde_json::json!({
614 "jsonrpc": "2.0",
615 "method": "ListTasks",
616 "id": "lt-1",
617 "params": {}
618 })
619 .to_string();
620 ws.send(WsMessage::Text(req.into())).await.unwrap();
621
622 let text = read_text(&mut ws).await;
623 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
624 assert_eq!(v["id"], "lt-1");
625 assert!(v.get("result").is_some(), "expected result: {text}");
626 }
627
628 #[tokio::test]
630 async fn ws_cancel_task_not_found() {
631 let addr = spawn_ws_server().await;
632 let mut ws = ws_connect(addr).await;
633
634 let req = serde_json::json!({
635 "jsonrpc": "2.0",
636 "method": "CancelTask",
637 "id": "ct-1",
638 "params": {"id": "nonexistent"}
639 })
640 .to_string();
641 ws.send(WsMessage::Text(req.into())).await.unwrap();
642
643 let text = read_text(&mut ws).await;
644 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
645 assert!(v.get("error").is_some(), "expected error: {text}");
646 }
647
648 #[tokio::test]
650 async fn ws_subscribe_task_not_found() {
651 let addr = spawn_ws_server().await;
652 let mut ws = ws_connect(addr).await;
653
654 let req = serde_json::json!({
655 "jsonrpc": "2.0",
656 "method": "SubscribeToTask",
657 "id": "sub-1",
658 "params": {"id": "nonexistent"}
659 })
660 .to_string();
661 ws.send(WsMessage::Text(req.into())).await.unwrap();
662
663 let text = read_text(&mut ws).await;
664 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
665 assert!(v.get("error").is_some(), "expected error: {text}");
666 }
667
668 #[tokio::test]
670 async fn ws_unknown_method_error() {
671 let addr = spawn_ws_server().await;
672 let mut ws = ws_connect(addr).await;
673
674 let req = serde_json::json!({
675 "jsonrpc": "2.0",
676 "method": "FooBar",
677 "id": "unk-1",
678 "params": {}
679 })
680 .to_string();
681 ws.send(WsMessage::Text(req.into())).await.unwrap();
682
683 let text = read_text(&mut ws).await;
684 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
685 assert!(v.get("error").is_some(), "expected error: {text}");
686 let msg = v["error"]["message"].as_str().unwrap_or("");
687 assert!(
688 msg.to_lowercase().contains("method")
689 || msg.to_lowercase().contains("not found")
690 || msg.to_lowercase().contains("unsupported"),
691 "error message should mention method not found: {msg}"
692 );
693 }
694
695 #[tokio::test]
697 async fn ws_invalid_json_parse_error() {
698 let addr = spawn_ws_server().await;
699 let mut ws = ws_connect(addr).await;
700
701 ws.send(WsMessage::Text("this is not json {{".into()))
702 .await
703 .unwrap();
704
705 let text = read_text(&mut ws).await;
706 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
707 assert_eq!(v["error"]["code"], -32700, "expected parse error code");
708 }
709
710 #[tokio::test]
712 async fn ws_oversized_message_rejected() {
713 let addr = spawn_ws_server().await;
714 let mut ws = ws_connect(addr).await;
715
716 let big = "x".repeat(4 * 1024 * 1024 + 1);
718 ws.send(WsMessage::Text(big.into())).await.unwrap();
719
720 let text = read_text(&mut ws).await;
721 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
722 assert!(v.get("error").is_some(), "expected error: {text}");
723 let msg = v["error"]["message"].as_str().unwrap_or("");
724 assert!(
725 msg.contains("too large"),
726 "error should mention 'too large': {msg}"
727 );
728 }
729
730 #[tokio::test]
732 async fn ws_ping_pong_response() {
733 let addr = spawn_ws_server().await;
734 let mut ws = ws_connect(addr).await;
735
736 ws.send(WsMessage::Ping(vec![42, 43].into())).await.unwrap();
737
738 let pong = tokio::time::timeout(std::time::Duration::from_secs(3), async {
739 loop {
740 let msg = ws.next().await.unwrap().unwrap();
741 if let WsMessage::Pong(data) = msg {
742 return data;
743 }
744 }
745 })
746 .await
747 .expect("should get pong within 3s");
748 assert_eq!(pong, vec![42, 43]);
749 }
750
751 #[tokio::test]
753 async fn ws_get_task_invalid_params() {
754 let addr = spawn_ws_server().await;
755 let mut ws = ws_connect(addr).await;
756
757 let req = serde_json::json!({
759 "jsonrpc": "2.0",
760 "method": "GetTask",
761 "id": "gti-1",
762 "params": {"wrong_field": 123}
763 })
764 .to_string();
765 ws.send(WsMessage::Text(req.into())).await.unwrap();
766
767 let text = read_text(&mut ws).await;
768 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
769 assert!(
770 v.get("error").is_some(),
771 "expected error for bad params: {text}"
772 );
773 }
774
775 #[tokio::test]
777 async fn ws_send_streaming_message_events() {
778 let addr = spawn_ws_server().await;
779 let mut ws = ws_connect(addr).await;
780
781 let req = serde_json::json!({
782 "jsonrpc": "2.0",
783 "method": "SendStreamingMessage",
784 "id": "ssm-1",
785 "params": {
786 "message": {
787 "messageId": "msg-stream-1",
788 "role": "ROLE_USER",
789 "parts": [{"text": "stream me"}]
790 }
791 }
792 })
793 .to_string();
794 ws.send(WsMessage::Text(req.into())).await.unwrap();
795
796 let mut frames = Vec::new();
798 let timeout = tokio::time::timeout(std::time::Duration::from_secs(5), async {
799 loop {
800 let msg = ws.next().await.unwrap().unwrap();
801 let text = msg.into_text().unwrap();
802 let done = text.contains("stream_complete");
803 frames.push(text);
804 if done {
805 break;
806 }
807 }
808 });
809 timeout.await.expect("streaming should complete within 5s");
810
811 assert!(
813 frames.len() >= 3,
814 "expected >= 3 frames, got {}: {:?}",
815 frames.len(),
816 frames
817 );
818 assert!(frames.last().unwrap().contains("stream_complete"));
820 }
821
822 #[tokio::test]
824 async fn ws_send_message_invalid_params() {
825 let addr = spawn_ws_server().await;
826 let mut ws = ws_connect(addr).await;
827
828 let req = serde_json::json!({
829 "jsonrpc": "2.0",
830 "method": "SendMessage",
831 "id": "smi-1",
832 "params": {"not_message": true}
833 })
834 .to_string();
835 ws.send(WsMessage::Text(req.into())).await.unwrap();
836
837 let text = read_text(&mut ws).await;
838 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
839 assert!(
840 v.get("error").is_some(),
841 "expected error for bad send params: {text}"
842 );
843 }
844
845 #[tokio::test]
847 async fn ws_subscribe_invalid_params() {
848 let addr = spawn_ws_server().await;
849 let mut ws = ws_connect(addr).await;
850
851 let req = serde_json::json!({
852 "jsonrpc": "2.0",
853 "method": "SubscribeToTask",
854 "id": "subi-1",
855 "params": {}
856 })
857 .to_string();
858 ws.send(WsMessage::Text(req.into())).await.unwrap();
859
860 let text = read_text(&mut ws).await;
861 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
862 assert!(
863 v.get("error").is_some(),
864 "expected error for bad subscribe params: {text}"
865 );
866 }
867
868 #[tokio::test]
870 async fn ws_cancel_task_invalid_params() {
871 let addr = spawn_ws_server().await;
872 let mut ws = ws_connect(addr).await;
873
874 let req = serde_json::json!({
875 "jsonrpc": "2.0",
876 "method": "CancelTask",
877 "id": "cti-1",
878 "params": {"wrong": 1}
879 })
880 .to_string();
881 ws.send(WsMessage::Text(req.into())).await.unwrap();
882
883 let text = read_text(&mut ws).await;
884 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
885 assert!(v.get("error").is_some(), "expected error: {text}");
886 }
887
888 #[tokio::test]
890 async fn ws_list_tasks_with_filters() {
891 let addr = spawn_ws_server().await;
892 let mut ws = ws_connect(addr).await;
893
894 let req = serde_json::json!({
895 "jsonrpc": "2.0",
896 "method": "ListTasks",
897 "id": "ltf-1",
898 "params": {
899 "contextId": "ctx-1",
900 "pageSize": 10
901 }
902 })
903 .to_string();
904 ws.send(WsMessage::Text(req.into())).await.unwrap();
905
906 let text = read_text(&mut ws).await;
907 let v: serde_json::Value = serde_json::from_str(&text).unwrap();
908 assert_eq!(v["id"], "ltf-1");
909 assert!(v.get("result").is_some(), "expected result: {text}");
910 }
911}