Skip to main content

ralph_api/
task_domain.rs

1use std::collections::{BTreeMap, HashSet};
2use std::path::{Path, PathBuf};
3
4use chrono::Utc;
5use serde::{Deserialize, Serialize};
6
7use crate::errors::ApiError;
8use crate::loop_support::now_ts;
9
10mod storage;
11
12#[derive(Debug, Clone, Deserialize)]
13#[serde(rename_all = "camelCase")]
14pub struct TaskListParams {
15    pub status: Option<String>,
16    pub include_archived: Option<bool>,
17}
18
19#[derive(Debug, Clone, Deserialize)]
20#[serde(rename_all = "camelCase")]
21pub struct TaskCreateParams {
22    pub id: String,
23    pub title: String,
24    pub status: Option<String>,
25    pub priority: Option<u8>,
26    pub blocked_by: Option<String>,
27    pub auto_execute: Option<bool>,
28    pub merge_loop_prompt: Option<String>,
29}
30
31#[derive(Debug, Clone)]
32pub struct TaskUpdateInput {
33    pub id: String,
34    pub title: Option<String>,
35    pub status: Option<String>,
36    pub priority: Option<u8>,
37    pub blocked_by: Option<Option<String>>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41#[serde(rename_all = "camelCase")]
42pub struct TaskRecord {
43    pub id: String,
44    pub title: String,
45    pub status: String,
46    pub priority: u8,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub blocked_by: Option<String>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub archived_at: Option<String>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub queued_task_id: Option<String>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub merge_loop_prompt: Option<String>,
55    pub created_at: String,
56    pub updated_at: String,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub completed_at: Option<String>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub error_message: Option<String>,
61}
62
63#[derive(Debug, Clone, Serialize)]
64#[serde(rename_all = "camelCase")]
65pub struct TaskRunResult {
66    pub success: bool,
67    pub queued_task_id: String,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub task: Option<TaskRecord>,
70}
71
72#[derive(Debug, Clone, Serialize)]
73#[serde(rename_all = "camelCase")]
74pub struct TaskRunAllResult {
75    pub enqueued: u64,
76    pub errors: Vec<String>,
77}
78
79#[derive(Debug, Clone, Serialize)]
80#[serde(rename_all = "camelCase")]
81pub struct TaskStatusResult {
82    pub is_queued: bool,
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub queue_position: Option<u64>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub runner_pid: Option<u32>,
87}
88
89pub struct TaskDomain {
90    store_path: PathBuf,
91    tasks: BTreeMap<String, TaskRecord>,
92    queue_counter: u64,
93}
94
95impl TaskDomain {
96    pub fn new(workspace_root: impl AsRef<Path>) -> Self {
97        let store_path = workspace_root.as_ref().join(".ralph/api/tasks-v1.json");
98        let mut domain = Self {
99            store_path,
100            tasks: BTreeMap::new(),
101            queue_counter: 0,
102        };
103        domain.load();
104        domain
105    }
106
107    pub fn list(&self, params: TaskListParams) -> Vec<TaskRecord> {
108        let include_archived = params.include_archived.unwrap_or(false);
109        let mut tasks = self.sorted_tasks();
110
111        if let Some(status) = params.status {
112            tasks.retain(|task| task.status == status);
113        }
114
115        if !include_archived {
116            tasks.retain(|task| task.archived_at.is_none());
117        }
118
119        tasks
120    }
121
122    pub fn get(&self, id: &str) -> Result<TaskRecord, ApiError> {
123        self.tasks
124            .get(id)
125            .cloned()
126            .ok_or_else(|| task_not_found_error(id))
127    }
128
129    pub fn ready(&self) -> Vec<TaskRecord> {
130        let unblocking_ids = self.unblocking_ids();
131        let mut tasks: Vec<_> = self
132            .tasks
133            .values()
134            .filter(|task| task.status == "open" && task.archived_at.is_none())
135            .filter(|task| {
136                task.blocked_by
137                    .as_ref()
138                    .is_none_or(|blocker_id| unblocking_ids.contains(blocker_id))
139            })
140            .cloned()
141            .collect();
142
143        tasks.sort_by(|a, b| a.created_at.cmp(&b.created_at));
144        tasks
145    }
146
147    pub fn create(&mut self, params: TaskCreateParams) -> Result<TaskRecord, ApiError> {
148        if self.tasks.contains_key(&params.id) {
149            return Err(
150                ApiError::conflict(format!("Task with id '{}' already exists", params.id))
151                    .with_details(serde_json::json!({ "taskId": params.id })),
152            );
153        }
154
155        let requested_status = params.status.unwrap_or_else(|| "open".to_string());
156        let auto_execute = params.auto_execute.unwrap_or(true);
157
158        if auto_execute && requested_status != "open" {
159            return Err(ApiError::invalid_params(
160                "task.create autoExecute=true is only valid when status is 'open'",
161            )
162            .with_details(serde_json::json!({
163                "taskId": params.id,
164                "status": requested_status,
165                "autoExecute": auto_execute,
166            })));
167        }
168
169        let now = now_ts();
170        let completed_at = is_terminal_status(&requested_status).then_some(now.clone());
171
172        let task = TaskRecord {
173            id: params.id.clone(),
174            title: params.title,
175            status: requested_status,
176            priority: params.priority.unwrap_or(2).clamp(1, 5),
177            blocked_by: params.blocked_by,
178            archived_at: None,
179            queued_task_id: None,
180            merge_loop_prompt: params.merge_loop_prompt,
181            created_at: now.clone(),
182            updated_at: now,
183            completed_at,
184            error_message: None,
185        };
186
187        let task_id = task.id.clone();
188        self.tasks.insert(task_id.clone(), task);
189
190        let should_auto_execute = auto_execute
191            && self
192                .tasks
193                .get(&task_id)
194                .is_some_and(|task| task.blocked_by.is_none() && task.status == "open");
195
196        if should_auto_execute {
197            let _ = self.run(&task_id)?;
198        } else {
199            self.persist()?;
200        }
201
202        self.get(&task_id)
203    }
204
205    pub fn update(&mut self, input: TaskUpdateInput) -> Result<TaskRecord, ApiError> {
206        let now = now_ts();
207        let task = self
208            .tasks
209            .get_mut(&input.id)
210            .ok_or_else(|| task_not_found_error(&input.id))?;
211
212        if let Some(title) = input.title {
213            task.title = title;
214        }
215        if let Some(status) = input.status {
216            task.status = status;
217
218            if is_terminal_status(&task.status) {
219                task.completed_at = Some(now.clone());
220                task.queued_task_id = None;
221            } else {
222                task.completed_at = None;
223                if !matches!(task.status.as_str(), "pending" | "running") {
224                    task.queued_task_id = None;
225                }
226            }
227
228            if task.status != "failed" {
229                task.error_message = None;
230            }
231        }
232        if let Some(priority) = input.priority {
233            task.priority = priority.clamp(1, 5);
234        }
235        if let Some(blocked_by) = input.blocked_by {
236            task.blocked_by = blocked_by;
237        }
238
239        task.updated_at = now;
240        self.persist()?;
241        self.get(&input.id)
242    }
243
244    pub fn close(&mut self, id: &str) -> Result<TaskRecord, ApiError> {
245        self.transition_task(id, "closed")
246    }
247
248    pub fn archive(&mut self, id: &str) -> Result<TaskRecord, ApiError> {
249        let task = self
250            .tasks
251            .get_mut(id)
252            .ok_or_else(|| task_not_found_error(id))?;
253
254        task.archived_at = Some(now_ts());
255        task.updated_at = now_ts();
256        self.persist()?;
257        self.get(id)
258    }
259
260    pub fn unarchive(&mut self, id: &str) -> Result<TaskRecord, ApiError> {
261        let task = self
262            .tasks
263            .get_mut(id)
264            .ok_or_else(|| task_not_found_error(id))?;
265
266        task.archived_at = None;
267        task.updated_at = now_ts();
268        self.persist()?;
269        self.get(id)
270    }
271
272    pub fn delete(&mut self, id: &str) -> Result<(), ApiError> {
273        let task = self.tasks.get(id).ok_or_else(|| task_not_found_error(id))?;
274
275        if !matches!(task.status.as_str(), "failed" | "closed") {
276            return Err(ApiError::precondition_failed(format!(
277                "Cannot delete task in '{}' state. Only failed or closed tasks can be deleted.",
278                task.status
279            ))
280            .with_details(serde_json::json!({
281                "taskId": id,
282                "status": task.status,
283                "allowedStatuses": ["failed", "closed"]
284            })));
285        }
286
287        self.tasks.remove(id);
288        self.persist()?;
289        Ok(())
290    }
291
292    pub fn clear(&mut self) -> Result<(), ApiError> {
293        self.tasks.clear();
294        self.persist()?;
295        Ok(())
296    }
297
298    pub fn run(&mut self, id: &str) -> Result<TaskRunResult, ApiError> {
299        let queued_task_id = self.queue_task(id)?;
300        Ok(TaskRunResult {
301            success: true,
302            queued_task_id,
303            task: Some(self.get(id)?),
304        })
305    }
306
307    pub fn run_all(&mut self) -> TaskRunAllResult {
308        let ready_task_ids: Vec<String> = self.ready().into_iter().map(|task| task.id).collect();
309        let mut enqueued = 0_u64;
310        let mut errors = Vec::new();
311
312        for task_id in ready_task_ids {
313            match self.queue_task(&task_id) {
314                Ok(_) => {
315                    enqueued = enqueued.saturating_add(1);
316                }
317                Err(error) => {
318                    errors.push(format!("{task_id}: {}", error.message));
319                }
320            }
321        }
322
323        TaskRunAllResult { enqueued, errors }
324    }
325
326    pub fn retry(&mut self, id: &str) -> Result<TaskRunResult, ApiError> {
327        {
328            let task = self
329                .tasks
330                .get_mut(id)
331                .ok_or_else(|| task_not_found_error(id))?;
332
333            if task.status != "failed" {
334                return Err(
335                    ApiError::precondition_failed("Only failed tasks can be retried").with_details(
336                        serde_json::json!({
337                            "taskId": id,
338                            "status": task.status,
339                        }),
340                    ),
341                );
342            }
343
344            let now = now_ts();
345            task.status = "open".to_string();
346            task.queued_task_id = None;
347            task.completed_at = None;
348            task.error_message = None;
349            task.updated_at = now;
350        }
351
352        self.run(id)
353    }
354
355    pub fn cancel(&mut self, id: &str) -> Result<TaskRecord, ApiError> {
356        let task = self
357            .tasks
358            .get_mut(id)
359            .ok_or_else(|| task_not_found_error(id))?;
360
361        if !matches!(task.status.as_str(), "pending" | "running") {
362            return Err(ApiError::precondition_failed(
363                "Only running or pending tasks can be cancelled",
364            )
365            .with_details(serde_json::json!({
366                "taskId": id,
367                "status": task.status,
368            })));
369        }
370
371        let now = now_ts();
372        task.status = "failed".to_string();
373        task.completed_at = Some(now.clone());
374        task.updated_at = now;
375        task.error_message = Some("Task cancelled by user".to_string());
376        task.queued_task_id = None;
377
378        self.persist()?;
379        self.get(id)
380    }
381
382    pub fn status(&self, id: &str) -> TaskStatusResult {
383        let Some(task) = self.tasks.get(id) else {
384            return TaskStatusResult {
385                is_queued: false,
386                queue_position: None,
387                runner_pid: None,
388            };
389        };
390
391        let is_queued =
392            task.queued_task_id.is_some() && matches!(task.status.as_str(), "pending" | "running");
393
394        let queue_position = if is_queued {
395            self.queue_position(id)
396        } else {
397            None
398        };
399
400        let runner_pid = if task.status == "running" {
401            Some(std::process::id())
402        } else {
403            None
404        };
405
406        TaskStatusResult {
407            is_queued,
408            queue_position,
409            runner_pid,
410        }
411    }
412
413    fn transition_task(&mut self, id: &str, status: &str) -> Result<TaskRecord, ApiError> {
414        let task = self
415            .tasks
416            .get_mut(id)
417            .ok_or_else(|| task_not_found_error(id))?;
418
419        let now = now_ts();
420        task.status = status.to_string();
421        task.updated_at = now.clone();
422
423        if is_terminal_status(status) {
424            task.completed_at = Some(now);
425            task.queued_task_id = None;
426        } else {
427            task.completed_at = None;
428            if !matches!(status, "pending" | "running") {
429                task.queued_task_id = None;
430            }
431        }
432
433        if status != "failed" {
434            task.error_message = None;
435        }
436
437        self.persist()?;
438        self.get(id)
439    }
440
441    fn queue_task(&mut self, id: &str) -> Result<String, ApiError> {
442        let queued_task_id = self.next_queued_task_id();
443        let now = now_ts();
444
445        let task = self
446            .tasks
447            .get_mut(id)
448            .ok_or_else(|| task_not_found_error(id))?;
449
450        if task.archived_at.is_some() {
451            return Err(
452                ApiError::precondition_failed("Cannot run archived task").with_details(
453                    serde_json::json!({
454                        "taskId": id,
455                    }),
456                ),
457            );
458        }
459
460        if matches!(task.status.as_str(), "pending" | "running") {
461            return Err(
462                ApiError::precondition_failed("Task is already queued or running").with_details(
463                    serde_json::json!({
464                        "taskId": id,
465                        "status": task.status
466                    }),
467                ),
468            );
469        }
470
471        task.status = "pending".to_string();
472        task.queued_task_id = Some(queued_task_id.clone());
473        task.completed_at = None;
474        task.error_message = None;
475        task.updated_at = now;
476        self.persist()?;
477
478        Ok(queued_task_id)
479    }
480
481    fn queue_position(&self, id: &str) -> Option<u64> {
482        let mut queued: Vec<&TaskRecord> = self
483            .tasks
484            .values()
485            .filter(|task| {
486                task.queued_task_id.is_some()
487                    && matches!(task.status.as_str(), "pending" | "running")
488            })
489            .collect();
490        queued.sort_by(|a, b| a.updated_at.cmp(&b.updated_at));
491
492        queued
493            .iter()
494            .position(|task| task.id == id)
495            .map(|index| index as u64)
496    }
497
498    fn unblocking_ids(&self) -> HashSet<String> {
499        self.tasks
500            .values()
501            .filter(|task| task.status == "closed" || task.archived_at.is_some())
502            .map(|task| task.id.clone())
503            .collect()
504    }
505
506    fn next_queued_task_id(&mut self) -> String {
507        self.queue_counter = self.queue_counter.saturating_add(1);
508        format!(
509            "queued-{}-{:04x}",
510            Utc::now().timestamp_millis(),
511            self.queue_counter
512        )
513    }
514
515    fn sorted_tasks(&self) -> Vec<TaskRecord> {
516        let mut tasks: Vec<_> = self.tasks.values().cloned().collect();
517        tasks.sort_by(|a, b| a.created_at.cmp(&b.created_at));
518        tasks
519    }
520}
521
522fn task_not_found_error(task_id: &str) -> ApiError {
523    ApiError::task_not_found(format!("Task with id '{task_id}' not found"))
524        .with_details(serde_json::json!({ "taskId": task_id }))
525}
526
527fn is_terminal_status(status: &str) -> bool {
528    matches!(status, "closed" | "failed")
529}