akribes_sdk/sub/run_stream.rs
1//! [`RunStream`] — a handle that wraps a script run together with its SSE
2//! event stream, translates the wire events into [`WorkflowEvent`]s and
3//! detects terminal events so callers can `await` a final output without
4//! hand-rolling a 30-line receiver loop.
5
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use futures::Stream;
12use tokio::sync::{mpsc, oneshot};
13
14use crate::client::Inner;
15use crate::error::{AkribesError, Result};
16use crate::events::WorkflowEvent;
17use crate::models::HubEvent;
18use crate::sub::events::{EventSubscription, stream_sse_with_retry};
19use crate::sub::executions::RunBuilder;
20use crate::suspend::SuspendTrigger;
21
22// ── Callback payloads ────────────────────────────────────────────────────────
23//
24// Owned snapshots passed to category callbacks. Decoupling these from
25// `WorkflowEvent` variants lets us add fields to the variants without
26// breaking callback signatures.
27
28/// Payload passed to `on_task_end` callbacks.
29#[derive(Debug, Clone)]
30pub struct TaskEndPayload {
31 pub task: String,
32 pub output: serde_json::Value,
33 pub duration: Duration,
34 pub usage: Option<akribes_types::event::TokenUsage>,
35 pub variant: crate::task_end::TaskEndVariant,
36}
37
38/// Payload passed to `on_suspend` callbacks (mirrors a `Checkpoint` event).
39#[derive(Debug, Clone)]
40pub struct SuspendPayload {
41 pub name: String,
42 pub token: String,
43 pub prompt: String,
44 pub schema: serde_json::Value,
45 pub timeout_secs: Option<u64>,
46 pub trigger: SuspendTrigger,
47}
48
49/// Payload passed to `on_error` callbacks.
50#[derive(Debug, Clone)]
51pub struct EngineErrorPayload {
52 pub message: String,
53 pub kind: akribes_types::error::ErrorKind,
54}
55
56// Boxed callback aliases. `Send` so callbacks can be registered from one
57// task and the stream polled on another (common in async runtimes).
58type OutputCb = Box<dyn Fn(&serde_json::Value) + Send + 'static>;
59type TaskEndCb = Box<dyn Fn(&TaskEndPayload) + Send + 'static>;
60type SuspendCb = Box<dyn Fn(&SuspendPayload) + Send + 'static>;
61type ErrorCb = Box<dyn Fn(&EngineErrorPayload) + Send + 'static>;
62type AnyCb = Box<dyn Fn(&WorkflowEvent) + Send + 'static>;
63
64/// A live handle to a running workflow execution.
65///
66/// Obtain one from [`crate::sub::executions::ScopedExecutionsClient::run_stream`].
67/// The stream yields [`WorkflowEvent`] items until the workflow reaches a
68/// terminal event (`End` or `Error`), at which point it ends. Call
69/// [`output`](Self::output) to consume the stream to completion and get the
70/// final workflow output (or an error).
71///
72/// Dropping the `RunStream` cancels the underlying SSE subscription.
73pub struct RunStream {
74 pub execution_id: String,
75 rx: mpsc::UnboundedReceiver<Result<WorkflowEvent>>,
76 // Held for cancel-on-drop semantics; the background SSE listener AND
77 // the filter/translator task are both aborted when this field is dropped.
78 _subscription: EventSubscription,
79 // Set to true once the stream has terminated (End or Error observed
80 // or the channel closed).
81 terminated: bool,
82 // Populated when a `WorkflowEvent::End` is yielded, so `output()` can
83 // resolve to the final output without re-reading the stream.
84 final_output: Option<serde_json::Value>,
85 // Populated when a `WorkflowEvent::Error` is yielded.
86 final_error: Option<(String, akribes_types::error::ErrorKind)>,
87 // ── Callback hooks ──────────────────────────────────────────────────
88 //
89 // Each list is invoked in registration order while the polling
90 // thread holds &mut self. Callbacks must be `Send` so the
91 // `RunStream` itself stays `Send`, but they execute *synchronously*
92 // on the polling thread — long-running work belongs in a spawned
93 // task. See the per-method docs for the contract.
94 on_output_cbs: Vec<OutputCb>,
95 on_task_end_cbs: Vec<TaskEndCb>,
96 on_suspend_cbs: Vec<SuspendCb>,
97 on_error_cbs: Vec<ErrorCb>,
98 on_any_cbs: Vec<AnyCb>,
99}
100
101impl std::fmt::Debug for RunStream {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("RunStream")
104 .field("execution_id", &self.execution_id)
105 .field("terminated", &self.terminated)
106 .finish()
107 }
108}
109
110impl RunStream {
111 /// Wire up a run stream from its pieces. Usually you don't call this
112 /// directly — see [`ScopedExecutionsClient::run_stream`].
113 ///
114 /// [`ScopedExecutionsClient::run_stream`]:
115 /// crate::sub::executions::ScopedExecutionsClient::run_stream
116 pub(crate) fn new(
117 execution_id: String,
118 rx: mpsc::UnboundedReceiver<Result<WorkflowEvent>>,
119 subscription: EventSubscription,
120 ) -> Self {
121 Self {
122 execution_id,
123 rx,
124 _subscription: subscription,
125 terminated: false,
126 final_output: None,
127 final_error: None,
128 on_output_cbs: Vec::new(),
129 on_task_end_cbs: Vec::new(),
130 on_suspend_cbs: Vec::new(),
131 on_error_cbs: Vec::new(),
132 on_any_cbs: Vec::new(),
133 }
134 }
135
136 // ── Callback registration ───────────────────────────────────────────
137 //
138 // The callback API is convenience sugar layered over the iterator —
139 // every event still flows through `next()` / `poll_next()`. Use it when
140 // you want fire-and-forget sinks (logging, metrics, UI updates) without
141 // hand-rolling a match-arm loop.
142 //
143 // **Threading.** Callbacks must be `Send + 'static` because `RunStream`
144 // itself is `Send` and may be polled across thread boundaries by the
145 // async runtime. They run synchronously on the polling thread between
146 // the time an event arrives and the time it's yielded to the caller —
147 // **don't block, sleep, or `.await` inside them**. If you need to do
148 // I/O, spawn a task or push onto a channel.
149 //
150 // Callbacks fire in registration order; multiple callbacks per category
151 // are supported. Calls are additive: there is no `clear` or `replace`
152 // helper today (the iterator is the canonical surface; callbacks are
153 // best registered once at stream construction).
154
155 /// Register a callback for streaming agent output chunks.
156 ///
157 /// Fires once per [`WorkflowEvent::AgentChunk`]. The callback receives
158 /// the chunk text wrapped in a `serde_json::Value::String` so the API
159 /// stays uniform across SDKs (TS/Python use `Value`-like shapes too).
160 pub fn on_output<F>(&mut self, cb: F)
161 where
162 F: Fn(&serde_json::Value) + Send + 'static,
163 {
164 self.on_output_cbs.push(Box::new(cb));
165 }
166
167 /// Register a callback for task completion events ([`WorkflowEvent::TaskEnd`]).
168 pub fn on_task_end<F>(&mut self, cb: F)
169 where
170 F: Fn(&TaskEndPayload) + Send + 'static,
171 {
172 self.on_task_end_cbs.push(Box::new(cb));
173 }
174
175 /// Register a callback for workflow suspensions ([`WorkflowEvent::Checkpoint`]).
176 ///
177 /// `WorkflowEvent::ToolApproval` and `Breakpoint` are not routed here —
178 /// register `on_any` if you need to observe every suspend-category event.
179 pub fn on_suspend<F>(&mut self, cb: F)
180 where
181 F: Fn(&SuspendPayload) + Send + 'static,
182 {
183 self.on_suspend_cbs.push(Box::new(cb));
184 }
185
186 /// Register a callback for terminal error events ([`WorkflowEvent::Error`]).
187 ///
188 /// The stream still terminates on the next poll after an error; the
189 /// callback fires once, before termination is observed.
190 pub fn on_error<F>(&mut self, cb: F)
191 where
192 F: Fn(&EngineErrorPayload) + Send + 'static,
193 {
194 self.on_error_cbs.push(Box::new(cb));
195 }
196
197 /// Register a catch-all callback that sees every yielded event.
198 ///
199 /// Fires *after* category-specific callbacks for the same event, in
200 /// registration order. Use for logging or generic event sinks; prefer
201 /// the category callbacks for typed access.
202 pub fn on_any<F>(&mut self, cb: F)
203 where
204 F: Fn(&WorkflowEvent) + Send + 'static,
205 {
206 self.on_any_cbs.push(Box::new(cb));
207 }
208
209 /// Dispatch the configured callbacks for one event. Called from both
210 /// `next()` and `poll_next()` after a successful receive.
211 fn dispatch_callbacks(&self, evt: &WorkflowEvent) {
212 match evt {
213 WorkflowEvent::AgentChunk { chunk, .. } if !self.on_output_cbs.is_empty() => {
214 let v = serde_json::Value::String(chunk.clone());
215 for cb in &self.on_output_cbs {
216 cb(&v);
217 }
218 }
219 WorkflowEvent::TaskEnd {
220 task,
221 output,
222 duration,
223 usage,
224 variant,
225 } if !self.on_task_end_cbs.is_empty() => {
226 let payload = TaskEndPayload {
227 task: task.clone(),
228 output: output.clone(),
229 duration: *duration,
230 usage: usage.clone(),
231 variant: *variant,
232 };
233 for cb in &self.on_task_end_cbs {
234 cb(&payload);
235 }
236 }
237 WorkflowEvent::Checkpoint {
238 name,
239 token,
240 prompt,
241 schema,
242 timeout_secs,
243 trigger,
244 } if !self.on_suspend_cbs.is_empty() => {
245 let payload = SuspendPayload {
246 name: name.clone(),
247 token: token.clone(),
248 prompt: prompt.clone(),
249 schema: schema.clone(),
250 timeout_secs: *timeout_secs,
251 trigger: trigger.clone(),
252 };
253 for cb in &self.on_suspend_cbs {
254 cb(&payload);
255 }
256 }
257 WorkflowEvent::Error { message, kind, .. } if !self.on_error_cbs.is_empty() => {
258 let payload = EngineErrorPayload {
259 message: message.clone(),
260 kind: *kind,
261 };
262 for cb in &self.on_error_cbs {
263 cb(&payload);
264 }
265 }
266 _ => {}
267 }
268 for cb in &self.on_any_cbs {
269 cb(evt);
270 }
271 }
272
273 /// Pull the next typed event. Returns `None` once the stream terminates.
274 ///
275 /// `Error` events are yielded as `Some(Ok(WorkflowEvent::Error{..}))` so
276 /// the caller can observe them, and they also cause the stream to end
277 /// immediately after.
278 pub async fn next(&mut self) -> Option<Result<WorkflowEvent>> {
279 if self.terminated {
280 return None;
281 }
282 match self.rx.recv().await {
283 Some(Ok(evt)) => {
284 // Capture terminal state before yielding so `output()` can
285 // resolve cheaply afterwards.
286 match &evt {
287 WorkflowEvent::End { output, .. } => {
288 self.final_output = Some(output.clone());
289 self.terminated = true;
290 }
291 WorkflowEvent::Error { message, kind, .. } => {
292 self.final_error = Some((message.clone(), *kind));
293 self.terminated = true;
294 }
295 _ => {}
296 }
297 self.dispatch_callbacks(&evt);
298 Some(Ok(evt))
299 }
300 Some(Err(e)) => {
301 self.terminated = true;
302 Some(Err(e))
303 }
304 None => {
305 self.terminated = true;
306 None
307 }
308 }
309 }
310
311 /// Drain the stream and resolve to the final workflow output.
312 ///
313 /// Resolves to `Ok(output)` when a `WorkflowEvent::End` was observed
314 /// (either already, or while draining). If the workflow ended with an
315 /// `Error` event, resolves to an [`AkribesError::Script`] / `::Transient` /
316 /// `::Fatal` depending on the `ErrorKind` — same classification as
317 /// [`crate::sub::executions::ExecutionsClient::await_execution`]. If the
318 /// stream closes without a terminal event, resolves to
319 /// [`AkribesError::Other`].
320 pub async fn output(mut self) -> Result<serde_json::Value> {
321 while !self.terminated {
322 if self.next().await.is_none() {
323 break;
324 }
325 }
326
327 if let Some(out) = self.final_output.take() {
328 return Ok(out);
329 }
330 if let Some((message, kind)) = self.final_error.take() {
331 return Err(classify_error(message, kind, self.execution_id.clone()));
332 }
333 Err(AkribesError::Other(format!(
334 "run stream for execution {} ended without a terminal event",
335 self.execution_id
336 )))
337 }
338
339 /// Drain the stream to terminal and return a [`RunSummary`] aggregated
340 /// from observed events (#1033 — mirrors TS `RunStream.summary()`).
341 ///
342 /// Resolves the same way as [`output`](Self::output): rejects when the
343 /// workflow ended with an `Error` event, or when the stream closed
344 /// without a terminal event. On success, the returned `RunSummary`
345 /// rolls up workflow duration, per-task durations, task pass/fail
346 /// counts, and per-model token totals collected from `TaskEnd` usage
347 /// blocks.
348 pub async fn summary(mut self) -> Result<RunSummary> {
349 let mut total: Duration = Duration::ZERO;
350 let mut per_task_ms: std::collections::HashMap<String, u128> =
351 std::collections::HashMap::new();
352 // `passed` / `failed` is determined by the last variant we see for
353 // each task — `unable` overrides a prior success on retry (matches
354 // TS).
355 let mut tasks_status: std::collections::HashMap<String, bool> =
356 std::collections::HashMap::new();
357 let mut by_model_tokens: std::collections::HashMap<String, u64> =
358 std::collections::HashMap::new();
359 let mut usage_observed = false;
360 let mut mock_observed = false;
361 let mut final_output: Option<serde_json::Value> = None;
362
363 while !self.terminated {
364 match self.next().await {
365 Some(Ok(evt)) => match &evt {
366 WorkflowEvent::End {
367 output, duration, ..
368 } => {
369 total = *duration;
370 final_output = Some(output.clone());
371 }
372 WorkflowEvent::TaskEnd {
373 task,
374 duration,
375 usage,
376 variant,
377 ..
378 } => {
379 *per_task_ms.entry(task.clone()).or_insert(0) += duration.as_millis();
380 // Latest variant wins.
381 let passed = matches!(variant, crate::task_end::TaskEndVariant::Success);
382 tasks_status.insert(task.clone(), passed);
383 if let Some(u) = usage {
384 usage_observed = true;
385 if u.provider == "mock" {
386 mock_observed = true;
387 }
388 let tokens = u.input_tokens.saturating_add(u.output_tokens);
389 let model = if u.model.is_empty() {
390 "unknown".to_string()
391 } else {
392 u.model.clone()
393 };
394 *by_model_tokens.entry(model).or_insert(0) += tokens;
395 }
396 }
397 _ => {}
398 },
399 Some(Err(e)) => return Err(e),
400 None => break,
401 }
402 }
403
404 if let Some((message, kind)) = self.final_error.take() {
405 return Err(classify_error(message, kind, self.execution_id.clone()));
406 }
407 let Some(out) = final_output.or(self.final_output.take()) else {
408 return Err(AkribesError::Other(format!(
409 "run stream for execution {} ended without a terminal event",
410 self.execution_id
411 )));
412 };
413
414 let total_tasks = tasks_status.len();
415 let passed = tasks_status.values().filter(|p| **p).count();
416 let failed = total_tasks - passed;
417
418 // Mirrors TS: when we have no real usage signal (no usage block, or
419 // the engine reported the `mock` provider) we report `cost = None`.
420 // When usage is real, `by_model` carries the total (input + output)
421 // token count per model; `total_usd` stays 0 until a pricing table
422 // is wired in.
423 let cost = if !usage_observed || mock_observed {
424 None
425 } else {
426 Some(RunSummaryCost {
427 total_usd: 0.0,
428 by_model: by_model_tokens,
429 })
430 };
431
432 Ok(RunSummary {
433 execution_id: self.execution_id.clone(),
434 output: out,
435 cost,
436 duration: RunSummaryDuration {
437 total_ms: total.as_millis(),
438 per_task_ms,
439 },
440 tasks: RunSummaryTasks {
441 passed,
442 failed,
443 total: total_tasks,
444 },
445 })
446 }
447}
448
449/// Aggregated summary of a run, returned by [`RunStream::summary`] (#1033).
450/// Mirrors TS `RunSummary` from `runStream.ts`.
451#[derive(Debug, Clone)]
452pub struct RunSummary {
453 pub execution_id: String,
454 pub output: serde_json::Value,
455 /// `None` when the stream observed no usage (`TaskEnd.usage` was
456 /// absent or the engine reported the `mock` provider). When `Some`,
457 /// the SDK currently leaves `total_usd` at 0 — `by_model` carries the
458 /// raw (input + output) token total per model so callers can multiply
459 /// by their own pricing table.
460 pub cost: Option<RunSummaryCost>,
461 pub duration: RunSummaryDuration,
462 pub tasks: RunSummaryTasks,
463}
464
465#[derive(Debug, Clone)]
466pub struct RunSummaryCost {
467 pub total_usd: f64,
468 pub by_model: std::collections::HashMap<String, u64>,
469}
470
471#[derive(Debug, Clone)]
472pub struct RunSummaryDuration {
473 pub total_ms: u128,
474 pub per_task_ms: std::collections::HashMap<String, u128>,
475}
476
477#[derive(Debug, Clone)]
478pub struct RunSummaryTasks {
479 pub passed: usize,
480 pub failed: usize,
481 pub total: usize,
482}
483
484impl Stream for RunStream {
485 type Item = Result<WorkflowEvent>;
486
487 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
488 let this = self.get_mut();
489 if this.terminated {
490 return Poll::Ready(None);
491 }
492 match this.rx.poll_recv(cx) {
493 Poll::Ready(Some(Ok(evt))) => {
494 match &evt {
495 WorkflowEvent::End { output, .. } => {
496 this.final_output = Some(output.clone());
497 this.terminated = true;
498 }
499 WorkflowEvent::Error { message, kind, .. } => {
500 this.final_error = Some((message.clone(), *kind));
501 this.terminated = true;
502 }
503 _ => {}
504 }
505 this.dispatch_callbacks(&evt);
506 Poll::Ready(Some(Ok(evt)))
507 }
508 Poll::Ready(Some(Err(e))) => {
509 this.terminated = true;
510 Poll::Ready(Some(Err(e)))
511 }
512 Poll::Ready(None) => {
513 this.terminated = true;
514 Poll::Ready(None)
515 }
516 Poll::Pending => Poll::Pending,
517 }
518 }
519}
520
521fn classify_error(
522 message: String,
523 kind: akribes_types::error::ErrorKind,
524 execution_id: String,
525) -> AkribesError {
526 use akribes_types::error::ErrorKind;
527 let eid = Some(execution_id);
528 match kind {
529 ErrorKind::RateLimit
530 | ErrorKind::ServerError500
531 | ErrorKind::BadGateway502
532 | ErrorKind::ServiceUnavailable503
533 | ErrorKind::GatewayTimeout504
534 | ErrorKind::NetworkError => {
535 // #1296: surface the status when the kind maps cleanly so
536 // callers can prefer the per-status base backoff over the
537 // sdk-wide default.
538 let status = match kind {
539 ErrorKind::RateLimit => Some(429u16),
540 ErrorKind::ServerError500 => Some(500u16),
541 ErrorKind::BadGateway502 => Some(502u16),
542 ErrorKind::ServiceUnavailable503 => Some(503u16),
543 ErrorKind::GatewayTimeout504 => Some(504u16),
544 _ => None,
545 };
546 AkribesError::Transient {
547 message,
548 execution_id: eid,
549 retry_after: None,
550 status,
551 }
552 }
553 ErrorKind::AuthError | ErrorKind::TokenLimit => AkribesError::Fatal {
554 message,
555 execution_id: eid,
556 },
557 _ => AkribesError::Script {
558 message,
559 execution_id: eid,
560 },
561 }
562}
563
564// ── Assembling a RunStream ──────────────────────────────────────────────────
565
566/// Start an SSE subscription filtered to the given script, then kick off the
567/// run and return a [`RunStream`] wired to translate `HubEvent::Execution`
568/// payloads into [`WorkflowEvent`]s.
569///
570/// Subscribes to SSE *first*, waits for the subscription to be live on the
571/// server (ready signal), then POSTs `/run`. This avoids the race where
572/// opening events broadcast by the hub are lost if the GET /events handshake
573/// hasn't completed when the POST response fires.
574///
575/// Called by [`ScopedExecutionsClient::run_stream`].
576pub(crate) async fn start_run_stream(
577 inner: Arc<Inner>,
578 project_id: i64,
579 builder: RunBuilder,
580) -> Result<RunStream> {
581 let script_name = builder.script_name().to_string();
582
583 // ── 1. Spawn the SSE listener with a ready-signal oneshot. Wait for the
584 // server to confirm the subscription before POSTing `/run`.
585 let (hub_tx, mut hub_rx) = mpsc::unbounded_channel();
586 let (ready_tx, ready_rx) = oneshot::channel::<Result<()>>();
587 let http = inner.http.clone();
588 let token = inner.token.clone();
589 let base_url = inner.base_url.clone();
590 let script_for_sse = script_name.clone();
591 let sse_handle = tokio::spawn(async move {
592 let _ = stream_sse_with_retry(
593 http,
594 token,
595 base_url,
596 project_id,
597 Some(script_for_sse),
598 hub_tx,
599 Some(ready_tx),
600 )
601 .await;
602 });
603
604 // Wait for "subscribed" signal. If the server rejects the subscription
605 // or the task dies before firing, surface the error to the caller.
606 match ready_rx.await {
607 Ok(Ok(())) => {}
608 Ok(Err(e)) => {
609 sse_handle.abort();
610 return Err(e);
611 }
612 Err(_) => {
613 sse_handle.abort();
614 return Err(AkribesError::Other(
615 "SSE listener died before subscription was confirmed".into(),
616 ));
617 }
618 }
619
620 // ── 2. Kick off the run now that we're guaranteed to receive events.
621 let run = match builder.execute().await {
622 Ok(r) => r,
623 Err(e) => {
624 sse_handle.abort();
625 return Err(e);
626 }
627 };
628 let execution_id = run.execution_id;
629
630 // ── 3. Filter-and-translate task: pull `HubEvent::Execution` entries
631 // whose script_name AND execution_id match this run, convert
632 // them to WorkflowEvent, and forward. Stop as soon as a
633 // terminal event is seen.
634 //
635 // Filtering by script name alone would conflate concurrent runs of
636 // the same script started by another caller — their `WorkflowEnd`
637 // would resolve this handle's `output()` with the wrong value.
638 // Matches the TS SDK's `RunStream` execution-id filter (see
639 // `packages/akribes-sdk-ts/src/runStream.ts::routeRaw`).
640 // Pre-#1042 servers that don't stamp `execution_id` on the
641 // broadcast envelope still flow through (back-compat: `None`
642 // matches anything) — but every server in production today does.
643 let (out_tx, out_rx) = mpsc::unbounded_channel::<Result<WorkflowEvent>>();
644 let script_for_filter = script_name.clone();
645 let exec_id_for_filter = execution_id.clone();
646 let filter_handle = tokio::spawn(async move {
647 while let Some(hub) = hub_rx.recv().await {
648 if let HubEvent::Execution {
649 script_name: evt_script,
650 execution_id: evt_exec_id,
651 event,
652 ..
653 } = hub
654 {
655 if evt_script != script_for_filter {
656 continue;
657 }
658 if let Some(eid) = evt_exec_id {
659 if eid != exec_id_for_filter {
660 continue;
661 }
662 }
663 let wf: WorkflowEvent = event.into();
664 let is_terminal = wf.is_terminal();
665 if out_tx.send(Ok(wf)).is_err() {
666 break;
667 }
668 if is_terminal {
669 break;
670 }
671 }
672 }
673 });
674
675 // Drop-guard: both the SSE listener AND the filter task abort when the
676 // RunStream is dropped. Previously only the filter was tracked, which
677 // leaked the SSE task whenever a RunStream was dropped pre-terminal.
678 let subscription = EventSubscription::from_handles(vec![sse_handle, filter_handle]);
679 Ok(RunStream::new(execution_id, out_rx, subscription))
680}