1use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::SystemTime;
7
8use bytes::Bytes;
9use parking_lot::RwLock;
10use solti_model::{OutputChunk, OutputEvent, StreamKind, TaskId};
11use tokio::sync::broadcast;
12
13#[derive(Clone)]
18pub struct OutputSink {
19 attempt: u32,
20 seq_stdout: Arc<AtomicU64>,
21 seq_stderr: Arc<AtomicU64>,
22 sender: broadcast::Sender<OutputEvent>,
23}
24
25impl OutputSink {
26 pub fn new(sender: broadcast::Sender<OutputEvent>, attempt: u32) -> Self {
28 Self {
29 sender,
30 attempt,
31 seq_stdout: Arc::new(AtomicU64::new(0)),
32 seq_stderr: Arc::new(AtomicU64::new(0)),
33 }
34 }
35
36 pub fn stdout_line(&self, line: Bytes) {
42 let seq = self.seq_stdout.fetch_add(1, Ordering::Relaxed);
43 self.push(StreamKind::Stdout, seq, line);
44 }
45
46 pub fn stderr_line(&self, line: Bytes) {
48 let seq = self.seq_stderr.fetch_add(1, Ordering::Relaxed);
49 self.push(StreamKind::Stderr, seq, line);
50 }
51
52 pub fn attempt(&self) -> u32 {
54 self.attempt
55 }
56
57 fn push(&self, stream: StreamKind, seq: u64, line: Bytes) {
58 let chunk = OutputChunk {
59 attempt: self.attempt,
60 stream,
61 seq,
62 ts: SystemTime::now(),
63 line,
64 };
65 let _ = self.sender.send(OutputEvent::Chunk(chunk));
66 }
67}
68
69pub struct OutputRegistry {
78 channels: RwLock<HashMap<TaskId, broadcast::Sender<OutputEvent>>>,
79 capacity: usize,
80}
81
82impl OutputRegistry {
83 pub fn new(capacity: usize) -> Self {
85 Self {
86 channels: RwLock::new(HashMap::new()),
87 capacity,
88 }
89 }
90
91 pub fn ensure_channel(&self, task_id: TaskId) {
98 let mut channels = self.channels.write();
99 channels
100 .entry(task_id)
101 .or_insert_with(|| broadcast::channel::<OutputEvent>(self.capacity).0);
102 }
103
104 pub fn sink_for(&self, task_id: TaskId, attempt: u32) -> OutputSink {
109 let mut channels = self.channels.write();
110 let sender = channels
111 .entry(task_id)
112 .or_insert_with(|| broadcast::channel::<OutputEvent>(self.capacity).0)
113 .clone();
114 OutputSink::new(sender, attempt)
115 }
116
117 pub fn subscribe(&self, task_id: &TaskId) -> Option<broadcast::Receiver<OutputEvent>> {
119 let channels = self.channels.read();
120 channels.get(task_id).map(|s| s.subscribe())
121 }
122
123 pub fn announce_run_started(&self, task_id: &TaskId, attempt: u32) {
127 let channels = self.channels.read();
128 if let Some(sender) = channels.get(task_id) {
129 let _ = sender.send(OutputEvent::RunStarted {
130 attempt,
131 started_at: SystemTime::now(),
132 });
133 }
134 }
135
136 pub fn announce_run_finished(&self, task_id: &TaskId, attempt: u32, exit_code: Option<i32>) {
140 let channels = self.channels.read();
141 if let Some(sender) = channels.get(task_id) {
142 let _ = sender.send(OutputEvent::RunFinished {
143 attempt,
144 exit_code,
145 finished_at: SystemTime::now(),
146 });
147 }
148 }
149
150 pub fn evict(&self, task_id: &TaskId) {
152 let mut channels = self.channels.write();
153 channels.remove(task_id);
154 }
155
156 pub fn active_channels(&self) -> usize {
158 self.channels.read().len()
159 }
160}
161
162impl Default for OutputRegistry {
163 fn default() -> Self {
165 Self::new(1024)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use bytes::Bytes;
172 use solti_model::{OutputEvent, StreamKind, TaskId};
173 use tokio::sync::broadcast;
174
175 use super::{OutputRegistry, OutputSink};
176
177 #[tokio::test]
178 async fn output_sink_pushes_stdout_line_to_subscriber() {
179 let (tx, mut rx) = broadcast::channel::<OutputEvent>(16);
180 let sink = OutputSink::new(tx, 1);
181
182 sink.stdout_line(Bytes::from_static(b"hello"));
183
184 match rx.recv().await.unwrap() {
185 OutputEvent::Chunk(chunk) => {
186 assert_eq!(chunk.attempt, 1);
187 assert_eq!(chunk.stream, StreamKind::Stdout);
188 assert_eq!(chunk.seq, 0);
189 assert_eq!(&chunk.line[..], b"hello");
190 }
191 other => panic!("expected Chunk, got {other:?}"),
192 }
193 }
194
195 #[tokio::test]
196 async fn output_sink_pushes_stderr_line_to_subscriber() {
197 let (tx, mut rx) = broadcast::channel::<OutputEvent>(16);
198 let sink = OutputSink::new(tx, 5);
199
200 sink.stderr_line(Bytes::from_static(b"oops"));
201
202 match rx.recv().await.unwrap() {
203 OutputEvent::Chunk(chunk) => {
204 assert_eq!(chunk.attempt, 5);
205 assert_eq!(chunk.stream, StreamKind::Stderr);
206 assert_eq!(&chunk.line[..], b"oops");
207 }
208 other => panic!("expected Chunk, got {other:?}"),
209 }
210 }
211
212 #[tokio::test]
213 async fn output_sink_assigns_monotonic_seq_per_stream() {
214 let (tx, mut rx) = broadcast::channel::<OutputEvent>(16);
215 let sink = OutputSink::new(tx, 1);
216
217 sink.stdout_line(Bytes::from_static(b"a"));
218 sink.stdout_line(Bytes::from_static(b"b"));
219 sink.stdout_line(Bytes::from_static(b"c"));
220
221 let mut seqs = Vec::new();
222 for _ in 0..3 {
223 if let OutputEvent::Chunk(c) = rx.recv().await.unwrap() {
224 seqs.push(c.seq);
225 }
226 }
227 assert_eq!(seqs, vec![0, 1, 2]);
228 }
229
230 #[tokio::test]
231 async fn output_sink_seq_is_independent_per_stream() {
232 let (tx, mut rx) = broadcast::channel::<OutputEvent>(16);
233 let sink = OutputSink::new(tx, 1);
234
235 sink.stdout_line(Bytes::from_static(b"o1"));
236 sink.stderr_line(Bytes::from_static(b"e1"));
237 sink.stdout_line(Bytes::from_static(b"o2"));
238 sink.stderr_line(Bytes::from_static(b"e2"));
239
240 let mut stdout_seqs = Vec::new();
241 let mut stderr_seqs = Vec::new();
242 for _ in 0..4 {
243 if let OutputEvent::Chunk(c) = rx.recv().await.unwrap() {
244 match c.stream {
245 StreamKind::Stdout => stdout_seqs.push(c.seq),
246 StreamKind::Stderr => stderr_seqs.push(c.seq),
247 }
248 }
249 }
250 assert_eq!(stdout_seqs, vec![0, 1]);
251 assert_eq!(stderr_seqs, vec![0, 1]);
252 }
253
254 #[tokio::test]
255 async fn output_sink_does_not_panic_without_subscribers() {
256 let (tx, _) = broadcast::channel::<OutputEvent>(16);
257 let sink = OutputSink::new(tx, 1);
258
259 sink.stdout_line(Bytes::from_static(b"nobody-listens"));
260 sink.stderr_line(Bytes::from_static(b"still-no-one"));
261 }
262
263 #[tokio::test]
264 async fn output_sink_fans_out_to_multiple_subscribers() {
265 let (tx, mut rx1) = broadcast::channel::<OutputEvent>(16);
266 let mut rx2 = tx.subscribe();
267 let sink = OutputSink::new(tx, 2);
268
269 sink.stdout_line(Bytes::from_static(b"hello"));
270
271 for rx in [&mut rx1, &mut rx2] {
272 if let OutputEvent::Chunk(c) = rx.recv().await.unwrap() {
273 assert_eq!(&c.line[..], b"hello");
274 } else {
275 panic!("expected Chunk");
276 }
277 }
278 }
279
280 #[tokio::test]
281 async fn output_sink_forwards_line_to_subscribers_without_byte_copy() {
282 let (tx, mut rx1) = broadcast::channel::<OutputEvent>(16);
283 let mut rx2 = tx.subscribe();
284 let sink = OutputSink::new(tx, 1);
285
286 let payload = Bytes::from_static(b"shared-line");
287 let payload_ptr = payload.as_ptr();
288 sink.stdout_line(payload);
289
290 for rx in [&mut rx1, &mut rx2] {
291 if let OutputEvent::Chunk(c) = rx.recv().await.unwrap() {
292 assert_eq!(
293 c.line.as_ptr(),
294 payload_ptr,
295 "line bytes must be shared across subscribers"
296 );
297 } else {
298 panic!("expected Chunk");
299 }
300 }
301 }
302
303 #[tokio::test]
304 async fn registry_subscribe_returns_none_before_first_sink_for() {
305 let reg = OutputRegistry::new(16);
306 let task = TaskId::from("t-1");
307 assert!(reg.subscribe(&task).is_none());
308 }
309
310 #[tokio::test]
311 async fn registry_subscribe_returns_some_after_sink_for() {
312 let reg = OutputRegistry::new(16);
313 let task = TaskId::from("t-2");
314 let _sink = reg.sink_for(task.clone(), 1);
315 assert!(reg.subscribe(&task).is_some());
316 }
317
318 #[tokio::test]
319 async fn registry_sink_for_reuses_channel_across_attempts() {
320 let reg = OutputRegistry::new(16);
321 let task = TaskId::from("t-merge");
322
323 let sink_a1 = reg.sink_for(task.clone(), 1);
324 let mut rx = reg.subscribe(&task).unwrap();
325
326 sink_a1.stdout_line(Bytes::from_static(b"from-attempt-1"));
327
328 let sink_a2 = reg.sink_for(task.clone(), 2);
329 sink_a2.stdout_line(Bytes::from_static(b"from-attempt-2"));
330
331 let mut seen = Vec::new();
332 for _ in 0..2 {
333 if let OutputEvent::Chunk(c) = rx.recv().await.unwrap() {
334 seen.push((c.attempt, std::str::from_utf8(&c.line).unwrap().to_string()));
335 }
336 }
337 assert_eq!(
338 seen,
339 vec![
340 (1u32, "from-attempt-1".to_string()),
341 (2u32, "from-attempt-2".to_string()),
342 ]
343 );
344 }
345
346 #[tokio::test]
347 async fn registry_announce_run_started_emits_boundary_event() {
348 let reg = OutputRegistry::new(16);
349 let task = TaskId::from("t-bound");
350 let _sink = reg.sink_for(task.clone(), 1);
351 let mut rx = reg.subscribe(&task).unwrap();
352
353 reg.announce_run_started(&task, 1);
354
355 match rx.recv().await.unwrap() {
356 OutputEvent::RunStarted { attempt, .. } => assert_eq!(attempt, 1),
357 other => panic!("expected RunStarted, got {other:?}"),
358 }
359 }
360
361 #[tokio::test]
362 async fn registry_announce_run_finished_carries_exit_code() {
363 let reg = OutputRegistry::new(16);
364 let task = TaskId::from("t-fin");
365 let _sink = reg.sink_for(task.clone(), 3);
366 let mut rx = reg.subscribe(&task).unwrap();
367
368 reg.announce_run_finished(&task, 3, Some(0));
369
370 match rx.recv().await.unwrap() {
371 OutputEvent::RunFinished {
372 attempt, exit_code, ..
373 } => {
374 assert_eq!(attempt, 3);
375 assert_eq!(exit_code, Some(0));
376 }
377 other => panic!("expected RunFinished, got {other:?}"),
378 }
379 }
380
381 #[tokio::test]
382 async fn registry_evict_drops_channel() {
383 let reg = OutputRegistry::new(16);
384 let task = TaskId::from("t-evict");
385 let _sink = reg.sink_for(task.clone(), 1);
386 assert!(reg.subscribe(&task).is_some());
387
388 reg.evict(&task);
389 assert!(reg.subscribe(&task).is_none());
390 }
391
392 #[tokio::test]
393 async fn registry_active_channels_reflects_state() {
394 let reg = OutputRegistry::new(16);
395 assert_eq!(reg.active_channels(), 0);
396
397 let _ = reg.sink_for(TaskId::from("a"), 1);
398 let _ = reg.sink_for(TaskId::from("b"), 1);
399 assert_eq!(reg.active_channels(), 2);
400
401 reg.evict(&TaskId::from("a"));
402 assert_eq!(reg.active_channels(), 1);
403 }
404
405 #[tokio::test]
406 async fn registry_ensure_channel_creates_subscribable_channel() {
407 let reg = OutputRegistry::new(16);
408 let task = TaskId::from("t-ensure");
409 assert!(reg.subscribe(&task).is_none());
410
411 reg.ensure_channel(task.clone());
412 assert!(reg.subscribe(&task).is_some());
413 }
414
415 #[tokio::test]
416 async fn registry_ensure_channel_is_idempotent() {
417 let reg = OutputRegistry::new(16);
418 let task = TaskId::from("t-idem");
419
420 reg.ensure_channel(task.clone());
421 let mut rx = reg.subscribe(&task).unwrap();
422
423 reg.ensure_channel(task.clone());
425
426 let _ = reg.sink_for(task.clone(), 1);
427 let _ = reg.subscribe(&task).unwrap();
429 assert!(rx.try_recv().is_err()); }
431
432 #[tokio::test]
433 async fn registry_announce_without_channel_is_noop() {
434 let reg = OutputRegistry::new(16);
435 let task = TaskId::from("t-ghost");
436
437 reg.announce_run_started(&task, 1);
438 reg.announce_run_finished(&task, 1, None);
439
440 assert!(
441 reg.subscribe(&task).is_none(),
442 "must not auto-create channel"
443 );
444 }
445}