#![allow(clippy::await_holding_lock)]
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use nemo_flow::api::event::{Event, ScopeCategory};
use nemo_flow::api::llm::{LlmAttributes, LlmHandle, LlmRequest};
use nemo_flow::api::llm::{LlmCallParams, llm_call};
use nemo_flow::api::runtime::NemoFlowContextState;
use nemo_flow::api::runtime::global_context;
use nemo_flow::api::subscriber::{deregister_subscriber, register_subscriber};
use nemo_flow::error::FlowError;
use nemo_flow::error::Result;
use nemo_flow::json::Json;
use nemo_flow::stream::LlmStreamWrapper;
use serde_json::json;
use tokio_stream::{Stream, StreamExt};
static TEST_MUTEX: Mutex<()> = Mutex::new(());
fn is_llm_end(event: &Event) -> bool {
event.scope_type() == Some(nemo_flow::api::scope::ScopeType::Llm)
&& event.scope_category() == Some(ScopeCategory::End)
}
fn reset_global() {
let ctx = global_context();
let mut state = ctx.write().unwrap();
*state = NemoFlowContextState::new();
}
fn make_llm_handle(name: &str) -> LlmHandle {
LlmHandle::builder()
.name(name.to_string())
.attributes(LlmAttributes::STREAMING)
.build()
}
fn make_stream(items: Vec<Result<Json>>) -> Pin<Box<dyn Stream<Item = Result<Json>> + Send>> {
Box::pin(tokio_stream::iter(items))
}
#[allow(clippy::type_complexity)]
fn make_collector_finalizer() -> (
Box<dyn FnMut(Json) -> Result<()> + Send>,
Box<dyn FnOnce() -> Json + Send>,
Arc<Mutex<Vec<Json>>>,
) {
let collected = Arc::new(Mutex::new(Vec::<Json>::new()));
let cc = collected.clone();
let collector: Box<dyn FnMut(Json) -> Result<()> + Send> = Box::new(move |chunk| {
cc.lock().unwrap().push(chunk);
Ok(())
});
let fc = collected.clone();
let finalizer: Box<dyn FnOnce() -> Json + Send> = Box::new(move || {
let chunks = fc.lock().unwrap();
Json::Array(chunks.clone())
});
(collector, finalizer, collected)
}
#[tokio::test]
async fn test_stream_wrapper_basic_chunks() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let items = vec![Ok(json!({"token": "hello"})), Ok(json!({"token": "world"}))];
let inner = make_stream(items);
let handle = make_llm_handle("test_llm");
let (collector, finalizer, _collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
let mut chunks = Vec::new();
while let Some(item) = wrapper.next().await {
chunks.push(item.unwrap());
}
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0]["token"], "hello");
assert_eq!(chunks[1]["token"], "world");
}
#[tokio::test]
async fn test_stream_wrapper_passthrough() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let items = vec![Ok(json!("data: partial")), Ok(json!("more data"))];
let inner = make_stream(items);
let handle = make_llm_handle("test_llm");
let (collector, finalizer, _collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
let mut chunks = Vec::new();
while let Some(item) = wrapper.next().await {
chunks.push(item.unwrap());
}
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0], json!("data: partial"));
assert_eq!(chunks[1], json!("more data"));
}
#[tokio::test]
async fn test_stream_wrapper_empty_stream() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let inner: Pin<Box<dyn Stream<Item = Result<Json>> + Send>> = Box::pin(tokio_stream::empty());
let handle = make_llm_handle("test_llm");
let (collector, finalizer, _collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
let mut count = 0;
while let Some(_item) = wrapper.next().await {
count += 1;
}
assert_eq!(count, 0);
}
#[tokio::test]
async fn test_stream_wrapper_single_chunk() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let items = vec![Ok(json!("only chunk"))];
let inner = make_stream(items);
let handle = make_llm_handle("test_llm");
let (collector, finalizer, _collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
let mut chunks = Vec::new();
while let Some(item) = wrapper.next().await {
chunks.push(item.unwrap());
}
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], json!("only chunk"));
}
#[tokio::test]
async fn test_stream_wrapper_emits_end_event() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let events = Arc::new(Mutex::new(Vec::new()));
let ec = events.clone();
register_subscriber(
"stream_end_test",
Arc::new(move |e: &Event| {
let phase = match e.scope_category() {
Some(ScopeCategory::Start) => "start",
Some(ScopeCategory::End) => "end",
None => e.kind(),
};
ec.lock().unwrap().push((phase.to_string(), e.scope_type()));
}),
)
.unwrap();
let items = vec![Ok(json!({"token": "hi"}))];
let inner = make_stream(items);
let request = LlmRequest {
headers: serde_json::Map::new(),
content: json!({"messages": []}),
};
let handle = llm_call(
LlmCallParams::builder()
.name("test_llm")
.request(&request)
.attributes(LlmAttributes::STREAMING)
.build(),
)
.unwrap();
let (collector, finalizer, _collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
while let Some(_item) = wrapper.next().await {}
let captured = events.lock().unwrap();
assert!(captured.len() >= 2);
assert_eq!(captured[0].0, "start");
assert_eq!(captured.last().unwrap().0, "end");
drop(captured);
deregister_subscriber("stream_end_test").unwrap();
}
#[tokio::test]
async fn test_stream_wrapper_drop_emits_end_event_for_partial_stream() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let events = Arc::new(Mutex::new(Vec::new()));
let captured = events.clone();
register_subscriber(
"stream_drop_end_test",
Arc::new(move |e: &Event| {
captured.lock().unwrap().push(e.clone());
}),
)
.unwrap();
let inner = make_stream(vec![
Ok(json!({"token": "partial"})),
Ok(json!({"token": "unread"})),
]);
let request = LlmRequest {
headers: serde_json::Map::new(),
content: json!({"messages": []}),
};
let handle = llm_call(
LlmCallParams::builder()
.name("stream_drop_llm")
.request(&request)
.attributes(LlmAttributes::STREAMING)
.build(),
)
.unwrap();
let (collector, finalizer, _collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
assert_eq!(
wrapper.next().await.unwrap().unwrap(),
json!({"token": "partial"})
);
drop(wrapper);
let events = events.lock().unwrap();
let end_event = events
.iter()
.find(|event| is_llm_end(event))
.expect("expected END event when a partial stream is dropped");
assert_eq!(end_event.output(), Some(&json!([{"token": "partial"}])));
drop(events);
deregister_subscriber("stream_drop_end_test").unwrap();
}
#[tokio::test]
async fn test_stream_wrapper_error_propagation() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let items: Vec<Result<Json>> = vec![
Ok(json!("good chunk")),
Err(FlowError::Internal("stream error".into())),
];
let inner = make_stream(items);
let handle = make_llm_handle("test_llm");
let (collector, finalizer, _collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
let first = wrapper.next().await.unwrap();
assert!(first.is_ok());
assert_eq!(first.unwrap(), json!("good chunk"));
let second = wrapper.next().await.unwrap();
assert!(second.is_err());
}
#[tokio::test]
async fn test_stream_wrapper_json_chunks() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let items = vec![Ok(json!({"token": "hello"})), Ok(json!({"token": "world"}))];
let inner = make_stream(items);
let handle = make_llm_handle("test_llm");
let (collector, finalizer, _collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
let mut chunks = Vec::new();
while let Some(item) = wrapper.next().await {
chunks.push(item.unwrap());
}
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0]["token"], "hello");
assert_eq!(chunks[1]["token"], "world");
}
#[tokio::test]
async fn test_stream_wrapper_collector_receives_all_chunks() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let items = vec![
Ok(json!("chunk1")),
Ok(json!("chunk2")),
Ok(json!("chunk3")),
];
let inner = make_stream(items);
let handle = make_llm_handle("test_llm");
let (collector, finalizer, collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
while let Some(_item) = wrapper.next().await {}
let chunks = collected.lock().unwrap();
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0], json!("chunk1"));
assert_eq!(chunks[1], json!("chunk2"));
assert_eq!(chunks[2], json!("chunk3"));
}
#[tokio::test]
async fn test_stream_wrapper_finalizer_called_on_exhaustion() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let finalizer_called = Arc::new(Mutex::new(false));
let fc = finalizer_called.clone();
let items = vec![Ok(json!("chunk"))];
let inner = make_stream(items);
let handle = make_llm_handle("test_llm");
let collector: Box<dyn FnMut(Json) -> Result<()> + Send> = Box::new(|_| Ok(()));
let finalizer: Box<dyn FnOnce() -> Json + Send> = Box::new(move || {
*fc.lock().unwrap() = true;
json!({"finalized": true})
});
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
assert!(!*finalizer_called.lock().unwrap());
while let Some(_item) = wrapper.next().await {}
assert!(*finalizer_called.lock().unwrap());
}
#[tokio::test]
async fn test_stream_wrapper_error_skips_collector_and_finalizes_immediately() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let collector_calls = Arc::new(Mutex::new(0u32));
let cc = collector_calls.clone();
let finalizer_called = Arc::new(Mutex::new(false));
let fc = finalizer_called.clone();
let items: Vec<Result<Json>> = vec![Err(FlowError::Internal("error".into()))];
let inner = make_stream(items);
let handle = make_llm_handle("test_llm");
let collector: Box<dyn FnMut(Json) -> Result<()> + Send> = Box::new(move |_| {
*cc.lock().unwrap() += 1;
Ok(())
});
let finalizer: Box<dyn FnOnce() -> Json + Send> = Box::new(move || {
*fc.lock().unwrap() = true;
Json::Null
});
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
let result = wrapper.next().await.unwrap();
assert!(result.is_err());
assert_eq!(*collector_calls.lock().unwrap(), 0);
assert!(*finalizer_called.lock().unwrap());
assert!(wrapper.next().await.is_none());
}
#[tokio::test]
async fn test_stream_wrapper_error_emits_end_event_on_first_error_poll() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let events = Arc::new(Mutex::new(Vec::new()));
let captured = events.clone();
register_subscriber(
"stream_error_end_test",
Arc::new(move |e: &Event| {
captured.lock().unwrap().push(e.clone());
}),
)
.unwrap();
let items: Vec<Result<Json>> = vec![Err(FlowError::Internal("error".into()))];
let inner = make_stream(items);
let request = LlmRequest {
headers: serde_json::Map::new(),
content: json!({"messages": []}),
};
let handle = llm_call(
LlmCallParams::builder()
.name("stream_error_llm")
.request(&request)
.attributes(LlmAttributes::STREAMING)
.build(),
)
.unwrap();
let collector: Box<dyn FnMut(Json) -> Result<()> + Send> = Box::new(|_| Ok(()));
let finalizer: Box<dyn FnOnce() -> Json + Send> = Box::new(|| json!({"partial": true}));
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
let result = wrapper.next().await.unwrap();
assert!(result.is_err());
let events = events.lock().unwrap();
let end_event = events
.iter()
.find(|event| is_llm_end(event))
.expect("expected END event on first error poll");
assert_eq!(end_event.output(), Some(&json!({"partial": true})));
drop(events);
deregister_subscriber("stream_error_end_test").unwrap();
}
#[tokio::test]
async fn test_stream_wrapper_end_event_contains_intercepted_response() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let events = Arc::new(Mutex::new(Vec::new()));
let ec = events.clone();
register_subscriber(
"end_event_test",
Arc::new(move |e: &Event| {
ec.lock().unwrap().push(e.clone());
}),
)
.unwrap();
let items = vec![Ok(json!({"token": "a"})), Ok(json!({"token": "b"}))];
let inner = make_stream(items);
let request = LlmRequest {
headers: serde_json::Map::new(),
content: json!({"messages": []}),
};
let handle = llm_call(
LlmCallParams::builder()
.name("test_llm")
.request(&request)
.attributes(LlmAttributes::STREAMING)
.build(),
)
.unwrap();
let (collector, finalizer, _collected) = make_collector_finalizer();
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
while let Some(_item) = wrapper.next().await {}
let captured = events.lock().unwrap();
let end_event = captured.iter().find(|e| is_llm_end(e)).unwrap();
let output = end_event.output().unwrap();
assert!(output.is_array());
let arr = output.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["token"], "a");
assert_eq!(arr[1]["token"], "b");
drop(captured);
deregister_subscriber("end_event_test").unwrap();
}
#[tokio::test]
async fn test_stream_wrapper_collector_error_terminates_stream() {
let _lock = TEST_MUTEX.lock().unwrap();
reset_global();
let collector_calls = Arc::new(Mutex::new(0u32));
let cc = collector_calls.clone();
let items = vec![
Ok(json!("chunk1")),
Ok(json!("chunk2")),
Ok(json!("chunk3")),
];
let inner = make_stream(items);
let handle = make_llm_handle("test_llm");
let collector: Box<dyn FnMut(Json) -> Result<()> + Send> = Box::new(move |_chunk| {
let mut count = cc.lock().unwrap();
*count += 1;
if *count >= 2 {
Err(FlowError::Internal("collector error".into()))
} else {
Ok(())
}
});
let finalizer: Box<dyn FnOnce() -> Json + Send> = Box::new(|| Json::Null);
let mut wrapper = LlmStreamWrapper::new(inner, handle, collector, finalizer, None, None, None);
let first = wrapper.next().await;
assert!(first.is_some());
assert!(first.unwrap().is_ok());
let second = wrapper.next().await;
assert!(second.is_some());
let second_result = second.unwrap();
assert!(second_result.is_err());
match second_result {
Err(FlowError::Internal(msg)) => {
assert_eq!(msg, "collector error");
}
other => panic!("expected Internal error, got {other:?}"),
}
let third = wrapper.next().await;
assert!(third.is_none());
assert_eq!(*collector_calls.lock().unwrap(), 2);
}