nemo_flow/stream.rs
1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Streaming LLM response wrapper.
5//!
6//! This module provides [`LlmStreamWrapper`], a [`Stream`] adapter
7//! that sits between the raw stream from an LLM API and the consumer. It
8//! feeds chunks to a user-supplied collector, and automatically emits
9//! lifecycle events when the stream ends.
10//!
11//! ## Pipeline
12//!
13//! ```text
14//! raw chunk (Json) -> collector(chunk) -> Ok(()) -> yield chunk
15//! -> Err(e) -> terminate stream with error
16//! upstream error -> terminate stream with error -> finalizer() -> Json -> SanitizeResponseGuardrails -> END event
17//! stream ends -> finalizer() -> Json -> SanitizeResponseGuardrails -> END event
18//! ```
19//!
20//! The **collector** receives each chunk (Json) and can accumulate state
21//! (e.g., concatenating tokens). If the collector returns `Err`, the stream
22//! terminates immediately with that error. Upstream stream errors also
23//! terminate the stream immediately. The **finalizer** is called once when the
24//! stream terminates and returns the aggregated response as [`Json`]. That
25//! aggregated response then flows through sanitize response guardrails before
26//! being included in the END event.
27
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31
32use tokio_stream::Stream;
33
34use crate::api::llm::LlmHandle;
35use crate::api::runtime::NemoFlowContextState;
36use crate::api::runtime::global_context;
37use crate::api::runtime::{ScopeStackHandle, current_scope_stack};
38use crate::codec::response::AnnotatedLlmResponse;
39use crate::codec::traits::LlmResponseCodec;
40use crate::error::Result;
41use crate::json::Json;
42
43/// Wraps an inner `Stream<Item = Result<Json>>` of raw chunks and:
44///
45/// 1. Passes each chunk to the user-supplied **collector** closure.
46/// If the collector returns `Err`, the stream terminates with that error.
47/// 2. On stream exhaustion, calls the **finalizer** to produce an aggregated
48/// [`Json`] response, runs sanitize response guardrails on it, then emits
49/// the LLM END event.
50///
51/// This type is returned by [`crate::api::llm::llm_stream_call_execute`] and
52/// is usually consumed as an ordinary async stream. The wrapper preserves the
53/// originating scope stack so end-of-stream bookkeeping still uses the correct
54/// scope-local middleware and subscribers even when polling happens elsewhere.
55pub struct LlmStreamWrapper {
56 inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
57 handle: LlmHandle,
58 scope_stack: ScopeStackHandle,
59 collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
60 finalizer: Option<Box<dyn FnOnce() -> Json + Send>>,
61 response_codec: Option<Arc<dyn LlmResponseCodec>>,
62 metadata: Option<Json>,
63 ended: bool,
64}
65
66impl LlmStreamWrapper {
67 /// Create a new `LlmStreamWrapper` around the given raw stream.
68 ///
69 /// Captures the current [`ScopeStackHandle`] at creation time so the
70 /// correct scope stack is used when the stream is later polled, even if
71 /// polling happens on a different task or thread.
72 ///
73 /// # Parameters
74 /// - `inner`: Raw stream of JSON chunks from the provider callback.
75 /// - `handle`: [`LlmHandle`] identifying the managed LLM span.
76 /// - `collector`: Per-chunk callback used to accumulate stream state or
77 /// forward chunks elsewhere. Returning `Err` terminates the stream.
78 /// - `finalizer`: One-shot callback invoked when the stream finishes to
79 /// synthesize the aggregated response payload.
80 /// - `data`: Retained compatibility payload; ATOF end data is the finalized response.
81 /// - `metadata`: Optional event metadata merged into the emitted LLM-end event.
82 /// - `response_codec`: Optional codec used to derive annotated response
83 /// metadata from the aggregated final payload.
84 ///
85 /// # Returns
86 /// A new [`LlmStreamWrapper`] ready to be polled.
87 pub fn new(
88 inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
89 handle: LlmHandle,
90 collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
91 finalizer: Box<dyn FnOnce() -> Json + Send>,
92 _data: Option<Json>,
93 metadata: Option<Json>,
94 response_codec: Option<Arc<dyn LlmResponseCodec>>,
95 ) -> Self {
96 Self {
97 inner,
98 handle,
99 scope_stack: current_scope_stack(),
100 collector,
101 finalizer: Some(finalizer),
102 response_codec,
103 metadata,
104 ended: false,
105 }
106 }
107
108 /// Return the captured scope stack handle for this stream.
109 ///
110 /// Callers can use this to bind the correct scope stack when spawning
111 /// the stream on a different task via `TASK_SCOPE_STACK.scope(...)`.
112 ///
113 /// # Returns
114 /// A shared reference to the [`ScopeStackHandle`] captured when the stream
115 /// wrapper was created.
116 pub fn scope_stack(&self) -> &ScopeStackHandle {
117 &self.scope_stack
118 }
119
120 fn finish(&mut self) {
121 if self.ended {
122 return;
123 }
124 self.ended = true;
125 self.emit_end_event();
126 }
127
128 /// Emit the LLM END event with aggregated response data.
129 ///
130 /// Calls the finalizer to produce the aggregated response, runs sanitize
131 /// response guardrails, and emits the END event.
132 fn emit_end_event(&mut self) {
133 let aggregated = match self.finalizer.take() {
134 Some(finalizer) => finalizer(),
135 None => Json::Null,
136 };
137
138 // Decode aggregated response if response codec is present (non-fatal)
139 let annotated_response: Option<Arc<AnnotatedLlmResponse>> = self
140 .response_codec
141 .as_ref()
142 .and_then(|c| c.decode_response(&aggregated).ok())
143 .map(Arc::new);
144
145 let event_snapshot = {
146 let ss_guard = self.scope_stack.read().expect("scope stack lock poisoned");
147 let sl =
148 ss_guard.collect_scope_local_registries(|r| &r.llm_sanitize_response_guardrails);
149 let sl_subs = ss_guard.collect_scope_local_subscribers();
150 let ctx = global_context();
151 let state = ctx.read();
152 match state {
153 Ok(state) => {
154 let subscribers = state.collect_event_subscribers(&sl_subs);
155 let sanitized = state.llm_sanitize_response_chain(aggregated, &sl);
156 let data = if sanitized.is_null() {
157 self.handle.data.clone()
158 } else {
159 Some(sanitized)
160 };
161 let event = state.end_llm_handle(
162 &self.handle,
163 data,
164 self.metadata.clone(),
165 annotated_response,
166 );
167 Some((event, subscribers))
168 }
169 Err(_) => None,
170 }
171 };
172 if let Some((event, subscribers)) = event_snapshot {
173 NemoFlowContextState::emit_event(&event, &subscribers);
174 }
175 }
176}
177
178impl Stream for LlmStreamWrapper {
179 type Item = Result<Json>;
180
181 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182 let this = self.get_mut();
183
184 if this.ended {
185 return Poll::Ready(None);
186 }
187
188 // Poll the inner stream
189 match this.inner.as_mut().poll_next(cx) {
190 Poll::Ready(Some(Ok(raw_chunk))) => {
191 // Feed chunk to the collector; if it returns Err, terminate the stream
192 match (this.collector)(raw_chunk.clone()) {
193 Ok(()) => Poll::Ready(Some(Ok(raw_chunk))),
194 Err(e) => {
195 this.finish();
196 Poll::Ready(Some(Err(e)))
197 }
198 }
199 }
200 Poll::Ready(Some(Err(e))) => {
201 this.finish();
202 Poll::Ready(Some(Err(e)))
203 }
204 Poll::Ready(None) => {
205 this.finish();
206 Poll::Ready(None)
207 }
208 Poll::Pending => Poll::Pending,
209 }
210 }
211}