Skip to main content

a2a_rs/adapter/storage/
task_storage.rs

1//! In-memory task storage implementation
2
3// This module is already conditionally compiled with #[cfg(feature = "server")] in mod.rs
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use tokio::sync::Mutex; // Changed from std::sync::Mutex
10
11use crate::adapter::business::push_notification::{
12    PushNotificationRegistry, PushNotificationSender,
13};
14
15#[cfg(feature = "http-client")]
16use crate::adapter::business::push_notification::HttpPushNotificationSender;
17#[cfg(not(feature = "http-client"))]
18use crate::adapter::business::push_notification::NoopPushNotificationSender;
19use crate::domain::{
20    A2AError, ContextId, Message, Task, TaskId, TaskPushNotificationConfig, TaskState,
21    VersionedTask,
22};
23use crate::port::{
24    AsyncNotificationManager, AsyncPushNotifier, AsyncTaskLifecycle, AsyncTaskQuery,
25    AsyncTaskVersioning,
26};
27
28/// Simple in-memory task storage for testing and example purposes.
29///
30/// Persistence-only: streaming fan-out lives in
31/// [`InMemoryStreamingHandler`](crate::adapter::InMemoryStreamingHandler) and
32/// push-webhook delivery behind the [`AsyncPushNotifier`] port (this struct hands
33/// out its registry via [`push_notifier`](Self::push_notifier)). The store still
34/// owns push-config CRUD ([`AsyncNotificationManager`]) because that is config
35/// *persistence*.
36pub struct InMemoryTaskStorage {
37    /// Tasks stored by ID
38    pub(crate) tasks: Arc<Mutex<HashMap<String, Task>>>,
39    /// Per-task optimistic-concurrency version, bumped on every mutation.
40    ///
41    /// A separate map keyed by the same task id. Mutators always lock `tasks`
42    /// first and `versions` second, so the two stay consistent and never
43    /// deadlock (see [`AsyncTaskVersioning`]).
44    pub(crate) versions: Arc<Mutex<HashMap<String, u64>>>,
45    /// Push notification registry (config store + delivery backend)
46    pub(crate) push_notification_registry: Arc<PushNotificationRegistry>,
47}
48
49impl InMemoryTaskStorage {
50    /// Create a new empty task storage
51    pub fn new() -> Self {
52        // Use the appropriate push notification sender based on available features
53        #[cfg(feature = "http-client")]
54        let push_sender = HttpPushNotificationSender::new();
55        #[cfg(not(feature = "http-client"))]
56        let push_sender = NoopPushNotificationSender;
57
58        let push_registry = PushNotificationRegistry::new(push_sender);
59
60        Self {
61            tasks: Arc::new(Mutex::new(HashMap::new())),
62            versions: Arc::new(Mutex::new(HashMap::new())),
63            push_notification_registry: Arc::new(push_registry),
64        }
65    }
66
67    /// Create a new task storage with a custom push notification sender
68    pub fn with_push_sender(push_sender: impl PushNotificationSender + 'static) -> Self {
69        let push_registry = PushNotificationRegistry::new(push_sender);
70
71        Self {
72            tasks: Arc::new(Mutex::new(HashMap::new())),
73            versions: Arc::new(Mutex::new(HashMap::new())),
74            push_notification_registry: Arc::new(push_registry),
75        }
76    }
77
78    /// Bump (or initialize) the stored version for a task, returning the new
79    /// value. Callers already hold the `tasks` lock; this acquires `versions`
80    /// second, preserving the global lock order.
81    async fn bump_version(&self, task_id: &str) -> u64 {
82        let mut versions = self.versions.lock().await;
83        let v = versions.entry(task_id.to_string()).or_insert(0);
84        *v += 1;
85        *v
86    }
87
88    /// Hand out this store's push-notification registry as an
89    /// [`AsyncPushNotifier`].
90    ///
91    /// The returned notifier shares the same config registry the store writes to
92    /// via [`AsyncNotificationManager::set_config`], so a config registered on
93    /// the store is immediately visible to the notifier at the composition edge.
94    pub fn push_notifier(&self) -> Arc<dyn AsyncPushNotifier> {
95        self.push_notification_registry.clone()
96    }
97}
98
99impl Default for InMemoryTaskStorage {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105#[async_trait]
106impl AsyncTaskLifecycle for InMemoryTaskStorage {
107    async fn create(&self, id: &TaskId, context_id: &ContextId) -> Result<Task, A2AError> {
108        let task_id = id.as_str();
109        let context_id = context_id.as_str();
110        let mut tasks_guard = self.tasks.lock().await;
111
112        if tasks_guard.contains_key(task_id) {
113            return Err(A2AError::TaskNotFound(format!(
114                "Task {} already exists",
115                task_id
116            )));
117        }
118
119        let task = Task::new(task_id.to_string(), context_id.to_string());
120        tasks_guard.insert(task_id.to_string(), task.clone());
121        self.bump_version(task_id).await; // version 0 -> 1
122
123        Ok(task)
124    }
125
126    async fn update_status(
127        &self,
128        id: &TaskId,
129        state: TaskState,
130        message: Option<Message>,
131    ) -> Result<Task, A2AError> {
132        let task_id = id.as_str();
133        let mut tasks_guard = self.tasks.lock().await;
134
135        let task = tasks_guard
136            .get_mut(task_id)
137            .ok_or_else(|| A2AError::TaskNotFound(task_id.to_string()))?;
138
139        // Update the task status with the optional message
140        task.update_status(state, message);
141        let updated = task.clone();
142        self.bump_version(task_id).await;
143
144        // Persistence only: announcing the change to streaming subscribers is
145        // the orchestration layer's job (see `TaskStatusBroadcast`), not a side
146        // effect of the mutator.
147        Ok(updated)
148    }
149
150    async fn exists(&self, id: &TaskId) -> Result<bool, A2AError> {
151        let task_id = id.as_str();
152        let tasks_guard = self.tasks.lock().await;
153        Ok(tasks_guard.contains_key(task_id))
154    }
155
156    async fn get(&self, id: &TaskId, history_length: Option<u32>) -> Result<Task, A2AError> {
157        let task_id = id.as_str();
158        // Get the task
159        let task = {
160            let tasks_guard = self.tasks.lock().await;
161
162            let Some(task) = tasks_guard.get(task_id) else {
163                return Err(A2AError::TaskNotFound(task_id.to_string()));
164            };
165
166            // Apply history length limitation if specified
167            task.with_limited_history(history_length)
168        }; // Lock is dropped here
169
170        Ok(task)
171    }
172
173    async fn cancel(&self, id: &TaskId) -> Result<Task, A2AError> {
174        let task_id = id.as_str();
175        let mut tasks_guard = self.tasks.lock().await;
176
177        let Some(task) = tasks_guard.get(task_id) else {
178            return Err(A2AError::TaskNotFound(task_id.to_string()));
179        };
180
181        let mut updated_task = task.clone();
182
183        // Only working tasks can be canceled
184        if updated_task.status.state != TaskState::Working {
185            return Err(A2AError::TaskNotCancelable(format!(
186                "Task {} is in state {:?} and cannot be canceled",
187                task_id, updated_task.status.state
188            )));
189        }
190
191        // Create a cancellation message to add to history
192        let cancel_message = Message {
193            role: ::buffa::EnumValue::from(crate::domain::Role::Agent),
194            parts: vec![crate::domain::Part::text(format!(
195                "Task {} canceled.",
196                task_id
197            ))],
198            message_id: uuid::Uuid::new_v4().to_string(),
199            task_id: task_id.to_string(),
200            context_id: updated_task.context_id.clone(),
201            ..Default::default()
202        };
203
204        // Update the status with the cancellation message to track in history
205        updated_task.update_status(TaskState::Canceled, Some(cancel_message));
206        tasks_guard.insert(task_id.to_string(), updated_task.clone());
207        self.bump_version(task_id).await;
208
209        // Persistence only: the orchestration layer announces the cancellation
210        // to streaming subscribers (see `TaskStatusBroadcast`).
211        Ok(updated_task)
212    }
213}
214
215#[async_trait]
216impl AsyncTaskVersioning for InMemoryTaskStorage {
217    async fn version(&self, id: &TaskId) -> Result<u64, A2AError> {
218        let task_id = id.as_str();
219        let tasks_guard = self.tasks.lock().await;
220        if !tasks_guard.contains_key(task_id) {
221            return Err(A2AError::TaskNotFound(task_id.to_string()));
222        }
223        let versions = self.versions.lock().await;
224        Ok(versions.get(task_id).copied().unwrap_or(0))
225    }
226
227    async fn get_versioned(
228        &self,
229        id: &TaskId,
230        history_length: Option<u32>,
231    ) -> Result<VersionedTask, A2AError> {
232        let task_id = id.as_str();
233        let tasks_guard = self.tasks.lock().await;
234        let Some(task) = tasks_guard.get(task_id) else {
235            return Err(A2AError::TaskNotFound(task_id.to_string()));
236        };
237        let task = task.with_limited_history(history_length);
238        let versions = self.versions.lock().await;
239        let version = versions.get(task_id).copied().unwrap_or(0);
240        Ok(VersionedTask::new(task, version))
241    }
242
243    async fn update_status_checked(
244        &self,
245        id: &TaskId,
246        expected: u64,
247        state: TaskState,
248        message: Option<Message>,
249    ) -> Result<VersionedTask, A2AError> {
250        let task_id = id.as_str();
251        // Lock order: tasks, then versions — the compare-and-swap holds both so
252        // the check and the bump are atomic against every other mutator.
253        let mut tasks_guard = self.tasks.lock().await;
254        let task = tasks_guard
255            .get_mut(task_id)
256            .ok_or_else(|| A2AError::TaskNotFound(task_id.to_string()))?;
257        let mut versions = self.versions.lock().await;
258        let current = versions.get(task_id).copied().unwrap_or(0);
259        if current != expected {
260            return Err(A2AError::VersionConflict {
261                id: task_id.to_string(),
262                expected,
263                actual: current,
264            });
265        }
266        task.update_status(state, message);
267        let new_version = current + 1;
268        versions.insert(task_id.to_string(), new_version);
269        Ok(VersionedTask::new(task.clone(), new_version))
270    }
271}
272
273#[async_trait]
274impl AsyncTaskQuery for InMemoryTaskStorage {
275    async fn list(
276        &self,
277        params: &crate::domain::ListTasksParams,
278    ) -> Result<crate::domain::ListTasksResult, A2AError> {
279        use crate::domain::ListTasksResult;
280
281        let tasks_guard = self.tasks.lock().await;
282
283        // Filter tasks based on parameters
284        let mut filtered_tasks: Vec<_> = tasks_guard
285            .values()
286            .filter(|task| {
287                // Filter by context_id if provided
288                if let Some(ref context_id) = params.context_id {
289                    if &task.context_id != context_id {
290                        return false;
291                    }
292                }
293
294                // Filter by status if provided
295                if let Some(ref status) = params.status {
296                    if &task.status.state != status {
297                        return false;
298                    }
299                }
300
301                // Filter by status_timestamp_after if provided
302                if let Some(status_timestamp_after) = &params.status_timestamp_after {
303                    if let Ok(after_dt) =
304                        chrono::DateTime::parse_from_rfc3339(status_timestamp_after)
305                    {
306                        let after_utc = after_dt.with_timezone(&chrono::Utc);
307                        if let Some(timestamp) = task.status.timestamp_utc() {
308                            if timestamp <= after_utc {
309                                return false;
310                            }
311                        }
312                    }
313                }
314
315                true
316            })
317            .cloned()
318            .collect();
319
320        // Sort by timestamp (most recent first)
321        filtered_tasks.sort_by(|a, b| {
322            let a_time = a
323                .status
324                .timestamp_utc()
325                .map(|t| t.timestamp_millis())
326                .unwrap_or(0);
327            let b_time = b
328                .status
329                .timestamp_utc()
330                .map(|t| t.timestamp_millis())
331                .unwrap_or(0);
332            b_time.cmp(&a_time)
333        });
334
335        let total_size = filtered_tasks.len() as i32;
336
337        // Handle pagination
338        let page_size = params.page_size.unwrap_or(50).clamp(1, 100) as usize;
339        let page_start = if let Some(ref token) = params.page_token {
340            // Parse page token as a number (simple implementation)
341            token.parse::<usize>().unwrap_or(0)
342        } else {
343            0
344        };
345
346        let page_end = (page_start + page_size).min(filtered_tasks.len());
347        let has_more = page_end < filtered_tasks.len();
348
349        // Get the page of tasks
350        let mut page_tasks: Vec<_> = filtered_tasks[page_start..page_end].to_vec();
351
352        // Apply history length limit
353        let history_length = params.history_length.unwrap_or(0);
354        for task in &mut page_tasks {
355            *task = task.with_limited_history(Some(history_length as u32));
356
357            // Remove artifacts if not requested
358            if !params.include_artifacts.unwrap_or(false) {
359                task.artifacts.clear();
360            }
361        }
362
363        // Generate next page token
364        let next_page_token = if has_more {
365            page_end.to_string()
366        } else {
367            String::new()
368        };
369
370        Ok(ListTasksResult {
371            tasks: page_tasks,
372            total_size,
373            page_size: page_size as i32,
374            next_page_token,
375        })
376    }
377}
378
379// AsyncNotificationManager implementation.
380//
381// In-memory storage keeps a single config per task in the push-notification
382// registry, so the multi-config CRUD surface is expressed in those terms.
383#[async_trait]
384impl AsyncNotificationManager for InMemoryTaskStorage {
385    async fn set_config(
386        &self,
387        config: &TaskPushNotificationConfig,
388    ) -> Result<TaskPushNotificationConfig, A2AError> {
389        #[cfg(feature = "tracing")]
390        tracing::info!(
391            task_id = %config.task_id,
392            url = %config.url,
393            "🚀 Registering push notification config for task"
394        );
395
396        // Register with the push notification registry
397        self.push_notification_registry
398            .register(&config.task_id, config.clone())
399            .await?;
400
401        #[cfg(feature = "tracing")]
402        tracing::info!(
403            task_id = %config.task_id,
404            "✅ Push notification config registered successfully"
405        );
406
407        Ok(config.clone())
408    }
409
410    async fn get_config(
411        &self,
412        params: &crate::domain::GetTaskPushNotificationConfigParams,
413    ) -> Result<TaskPushNotificationConfig, A2AError> {
414        match self
415            .push_notification_registry
416            .get_config(&params.id)
417            .await?
418        {
419            Some(config) => Ok(config),
420            None => Err(A2AError::PushNotificationNotSupported),
421        }
422    }
423
424    async fn list_configs(
425        &self,
426        params: &crate::domain::ListTaskPushNotificationConfigsParams,
427    ) -> Result<Vec<TaskPushNotificationConfig>, A2AError> {
428        // In-memory storage supports one config per task; return it as a
429        // single-item vec (or empty if none registered).
430        match self
431            .push_notification_registry
432            .get_config(&params.id)
433            .await?
434        {
435            Some(config) => Ok(vec![config]),
436            None => Ok(vec![]),
437        }
438    }
439
440    async fn delete_config(
441        &self,
442        params: &crate::domain::DeleteTaskPushNotificationConfigParams,
443    ) -> Result<(), A2AError> {
444        // In-memory storage keeps a single config per task, so config_id is
445        // not used for lookup. Idempotent per the v1.0.0 spec.
446        self.push_notification_registry
447            .unregister(&params.id)
448            .await?;
449        Ok(())
450    }
451}
452
453impl Clone for InMemoryTaskStorage {
454    fn clone(&self) -> Self {
455        Self {
456            tasks: self.tasks.clone(),
457            versions: self.versions.clone(),
458            push_notification_registry: self.push_notification_registry.clone(),
459        }
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use crate::domain::ContextId;
467
468    fn tid(s: &str) -> TaskId {
469        s.parse().unwrap()
470    }
471    fn cid(s: &str) -> ContextId {
472        s.parse().unwrap()
473    }
474
475    #[tokio::test]
476    async fn versioning_tracks_and_guards_mutations() {
477        let store = InMemoryTaskStorage::new();
478        store.create(&tid("t1"), &cid("c1")).await.unwrap();
479        assert_eq!(store.version(&tid("t1")).await.unwrap(), 1);
480
481        // Unversioned mutations bump the version, keeping the two views in sync.
482        store
483            .update_status(&tid("t1"), TaskState::Working, None)
484            .await
485            .unwrap();
486        let snap = store.get_versioned(&tid("t1"), None).await.unwrap();
487        assert_eq!(snap.version, 2);
488
489        // Stale conditional update is rejected and leaves the task unchanged.
490        let err = store
491            .update_status_checked(&tid("t1"), 1, TaskState::Completed, None)
492            .await
493            .unwrap_err();
494        assert!(matches!(
495            err,
496            A2AError::VersionConflict {
497                expected: 1,
498                actual: 2,
499                ..
500            }
501        ));
502        assert_eq!(
503            store.get(&tid("t1"), None).await.unwrap().status.state,
504            TaskState::Working
505        );
506
507        // Current-version conditional update succeeds and bumps.
508        let ok = store
509            .update_status_checked(&tid("t1"), 2, TaskState::Completed, None)
510            .await
511            .unwrap();
512        assert_eq!(ok.version, 3);
513        assert_eq!(ok.task.status.state, TaskState::Completed);
514    }
515}