Skip to main content

claude_pool/
store.rs

1//! Pluggable storage backend for pool state.
2//!
3//! The [`PoolStore`] trait abstracts where task and slot records live.
4//! [`InMemoryStore`] keeps everything in-process; a future `RedisStore`
5//! could share state across multiple pool server instances.
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::error::Result;
11use crate::types::*;
12
13/// Trait for storing and retrieving pool state.
14///
15/// Implementations must be `Send + Sync` for use in async contexts.
16#[async_trait]
17pub trait PoolStore: Send + Sync {
18    /// Insert or update a task record.
19    async fn put_task(&self, record: TaskRecord) -> Result<()>;
20
21    /// Get a task by ID.
22    async fn get_task(&self, id: &TaskId) -> Result<Option<TaskRecord>>;
23
24    /// List tasks matching an optional filter.
25    async fn list_tasks(&self, filter: &TaskFilter) -> Result<Vec<TaskRecord>>;
26
27    /// Delete a task record.
28    async fn delete_task(&self, id: &TaskId) -> Result<bool>;
29
30    /// Insert or update a slot record.
31    async fn put_slot(&self, record: SlotRecord) -> Result<()>;
32
33    /// Get a slot by ID.
34    async fn get_slot(&self, id: &SlotId) -> Result<Option<SlotRecord>>;
35
36    /// List all slots.
37    async fn list_slots(&self) -> Result<Vec<SlotRecord>>;
38
39    /// Delete a slot record.
40    async fn delete_slot(&self, id: &SlotId) -> Result<bool>;
41}
42
43/// In-memory store using [`DashMap`] for concurrent access.
44///
45/// All data is lost when the process exits. Suitable for single-session
46/// usage and development.
47#[derive(Debug, Default)]
48pub struct InMemoryStore {
49    tasks: DashMap<String, TaskRecord>,
50    slots: DashMap<String, SlotRecord>,
51}
52
53impl InMemoryStore {
54    /// Create a new empty in-memory store.
55    pub fn new() -> Self {
56        Self::default()
57    }
58}
59
60#[async_trait]
61impl PoolStore for InMemoryStore {
62    async fn put_task(&self, record: TaskRecord) -> Result<()> {
63        self.tasks.insert(record.id.0.clone(), record);
64        Ok(())
65    }
66
67    async fn get_task(&self, id: &TaskId) -> Result<Option<TaskRecord>> {
68        Ok(self.tasks.get(&id.0).map(|r| r.value().clone()))
69    }
70
71    async fn list_tasks(&self, filter: &TaskFilter) -> Result<Vec<TaskRecord>> {
72        let tasks: Vec<TaskRecord> = self
73            .tasks
74            .iter()
75            .map(|r| r.value().clone())
76            .filter(|t| {
77                if let Some(state) = filter.state
78                    && t.state != state
79                {
80                    return false;
81                }
82                if let Some(ref wid) = filter.slot_id
83                    && t.slot_id.as_ref() != Some(wid)
84                {
85                    return false;
86                }
87                if let Some(ref tags) = filter.tags
88                    && !tags.iter().any(|tag| t.tags.contains(tag))
89                {
90                    return false;
91                }
92                true
93            })
94            .collect();
95        Ok(tasks)
96    }
97
98    async fn delete_task(&self, id: &TaskId) -> Result<bool> {
99        Ok(self.tasks.remove(&id.0).is_some())
100    }
101
102    async fn put_slot(&self, record: SlotRecord) -> Result<()> {
103        self.slots.insert(record.id.0.clone(), record);
104        Ok(())
105    }
106
107    async fn get_slot(&self, id: &SlotId) -> Result<Option<SlotRecord>> {
108        Ok(self.slots.get(&id.0).map(|r| r.value().clone()))
109    }
110
111    async fn list_slots(&self) -> Result<Vec<SlotRecord>> {
112        Ok(self.slots.iter().map(|r| r.value().clone()).collect())
113    }
114
115    async fn delete_slot(&self, id: &SlotId) -> Result<bool> {
116        Ok(self.slots.remove(&id.0).is_some())
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[tokio::test]
125    async fn task_crud() {
126        let store = InMemoryStore::new();
127        let id = TaskId("t-1".into());
128
129        let record = TaskRecord {
130            id: id.clone(),
131            prompt: "write tests".into(),
132            state: TaskState::Pending,
133            slot_id: None,
134            result: None,
135            tags: vec!["testing".into()],
136            config: None,
137            review_required: false,
138            max_rejections: 3,
139            rejection_count: 0,
140            original_prompt: None,
141        };
142
143        store.put_task(record).await.unwrap();
144
145        let fetched = store.get_task(&id).await.unwrap().unwrap();
146        assert_eq!(fetched.prompt, "write tests");
147        assert_eq!(fetched.state, TaskState::Pending);
148
149        let all = store.list_tasks(&TaskFilter::default()).await.unwrap();
150        assert_eq!(all.len(), 1);
151
152        let deleted = store.delete_task(&id).await.unwrap();
153        assert!(deleted);
154        assert!(store.get_task(&id).await.unwrap().is_none());
155    }
156
157    #[tokio::test]
158    async fn slot_crud() {
159        let store = InMemoryStore::new();
160        let id = SlotId("w-0".into());
161
162        let record = SlotRecord {
163            id: id.clone(),
164            state: SlotState::Idle,
165            config: SlotConfig::default(),
166            current_task: None,
167            session_id: None,
168            tasks_completed: 0,
169            cost_microdollars: 0,
170            restart_count: 0,
171            worktree_path: None,
172            mcp_config_path: None,
173        };
174
175        store.put_slot(record).await.unwrap();
176
177        let fetched = store.get_slot(&id).await.unwrap().unwrap();
178        assert_eq!(fetched.state, SlotState::Idle);
179
180        let all = store.list_slots().await.unwrap();
181        assert_eq!(all.len(), 1);
182
183        let deleted = store.delete_slot(&id).await.unwrap();
184        assert!(deleted);
185        assert!(store.get_slot(&id).await.unwrap().is_none());
186    }
187
188    #[tokio::test]
189    async fn task_filter_by_state() {
190        let store = InMemoryStore::new();
191
192        for i in 0..3 {
193            let state = if i == 0 {
194                TaskState::Pending
195            } else {
196                TaskState::Completed
197            };
198            store
199                .put_task(TaskRecord {
200                    id: TaskId(format!("t-{i}")),
201                    prompt: format!("task {i}"),
202                    state,
203                    slot_id: None,
204                    result: None,
205                    tags: vec![],
206                    config: None,
207                    review_required: false,
208                    max_rejections: 3,
209                    rejection_count: 0,
210                    original_prompt: None,
211                })
212                .await
213                .unwrap();
214        }
215
216        let pending = store
217            .list_tasks(&TaskFilter {
218                state: Some(TaskState::Pending),
219                ..Default::default()
220            })
221            .await
222            .unwrap();
223        assert_eq!(pending.len(), 1);
224
225        let completed = store
226            .list_tasks(&TaskFilter {
227                state: Some(TaskState::Completed),
228                ..Default::default()
229            })
230            .await
231            .unwrap();
232        assert_eq!(completed.len(), 2);
233    }
234}