1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use roder_api::events::RoderEvent;
5use roder_api::processes::{
6 ProcessDescriptor, ProcessExited, ProcessFailed, ProcessId, ProcessOutput, ProcessRegistrySink,
7 ProcessState, ProcessStopResult, ProcessStopped, ProcessStopper, ProcessStopping,
8};
9use roder_api::tasks::TaskOutputStream;
10use time::OffsetDateTime;
11use tokio::sync::{Mutex, broadcast};
12
13#[derive(Debug, Clone)]
14pub struct ProcessRegistryConfig {
15 pub max_completed: usize,
16 pub max_output_bytes: usize,
17}
18
19impl Default for ProcessRegistryConfig {
20 fn default() -> Self {
21 Self {
22 max_completed: 64,
23 max_output_bytes: 64 * 1024,
24 }
25 }
26}
27
28#[derive(Clone)]
29pub struct ProcessRegistry {
30 inner: Arc<Mutex<ProcessRegistryInner>>,
31 events: broadcast::Sender<RoderEvent>,
32}
33
34#[derive(Default)]
35struct ProcessRegistryInner {
36 config: ProcessRegistryConfig,
37 processes: BTreeMap<ProcessId, ProcessRecord>,
38}
39
40struct ProcessRecord {
41 descriptor: ProcessDescriptor,
42 output: Vec<ProcessOutput>,
43 output_bytes: usize,
44 stopper: Option<Arc<dyn ProcessStopper>>,
45}
46
47impl ProcessRegistry {
48 pub fn new(config: ProcessRegistryConfig) -> Self {
49 let (events, _) = broadcast::channel(1024);
50 Self {
51 inner: Arc::new(Mutex::new(ProcessRegistryInner {
52 config,
53 processes: BTreeMap::new(),
54 })),
55 events,
56 }
57 }
58
59 pub fn subscribe(&self) -> broadcast::Receiver<RoderEvent> {
60 self.events.subscribe()
61 }
62
63 pub async fn register(
64 &self,
65 mut process: ProcessDescriptor,
66 stopper: Option<Arc<dyn ProcessStopper>>,
67 ) -> anyhow::Result<ProcessDescriptor> {
68 process.updated_at = OffsetDateTime::now_utc();
69 if process.started_at > process.updated_at {
70 process.started_at = process.updated_at;
71 }
72 let registered = process.clone();
73 {
74 let mut inner = self.inner.lock().await;
75 inner.processes.insert(
76 process.process_id.clone(),
77 ProcessRecord {
78 descriptor: process,
79 output: Vec::new(),
80 output_bytes: 0,
81 stopper,
82 },
83 );
84 inner.prune_completed();
85 }
86 self.emit(RoderEvent::ProcessStarted(
87 roder_api::processes::ProcessStarted {
88 process: registered.clone(),
89 timestamp: OffsetDateTime::now_utc(),
90 },
91 ));
92 Ok(registered)
93 }
94
95 pub async fn list(&self, include_completed: bool) -> Vec<ProcessDescriptor> {
96 self.inner
97 .lock()
98 .await
99 .processes
100 .values()
101 .filter(|record| include_completed || !is_terminal(&record.descriptor.state))
102 .map(|record| record.descriptor.clone())
103 .collect()
104 }
105
106 pub async fn get(&self, process_id: &str) -> Option<ProcessDescriptor> {
107 self.inner
108 .lock()
109 .await
110 .processes
111 .get(process_id)
112 .map(|record| record.descriptor.clone())
113 }
114
115 pub async fn output(&self, process_id: &str) -> Vec<ProcessOutput> {
116 self.inner
117 .lock()
118 .await
119 .processes
120 .get(process_id)
121 .map(|record| record.output.clone())
122 .unwrap_or_default()
123 }
124
125 pub async fn output_for_task(&self, task_id: &str) -> Vec<ProcessOutput> {
126 self.inner
127 .lock()
128 .await
129 .processes
130 .values()
131 .find(|record| record.descriptor.task_id.as_deref() == Some(task_id))
132 .map(|record| record.output.clone())
133 .unwrap_or_default()
134 }
135
136 pub async fn append_output(&self, output: ProcessOutput) -> anyhow::Result<()> {
137 let stored = {
138 let mut inner = self.inner.lock().await;
139 let max_output_bytes = inner.config.max_output_bytes;
140 let Some(record) = inner.processes.get_mut(&output.process_id) else {
141 anyhow::bail!("unknown process {:?}", output.process_id);
142 };
143 let stream = output.stream.clone();
144 let chunk = output.chunk.clone();
145 let chunk_len = chunk.len();
146 record.output.push(output.clone());
147 record.output_bytes = record.output_bytes.saturating_add(chunk_len);
148 while record.output_bytes > max_output_bytes {
149 let Some(removed) = record.output.first().cloned() else {
150 break;
151 };
152 record.output.remove(0);
153 record.output_bytes = record.output_bytes.saturating_sub(removed.chunk.len());
154 }
155 match stream {
156 TaskOutputStream::Stdout => record.descriptor.stdout_tail = Some(chunk),
157 TaskOutputStream::Stderr => record.descriptor.stderr_tail = Some(chunk),
158 TaskOutputStream::Log => {}
159 }
160 record.descriptor.updated_at = OffsetDateTime::now_utc();
161 output
162 };
163 self.emit(RoderEvent::ProcessOutput(stored));
164 Ok(())
165 }
166
167 pub async fn mark_exited(
168 &self,
169 process_id: &str,
170 exit_code: Option<i32>,
171 ) -> anyhow::Result<()> {
172 let process = self
173 .update_terminal(process_id, ProcessState::Exited { exit_code })
174 .await?;
175 self.emit(RoderEvent::ProcessExited(ProcessExited {
176 process,
177 exit_code,
178 timestamp: OffsetDateTime::now_utc(),
179 }));
180 Ok(())
181 }
182
183 pub async fn mark_failed(&self, process_id: &str, error: String) -> anyhow::Result<()> {
184 let process = self
185 .update_terminal(
186 process_id,
187 ProcessState::Failed {
188 error: error.clone(),
189 },
190 )
191 .await?;
192 self.emit(RoderEvent::ProcessFailed(ProcessFailed {
193 process,
194 error,
195 timestamp: OffsetDateTime::now_utc(),
196 }));
197 Ok(())
198 }
199
200 pub async fn mark_stopped(
201 &self,
202 process_id: &str,
203 reason: Option<String>,
204 ) -> anyhow::Result<()> {
205 let process = self
206 .update_terminal(process_id, ProcessState::Stopped)
207 .await?;
208 self.emit(RoderEvent::ProcessStopped(ProcessStopped {
209 process,
210 reason,
211 timestamp: OffsetDateTime::now_utc(),
212 }));
213 Ok(())
214 }
215
216 pub async fn stop(
217 &self,
218 process_id: &str,
219 reason: Option<String>,
220 ) -> anyhow::Result<ProcessStopResult> {
221 let (stopper, process) = {
222 let mut inner = self.inner.lock().await;
223 let Some(record) = inner.processes.get_mut(process_id) else {
224 anyhow::bail!("unknown process {process_id:?}");
225 };
226 if is_terminal(&record.descriptor.state) || !record.descriptor.stoppable {
227 return Ok(ProcessStopResult {
228 process_id: process_id.to_string(),
229 stopped: false,
230 process: Some(record.descriptor.clone()),
231 });
232 }
233 record.descriptor.state = ProcessState::Stopping;
234 record.descriptor.updated_at = OffsetDateTime::now_utc();
235 let process = record.descriptor.clone();
236 (record.stopper.clone(), process)
237 };
238 self.emit(RoderEvent::ProcessStopping(ProcessStopping {
239 process_id: process_id.to_string(),
240 reason: reason.clone(),
241 timestamp: OffsetDateTime::now_utc(),
242 }));
243 if let Some(stopper) = stopper
244 && let Err(error) = stopper.stop(reason).await
245 {
246 let mut inner = self.inner.lock().await;
247 if let Some(record) = inner.processes.get_mut(process_id)
248 && matches!(record.descriptor.state, ProcessState::Stopping)
249 {
250 record.descriptor.state = ProcessState::Running;
251 record.descriptor.updated_at = OffsetDateTime::now_utc();
252 }
253 return Err(error);
254 }
255 Ok(ProcessStopResult {
256 process_id: process_id.to_string(),
257 stopped: true,
258 process: Some(process),
259 })
260 }
261
262 pub async fn stop_all(&self, reason: Option<String>) -> Vec<ProcessStopResult> {
263 let process_ids = {
264 self.inner
265 .lock()
266 .await
267 .processes
268 .values()
269 .filter(|record| {
270 record.descriptor.stoppable && !is_terminal(&record.descriptor.state)
271 })
272 .map(|record| record.descriptor.process_id.clone())
273 .collect::<Vec<_>>()
274 };
275 let mut results = Vec::new();
276 for process_id in process_ids {
277 match self.stop(&process_id, reason.clone()).await {
278 Ok(result) => results.push(result),
279 Err(_) => results.push(ProcessStopResult {
280 process_id,
281 stopped: false,
282 process: None,
283 }),
284 }
285 }
286 results
287 }
288
289 pub async fn append_task_output(
290 &self,
291 task_id: &str,
292 stream: TaskOutputStream,
293 chunk: String,
294 dropped_bytes: u64,
295 thread_id: Option<String>,
296 turn_id: Option<String>,
297 ) -> anyhow::Result<()> {
298 let process_id = {
299 self.inner
300 .lock()
301 .await
302 .processes
303 .values()
304 .find(|record| record.descriptor.task_id.as_deref() == Some(task_id))
305 .map(|record| record.descriptor.process_id.clone())
306 };
307 if let Some(process_id) = process_id {
308 self.append_output(ProcessOutput {
309 process_id,
310 stream,
311 chunk,
312 dropped_bytes,
313 thread_id,
314 turn_id,
315 timestamp: OffsetDateTime::now_utc(),
316 })
317 .await?;
318 }
319 Ok(())
320 }
321
322 async fn update_terminal(
323 &self,
324 process_id: &str,
325 state: ProcessState,
326 ) -> anyhow::Result<ProcessDescriptor> {
327 let process = {
328 let mut inner = self.inner.lock().await;
329 let Some(record) = inner.processes.get_mut(process_id) else {
330 anyhow::bail!("unknown process {process_id:?}");
331 };
332 if is_terminal(&record.descriptor.state) {
333 return Ok(record.descriptor.clone());
334 }
335 record.descriptor.state = state;
336 record.descriptor.stoppable = false;
337 record.descriptor.updated_at = OffsetDateTime::now_utc();
338 record.stopper = None;
339 let process = record.descriptor.clone();
340 inner.prune_completed();
341 process
342 };
343 Ok(process)
344 }
345
346 fn emit(&self, event: RoderEvent) {
347 let _ = self.events.send(event);
348 }
349}
350
351impl Default for ProcessRegistry {
352 fn default() -> Self {
353 Self::new(ProcessRegistryConfig::default())
354 }
355}
356
357#[async_trait::async_trait]
358impl ProcessRegistrySink for ProcessRegistry {
359 async fn register_process(
360 &self,
361 process: ProcessDescriptor,
362 stopper: Option<Arc<dyn ProcessStopper>>,
363 ) -> anyhow::Result<ProcessDescriptor> {
364 self.register(process, stopper).await
365 }
366
367 async fn append_process_output(&self, output: ProcessOutput) -> anyhow::Result<()> {
368 self.append_output(output).await
369 }
370
371 async fn mark_process_exited(
372 &self,
373 process_id: &str,
374 exit_code: Option<i32>,
375 ) -> anyhow::Result<()> {
376 self.mark_exited(process_id, exit_code).await
377 }
378
379 async fn mark_process_failed(&self, process_id: &str, error: String) -> anyhow::Result<()> {
380 self.mark_failed(process_id, error).await
381 }
382
383 async fn mark_process_stopped(
384 &self,
385 process_id: &str,
386 reason: Option<String>,
387 ) -> anyhow::Result<()> {
388 self.mark_stopped(process_id, reason).await
389 }
390}
391
392impl ProcessRegistryInner {
393 fn prune_completed(&mut self) {
394 let completed = self
395 .processes
396 .values()
397 .filter(|record| is_terminal(&record.descriptor.state))
398 .count();
399 if completed <= self.config.max_completed {
400 return;
401 }
402 let remove_count = completed - self.config.max_completed;
403 let mut terminal = self
404 .processes
405 .values()
406 .filter(|record| is_terminal(&record.descriptor.state))
407 .map(|record| {
408 (
409 record.descriptor.updated_at,
410 record.descriptor.process_id.clone(),
411 )
412 })
413 .collect::<Vec<_>>();
414 terminal.sort_by_key(|(updated_at, _)| *updated_at);
415 for (_, process_id) in terminal.into_iter().take(remove_count) {
416 self.processes.remove(&process_id);
417 }
418 }
419}
420
421fn is_terminal(state: &ProcessState) -> bool {
422 matches!(
423 state,
424 ProcessState::Exited { .. } | ProcessState::Failed { .. } | ProcessState::Stopped
425 )
426}