Skip to main content

nemo_flow/
stream.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Streaming LLM response wrapper.
5//!
6//! This module provides [`LlmStreamWrapper`], a [`Stream`] adapter
7//! that sits between the raw stream from an LLM API and the consumer. It
8//! feeds chunks to a user-supplied collector, and automatically emits
9//! lifecycle events when the stream ends.
10//!
11//! ## Pipeline
12//!
13//! ```text
14//! raw chunk (Json) -> collector(chunk) -> Ok(()) -> yield chunk
15//!                                      -> Err(e) -> terminate stream with error
16//! upstream error -> terminate stream with error -> finalizer() -> Json -> SanitizeResponseGuardrails -> END event
17//! stream ends -> finalizer() -> Json -> SanitizeResponseGuardrails -> END event
18//! ```
19//!
20//! The **collector** receives each chunk (Json) and can accumulate state
21//! (e.g., concatenating tokens). If the collector returns `Err`, the stream
22//! terminates immediately with that error. Upstream stream errors also
23//! terminate the stream immediately. The **finalizer** is called once when the
24//! stream terminates and returns the aggregated response as [`Json`]. That
25//! aggregated response then flows through sanitize response guardrails before
26//! being included in the END event.
27
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31
32use tokio_stream::Stream;
33
34use crate::api::llm::LlmHandle;
35use crate::api::runtime::NemoFlowContextState;
36use crate::api::runtime::global_context;
37use crate::api::runtime::{ScopeStackHandle, current_scope_stack};
38use crate::codec::response::AnnotatedLlmResponse;
39use crate::codec::traits::LlmResponseCodec;
40use crate::error::Result;
41use crate::json::Json;
42
43/// Wraps an inner `Stream<Item = Result<Json>>` of raw chunks and:
44///
45/// 1. Passes each chunk to the user-supplied **collector** closure.
46///    If the collector returns `Err`, the stream terminates with that error.
47/// 2. On stream exhaustion, calls the **finalizer** to produce an aggregated
48///    [`Json`] response, runs sanitize response guardrails on it, then emits
49///    the LLM END event.
50///
51/// This type is returned by [`crate::api::llm::llm_stream_call_execute`] and
52/// is usually consumed as an ordinary async stream. The wrapper preserves the
53/// originating scope stack so end-of-stream bookkeeping still uses the correct
54/// scope-local middleware and subscribers even when polling happens elsewhere.
55pub struct LlmStreamWrapper {
56    inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
57    handle: LlmHandle,
58    scope_stack: ScopeStackHandle,
59    collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
60    finalizer: Option<Box<dyn FnOnce() -> Json + Send>>,
61    response_codec: Option<Arc<dyn LlmResponseCodec>>,
62    metadata: Option<Json>,
63    ended: bool,
64}
65
66impl LlmStreamWrapper {
67    /// Create a new `LlmStreamWrapper` around the given raw stream.
68    ///
69    /// Captures the current [`ScopeStackHandle`] at creation time so the
70    /// correct scope stack is used when the stream is later polled, even if
71    /// polling happens on a different task or thread.
72    ///
73    /// # Parameters
74    /// - `inner`: Raw stream of JSON chunks from the provider callback.
75    /// - `handle`: [`LlmHandle`] identifying the managed LLM span.
76    /// - `collector`: Per-chunk callback used to accumulate stream state or
77    ///   forward chunks elsewhere. Returning `Err` terminates the stream.
78    /// - `finalizer`: One-shot callback invoked when the stream finishes to
79    ///   synthesize the aggregated response payload.
80    /// - `data`: Retained compatibility payload; Agent Trajectory
81    ///   Observability Format (ATOF) end data is the finalized response.
82    /// - `metadata`: Optional event metadata merged into the emitted LLM-end event.
83    /// - `response_codec`: Optional codec used to derive annotated response
84    ///   metadata from the aggregated final payload.
85    ///
86    /// # Returns
87    /// A new [`LlmStreamWrapper`] ready to be polled.
88    pub fn new(
89        inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
90        handle: LlmHandle,
91        collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
92        finalizer: Box<dyn FnOnce() -> Json + Send>,
93        _data: Option<Json>,
94        metadata: Option<Json>,
95        response_codec: Option<Arc<dyn LlmResponseCodec>>,
96    ) -> Self {
97        Self {
98            inner,
99            handle,
100            scope_stack: current_scope_stack(),
101            collector,
102            finalizer: Some(finalizer),
103            response_codec,
104            metadata,
105            ended: false,
106        }
107    }
108
109    /// Return the captured scope stack handle for this stream.
110    ///
111    /// Callers can use this to bind the correct scope stack when spawning
112    /// the stream on a different task via `TASK_SCOPE_STACK.scope(...)`.
113    ///
114    /// # Returns
115    /// A shared reference to the [`ScopeStackHandle`] captured when the stream
116    /// wrapper was created.
117    pub fn scope_stack(&self) -> &ScopeStackHandle {
118        &self.scope_stack
119    }
120
121    fn finish(&mut self) {
122        if self.ended {
123            return;
124        }
125        self.ended = true;
126        self.emit_end_event();
127    }
128
129    /// Emit the LLM END event with aggregated response data.
130    ///
131    /// Calls the finalizer to produce the aggregated response, runs sanitize
132    /// response guardrails, and emits the END event.
133    fn emit_end_event(&mut self) {
134        let aggregated = match self.finalizer.take() {
135            Some(finalizer) => finalizer(),
136            None => Json::Null,
137        };
138
139        // Decode aggregated response if response codec is present (non-fatal)
140        let annotated_response: Option<Arc<AnnotatedLlmResponse>> = self
141            .response_codec
142            .as_ref()
143            .and_then(|c| c.decode_response(&aggregated).ok())
144            .map(Arc::new);
145
146        let event_snapshot = {
147            let ss_guard = self.scope_stack.read().expect("scope stack lock poisoned");
148            let sl =
149                ss_guard.collect_scope_local_registries(|r| &r.llm_sanitize_response_guardrails);
150            let sl_subs = ss_guard.collect_scope_local_subscribers();
151            let ctx = global_context();
152            let state = ctx.read();
153            match state {
154                Ok(state) => {
155                    let subscribers = state.collect_event_subscribers(&sl_subs);
156                    let sanitized = state.llm_sanitize_response_chain(aggregated, &sl);
157                    let data = if sanitized.is_null() {
158                        self.handle.data.clone()
159                    } else {
160                        Some(sanitized)
161                    };
162                    let event = state.end_llm_handle(
163                        &self.handle,
164                        data,
165                        self.metadata.clone(),
166                        annotated_response,
167                    );
168                    Some((event, subscribers))
169                }
170                Err(_) => None,
171            }
172        };
173        if let Some((event, subscribers)) = event_snapshot {
174            NemoFlowContextState::emit_event(&event, &subscribers);
175        }
176    }
177}
178
179impl Stream for LlmStreamWrapper {
180    type Item = Result<Json>;
181
182    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183        let this = self.get_mut();
184
185        if this.ended {
186            return Poll::Ready(None);
187        }
188
189        // Poll the inner stream
190        match this.inner.as_mut().poll_next(cx) {
191            Poll::Ready(Some(Ok(raw_chunk))) => {
192                // Feed chunk to the collector; if it returns Err, terminate the stream
193                match (this.collector)(raw_chunk.clone()) {
194                    Ok(()) => Poll::Ready(Some(Ok(raw_chunk))),
195                    Err(e) => {
196                        this.finish();
197                        Poll::Ready(Some(Err(e)))
198                    }
199                }
200            }
201            Poll::Ready(Some(Err(e))) => {
202                this.finish();
203                Poll::Ready(Some(Err(e)))
204            }
205            Poll::Ready(None) => {
206                this.finish();
207                Poll::Ready(None)
208            }
209            Poll::Pending => Poll::Pending,
210        }
211    }
212}
213
214impl Drop for LlmStreamWrapper {
215    fn drop(&mut self) {
216        self.finish();
217    }
218}