Skip to main content

roder_tasks/
runner.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use roder_api::events::{EventEnvelope, RoderEvent, ThreadId, TurnId};
5use roder_api::extension::TaskExecutorId;
6use roder_api::remote_runner::{RemoteRunnerSession, RunnerDestination};
7use roder_api::tasks::{
8    TaskCancelled, TaskCompleted, TaskExecutionContext, TaskFailed, TaskHandle, TaskId, TaskOutput,
9    TaskOutputSink, TaskOutputStream, TaskOutputWriter, TaskStarted, TaskState,
10};
11use time::OffsetDateTime;
12use tokio::sync::{Mutex, Semaphore, broadcast};
13use tokio::task::AbortHandle;
14
15use crate::log_buffer::{BoundedLogBuffer, TaskLogEntry};
16use crate::process_registry::ProcessRegistry;
17use crate::registry::TaskExecutorRegistry;
18
19#[derive(Debug, Clone)]
20pub struct BackgroundRunnerConfig {
21    pub max_concurrent: usize,
22    pub max_log_bytes: usize,
23    pub auto_cancel_on_session_end: bool,
24}
25
26impl Default for BackgroundRunnerConfig {
27    fn default() -> Self {
28        Self {
29            max_concurrent: 4,
30            max_log_bytes: 64 * 1024,
31            auto_cancel_on_session_end: true,
32        }
33    }
34}
35
36#[derive(Clone, Default)]
37pub struct TaskSubmitOptions {
38    pub thread_id: Option<ThreadId>,
39    pub turn_id: Option<TurnId>,
40    pub workspace_root: Option<String>,
41    pub runner_destination: Option<RunnerDestination>,
42    pub runner_session: Option<Arc<dyn RemoteRunnerSession>>,
43    pub deadline: Option<OffsetDateTime>,
44    pub metadata: serde_json::Value,
45}
46
47#[derive(Clone)]
48pub struct BackgroundRunner {
49    registry: TaskExecutorRegistry,
50    config: BackgroundRunnerConfig,
51    semaphore: Arc<Semaphore>,
52    tasks: Arc<Mutex<BTreeMap<TaskId, TaskRecord>>>,
53    processes: ProcessRegistry,
54    events: broadcast::Sender<RoderEvent>,
55}
56
57struct TaskRecord {
58    handle: TaskHandle,
59    log: BoundedLogBuffer,
60    abort_handle: Option<AbortHandle>,
61    thread_id: Option<ThreadId>,
62    turn_id: Option<TurnId>,
63}
64
65impl BackgroundRunner {
66    pub fn new(registry: TaskExecutorRegistry, config: BackgroundRunnerConfig) -> Self {
67        let (events, _) = broadcast::channel(1024);
68        let processes = ProcessRegistry::default();
69        if tokio::runtime::Handle::try_current().is_ok() {
70            let mut process_events = processes.subscribe();
71            let task_events = events.clone();
72            tokio::spawn(async move {
73                while let Ok(event) = process_events.recv().await {
74                    let _ = task_events.send(event);
75                }
76            });
77        }
78        Self {
79            registry,
80            semaphore: Arc::new(Semaphore::new(config.max_concurrent.max(1))),
81            config,
82            tasks: Arc::new(Mutex::new(BTreeMap::new())),
83            processes,
84            events,
85        }
86    }
87
88    pub fn subscribe(&self) -> broadcast::Receiver<RoderEvent> {
89        self.events.subscribe()
90    }
91
92    pub fn processes(&self) -> ProcessRegistry {
93        self.processes.clone()
94    }
95
96    pub async fn submit(
97        &self,
98        executor_id: impl Into<TaskExecutorId>,
99        input: serde_json::Value,
100        options: TaskSubmitOptions,
101    ) -> anyhow::Result<TaskHandle> {
102        let executor_id = executor_id.into();
103        let executor = self
104            .registry
105            .get(&executor_id)
106            .ok_or_else(|| anyhow::anyhow!("unknown task executor {executor_id:?}"))?;
107        let spec = executor.spec();
108        let task_id = uuid::Uuid::new_v4().to_string();
109        let handle = TaskHandle {
110            task_id: task_id.clone(),
111            executor_id: executor_id.clone(),
112            spec: spec.clone(),
113            state: TaskState::Queued,
114            created_at: OffsetDateTime::now_utc(),
115            started_at: None,
116            finished_at: None,
117        };
118
119        {
120            let mut tasks = self.tasks.lock().await;
121            tasks.insert(
122                task_id.clone(),
123                TaskRecord {
124                    handle: handle.clone(),
125                    log: BoundedLogBuffer::new(self.config.max_log_bytes),
126                    abort_handle: None,
127                    thread_id: options.thread_id.clone(),
128                    turn_id: options.turn_id.clone(),
129                },
130            );
131        }
132
133        let runner = self.clone();
134        let task_id_for_spawn = task_id.clone();
135        let spawn_options = options.clone();
136        let join = tokio::spawn(async move {
137            runner
138                .run_task(
139                    task_id_for_spawn,
140                    executor_id,
141                    executor,
142                    input,
143                    spawn_options,
144                )
145                .await;
146        });
147        let abort_handle = join.abort_handle();
148        {
149            let mut tasks = self.tasks.lock().await;
150            if let Some(record) = tasks.get_mut(&task_id) {
151                record.abort_handle = Some(abort_handle);
152            }
153        }
154
155        Ok(handle)
156    }
157
158    pub async fn cancel(&self, task_id: &str, reason: Option<String>) -> anyhow::Result<bool> {
159        let cancelled = {
160            let mut tasks = self.tasks.lock().await;
161            let Some(record) = tasks.get_mut(task_id) else {
162                anyhow::bail!("unknown task {task_id:?}");
163            };
164            if matches!(
165                record.handle.state,
166                TaskState::Completed | TaskState::Failed | TaskState::Cancelled
167            ) {
168                return Ok(false);
169            }
170            record.handle.state = TaskState::Cancelled;
171            record.handle.finished_at = Some(OffsetDateTime::now_utc());
172            if let Some(abort_handle) = record.abort_handle.take() {
173                abort_handle.abort();
174            }
175            true
176        };
177
178        if cancelled {
179            self.emit(RoderEvent::TaskCancelled(TaskCancelled {
180                task_id: task_id.to_string(),
181                reason,
182                thread_id: self.thread_id(task_id).await,
183                turn_id: self.turn_id(task_id).await,
184                timestamp: OffsetDateTime::now_utc(),
185            }));
186        }
187
188        Ok(cancelled)
189    }
190
191    pub async fn list(&self) -> Vec<TaskHandle> {
192        self.tasks
193            .lock()
194            .await
195            .values()
196            .map(|record| record.handle.clone())
197            .collect()
198    }
199
200    pub async fn get(&self, task_id: &str) -> Option<TaskHandle> {
201        self.tasks
202            .lock()
203            .await
204            .get(task_id)
205            .map(|record| record.handle.clone())
206    }
207
208    pub async fn logs(&self, task_id: &str) -> Option<(Vec<TaskLogEntry>, u64)> {
209        self.tasks
210            .lock()
211            .await
212            .get(task_id)
213            .map(|record| (record.log.entries(), record.log.dropped_bytes()))
214    }
215
216    pub async fn handle_event(&self, envelope: &EventEnvelope) -> anyhow::Result<()> {
217        if !self.config.auto_cancel_on_session_end {
218            return Ok(());
219        }
220        if !matches!(
221            envelope.kind.as_str(),
222            "session.ended" | "turn.completed" | "turn.failed" | "turn.interrupted"
223        ) {
224            return Ok(());
225        }
226        let Some(thread_id) = envelope.thread_id.as_deref() else {
227            return Ok(());
228        };
229        let task_ids = {
230            self.tasks
231                .lock()
232                .await
233                .iter()
234                .filter_map(|(task_id, record)| {
235                    let active = !matches!(
236                        record.handle.state,
237                        TaskState::Completed | TaskState::Failed | TaskState::Cancelled
238                    );
239                    let same_thread =
240                        active && self.record_thread_id(record).as_deref() == Some(thread_id);
241                    same_thread.then(|| task_id.clone())
242                })
243                .collect::<Vec<_>>()
244        };
245        for task_id in task_ids {
246            self.cancel(&task_id, Some("session ended".to_string()))
247                .await?;
248        }
249        Ok(())
250    }
251
252    async fn run_task(
253        &self,
254        task_id: TaskId,
255        executor_id: TaskExecutorId,
256        executor: Arc<dyn roder_api::tasks::TaskExecutor>,
257        input: serde_json::Value,
258        options: TaskSubmitOptions,
259    ) {
260        let permit = match self.semaphore.clone().acquire_owned().await {
261            Ok(permit) => permit,
262            Err(_) => return,
263        };
264        let _permit = permit;
265
266        let queue_depth = {
267            let mut tasks = self.tasks.lock().await;
268            let queue_depth = tasks
269                .values()
270                .filter(|record| record.handle.state == TaskState::Queued)
271                .count()
272                .saturating_sub(1);
273            if let Some(record) = tasks.get_mut(&task_id) {
274                if record.handle.state == TaskState::Cancelled {
275                    return;
276                }
277                record.handle.state = TaskState::Running;
278                record.handle.started_at = Some(OffsetDateTime::now_utc());
279            }
280            queue_depth
281        };
282
283        self.emit(RoderEvent::TaskStarted(TaskStarted {
284            task_id: task_id.clone(),
285            executor_id,
286            task_kind: executor.spec().kind,
287            thread_id: options.thread_id.clone(),
288            turn_id: options.turn_id.clone(),
289            queue_depth,
290            timestamp: OffsetDateTime::now_utc(),
291        }));
292
293        let ctx = TaskExecutionContext {
294            task_id: task_id.clone(),
295            thread_id: options.thread_id.clone(),
296            turn_id: options.turn_id.clone(),
297            workspace_root: options.workspace_root,
298            runner_destination: options.runner_destination,
299            runner_session: options.runner_session,
300            deadline: options.deadline,
301            metadata: options.metadata,
302            process_registry: Some(Arc::new(self.processes.clone())),
303            output: TaskOutputSink::new(Arc::new(RunnerOutputWriter {
304                runner: self.clone(),
305                task_id: task_id.clone(),
306                thread_id: options.thread_id.clone(),
307                turn_id: options.turn_id.clone(),
308            })),
309        };
310
311        let mut timeout_partial_result = None;
312        let result = if let Some(deadline) = options.deadline {
313            let now = OffsetDateTime::now_utc();
314            let duration = (deadline - now).unsigned_abs();
315            let deadline_instant = if deadline > now {
316                tokio::time::Instant::now() + duration
317            } else {
318                tokio::time::Instant::now()
319            };
320            match tokio::time::timeout_at(deadline_instant, executor.execute(ctx, input)).await {
321                Ok(result) => result,
322                Err(_) => {
323                    let partial = self.partial_result(&task_id).await;
324                    timeout_partial_result = Some(partial.clone());
325                    self.emit(RoderEvent::TaskOutput(TaskOutput {
326                        task_id: task_id.clone(),
327                        stream: TaskOutputStream::Log,
328                        chunk: format!("task deadline expired; partial result: {partial}"),
329                        dropped_bytes: 0,
330                        thread_id: options.thread_id.clone(),
331                        turn_id: options.turn_id.clone(),
332                        timestamp: OffsetDateTime::now_utc(),
333                    }));
334                    Err(anyhow::anyhow!("task deadline expired"))
335                }
336            }
337        } else {
338            executor.execute(ctx, input).await
339        };
340
341        match result {
342            Ok(payload) => {
343                self.finish_task(&task_id, TaskState::Completed).await;
344                self.emit(RoderEvent::TaskCompleted(TaskCompleted {
345                    task_id,
346                    exit_code: payload.exit_code,
347                    payload: payload.payload,
348                    thread_id: options.thread_id,
349                    turn_id: options.turn_id,
350                    timestamp: OffsetDateTime::now_utc(),
351                }));
352            }
353            Err(error) => {
354                self.finish_task(&task_id, TaskState::Failed).await;
355                self.emit(RoderEvent::TaskFailed(TaskFailed {
356                    task_id,
357                    error: error.to_string(),
358                    error_kind: timeout_partial_result
359                        .as_ref()
360                        .map(|_| "deadline_timeout".to_string()),
361                    partial_result: timeout_partial_result,
362                    thread_id: options.thread_id,
363                    turn_id: options.turn_id,
364                    timestamp: OffsetDateTime::now_utc(),
365                }));
366            }
367        }
368    }
369
370    async fn finish_task(&self, task_id: &str, state: TaskState) {
371        let mut tasks = self.tasks.lock().await;
372        if let Some(record) = tasks.get_mut(task_id) {
373            if record.handle.state == TaskState::Cancelled {
374                return;
375            }
376            record.handle.state = state;
377            record.handle.finished_at = Some(OffsetDateTime::now_utc());
378            record.abort_handle = None;
379        }
380    }
381
382    async fn append_output(
383        &self,
384        task_id: &str,
385        stream: TaskOutputStream,
386        chunk: String,
387        thread_id: Option<ThreadId>,
388        turn_id: Option<TurnId>,
389    ) -> anyhow::Result<()> {
390        let dropped_bytes = {
391            let mut tasks = self.tasks.lock().await;
392            let Some(record) = tasks.get_mut(task_id) else {
393                anyhow::bail!("unknown task {task_id:?}");
394            };
395            record.log.push(stream.clone(), chunk.clone())
396        };
397        let _ = self
398            .processes
399            .append_task_output(
400                task_id,
401                stream.clone(),
402                chunk.clone(),
403                dropped_bytes,
404                thread_id.clone(),
405                turn_id.clone(),
406            )
407            .await;
408        self.emit(RoderEvent::TaskOutput(TaskOutput {
409            task_id: task_id.to_string(),
410            stream,
411            chunk,
412            dropped_bytes,
413            thread_id,
414            turn_id,
415            timestamp: OffsetDateTime::now_utc(),
416        }));
417        Ok(())
418    }
419
420    async fn partial_result(&self, task_id: &str) -> String {
421        let Some((logs, dropped)) = self.logs(task_id).await else {
422            return "no task output captured before timeout".to_string();
423        };
424        if logs.is_empty() {
425            return "no task output captured before timeout".to_string();
426        }
427        let mut text = logs
428            .iter()
429            .rev()
430            .take(3)
431            .map(|entry| entry.chunk.trim())
432            .collect::<Vec<_>>();
433        text.reverse();
434        let mut partial = text.join("\n");
435        if dropped > 0 {
436            partial.push_str(&format!("\n... {dropped} bytes dropped"));
437        }
438        partial
439    }
440
441    fn emit(&self, event: RoderEvent) {
442        let _ = self.events.send(event);
443    }
444
445    async fn thread_id(&self, task_id: &str) -> Option<ThreadId> {
446        self.tasks
447            .lock()
448            .await
449            .get(task_id)
450            .and_then(|record| self.record_thread_id(record))
451    }
452
453    async fn turn_id(&self, task_id: &str) -> Option<TurnId> {
454        self.tasks
455            .lock()
456            .await
457            .get(task_id)
458            .and_then(|record| self.record_turn_id(record))
459    }
460
461    fn record_thread_id(&self, record: &TaskRecord) -> Option<ThreadId> {
462        record.thread_id.clone()
463    }
464
465    fn record_turn_id(&self, record: &TaskRecord) -> Option<TurnId> {
466        record.turn_id.clone()
467    }
468}
469
470struct RunnerOutputWriter {
471    runner: BackgroundRunner,
472    task_id: TaskId,
473    thread_id: Option<ThreadId>,
474    turn_id: Option<TurnId>,
475}
476
477#[async_trait::async_trait]
478impl TaskOutputWriter for RunnerOutputWriter {
479    async fn write(&self, stream: TaskOutputStream, chunk: String) -> anyhow::Result<()> {
480        self.runner
481            .append_output(
482                &self.task_id,
483                stream,
484                chunk,
485                self.thread_id.clone(),
486                self.turn_id.clone(),
487            )
488            .await
489    }
490}