use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio_stream::Stream;
use crate::api::llm::LlmHandle;
use crate::api::runtime::NemoFlowContextState;
use crate::api::runtime::global_context;
use crate::api::runtime::{ScopeStackHandle, current_scope_stack};
use crate::codec::response::AnnotatedLlmResponse;
use crate::codec::traits::LlmResponseCodec;
use crate::error::Result;
use crate::json::Json;
pub struct LlmStreamWrapper {
inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
handle: LlmHandle,
scope_stack: ScopeStackHandle,
collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
finalizer: Option<Box<dyn FnOnce() -> Json + Send>>,
response_codec: Option<Arc<dyn LlmResponseCodec>>,
metadata: Option<Json>,
ended: bool,
}
impl LlmStreamWrapper {
pub fn new(
inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>>,
handle: LlmHandle,
collector: Box<dyn FnMut(Json) -> Result<()> + Send>,
finalizer: Box<dyn FnOnce() -> Json + Send>,
_data: Option<Json>,
metadata: Option<Json>,
response_codec: Option<Arc<dyn LlmResponseCodec>>,
) -> Self {
Self {
inner,
handle,
scope_stack: current_scope_stack(),
collector,
finalizer: Some(finalizer),
response_codec,
metadata,
ended: false,
}
}
pub fn scope_stack(&self) -> &ScopeStackHandle {
&self.scope_stack
}
fn finish(&mut self) {
if self.ended {
return;
}
self.ended = true;
self.emit_end_event();
}
fn emit_end_event(&mut self) {
let aggregated = match self.finalizer.take() {
Some(finalizer) => finalizer(),
None => Json::Null,
};
let annotated_response: Option<Arc<AnnotatedLlmResponse>> = self
.response_codec
.as_ref()
.and_then(|c| c.decode_response(&aggregated).ok())
.map(Arc::new);
let event_snapshot = {
let ss_guard = self.scope_stack.read().expect("scope stack lock poisoned");
let sl =
ss_guard.collect_scope_local_registries(|r| &r.llm_sanitize_response_guardrails);
let sl_subs = ss_guard.collect_scope_local_subscribers();
let ctx = global_context();
let state = ctx.read();
match state {
Ok(state) => {
let subscribers = state.collect_event_subscribers(&sl_subs);
let sanitized = state.llm_sanitize_response_chain(aggregated, &sl);
let data = if sanitized.is_null() {
self.handle.data.clone()
} else {
Some(sanitized)
};
let event = state.end_llm_handle(
&self.handle,
data,
self.metadata.clone(),
annotated_response,
);
Some((event, subscribers))
}
Err(_) => None,
}
};
if let Some((event, subscribers)) = event_snapshot {
NemoFlowContextState::emit_event(&event, &subscribers);
}
}
}
impl Stream for LlmStreamWrapper {
type Item = Result<Json>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.ended {
return Poll::Ready(None);
}
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(raw_chunk))) => {
match (this.collector)(raw_chunk.clone()) {
Ok(()) => Poll::Ready(Some(Ok(raw_chunk))),
Err(e) => {
this.finish();
Poll::Ready(Some(Err(e)))
}
}
}
Poll::Ready(Some(Err(e))) => {
this.finish();
Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => {
this.finish();
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
impl Drop for LlmStreamWrapper {
fn drop(&mut self) {
self.finish();
}
}