1use std::fmt;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5use time::OffsetDateTime;
6
7use crate::events::{ThreadId, TurnId};
8use crate::extension::TaskExecutorId;
9use crate::processes::ProcessRegistrySink;
10use crate::remote_runner::{RemoteRunnerSession, RunnerDestination};
11use crate::{ToolSchemaPolicy, normalize_tool_schema};
12
13pub type TaskId = String;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct TaskSpec {
17 pub kind: String,
18 pub description: String,
19 pub input_schema: serde_json::Value,
20 #[serde(default, skip_serializing_if = "Option::is_none")]
21 pub default_timeout_seconds: Option<u64>,
22 #[serde(default)]
23 pub metadata: serde_json::Value,
24}
25
26impl TaskSpec {
27 pub fn normalized_for_model(&self, policy: ToolSchemaPolicy) -> Self {
28 let mut spec = self.clone();
29 spec.input_schema = normalize_tool_schema(&spec.kind, &spec.input_schema, policy).schema;
30 spec
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
35#[serde(rename_all = "snake_case")]
36pub enum TaskState {
37 Queued,
38 Running,
39 Completed,
40 Failed,
41 Cancelled,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
45pub struct TaskHandle {
46 pub task_id: TaskId,
47 pub executor_id: TaskExecutorId,
48 pub spec: TaskSpec,
49 pub state: TaskState,
50 #[serde(with = "time::serde::rfc3339")]
51 pub created_at: OffsetDateTime,
52 #[serde(default, with = "time::serde::rfc3339::option")]
53 pub started_at: Option<OffsetDateTime>,
54 #[serde(default, with = "time::serde::rfc3339::option")]
55 pub finished_at: Option<OffsetDateTime>,
56}
57
58#[derive(Clone)]
59pub struct TaskExecutionContext {
60 pub task_id: TaskId,
61 pub thread_id: Option<ThreadId>,
62 pub turn_id: Option<TurnId>,
63 pub workspace_root: Option<String>,
64 pub runner_destination: Option<RunnerDestination>,
65 pub runner_session: Option<Arc<dyn RemoteRunnerSession>>,
66 pub deadline: Option<OffsetDateTime>,
67 pub metadata: serde_json::Value,
68 pub process_registry: Option<Arc<dyn ProcessRegistrySink>>,
69 pub output: TaskOutputSink,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct TaskExecutionResult {
74 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub exit_code: Option<i32>,
76 #[serde(default)]
77 pub payload: serde_json::Value,
78}
79
80impl TaskExecutionResult {
81 pub fn success(payload: serde_json::Value) -> Self {
82 Self {
83 exit_code: None,
84 payload,
85 }
86 }
87}
88
89impl fmt::Debug for TaskExecutionContext {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 f.debug_struct("TaskExecutionContext")
92 .field("task_id", &self.task_id)
93 .field("thread_id", &self.thread_id)
94 .field("turn_id", &self.turn_id)
95 .field("workspace_root", &self.workspace_root)
96 .field("runner_destination", &self.runner_destination)
97 .field(
98 "runner_session",
99 &self.runner_session.as_ref().map(|session| session.state()),
100 )
101 .field("deadline", &self.deadline)
102 .field("metadata", &self.metadata)
103 .field(
104 "process_registry",
105 &self.process_registry.as_ref().map(|_| "<process-registry>"),
106 )
107 .finish_non_exhaustive()
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
112#[serde(rename_all = "snake_case")]
113pub enum TaskOutputStream {
114 Stdout,
115 Stderr,
116 Log,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
120pub struct TaskStarted {
121 pub task_id: TaskId,
122 pub executor_id: TaskExecutorId,
123 pub task_kind: String,
124 #[serde(default)]
125 pub queue_depth: usize,
126 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub thread_id: Option<ThreadId>,
128 #[serde(default, skip_serializing_if = "Option::is_none")]
129 pub turn_id: Option<TurnId>,
130 #[serde(with = "time::serde::rfc3339")]
131 pub timestamp: OffsetDateTime,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
135pub struct TaskOutput {
136 pub task_id: TaskId,
137 pub stream: TaskOutputStream,
138 pub chunk: String,
139 #[serde(default)]
140 pub dropped_bytes: u64,
141 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub thread_id: Option<ThreadId>,
143 #[serde(default, skip_serializing_if = "Option::is_none")]
144 pub turn_id: Option<TurnId>,
145 #[serde(with = "time::serde::rfc3339")]
146 pub timestamp: OffsetDateTime,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
150pub struct TaskCompleted {
151 pub task_id: TaskId,
152 #[serde(default, skip_serializing_if = "Option::is_none")]
153 pub exit_code: Option<i32>,
154 #[serde(default)]
155 pub payload: serde_json::Value,
156 #[serde(default, skip_serializing_if = "Option::is_none")]
157 pub thread_id: Option<ThreadId>,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub turn_id: Option<TurnId>,
160 #[serde(with = "time::serde::rfc3339")]
161 pub timestamp: OffsetDateTime,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
165pub struct TaskFailed {
166 pub task_id: TaskId,
167 pub error: String,
168 #[serde(default, skip_serializing_if = "Option::is_none")]
169 pub error_kind: Option<String>,
170 #[serde(default, skip_serializing_if = "Option::is_none")]
171 pub partial_result: Option<String>,
172 #[serde(default, skip_serializing_if = "Option::is_none")]
173 pub thread_id: Option<ThreadId>,
174 #[serde(default, skip_serializing_if = "Option::is_none")]
175 pub turn_id: Option<TurnId>,
176 #[serde(with = "time::serde::rfc3339")]
177 pub timestamp: OffsetDateTime,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
181pub struct TaskCancelled {
182 pub task_id: TaskId,
183 #[serde(default, skip_serializing_if = "Option::is_none")]
184 pub reason: Option<String>,
185 #[serde(default, skip_serializing_if = "Option::is_none")]
186 pub thread_id: Option<ThreadId>,
187 #[serde(default, skip_serializing_if = "Option::is_none")]
188 pub turn_id: Option<TurnId>,
189 #[serde(with = "time::serde::rfc3339")]
190 pub timestamp: OffsetDateTime,
191}
192
193#[derive(Clone)]
194pub struct TaskOutputSink {
195 writer: Arc<dyn TaskOutputWriter>,
196}
197
198impl Default for TaskOutputSink {
199 fn default() -> Self {
200 Self {
201 writer: Arc::new(NoopTaskOutputWriter),
202 }
203 }
204}
205
206impl fmt::Debug for TaskOutputSink {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 f.debug_struct("TaskOutputSink").finish_non_exhaustive()
209 }
210}
211
212impl TaskOutputSink {
213 pub fn new(writer: Arc<dyn TaskOutputWriter>) -> Self {
214 Self { writer }
215 }
216
217 pub async fn write(
218 &self,
219 stream: TaskOutputStream,
220 chunk: impl Into<String>,
221 ) -> anyhow::Result<()> {
222 self.writer.write(stream, chunk.into()).await
223 }
224}
225
226#[async_trait::async_trait]
227pub trait TaskOutputWriter: Send + Sync + 'static {
228 async fn write(&self, stream: TaskOutputStream, chunk: String) -> anyhow::Result<()>;
229}
230
231struct NoopTaskOutputWriter;
232
233#[async_trait::async_trait]
234impl TaskOutputWriter for NoopTaskOutputWriter {
235 async fn write(&self, _stream: TaskOutputStream, _chunk: String) -> anyhow::Result<()> {
236 Ok(())
237 }
238}
239
240#[async_trait::async_trait]
241pub trait TaskExecutor: Send + Sync + 'static {
242 fn id(&self) -> TaskExecutorId;
243
244 fn spec(&self) -> TaskSpec;
245
246 async fn execute(
247 &self,
248 ctx: TaskExecutionContext,
249 input: serde_json::Value,
250 ) -> anyhow::Result<TaskExecutionResult>;
251}
252
253#[cfg(test)]
254mod tests {
255 use std::sync::Arc;
256
257 use super::*;
258
259 struct NoopTaskExecutor;
260
261 #[async_trait::async_trait]
262 impl TaskExecutor for NoopTaskExecutor {
263 fn id(&self) -> TaskExecutorId {
264 "noop-task".to_string()
265 }
266
267 fn spec(&self) -> TaskSpec {
268 TaskSpec {
269 kind: "noop".to_string(),
270 description: "No-op task".to_string(),
271 input_schema: serde_json::json!({ "type": "object" }),
272 default_timeout_seconds: Some(30),
273 metadata: serde_json::json!({ "category": "test" }),
274 }
275 }
276
277 async fn execute(
278 &self,
279 ctx: TaskExecutionContext,
280 input: serde_json::Value,
281 ) -> anyhow::Result<TaskExecutionResult> {
282 Ok(TaskExecutionResult::success(serde_json::json!({
283 "task_id": ctx.task_id,
284 "input": input,
285 })))
286 }
287 }
288
289 #[test]
290 fn task_handle_round_trips_json() {
291 let handle = TaskHandle {
292 task_id: "task-1".to_string(),
293 executor_id: "process".to_string(),
294 spec: TaskSpec {
295 kind: "process".to_string(),
296 description: "Run a process".to_string(),
297 input_schema: serde_json::json!({ "type": "object" }),
298 default_timeout_seconds: Some(60),
299 metadata: serde_json::json!({}),
300 },
301 state: TaskState::Queued,
302 created_at: OffsetDateTime::UNIX_EPOCH,
303 started_at: None,
304 finished_at: None,
305 };
306
307 let encoded = serde_json::to_string(&handle).expect("serialize task handle");
308 let decoded: TaskHandle = serde_json::from_str(&encoded).expect("deserialize task handle");
309
310 assert_eq!(decoded, handle);
311 }
312
313 #[test]
314 fn task_events_round_trip_json() {
315 let started = TaskStarted {
316 task_id: "task-1".to_string(),
317 executor_id: "process".to_string(),
318 task_kind: "process".to_string(),
319 queue_depth: 0,
320 thread_id: Some("thread-a".to_string()),
321 turn_id: Some("turn-a".to_string()),
322 timestamp: OffsetDateTime::UNIX_EPOCH,
323 };
324 let output = TaskOutput {
325 task_id: "task-1".to_string(),
326 stream: TaskOutputStream::Stdout,
327 chunk: "hello\n".to_string(),
328 dropped_bytes: 0,
329 thread_id: Some("thread-a".to_string()),
330 turn_id: Some("turn-a".to_string()),
331 timestamp: OffsetDateTime::UNIX_EPOCH,
332 };
333
334 assert_eq!(
335 serde_json::from_value::<TaskStarted>(serde_json::to_value(&started).unwrap()).unwrap(),
336 started
337 );
338 assert_eq!(
339 serde_json::from_value::<TaskOutput>(serde_json::to_value(&output).unwrap()).unwrap(),
340 output
341 );
342 }
343
344 #[tokio::test]
345 async fn task_executor_trait_is_object_safe() {
346 let executor: Arc<dyn TaskExecutor> = Arc::new(NoopTaskExecutor);
347 let result = executor
348 .execute(
349 TaskExecutionContext {
350 task_id: "task-1".to_string(),
351 thread_id: None,
352 turn_id: None,
353 workspace_root: None,
354 runner_destination: None,
355 runner_session: None,
356 deadline: None,
357 metadata: serde_json::json!({}),
358 process_registry: None,
359 output: TaskOutputSink::default(),
360 },
361 serde_json::json!({ "ok": true }),
362 )
363 .await
364 .unwrap();
365
366 assert_eq!(executor.id(), "noop-task");
367 assert_eq!(executor.spec().kind, "noop");
368 assert_eq!(result.payload["task_id"], "task-1");
369 }
370}