nemo_flow/stream.rs
1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Streaming LLM response wrapper.
5//!
6//! This module provides [`LlmStreamWrapper`], a [`Stream`] adapter
7//! that sits between the raw stream from an LLM API and the consumer. It
8//! feeds chunks to a user-supplied collector, and automatically emits
9//! lifecycle events when the stream ends.
10//!
11//! ## Pipeline
12//!
13//! ```text
14//! raw chunk (Json) -> collector(chunk) -> Ok(()) -> yield chunk
15//! -> Err(e) -> terminate stream with error
16//! upstream error -> terminate stream with error -> finalizer() -> Json -> SanitizeResponseGuardrails -> END event
17//! stream ends -> finalizer() -> Json -> SanitizeResponseGuardrails -> END event
18//! ```
19//!
20//! The **collector** receives each chunk (Json) and can accumulate state
21//! (e.g., concatenating tokens). If the collector returns `Err`, the stream
22//! terminates immediately with that error. Upstream stream errors also
23//! terminate the stream immediately. The **finalizer** is called once when the
24//! stream terminates and returns the aggregated response as [`Json`]. That
25//! aggregated response then flows through sanitize response guardrails before
26//! being included in the END event.
27
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31
32use tokio_stream::Stream;
33
34use crate::api::llm::LlmHandle;
35use crate::api::runtime::NemoFlowContextState;
36use crate::api::runtime::global_context;
37use crate::api::runtime::{ScopeStackHandle, current_scope_stack};
38use crate::codec::response::AnnotatedLlmResponse;
39use crate::codec::traits::LlmResponseCodec;
40use crate::error::Result;
41use crate::json::Json;
42
43/// Wraps an inner `Stream<Item = Result<Json>>` of raw chunks and:
44///
45/// 1. Passes each chunk to the user-supplied **collector** closure.
46/// If the collector returns `Err`, the stream terminates with that error.
47/// 2. On stream exhaustion, calls the **finalizer** to produce an aggregated
48/// [`Json`] response, runs sanitize response guardrails on it, then emits
49/// the LLM END event.
50///
51/// This type is returned by [`crate::api::llm::llm_stream_call_execute`] and
52/// is usually consumed as an ordinary async stream. The wrapper preserves the
53/// originating scope stack so end-of-stream bookkeeping still uses the correct
54/// scope-local middleware and subscribers even when polling happens elsewhere.
55pub struct LlmStreamWrapper {
56 inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
57 handle: LlmHandle,
58 scope_stack: ScopeStackHandle,
59 collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
60 finalizer: Option<Box<dyn FnOnce() -> Json + Send>>,
61 response_codec: Option<Arc<dyn LlmResponseCodec>>,
62 metadata: Option<Json>,
63 ended: bool,
64}
65
66impl LlmStreamWrapper {
67 /// Create a new `LlmStreamWrapper` around the given raw stream.
68 ///
69 /// Captures the current [`ScopeStackHandle`] at creation time so the
70 /// correct scope stack is used when the stream is later polled, even if
71 /// polling happens on a different task or thread.
72 ///
73 /// # Parameters
74 /// - `inner`: Raw stream of JSON chunks from the provider callback.
75 /// - `handle`: [`LlmHandle`] identifying the managed LLM span.
76 /// - `collector`: Per-chunk callback used to accumulate stream state or
77 /// forward chunks elsewhere. Returning `Err` terminates the stream.
78 /// - `finalizer`: One-shot callback invoked when the stream finishes to
79 /// synthesize the aggregated response payload.
80 /// - `data`: Retained compatibility payload; Agent Trajectory
81 /// Observability Format (ATOF) end data is the finalized response.
82 /// - `metadata`: Optional event metadata merged into the emitted LLM-end event.
83 /// - `response_codec`: Optional codec used to derive annotated response
84 /// metadata from the aggregated final payload.
85 ///
86 /// # Returns
87 /// A new [`LlmStreamWrapper`] ready to be polled.
88 pub fn new(
89 inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
90 handle: LlmHandle,
91 collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
92 finalizer: Box<dyn FnOnce() -> Json + Send>,
93 _data: Option<Json>,
94 metadata: Option<Json>,
95 response_codec: Option<Arc<dyn LlmResponseCodec>>,
96 ) -> Self {
97 Self {
98 inner,
99 handle,
100 scope_stack: current_scope_stack(),
101 collector,
102 finalizer: Some(finalizer),
103 response_codec,
104 metadata,
105 ended: false,
106 }
107 }
108
109 /// Return the captured scope stack handle for this stream.
110 ///
111 /// Callers can use this to bind the correct scope stack when spawning
112 /// the stream on a different task via `TASK_SCOPE_STACK.scope(...)`.
113 ///
114 /// # Returns
115 /// A shared reference to the [`ScopeStackHandle`] captured when the stream
116 /// wrapper was created.
117 pub fn scope_stack(&self) -> &ScopeStackHandle {
118 &self.scope_stack
119 }
120
121 fn finish(&mut self) {
122 if self.ended {
123 return;
124 }
125 self.ended = true;
126 self.emit_end_event();
127 }
128
129 /// Emit the LLM END event with aggregated response data.
130 ///
131 /// Calls the finalizer to produce the aggregated response, runs sanitize
132 /// response guardrails, and emits the END event.
133 fn emit_end_event(&mut self) {
134 let aggregated = match self.finalizer.take() {
135 Some(finalizer) => finalizer(),
136 None => Json::Null,
137 };
138
139 // Decode aggregated response if response codec is present (non-fatal)
140 let annotated_response: Option<Arc<AnnotatedLlmResponse>> = self
141 .response_codec
142 .as_ref()
143 .and_then(|c| c.decode_response(&aggregated).ok())
144 .map(Arc::new);
145
146 let event_snapshot = {
147 let ss_guard = self.scope_stack.read().expect("scope stack lock poisoned");
148 let sl =
149 ss_guard.collect_scope_local_registries(|r| &r.llm_sanitize_response_guardrails);
150 let sl_subs = ss_guard.collect_scope_local_subscribers();
151 let ctx = global_context();
152 let state = ctx.read();
153 match state {
154 Ok(state) => {
155 let subscribers = state.collect_event_subscribers(&sl_subs);
156 let sanitized = state.llm_sanitize_response_chain(aggregated, &sl);
157 let data = if sanitized.is_null() {
158 self.handle.data.clone()
159 } else {
160 Some(sanitized)
161 };
162 let event = state.end_llm_handle(
163 &self.handle,
164 data,
165 self.metadata.clone(),
166 annotated_response,
167 );
168 Some((event, subscribers))
169 }
170 Err(_) => None,
171 }
172 };
173 if let Some((event, subscribers)) = event_snapshot {
174 NemoFlowContextState::emit_event(&event, &subscribers);
175 }
176 }
177}
178
179impl Stream for LlmStreamWrapper {
180 type Item = Result<Json>;
181
182 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183 let this = self.get_mut();
184
185 if this.ended {
186 return Poll::Ready(None);
187 }
188
189 // Poll the inner stream
190 match this.inner.as_mut().poll_next(cx) {
191 Poll::Ready(Some(Ok(raw_chunk))) => {
192 // Feed chunk to the collector; if it returns Err, terminate the stream
193 match (this.collector)(raw_chunk.clone()) {
194 Ok(()) => Poll::Ready(Some(Ok(raw_chunk))),
195 Err(e) => {
196 this.finish();
197 Poll::Ready(Some(Err(e)))
198 }
199 }
200 }
201 Poll::Ready(Some(Err(e))) => {
202 this.finish();
203 Poll::Ready(Some(Err(e)))
204 }
205 Poll::Ready(None) => {
206 this.finish();
207 Poll::Ready(None)
208 }
209 Poll::Pending => Poll::Pending,
210 }
211 }
212}
213
214impl Drop for LlmStreamWrapper {
215 fn drop(&mut self) {
216 self.finish();
217 }
218}