Skip to main content

roder_api/
tasks.rs

1use std::fmt;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5use time::OffsetDateTime;
6
7use crate::events::{ThreadId, TurnId};
8use crate::extension::TaskExecutorId;
9use crate::processes::ProcessRegistrySink;
10use crate::remote_runner::{RemoteRunnerSession, RunnerDestination};
11use crate::{ToolSchemaPolicy, normalize_tool_schema};
12
13pub type TaskId = String;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct TaskSpec {
17    pub kind: String,
18    pub description: String,
19    pub input_schema: serde_json::Value,
20    #[serde(default, skip_serializing_if = "Option::is_none")]
21    pub default_timeout_seconds: Option<u64>,
22    #[serde(default)]
23    pub metadata: serde_json::Value,
24}
25
26impl TaskSpec {
27    pub fn normalized_for_model(&self, policy: ToolSchemaPolicy) -> Self {
28        let mut spec = self.clone();
29        spec.input_schema = normalize_tool_schema(&spec.kind, &spec.input_schema, policy).schema;
30        spec
31    }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
35#[serde(rename_all = "snake_case")]
36pub enum TaskState {
37    Queued,
38    Running,
39    Completed,
40    Failed,
41    Cancelled,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
45pub struct TaskHandle {
46    pub task_id: TaskId,
47    pub executor_id: TaskExecutorId,
48    pub spec: TaskSpec,
49    pub state: TaskState,
50    #[serde(with = "time::serde::rfc3339")]
51    pub created_at: OffsetDateTime,
52    #[serde(default, with = "time::serde::rfc3339::option")]
53    pub started_at: Option<OffsetDateTime>,
54    #[serde(default, with = "time::serde::rfc3339::option")]
55    pub finished_at: Option<OffsetDateTime>,
56}
57
58#[derive(Clone)]
59pub struct TaskExecutionContext {
60    pub task_id: TaskId,
61    pub thread_id: Option<ThreadId>,
62    pub turn_id: Option<TurnId>,
63    pub workspace_root: Option<String>,
64    pub runner_destination: Option<RunnerDestination>,
65    pub runner_session: Option<Arc<dyn RemoteRunnerSession>>,
66    pub deadline: Option<OffsetDateTime>,
67    pub metadata: serde_json::Value,
68    pub process_registry: Option<Arc<dyn ProcessRegistrySink>>,
69    pub output: TaskOutputSink,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct TaskExecutionResult {
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    pub exit_code: Option<i32>,
76    #[serde(default)]
77    pub payload: serde_json::Value,
78}
79
80impl TaskExecutionResult {
81    pub fn success(payload: serde_json::Value) -> Self {
82        Self {
83            exit_code: None,
84            payload,
85        }
86    }
87}
88
89impl fmt::Debug for TaskExecutionContext {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        f.debug_struct("TaskExecutionContext")
92            .field("task_id", &self.task_id)
93            .field("thread_id", &self.thread_id)
94            .field("turn_id", &self.turn_id)
95            .field("workspace_root", &self.workspace_root)
96            .field("runner_destination", &self.runner_destination)
97            .field(
98                "runner_session",
99                &self.runner_session.as_ref().map(|session| session.state()),
100            )
101            .field("deadline", &self.deadline)
102            .field("metadata", &self.metadata)
103            .field(
104                "process_registry",
105                &self.process_registry.as_ref().map(|_| "<process-registry>"),
106            )
107            .finish_non_exhaustive()
108    }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
112#[serde(rename_all = "snake_case")]
113pub enum TaskOutputStream {
114    Stdout,
115    Stderr,
116    Log,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
120pub struct TaskStarted {
121    pub task_id: TaskId,
122    pub executor_id: TaskExecutorId,
123    pub task_kind: String,
124    #[serde(default)]
125    pub queue_depth: usize,
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub thread_id: Option<ThreadId>,
128    #[serde(default, skip_serializing_if = "Option::is_none")]
129    pub turn_id: Option<TurnId>,
130    #[serde(with = "time::serde::rfc3339")]
131    pub timestamp: OffsetDateTime,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
135pub struct TaskOutput {
136    pub task_id: TaskId,
137    pub stream: TaskOutputStream,
138    pub chunk: String,
139    #[serde(default)]
140    pub dropped_bytes: u64,
141    #[serde(default, skip_serializing_if = "Option::is_none")]
142    pub thread_id: Option<ThreadId>,
143    #[serde(default, skip_serializing_if = "Option::is_none")]
144    pub turn_id: Option<TurnId>,
145    #[serde(with = "time::serde::rfc3339")]
146    pub timestamp: OffsetDateTime,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
150pub struct TaskCompleted {
151    pub task_id: TaskId,
152    #[serde(default, skip_serializing_if = "Option::is_none")]
153    pub exit_code: Option<i32>,
154    #[serde(default)]
155    pub payload: serde_json::Value,
156    #[serde(default, skip_serializing_if = "Option::is_none")]
157    pub thread_id: Option<ThreadId>,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub turn_id: Option<TurnId>,
160    #[serde(with = "time::serde::rfc3339")]
161    pub timestamp: OffsetDateTime,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
165pub struct TaskFailed {
166    pub task_id: TaskId,
167    pub error: String,
168    #[serde(default, skip_serializing_if = "Option::is_none")]
169    pub error_kind: Option<String>,
170    #[serde(default, skip_serializing_if = "Option::is_none")]
171    pub partial_result: Option<String>,
172    #[serde(default, skip_serializing_if = "Option::is_none")]
173    pub thread_id: Option<ThreadId>,
174    #[serde(default, skip_serializing_if = "Option::is_none")]
175    pub turn_id: Option<TurnId>,
176    #[serde(with = "time::serde::rfc3339")]
177    pub timestamp: OffsetDateTime,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
181pub struct TaskCancelled {
182    pub task_id: TaskId,
183    #[serde(default, skip_serializing_if = "Option::is_none")]
184    pub reason: Option<String>,
185    #[serde(default, skip_serializing_if = "Option::is_none")]
186    pub thread_id: Option<ThreadId>,
187    #[serde(default, skip_serializing_if = "Option::is_none")]
188    pub turn_id: Option<TurnId>,
189    #[serde(with = "time::serde::rfc3339")]
190    pub timestamp: OffsetDateTime,
191}
192
193#[derive(Clone)]
194pub struct TaskOutputSink {
195    writer: Arc<dyn TaskOutputWriter>,
196}
197
198impl Default for TaskOutputSink {
199    fn default() -> Self {
200        Self {
201            writer: Arc::new(NoopTaskOutputWriter),
202        }
203    }
204}
205
206impl fmt::Debug for TaskOutputSink {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        f.debug_struct("TaskOutputSink").finish_non_exhaustive()
209    }
210}
211
212impl TaskOutputSink {
213    pub fn new(writer: Arc<dyn TaskOutputWriter>) -> Self {
214        Self { writer }
215    }
216
217    pub async fn write(
218        &self,
219        stream: TaskOutputStream,
220        chunk: impl Into<String>,
221    ) -> anyhow::Result<()> {
222        self.writer.write(stream, chunk.into()).await
223    }
224}
225
226#[async_trait::async_trait]
227pub trait TaskOutputWriter: Send + Sync + 'static {
228    async fn write(&self, stream: TaskOutputStream, chunk: String) -> anyhow::Result<()>;
229}
230
231struct NoopTaskOutputWriter;
232
233#[async_trait::async_trait]
234impl TaskOutputWriter for NoopTaskOutputWriter {
235    async fn write(&self, _stream: TaskOutputStream, _chunk: String) -> anyhow::Result<()> {
236        Ok(())
237    }
238}
239
240#[async_trait::async_trait]
241pub trait TaskExecutor: Send + Sync + 'static {
242    fn id(&self) -> TaskExecutorId;
243
244    fn spec(&self) -> TaskSpec;
245
246    async fn execute(
247        &self,
248        ctx: TaskExecutionContext,
249        input: serde_json::Value,
250    ) -> anyhow::Result<TaskExecutionResult>;
251}
252
253#[cfg(test)]
254mod tests {
255    use std::sync::Arc;
256
257    use super::*;
258
259    struct NoopTaskExecutor;
260
261    #[async_trait::async_trait]
262    impl TaskExecutor for NoopTaskExecutor {
263        fn id(&self) -> TaskExecutorId {
264            "noop-task".to_string()
265        }
266
267        fn spec(&self) -> TaskSpec {
268            TaskSpec {
269                kind: "noop".to_string(),
270                description: "No-op task".to_string(),
271                input_schema: serde_json::json!({ "type": "object" }),
272                default_timeout_seconds: Some(30),
273                metadata: serde_json::json!({ "category": "test" }),
274            }
275        }
276
277        async fn execute(
278            &self,
279            ctx: TaskExecutionContext,
280            input: serde_json::Value,
281        ) -> anyhow::Result<TaskExecutionResult> {
282            Ok(TaskExecutionResult::success(serde_json::json!({
283                "task_id": ctx.task_id,
284                "input": input,
285            })))
286        }
287    }
288
289    #[test]
290    fn task_handle_round_trips_json() {
291        let handle = TaskHandle {
292            task_id: "task-1".to_string(),
293            executor_id: "process".to_string(),
294            spec: TaskSpec {
295                kind: "process".to_string(),
296                description: "Run a process".to_string(),
297                input_schema: serde_json::json!({ "type": "object" }),
298                default_timeout_seconds: Some(60),
299                metadata: serde_json::json!({}),
300            },
301            state: TaskState::Queued,
302            created_at: OffsetDateTime::UNIX_EPOCH,
303            started_at: None,
304            finished_at: None,
305        };
306
307        let encoded = serde_json::to_string(&handle).expect("serialize task handle");
308        let decoded: TaskHandle = serde_json::from_str(&encoded).expect("deserialize task handle");
309
310        assert_eq!(decoded, handle);
311    }
312
313    #[test]
314    fn task_events_round_trip_json() {
315        let started = TaskStarted {
316            task_id: "task-1".to_string(),
317            executor_id: "process".to_string(),
318            task_kind: "process".to_string(),
319            queue_depth: 0,
320            thread_id: Some("thread-a".to_string()),
321            turn_id: Some("turn-a".to_string()),
322            timestamp: OffsetDateTime::UNIX_EPOCH,
323        };
324        let output = TaskOutput {
325            task_id: "task-1".to_string(),
326            stream: TaskOutputStream::Stdout,
327            chunk: "hello\n".to_string(),
328            dropped_bytes: 0,
329            thread_id: Some("thread-a".to_string()),
330            turn_id: Some("turn-a".to_string()),
331            timestamp: OffsetDateTime::UNIX_EPOCH,
332        };
333
334        assert_eq!(
335            serde_json::from_value::<TaskStarted>(serde_json::to_value(&started).unwrap()).unwrap(),
336            started
337        );
338        assert_eq!(
339            serde_json::from_value::<TaskOutput>(serde_json::to_value(&output).unwrap()).unwrap(),
340            output
341        );
342    }
343
344    #[tokio::test]
345    async fn task_executor_trait_is_object_safe() {
346        let executor: Arc<dyn TaskExecutor> = Arc::new(NoopTaskExecutor);
347        let result = executor
348            .execute(
349                TaskExecutionContext {
350                    task_id: "task-1".to_string(),
351                    thread_id: None,
352                    turn_id: None,
353                    workspace_root: None,
354                    runner_destination: None,
355                    runner_session: None,
356                    deadline: None,
357                    metadata: serde_json::json!({}),
358                    process_registry: None,
359                    output: TaskOutputSink::default(),
360                },
361                serde_json::json!({ "ok": true }),
362            )
363            .await
364            .unwrap();
365
366        assert_eq!(executor.id(), "noop-task");
367        assert_eq!(executor.spec().kind, "noop");
368        assert_eq!(result.payload["task_id"], "task-1");
369    }
370}