Skip to main content

inference_gateway_adk/server/
storage.rs

1//! Storage abstraction backing the A2A task manager.
2//!
3//! [`Storage`] is the trait the A2A server holds as `Arc<dyn Storage>` to
4//! persist tasks, contexts, and push-notification configurations and to
5//! drive the background-task queue. [`InMemoryStorage`] is the bundled
6//! default - a `Mutex`+`Notify`-backed structure suitable for tests,
7//! single-instance deployments, and bootstrap. Implement [`Storage`]
8//! yourself to plug in Redis, Postgres, or any other backend without
9//! forking the crate.
10//!
11//! A queue (enqueue/dequeue/length/clear),
12//! an active-task store (create/get/update), a dead-letter store
13//! (store/list), context bookkeeping, cleanup helpers, and stats.
14
15use crate::a2a_types::{Task, TaskPushNotificationConfig, TaskState};
16use anyhow::{Result, anyhow};
17use async_trait::async_trait;
18use chrono::{DateTime, Utc};
19use serde_json::Value;
20use std::collections::{HashMap, HashSet, VecDeque};
21use std::sync::Mutex;
22use tokio::sync::Notify;
23
24/// A task pulled off the queue, plus the JSON-RPC `request_id` that
25/// originally enqueued it. The `request_id` is preserved for
26/// correlation/tracing - it is not consumed by the worker today,
27/// `Serialize`/`Deserialize` so the Redis backend
28/// can JSON-encode the value into a Redis LIST entry.
29#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
30pub struct QueuedTask {
31    pub task: Task,
32    pub request_id: Value,
33    pub enqueued_at: DateTime<Utc>,
34}
35
36/// Filter / pagination applied to `list_tasks` / `list_tasks_by_context`.
37/// `state == None` and `limit == None` means "no filtering / no cap".
38#[derive(Debug, Clone, Default)]
39pub struct TaskFilter {
40    pub state: Option<TaskState>,
41    pub limit: Option<usize>,
42    pub offset: Option<usize>,
43}
44
45/// Counters returned by [`Storage::get_stats`]. Used by health endpoints
46/// and operational dashboards.
47#[derive(Debug, Clone, Default, PartialEq, Eq)]
48pub struct StorageStats {
49    pub queue_length: usize,
50    pub active_tasks: usize,
51    pub dead_letter_tasks: usize,
52    pub contexts: usize,
53}
54
55/// Pluggable storage for the A2A task manager. backend can be swapped
56/// in via [`A2AServerBuilder::with_storage`](super::server_builder::A2AServerBuilder::with_storage).
57///
58/// All methods are async to accommodate backends that need to issue
59/// I/O. Implementations use interior mutability (mutex, connection
60/// pool, etc.) so signatures stay `Arc<dyn Storage>`-friendly.
61#[async_trait]
62pub trait Storage: Send + Sync + std::fmt::Debug {
63    // ----- Queue -----------------------------------------------------
64
65    /// Push a task onto the back of the work queue.
66    async fn enqueue_task(&self, task: Task, request_id: Value) -> Result<()>;
67
68    /// Pop the next task off the front of the queue, **blocking** until
69    /// one is available (Redis: `BRPOP`).
70    async fn dequeue_task(&self) -> Result<QueuedTask>;
71
72    async fn queue_length(&self) -> usize;
73
74    async fn clear_queue(&self) -> Result<()>;
75
76    // ----- Active-task store ----------------------------------------
77
78    async fn create_active_task(&self, task: &Task) -> Result<()>;
79
80    async fn get_active_task(&self, task_id: &str) -> Result<Option<Task>>;
81
82    async fn update_active_task(&self, task: &Task) -> Result<()>;
83
84    // ----- Dead-letter + general task read --------------------------
85
86    /// Move a task to the dead-letter store (terminal-state archive).
87    /// Implementations also remove the task from the active store if
88    /// it is present there.
89    async fn store_dead_letter_task(&self, task: &Task) -> Result<()>;
90
91    /// Look up a task in any store (active first, then dead-letter).
92    async fn get_task(&self, task_id: &str) -> Option<Task>;
93
94    /// Upsert a task into the active store, replacing any existing entry
95    /// with the same id. Convenience method retained for callers that
96    /// don't want to distinguish create-vs-update.
97    async fn put_task(&self, task: Task);
98
99    async fn get_task_by_context_and_id(&self, context_id: &str, task_id: &str) -> Option<Task>;
100
101    /// Remove a task from both active and dead-letter stores.
102    async fn delete_task(&self, task_id: &str) -> Result<()>;
103
104    /// List tasks across active + dead-letter, applying `filter`.
105    async fn list_tasks(&self, filter: TaskFilter) -> Vec<Task>;
106
107    async fn list_tasks_by_context(&self, context_id: &str, filter: TaskFilter) -> Vec<Task>;
108
109    // ----- Contexts -------------------------------------------------
110
111    async fn get_contexts(&self) -> Vec<String>;
112
113    async fn get_contexts_with_tasks(&self) -> Vec<String>;
114
115    async fn delete_context(&self, context_id: &str) -> Result<()>;
116
117    async fn delete_context_and_tasks(&self, context_id: &str) -> Result<()>;
118
119    // ----- Cleanup / stats ------------------------------------------
120
121    /// Remove every task in dead-letter whose state is `Completed`.
122    /// Returns the number of tasks deleted.
123    async fn cleanup_completed_tasks(&self) -> usize;
124
125    /// Trim the dead-letter store to at most `max_completed` `Completed`
126    /// tasks and `max_failed` `Failed` tasks (oldest first). Returns the
127    /// total number of tasks deleted.
128    async fn cleanup_tasks_with_retention(&self, max_completed: usize, max_failed: usize) -> usize;
129
130    async fn get_stats(&self) -> StorageStats;
131
132    // ----- Push-notification configs (Rust-specific) ----------------
133
134    async fn put_push_notification_config(&self, config: TaskPushNotificationConfig);
135
136    async fn get_push_notification_config(&self, name: &str) -> Option<TaskPushNotificationConfig>;
137
138    async fn list_push_notification_configs(&self, parent: &str)
139    -> Vec<TaskPushNotificationConfig>;
140
141    async fn delete_push_notification_config(&self, name: &str) -> bool;
142}
143
144/// Simple in-memory [`Storage`] implementation. Suitable for tests,
145/// single-instance deployments, and as a baseline reference. Holds a
146/// `std::sync::Mutex` plus a `tokio::sync::Notify` to park the dequeue
147/// loop until an enqueue notifies it.
148#[derive(Debug, Default)]
149pub struct InMemoryStorage {
150    inner: Mutex<StorageInner>,
151    queue_notify: Notify,
152}
153
154#[derive(Debug, Default)]
155struct StorageInner {
156    queue: VecDeque<QueuedTask>,
157    active_tasks: HashMap<String, Task>,
158    dead_letter_tasks: HashMap<String, Task>,
159    /// Set of `context_id` values seen via `enqueue_task` /
160    /// `create_active_task` / `store_dead_letter_task`. Used by
161    /// `get_contexts`.
162    contexts: HashSet<String>,
163    push_notification_configs: HashMap<String, TaskPushNotificationConfig>,
164}
165
166impl InMemoryStorage {
167    pub fn new() -> Self {
168        Self::default()
169    }
170}
171
172fn apply_filter(mut tasks: Vec<Task>, filter: TaskFilter) -> Vec<Task> {
173    if let Some(state) = filter.state {
174        tasks.retain(|t| t.status.state == state);
175    }
176    if let Some(offset) = filter.offset {
177        if offset >= tasks.len() {
178            return Vec::new();
179        }
180        tasks.drain(..offset);
181    }
182    if let Some(limit) = filter.limit
183        && tasks.len() > limit
184    {
185        tasks.truncate(limit);
186    }
187    tasks
188}
189
190#[async_trait]
191impl Storage for InMemoryStorage {
192    // ----- Queue -----------------------------------------------------
193
194    async fn enqueue_task(&self, task: Task, request_id: Value) -> Result<()> {
195        {
196            let mut inner = self.inner.lock().expect("storage mutex poisoned");
197            inner.contexts.insert(task.context_id.clone());
198            inner.queue.push_back(QueuedTask {
199                task,
200                request_id,
201                enqueued_at: Utc::now(),
202            });
203        }
204        self.queue_notify.notify_one();
205        Ok(())
206    }
207
208    async fn dequeue_task(&self) -> Result<QueuedTask> {
209        loop {
210            let notified = self.queue_notify.notified();
211            {
212                let mut inner = self.inner.lock().expect("storage mutex poisoned");
213                if let Some(queued) = inner.queue.pop_front() {
214                    return Ok(queued);
215                }
216            }
217            notified.await;
218        }
219    }
220
221    async fn queue_length(&self) -> usize {
222        let inner = self.inner.lock().expect("storage mutex poisoned");
223        inner.queue.len()
224    }
225
226    async fn clear_queue(&self) -> Result<()> {
227        let mut inner = self.inner.lock().expect("storage mutex poisoned");
228        inner.queue.clear();
229        Ok(())
230    }
231
232    // ----- Active-task store ----------------------------------------
233
234    async fn create_active_task(&self, task: &Task) -> Result<()> {
235        let mut inner = self.inner.lock().expect("storage mutex poisoned");
236        if inner.active_tasks.contains_key(&task.id) {
237            return Err(anyhow!("active task {:?} already exists", task.id));
238        }
239        inner.contexts.insert(task.context_id.clone());
240        inner.active_tasks.insert(task.id.clone(), task.clone());
241        Ok(())
242    }
243
244    async fn get_active_task(&self, task_id: &str) -> Result<Option<Task>> {
245        let inner = self.inner.lock().expect("storage mutex poisoned");
246        Ok(inner.active_tasks.get(task_id).cloned())
247    }
248
249    async fn update_active_task(&self, task: &Task) -> Result<()> {
250        let mut inner = self.inner.lock().expect("storage mutex poisoned");
251        if !inner.active_tasks.contains_key(&task.id) {
252            return Err(anyhow!(
253                "cannot update active task {:?}: not found",
254                task.id
255            ));
256        }
257        inner.active_tasks.insert(task.id.clone(), task.clone());
258        Ok(())
259    }
260
261    // ----- Dead-letter + general task read --------------------------
262
263    async fn store_dead_letter_task(&self, task: &Task) -> Result<()> {
264        let mut inner = self.inner.lock().expect("storage mutex poisoned");
265        inner.contexts.insert(task.context_id.clone());
266        inner.active_tasks.remove(&task.id);
267        inner
268            .dead_letter_tasks
269            .insert(task.id.clone(), task.clone());
270        Ok(())
271    }
272
273    async fn get_task(&self, task_id: &str) -> Option<Task> {
274        let inner = self.inner.lock().expect("storage mutex poisoned");
275        inner
276            .active_tasks
277            .get(task_id)
278            .or_else(|| inner.dead_letter_tasks.get(task_id))
279            .cloned()
280    }
281
282    async fn put_task(&self, task: Task) {
283        let mut inner = self.inner.lock().expect("storage mutex poisoned");
284        inner.contexts.insert(task.context_id.clone());
285        inner.active_tasks.insert(task.id.clone(), task);
286    }
287
288    async fn get_task_by_context_and_id(&self, context_id: &str, task_id: &str) -> Option<Task> {
289        let inner = self.inner.lock().expect("storage mutex poisoned");
290        inner
291            .active_tasks
292            .get(task_id)
293            .or_else(|| inner.dead_letter_tasks.get(task_id))
294            .filter(|t| t.context_id == context_id)
295            .cloned()
296    }
297
298    async fn delete_task(&self, task_id: &str) -> Result<()> {
299        let mut inner = self.inner.lock().expect("storage mutex poisoned");
300        let active_removed = inner.active_tasks.remove(task_id).is_some();
301        let dead_removed = inner.dead_letter_tasks.remove(task_id).is_some();
302        if !active_removed && !dead_removed {
303            return Err(anyhow!("task {task_id:?} not found in any store"));
304        }
305        Ok(())
306    }
307
308    async fn list_tasks(&self, filter: TaskFilter) -> Vec<Task> {
309        let inner = self.inner.lock().expect("storage mutex poisoned");
310        let tasks: Vec<Task> = inner
311            .active_tasks
312            .values()
313            .chain(inner.dead_letter_tasks.values())
314            .cloned()
315            .collect();
316        drop(inner);
317        apply_filter(tasks, filter)
318    }
319
320    async fn list_tasks_by_context(&self, context_id: &str, filter: TaskFilter) -> Vec<Task> {
321        let inner = self.inner.lock().expect("storage mutex poisoned");
322        let tasks: Vec<Task> = inner
323            .active_tasks
324            .values()
325            .chain(inner.dead_letter_tasks.values())
326            .filter(|t| t.context_id == context_id)
327            .cloned()
328            .collect();
329        drop(inner);
330        apply_filter(tasks, filter)
331    }
332
333    // ----- Contexts -------------------------------------------------
334
335    async fn get_contexts(&self) -> Vec<String> {
336        let inner = self.inner.lock().expect("storage mutex poisoned");
337        inner.contexts.iter().cloned().collect()
338    }
339
340    async fn get_contexts_with_tasks(&self) -> Vec<String> {
341        let inner = self.inner.lock().expect("storage mutex poisoned");
342        let mut out: HashSet<String> = HashSet::new();
343        for t in inner.active_tasks.values() {
344            out.insert(t.context_id.clone());
345        }
346        for t in inner.dead_letter_tasks.values() {
347            out.insert(t.context_id.clone());
348        }
349        out.into_iter().collect()
350    }
351
352    async fn delete_context(&self, context_id: &str) -> Result<()> {
353        let mut inner = self.inner.lock().expect("storage mutex poisoned");
354        inner.contexts.remove(context_id);
355        Ok(())
356    }
357
358    async fn delete_context_and_tasks(&self, context_id: &str) -> Result<()> {
359        let mut inner = self.inner.lock().expect("storage mutex poisoned");
360        inner.active_tasks.retain(|_, t| t.context_id != context_id);
361        inner
362            .dead_letter_tasks
363            .retain(|_, t| t.context_id != context_id);
364        inner.contexts.remove(context_id);
365        Ok(())
366    }
367
368    // ----- Cleanup / stats ------------------------------------------
369
370    async fn cleanup_completed_tasks(&self) -> usize {
371        let mut inner = self.inner.lock().expect("storage mutex poisoned");
372        let before = inner.dead_letter_tasks.len();
373        inner
374            .dead_letter_tasks
375            .retain(|_, t| t.status.state != TaskState::TaskStateCompleted);
376        before - inner.dead_letter_tasks.len()
377    }
378
379    async fn cleanup_tasks_with_retention(&self, max_completed: usize, max_failed: usize) -> usize {
380        let mut inner = self.inner.lock().expect("storage mutex poisoned");
381
382        fn evict(store: &mut HashMap<String, Task>, state: TaskState, keep: usize) -> usize {
383            let mut matching: Vec<(String, Option<DateTime<Utc>>)> = store
384                .iter()
385                .filter(|(_, t)| t.status.state == state)
386                .map(|(k, t)| (k.clone(), t.status.timestamp.as_ref().map(|ts| ts.0)))
387                .collect();
388            if matching.len() <= keep {
389                return 0;
390            }
391            matching.sort_by_key(|(_, ts)| *ts);
392            let evict_count = matching.len() - keep;
393            for (id, _) in matching.into_iter().take(evict_count) {
394                store.remove(&id);
395            }
396            evict_count
397        }
398
399        let completed_removed = evict(
400            &mut inner.dead_letter_tasks,
401            TaskState::TaskStateCompleted,
402            max_completed,
403        );
404        let failed_removed = evict(
405            &mut inner.dead_letter_tasks,
406            TaskState::TaskStateFailed,
407            max_failed,
408        );
409        completed_removed + failed_removed
410    }
411
412    async fn get_stats(&self) -> StorageStats {
413        let inner = self.inner.lock().expect("storage mutex poisoned");
414        StorageStats {
415            queue_length: inner.queue.len(),
416            active_tasks: inner.active_tasks.len(),
417            dead_letter_tasks: inner.dead_letter_tasks.len(),
418            contexts: inner.contexts.len(),
419        }
420    }
421
422    // ----- Push-notification configs -------------------------------
423
424    async fn put_push_notification_config(&self, config: TaskPushNotificationConfig) {
425        let mut inner = self.inner.lock().expect("storage mutex poisoned");
426        inner
427            .push_notification_configs
428            .insert(config.name.clone(), config);
429    }
430
431    async fn get_push_notification_config(&self, name: &str) -> Option<TaskPushNotificationConfig> {
432        let inner = self.inner.lock().expect("storage mutex poisoned");
433        inner.push_notification_configs.get(name).cloned()
434    }
435
436    async fn list_push_notification_configs(
437        &self,
438        parent: &str,
439    ) -> Vec<TaskPushNotificationConfig> {
440        let prefix = format!("{parent}/pushNotificationConfigs/");
441        let inner = self.inner.lock().expect("storage mutex poisoned");
442        inner
443            .push_notification_configs
444            .values()
445            .filter(|c| c.name.starts_with(&prefix))
446            .cloned()
447            .collect()
448    }
449
450    async fn delete_push_notification_config(&self, name: &str) -> bool {
451        let mut inner = self.inner.lock().expect("storage mutex poisoned");
452        inner.push_notification_configs.remove(name).is_some()
453    }
454}
455
456/// Extract the bare task id from a resource name of the form `tasks/{task_id}`.
457/// Returns `None` if `name` does not start with the `tasks/` prefix.
458pub fn parse_task_name(name: &str) -> Option<&str> {
459    name.strip_prefix("tasks/")
460        .filter(|rest| !rest.is_empty() && !rest.contains('/'))
461}
462
463/// Construct a `Storage` backend from a [`QueueConfig`]. The default
464/// (`provider = Memory`) returns an [`InMemoryStorage`]; with the
465/// `redis` cargo feature enabled and `provider = Redis`, this connects
466/// to the configured URL and returns a `RedisStorage`. Without the
467/// `redis` feature, selecting Redis errors at construction time.
468pub async fn create_storage(
469    cfg: &crate::config::QueueConfig,
470) -> Result<std::sync::Arc<dyn Storage>> {
471    use crate::config::QueueProvider;
472    match cfg.provider {
473        QueueProvider::Memory => Ok(std::sync::Arc::new(InMemoryStorage::new())),
474        QueueProvider::Redis => {
475            #[cfg(feature = "redis")]
476            {
477                let url = cfg.url.as_deref().ok_or_else(|| {
478                    anyhow!("A2A_QUEUE_URL is required when A2A_QUEUE_PROVIDER=redis")
479                })?;
480                let storage =
481                    super::storage_redis::RedisStorage::connect(url, &cfg.namespace).await?;
482                Ok(std::sync::Arc::new(storage))
483            }
484            #[cfg(not(feature = "redis"))]
485            {
486                Err(anyhow!(
487                    "A2A_QUEUE_PROVIDER=redis requires the `redis` cargo feature. \
488                     Rebuild with `--features redis`."
489                ))
490            }
491        }
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::a2a_types::{
499        PushNotificationConfig, TaskPushNotificationConfig, TaskState, TaskStatus, Timestamp,
500    };
501
502    fn make_task(id: &str) -> Task {
503        make_task_in_context(id, "ctx")
504    }
505
506    fn make_task_in_context(id: &str, context_id: &str) -> Task {
507        Task {
508            artifacts: vec![],
509            context_id: context_id.to_string(),
510            history: vec![],
511            id: id.to_string(),
512            metadata: None,
513            status: TaskStatus {
514                message: None,
515                state: TaskState::TaskStateSubmitted,
516                timestamp: Some(Timestamp(Utc::now())),
517            },
518        }
519    }
520
521    fn make_config(name: &str, url: &str) -> TaskPushNotificationConfig {
522        TaskPushNotificationConfig {
523            name: name.to_string(),
524            push_notification_config: PushNotificationConfig {
525                authentication: None,
526                id: None,
527                token: None,
528                url: url.to_string(),
529            },
530        }
531    }
532
533    // ----- Queue ----------------------------------------------------
534
535    #[tokio::test]
536    async fn queue_enqueue_dequeue_round_trip() {
537        let storage = InMemoryStorage::new();
538        let task = make_task("t1");
539        storage
540            .enqueue_task(task.clone(), Value::String("req-1".to_string()))
541            .await
542            .expect("enqueue");
543        assert_eq!(storage.queue_length().await, 1);
544
545        let dequeued = storage.dequeue_task().await.expect("dequeue");
546        assert_eq!(dequeued.task.id, "t1");
547        assert_eq!(dequeued.request_id, Value::String("req-1".to_string()));
548        assert_eq!(storage.queue_length().await, 0);
549    }
550
551    #[tokio::test]
552    async fn dequeue_parks_until_enqueue() {
553        let storage = std::sync::Arc::new(InMemoryStorage::new());
554        let storage_consumer = storage.clone();
555        let consumer = tokio::spawn(async move { storage_consumer.dequeue_task().await });
556        tokio::task::yield_now().await;
557        storage
558            .enqueue_task(make_task("t2"), Value::Null)
559            .await
560            .expect("enqueue");
561        let result = consumer.await.expect("join").expect("dequeue");
562        assert_eq!(result.task.id, "t2");
563    }
564
565    #[tokio::test]
566    async fn clear_queue_drops_pending_tasks() {
567        let storage = InMemoryStorage::new();
568        for n in 0..3 {
569            storage
570                .enqueue_task(make_task(&format!("t{n}")), Value::Null)
571                .await
572                .expect("enqueue");
573        }
574        assert_eq!(storage.queue_length().await, 3);
575        storage.clear_queue().await.expect("clear");
576        assert_eq!(storage.queue_length().await, 0);
577    }
578
579    // ----- Active + dead-letter -------------------------------------
580
581    #[tokio::test]
582    async fn create_active_task_rejects_duplicates() {
583        let storage = InMemoryStorage::new();
584        storage
585            .create_active_task(&make_task("t1"))
586            .await
587            .expect("first create");
588        let err = storage
589            .create_active_task(&make_task("t1"))
590            .await
591            .expect_err("duplicate must fail");
592        assert!(err.to_string().contains("already exists"));
593    }
594
595    #[tokio::test]
596    async fn update_active_task_requires_existing_entry() {
597        let storage = InMemoryStorage::new();
598        let err = storage
599            .update_active_task(&make_task("missing"))
600            .await
601            .expect_err("update of missing task must fail");
602        assert!(err.to_string().contains("not found"));
603    }
604
605    #[tokio::test]
606    async fn store_dead_letter_moves_task_out_of_active() {
607        let storage = InMemoryStorage::new();
608        let mut task = make_task("t1");
609        storage.create_active_task(&task).await.expect("create");
610        task.status.state = TaskState::TaskStateCompleted;
611        storage
612            .store_dead_letter_task(&task)
613            .await
614            .expect("dead-letter");
615
616        assert!(
617            storage.get_active_task("t1").await.expect("ok").is_none(),
618            "task should be removed from active"
619        );
620        let fetched = storage.get_task("t1").await.expect("task in dead-letter");
621        assert_eq!(fetched.status.state, TaskState::TaskStateCompleted);
622    }
623
624    #[tokio::test]
625    async fn list_tasks_includes_active_and_dead_letter() {
626        let storage = InMemoryStorage::new();
627        storage
628            .create_active_task(&make_task("active-1"))
629            .await
630            .expect("active create");
631        let mut dead = make_task("dead-1");
632        dead.status.state = TaskState::TaskStateFailed;
633        storage
634            .store_dead_letter_task(&dead)
635            .await
636            .expect("dead-letter store");
637
638        let all = storage.list_tasks(TaskFilter::default()).await;
639        let ids: Vec<String> = all.iter().map(|t| t.id.clone()).collect();
640        assert_eq!(all.len(), 2, "expected 2 tasks across stores, got {ids:?}");
641        assert!(ids.contains(&"active-1".to_string()));
642        assert!(ids.contains(&"dead-1".to_string()));
643    }
644
645    #[tokio::test]
646    async fn list_tasks_applies_state_filter() {
647        let storage = InMemoryStorage::new();
648        let mut active = make_task("a");
649        active.status.state = TaskState::TaskStateWorking;
650        storage.create_active_task(&active).await.expect("create");
651        let mut dead = make_task("d");
652        dead.status.state = TaskState::TaskStateFailed;
653        storage
654            .store_dead_letter_task(&dead)
655            .await
656            .expect("dead-letter store");
657
658        let filter = TaskFilter {
659            state: Some(TaskState::TaskStateFailed),
660            ..Default::default()
661        };
662        let filtered = storage.list_tasks(filter).await;
663        assert_eq!(filtered.len(), 1);
664        assert_eq!(filtered[0].id, "d");
665    }
666
667    #[tokio::test]
668    async fn list_tasks_by_context_scopes_results() {
669        let storage = InMemoryStorage::new();
670        storage
671            .create_active_task(&make_task_in_context("a", "ctx-1"))
672            .await
673            .expect("create a");
674        storage
675            .create_active_task(&make_task_in_context("b", "ctx-2"))
676            .await
677            .expect("create b");
678
679        let scoped = storage
680            .list_tasks_by_context("ctx-1", TaskFilter::default())
681            .await;
682        assert_eq!(scoped.len(), 1);
683        assert_eq!(scoped[0].id, "a");
684    }
685
686    #[tokio::test]
687    async fn delete_task_removes_from_both_stores() {
688        let storage = InMemoryStorage::new();
689        storage
690            .create_active_task(&make_task("active"))
691            .await
692            .expect("create");
693        let mut dead = make_task("dead");
694        dead.status.state = TaskState::TaskStateFailed;
695        storage
696            .store_dead_letter_task(&dead)
697            .await
698            .expect("dead-letter store");
699
700        storage.delete_task("active").await.expect("delete active");
701        storage.delete_task("dead").await.expect("delete dead");
702        assert!(storage.get_task("active").await.is_none());
703        assert!(storage.get_task("dead").await.is_none());
704        assert!(
705            storage.delete_task("nonexistent").await.is_err(),
706            "deleting unknown task must error"
707        );
708    }
709
710    // ----- Contexts -------------------------------------------------
711
712    #[tokio::test]
713    async fn contexts_track_seen_tasks() {
714        let storage = InMemoryStorage::new();
715        storage
716            .enqueue_task(make_task_in_context("t1", "ctx-q"), Value::Null)
717            .await
718            .expect("enqueue");
719        storage
720            .create_active_task(&make_task_in_context("t2", "ctx-a"))
721            .await
722            .expect("create");
723
724        let mut contexts = storage.get_contexts().await;
725        contexts.sort();
726        assert_eq!(contexts, vec!["ctx-a".to_string(), "ctx-q".to_string()]);
727    }
728
729    #[tokio::test]
730    async fn delete_context_and_tasks_clears_both_stores() {
731        let storage = InMemoryStorage::new();
732        storage
733            .create_active_task(&make_task_in_context("a", "ctx-x"))
734            .await
735            .expect("create active");
736        let mut dead = make_task_in_context("b", "ctx-x");
737        dead.status.state = TaskState::TaskStateFailed;
738        storage
739            .store_dead_letter_task(&dead)
740            .await
741            .expect("dead-letter");
742        storage
743            .create_active_task(&make_task_in_context("survivor", "ctx-y"))
744            .await
745            .expect("create survivor");
746
747        storage
748            .delete_context_and_tasks("ctx-x")
749            .await
750            .expect("delete context");
751        let remaining = storage.list_tasks(TaskFilter::default()).await;
752        assert_eq!(remaining.len(), 1);
753        assert_eq!(remaining[0].id, "survivor");
754    }
755
756    // ----- Cleanup / stats ------------------------------------------
757
758    #[tokio::test]
759    async fn cleanup_completed_tasks_only_drops_completed() {
760        let storage = InMemoryStorage::new();
761        let mut done = make_task("done");
762        done.status.state = TaskState::TaskStateCompleted;
763        let mut failed = make_task("failed");
764        failed.status.state = TaskState::TaskStateFailed;
765        storage
766            .store_dead_letter_task(&done)
767            .await
768            .expect("dead-letter completed");
769        storage
770            .store_dead_letter_task(&failed)
771            .await
772            .expect("dead-letter failed");
773
774        let removed = storage.cleanup_completed_tasks().await;
775        assert_eq!(removed, 1);
776        let remaining = storage.list_tasks(TaskFilter::default()).await;
777        assert_eq!(remaining.len(), 1);
778        assert_eq!(remaining[0].id, "failed");
779    }
780
781    #[tokio::test]
782    async fn cleanup_with_retention_keeps_newest() {
783        let storage = InMemoryStorage::new();
784        for (i, id) in ["old", "mid", "new"].iter().enumerate() {
785            let mut t = make_task(id);
786            t.status.state = TaskState::TaskStateCompleted;
787            t.status.timestamp = Some(Timestamp(Utc::now() + chrono::Duration::seconds(i as i64)));
788            storage
789                .store_dead_letter_task(&t)
790                .await
791                .expect("dead-letter store");
792        }
793        let removed = storage.cleanup_tasks_with_retention(1, 0).await;
794        assert_eq!(removed, 2);
795        let remaining = storage.list_tasks(TaskFilter::default()).await;
796        assert_eq!(remaining.len(), 1);
797        assert_eq!(remaining[0].id, "new");
798    }
799
800    #[tokio::test]
801    async fn stats_count_everything() {
802        let storage = InMemoryStorage::new();
803        storage
804            .enqueue_task(make_task("queued"), Value::Null)
805            .await
806            .expect("enqueue");
807        storage
808            .create_active_task(&make_task("active"))
809            .await
810            .expect("create");
811        let mut dead = make_task("dead");
812        dead.status.state = TaskState::TaskStateFailed;
813        storage
814            .store_dead_letter_task(&dead)
815            .await
816            .expect("dead-letter store");
817
818        let stats = storage.get_stats().await;
819        assert_eq!(stats.queue_length, 1);
820        assert_eq!(stats.active_tasks, 1);
821        assert_eq!(stats.dead_letter_tasks, 1);
822        assert_eq!(stats.contexts, 1, "all three share 'ctx'");
823    }
824
825    // ----- Existing surface (smoke tests) ---------------------------
826
827    #[tokio::test]
828    async fn get_task_falls_back_to_dead_letter() {
829        let storage = InMemoryStorage::new();
830        let mut task = make_task("t1");
831        task.status.state = TaskState::TaskStateCompleted;
832        storage
833            .store_dead_letter_task(&task)
834            .await
835            .expect("dead-letter store");
836        let got = storage.get_task("t1").await.expect("dead-letter read");
837        assert_eq!(got.status.state, TaskState::TaskStateCompleted);
838    }
839
840    #[tokio::test]
841    async fn push_notification_configs_filter_by_parent() {
842        let storage = InMemoryStorage::new();
843        storage
844            .put_push_notification_config(make_config(
845                "tasks/abc/pushNotificationConfigs/c1",
846                "https://a.example/webhook",
847            ))
848            .await;
849        storage
850            .put_push_notification_config(make_config(
851                "tasks/abc/pushNotificationConfigs/c2",
852                "https://b.example/webhook",
853            ))
854            .await;
855        storage
856            .put_push_notification_config(make_config(
857                "tasks/other/pushNotificationConfigs/c3",
858                "https://c.example/webhook",
859            ))
860            .await;
861
862        let configs = storage.list_push_notification_configs("tasks/abc").await;
863        assert_eq!(configs.len(), 2);
864
865        assert!(
866            storage
867                .delete_push_notification_config("tasks/abc/pushNotificationConfigs/c1")
868                .await
869        );
870        assert_eq!(
871            storage
872                .list_push_notification_configs("tasks/abc")
873                .await
874                .len(),
875            1
876        );
877        assert!(
878            !storage
879                .delete_push_notification_config("tasks/abc/pushNotificationConfigs/c1")
880                .await
881        );
882    }
883
884    #[test]
885    fn parse_task_name_strips_prefix() {
886        assert_eq!(parse_task_name("tasks/abc"), Some("abc"));
887        assert_eq!(
888            parse_task_name("tasks/abc/pushNotificationConfigs/c1"),
889            None
890        );
891        assert_eq!(parse_task_name("tasks/"), None);
892        assert_eq!(parse_task_name("notasks/abc"), None);
893    }
894
895    #[tokio::test]
896    async fn dyn_storage_dispatches_through_trait() {
897        let storage: std::sync::Arc<dyn Storage> = std::sync::Arc::new(InMemoryStorage::new());
898        storage
899            .create_active_task(&make_task("abc"))
900            .await
901            .expect("create");
902        let got = storage
903            .get_task("abc")
904            .await
905            .expect("task should be present");
906        assert_eq!(got.id, "abc");
907    }
908}