1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use roder_api::events::{EventEnvelope, RoderEvent, ThreadId, TurnId};
5use roder_api::extension::TaskExecutorId;
6use roder_api::remote_runner::{RemoteRunnerSession, RunnerDestination};
7use roder_api::tasks::{
8 TaskCancelled, TaskCompleted, TaskExecutionContext, TaskFailed, TaskHandle, TaskId, TaskOutput,
9 TaskOutputSink, TaskOutputStream, TaskOutputWriter, TaskStarted, TaskState,
10};
11use time::OffsetDateTime;
12use tokio::sync::{Mutex, Semaphore, broadcast};
13use tokio::task::AbortHandle;
14
15use crate::log_buffer::{BoundedLogBuffer, TaskLogEntry};
16use crate::process_registry::ProcessRegistry;
17use crate::registry::TaskExecutorRegistry;
18
19#[derive(Debug, Clone)]
20pub struct BackgroundRunnerConfig {
21 pub max_concurrent: usize,
22 pub max_log_bytes: usize,
23 pub auto_cancel_on_session_end: bool,
24}
25
26impl Default for BackgroundRunnerConfig {
27 fn default() -> Self {
28 Self {
29 max_concurrent: 4,
30 max_log_bytes: 64 * 1024,
31 auto_cancel_on_session_end: true,
32 }
33 }
34}
35
36#[derive(Clone, Default)]
37pub struct TaskSubmitOptions {
38 pub thread_id: Option<ThreadId>,
39 pub turn_id: Option<TurnId>,
40 pub workspace_root: Option<String>,
41 pub runner_destination: Option<RunnerDestination>,
42 pub runner_session: Option<Arc<dyn RemoteRunnerSession>>,
43 pub deadline: Option<OffsetDateTime>,
44 pub metadata: serde_json::Value,
45}
46
47#[derive(Clone)]
48pub struct BackgroundRunner {
49 registry: TaskExecutorRegistry,
50 config: BackgroundRunnerConfig,
51 semaphore: Arc<Semaphore>,
52 tasks: Arc<Mutex<BTreeMap<TaskId, TaskRecord>>>,
53 processes: ProcessRegistry,
54 events: broadcast::Sender<RoderEvent>,
55}
56
57struct TaskRecord {
58 handle: TaskHandle,
59 log: BoundedLogBuffer,
60 abort_handle: Option<AbortHandle>,
61 thread_id: Option<ThreadId>,
62 turn_id: Option<TurnId>,
63}
64
65impl BackgroundRunner {
66 pub fn new(registry: TaskExecutorRegistry, config: BackgroundRunnerConfig) -> Self {
67 let (events, _) = broadcast::channel(1024);
68 let processes = ProcessRegistry::default();
69 if tokio::runtime::Handle::try_current().is_ok() {
70 let mut process_events = processes.subscribe();
71 let task_events = events.clone();
72 tokio::spawn(async move {
73 while let Ok(event) = process_events.recv().await {
74 let _ = task_events.send(event);
75 }
76 });
77 }
78 Self {
79 registry,
80 semaphore: Arc::new(Semaphore::new(config.max_concurrent.max(1))),
81 config,
82 tasks: Arc::new(Mutex::new(BTreeMap::new())),
83 processes,
84 events,
85 }
86 }
87
88 pub fn subscribe(&self) -> broadcast::Receiver<RoderEvent> {
89 self.events.subscribe()
90 }
91
92 pub fn processes(&self) -> ProcessRegistry {
93 self.processes.clone()
94 }
95
96 pub async fn submit(
97 &self,
98 executor_id: impl Into<TaskExecutorId>,
99 input: serde_json::Value,
100 options: TaskSubmitOptions,
101 ) -> anyhow::Result<TaskHandle> {
102 let executor_id = executor_id.into();
103 let executor = self
104 .registry
105 .get(&executor_id)
106 .ok_or_else(|| anyhow::anyhow!("unknown task executor {executor_id:?}"))?;
107 let spec = executor.spec();
108 let task_id = uuid::Uuid::new_v4().to_string();
109 let handle = TaskHandle {
110 task_id: task_id.clone(),
111 executor_id: executor_id.clone(),
112 spec: spec.clone(),
113 state: TaskState::Queued,
114 created_at: OffsetDateTime::now_utc(),
115 started_at: None,
116 finished_at: None,
117 };
118
119 {
120 let mut tasks = self.tasks.lock().await;
121 tasks.insert(
122 task_id.clone(),
123 TaskRecord {
124 handle: handle.clone(),
125 log: BoundedLogBuffer::new(self.config.max_log_bytes),
126 abort_handle: None,
127 thread_id: options.thread_id.clone(),
128 turn_id: options.turn_id.clone(),
129 },
130 );
131 }
132
133 let runner = self.clone();
134 let task_id_for_spawn = task_id.clone();
135 let spawn_options = options.clone();
136 let join = tokio::spawn(async move {
137 runner
138 .run_task(
139 task_id_for_spawn,
140 executor_id,
141 executor,
142 input,
143 spawn_options,
144 )
145 .await;
146 });
147 let abort_handle = join.abort_handle();
148 {
149 let mut tasks = self.tasks.lock().await;
150 if let Some(record) = tasks.get_mut(&task_id) {
151 record.abort_handle = Some(abort_handle);
152 }
153 }
154
155 Ok(handle)
156 }
157
158 pub async fn cancel(&self, task_id: &str, reason: Option<String>) -> anyhow::Result<bool> {
159 let cancelled = {
160 let mut tasks = self.tasks.lock().await;
161 let Some(record) = tasks.get_mut(task_id) else {
162 anyhow::bail!("unknown task {task_id:?}");
163 };
164 if matches!(
165 record.handle.state,
166 TaskState::Completed | TaskState::Failed | TaskState::Cancelled
167 ) {
168 return Ok(false);
169 }
170 record.handle.state = TaskState::Cancelled;
171 record.handle.finished_at = Some(OffsetDateTime::now_utc());
172 if let Some(abort_handle) = record.abort_handle.take() {
173 abort_handle.abort();
174 }
175 true
176 };
177
178 if cancelled {
179 self.emit(RoderEvent::TaskCancelled(TaskCancelled {
180 task_id: task_id.to_string(),
181 reason,
182 thread_id: self.thread_id(task_id).await,
183 turn_id: self.turn_id(task_id).await,
184 timestamp: OffsetDateTime::now_utc(),
185 }));
186 }
187
188 Ok(cancelled)
189 }
190
191 pub async fn list(&self) -> Vec<TaskHandle> {
192 self.tasks
193 .lock()
194 .await
195 .values()
196 .map(|record| record.handle.clone())
197 .collect()
198 }
199
200 pub async fn get(&self, task_id: &str) -> Option<TaskHandle> {
201 self.tasks
202 .lock()
203 .await
204 .get(task_id)
205 .map(|record| record.handle.clone())
206 }
207
208 pub async fn logs(&self, task_id: &str) -> Option<(Vec<TaskLogEntry>, u64)> {
209 self.tasks
210 .lock()
211 .await
212 .get(task_id)
213 .map(|record| (record.log.entries(), record.log.dropped_bytes()))
214 }
215
216 pub async fn handle_event(&self, envelope: &EventEnvelope) -> anyhow::Result<()> {
217 if !self.config.auto_cancel_on_session_end {
218 return Ok(());
219 }
220 if !matches!(
221 envelope.kind.as_str(),
222 "session.ended" | "turn.completed" | "turn.failed" | "turn.interrupted"
223 ) {
224 return Ok(());
225 }
226 let Some(thread_id) = envelope.thread_id.as_deref() else {
227 return Ok(());
228 };
229 let task_ids = {
230 self.tasks
231 .lock()
232 .await
233 .iter()
234 .filter_map(|(task_id, record)| {
235 let active = !matches!(
236 record.handle.state,
237 TaskState::Completed | TaskState::Failed | TaskState::Cancelled
238 );
239 let same_thread =
240 active && self.record_thread_id(record).as_deref() == Some(thread_id);
241 same_thread.then(|| task_id.clone())
242 })
243 .collect::<Vec<_>>()
244 };
245 for task_id in task_ids {
246 self.cancel(&task_id, Some("session ended".to_string()))
247 .await?;
248 }
249 Ok(())
250 }
251
252 async fn run_task(
253 &self,
254 task_id: TaskId,
255 executor_id: TaskExecutorId,
256 executor: Arc<dyn roder_api::tasks::TaskExecutor>,
257 input: serde_json::Value,
258 options: TaskSubmitOptions,
259 ) {
260 let permit = match self.semaphore.clone().acquire_owned().await {
261 Ok(permit) => permit,
262 Err(_) => return,
263 };
264 let _permit = permit;
265
266 let queue_depth = {
267 let mut tasks = self.tasks.lock().await;
268 let queue_depth = tasks
269 .values()
270 .filter(|record| record.handle.state == TaskState::Queued)
271 .count()
272 .saturating_sub(1);
273 if let Some(record) = tasks.get_mut(&task_id) {
274 if record.handle.state == TaskState::Cancelled {
275 return;
276 }
277 record.handle.state = TaskState::Running;
278 record.handle.started_at = Some(OffsetDateTime::now_utc());
279 }
280 queue_depth
281 };
282
283 self.emit(RoderEvent::TaskStarted(TaskStarted {
284 task_id: task_id.clone(),
285 executor_id,
286 task_kind: executor.spec().kind,
287 thread_id: options.thread_id.clone(),
288 turn_id: options.turn_id.clone(),
289 queue_depth,
290 timestamp: OffsetDateTime::now_utc(),
291 }));
292
293 let ctx = TaskExecutionContext {
294 task_id: task_id.clone(),
295 thread_id: options.thread_id.clone(),
296 turn_id: options.turn_id.clone(),
297 workspace_root: options.workspace_root,
298 runner_destination: options.runner_destination,
299 runner_session: options.runner_session,
300 deadline: options.deadline,
301 metadata: options.metadata,
302 process_registry: Some(Arc::new(self.processes.clone())),
303 output: TaskOutputSink::new(Arc::new(RunnerOutputWriter {
304 runner: self.clone(),
305 task_id: task_id.clone(),
306 thread_id: options.thread_id.clone(),
307 turn_id: options.turn_id.clone(),
308 })),
309 };
310
311 let mut timeout_partial_result = None;
312 let result = if let Some(deadline) = options.deadline {
313 let now = OffsetDateTime::now_utc();
314 let duration = (deadline - now).unsigned_abs();
315 let deadline_instant = if deadline > now {
316 tokio::time::Instant::now() + duration
317 } else {
318 tokio::time::Instant::now()
319 };
320 match tokio::time::timeout_at(deadline_instant, executor.execute(ctx, input)).await {
321 Ok(result) => result,
322 Err(_) => {
323 let partial = self.partial_result(&task_id).await;
324 timeout_partial_result = Some(partial.clone());
325 self.emit(RoderEvent::TaskOutput(TaskOutput {
326 task_id: task_id.clone(),
327 stream: TaskOutputStream::Log,
328 chunk: format!("task deadline expired; partial result: {partial}"),
329 dropped_bytes: 0,
330 thread_id: options.thread_id.clone(),
331 turn_id: options.turn_id.clone(),
332 timestamp: OffsetDateTime::now_utc(),
333 }));
334 Err(anyhow::anyhow!("task deadline expired"))
335 }
336 }
337 } else {
338 executor.execute(ctx, input).await
339 };
340
341 match result {
342 Ok(payload) => {
343 self.finish_task(&task_id, TaskState::Completed).await;
344 self.emit(RoderEvent::TaskCompleted(TaskCompleted {
345 task_id,
346 exit_code: payload.exit_code,
347 payload: payload.payload,
348 thread_id: options.thread_id,
349 turn_id: options.turn_id,
350 timestamp: OffsetDateTime::now_utc(),
351 }));
352 }
353 Err(error) => {
354 self.finish_task(&task_id, TaskState::Failed).await;
355 self.emit(RoderEvent::TaskFailed(TaskFailed {
356 task_id,
357 error: error.to_string(),
358 error_kind: timeout_partial_result
359 .as_ref()
360 .map(|_| "deadline_timeout".to_string()),
361 partial_result: timeout_partial_result,
362 thread_id: options.thread_id,
363 turn_id: options.turn_id,
364 timestamp: OffsetDateTime::now_utc(),
365 }));
366 }
367 }
368 }
369
370 async fn finish_task(&self, task_id: &str, state: TaskState) {
371 let mut tasks = self.tasks.lock().await;
372 if let Some(record) = tasks.get_mut(task_id) {
373 if record.handle.state == TaskState::Cancelled {
374 return;
375 }
376 record.handle.state = state;
377 record.handle.finished_at = Some(OffsetDateTime::now_utc());
378 record.abort_handle = None;
379 }
380 }
381
382 async fn append_output(
383 &self,
384 task_id: &str,
385 stream: TaskOutputStream,
386 chunk: String,
387 thread_id: Option<ThreadId>,
388 turn_id: Option<TurnId>,
389 ) -> anyhow::Result<()> {
390 let dropped_bytes = {
391 let mut tasks = self.tasks.lock().await;
392 let Some(record) = tasks.get_mut(task_id) else {
393 anyhow::bail!("unknown task {task_id:?}");
394 };
395 record.log.push(stream.clone(), chunk.clone())
396 };
397 let _ = self
398 .processes
399 .append_task_output(
400 task_id,
401 stream.clone(),
402 chunk.clone(),
403 dropped_bytes,
404 thread_id.clone(),
405 turn_id.clone(),
406 )
407 .await;
408 self.emit(RoderEvent::TaskOutput(TaskOutput {
409 task_id: task_id.to_string(),
410 stream,
411 chunk,
412 dropped_bytes,
413 thread_id,
414 turn_id,
415 timestamp: OffsetDateTime::now_utc(),
416 }));
417 Ok(())
418 }
419
420 async fn partial_result(&self, task_id: &str) -> String {
421 let Some((logs, dropped)) = self.logs(task_id).await else {
422 return "no task output captured before timeout".to_string();
423 };
424 if logs.is_empty() {
425 return "no task output captured before timeout".to_string();
426 }
427 let mut text = logs
428 .iter()
429 .rev()
430 .take(3)
431 .map(|entry| entry.chunk.trim())
432 .collect::<Vec<_>>();
433 text.reverse();
434 let mut partial = text.join("\n");
435 if dropped > 0 {
436 partial.push_str(&format!("\n... {dropped} bytes dropped"));
437 }
438 partial
439 }
440
441 fn emit(&self, event: RoderEvent) {
442 let _ = self.events.send(event);
443 }
444
445 async fn thread_id(&self, task_id: &str) -> Option<ThreadId> {
446 self.tasks
447 .lock()
448 .await
449 .get(task_id)
450 .and_then(|record| self.record_thread_id(record))
451 }
452
453 async fn turn_id(&self, task_id: &str) -> Option<TurnId> {
454 self.tasks
455 .lock()
456 .await
457 .get(task_id)
458 .and_then(|record| self.record_turn_id(record))
459 }
460
461 fn record_thread_id(&self, record: &TaskRecord) -> Option<ThreadId> {
462 record.thread_id.clone()
463 }
464
465 fn record_turn_id(&self, record: &TaskRecord) -> Option<TurnId> {
466 record.turn_id.clone()
467 }
468}
469
470struct RunnerOutputWriter {
471 runner: BackgroundRunner,
472 task_id: TaskId,
473 thread_id: Option<ThreadId>,
474 turn_id: Option<TurnId>,
475}
476
477#[async_trait::async_trait]
478impl TaskOutputWriter for RunnerOutputWriter {
479 async fn write(&self, stream: TaskOutputStream, chunk: String) -> anyhow::Result<()> {
480 self.runner
481 .append_output(
482 &self.task_id,
483 stream,
484 chunk,
485 self.thread_id.clone(),
486 self.turn_id.clone(),
487 )
488 .await
489 }
490}