Skip to main content

akribes_sdk/sub/
run_stream.rs

1//! [`RunStream`] — a handle that wraps a script run together with its SSE
2//! event stream, translates the wire events into [`WorkflowEvent`]s and
3//! detects terminal events so callers can `await` a final output without
4//! hand-rolling a 30-line receiver loop.
5
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use futures::Stream;
12use tokio::sync::{mpsc, oneshot};
13
14use crate::client::Inner;
15use crate::error::{AkribesError, Result};
16use crate::events::WorkflowEvent;
17use crate::models::HubEvent;
18use crate::sub::events::{EventSubscription, stream_sse_with_retry};
19use crate::sub::executions::RunBuilder;
20use crate::suspend::SuspendTrigger;
21
22// ── Callback payloads ────────────────────────────────────────────────────────
23//
24// Owned snapshots passed to category callbacks. Decoupling these from
25// `WorkflowEvent` variants lets us add fields to the variants without
26// breaking callback signatures.
27
28/// Payload passed to `on_task_end` callbacks.
29#[derive(Debug, Clone)]
30pub struct TaskEndPayload {
31    pub task: String,
32    pub output: serde_json::Value,
33    pub duration: Duration,
34    pub usage: Option<akribes_types::event::TokenUsage>,
35    pub variant: crate::task_end::TaskEndVariant,
36}
37
38/// Payload passed to `on_suspend` callbacks (mirrors a `Checkpoint` event).
39#[derive(Debug, Clone)]
40pub struct SuspendPayload {
41    pub name: String,
42    pub token: String,
43    pub prompt: String,
44    pub schema: serde_json::Value,
45    pub timeout_secs: Option<u64>,
46    pub trigger: SuspendTrigger,
47}
48
49/// Payload passed to `on_error` callbacks.
50#[derive(Debug, Clone)]
51pub struct EngineErrorPayload {
52    pub message: String,
53    pub kind: akribes_types::error::ErrorKind,
54}
55
56// Boxed callback aliases. `Send` so callbacks can be registered from one
57// task and the stream polled on another (common in async runtimes).
58type OutputCb = Box<dyn Fn(&serde_json::Value) + Send + 'static>;
59type TaskEndCb = Box<dyn Fn(&TaskEndPayload) + Send + 'static>;
60type SuspendCb = Box<dyn Fn(&SuspendPayload) + Send + 'static>;
61type ErrorCb = Box<dyn Fn(&EngineErrorPayload) + Send + 'static>;
62type AnyCb = Box<dyn Fn(&WorkflowEvent) + Send + 'static>;
63
64/// A live handle to a running workflow execution.
65///
66/// Obtain one from [`crate::sub::executions::ScopedExecutionsClient::run_stream`].
67/// The stream yields [`WorkflowEvent`] items until the workflow reaches a
68/// terminal event (`End` or `Error`), at which point it ends. Call
69/// [`output`](Self::output) to consume the stream to completion and get the
70/// final workflow output (or an error).
71///
72/// Dropping the `RunStream` cancels the underlying SSE subscription.
73pub struct RunStream {
74    pub execution_id: String,
75    rx: mpsc::UnboundedReceiver<Result<WorkflowEvent>>,
76    // Held for cancel-on-drop semantics; the background SSE listener AND
77    // the filter/translator task are both aborted when this field is dropped.
78    _subscription: EventSubscription,
79    // Set to true once the stream has terminated (End or Error observed
80    // or the channel closed).
81    terminated: bool,
82    // Populated when a `WorkflowEvent::End` is yielded, so `output()` can
83    // resolve to the final output without re-reading the stream.
84    final_output: Option<serde_json::Value>,
85    // Populated when a `WorkflowEvent::Error` is yielded.
86    final_error: Option<(String, akribes_types::error::ErrorKind)>,
87    // ── Callback hooks ──────────────────────────────────────────────────
88    //
89    // Each list is invoked in registration order while the polling
90    // thread holds &mut self. Callbacks must be `Send` so the
91    // `RunStream` itself stays `Send`, but they execute *synchronously*
92    // on the polling thread — long-running work belongs in a spawned
93    // task. See the per-method docs for the contract.
94    on_output_cbs: Vec<OutputCb>,
95    on_task_end_cbs: Vec<TaskEndCb>,
96    on_suspend_cbs: Vec<SuspendCb>,
97    on_error_cbs: Vec<ErrorCb>,
98    on_any_cbs: Vec<AnyCb>,
99}
100
101impl std::fmt::Debug for RunStream {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("RunStream")
104            .field("execution_id", &self.execution_id)
105            .field("terminated", &self.terminated)
106            .finish()
107    }
108}
109
110impl RunStream {
111    /// Wire up a run stream from its pieces. Usually you don't call this
112    /// directly — see [`ScopedExecutionsClient::run_stream`].
113    ///
114    /// [`ScopedExecutionsClient::run_stream`]:
115    ///     crate::sub::executions::ScopedExecutionsClient::run_stream
116    pub(crate) fn new(
117        execution_id: String,
118        rx: mpsc::UnboundedReceiver<Result<WorkflowEvent>>,
119        subscription: EventSubscription,
120    ) -> Self {
121        Self {
122            execution_id,
123            rx,
124            _subscription: subscription,
125            terminated: false,
126            final_output: None,
127            final_error: None,
128            on_output_cbs: Vec::new(),
129            on_task_end_cbs: Vec::new(),
130            on_suspend_cbs: Vec::new(),
131            on_error_cbs: Vec::new(),
132            on_any_cbs: Vec::new(),
133        }
134    }
135
136    // ── Callback registration ───────────────────────────────────────────
137    //
138    // The callback API is convenience sugar layered over the iterator —
139    // every event still flows through `next()` / `poll_next()`. Use it when
140    // you want fire-and-forget sinks (logging, metrics, UI updates) without
141    // hand-rolling a match-arm loop.
142    //
143    // **Threading.** Callbacks must be `Send + 'static` because `RunStream`
144    // itself is `Send` and may be polled across thread boundaries by the
145    // async runtime. They run synchronously on the polling thread between
146    // the time an event arrives and the time it's yielded to the caller —
147    // **don't block, sleep, or `.await` inside them**. If you need to do
148    // I/O, spawn a task or push onto a channel.
149    //
150    // Callbacks fire in registration order; multiple callbacks per category
151    // are supported. Calls are additive: there is no `clear` or `replace`
152    // helper today (the iterator is the canonical surface; callbacks are
153    // best registered once at stream construction).
154
155    /// Register a callback for streaming agent output chunks.
156    ///
157    /// Fires once per [`WorkflowEvent::AgentChunk`]. The callback receives
158    /// the chunk text wrapped in a `serde_json::Value::String` so the API
159    /// stays uniform across SDKs (TS/Python use `Value`-like shapes too).
160    pub fn on_output<F>(&mut self, cb: F)
161    where
162        F: Fn(&serde_json::Value) + Send + 'static,
163    {
164        self.on_output_cbs.push(Box::new(cb));
165    }
166
167    /// Register a callback for task completion events ([`WorkflowEvent::TaskEnd`]).
168    pub fn on_task_end<F>(&mut self, cb: F)
169    where
170        F: Fn(&TaskEndPayload) + Send + 'static,
171    {
172        self.on_task_end_cbs.push(Box::new(cb));
173    }
174
175    /// Register a callback for workflow suspensions ([`WorkflowEvent::Checkpoint`]).
176    ///
177    /// `WorkflowEvent::ToolApproval` and `Breakpoint` are not routed here —
178    /// register `on_any` if you need to observe every suspend-category event.
179    pub fn on_suspend<F>(&mut self, cb: F)
180    where
181        F: Fn(&SuspendPayload) + Send + 'static,
182    {
183        self.on_suspend_cbs.push(Box::new(cb));
184    }
185
186    /// Register a callback for terminal error events ([`WorkflowEvent::Error`]).
187    ///
188    /// The stream still terminates on the next poll after an error; the
189    /// callback fires once, before termination is observed.
190    pub fn on_error<F>(&mut self, cb: F)
191    where
192        F: Fn(&EngineErrorPayload) + Send + 'static,
193    {
194        self.on_error_cbs.push(Box::new(cb));
195    }
196
197    /// Register a catch-all callback that sees every yielded event.
198    ///
199    /// Fires *after* category-specific callbacks for the same event, in
200    /// registration order. Use for logging or generic event sinks; prefer
201    /// the category callbacks for typed access.
202    pub fn on_any<F>(&mut self, cb: F)
203    where
204        F: Fn(&WorkflowEvent) + Send + 'static,
205    {
206        self.on_any_cbs.push(Box::new(cb));
207    }
208
209    /// Dispatch the configured callbacks for one event. Called from both
210    /// `next()` and `poll_next()` after a successful receive.
211    fn dispatch_callbacks(&self, evt: &WorkflowEvent) {
212        match evt {
213            WorkflowEvent::AgentChunk { chunk, .. } if !self.on_output_cbs.is_empty() => {
214                let v = serde_json::Value::String(chunk.clone());
215                for cb in &self.on_output_cbs {
216                    cb(&v);
217                }
218            }
219            WorkflowEvent::TaskEnd {
220                task,
221                output,
222                duration,
223                usage,
224                variant,
225            } if !self.on_task_end_cbs.is_empty() => {
226                let payload = TaskEndPayload {
227                    task: task.clone(),
228                    output: output.clone(),
229                    duration: *duration,
230                    usage: usage.clone(),
231                    variant: *variant,
232                };
233                for cb in &self.on_task_end_cbs {
234                    cb(&payload);
235                }
236            }
237            WorkflowEvent::Checkpoint {
238                name,
239                token,
240                prompt,
241                schema,
242                timeout_secs,
243                trigger,
244            } if !self.on_suspend_cbs.is_empty() => {
245                let payload = SuspendPayload {
246                    name: name.clone(),
247                    token: token.clone(),
248                    prompt: prompt.clone(),
249                    schema: schema.clone(),
250                    timeout_secs: *timeout_secs,
251                    trigger: trigger.clone(),
252                };
253                for cb in &self.on_suspend_cbs {
254                    cb(&payload);
255                }
256            }
257            WorkflowEvent::Error { message, kind, .. } if !self.on_error_cbs.is_empty() => {
258                let payload = EngineErrorPayload {
259                    message: message.clone(),
260                    kind: *kind,
261                };
262                for cb in &self.on_error_cbs {
263                    cb(&payload);
264                }
265            }
266            _ => {}
267        }
268        for cb in &self.on_any_cbs {
269            cb(evt);
270        }
271    }
272
273    /// Pull the next typed event. Returns `None` once the stream terminates.
274    ///
275    /// `Error` events are yielded as `Some(Ok(WorkflowEvent::Error{..}))` so
276    /// the caller can observe them, and they also cause the stream to end
277    /// immediately after.
278    pub async fn next(&mut self) -> Option<Result<WorkflowEvent>> {
279        if self.terminated {
280            return None;
281        }
282        match self.rx.recv().await {
283            Some(Ok(evt)) => {
284                // Capture terminal state before yielding so `output()` can
285                // resolve cheaply afterwards.
286                match &evt {
287                    WorkflowEvent::End { output, .. } => {
288                        self.final_output = Some(output.clone());
289                        self.terminated = true;
290                    }
291                    WorkflowEvent::Error { message, kind, .. } => {
292                        self.final_error = Some((message.clone(), *kind));
293                        self.terminated = true;
294                    }
295                    _ => {}
296                }
297                self.dispatch_callbacks(&evt);
298                Some(Ok(evt))
299            }
300            Some(Err(e)) => {
301                self.terminated = true;
302                Some(Err(e))
303            }
304            None => {
305                self.terminated = true;
306                None
307            }
308        }
309    }
310
311    /// Drain the stream and resolve to the final workflow output.
312    ///
313    /// Resolves to `Ok(output)` when a `WorkflowEvent::End` was observed
314    /// (either already, or while draining). If the workflow ended with an
315    /// `Error` event, resolves to an [`AkribesError::Script`] / `::Transient` /
316    /// `::Fatal` depending on the `ErrorKind` — same classification as
317    /// [`crate::sub::executions::ExecutionsClient::await_execution`]. If the
318    /// stream closes without a terminal event, resolves to
319    /// [`AkribesError::Other`].
320    pub async fn output(mut self) -> Result<serde_json::Value> {
321        while !self.terminated {
322            if self.next().await.is_none() {
323                break;
324            }
325        }
326
327        if let Some(out) = self.final_output.take() {
328            return Ok(out);
329        }
330        if let Some((message, kind)) = self.final_error.take() {
331            return Err(classify_error(message, kind, self.execution_id.clone()));
332        }
333        Err(AkribesError::Other(format!(
334            "run stream for execution {} ended without a terminal event",
335            self.execution_id
336        )))
337    }
338
339    /// Drain the stream to terminal and return a [`RunSummary`] aggregated
340    /// from observed events (#1033 — mirrors TS `RunStream.summary()`).
341    ///
342    /// Resolves the same way as [`output`](Self::output): rejects when the
343    /// workflow ended with an `Error` event, or when the stream closed
344    /// without a terminal event. On success, the returned `RunSummary`
345    /// rolls up workflow duration, per-task durations, task pass/fail
346    /// counts, and per-model token totals collected from `TaskEnd` usage
347    /// blocks.
348    pub async fn summary(mut self) -> Result<RunSummary> {
349        let mut total: Duration = Duration::ZERO;
350        let mut per_task_ms: std::collections::HashMap<String, u128> =
351            std::collections::HashMap::new();
352        // `passed` / `failed` is determined by the last variant we see for
353        // each task — `unable` overrides a prior success on retry (matches
354        // TS).
355        let mut tasks_status: std::collections::HashMap<String, bool> =
356            std::collections::HashMap::new();
357        let mut by_model_tokens: std::collections::HashMap<String, u64> =
358            std::collections::HashMap::new();
359        let mut usage_observed = false;
360        let mut mock_observed = false;
361        let mut final_output: Option<serde_json::Value> = None;
362
363        while !self.terminated {
364            match self.next().await {
365                Some(Ok(evt)) => match &evt {
366                    WorkflowEvent::End {
367                        output, duration, ..
368                    } => {
369                        total = *duration;
370                        final_output = Some(output.clone());
371                    }
372                    WorkflowEvent::TaskEnd {
373                        task,
374                        duration,
375                        usage,
376                        variant,
377                        ..
378                    } => {
379                        *per_task_ms.entry(task.clone()).or_insert(0) += duration.as_millis();
380                        // Latest variant wins.
381                        let passed = matches!(variant, crate::task_end::TaskEndVariant::Success);
382                        tasks_status.insert(task.clone(), passed);
383                        if let Some(u) = usage {
384                            usage_observed = true;
385                            if u.provider == "mock" {
386                                mock_observed = true;
387                            }
388                            let tokens = u.input_tokens.saturating_add(u.output_tokens);
389                            let model = if u.model.is_empty() {
390                                "unknown".to_string()
391                            } else {
392                                u.model.clone()
393                            };
394                            *by_model_tokens.entry(model).or_insert(0) += tokens;
395                        }
396                    }
397                    _ => {}
398                },
399                Some(Err(e)) => return Err(e),
400                None => break,
401            }
402        }
403
404        if let Some((message, kind)) = self.final_error.take() {
405            return Err(classify_error(message, kind, self.execution_id.clone()));
406        }
407        let Some(out) = final_output.or(self.final_output.take()) else {
408            return Err(AkribesError::Other(format!(
409                "run stream for execution {} ended without a terminal event",
410                self.execution_id
411            )));
412        };
413
414        let total_tasks = tasks_status.len();
415        let passed = tasks_status.values().filter(|p| **p).count();
416        let failed = total_tasks - passed;
417
418        // Mirrors TS: when we have no real usage signal (no usage block, or
419        // the engine reported the `mock` provider) we report `cost = None`.
420        // When usage is real, `by_model` carries the total (input + output)
421        // token count per model; `total_usd` stays 0 until a pricing table
422        // is wired in.
423        let cost = if !usage_observed || mock_observed {
424            None
425        } else {
426            Some(RunSummaryCost {
427                total_usd: 0.0,
428                by_model: by_model_tokens,
429            })
430        };
431
432        Ok(RunSummary {
433            execution_id: self.execution_id.clone(),
434            output: out,
435            cost,
436            duration: RunSummaryDuration {
437                total_ms: total.as_millis(),
438                per_task_ms,
439            },
440            tasks: RunSummaryTasks {
441                passed,
442                failed,
443                total: total_tasks,
444            },
445        })
446    }
447}
448
449/// Aggregated summary of a run, returned by [`RunStream::summary`] (#1033).
450/// Mirrors TS `RunSummary` from `runStream.ts`.
451#[derive(Debug, Clone)]
452pub struct RunSummary {
453    pub execution_id: String,
454    pub output: serde_json::Value,
455    /// `None` when the stream observed no usage (`TaskEnd.usage` was
456    /// absent or the engine reported the `mock` provider). When `Some`,
457    /// the SDK currently leaves `total_usd` at 0 — `by_model` carries the
458    /// raw (input + output) token total per model so callers can multiply
459    /// by their own pricing table.
460    pub cost: Option<RunSummaryCost>,
461    pub duration: RunSummaryDuration,
462    pub tasks: RunSummaryTasks,
463}
464
465#[derive(Debug, Clone)]
466pub struct RunSummaryCost {
467    pub total_usd: f64,
468    pub by_model: std::collections::HashMap<String, u64>,
469}
470
471#[derive(Debug, Clone)]
472pub struct RunSummaryDuration {
473    pub total_ms: u128,
474    pub per_task_ms: std::collections::HashMap<String, u128>,
475}
476
477#[derive(Debug, Clone)]
478pub struct RunSummaryTasks {
479    pub passed: usize,
480    pub failed: usize,
481    pub total: usize,
482}
483
484impl Stream for RunStream {
485    type Item = Result<WorkflowEvent>;
486
487    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
488        let this = self.get_mut();
489        if this.terminated {
490            return Poll::Ready(None);
491        }
492        match this.rx.poll_recv(cx) {
493            Poll::Ready(Some(Ok(evt))) => {
494                match &evt {
495                    WorkflowEvent::End { output, .. } => {
496                        this.final_output = Some(output.clone());
497                        this.terminated = true;
498                    }
499                    WorkflowEvent::Error { message, kind, .. } => {
500                        this.final_error = Some((message.clone(), *kind));
501                        this.terminated = true;
502                    }
503                    _ => {}
504                }
505                this.dispatch_callbacks(&evt);
506                Poll::Ready(Some(Ok(evt)))
507            }
508            Poll::Ready(Some(Err(e))) => {
509                this.terminated = true;
510                Poll::Ready(Some(Err(e)))
511            }
512            Poll::Ready(None) => {
513                this.terminated = true;
514                Poll::Ready(None)
515            }
516            Poll::Pending => Poll::Pending,
517        }
518    }
519}
520
521fn classify_error(
522    message: String,
523    kind: akribes_types::error::ErrorKind,
524    execution_id: String,
525) -> AkribesError {
526    use akribes_types::error::ErrorKind;
527    let eid = Some(execution_id);
528    match kind {
529        ErrorKind::RateLimit
530        | ErrorKind::ServerError500
531        | ErrorKind::BadGateway502
532        | ErrorKind::ServiceUnavailable503
533        | ErrorKind::GatewayTimeout504
534        | ErrorKind::NetworkError => {
535            // #1296: surface the status when the kind maps cleanly so
536            // callers can prefer the per-status base backoff over the
537            // sdk-wide default.
538            let status = match kind {
539                ErrorKind::RateLimit => Some(429u16),
540                ErrorKind::ServerError500 => Some(500u16),
541                ErrorKind::BadGateway502 => Some(502u16),
542                ErrorKind::ServiceUnavailable503 => Some(503u16),
543                ErrorKind::GatewayTimeout504 => Some(504u16),
544                _ => None,
545            };
546            AkribesError::Transient {
547                message,
548                execution_id: eid,
549                retry_after: None,
550                status,
551            }
552        }
553        ErrorKind::AuthError | ErrorKind::TokenLimit => AkribesError::Fatal {
554            message,
555            execution_id: eid,
556        },
557        _ => AkribesError::Script {
558            message,
559            execution_id: eid,
560        },
561    }
562}
563
564// ── Assembling a RunStream ──────────────────────────────────────────────────
565
566/// Start an SSE subscription filtered to the given script, then kick off the
567/// run and return a [`RunStream`] wired to translate `HubEvent::Execution`
568/// payloads into [`WorkflowEvent`]s.
569///
570/// Subscribes to SSE *first*, waits for the subscription to be live on the
571/// server (ready signal), then POSTs `/run`. This avoids the race where
572/// opening events broadcast by the hub are lost if the GET /events handshake
573/// hasn't completed when the POST response fires.
574///
575/// Called by [`ScopedExecutionsClient::run_stream`].
576pub(crate) async fn start_run_stream(
577    inner: Arc<Inner>,
578    project_id: i64,
579    builder: RunBuilder,
580) -> Result<RunStream> {
581    let script_name = builder.script_name().to_string();
582
583    // ── 1. Spawn the SSE listener with a ready-signal oneshot. Wait for the
584    //        server to confirm the subscription before POSTing `/run`.
585    let (hub_tx, mut hub_rx) = mpsc::unbounded_channel();
586    let (ready_tx, ready_rx) = oneshot::channel::<Result<()>>();
587    let http = inner.http.clone();
588    let token = inner.token.clone();
589    let base_url = inner.base_url.clone();
590    let script_for_sse = script_name.clone();
591    let sse_handle = tokio::spawn(async move {
592        let _ = stream_sse_with_retry(
593            http,
594            token,
595            base_url,
596            project_id,
597            Some(script_for_sse),
598            hub_tx,
599            Some(ready_tx),
600        )
601        .await;
602    });
603
604    // Wait for "subscribed" signal. If the server rejects the subscription
605    // or the task dies before firing, surface the error to the caller.
606    match ready_rx.await {
607        Ok(Ok(())) => {}
608        Ok(Err(e)) => {
609            sse_handle.abort();
610            return Err(e);
611        }
612        Err(_) => {
613            sse_handle.abort();
614            return Err(AkribesError::Other(
615                "SSE listener died before subscription was confirmed".into(),
616            ));
617        }
618    }
619
620    // ── 2. Kick off the run now that we're guaranteed to receive events.
621    let run = match builder.execute().await {
622        Ok(r) => r,
623        Err(e) => {
624            sse_handle.abort();
625            return Err(e);
626        }
627    };
628    let execution_id = run.execution_id;
629
630    // ── 3. Filter-and-translate task: pull `HubEvent::Execution` entries
631    //        whose script_name AND execution_id match this run, convert
632    //        them to WorkflowEvent, and forward. Stop as soon as a
633    //        terminal event is seen.
634    //
635    //  Filtering by script name alone would conflate concurrent runs of
636    //  the same script started by another caller — their `WorkflowEnd`
637    //  would resolve this handle's `output()` with the wrong value.
638    //  Matches the TS SDK's `RunStream` execution-id filter (see
639    //  `packages/akribes-sdk-ts/src/runStream.ts::routeRaw`).
640    //  Pre-#1042 servers that don't stamp `execution_id` on the
641    //  broadcast envelope still flow through (back-compat: `None`
642    //  matches anything) — but every server in production today does.
643    let (out_tx, out_rx) = mpsc::unbounded_channel::<Result<WorkflowEvent>>();
644    let script_for_filter = script_name.clone();
645    let exec_id_for_filter = execution_id.clone();
646    let filter_handle = tokio::spawn(async move {
647        while let Some(hub) = hub_rx.recv().await {
648            if let HubEvent::Execution {
649                script_name: evt_script,
650                execution_id: evt_exec_id,
651                event,
652                ..
653            } = hub
654            {
655                if evt_script != script_for_filter {
656                    continue;
657                }
658                if let Some(eid) = evt_exec_id {
659                    if eid != exec_id_for_filter {
660                        continue;
661                    }
662                }
663                let wf: WorkflowEvent = event.into();
664                let is_terminal = wf.is_terminal();
665                if out_tx.send(Ok(wf)).is_err() {
666                    break;
667                }
668                if is_terminal {
669                    break;
670                }
671            }
672        }
673    });
674
675    // Drop-guard: both the SSE listener AND the filter task abort when the
676    // RunStream is dropped. Previously only the filter was tracked, which
677    // leaked the SSE task whenever a RunStream was dropped pre-terminal.
678    let subscription = EventSubscription::from_handles(vec![sse_handle, filter_handle]);
679    Ok(RunStream::new(execution_id, out_rx, subscription))
680}