1use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::error::Result;
11use crate::types::*;
12
13#[async_trait]
17pub trait PoolStore: Send + Sync {
18 async fn put_task(&self, record: TaskRecord) -> Result<()>;
20
21 async fn get_task(&self, id: &TaskId) -> Result<Option<TaskRecord>>;
23
24 async fn list_tasks(&self, filter: &TaskFilter) -> Result<Vec<TaskRecord>>;
26
27 async fn delete_task(&self, id: &TaskId) -> Result<bool>;
29
30 async fn put_slot(&self, record: SlotRecord) -> Result<()>;
32
33 async fn get_slot(&self, id: &SlotId) -> Result<Option<SlotRecord>>;
35
36 async fn list_slots(&self) -> Result<Vec<SlotRecord>>;
38
39 async fn delete_slot(&self, id: &SlotId) -> Result<bool>;
41}
42
43#[derive(Debug, Default)]
48pub struct InMemoryStore {
49 tasks: DashMap<String, TaskRecord>,
50 slots: DashMap<String, SlotRecord>,
51}
52
53impl InMemoryStore {
54 pub fn new() -> Self {
56 Self::default()
57 }
58}
59
60#[async_trait]
61impl PoolStore for InMemoryStore {
62 async fn put_task(&self, record: TaskRecord) -> Result<()> {
63 self.tasks.insert(record.id.0.clone(), record);
64 Ok(())
65 }
66
67 async fn get_task(&self, id: &TaskId) -> Result<Option<TaskRecord>> {
68 Ok(self.tasks.get(&id.0).map(|r| r.value().clone()))
69 }
70
71 async fn list_tasks(&self, filter: &TaskFilter) -> Result<Vec<TaskRecord>> {
72 let tasks: Vec<TaskRecord> = self
73 .tasks
74 .iter()
75 .map(|r| r.value().clone())
76 .filter(|t| {
77 if let Some(state) = filter.state
78 && t.state != state
79 {
80 return false;
81 }
82 if let Some(ref wid) = filter.slot_id
83 && t.slot_id.as_ref() != Some(wid)
84 {
85 return false;
86 }
87 if let Some(ref tags) = filter.tags
88 && !tags.iter().any(|tag| t.tags.contains(tag))
89 {
90 return false;
91 }
92 true
93 })
94 .collect();
95 Ok(tasks)
96 }
97
98 async fn delete_task(&self, id: &TaskId) -> Result<bool> {
99 Ok(self.tasks.remove(&id.0).is_some())
100 }
101
102 async fn put_slot(&self, record: SlotRecord) -> Result<()> {
103 self.slots.insert(record.id.0.clone(), record);
104 Ok(())
105 }
106
107 async fn get_slot(&self, id: &SlotId) -> Result<Option<SlotRecord>> {
108 Ok(self.slots.get(&id.0).map(|r| r.value().clone()))
109 }
110
111 async fn list_slots(&self) -> Result<Vec<SlotRecord>> {
112 Ok(self.slots.iter().map(|r| r.value().clone()).collect())
113 }
114
115 async fn delete_slot(&self, id: &SlotId) -> Result<bool> {
116 Ok(self.slots.remove(&id.0).is_some())
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[tokio::test]
125 async fn task_crud() {
126 let store = InMemoryStore::new();
127 let id = TaskId("t-1".into());
128
129 let record = TaskRecord {
130 id: id.clone(),
131 prompt: "write tests".into(),
132 state: TaskState::Pending,
133 slot_id: None,
134 result: None,
135 tags: vec!["testing".into()],
136 config: None,
137 review_required: false,
138 max_rejections: 3,
139 rejection_count: 0,
140 original_prompt: None,
141 };
142
143 store.put_task(record).await.unwrap();
144
145 let fetched = store.get_task(&id).await.unwrap().unwrap();
146 assert_eq!(fetched.prompt, "write tests");
147 assert_eq!(fetched.state, TaskState::Pending);
148
149 let all = store.list_tasks(&TaskFilter::default()).await.unwrap();
150 assert_eq!(all.len(), 1);
151
152 let deleted = store.delete_task(&id).await.unwrap();
153 assert!(deleted);
154 assert!(store.get_task(&id).await.unwrap().is_none());
155 }
156
157 #[tokio::test]
158 async fn slot_crud() {
159 let store = InMemoryStore::new();
160 let id = SlotId("w-0".into());
161
162 let record = SlotRecord {
163 id: id.clone(),
164 state: SlotState::Idle,
165 config: SlotConfig::default(),
166 current_task: None,
167 session_id: None,
168 tasks_completed: 0,
169 cost_microdollars: 0,
170 restart_count: 0,
171 worktree_path: None,
172 mcp_config_path: None,
173 };
174
175 store.put_slot(record).await.unwrap();
176
177 let fetched = store.get_slot(&id).await.unwrap().unwrap();
178 assert_eq!(fetched.state, SlotState::Idle);
179
180 let all = store.list_slots().await.unwrap();
181 assert_eq!(all.len(), 1);
182
183 let deleted = store.delete_slot(&id).await.unwrap();
184 assert!(deleted);
185 assert!(store.get_slot(&id).await.unwrap().is_none());
186 }
187
188 #[tokio::test]
189 async fn task_filter_by_state() {
190 let store = InMemoryStore::new();
191
192 for i in 0..3 {
193 let state = if i == 0 {
194 TaskState::Pending
195 } else {
196 TaskState::Completed
197 };
198 store
199 .put_task(TaskRecord {
200 id: TaskId(format!("t-{i}")),
201 prompt: format!("task {i}"),
202 state,
203 slot_id: None,
204 result: None,
205 tags: vec![],
206 config: None,
207 review_required: false,
208 max_rejections: 3,
209 rejection_count: 0,
210 original_prompt: None,
211 })
212 .await
213 .unwrap();
214 }
215
216 let pending = store
217 .list_tasks(&TaskFilter {
218 state: Some(TaskState::Pending),
219 ..Default::default()
220 })
221 .await
222 .unwrap();
223 assert_eq!(pending.len(), 1);
224
225 let completed = store
226 .list_tasks(&TaskFilter {
227 state: Some(TaskState::Completed),
228 ..Default::default()
229 })
230 .await
231 .unwrap();
232 assert_eq!(completed.len(), 2);
233 }
234}