Skip to main content

envoy/task/
store.rs

1use sqlitegraph::GraphEntity;
2
3use crate::error::{EnvoyError, Result};
4use crate::task::{Task, TaskState, KIND_TASK};
5
6pub struct TaskStore;
7
8impl Default for TaskStore {
9    fn default() -> Self {
10        Self::new()
11    }
12}
13
14impl TaskStore {
15    pub fn new() -> Self {
16        Self
17    }
18
19    pub fn propose(
20        &self,
21        graph: &sqlitegraph::SqliteGraph,
22        project: String,
23        description: String,
24        blocked_by: Vec<String>,
25    ) -> Result<Task> {
26        let now = chrono::Utc::now().to_rfc3339();
27        let name = format!("task-{}", uuid::Uuid::new_v4());
28        let blocked_json: Vec<serde_json::Value> =
29            blocked_by.iter().map(|s| serde_json::json!(s)).collect();
30        let entity = GraphEntity {
31            id: 0,
32            kind: KIND_TASK.to_string(),
33            name,
34            file_path: None,
35            data: serde_json::json!({
36                "project": project,
37                "description": description,
38                "state": "proposed",
39                "claimed_by": null,
40                "blocked_by": blocked_json,
41                "checkpoint": null,
42                "created_at": now,
43                "updated_at": now,
44            }),
45        };
46        let id = graph.insert_entity(&entity)?;
47        Ok(Task {
48            id: id.to_string(),
49            project,
50            description,
51            state: TaskState::Proposed,
52            claimed_by: None,
53            blocked_by,
54            checkpoint: None,
55            created_at: now.clone(),
56            updated_at: now,
57        })
58    }
59
60    pub fn claim(
61        &self,
62        graph: &sqlitegraph::SqliteGraph,
63        task_id: &str,
64        agent_id: String,
65    ) -> Result<Task> {
66        let task = self.get(graph, task_id)?;
67        if task.state != TaskState::Proposed {
68            return Err(EnvoyError::TaskAlreadyClaimed(
69                task_id.to_string(),
70                task.claimed_by.unwrap_or_default(),
71            ));
72        }
73        let now = chrono::Utc::now().to_rfc3339();
74        self.update_entity(graph, task_id, |data| {
75            data["state"] = serde_json::json!("claimed");
76            data["claimed_by"] = serde_json::json!(&agent_id);
77            data["updated_at"] = serde_json::json!(&now);
78            data["checkpoint"] = serde_json::json!("claimed");
79        })?;
80        Ok(Task {
81            state: TaskState::Claimed,
82            claimed_by: Some(agent_id),
83            updated_at: now,
84            ..task
85        })
86    }
87
88    pub fn claim_next(
89        &self,
90        graph: &sqlitegraph::SqliteGraph,
91        project: &str,
92        agent_id: String,
93    ) -> Result<Task> {
94        let tasks = self.list(graph, project, Some(&TaskState::Proposed))?;
95        if tasks.is_empty() {
96            return Err(EnvoyError::TaskNotFound("no proposed tasks".into()));
97        }
98        let oldest = tasks
99            .into_iter()
100            .min_by_key(|t| t.created_at.clone())
101            .ok_or_else(|| EnvoyError::TaskNotFound("no proposed tasks after filter".into()))?;
102        self.claim(graph, &oldest.id, agent_id)
103    }
104
105    pub fn update_state(
106        &self,
107        graph: &sqlitegraph::SqliteGraph,
108        task_id: &str,
109        new_state: TaskState,
110        checkpoint: Option<String>,
111        agent_id: Option<&str>,
112    ) -> Result<Task> {
113        let task = self.get(graph, task_id)?;
114        if !task.state.can_transition_to(&new_state) {
115            return Err(EnvoyError::InvalidTaskState {
116                task_id: task_id.to_string(),
117                from: task.state.as_str().to_string(),
118                to: new_state.as_str().to_string(),
119            });
120        }
121        if let (Some(ref claimant), Some(agent)) = (&task.claimed_by, agent_id) {
122            if claimant != agent && new_state != TaskState::Proposed {
123                return Err(EnvoyError::NotTaskClaimant {
124                    agent: agent.to_string(),
125                    task_id: task_id.to_string(),
126                });
127            }
128        }
129        let now = chrono::Utc::now().to_rfc3339();
130        self.update_entity(graph, task_id, |data| {
131            data["state"] = serde_json::json!(new_state.as_str());
132            data["updated_at"] = serde_json::json!(&now);
133            if let Some(ref cp) = checkpoint {
134                data["checkpoint"] = serde_json::json!(cp);
135            }
136        })?;
137        Ok(Task {
138            state: new_state,
139            checkpoint: checkpoint.or(task.checkpoint),
140            updated_at: now,
141            ..task
142        })
143    }
144
145    pub fn get(&self, graph: &sqlitegraph::SqliteGraph, task_id: &str) -> Result<Task> {
146        let id: i64 = task_id
147            .parse()
148            .map_err(|_| EnvoyError::TaskNotFound(task_id.to_string()))?;
149        let entity = graph
150            .get_entity(id)
151            .map_err(|_| EnvoyError::TaskNotFound(task_id.to_string()))?;
152        if entity.kind != KIND_TASK {
153            return Err(EnvoyError::TaskNotFound(task_id.to_string()));
154        }
155        entity_to_task(&entity)
156    }
157
158    pub fn list(
159        &self,
160        graph: &sqlitegraph::SqliteGraph,
161        project: &str,
162        state_filter: Option<&TaskState>,
163    ) -> Result<Vec<Task>> {
164        let entities = graph.find_entities_by_kind(KIND_TASK)?;
165        let mut tasks: Vec<Task> = entities
166            .iter()
167            .filter(|e| read_str(&e.data, "project") == project)
168            .filter_map(|e| entity_to_task(e).ok())
169            .filter(|t| state_filter.is_none_or(|f| t.state == *f))
170            .collect();
171        tasks.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
172        Ok(tasks)
173    }
174
175    pub fn find_blocked_by(
176        &self,
177        graph: &sqlitegraph::SqliteGraph,
178        blocker_id: &str,
179    ) -> Result<Vec<Task>> {
180        let entities = graph.find_entities_by_kind(KIND_TASK)?;
181        Ok(entities
182            .iter()
183            .filter(|e| {
184                read_json_array(&e.data, "blocked_by")
185                    .iter()
186                    .any(|b| b == blocker_id)
187            })
188            .filter_map(|e| entity_to_task(e).ok())
189            .collect())
190    }
191
192    pub fn reclaim_stale(
193        &self,
194        graph: &sqlitegraph::SqliteGraph,
195        stale_agent_id: &str,
196    ) -> Result<Vec<String>> {
197        let entities = graph.find_entities_by_kind(KIND_TASK)?;
198        let mut reclaimed = Vec::new();
199        for e in &entities {
200            let claimed_by = read_str(&e.data, "claimed_by");
201            let state = read_str(&e.data, "state");
202            if claimed_by == stale_agent_id && (state == "claimed" || state == "in_progress") {
203                let now = chrono::Utc::now().to_rfc3339();
204                let mut entity = e.clone();
205                entity.data["state"] = serde_json::json!("proposed");
206                entity.data["claimed_by"] = serde_json::json!(null);
207                entity.data["checkpoint"] = serde_json::json!(null);
208                entity.data["updated_at"] = serde_json::json!(&now);
209                graph.update_entity(&entity)?;
210                reclaimed.push(entity.id.to_string());
211            }
212        }
213        Ok(reclaimed)
214    }
215
216    fn update_entity(
217        &self,
218        graph: &sqlitegraph::SqliteGraph,
219        task_id: &str,
220        updater: impl FnOnce(&mut serde_json::Value),
221    ) -> Result<()> {
222        let id: i64 = task_id
223            .parse()
224            .map_err(|_| EnvoyError::TaskNotFound(task_id.to_string()))?;
225        let mut entity = graph
226            .get_entity(id)
227            .map_err(|_| EnvoyError::TaskNotFound(task_id.to_string()))?;
228        if entity.kind != KIND_TASK {
229            return Err(EnvoyError::TaskNotFound(task_id.to_string()));
230        }
231        updater(&mut entity.data);
232        graph.update_entity(&entity)?;
233        Ok(())
234    }
235}
236
237fn entity_to_task(entity: &sqlitegraph::GraphEntity) -> Result<Task> {
238    Ok(Task {
239        id: entity.id.to_string(),
240        project: read_str(&entity.data, "project"),
241        description: read_str(&entity.data, "description"),
242        state: read_str(&entity.data, "state")
243            .parse()
244            .unwrap_or(TaskState::Proposed),
245        claimed_by: entity
246            .data
247            .get("claimed_by")
248            .and_then(|v| v.as_str())
249            .filter(|s| !s.is_empty())
250            .map(String::from),
251        blocked_by: read_json_array(&entity.data, "blocked_by"),
252        checkpoint: entity
253            .data
254            .get("checkpoint")
255            .and_then(|v| v.as_str())
256            .filter(|s| !s.is_empty())
257            .map(String::from),
258        created_at: read_str(&entity.data, "created_at"),
259        updated_at: read_str(&entity.data, "updated_at"),
260    })
261}
262
263fn read_str(data: &serde_json::Value, key: &str) -> String {
264    data.get(key)
265        .and_then(|v| v.as_str())
266        .unwrap_or("")
267        .to_string()
268}
269
270fn read_json_array(data: &serde_json::Value, key: &str) -> Vec<String> {
271    data.get(key)
272        .and_then(|v| v.as_array())
273        .map(|arr| {
274            arr.iter()
275                .filter_map(|v| v.as_str().map(String::from))
276                .collect()
277        })
278        .unwrap_or_default()
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::engine::Engine;
285
286    #[test]
287    fn propose_and_claim() {
288        let engine = Engine::open_in_memory().unwrap();
289        let graph = engine.graph();
290        let store = TaskStore::new();
291        let task = store
292            .propose(graph, "magellan".into(), "fix".into(), vec![])
293            .unwrap();
294        assert_eq!(task.state, TaskState::Proposed);
295        let claimed = store.claim(graph, &task.id, "agent-1".into()).unwrap();
296        assert_eq!(claimed.state, TaskState::Claimed);
297        assert_eq!(claimed.claimed_by, Some("agent-1".into()));
298    }
299
300    #[test]
301    fn reject_double_claim() {
302        let engine = Engine::open_in_memory().unwrap();
303        let graph = engine.graph();
304        let store = TaskStore::new();
305        let task = store
306            .propose(graph, "m".into(), "fix".into(), vec![])
307            .unwrap();
308        store.claim(graph, &task.id, "a".into()).unwrap();
309        assert!(store.claim(graph, &task.id, "b".into()).is_err());
310    }
311
312    #[test]
313    fn state_transitions() {
314        let engine = Engine::open_in_memory().unwrap();
315        let graph = engine.graph();
316        let store = TaskStore::new();
317        let task = store
318            .propose(graph, "m".into(), "fix".into(), vec![])
319            .unwrap();
320        store.claim(graph, &task.id, "a".into()).unwrap();
321        let updated = store
322            .update_state(
323                graph,
324                &task.id,
325                TaskState::InProgress,
326                Some("impl".into()),
327                Some("a"),
328            )
329            .unwrap();
330        assert_eq!(updated.state, TaskState::InProgress);
331    }
332
333    #[test]
334    fn reject_invalid_transition() {
335        let engine = Engine::open_in_memory().unwrap();
336        let graph = engine.graph();
337        let store = TaskStore::new();
338        let task = store
339            .propose(graph, "m".into(), "fix".into(), vec![])
340            .unwrap();
341        assert!(store
342            .update_state(graph, &task.id, TaskState::Done, None, None)
343            .is_err());
344    }
345
346    #[test]
347    fn find_blocked_by() {
348        let engine = Engine::open_in_memory().unwrap();
349        let graph = engine.graph();
350        let store = TaskStore::new();
351        let a = store
352            .propose(graph, "m".into(), "A".into(), vec![])
353            .unwrap();
354        store
355            .propose(graph, "m".into(), "B".into(), vec![a.id.clone()])
356            .unwrap();
357        assert_eq!(store.find_blocked_by(graph, &a.id).unwrap().len(), 1);
358    }
359
360    #[test]
361    fn claim_next_oldest_proposed() {
362        let engine = Engine::open_in_memory().unwrap();
363        let graph = engine.graph();
364        let store = TaskStore::new();
365        let a = store
366            .propose(graph, "m".into(), "A".into(), vec![])
367            .unwrap();
368        let b = store
369            .propose(graph, "m".into(), "B".into(), vec![])
370            .unwrap();
371        let next = store.claim_next(graph, "m", "agent-1".into()).unwrap();
372        assert_eq!(next.id, a.id); // oldest
373        assert_eq!(next.state, TaskState::Claimed);
374        assert_eq!(next.claimed_by, Some("agent-1".into()));
375        // b still proposed
376        assert_eq!(store.get(graph, &b.id).unwrap().state, TaskState::Proposed);
377    }
378
379    #[test]
380    fn claim_next_empty_project() {
381        let engine = Engine::open_in_memory().unwrap();
382        let graph = engine.graph();
383        let store = TaskStore::new();
384        assert!(store.claim_next(graph, "no-project", "a".into()).is_err());
385    }
386
387    #[test]
388    fn reclaim_stale() {
389        let engine = Engine::open_in_memory().unwrap();
390        let graph = engine.graph();
391        let store = TaskStore::new();
392        let task = store
393            .propose(graph, "m".into(), "fix".into(), vec![])
394            .unwrap();
395        store.claim(graph, &task.id, "agent-1".into()).unwrap();
396        let reclaimed = store.reclaim_stale(graph, "agent-1").unwrap();
397        assert_eq!(reclaimed.len(), 1);
398        assert_eq!(
399            store.get(graph, &task.id).unwrap().state,
400            TaskState::Proposed
401        );
402    }
403}