Skip to main content

nemo_flow/
stream.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Streaming LLM response wrapper.
5//!
6//! This module provides [`LlmStreamWrapper`], a [`Stream`] adapter
7//! that sits between the raw stream from an LLM API and the consumer. It
8//! feeds chunks to a user-supplied collector, and automatically emits
9//! lifecycle events when the stream ends.
10//!
11//! ## Pipeline
12//!
13//! ```text
14//! raw chunk (Json) -> collector(chunk) -> Ok(()) -> yield chunk
15//!                                      -> Err(e) -> terminate stream with error
16//! upstream error -> terminate stream with error -> finalizer() -> Json -> SanitizeResponseGuardrails -> END event
17//! stream ends -> finalizer() -> Json -> SanitizeResponseGuardrails -> END event
18//! ```
19//!
20//! The **collector** receives each chunk (Json) and can accumulate state
21//! (e.g., concatenating tokens). If the collector returns `Err`, the stream
22//! terminates immediately with that error. Upstream stream errors also
23//! terminate the stream immediately. The **finalizer** is called once when the
24//! stream terminates and returns the aggregated response as [`Json`]. That
25//! aggregated response then flows through sanitize response guardrails before
26//! being included in the END event.
27
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31
32use tokio_stream::Stream;
33
34use crate::api::llm::LlmHandle;
35use crate::api::runtime::NemoFlowContextState;
36use crate::api::runtime::global_context;
37use crate::api::runtime::{ScopeStackHandle, current_scope_stack};
38use crate::codec::response::AnnotatedLlmResponse;
39use crate::codec::traits::LlmResponseCodec;
40use crate::error::Result;
41use crate::json::Json;
42
43/// Wraps an inner `Stream<Item = Result<Json>>` of raw chunks and:
44///
45/// 1. Passes each chunk to the user-supplied **collector** closure.
46///    If the collector returns `Err`, the stream terminates with that error.
47/// 2. On stream exhaustion, calls the **finalizer** to produce an aggregated
48///    [`Json`] response, runs sanitize response guardrails on it, then emits
49///    the LLM END event.
50///
51/// This type is returned by [`crate::api::llm::llm_stream_call_execute`] and
52/// is usually consumed as an ordinary async stream. The wrapper preserves the
53/// originating scope stack so end-of-stream bookkeeping still uses the correct
54/// scope-local middleware and subscribers even when polling happens elsewhere.
55pub struct LlmStreamWrapper {
56    inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
57    handle: LlmHandle,
58    scope_stack: ScopeStackHandle,
59    collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
60    finalizer: Option<Box<dyn FnOnce() -> Json + Send>>,
61    response_codec: Option<Arc<dyn LlmResponseCodec>>,
62    metadata: Option<Json>,
63    ended: bool,
64}
65
66impl LlmStreamWrapper {
67    /// Create a new `LlmStreamWrapper` around the given raw stream.
68    ///
69    /// Captures the current [`ScopeStackHandle`] at creation time so the
70    /// correct scope stack is used when the stream is later polled, even if
71    /// polling happens on a different task or thread.
72    ///
73    /// # Parameters
74    /// - `inner`: Raw stream of JSON chunks from the provider callback.
75    /// - `handle`: [`LlmHandle`] identifying the managed LLM span.
76    /// - `collector`: Per-chunk callback used to accumulate stream state or
77    ///   forward chunks elsewhere. Returning `Err` terminates the stream.
78    /// - `finalizer`: One-shot callback invoked when the stream finishes to
79    ///   synthesize the aggregated response payload.
80    /// - `data`: Retained compatibility payload; ATOF end data is the finalized response.
81    /// - `metadata`: Optional event metadata merged into the emitted LLM-end event.
82    /// - `response_codec`: Optional codec used to derive annotated response
83    ///   metadata from the aggregated final payload.
84    ///
85    /// # Returns
86    /// A new [`LlmStreamWrapper`] ready to be polled.
87    pub fn new(
88        inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
89        handle: LlmHandle,
90        collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
91        finalizer: Box<dyn FnOnce() -> Json + Send>,
92        _data: Option<Json>,
93        metadata: Option<Json>,
94        response_codec: Option<Arc<dyn LlmResponseCodec>>,
95    ) -> Self {
96        Self {
97            inner,
98            handle,
99            scope_stack: current_scope_stack(),
100            collector,
101            finalizer: Some(finalizer),
102            response_codec,
103            metadata,
104            ended: false,
105        }
106    }
107
108    /// Return the captured scope stack handle for this stream.
109    ///
110    /// Callers can use this to bind the correct scope stack when spawning
111    /// the stream on a different task via `TASK_SCOPE_STACK.scope(...)`.
112    ///
113    /// # Returns
114    /// A shared reference to the [`ScopeStackHandle`] captured when the stream
115    /// wrapper was created.
116    pub fn scope_stack(&self) -> &ScopeStackHandle {
117        &self.scope_stack
118    }
119
120    fn finish(&mut self) {
121        if self.ended {
122            return;
123        }
124        self.ended = true;
125        self.emit_end_event();
126    }
127
128    /// Emit the LLM END event with aggregated response data.
129    ///
130    /// Calls the finalizer to produce the aggregated response, runs sanitize
131    /// response guardrails, and emits the END event.
132    fn emit_end_event(&mut self) {
133        let aggregated = match self.finalizer.take() {
134            Some(finalizer) => finalizer(),
135            None => Json::Null,
136        };
137
138        // Decode aggregated response if response codec is present (non-fatal)
139        let annotated_response: Option<Arc<AnnotatedLlmResponse>> = self
140            .response_codec
141            .as_ref()
142            .and_then(|c| c.decode_response(&aggregated).ok())
143            .map(Arc::new);
144
145        let event_snapshot = {
146            let ss_guard = self.scope_stack.read().expect("scope stack lock poisoned");
147            let sl =
148                ss_guard.collect_scope_local_registries(|r| &r.llm_sanitize_response_guardrails);
149            let sl_subs = ss_guard.collect_scope_local_subscribers();
150            let ctx = global_context();
151            let state = ctx.read();
152            match state {
153                Ok(state) => {
154                    let subscribers = state.collect_event_subscribers(&sl_subs);
155                    let sanitized = state.llm_sanitize_response_chain(aggregated, &sl);
156                    let data = if sanitized.is_null() {
157                        self.handle.data.clone()
158                    } else {
159                        Some(sanitized)
160                    };
161                    let event = state.end_llm_handle(
162                        &self.handle,
163                        data,
164                        self.metadata.clone(),
165                        annotated_response,
166                    );
167                    Some((event, subscribers))
168                }
169                Err(_) => None,
170            }
171        };
172        if let Some((event, subscribers)) = event_snapshot {
173            NemoFlowContextState::emit_event(&event, &subscribers);
174        }
175    }
176}
177
178impl Stream for LlmStreamWrapper {
179    type Item = Result<Json>;
180
181    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182        let this = self.get_mut();
183
184        if this.ended {
185            return Poll::Ready(None);
186        }
187
188        // Poll the inner stream
189        match this.inner.as_mut().poll_next(cx) {
190            Poll::Ready(Some(Ok(raw_chunk))) => {
191                // Feed chunk to the collector; if it returns Err, terminate the stream
192                match (this.collector)(raw_chunk.clone()) {
193                    Ok(()) => Poll::Ready(Some(Ok(raw_chunk))),
194                    Err(e) => {
195                        this.finish();
196                        Poll::Ready(Some(Err(e)))
197                    }
198                }
199            }
200            Poll::Ready(Some(Err(e))) => {
201                this.finish();
202                Poll::Ready(Some(Err(e)))
203            }
204            Poll::Ready(None) => {
205                this.finish();
206                Poll::Ready(None)
207            }
208            Poll::Pending => Poll::Pending,
209        }
210    }
211}