1use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::error::Result;
11use crate::types::*;
12
13#[async_trait]
17pub trait PoolStore: Send + Sync {
18 async fn put_task(&self, record: TaskRecord) -> Result<()>;
20
21 async fn get_task(&self, id: &TaskId) -> Result<Option<TaskRecord>>;
23
24 async fn list_tasks(&self, filter: &TaskFilter) -> Result<Vec<TaskRecord>>;
26
27 async fn delete_task(&self, id: &TaskId) -> Result<bool>;
29
30 async fn put_slot(&self, record: SlotRecord) -> Result<()>;
32
33 async fn get_slot(&self, id: &SlotId) -> Result<Option<SlotRecord>>;
35
36 async fn list_slots(&self) -> Result<Vec<SlotRecord>>;
38
39 async fn delete_slot(&self, id: &SlotId) -> Result<bool>;
41}
42
43#[derive(Debug, Default)]
48pub struct InMemoryStore {
49 tasks: DashMap<String, TaskRecord>,
50 slots: DashMap<String, SlotRecord>,
51}
52
53impl InMemoryStore {
54 pub fn new() -> Self {
56 Self::default()
57 }
58}
59
60#[async_trait]
61impl PoolStore for InMemoryStore {
62 async fn put_task(&self, record: TaskRecord) -> Result<()> {
63 self.tasks.insert(record.id.0.clone(), record);
64 Ok(())
65 }
66
67 async fn get_task(&self, id: &TaskId) -> Result<Option<TaskRecord>> {
68 Ok(self.tasks.get(&id.0).map(|r| r.value().clone()))
69 }
70
71 async fn list_tasks(&self, filter: &TaskFilter) -> Result<Vec<TaskRecord>> {
72 let tasks: Vec<TaskRecord> = self
73 .tasks
74 .iter()
75 .map(|r| r.value().clone())
76 .filter(|t| {
77 if let Some(state) = filter.state
78 && t.state != state
79 {
80 return false;
81 }
82 if let Some(ref wid) = filter.slot_id
83 && t.slot_id.as_ref() != Some(wid)
84 {
85 return false;
86 }
87 if let Some(ref tags) = filter.tags
88 && !tags.iter().any(|tag| t.tags.contains(tag))
89 {
90 return false;
91 }
92 true
93 })
94 .collect();
95 Ok(tasks)
96 }
97
98 async fn delete_task(&self, id: &TaskId) -> Result<bool> {
99 Ok(self.tasks.remove(&id.0).is_some())
100 }
101
102 async fn put_slot(&self, record: SlotRecord) -> Result<()> {
103 self.slots.insert(record.id.0.clone(), record);
104 Ok(())
105 }
106
107 async fn get_slot(&self, id: &SlotId) -> Result<Option<SlotRecord>> {
108 Ok(self.slots.get(&id.0).map(|r| r.value().clone()))
109 }
110
111 async fn list_slots(&self) -> Result<Vec<SlotRecord>> {
112 Ok(self.slots.iter().map(|r| r.value().clone()).collect())
113 }
114
115 async fn delete_slot(&self, id: &SlotId) -> Result<bool> {
116 Ok(self.slots.remove(&id.0).is_some())
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[tokio::test]
125 async fn task_crud() {
126 let store = InMemoryStore::new();
127 let id = TaskId("t-1".into());
128
129 let record = TaskRecord {
130 id: id.clone(),
131 prompt: "write tests".into(),
132 state: TaskState::Pending,
133 slot_id: None,
134 result: None,
135 tags: vec!["testing".into()],
136 config: None,
137 };
138
139 store.put_task(record).await.unwrap();
140
141 let fetched = store.get_task(&id).await.unwrap().unwrap();
142 assert_eq!(fetched.prompt, "write tests");
143 assert_eq!(fetched.state, TaskState::Pending);
144
145 let all = store.list_tasks(&TaskFilter::default()).await.unwrap();
146 assert_eq!(all.len(), 1);
147
148 let deleted = store.delete_task(&id).await.unwrap();
149 assert!(deleted);
150 assert!(store.get_task(&id).await.unwrap().is_none());
151 }
152
153 #[tokio::test]
154 async fn slot_crud() {
155 let store = InMemoryStore::new();
156 let id = SlotId("w-0".into());
157
158 let record = SlotRecord {
159 id: id.clone(),
160 state: SlotState::Idle,
161 config: SlotConfig::default(),
162 current_task: None,
163 session_id: None,
164 tasks_completed: 0,
165 cost_microdollars: 0,
166 restart_count: 0,
167 worktree_path: None,
168 mcp_config_path: None,
169 };
170
171 store.put_slot(record).await.unwrap();
172
173 let fetched = store.get_slot(&id).await.unwrap().unwrap();
174 assert_eq!(fetched.state, SlotState::Idle);
175
176 let all = store.list_slots().await.unwrap();
177 assert_eq!(all.len(), 1);
178
179 let deleted = store.delete_slot(&id).await.unwrap();
180 assert!(deleted);
181 assert!(store.get_slot(&id).await.unwrap().is_none());
182 }
183
184 #[tokio::test]
185 async fn task_filter_by_state() {
186 let store = InMemoryStore::new();
187
188 for i in 0..3 {
189 let state = if i == 0 {
190 TaskState::Pending
191 } else {
192 TaskState::Completed
193 };
194 store
195 .put_task(TaskRecord {
196 id: TaskId(format!("t-{i}")),
197 prompt: format!("task {i}"),
198 state,
199 slot_id: None,
200 result: None,
201 tags: vec![],
202 config: None,
203 })
204 .await
205 .unwrap();
206 }
207
208 let pending = store
209 .list_tasks(&TaskFilter {
210 state: Some(TaskState::Pending),
211 ..Default::default()
212 })
213 .await
214 .unwrap();
215 assert_eq!(pending.len(), 1);
216
217 let completed = store
218 .list_tasks(&TaskFilter {
219 state: Some(TaskState::Completed),
220 ..Default::default()
221 })
222 .await
223 .unwrap();
224 assert_eq!(completed.len(), 2);
225 }
226}