Skip to main content

a2a_rs/adapter/storage/
task_storage.rs

1//! In-memory task storage implementation
2
3// This module is already conditionally compiled with #[cfg(feature = "server")] in mod.rs
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use tokio::sync::Mutex; // Changed from std::sync::Mutex
10
11use crate::adapter::business::push_notification::{
12    PushNotificationRegistry, PushNotificationSender,
13};
14
15#[cfg(feature = "http-client")]
16use crate::adapter::business::push_notification::HttpPushNotificationSender;
17#[cfg(not(feature = "http-client"))]
18use crate::adapter::business::push_notification::NoopPushNotificationSender;
19use crate::domain::{
20    A2AError, Artifact, Message, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig,
21    TaskState, TaskStatus, TaskStatusUpdateEvent,
22};
23use crate::port::{
24    AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager,
25    streaming_handler::Subscriber,
26};
27
28type StatusSubscribers = Vec<Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>>;
29type ArtifactSubscribers = Vec<Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>>;
30
31/// Structure to hold subscribers for a task
32pub(crate) struct TaskSubscribers {
33    status: StatusSubscribers,
34    artifacts: ArtifactSubscribers,
35}
36
37impl TaskSubscribers {
38    fn new() -> Self {
39        Self {
40            status: Vec::new(),
41            artifacts: Vec::new(),
42        }
43    }
44}
45
46/// Simple in-memory task storage for testing and example purposes
47pub struct InMemoryTaskStorage {
48    /// Tasks stored by ID
49    pub(crate) tasks: Arc<Mutex<HashMap<String, Task>>>,
50    /// Subscribers for task updates
51    pub(crate) subscribers: Arc<Mutex<HashMap<String, TaskSubscribers>>>,
52    /// Push notification registry
53    pub(crate) push_notification_registry: Arc<PushNotificationRegistry>,
54}
55
56impl InMemoryTaskStorage {
57    /// Create a new empty task storage
58    pub fn new() -> Self {
59        // Use the appropriate push notification sender based on available features
60        #[cfg(feature = "http-client")]
61        let push_sender = HttpPushNotificationSender::new();
62        #[cfg(not(feature = "http-client"))]
63        let push_sender = NoopPushNotificationSender;
64
65        let push_registry = PushNotificationRegistry::new(push_sender);
66
67        Self {
68            tasks: Arc::new(Mutex::new(HashMap::new())),
69            subscribers: Arc::new(Mutex::new(HashMap::new())),
70            push_notification_registry: Arc::new(push_registry),
71        }
72    }
73
74    /// Create a new task storage with a custom push notification sender
75    pub fn with_push_sender(push_sender: impl PushNotificationSender + 'static) -> Self {
76        let push_registry = PushNotificationRegistry::new(push_sender);
77
78        Self {
79            tasks: Arc::new(Mutex::new(HashMap::new())),
80            subscribers: Arc::new(Mutex::new(HashMap::new())),
81            push_notification_registry: Arc::new(push_registry),
82        }
83    }
84
85    /// Add a status update subscriber for streaming (convenience method)
86    pub async fn add_status_subscriber_legacy(
87        &self,
88        task_id: &str,
89        subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
90    ) -> Result<(), A2AError> {
91        self.add_status_subscriber(task_id, subscriber)
92            .await
93            .map(|_| ())
94    }
95
96    /// Add an artifact update subscriber for streaming (convenience method)
97    pub async fn add_artifact_subscriber_legacy(
98        &self,
99        task_id: &str,
100        subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
101    ) -> Result<(), A2AError> {
102        self.add_artifact_subscriber(task_id, subscriber)
103            .await
104            .map(|_| ())
105    }
106}
107
108impl Default for InMemoryTaskStorage {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl InMemoryTaskStorage {
115    /// Look up the context_id for a task
116    async fn get_task_context_id(&self, task_id: &str) -> String {
117        let tasks_guard = self.tasks.lock().await;
118        tasks_guard
119            .get(task_id)
120            .map(|t| t.context_id.clone())
121            .unwrap_or_else(|| "default".to_string())
122    }
123
124    /// Send a status update to all subscribers for a task
125    pub(crate) async fn broadcast_status_update(
126        &self,
127        task_id: &str,
128        status: TaskStatus,
129        final_: bool,
130    ) -> Result<(), A2AError> {
131        let context_id = self.get_task_context_id(task_id).await;
132
133        // Create the update event
134        let event = TaskStatusUpdateEvent {
135            task_id: task_id.to_string(),
136            context_id,
137            kind: "status-update".to_string(),
138            status: status.clone(),
139            final_,
140            metadata: None,
141        };
142
143        #[cfg(feature = "tracing")]
144        tracing::debug!(
145            task_id = %task_id,
146            state = ?status.state,
147            "📡 Broadcasting status update to subscribers"
148        );
149
150        // Get all subscribers for this task and notify them
151        let subscriber_count = {
152            let subscribers_guard = self.subscribers.lock().await;
153
154            if let Some(task_subscribers) = subscribers_guard.get(task_id) {
155                let count = task_subscribers.status.len();
156                #[cfg(feature = "tracing")]
157                tracing::info!(
158                    task_id = %task_id,
159                    subscriber_count = count,
160                    state = ?status.state,
161                    "📡 Notifying WebSocket subscribers of status update"
162                );
163
164                // Clone the subscribers so we don't hold the lock during notification
165                for (i, subscriber) in task_subscribers.status.iter().enumerate() {
166                    if let Err(e) = subscriber.on_update(event.clone()).await {
167                        #[cfg(feature = "tracing")]
168                        tracing::error!(
169                            task_id = %task_id,
170                            subscriber_index = i,
171                            error = %e,
172                            "❌ Failed to notify subscriber"
173                        );
174                        eprintln!("Failed to notify subscriber: {}", e);
175                    } else {
176                        #[cfg(feature = "tracing")]
177                        tracing::debug!(
178                            task_id = %task_id,
179                            subscriber_index = i,
180                            "✅ Successfully notified subscriber"
181                        );
182                    }
183                }
184                count
185            } else {
186                #[cfg(feature = "tracing")]
187                tracing::warn!(
188                    task_id = %task_id,
189                    "⚠️  No WebSocket subscribers found for task"
190                );
191                0
192            }
193        }; // Lock is dropped here
194
195        #[cfg(feature = "tracing")]
196        tracing::debug!(
197            task_id = %task_id,
198            notified_count = subscriber_count,
199            "📡 Finished broadcasting to WebSocket subscribers"
200        );
201
202        // Send push notification if configured
203        if let Err(e) = self
204            .push_notification_registry
205            .send_status_update(task_id, &event)
206            .await
207        {
208            eprintln!("Failed to send push notification: {}", e);
209        }
210
211        Ok(())
212    }
213
214    /// Send an artifact update to all subscribers for a task
215    pub(crate) async fn broadcast_artifact_update(
216        &self,
217        task_id: &str,
218        artifact: Artifact,
219        _index: Option<u32>,
220        _final: bool,
221    ) -> Result<(), A2AError> {
222        let context_id = self.get_task_context_id(task_id).await;
223
224        // Create the update event
225        let event = TaskArtifactUpdateEvent {
226            task_id: task_id.to_string(),
227            context_id,
228            kind: "artifact-update".to_string(),
229            artifact,
230            append: None,
231            last_chunk: None,
232            metadata: None,
233        };
234
235        // Get all subscribers for this task
236        {
237            let subscribers_guard = self.subscribers.lock().await;
238
239            if let Some(task_subscribers) = subscribers_guard.get(task_id) {
240                // Clone the subscribers so we don't hold the lock during notification
241                for subscriber in task_subscribers.artifacts.iter() {
242                    if let Err(e) = subscriber.on_update(event.clone()).await {
243                        eprintln!("Failed to notify subscriber: {}", e);
244                    }
245                }
246            }
247        }; // Lock is dropped here
248
249        // Send push notification if configured
250        if let Err(e) = self
251            .push_notification_registry
252            .send_artifact_update(task_id, &event)
253            .await
254        {
255            eprintln!("Failed to send push notification: {}", e);
256        }
257
258        Ok(())
259    }
260}
261
262#[async_trait]
263impl AsyncTaskManager for InMemoryTaskStorage {
264    async fn create_task(&self, task_id: &str, context_id: &str) -> Result<Task, A2AError> {
265        let mut tasks_guard = self.tasks.lock().await;
266
267        if tasks_guard.contains_key(task_id) {
268            return Err(A2AError::TaskNotFound(format!(
269                "Task {} already exists",
270                task_id
271            )));
272        }
273
274        let task = Task::new(task_id.to_string(), context_id.to_string());
275        tasks_guard.insert(task_id.to_string(), task.clone());
276
277        Ok(task)
278    }
279
280    async fn update_task_status(
281        &self,
282        task_id: &str,
283        state: TaskState,
284        message: Option<Message>,
285    ) -> Result<Task, A2AError> {
286        let mut tasks_guard = self.tasks.lock().await;
287
288        let task = tasks_guard
289            .get_mut(task_id)
290            .ok_or_else(|| A2AError::TaskNotFound(task_id.to_string()))?;
291
292        // Update the task status with the optional message
293        task.update_status(state, message);
294
295        // Clone status before cloning the entire task to avoid double clone
296        let status_for_broadcast = task.status.clone();
297        let updated_task = task.clone();
298
299        // Release the lock before broadcasting
300        drop(tasks_guard);
301
302        // Broadcast status update
303        self.broadcast_status_update(task_id, status_for_broadcast, false)
304            .await?;
305
306        Ok(updated_task)
307    }
308
309    async fn task_exists(&self, task_id: &str) -> Result<bool, A2AError> {
310        let tasks_guard = self.tasks.lock().await;
311        Ok(tasks_guard.contains_key(task_id))
312    }
313
314    async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
315        // Get the task
316        let task = {
317            let tasks_guard = self.tasks.lock().await;
318
319            let Some(task) = tasks_guard.get(task_id) else {
320                return Err(A2AError::TaskNotFound(task_id.to_string()));
321            };
322
323            // Apply history length limitation if specified
324            task.with_limited_history(history_length)
325        }; // Lock is dropped here
326
327        Ok(task)
328    }
329
330    async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
331        // Get and update the task
332        let (task, status_for_broadcast) = {
333            let mut tasks_guard = self.tasks.lock().await;
334
335            let Some(task) = tasks_guard.get(task_id) else {
336                return Err(A2AError::TaskNotFound(task_id.to_string()));
337            };
338
339            let mut updated_task = task.clone();
340
341            // Only working tasks can be canceled
342            if updated_task.status.state != TaskState::Working {
343                return Err(A2AError::TaskNotCancelable(format!(
344                    "Task {} is in state {:?} and cannot be canceled",
345                    task_id, updated_task.status.state
346                )));
347            }
348
349            // Create a cancellation message to add to history
350            let cancel_message = Message {
351                role: crate::domain::Role::Agent,
352                parts: vec![crate::domain::Part::Text {
353                    text: format!("Task {} canceled.", task_id),
354                    metadata: None,
355                }],
356                metadata: None,
357                reference_task_ids: None,
358                message_id: uuid::Uuid::new_v4().to_string(),
359                task_id: Some(task_id.to_string()),
360                context_id: Some(updated_task.context_id.clone()),
361                extensions: None,
362                kind: "message".to_string(),
363            };
364
365            // Update the status with the cancellation message to track in history
366            updated_task.update_status(TaskState::Canceled, Some(cancel_message));
367
368            // Clone status before updating storage to avoid cloning task twice
369            let status_for_broadcast = updated_task.status.clone();
370            tasks_guard.insert(task_id.to_string(), updated_task.clone());
371
372            // Drop guard early and return status for use after broadcasting
373            drop(tasks_guard);
374            (updated_task, status_for_broadcast)
375        }; // Lock is dropped here
376
377        // Broadcast status update (with final flag set to true)
378        self.broadcast_status_update(task_id, status_for_broadcast, true)
379            .await?;
380
381        Ok(task)
382    }
383
384    // ===== v0.3.0 New Methods =====
385
386    async fn list_tasks_v3(
387        &self,
388        params: &crate::domain::ListTasksParams,
389    ) -> Result<crate::domain::ListTasksResult, A2AError> {
390        use crate::domain::ListTasksResult;
391
392        let tasks_guard = self.tasks.lock().await;
393
394        // Filter tasks based on parameters
395        let mut filtered_tasks: Vec<_> = tasks_guard
396            .values()
397            .filter(|task| {
398                // Filter by context_id if provided
399                if let Some(ref context_id) = params.context_id {
400                    if &task.context_id != context_id {
401                        return false;
402                    }
403                }
404
405                // Filter by status if provided
406                if let Some(ref status) = params.status {
407                    if &task.status.state != status {
408                        return false;
409                    }
410                }
411
412                // Filter by lastUpdatedAfter if provided
413                if let Some(last_updated_after) = params.last_updated_after {
414                    if let Some(timestamp) = task.status.timestamp {
415                        let task_time_ms = timestamp.timestamp_millis();
416                        if task_time_ms <= last_updated_after {
417                            return false;
418                        }
419                    }
420                }
421
422                true
423            })
424            .cloned()
425            .collect();
426
427        // Sort by timestamp (most recent first)
428        filtered_tasks.sort_by(|a, b| {
429            let a_time = a
430                .status
431                .timestamp
432                .map(|t| t.timestamp_millis())
433                .unwrap_or(0);
434            let b_time = b
435                .status
436                .timestamp
437                .map(|t| t.timestamp_millis())
438                .unwrap_or(0);
439            b_time.cmp(&a_time)
440        });
441
442        let total_size = filtered_tasks.len() as i32;
443
444        // Handle pagination
445        let page_size = params.page_size.unwrap_or(50).clamp(1, 100) as usize;
446        let page_start = if let Some(ref token) = params.page_token {
447            // Parse page token as a number (simple implementation)
448            token.parse::<usize>().unwrap_or(0)
449        } else {
450            0
451        };
452
453        let page_end = (page_start + page_size).min(filtered_tasks.len());
454        let has_more = page_end < filtered_tasks.len();
455
456        // Get the page of tasks
457        let mut page_tasks: Vec<_> = filtered_tasks[page_start..page_end].to_vec();
458
459        // Apply history length limit
460        let history_length = params.history_length.unwrap_or(0);
461        for task in &mut page_tasks {
462            *task = task.with_limited_history(Some(history_length as u32));
463
464            // Remove artifacts if not requested
465            if !params.include_artifacts.unwrap_or(false) {
466                task.artifacts = None;
467            }
468        }
469
470        // Generate next page token
471        let next_page_token = if has_more {
472            page_end.to_string()
473        } else {
474            String::new()
475        };
476
477        Ok(ListTasksResult {
478            tasks: page_tasks,
479            total_size,
480            page_size: page_size as i32,
481            next_page_token,
482        })
483    }
484
485    async fn get_push_notification_config(
486        &self,
487        params: &crate::domain::GetTaskPushNotificationConfigParams,
488    ) -> Result<crate::domain::TaskPushNotificationConfig, A2AError> {
489        // For in-memory storage, we don't support multiple configs per task yet
490        // Just use the existing get_task_notification method
491        self.get_task_notification(&params.id).await
492    }
493
494    async fn list_push_notification_configs(
495        &self,
496        params: &crate::domain::ListTaskPushNotificationConfigParams,
497    ) -> Result<Vec<crate::domain::TaskPushNotificationConfig>, A2AError> {
498        // For in-memory storage, we only support one config per task
499        // Return it as a single-item vec
500        match self
501            .push_notification_registry
502            .get_config(&params.id)
503            .await?
504        {
505            Some(config) => Ok(vec![crate::domain::TaskPushNotificationConfig {
506                task_id: params.id.clone(),
507                push_notification_config: config,
508            }]),
509            None => Ok(vec![]),
510        }
511    }
512
513    async fn delete_push_notification_config(
514        &self,
515        params: &crate::domain::DeleteTaskPushNotificationConfigParams,
516    ) -> Result<(), A2AError> {
517        // For in-memory storage, just remove the single config
518        // In a full implementation, would need to handle config_id
519        self.remove_task_notification(&params.id).await
520    }
521}
522
523// AsyncNotificationManager implementation
524#[async_trait]
525impl AsyncNotificationManager for InMemoryTaskStorage {
526    async fn set_task_notification(
527        &self,
528        config: &TaskPushNotificationConfig,
529    ) -> Result<TaskPushNotificationConfig, A2AError> {
530        #[cfg(feature = "tracing")]
531        tracing::info!(
532            task_id = %config.task_id,
533            url = %config.push_notification_config.url,
534            "✅ Registering push notification config for task"
535        );
536
537        // Register with the push notification registry
538        self.push_notification_registry
539            .register(&config.task_id, config.push_notification_config.clone())
540            .await?;
541
542        #[cfg(feature = "tracing")]
543        tracing::info!(
544            task_id = %config.task_id,
545            "✅ Push notification config registered successfully"
546        );
547
548        Ok(config.clone())
549    }
550
551    async fn get_task_notification(
552        &self,
553        task_id: &str,
554    ) -> Result<TaskPushNotificationConfig, A2AError> {
555        // Get the push notification config from the registry
556        match self.push_notification_registry.get_config(task_id).await? {
557            Some(config) => Ok(TaskPushNotificationConfig {
558                task_id: task_id.to_string(),
559                push_notification_config: config,
560            }),
561            None => Err(A2AError::PushNotificationNotSupported),
562        }
563    }
564
565    async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> {
566        self.push_notification_registry.unregister(task_id).await?;
567        Ok(())
568    }
569}
570
571// AsyncStreamingHandler implementation
572#[async_trait]
573impl AsyncStreamingHandler for InMemoryTaskStorage {
574    async fn add_status_subscriber(
575        &self,
576        task_id: &str,
577        subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
578    ) -> Result<String, A2AError> {
579        #[cfg(feature = "tracing")]
580        tracing::info!(
581            task_id = %task_id,
582            "✅ Adding WebSocket subscriber for status updates"
583        );
584
585        // Add the subscriber
586        {
587            let mut subscribers_guard = self.subscribers.lock().await;
588
589            let task_subscribers = subscribers_guard
590                .entry(task_id.to_string())
591                .or_insert_with(TaskSubscribers::new);
592
593            task_subscribers.status.push(subscriber);
594
595            #[cfg(feature = "tracing")]
596            tracing::info!(
597                task_id = %task_id,
598                subscriber_count = task_subscribers.status.len(),
599                "✅ WebSocket subscriber added successfully"
600            );
601        } // Lock is dropped here
602
603        // Try to get the current status to send as an initial update
604        // But don't fail if the task doesn't exist yet - the subscriber will get updates when it's created
605        if let Ok(task) = self.get_task(task_id, None).await {
606            let _ = self
607                .broadcast_status_update(task_id, task.status, false)
608                .await;
609        }
610
611        Ok(format!("status-{}-{}", task_id, uuid::Uuid::new_v4()))
612    }
613
614    async fn add_artifact_subscriber(
615        &self,
616        task_id: &str,
617        subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
618    ) -> Result<String, A2AError> {
619        // Add the subscriber
620        {
621            let mut subscribers_guard = self.subscribers.lock().await;
622
623            let task_subscribers = subscribers_guard
624                .entry(task_id.to_string())
625                .or_insert_with(TaskSubscribers::new);
626
627            task_subscribers.artifacts.push(subscriber);
628        } // Lock is dropped here
629
630        // If there are existing artifacts, broadcast them
631        // But don't fail if the task doesn't exist yet - the subscriber will get updates when it's created
632        if let Ok(task) = self.get_task(task_id, None).await {
633            if let Some(artifacts) = task.artifacts {
634                for artifact in artifacts {
635                    let _ = self
636                        .broadcast_artifact_update(task_id, artifact, None, false)
637                        .await;
638                }
639            }
640        }
641
642        Ok(format!("artifact-{}-{}", task_id, uuid::Uuid::new_v4()))
643    }
644
645    async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> {
646        Err(A2AError::UnsupportedOperation(
647            "Subscription removal by ID requires storage layer refactoring".to_string(),
648        ))
649    }
650
651    async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
652        // Remove all subscribers
653        {
654            let mut subscribers_guard = self.subscribers.lock().await;
655            subscribers_guard.remove(task_id);
656        } // Lock is dropped here
657
658        Ok(())
659    }
660
661    async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
662        let subscribers_guard = self.subscribers.lock().await;
663
664        if let Some(task_subscribers) = subscribers_guard.get(task_id) {
665            Ok(task_subscribers.status.len() + task_subscribers.artifacts.len())
666        } else {
667            Ok(0)
668        }
669    }
670
671    async fn broadcast_status_update(
672        &self,
673        task_id: &str,
674        update: TaskStatusUpdateEvent,
675    ) -> Result<(), A2AError> {
676        self.broadcast_status_update(task_id, update.status, update.final_)
677            .await
678    }
679
680    async fn broadcast_artifact_update(
681        &self,
682        task_id: &str,
683        update: TaskArtifactUpdateEvent,
684    ) -> Result<(), A2AError> {
685        self.broadcast_artifact_update(
686            task_id,
687            update.artifact,
688            None,
689            update.last_chunk.unwrap_or(false),
690        )
691        .await
692    }
693
694    async fn status_update_stream(
695        &self,
696        _task_id: &str,
697    ) -> Result<
698        std::pin::Pin<
699            Box<dyn futures::Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>,
700        >,
701        A2AError,
702    > {
703        Err(A2AError::UnsupportedOperation(
704            "Status update stream requires storage layer refactoring".to_string(),
705        ))
706    }
707
708    async fn artifact_update_stream(
709        &self,
710        _task_id: &str,
711    ) -> Result<
712        std::pin::Pin<
713            Box<dyn futures::Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>,
714        >,
715        A2AError,
716    > {
717        Err(A2AError::UnsupportedOperation(
718            "Artifact update stream requires storage layer refactoring".to_string(),
719        ))
720    }
721
722    async fn combined_update_stream(
723        &self,
724        _task_id: &str,
725    ) -> Result<
726        std::pin::Pin<
727            Box<
728                dyn futures::Stream<
729                        Item = Result<crate::port::streaming_handler::UpdateEvent, A2AError>,
730                    > + Send,
731            >,
732        >,
733        A2AError,
734    > {
735        Err(A2AError::UnsupportedOperation(
736            "Combined update stream requires storage layer refactoring".to_string(),
737        ))
738    }
739}
740
741impl Clone for InMemoryTaskStorage {
742    fn clone(&self) -> Self {
743        Self {
744            tasks: self.tasks.clone(),
745            subscribers: self.subscribers.clone(),
746            push_notification_registry: self.push_notification_registry.clone(),
747        }
748    }
749}