1use super::protocol::*;
4use crate::types::Workflow;
5use crate::workflow::parser;
6use anyhow::Result;
7use chrono::Utc;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11const MAX_TASKS: usize = 1000;
13
14pub struct A2aServer {
16 tasks: Arc<RwLock<HashMap<String, TaskResponse>>>,
18 config: A2aServerConfig,
20 agent_card_cache: RwLock<Option<AgentCard>>,
22}
23
24pub struct A2aServerConfig {
26 pub name: String,
28 pub description: String,
30 pub base_url: String,
32 pub version: String,
34}
35
36impl Default for A2aServerConfig {
37 fn default() -> Self {
38 Self {
39 name: "MUR Commander".into(),
40 description: "Autonomous workflow execution agent".into(),
41 base_url: "http://localhost:3939".into(),
42 version: env!("CARGO_PKG_VERSION").into(),
43 }
44 }
45}
46
47impl A2aServer {
48 pub fn new(config: A2aServerConfig) -> Self {
50 Self {
51 tasks: Arc::new(RwLock::new(HashMap::new())),
52 config,
53 agent_card_cache: RwLock::new(None),
54 }
55 }
56
57 pub fn agent_card(&self) -> AgentCard {
60 {
62 let cache = self.agent_card_cache.read().unwrap_or_else(|e| e.into_inner());
63 if let Some(ref card) = *cache {
64 return card.clone();
65 }
66 }
67
68 let card = self.generate_agent_card();
70 {
71 let mut cache = self.agent_card_cache.write().unwrap_or_else(|e| e.into_inner());
72 *cache = Some(card.clone());
73 }
74 card
75 }
76
77 pub fn invalidate_agent_card(&self) {
80 let mut cache = self.agent_card_cache.write().unwrap_or_else(|e| e.into_inner());
81 *cache = None;
82 }
83
84 fn generate_agent_card(&self) -> AgentCard {
86 let workflows = parser::load_all_workflows().unwrap_or_default();
87 let skills = Self::workflows_to_skills(&workflows);
88
89 AgentCard {
90 name: self.config.name.clone(),
91 description: self.config.description.clone(),
92 url: format!("{}/a2a", self.config.base_url),
93 version: self.config.version.clone(),
94 protocol_version: "0.1".into(),
95 capabilities: AgentCapabilities {
96 streaming: true,
97 push_notifications: false,
98 state_management: true,
99 },
100 skills,
101 authentication: None,
102 }
103 }
104
105 pub fn handle_request(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
107 match request.method.as_str() {
108 methods::TASKS_SEND => self.handle_task_send(request),
109 methods::TASKS_GET => self.handle_task_get(request),
110 methods::TASKS_CANCEL => self.handle_task_cancel(request),
111 _ => JsonRpcResponse::error(
112 request.id.clone(),
113 error_codes::METHOD_NOT_FOUND,
114 &format!("Unknown method: {}", request.method),
115 ),
116 }
117 }
118
119 fn handle_task_send(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
121 let params = match &request.params {
122 Some(p) => p,
123 None => {
124 return JsonRpcResponse::error(
125 request.id.clone(),
126 error_codes::INVALID_PARAMS,
127 "Missing params",
128 );
129 }
130 };
131
132 let task_req: TaskRequest = match serde_json::from_value(params.clone()) {
133 Ok(t) => t,
134 Err(e) => {
135 return JsonRpcResponse::error(
136 request.id.clone(),
137 error_codes::INVALID_PARAMS,
138 &format!("Invalid task request: {}", e),
139 );
140 }
141 };
142
143 let task_response = TaskResponse {
145 id: task_req.id.clone(),
146 status: TaskStatus {
147 state: TaskState::Submitted,
148 message: Some(TaskMessage {
149 role: MessageRole::Agent,
150 parts: vec![MessagePart::Text {
151 text: "Task received and queued for execution".into(),
152 }],
153 }),
154 },
155 artifacts: vec![],
156 history: task_req.messages.clone(),
157 metadata: task_req.metadata.clone(),
158 };
159
160 match self.tasks.write() {
162 Ok(mut tasks) => {
163 if tasks.len() >= MAX_TASKS {
164 Self::evict_terminal_tasks(&mut tasks);
165 }
166 tasks.insert(task_req.id.clone(), task_response.clone());
167 }
168 Err(_) => {
169 return JsonRpcResponse::error(
170 request.id.clone(),
171 error_codes::INTERNAL_ERROR,
172 "Internal error: task store unavailable",
173 );
174 }
175 }
176
177 match serde_json::to_value(&task_response) {
178 Ok(v) => JsonRpcResponse::success(request.id.clone(), v),
179 Err(e) => JsonRpcResponse::error(
180 request.id.clone(),
181 error_codes::INTERNAL_ERROR,
182 &format!("Serialization error: {}", e),
183 ),
184 }
185 }
186
187 fn handle_task_get(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
189 let task_id = request
190 .params
191 .as_ref()
192 .and_then(|p| p.get("id"))
193 .and_then(|v| v.as_str());
194
195 let task_id = match task_id {
196 Some(id) => id,
197 None => {
198 return JsonRpcResponse::error(
199 request.id.clone(),
200 error_codes::INVALID_PARAMS,
201 "Missing task id",
202 );
203 }
204 };
205
206 let tasks = self.tasks.read().unwrap_or_else(|e| e.into_inner());
207 match tasks.get(task_id) {
208 Some(task) => match serde_json::to_value(task) {
209 Ok(v) => JsonRpcResponse::success(request.id.clone(), v),
210 Err(e) => JsonRpcResponse::error(
211 request.id.clone(),
212 error_codes::INTERNAL_ERROR,
213 &format!("Serialization error: {}", e),
214 ),
215 },
216 None => JsonRpcResponse::error(
217 request.id.clone(),
218 error_codes::TASK_NOT_FOUND,
219 &format!("Task '{}' not found", task_id),
220 ),
221 }
222 }
223
224 fn handle_task_cancel(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
226 let task_id = request
227 .params
228 .as_ref()
229 .and_then(|p| p.get("id"))
230 .and_then(|v| v.as_str());
231
232 let task_id = match task_id {
233 Some(id) => id,
234 None => {
235 return JsonRpcResponse::error(
236 request.id.clone(),
237 error_codes::INVALID_PARAMS,
238 "Missing task id",
239 );
240 }
241 };
242
243 let mut tasks = self.tasks.write().unwrap_or_else(|e| e.into_inner());
244 match tasks.get_mut(task_id) {
245 Some(task) => {
246 match task.status.state {
248 TaskState::Submitted | TaskState::Working => {}
249 _ => {
250 return JsonRpcResponse::error(
251 request.id.clone(),
252 error_codes::INVALID_PARAMS,
253 &format!(
254 "Task '{}' cannot be canceled in state {:?}",
255 task_id, task.status.state
256 ),
257 );
258 }
259 }
260
261 task.status = TaskStatus {
262 state: TaskState::Canceled,
263 message: Some(TaskMessage {
264 role: MessageRole::Agent,
265 parts: vec![MessagePart::Text {
266 text: "Task canceled".into(),
267 }],
268 }),
269 };
270
271 let update = TaskStatusUpdate {
272 id: task_id.to_string(),
273 status: task.status.clone(),
274 final_update: true,
275 timestamp: Utc::now(),
276 };
277
278 match serde_json::to_value(&update) {
279 Ok(v) => JsonRpcResponse::success(request.id.clone(), v),
280 Err(e) => JsonRpcResponse::error(
281 request.id.clone(),
282 error_codes::INTERNAL_ERROR,
283 &format!("Serialization error: {}", e),
284 ),
285 }
286 }
287 None => JsonRpcResponse::error(
288 request.id.clone(),
289 error_codes::TASK_NOT_FOUND,
290 &format!("Task '{}' not found", task_id),
291 ),
292 }
293 }
294
295 pub fn update_task_status(
297 &self,
298 task_id: &str,
299 state: TaskState,
300 message: Option<String>,
301 artifacts: Option<Vec<TaskArtifact>>,
302 ) -> Result<()> {
303 let mut tasks = self
304 .tasks
305 .write()
306 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
307
308 let task = tasks
309 .get_mut(task_id)
310 .ok_or_else(|| anyhow::anyhow!("Task not found: {}", task_id))?;
311
312 task.status = TaskStatus {
313 state,
314 message: message.map(|text| TaskMessage {
315 role: MessageRole::Agent,
316 parts: vec![MessagePart::Text { text }],
317 }),
318 };
319
320 if let Some(arts) = artifacts {
321 task.artifacts = arts;
322 }
323
324 Ok(())
325 }
326
327 pub fn list_tasks(&self) -> Vec<TaskResponse> {
329 let tasks = self.tasks.read().unwrap_or_else(|e| e.into_inner());
330 tasks.values().cloned().collect()
331 }
332
333 fn evict_terminal_tasks(tasks: &mut HashMap<String, TaskResponse>) {
336 let terminal_ids: Vec<String> = tasks
337 .iter()
338 .filter(|(_, t)| matches!(
339 t.status.state,
340 TaskState::Completed | TaskState::Failed | TaskState::Canceled
341 ))
342 .map(|(id, _)| id.clone())
343 .collect();
344
345 for id in terminal_ids {
346 tasks.remove(&id);
347 }
348 }
349
350 fn workflows_to_skills(workflows: &[Workflow]) -> Vec<AgentSkill> {
352 workflows
353 .iter()
354 .map(|w| AgentSkill {
355 id: w.id.clone(),
356 name: w.name.clone(),
357 description: w.description.clone(),
358 tags: vec!["workflow".into()],
359 input_schema: Some(serde_json::json!({
360 "type": "object",
361 "properties": {
362 "variables": {
363 "type": "object",
364 "description": "Workflow variables",
365 "properties": w.variables.iter().map(|(k, v)| {
366 (k.clone(), serde_json::json!({
367 "type": "string",
368 "default": v,
369 }))
370 }).collect::<HashMap<_, _>>(),
371 },
372 "shadow": {
373 "type": "boolean",
374 "default": false,
375 "description": "Run in dry-run mode",
376 },
377 },
378 })),
379 output_schema: None,
380 })
381 .collect()
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 fn test_server() -> A2aServer {
390 A2aServer::new(A2aServerConfig::default())
391 }
392
393 #[test]
394 fn test_agent_card_generation() {
395 let server = test_server();
396 let card = server.agent_card();
397 assert_eq!(card.name, "MUR Commander");
398 assert!(card.capabilities.streaming);
399 assert_eq!(card.protocol_version, "0.1");
400 }
401
402 #[test]
403 fn test_handle_unknown_method() {
404 let server = test_server();
405 let req = JsonRpcRequest::new("unknown/method", None, serde_json::json!(1));
406 let resp = server.handle_request(&req);
407 assert!(resp.error.is_some());
408 assert_eq!(resp.error.unwrap().code, error_codes::METHOD_NOT_FOUND);
409 }
410
411 #[test]
412 fn test_handle_task_send() {
413 let server = test_server();
414 let task_req = TaskRequest {
415 id: "test-task-1".into(),
416 skill_id: None,
417 messages: vec![TaskMessage {
418 role: MessageRole::User,
419 parts: vec![MessagePart::Text {
420 text: "Hello".into(),
421 }],
422 }],
423 metadata: HashMap::new(),
424 };
425 let req = JsonRpcRequest::new(
426 methods::TASKS_SEND,
427 Some(serde_json::to_value(&task_req).unwrap()),
428 serde_json::json!(1),
429 );
430 let resp = server.handle_request(&req);
431 assert!(resp.error.is_none());
432 assert!(resp.result.is_some());
433 }
434
435 #[test]
436 fn test_handle_task_get() {
437 let server = test_server();
438
439 let task_req = TaskRequest {
441 id: "test-task-2".into(),
442 skill_id: None,
443 messages: vec![],
444 metadata: HashMap::new(),
445 };
446 let send_req = JsonRpcRequest::new(
447 methods::TASKS_SEND,
448 Some(serde_json::to_value(&task_req).unwrap()),
449 serde_json::json!(1),
450 );
451 server.handle_request(&send_req);
452
453 let get_req = JsonRpcRequest::new(
455 methods::TASKS_GET,
456 Some(serde_json::json!({"id": "test-task-2"})),
457 serde_json::json!(2),
458 );
459 let resp = server.handle_request(&get_req);
460 assert!(resp.error.is_none());
461 }
462
463 #[test]
464 fn test_handle_task_get_not_found() {
465 let server = test_server();
466 let req = JsonRpcRequest::new(
467 methods::TASKS_GET,
468 Some(serde_json::json!({"id": "nonexistent"})),
469 serde_json::json!(1),
470 );
471 let resp = server.handle_request(&req);
472 assert!(resp.error.is_some());
473 assert_eq!(resp.error.unwrap().code, error_codes::TASK_NOT_FOUND);
474 }
475
476 #[test]
477 fn test_handle_task_cancel() {
478 let server = test_server();
479
480 let task_req = TaskRequest {
482 id: "cancel-me".into(),
483 skill_id: None,
484 messages: vec![],
485 metadata: HashMap::new(),
486 };
487 let send_req = JsonRpcRequest::new(
488 methods::TASKS_SEND,
489 Some(serde_json::to_value(&task_req).unwrap()),
490 serde_json::json!(1),
491 );
492 server.handle_request(&send_req);
493
494 let cancel_req = JsonRpcRequest::new(
496 methods::TASKS_CANCEL,
497 Some(serde_json::json!({"id": "cancel-me"})),
498 serde_json::json!(2),
499 );
500 let resp = server.handle_request(&cancel_req);
501 assert!(resp.error.is_none());
502 }
503
504 #[test]
505 fn test_update_task_status() {
506 let server = test_server();
507
508 let task_req = TaskRequest {
510 id: "update-me".into(),
511 skill_id: None,
512 messages: vec![],
513 metadata: HashMap::new(),
514 };
515 let send_req = JsonRpcRequest::new(
516 methods::TASKS_SEND,
517 Some(serde_json::to_value(&task_req).unwrap()),
518 serde_json::json!(1),
519 );
520 server.handle_request(&send_req);
521
522 server
524 .update_task_status(
525 "update-me",
526 TaskState::Working,
527 Some("Processing...".into()),
528 None,
529 )
530 .unwrap();
531
532 let tasks = server.list_tasks();
533 assert_eq!(tasks.len(), 1);
534 assert_eq!(tasks[0].status.state, TaskState::Working);
535 }
536
537 #[test]
538 fn test_cancel_already_completed_task() {
539 let server = test_server();
540
541 let task_req = TaskRequest {
543 id: "done-task".into(),
544 skill_id: None,
545 messages: vec![],
546 metadata: HashMap::new(),
547 };
548 let send_req = JsonRpcRequest::new(
549 methods::TASKS_SEND,
550 Some(serde_json::to_value(&task_req).unwrap()),
551 serde_json::json!(1),
552 );
553 server.handle_request(&send_req);
554 server
555 .update_task_status("done-task", TaskState::Completed, None, None)
556 .unwrap();
557
558 let cancel_req = JsonRpcRequest::new(
560 methods::TASKS_CANCEL,
561 Some(serde_json::json!({"id": "done-task"})),
562 serde_json::json!(2),
563 );
564 let resp = server.handle_request(&cancel_req);
565 assert!(resp.error.is_some());
566 assert_eq!(resp.error.unwrap().code, error_codes::INVALID_PARAMS);
567 }
568
569 #[test]
570 fn test_task_eviction() {
571 let server = test_server();
572
573 for i in 0..MAX_TASKS {
575 let task_req = TaskRequest {
576 id: format!("task-{}", i),
577 skill_id: None,
578 messages: vec![],
579 metadata: HashMap::new(),
580 };
581 let send_req = JsonRpcRequest::new(
582 methods::TASKS_SEND,
583 Some(serde_json::to_value(&task_req).unwrap()),
584 serde_json::json!(i),
585 );
586 server.handle_request(&send_req);
587 server
588 .update_task_status(
589 &format!("task-{}", i),
590 TaskState::Completed,
591 None,
592 None,
593 )
594 .unwrap();
595 }
596
597 assert_eq!(server.list_tasks().len(), MAX_TASKS);
598
599 let task_req = TaskRequest {
601 id: "new-task".into(),
602 skill_id: None,
603 messages: vec![],
604 metadata: HashMap::new(),
605 };
606 let send_req = JsonRpcRequest::new(
607 methods::TASKS_SEND,
608 Some(serde_json::to_value(&task_req).unwrap()),
609 serde_json::json!(9999),
610 );
611 server.handle_request(&send_req);
612
613 assert_eq!(server.list_tasks().len(), 1);
615 }
616}