Skip to main content

roder_tasks/
process_registry.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use roder_api::events::RoderEvent;
5use roder_api::processes::{
6    ProcessDescriptor, ProcessExited, ProcessFailed, ProcessId, ProcessOutput, ProcessRegistrySink,
7    ProcessState, ProcessStopResult, ProcessStopped, ProcessStopper, ProcessStopping,
8};
9use roder_api::tasks::TaskOutputStream;
10use time::OffsetDateTime;
11use tokio::sync::{Mutex, broadcast};
12
13#[derive(Debug, Clone)]
14pub struct ProcessRegistryConfig {
15    pub max_completed: usize,
16    pub max_output_bytes: usize,
17}
18
19impl Default for ProcessRegistryConfig {
20    fn default() -> Self {
21        Self {
22            max_completed: 64,
23            max_output_bytes: 64 * 1024,
24        }
25    }
26}
27
28#[derive(Clone)]
29pub struct ProcessRegistry {
30    inner: Arc<Mutex<ProcessRegistryInner>>,
31    events: broadcast::Sender<RoderEvent>,
32}
33
34#[derive(Default)]
35struct ProcessRegistryInner {
36    config: ProcessRegistryConfig,
37    processes: BTreeMap<ProcessId, ProcessRecord>,
38}
39
40struct ProcessRecord {
41    descriptor: ProcessDescriptor,
42    output: Vec<ProcessOutput>,
43    output_bytes: usize,
44    stopper: Option<Arc<dyn ProcessStopper>>,
45}
46
47impl ProcessRegistry {
48    pub fn new(config: ProcessRegistryConfig) -> Self {
49        let (events, _) = broadcast::channel(1024);
50        Self {
51            inner: Arc::new(Mutex::new(ProcessRegistryInner {
52                config,
53                processes: BTreeMap::new(),
54            })),
55            events,
56        }
57    }
58
59    pub fn subscribe(&self) -> broadcast::Receiver<RoderEvent> {
60        self.events.subscribe()
61    }
62
63    pub async fn register(
64        &self,
65        mut process: ProcessDescriptor,
66        stopper: Option<Arc<dyn ProcessStopper>>,
67    ) -> anyhow::Result<ProcessDescriptor> {
68        process.updated_at = OffsetDateTime::now_utc();
69        if process.started_at > process.updated_at {
70            process.started_at = process.updated_at;
71        }
72        let registered = process.clone();
73        {
74            let mut inner = self.inner.lock().await;
75            inner.processes.insert(
76                process.process_id.clone(),
77                ProcessRecord {
78                    descriptor: process,
79                    output: Vec::new(),
80                    output_bytes: 0,
81                    stopper,
82                },
83            );
84            inner.prune_completed();
85        }
86        self.emit(RoderEvent::ProcessStarted(
87            roder_api::processes::ProcessStarted {
88                process: registered.clone(),
89                timestamp: OffsetDateTime::now_utc(),
90            },
91        ));
92        Ok(registered)
93    }
94
95    pub async fn list(&self, include_completed: bool) -> Vec<ProcessDescriptor> {
96        self.inner
97            .lock()
98            .await
99            .processes
100            .values()
101            .filter(|record| include_completed || !is_terminal(&record.descriptor.state))
102            .map(|record| record.descriptor.clone())
103            .collect()
104    }
105
106    pub async fn get(&self, process_id: &str) -> Option<ProcessDescriptor> {
107        self.inner
108            .lock()
109            .await
110            .processes
111            .get(process_id)
112            .map(|record| record.descriptor.clone())
113    }
114
115    pub async fn output(&self, process_id: &str) -> Vec<ProcessOutput> {
116        self.inner
117            .lock()
118            .await
119            .processes
120            .get(process_id)
121            .map(|record| record.output.clone())
122            .unwrap_or_default()
123    }
124
125    pub async fn output_for_task(&self, task_id: &str) -> Vec<ProcessOutput> {
126        self.inner
127            .lock()
128            .await
129            .processes
130            .values()
131            .find(|record| record.descriptor.task_id.as_deref() == Some(task_id))
132            .map(|record| record.output.clone())
133            .unwrap_or_default()
134    }
135
136    pub async fn append_output(&self, output: ProcessOutput) -> anyhow::Result<()> {
137        let stored = {
138            let mut inner = self.inner.lock().await;
139            let max_output_bytes = inner.config.max_output_bytes;
140            let Some(record) = inner.processes.get_mut(&output.process_id) else {
141                anyhow::bail!("unknown process {:?}", output.process_id);
142            };
143            let stream = output.stream.clone();
144            let chunk = output.chunk.clone();
145            let chunk_len = chunk.len();
146            record.output.push(output.clone());
147            record.output_bytes = record.output_bytes.saturating_add(chunk_len);
148            while record.output_bytes > max_output_bytes {
149                let Some(removed) = record.output.first().cloned() else {
150                    break;
151                };
152                record.output.remove(0);
153                record.output_bytes = record.output_bytes.saturating_sub(removed.chunk.len());
154            }
155            match stream {
156                TaskOutputStream::Stdout => record.descriptor.stdout_tail = Some(chunk),
157                TaskOutputStream::Stderr => record.descriptor.stderr_tail = Some(chunk),
158                TaskOutputStream::Log => {}
159            }
160            record.descriptor.updated_at = OffsetDateTime::now_utc();
161            output
162        };
163        self.emit(RoderEvent::ProcessOutput(stored));
164        Ok(())
165    }
166
167    pub async fn mark_exited(
168        &self,
169        process_id: &str,
170        exit_code: Option<i32>,
171    ) -> anyhow::Result<()> {
172        let process = self
173            .update_terminal(process_id, ProcessState::Exited { exit_code })
174            .await?;
175        self.emit(RoderEvent::ProcessExited(ProcessExited {
176            process,
177            exit_code,
178            timestamp: OffsetDateTime::now_utc(),
179        }));
180        Ok(())
181    }
182
183    pub async fn mark_failed(&self, process_id: &str, error: String) -> anyhow::Result<()> {
184        let process = self
185            .update_terminal(
186                process_id,
187                ProcessState::Failed {
188                    error: error.clone(),
189                },
190            )
191            .await?;
192        self.emit(RoderEvent::ProcessFailed(ProcessFailed {
193            process,
194            error,
195            timestamp: OffsetDateTime::now_utc(),
196        }));
197        Ok(())
198    }
199
200    pub async fn mark_stopped(
201        &self,
202        process_id: &str,
203        reason: Option<String>,
204    ) -> anyhow::Result<()> {
205        let process = self
206            .update_terminal(process_id, ProcessState::Stopped)
207            .await?;
208        self.emit(RoderEvent::ProcessStopped(ProcessStopped {
209            process,
210            reason,
211            timestamp: OffsetDateTime::now_utc(),
212        }));
213        Ok(())
214    }
215
216    pub async fn stop(
217        &self,
218        process_id: &str,
219        reason: Option<String>,
220    ) -> anyhow::Result<ProcessStopResult> {
221        let (stopper, process) = {
222            let mut inner = self.inner.lock().await;
223            let Some(record) = inner.processes.get_mut(process_id) else {
224                anyhow::bail!("unknown process {process_id:?}");
225            };
226            if is_terminal(&record.descriptor.state) || !record.descriptor.stoppable {
227                return Ok(ProcessStopResult {
228                    process_id: process_id.to_string(),
229                    stopped: false,
230                    process: Some(record.descriptor.clone()),
231                });
232            }
233            record.descriptor.state = ProcessState::Stopping;
234            record.descriptor.updated_at = OffsetDateTime::now_utc();
235            let process = record.descriptor.clone();
236            (record.stopper.clone(), process)
237        };
238        self.emit(RoderEvent::ProcessStopping(ProcessStopping {
239            process_id: process_id.to_string(),
240            reason: reason.clone(),
241            timestamp: OffsetDateTime::now_utc(),
242        }));
243        if let Some(stopper) = stopper
244            && let Err(error) = stopper.stop(reason).await
245        {
246            let mut inner = self.inner.lock().await;
247            if let Some(record) = inner.processes.get_mut(process_id)
248                && matches!(record.descriptor.state, ProcessState::Stopping)
249            {
250                record.descriptor.state = ProcessState::Running;
251                record.descriptor.updated_at = OffsetDateTime::now_utc();
252            }
253            return Err(error);
254        }
255        Ok(ProcessStopResult {
256            process_id: process_id.to_string(),
257            stopped: true,
258            process: Some(process),
259        })
260    }
261
262    pub async fn stop_all(&self, reason: Option<String>) -> Vec<ProcessStopResult> {
263        let process_ids = {
264            self.inner
265                .lock()
266                .await
267                .processes
268                .values()
269                .filter(|record| {
270                    record.descriptor.stoppable && !is_terminal(&record.descriptor.state)
271                })
272                .map(|record| record.descriptor.process_id.clone())
273                .collect::<Vec<_>>()
274        };
275        let mut results = Vec::new();
276        for process_id in process_ids {
277            match self.stop(&process_id, reason.clone()).await {
278                Ok(result) => results.push(result),
279                Err(_) => results.push(ProcessStopResult {
280                    process_id,
281                    stopped: false,
282                    process: None,
283                }),
284            }
285        }
286        results
287    }
288
289    pub async fn append_task_output(
290        &self,
291        task_id: &str,
292        stream: TaskOutputStream,
293        chunk: String,
294        dropped_bytes: u64,
295        thread_id: Option<String>,
296        turn_id: Option<String>,
297    ) -> anyhow::Result<()> {
298        let process_id = {
299            self.inner
300                .lock()
301                .await
302                .processes
303                .values()
304                .find(|record| record.descriptor.task_id.as_deref() == Some(task_id))
305                .map(|record| record.descriptor.process_id.clone())
306        };
307        if let Some(process_id) = process_id {
308            self.append_output(ProcessOutput {
309                process_id,
310                stream,
311                chunk,
312                dropped_bytes,
313                thread_id,
314                turn_id,
315                timestamp: OffsetDateTime::now_utc(),
316            })
317            .await?;
318        }
319        Ok(())
320    }
321
322    async fn update_terminal(
323        &self,
324        process_id: &str,
325        state: ProcessState,
326    ) -> anyhow::Result<ProcessDescriptor> {
327        let process = {
328            let mut inner = self.inner.lock().await;
329            let Some(record) = inner.processes.get_mut(process_id) else {
330                anyhow::bail!("unknown process {process_id:?}");
331            };
332            if is_terminal(&record.descriptor.state) {
333                return Ok(record.descriptor.clone());
334            }
335            record.descriptor.state = state;
336            record.descriptor.stoppable = false;
337            record.descriptor.updated_at = OffsetDateTime::now_utc();
338            record.stopper = None;
339            let process = record.descriptor.clone();
340            inner.prune_completed();
341            process
342        };
343        Ok(process)
344    }
345
346    fn emit(&self, event: RoderEvent) {
347        let _ = self.events.send(event);
348    }
349}
350
351impl Default for ProcessRegistry {
352    fn default() -> Self {
353        Self::new(ProcessRegistryConfig::default())
354    }
355}
356
357#[async_trait::async_trait]
358impl ProcessRegistrySink for ProcessRegistry {
359    async fn register_process(
360        &self,
361        process: ProcessDescriptor,
362        stopper: Option<Arc<dyn ProcessStopper>>,
363    ) -> anyhow::Result<ProcessDescriptor> {
364        self.register(process, stopper).await
365    }
366
367    async fn append_process_output(&self, output: ProcessOutput) -> anyhow::Result<()> {
368        self.append_output(output).await
369    }
370
371    async fn mark_process_exited(
372        &self,
373        process_id: &str,
374        exit_code: Option<i32>,
375    ) -> anyhow::Result<()> {
376        self.mark_exited(process_id, exit_code).await
377    }
378
379    async fn mark_process_failed(&self, process_id: &str, error: String) -> anyhow::Result<()> {
380        self.mark_failed(process_id, error).await
381    }
382
383    async fn mark_process_stopped(
384        &self,
385        process_id: &str,
386        reason: Option<String>,
387    ) -> anyhow::Result<()> {
388        self.mark_stopped(process_id, reason).await
389    }
390}
391
392impl ProcessRegistryInner {
393    fn prune_completed(&mut self) {
394        let completed = self
395            .processes
396            .values()
397            .filter(|record| is_terminal(&record.descriptor.state))
398            .count();
399        if completed <= self.config.max_completed {
400            return;
401        }
402        let remove_count = completed - self.config.max_completed;
403        let mut terminal = self
404            .processes
405            .values()
406            .filter(|record| is_terminal(&record.descriptor.state))
407            .map(|record| {
408                (
409                    record.descriptor.updated_at,
410                    record.descriptor.process_id.clone(),
411                )
412            })
413            .collect::<Vec<_>>();
414        terminal.sort_by_key(|(updated_at, _)| *updated_at);
415        for (_, process_id) in terminal.into_iter().take(remove_count) {
416            self.processes.remove(&process_id);
417        }
418    }
419}
420
421fn is_terminal(state: &ProcessState) -> bool {
422    matches!(
423        state,
424        ProcessState::Exited { .. } | ProcessState::Failed { .. } | ProcessState::Stopped
425    )
426}