1use crate::store::{InMemoryTaskStore, TaskStore};
4use axum::{
5 extract::State,
6 response::{
7 sse::{Event, Sse},
8 IntoResponse, Json,
9 },
10 routing::{get, post},
11 Router,
12};
13use jamjet_a2a_types::*;
14use std::convert::Infallible;
15use std::sync::Arc;
16use tokio_stream::wrappers::BroadcastStream;
17use tokio_stream::StreamExt;
18use tracing::{debug, error, info};
19
20#[async_trait::async_trait]
29pub trait TaskHandler: Send + Sync {
30 async fn handle(
31 &self,
32 task_id: String,
33 message: Message,
34 store: Arc<dyn TaskStore>,
35 ) -> Result<(), A2aError>;
36}
37
38#[derive(Clone)]
43struct ServerState {
44 card: Arc<AgentCard>,
45 store: Arc<dyn TaskStore>,
46 handler: Arc<Option<Box<dyn TaskHandler>>>,
47}
48
49pub struct A2aServer {
87 card: AgentCard,
88 port: u16,
89 handler: Option<Box<dyn TaskHandler>>,
90 store: Option<Box<dyn TaskStore>>,
91}
92
93impl A2aServer {
94 pub fn new(card: AgentCard) -> Self {
96 Self {
97 card,
98 port: 3000,
99 handler: None,
100 store: None,
101 }
102 }
103
104 pub fn with_port(mut self, port: u16) -> Self {
106 self.port = port;
107 self
108 }
109
110 pub fn with_handler(mut self, handler: impl TaskHandler + 'static) -> Self {
112 self.handler = Some(Box::new(handler));
113 self
114 }
115
116 pub fn with_store(mut self, store: impl TaskStore + 'static) -> Self {
118 self.store = Some(Box::new(store));
119 self
120 }
121
122 pub fn into_router(self) -> Router {
127 let store: Arc<dyn TaskStore> = match self.store {
128 Some(s) => Arc::from(s),
129 None => Arc::new(InMemoryTaskStore::new()),
130 };
131
132 let state = ServerState {
133 card: Arc::new(self.card),
134 store,
135 handler: Arc::new(self.handler),
136 };
137
138 Router::new()
139 .route("/.well-known/agent-card.json", get(agent_card_handler))
140 .route("/.well-known/agent.json", get(agent_card_handler))
141 .route("/", post(jsonrpc_handler))
142 .with_state(state)
143 }
144
145 pub async fn start(self) -> Result<(), A2aError> {
147 let port = self.port;
148 let router = self.into_router();
149 let addr = format!("0.0.0.0:{port}");
150 info!(addr = %addr, "starting A2A server");
151
152 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
153 A2aTransportError::Connection {
154 url: addr.clone(),
155 source: Box::new(e),
156 }
157 })?;
158
159 axum::serve(listener, router)
160 .await
161 .map_err(|e| A2aTransportError::Connection {
162 url: addr,
163 source: Box::new(e),
164 })?;
165
166 Ok(())
167 }
168}
169
170async fn agent_card_handler(State(state): State<ServerState>) -> impl IntoResponse {
176 let headers = [
177 (axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"),
178 (axum::http::header::CACHE_CONTROL, "no-cache"),
179 ];
180 (
181 headers,
182 Json(serde_json::to_value(state.card.as_ref()).unwrap_or(serde_json::json!({}))),
183 )
184}
185
186#[allow(dead_code)]
189struct IncomingRpc {
190 jsonrpc: String,
191 id: serde_json::Value,
192 method: String,
193 params: serde_json::Value,
194}
195
196async fn jsonrpc_handler(
202 State(state): State<ServerState>,
203 body: axum::body::Bytes,
204) -> axum::response::Response {
205 let body: serde_json::Value = match serde_json::from_slice(&body) {
207 Ok(v) => v,
208 Err(_) => {
209 return Json(serde_json::json!({
210 "jsonrpc": "2.0",
211 "id": null,
212 "error": {"code": -32700, "message": "Parse error"}
213 }))
214 .into_response();
215 }
216 };
217
218 let jsonrpc = body.get("jsonrpc").and_then(|v| v.as_str());
220 if jsonrpc != Some("2.0") {
221 let id = body.get("id").cloned().unwrap_or(serde_json::Value::Null);
222 return Json(serde_json::json!({
223 "jsonrpc": "2.0",
224 "id": id,
225 "error": {"code": -32600, "message": "Invalid Request"}
226 }))
227 .into_response();
228 }
229
230 if let Some(id_val) = body.get("id") {
232 match id_val {
233 serde_json::Value::String(_)
234 | serde_json::Value::Number(_)
235 | serde_json::Value::Null => {} _ => {
237 return Json(serde_json::json!({
238 "jsonrpc": "2.0",
239 "id": null,
240 "error": {"code": -32600, "message": "Invalid Request"}
241 }))
242 .into_response();
243 }
244 }
245 }
246
247 let method = match body.get("method").and_then(|v| v.as_str()) {
249 Some(m) => m.to_string(),
250 None => {
251 let id = body.get("id").cloned().unwrap_or(serde_json::Value::Null);
252 return Json(serde_json::json!({
253 "jsonrpc": "2.0",
254 "id": id,
255 "error": {"code": -32600, "message": "Invalid Request"}
256 }))
257 .into_response();
258 }
259 };
260
261 let rpc_id = body.get("id").cloned().unwrap_or(serde_json::Value::Null);
263 let params = body.get("params").cloned().unwrap_or(serde_json::Value::Null);
264
265 let rpc = IncomingRpc {
266 jsonrpc: "2.0".to_string(),
267 id: rpc_id,
268 method: method.clone(),
269 params,
270 };
271
272 debug!(method = %method, "incoming JSON-RPC request");
273
274 match method.as_str() {
275 "SendMessage" | "message/send" => handle_send_message(state, rpc).await,
277 "GetTask" | "tasks/get" => handle_get_task(state, rpc).await,
278 "ListTasks" => handle_list_tasks(state, rpc).await,
279 "CancelTask" | "tasks/cancel" => handle_cancel_task(state, rpc).await,
280 "SendStreamingMessage" | "message/stream" => handle_send_streaming(state, rpc).await,
281 "SubscribeToTask" | "tasks/resubscribe" => handle_subscribe(state, rpc).await,
282 "GetExtendedAgentCard" => handle_get_extended_card(state, rpc).await,
283 "CreateTaskPushNotificationConfig"
284 | "tasks/pushNotificationConfig/set"
285 | "GetTaskPushNotificationConfig"
286 | "ListTaskPushNotificationConfigs"
287 | "DeleteTaskPushNotificationConfig" => make_error_response(
288 rpc.id,
289 A2aProtocolError::UnsupportedOperation { method: rpc.method },
290 )
291 .into_response(),
292 _ => make_json_rpc_error_response(rpc.id, -32601, "Method not found", None).into_response(),
293 }
294}
295
296async fn handle_send_message(state: ServerState, rpc: IncomingRpc) -> axum::response::Response {
301 let req: SendMessageRequest = match serde_json::from_value(rpc.params) {
302 Ok(r) => r,
303 Err(e) => {
304 return make_json_rpc_error_response(
305 rpc.id,
306 -32602,
307 &format!("Invalid params: {e}"),
308 None,
309 )
310 .into_response();
311 }
312 };
313
314 let message = req.message.clone();
315 let task_id = message
316 .task_id
317 .clone()
318 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
319 let context_id = message.context_id.clone();
320
321 let existing_task = if message.task_id.is_some() {
323 state.store.get(&task_id).await.unwrap_or(None)
324 } else {
325 None
326 };
327
328 if let Some(_existing) = existing_task {
329 if let Err(e) = state.store.append_message(&task_id, message.clone()).await {
331 error!(error = %e, "failed to append message");
332 return make_error_response(rpc.id, e).into_response();
333 }
334
335 if state.handler.is_some() {
337 let handler = Arc::clone(&state.handler);
338 let store = Arc::clone(&state.store);
339 let tid = task_id.clone();
340 let msg = message;
341
342 tokio::spawn(async move {
343 if let Some(ref h) = *handler {
344 if let Err(e) = h.handle(tid.clone(), msg, store.clone()).await {
345 error!(task_id = %tid, error = %e, "handler failed");
346 }
347 }
348 });
349 }
350
351 let resp_task = state.store.get(&task_id).await.unwrap_or(None);
352 return make_success_response(rpc.id, &serde_json::json!({"task": resp_task}))
353 .into_response();
354 }
355
356 let task = Task {
358 id: task_id.clone(),
359 context_id,
360 status: TaskStatus {
361 state: TaskState::Submitted,
362 message: None,
363 timestamp: Some(chrono::Utc::now().to_rfc3339()),
364 },
365 artifacts: vec![],
366 history: Some(vec![message.clone()]),
367 metadata: req.metadata.clone(),
368 };
369
370 if let Err(e) = state.store.insert(task.clone()).await {
371 error!(error = %e, "failed to insert task");
372 return make_error_response(rpc.id, e).into_response();
373 }
374
375 if state.handler.is_some() {
377 let handler = Arc::clone(&state.handler);
378 let store = Arc::clone(&state.store);
379 let tid = task_id.clone();
380 let msg = message;
381
382 tokio::spawn(async move {
383 let working_status = TaskStatus {
385 state: TaskState::Working,
386 message: None,
387 timestamp: Some(chrono::Utc::now().to_rfc3339()),
388 };
389 store.update_status(&tid, working_status).await.ok();
390
391 if let Some(ref h) = *handler {
392 match h.handle(tid.clone(), msg, store.clone()).await {
393 Ok(()) => {}
394 Err(e) => {
395 error!(task_id = %tid, error = %e, "handler failed");
396 let failed_status = TaskStatus {
397 state: TaskState::Failed,
398 message: Some(Message {
399 message_id: uuid::Uuid::new_v4().to_string(),
400 context_id: None,
401 task_id: Some(tid.clone()),
402 role: Role::Agent,
403 parts: vec![Part {
404 content: PartContent::Text(format!("Handler error: {e}")),
405 metadata: None,
406 filename: None,
407 media_type: None,
408 }],
409 metadata: None,
410 extensions: vec![],
411 reference_task_ids: vec![],
412 }),
413 timestamp: Some(chrono::Utc::now().to_rfc3339()),
414 };
415 store.update_status(&tid, failed_status).await.ok();
416 }
417 }
418 }
419 });
420 }
421
422 let resp_task = state.store.get(&task_id).await.unwrap_or(Some(task));
425 make_success_response(rpc.id, &serde_json::json!({"task": resp_task})).into_response()
427}
428
429async fn handle_get_task(state: ServerState, rpc: IncomingRpc) -> axum::response::Response {
430 let req: GetTaskRequest = match serde_json::from_value(rpc.params) {
431 Ok(r) => r,
432 Err(e) => {
433 return make_json_rpc_error_response(
434 rpc.id,
435 -32602,
436 &format!("Invalid params: {e}"),
437 None,
438 )
439 .into_response();
440 }
441 };
442
443 match state.store.get(&req.id).await {
444 Ok(Some(mut task)) => {
445 if let Some(hl) = req.history_length {
447 if hl == 0 {
448 task.history = None;
449 } else if let Some(ref mut hist) = task.history {
450 let hl = hl as usize;
451 if hist.len() > hl {
452 let start = hist.len() - hl;
453 *hist = hist.split_off(start);
454 }
455 }
456 }
457 make_success_response(rpc.id, &task).into_response()
458 }
459 Ok(None) => make_error_response(rpc.id, A2aProtocolError::TaskNotFound { task_id: req.id })
460 .into_response(),
461 Err(e) => make_error_response(rpc.id, e).into_response(),
462 }
463}
464
465async fn handle_list_tasks(state: ServerState, rpc: IncomingRpc) -> axum::response::Response {
466 let req: ListTasksRequest = match serde_json::from_value(rpc.params) {
467 Ok(r) => r,
468 Err(e) => {
469 return make_json_rpc_error_response(
470 rpc.id,
471 -32602,
472 &format!("Invalid params: {e}"),
473 None,
474 )
475 .into_response();
476 }
477 };
478
479 if let Some(ps) = req.page_size {
481 if ps <= 0 || ps > 100 {
482 return make_json_rpc_error_response(
483 rpc.id,
484 -32602,
485 "Invalid params: pageSize must be between 1 and 100",
486 None,
487 )
488 .into_response();
489 }
490 }
491
492 if let Some(hl) = req.history_length {
494 if hl < 0 {
495 return make_json_rpc_error_response(
496 rpc.id,
497 -32602,
498 "Invalid params: historyLength must be >= 0",
499 None,
500 )
501 .into_response();
502 }
503 }
504
505 if let Some(ref token) = req.page_token {
507 if !token.is_empty() {
508 match state.store.get(token).await {
510 Ok(None) => {
511 return make_json_rpc_error_response(
512 rpc.id,
513 -32602,
514 "Invalid params: invalid pageToken",
515 None,
516 )
517 .into_response();
518 }
519 Err(e) => return make_error_response(rpc.id, e).into_response(),
520 Ok(Some(_)) => {} }
522 }
523 }
524
525 if let Some(ref ts) = req.status_timestamp_after {
527 if chrono::DateTime::parse_from_rfc3339(ts).is_err() {
528 return make_json_rpc_error_response(
529 rpc.id,
530 -32602,
531 "Invalid params: statusTimestampAfter must be a valid ISO 8601 timestamp",
532 None,
533 )
534 .into_response();
535 }
536 }
537
538 match state.store.list(&req).await {
539 Ok(resp) => make_success_response(rpc.id, &resp).into_response(),
540 Err(e) => make_error_response(rpc.id, e).into_response(),
541 }
542}
543
544async fn handle_cancel_task(state: ServerState, rpc: IncomingRpc) -> axum::response::Response {
545 let req: CancelTaskRequest = match serde_json::from_value(rpc.params) {
546 Ok(r) => r,
547 Err(e) => {
548 return make_json_rpc_error_response(
549 rpc.id,
550 -32602,
551 &format!("Invalid params: {e}"),
552 None,
553 )
554 .into_response();
555 }
556 };
557
558 let task_id = req.id.clone();
559 match state.store.cancel(&task_id).await {
560 Ok(()) => {
561 match state.store.get(&task_id).await {
563 Ok(Some(task)) => make_success_response(rpc.id, &task).into_response(),
564 Ok(None) => make_error_response(rpc.id, A2aProtocolError::TaskNotFound { task_id })
565 .into_response(),
566 Err(e) => make_error_response(rpc.id, e).into_response(),
567 }
568 }
569 Err(e) => make_error_response(rpc.id, e).into_response(),
570 }
571}
572
573async fn handle_send_streaming(state: ServerState, rpc: IncomingRpc) -> axum::response::Response {
574 let req: SendMessageRequest = match serde_json::from_value(rpc.params) {
575 Ok(r) => r,
576 Err(e) => {
577 return make_json_rpc_error_response(
578 rpc.id,
579 -32602,
580 &format!("Invalid params: {e}"),
581 None,
582 )
583 .into_response();
584 }
585 };
586
587 let message = req.message.clone();
588 let task_id = message
589 .task_id
590 .clone()
591 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
592 let context_id = message.context_id.clone();
593
594 let task = Task {
595 id: task_id.clone(),
596 context_id,
597 status: TaskStatus {
598 state: TaskState::Submitted,
599 message: None,
600 timestamp: Some(chrono::Utc::now().to_rfc3339()),
601 },
602 artifacts: vec![],
603 history: Some(vec![message.clone()]),
604 metadata: req.metadata.clone(),
605 };
606
607 if let Err(e) = state.store.insert(task).await {
608 error!(error = %e, "failed to insert task");
609 return make_error_response(rpc.id, e).into_response();
610 }
611
612 let rx = state.store.subscribe(&task_id).await;
614
615 if state.handler.is_some() {
617 let handler = Arc::clone(&state.handler);
618 let store = Arc::clone(&state.store);
619 let tid = task_id.clone();
620 let msg = message;
621
622 tokio::spawn(async move {
623 let working_status = TaskStatus {
624 state: TaskState::Working,
625 message: None,
626 timestamp: Some(chrono::Utc::now().to_rfc3339()),
627 };
628 store.update_status(&tid, working_status).await.ok();
629
630 if let Some(ref h) = *handler {
631 match h.handle(tid.clone(), msg, store.clone()).await {
632 Ok(()) => {}
633 Err(e) => {
634 error!(task_id = %tid, error = %e, "handler failed");
635 let failed_status = TaskStatus {
636 state: TaskState::Failed,
637 message: Some(Message {
638 message_id: uuid::Uuid::new_v4().to_string(),
639 context_id: None,
640 task_id: Some(tid.clone()),
641 role: Role::Agent,
642 parts: vec![Part {
643 content: PartContent::Text(format!("Handler error: {e}")),
644 metadata: None,
645 filename: None,
646 media_type: None,
647 }],
648 metadata: None,
649 extensions: vec![],
650 reference_task_ids: vec![],
651 }),
652 timestamp: Some(chrono::Utc::now().to_rfc3339()),
653 };
654 store.update_status(&tid, failed_status).await.ok();
655 }
656 }
657 }
658 });
659 }
660
661 match rx {
662 Some(rx) => make_sse_stream(rx).into_response(),
663 None => make_json_rpc_error_response(rpc.id, -32603, "Failed to create event stream", None)
664 .into_response(),
665 }
666}
667
668async fn handle_subscribe(state: ServerState, rpc: IncomingRpc) -> axum::response::Response {
669 let req: SubscribeToTaskRequest = match serde_json::from_value(rpc.params) {
670 Ok(r) => r,
671 Err(e) => {
672 return make_json_rpc_error_response(
673 rpc.id,
674 -32602,
675 &format!("Invalid params: {e}"),
676 None,
677 )
678 .into_response();
679 }
680 };
681
682 match state.store.get(&req.id).await {
684 Ok(None) => {
685 return make_error_response(rpc.id, A2aProtocolError::TaskNotFound { task_id: req.id })
686 .into_response();
687 }
688 Err(e) => return make_error_response(rpc.id, e).into_response(),
689 Ok(Some(_)) => {}
690 }
691
692 match state.store.subscribe(&req.id).await {
693 Some(rx) => make_sse_stream(rx).into_response(),
694 None => {
695 make_json_rpc_error_response(rpc.id, -32603, "Failed to subscribe to task events", None)
696 .into_response()
697 }
698 }
699}
700
701async fn handle_get_extended_card(
702 state: ServerState,
703 rpc: IncomingRpc,
704) -> axum::response::Response {
705 make_success_response(rpc.id, state.card.as_ref()).into_response()
706}
707
708fn make_sse_stream(
713 rx: tokio::sync::broadcast::Receiver<StreamResponse>,
714) -> Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>>> {
715 let stream = BroadcastStream::new(rx).filter_map(|result| match result {
716 Ok(event) => {
717 let json = serde_json::to_string(&event).unwrap_or_default();
718 Some(Ok(Event::default().data(json)))
719 }
720 Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
721 debug!(lagged = n, "SSE client lagged behind");
722 None
723 }
724 });
725 Sse::new(stream)
726}
727
728fn make_success_response<T: serde::Serialize>(
733 id: serde_json::Value,
734 result: &T,
735) -> Json<serde_json::Value> {
736 Json(serde_json::json!({
737 "jsonrpc": "2.0",
738 "id": id,
739 "result": result,
740 }))
741}
742
743fn make_error_response(
744 id: serde_json::Value,
745 error: impl Into<A2aError>,
746) -> Json<serde_json::Value> {
747 let error: A2aError = error.into();
748 match error {
749 A2aError::Protocol(ref proto) => {
750 make_json_rpc_error_response(id, proto.json_rpc_code(), &error.to_string(), None)
751 }
752 _ => make_json_rpc_error_response(id, -32603, &error.to_string(), None),
753 }
754}
755
756fn make_json_rpc_error_response(
757 id: serde_json::Value,
758 code: i32,
759 message: &str,
760 data: Option<serde_json::Value>,
761) -> Json<serde_json::Value> {
762 let mut error = serde_json::json!({
763 "code": code,
764 "message": message,
765 });
766 if let Some(d) = data {
767 error["data"] = d;
768 }
769 Json(serde_json::json!({
770 "jsonrpc": "2.0",
771 "id": id,
772 "error": error,
773 }))
774}