Skip to main content

claude_pool/
store.rs

1//! Pluggable storage backend for pool state.
2//!
3//! The [`PoolStore`] trait abstracts where task and slot records live.
4//! [`InMemoryStore`] keeps everything in-process; a future `RedisStore`
5//! could share state across multiple pool server instances.
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::error::Result;
11use crate::types::*;
12
13/// Trait for storing and retrieving pool state.
14///
15/// Implementations must be `Send + Sync` for use in async contexts.
16#[async_trait]
17pub trait PoolStore: Send + Sync {
18    /// Insert or update a task record.
19    async fn put_task(&self, record: TaskRecord) -> Result<()>;
20
21    /// Get a task by ID.
22    async fn get_task(&self, id: &TaskId) -> Result<Option<TaskRecord>>;
23
24    /// List tasks matching an optional filter.
25    async fn list_tasks(&self, filter: &TaskFilter) -> Result<Vec<TaskRecord>>;
26
27    /// Delete a task record.
28    async fn delete_task(&self, id: &TaskId) -> Result<bool>;
29
30    /// Insert or update a slot record.
31    async fn put_slot(&self, record: SlotRecord) -> Result<()>;
32
33    /// Get a slot by ID.
34    async fn get_slot(&self, id: &SlotId) -> Result<Option<SlotRecord>>;
35
36    /// List all slots.
37    async fn list_slots(&self) -> Result<Vec<SlotRecord>>;
38
39    /// Delete a slot record.
40    async fn delete_slot(&self, id: &SlotId) -> Result<bool>;
41}
42
43/// In-memory store using [`DashMap`] for concurrent access.
44///
45/// All data is lost when the process exits. Suitable for single-session
46/// usage and development.
47#[derive(Debug, Default)]
48pub struct InMemoryStore {
49    tasks: DashMap<String, TaskRecord>,
50    slots: DashMap<String, SlotRecord>,
51}
52
53impl InMemoryStore {
54    /// Create a new empty in-memory store.
55    pub fn new() -> Self {
56        Self::default()
57    }
58}
59
60#[async_trait]
61impl PoolStore for InMemoryStore {
62    async fn put_task(&self, record: TaskRecord) -> Result<()> {
63        self.tasks.insert(record.id.0.clone(), record);
64        Ok(())
65    }
66
67    async fn get_task(&self, id: &TaskId) -> Result<Option<TaskRecord>> {
68        Ok(self.tasks.get(&id.0).map(|r| r.value().clone()))
69    }
70
71    async fn list_tasks(&self, filter: &TaskFilter) -> Result<Vec<TaskRecord>> {
72        let tasks: Vec<TaskRecord> = self
73            .tasks
74            .iter()
75            .map(|r| r.value().clone())
76            .filter(|t| {
77                if let Some(state) = filter.state
78                    && t.state != state
79                {
80                    return false;
81                }
82                if let Some(ref wid) = filter.slot_id
83                    && t.slot_id.as_ref() != Some(wid)
84                {
85                    return false;
86                }
87                if let Some(ref tags) = filter.tags
88                    && !tags.iter().any(|tag| t.tags.contains(tag))
89                {
90                    return false;
91                }
92                true
93            })
94            .collect();
95        Ok(tasks)
96    }
97
98    async fn delete_task(&self, id: &TaskId) -> Result<bool> {
99        Ok(self.tasks.remove(&id.0).is_some())
100    }
101
102    async fn put_slot(&self, record: SlotRecord) -> Result<()> {
103        self.slots.insert(record.id.0.clone(), record);
104        Ok(())
105    }
106
107    async fn get_slot(&self, id: &SlotId) -> Result<Option<SlotRecord>> {
108        Ok(self.slots.get(&id.0).map(|r| r.value().clone()))
109    }
110
111    async fn list_slots(&self) -> Result<Vec<SlotRecord>> {
112        Ok(self.slots.iter().map(|r| r.value().clone()).collect())
113    }
114
115    async fn delete_slot(&self, id: &SlotId) -> Result<bool> {
116        Ok(self.slots.remove(&id.0).is_some())
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[tokio::test]
125    async fn task_crud() {
126        let store = InMemoryStore::new();
127        let id = TaskId("t-1".into());
128
129        let record = TaskRecord {
130            id: id.clone(),
131            prompt: "write tests".into(),
132            state: TaskState::Pending,
133            slot_id: None,
134            result: None,
135            tags: vec!["testing".into()],
136            config: None,
137        };
138
139        store.put_task(record).await.unwrap();
140
141        let fetched = store.get_task(&id).await.unwrap().unwrap();
142        assert_eq!(fetched.prompt, "write tests");
143        assert_eq!(fetched.state, TaskState::Pending);
144
145        let all = store.list_tasks(&TaskFilter::default()).await.unwrap();
146        assert_eq!(all.len(), 1);
147
148        let deleted = store.delete_task(&id).await.unwrap();
149        assert!(deleted);
150        assert!(store.get_task(&id).await.unwrap().is_none());
151    }
152
153    #[tokio::test]
154    async fn slot_crud() {
155        let store = InMemoryStore::new();
156        let id = SlotId("w-0".into());
157
158        let record = SlotRecord {
159            id: id.clone(),
160            state: SlotState::Idle,
161            config: SlotConfig::default(),
162            current_task: None,
163            session_id: None,
164            tasks_completed: 0,
165            cost_microdollars: 0,
166            restart_count: 0,
167            worktree_path: None,
168        };
169
170        store.put_slot(record).await.unwrap();
171
172        let fetched = store.get_slot(&id).await.unwrap().unwrap();
173        assert_eq!(fetched.state, SlotState::Idle);
174
175        let all = store.list_slots().await.unwrap();
176        assert_eq!(all.len(), 1);
177
178        let deleted = store.delete_slot(&id).await.unwrap();
179        assert!(deleted);
180        assert!(store.get_slot(&id).await.unwrap().is_none());
181    }
182
183    #[tokio::test]
184    async fn task_filter_by_state() {
185        let store = InMemoryStore::new();
186
187        for i in 0..3 {
188            let state = if i == 0 {
189                TaskState::Pending
190            } else {
191                TaskState::Completed
192            };
193            store
194                .put_task(TaskRecord {
195                    id: TaskId(format!("t-{i}")),
196                    prompt: format!("task {i}"),
197                    state,
198                    slot_id: None,
199                    result: None,
200                    tags: vec![],
201                    config: None,
202                })
203                .await
204                .unwrap();
205        }
206
207        let pending = store
208            .list_tasks(&TaskFilter {
209                state: Some(TaskState::Pending),
210                ..Default::default()
211            })
212            .await
213            .unwrap();
214        assert_eq!(pending.len(), 1);
215
216        let completed = store
217            .list_tasks(&TaskFilter {
218                state: Some(TaskState::Completed),
219                ..Default::default()
220            })
221            .await
222            .unwrap();
223        assert_eq!(completed.len(), 2);
224    }
225}