1use std::collections::HashMap;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use tokio::sync::Mutex; use crate::adapter::business::push_notification::{
12 PushNotificationRegistry, PushNotificationSender,
13};
14
15#[cfg(feature = "http-client")]
16use crate::adapter::business::push_notification::HttpPushNotificationSender;
17#[cfg(not(feature = "http-client"))]
18use crate::adapter::business::push_notification::NoopPushNotificationSender;
19use crate::domain::{
20 A2AError, Artifact, Message, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig,
21 TaskState, TaskStatus, TaskStatusUpdateEvent,
22};
23use crate::port::{
24 AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager,
25 streaming_handler::Subscriber,
26};
27
28type StatusSubscribers = Vec<Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>>;
29type ArtifactSubscribers = Vec<Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>>;
30
31pub(crate) struct TaskSubscribers {
33 status: StatusSubscribers,
34 artifacts: ArtifactSubscribers,
35}
36
37impl TaskSubscribers {
38 fn new() -> Self {
39 Self {
40 status: Vec::new(),
41 artifacts: Vec::new(),
42 }
43 }
44}
45
46pub struct InMemoryTaskStorage {
48 pub(crate) tasks: Arc<Mutex<HashMap<String, Task>>>,
50 pub(crate) subscribers: Arc<Mutex<HashMap<String, TaskSubscribers>>>,
52 pub(crate) push_notification_registry: Arc<PushNotificationRegistry>,
54}
55
56impl InMemoryTaskStorage {
57 pub fn new() -> Self {
59 #[cfg(feature = "http-client")]
61 let push_sender = HttpPushNotificationSender::new();
62 #[cfg(not(feature = "http-client"))]
63 let push_sender = NoopPushNotificationSender;
64
65 let push_registry = PushNotificationRegistry::new(push_sender);
66
67 Self {
68 tasks: Arc::new(Mutex::new(HashMap::new())),
69 subscribers: Arc::new(Mutex::new(HashMap::new())),
70 push_notification_registry: Arc::new(push_registry),
71 }
72 }
73
74 pub fn with_push_sender(push_sender: impl PushNotificationSender + 'static) -> Self {
76 let push_registry = PushNotificationRegistry::new(push_sender);
77
78 Self {
79 tasks: Arc::new(Mutex::new(HashMap::new())),
80 subscribers: Arc::new(Mutex::new(HashMap::new())),
81 push_notification_registry: Arc::new(push_registry),
82 }
83 }
84
85 pub async fn add_status_subscriber_legacy(
87 &self,
88 task_id: &str,
89 subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
90 ) -> Result<(), A2AError> {
91 self.add_status_subscriber(task_id, subscriber)
92 .await
93 .map(|_| ())
94 }
95
96 pub async fn add_artifact_subscriber_legacy(
98 &self,
99 task_id: &str,
100 subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
101 ) -> Result<(), A2AError> {
102 self.add_artifact_subscriber(task_id, subscriber)
103 .await
104 .map(|_| ())
105 }
106}
107
108impl Default for InMemoryTaskStorage {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl InMemoryTaskStorage {
115 async fn get_task_context_id(&self, task_id: &str) -> String {
117 let tasks_guard = self.tasks.lock().await;
118 tasks_guard
119 .get(task_id)
120 .map(|t| t.context_id.clone())
121 .unwrap_or_else(|| "default".to_string())
122 }
123
124 pub(crate) async fn broadcast_status_update(
126 &self,
127 task_id: &str,
128 status: TaskStatus,
129 final_: bool,
130 ) -> Result<(), A2AError> {
131 let context_id = self.get_task_context_id(task_id).await;
132
133 let event = TaskStatusUpdateEvent {
135 task_id: task_id.to_string(),
136 context_id,
137 kind: "status-update".to_string(),
138 status: status.clone(),
139 final_,
140 metadata: None,
141 };
142
143 #[cfg(feature = "tracing")]
144 tracing::debug!(
145 task_id = %task_id,
146 state = ?status.state,
147 "📡 Broadcasting status update to subscribers"
148 );
149
150 let subscriber_count = {
152 let subscribers_guard = self.subscribers.lock().await;
153
154 if let Some(task_subscribers) = subscribers_guard.get(task_id) {
155 let count = task_subscribers.status.len();
156 #[cfg(feature = "tracing")]
157 tracing::info!(
158 task_id = %task_id,
159 subscriber_count = count,
160 state = ?status.state,
161 "📡 Notifying WebSocket subscribers of status update"
162 );
163
164 for (i, subscriber) in task_subscribers.status.iter().enumerate() {
166 if let Err(e) = subscriber.on_update(event.clone()).await {
167 #[cfg(feature = "tracing")]
168 tracing::error!(
169 task_id = %task_id,
170 subscriber_index = i,
171 error = %e,
172 "❌ Failed to notify subscriber"
173 );
174 eprintln!("Failed to notify subscriber: {}", e);
175 } else {
176 #[cfg(feature = "tracing")]
177 tracing::debug!(
178 task_id = %task_id,
179 subscriber_index = i,
180 "✅ Successfully notified subscriber"
181 );
182 }
183 }
184 count
185 } else {
186 #[cfg(feature = "tracing")]
187 tracing::warn!(
188 task_id = %task_id,
189 "⚠️ No WebSocket subscribers found for task"
190 );
191 0
192 }
193 }; #[cfg(feature = "tracing")]
196 tracing::debug!(
197 task_id = %task_id,
198 notified_count = subscriber_count,
199 "📡 Finished broadcasting to WebSocket subscribers"
200 );
201
202 if let Err(e) = self
204 .push_notification_registry
205 .send_status_update(task_id, &event)
206 .await
207 {
208 eprintln!("Failed to send push notification: {}", e);
209 }
210
211 Ok(())
212 }
213
214 pub(crate) async fn broadcast_artifact_update(
216 &self,
217 task_id: &str,
218 artifact: Artifact,
219 _index: Option<u32>,
220 _final: bool,
221 ) -> Result<(), A2AError> {
222 let context_id = self.get_task_context_id(task_id).await;
223
224 let event = TaskArtifactUpdateEvent {
226 task_id: task_id.to_string(),
227 context_id,
228 kind: "artifact-update".to_string(),
229 artifact,
230 append: None,
231 last_chunk: None,
232 metadata: None,
233 };
234
235 {
237 let subscribers_guard = self.subscribers.lock().await;
238
239 if let Some(task_subscribers) = subscribers_guard.get(task_id) {
240 for subscriber in task_subscribers.artifacts.iter() {
242 if let Err(e) = subscriber.on_update(event.clone()).await {
243 eprintln!("Failed to notify subscriber: {}", e);
244 }
245 }
246 }
247 }; if let Err(e) = self
251 .push_notification_registry
252 .send_artifact_update(task_id, &event)
253 .await
254 {
255 eprintln!("Failed to send push notification: {}", e);
256 }
257
258 Ok(())
259 }
260}
261
262#[async_trait]
263impl AsyncTaskManager for InMemoryTaskStorage {
264 async fn create_task(&self, task_id: &str, context_id: &str) -> Result<Task, A2AError> {
265 let mut tasks_guard = self.tasks.lock().await;
266
267 if tasks_guard.contains_key(task_id) {
268 return Err(A2AError::TaskNotFound(format!(
269 "Task {} already exists",
270 task_id
271 )));
272 }
273
274 let task = Task::new(task_id.to_string(), context_id.to_string());
275 tasks_guard.insert(task_id.to_string(), task.clone());
276
277 Ok(task)
278 }
279
280 async fn update_task_status(
281 &self,
282 task_id: &str,
283 state: TaskState,
284 message: Option<Message>,
285 ) -> Result<Task, A2AError> {
286 let mut tasks_guard = self.tasks.lock().await;
287
288 let task = tasks_guard
289 .get_mut(task_id)
290 .ok_or_else(|| A2AError::TaskNotFound(task_id.to_string()))?;
291
292 task.update_status(state, message);
294
295 let status_for_broadcast = task.status.clone();
297 let updated_task = task.clone();
298
299 drop(tasks_guard);
301
302 self.broadcast_status_update(task_id, status_for_broadcast, false)
304 .await?;
305
306 Ok(updated_task)
307 }
308
309 async fn task_exists(&self, task_id: &str) -> Result<bool, A2AError> {
310 let tasks_guard = self.tasks.lock().await;
311 Ok(tasks_guard.contains_key(task_id))
312 }
313
314 async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
315 let task = {
317 let tasks_guard = self.tasks.lock().await;
318
319 let Some(task) = tasks_guard.get(task_id) else {
320 return Err(A2AError::TaskNotFound(task_id.to_string()));
321 };
322
323 task.with_limited_history(history_length)
325 }; Ok(task)
328 }
329
330 async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
331 let (task, status_for_broadcast) = {
333 let mut tasks_guard = self.tasks.lock().await;
334
335 let Some(task) = tasks_guard.get(task_id) else {
336 return Err(A2AError::TaskNotFound(task_id.to_string()));
337 };
338
339 let mut updated_task = task.clone();
340
341 if updated_task.status.state != TaskState::Working {
343 return Err(A2AError::TaskNotCancelable(format!(
344 "Task {} is in state {:?} and cannot be canceled",
345 task_id, updated_task.status.state
346 )));
347 }
348
349 let cancel_message = Message {
351 role: crate::domain::Role::Agent,
352 parts: vec![crate::domain::Part::Text {
353 text: format!("Task {} canceled.", task_id),
354 metadata: None,
355 }],
356 metadata: None,
357 reference_task_ids: None,
358 message_id: uuid::Uuid::new_v4().to_string(),
359 task_id: Some(task_id.to_string()),
360 context_id: Some(updated_task.context_id.clone()),
361 extensions: None,
362 kind: "message".to_string(),
363 };
364
365 updated_task.update_status(TaskState::Canceled, Some(cancel_message));
367
368 let status_for_broadcast = updated_task.status.clone();
370 tasks_guard.insert(task_id.to_string(), updated_task.clone());
371
372 drop(tasks_guard);
374 (updated_task, status_for_broadcast)
375 }; self.broadcast_status_update(task_id, status_for_broadcast, true)
379 .await?;
380
381 Ok(task)
382 }
383
384 async fn list_tasks_v3(
387 &self,
388 params: &crate::domain::ListTasksParams,
389 ) -> Result<crate::domain::ListTasksResult, A2AError> {
390 use crate::domain::ListTasksResult;
391
392 let tasks_guard = self.tasks.lock().await;
393
394 let mut filtered_tasks: Vec<_> = tasks_guard
396 .values()
397 .filter(|task| {
398 if let Some(ref context_id) = params.context_id {
400 if &task.context_id != context_id {
401 return false;
402 }
403 }
404
405 if let Some(ref status) = params.status {
407 if &task.status.state != status {
408 return false;
409 }
410 }
411
412 if let Some(last_updated_after) = params.last_updated_after {
414 if let Some(timestamp) = task.status.timestamp {
415 let task_time_ms = timestamp.timestamp_millis();
416 if task_time_ms <= last_updated_after {
417 return false;
418 }
419 }
420 }
421
422 true
423 })
424 .cloned()
425 .collect();
426
427 filtered_tasks.sort_by(|a, b| {
429 let a_time = a
430 .status
431 .timestamp
432 .map(|t| t.timestamp_millis())
433 .unwrap_or(0);
434 let b_time = b
435 .status
436 .timestamp
437 .map(|t| t.timestamp_millis())
438 .unwrap_or(0);
439 b_time.cmp(&a_time)
440 });
441
442 let total_size = filtered_tasks.len() as i32;
443
444 let page_size = params.page_size.unwrap_or(50).clamp(1, 100) as usize;
446 let page_start = if let Some(ref token) = params.page_token {
447 token.parse::<usize>().unwrap_or(0)
449 } else {
450 0
451 };
452
453 let page_end = (page_start + page_size).min(filtered_tasks.len());
454 let has_more = page_end < filtered_tasks.len();
455
456 let mut page_tasks: Vec<_> = filtered_tasks[page_start..page_end].to_vec();
458
459 let history_length = params.history_length.unwrap_or(0);
461 for task in &mut page_tasks {
462 *task = task.with_limited_history(Some(history_length as u32));
463
464 if !params.include_artifacts.unwrap_or(false) {
466 task.artifacts = None;
467 }
468 }
469
470 let next_page_token = if has_more {
472 page_end.to_string()
473 } else {
474 String::new()
475 };
476
477 Ok(ListTasksResult {
478 tasks: page_tasks,
479 total_size,
480 page_size: page_size as i32,
481 next_page_token,
482 })
483 }
484
485 async fn get_push_notification_config(
486 &self,
487 params: &crate::domain::GetTaskPushNotificationConfigParams,
488 ) -> Result<crate::domain::TaskPushNotificationConfig, A2AError> {
489 self.get_task_notification(¶ms.id).await
492 }
493
494 async fn list_push_notification_configs(
495 &self,
496 params: &crate::domain::ListTaskPushNotificationConfigParams,
497 ) -> Result<Vec<crate::domain::TaskPushNotificationConfig>, A2AError> {
498 match self
501 .push_notification_registry
502 .get_config(¶ms.id)
503 .await?
504 {
505 Some(config) => Ok(vec![crate::domain::TaskPushNotificationConfig {
506 task_id: params.id.clone(),
507 push_notification_config: config,
508 }]),
509 None => Ok(vec![]),
510 }
511 }
512
513 async fn delete_push_notification_config(
514 &self,
515 params: &crate::domain::DeleteTaskPushNotificationConfigParams,
516 ) -> Result<(), A2AError> {
517 self.remove_task_notification(¶ms.id).await
520 }
521}
522
523#[async_trait]
525impl AsyncNotificationManager for InMemoryTaskStorage {
526 async fn set_task_notification(
527 &self,
528 config: &TaskPushNotificationConfig,
529 ) -> Result<TaskPushNotificationConfig, A2AError> {
530 #[cfg(feature = "tracing")]
531 tracing::info!(
532 task_id = %config.task_id,
533 url = %config.push_notification_config.url,
534 "✅ Registering push notification config for task"
535 );
536
537 self.push_notification_registry
539 .register(&config.task_id, config.push_notification_config.clone())
540 .await?;
541
542 #[cfg(feature = "tracing")]
543 tracing::info!(
544 task_id = %config.task_id,
545 "✅ Push notification config registered successfully"
546 );
547
548 Ok(config.clone())
549 }
550
551 async fn get_task_notification(
552 &self,
553 task_id: &str,
554 ) -> Result<TaskPushNotificationConfig, A2AError> {
555 match self.push_notification_registry.get_config(task_id).await? {
557 Some(config) => Ok(TaskPushNotificationConfig {
558 task_id: task_id.to_string(),
559 push_notification_config: config,
560 }),
561 None => Err(A2AError::PushNotificationNotSupported),
562 }
563 }
564
565 async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> {
566 self.push_notification_registry.unregister(task_id).await?;
567 Ok(())
568 }
569}
570
571#[async_trait]
573impl AsyncStreamingHandler for InMemoryTaskStorage {
574 async fn add_status_subscriber(
575 &self,
576 task_id: &str,
577 subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
578 ) -> Result<String, A2AError> {
579 #[cfg(feature = "tracing")]
580 tracing::info!(
581 task_id = %task_id,
582 "✅ Adding WebSocket subscriber for status updates"
583 );
584
585 {
587 let mut subscribers_guard = self.subscribers.lock().await;
588
589 let task_subscribers = subscribers_guard
590 .entry(task_id.to_string())
591 .or_insert_with(TaskSubscribers::new);
592
593 task_subscribers.status.push(subscriber);
594
595 #[cfg(feature = "tracing")]
596 tracing::info!(
597 task_id = %task_id,
598 subscriber_count = task_subscribers.status.len(),
599 "✅ WebSocket subscriber added successfully"
600 );
601 } if let Ok(task) = self.get_task(task_id, None).await {
606 let _ = self
607 .broadcast_status_update(task_id, task.status, false)
608 .await;
609 }
610
611 Ok(format!("status-{}-{}", task_id, uuid::Uuid::new_v4()))
612 }
613
614 async fn add_artifact_subscriber(
615 &self,
616 task_id: &str,
617 subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
618 ) -> Result<String, A2AError> {
619 {
621 let mut subscribers_guard = self.subscribers.lock().await;
622
623 let task_subscribers = subscribers_guard
624 .entry(task_id.to_string())
625 .or_insert_with(TaskSubscribers::new);
626
627 task_subscribers.artifacts.push(subscriber);
628 } if let Ok(task) = self.get_task(task_id, None).await {
633 if let Some(artifacts) = task.artifacts {
634 for artifact in artifacts {
635 let _ = self
636 .broadcast_artifact_update(task_id, artifact, None, false)
637 .await;
638 }
639 }
640 }
641
642 Ok(format!("artifact-{}-{}", task_id, uuid::Uuid::new_v4()))
643 }
644
645 async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> {
646 Err(A2AError::UnsupportedOperation(
647 "Subscription removal by ID requires storage layer refactoring".to_string(),
648 ))
649 }
650
651 async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
652 {
654 let mut subscribers_guard = self.subscribers.lock().await;
655 subscribers_guard.remove(task_id);
656 } Ok(())
659 }
660
661 async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
662 let subscribers_guard = self.subscribers.lock().await;
663
664 if let Some(task_subscribers) = subscribers_guard.get(task_id) {
665 Ok(task_subscribers.status.len() + task_subscribers.artifacts.len())
666 } else {
667 Ok(0)
668 }
669 }
670
671 async fn broadcast_status_update(
672 &self,
673 task_id: &str,
674 update: TaskStatusUpdateEvent,
675 ) -> Result<(), A2AError> {
676 self.broadcast_status_update(task_id, update.status, update.final_)
677 .await
678 }
679
680 async fn broadcast_artifact_update(
681 &self,
682 task_id: &str,
683 update: TaskArtifactUpdateEvent,
684 ) -> Result<(), A2AError> {
685 self.broadcast_artifact_update(
686 task_id,
687 update.artifact,
688 None,
689 update.last_chunk.unwrap_or(false),
690 )
691 .await
692 }
693
694 async fn status_update_stream(
695 &self,
696 _task_id: &str,
697 ) -> Result<
698 std::pin::Pin<
699 Box<dyn futures::Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>,
700 >,
701 A2AError,
702 > {
703 Err(A2AError::UnsupportedOperation(
704 "Status update stream requires storage layer refactoring".to_string(),
705 ))
706 }
707
708 async fn artifact_update_stream(
709 &self,
710 _task_id: &str,
711 ) -> Result<
712 std::pin::Pin<
713 Box<dyn futures::Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>,
714 >,
715 A2AError,
716 > {
717 Err(A2AError::UnsupportedOperation(
718 "Artifact update stream requires storage layer refactoring".to_string(),
719 ))
720 }
721
722 async fn combined_update_stream(
723 &self,
724 _task_id: &str,
725 ) -> Result<
726 std::pin::Pin<
727 Box<
728 dyn futures::Stream<
729 Item = Result<crate::port::streaming_handler::UpdateEvent, A2AError>,
730 > + Send,
731 >,
732 >,
733 A2AError,
734 > {
735 Err(A2AError::UnsupportedOperation(
736 "Combined update stream requires storage layer refactoring".to_string(),
737 ))
738 }
739}
740
741impl Clone for InMemoryTaskStorage {
742 fn clone(&self) -> Self {
743 Self {
744 tasks: self.tasks.clone(),
745 subscribers: self.subscribers.clone(),
746 push_notification_registry: self.push_notification_registry.clone(),
747 }
748 }
749}