use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use serde_json::Value;
use uuid::Uuid;
use cognis_core::callbacks::base::CallbackHandler;
use cognis_core::error::Result;
use cognis_core::runnables::base::Runnable;
use cognis_core::runnables::config::RunnableConfig;
#[derive(Debug, Clone)]
pub struct StreamingChainResult {
pub output: String,
pub token_count: usize,
pub chunks: Vec<String>,
pub metadata: HashMap<String, Value>,
}
impl StreamingChainResult {
pub fn new(
output: String,
token_count: usize,
chunks: Vec<String>,
metadata: HashMap<String, Value>,
) -> Self {
Self {
output,
token_count,
chunks,
metadata,
}
}
}
pub struct TokenCollector {
tokens: Mutex<Vec<String>>,
}
impl TokenCollector {
pub fn new() -> Self {
Self {
tokens: Mutex::new(Vec::new()),
}
}
pub fn get_tokens(&self) -> Vec<String> {
self.tokens.lock().unwrap().clone()
}
pub fn get_full_text(&self) -> String {
self.tokens.lock().unwrap().join("")
}
pub fn token_count(&self) -> usize {
self.tokens.lock().unwrap().len()
}
pub fn clear(&self) {
self.tokens.lock().unwrap().clear();
}
}
impl Default for TokenCollector {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl CallbackHandler for TokenCollector {
fn name(&self) -> &str {
"TokenCollector"
}
async fn on_llm_new_token(
&self,
token: &str,
_run_id: Uuid,
_parent_run_id: Option<Uuid>,
) -> Result<()> {
self.tokens.lock().unwrap().push(token.to_string());
Ok(())
}
}
#[async_trait]
pub trait StreamingCallback: Send + Sync {
async fn on_token(&self, token: &str);
async fn on_chain_start(&self, name: &str);
async fn on_chain_end(&self, output: &str);
async fn on_error(&self, error: &str);
}
pub struct StreamingCallbackAdapter {
inner: Arc<dyn StreamingCallback>,
}
impl StreamingCallbackAdapter {
pub fn new(callback: Arc<dyn StreamingCallback>) -> Self {
Self { inner: callback }
}
}
#[async_trait]
impl CallbackHandler for StreamingCallbackAdapter {
fn name(&self) -> &str {
"StreamingCallbackAdapter"
}
async fn on_llm_new_token(
&self,
token: &str,
_run_id: Uuid,
_parent_run_id: Option<Uuid>,
) -> Result<()> {
self.inner.on_token(token).await;
Ok(())
}
async fn on_chain_start(
&self,
_serialized: &Value,
_inputs: &Value,
_run_id: Uuid,
_parent_run_id: Option<Uuid>,
) -> Result<()> {
self.inner.on_chain_start("chain").await;
Ok(())
}
async fn on_chain_end(
&self,
outputs: &Value,
_run_id: Uuid,
_parent_run_id: Option<Uuid>,
) -> Result<()> {
let output_str = outputs
.as_str()
.map(|s| s.to_string())
.unwrap_or_else(|| serde_json::to_string(outputs).unwrap_or_default());
self.inner.on_chain_end(&output_str).await;
Ok(())
}
async fn on_chain_error(
&self,
error: &str,
_run_id: Uuid,
_parent_run_id: Option<Uuid>,
) -> Result<()> {
self.inner.on_error(error).await;
Ok(())
}
async fn on_llm_error(
&self,
error: &str,
_run_id: Uuid,
_parent_run_id: Option<Uuid>,
) -> Result<()> {
self.inner.on_error(error).await;
Ok(())
}
}
pub struct ConsoleStreamingCallback;
impl ConsoleStreamingCallback {
pub fn new() -> Self {
Self
}
}
impl Default for ConsoleStreamingCallback {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StreamingCallback for ConsoleStreamingCallback {
async fn on_token(&self, token: &str) {
use std::io::Write;
print!("{}", token);
std::io::stdout().flush().ok();
}
async fn on_chain_start(&self, name: &str) {
println!("\n> Starting chain: {}", name);
}
async fn on_chain_end(&self, output: &str) {
println!("\n> Chain complete: {}", output);
}
async fn on_error(&self, error: &str) {
eprintln!("\n> Chain error: {}", error);
}
}
pub struct StreamingChainExecutor {
chain: Arc<dyn Runnable>,
callbacks: Vec<Arc<dyn CallbackHandler>>,
}
impl StreamingChainExecutor {
pub fn new(chain: Arc<dyn Runnable>) -> Self {
Self {
chain,
callbacks: Vec::new(),
}
}
pub fn with_callback(mut self, handler: Arc<dyn CallbackHandler>) -> Self {
self.callbacks.push(handler);
self
}
pub fn add_callback(&mut self, handler: Arc<dyn CallbackHandler>) {
self.callbacks.push(handler);
}
pub async fn execute_streaming(
&self,
input: Value,
extra_callbacks: Option<Vec<Arc<dyn CallbackHandler>>>,
) -> Result<StreamingChainResult> {
let run_id = Uuid::new_v4();
let chain_name = self.chain.name().to_string();
let all_callbacks: Vec<&Arc<dyn CallbackHandler>> = self
.callbacks
.iter()
.chain(extra_callbacks.as_ref().map_or([].iter(), |v| v.iter()))
.collect();
let serialized = serde_json::json!({"name": chain_name});
for cb in &all_callbacks {
if !cb.ignore_chain() {
cb.on_chain_start(&serialized, &input, run_id, None).await?;
}
}
let config = RunnableConfig {
callbacks: all_callbacks.iter().map(|cb| Arc::clone(cb)).collect(),
..RunnableConfig::default()
};
let result = self.chain.invoke(input, Some(&config)).await;
match result {
Ok(output) => {
for cb in &all_callbacks {
if !cb.ignore_chain() {
cb.on_chain_end(&output, run_id, None).await?;
}
}
let output_text = output
.as_object()
.and_then(|m| m.values().next().and_then(|v| v.as_str().map(String::from)))
.unwrap_or_else(|| serde_json::to_string(&output).unwrap_or_default());
let mut all_chunks: Vec<String> = Vec::new();
for cb in &all_callbacks {
let _ = cb; }
if all_chunks.is_empty() {
all_chunks.push(output_text.clone());
}
let token_count = all_chunks.len();
let mut metadata = HashMap::new();
metadata.insert("chain_name".to_string(), Value::String(chain_name));
metadata.insert("run_id".to_string(), Value::String(run_id.to_string()));
Ok(StreamingChainResult::new(
output_text,
token_count,
all_chunks,
metadata,
))
}
Err(e) => {
let error_str = e.to_string();
for cb in &all_callbacks {
if !cb.ignore_chain() {
let _ = cb.on_chain_error(&error_str, run_id, None).await;
}
}
Err(e)
}
}
}
}
pub async fn stream_chain(
chain: Arc<dyn Runnable>,
input: Value,
callback: Arc<dyn StreamingCallback>,
) -> Result<String> {
let adapter = Arc::new(StreamingCallbackAdapter::new(callback));
let executor = StreamingChainExecutor::new(chain).with_callback(adapter);
let result = executor.execute_streaming(input, None).await?;
Ok(result.output)
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::language_models::fake::FakeListChatModel;
fn fake_model(responses: Vec<&str>) -> Arc<dyn BaseChatModel> {
Arc::new(FakeListChatModel::new(
responses.into_iter().map(String::from).collect(),
))
}
fn make_chain(responses: Vec<&str>) -> Arc<dyn Runnable> {
use crate::chains::LLMChain;
Arc::new(
LLMChain::builder()
.model(fake_model(responses))
.prompt("{input}")
.build(),
)
}
struct RecordingCallback {
tokens: Mutex<Vec<String>>,
events: Mutex<Vec<String>>,
}
impl RecordingCallback {
fn new() -> Self {
Self {
tokens: Mutex::new(Vec::new()),
events: Mutex::new(Vec::new()),
}
}
fn events(&self) -> Vec<String> {
self.events.lock().unwrap().clone()
}
}
#[async_trait]
impl StreamingCallback for RecordingCallback {
async fn on_token(&self, token: &str) {
self.tokens.lock().unwrap().push(token.to_string());
self.events.lock().unwrap().push(format!("token:{}", token));
}
async fn on_chain_start(&self, name: &str) {
self.events.lock().unwrap().push(format!("start:{}", name));
}
async fn on_chain_end(&self, output: &str) {
self.events.lock().unwrap().push(format!("end:{}", output));
}
async fn on_error(&self, error: &str) {
self.events.lock().unwrap().push(format!("error:{}", error));
}
}
#[tokio::test]
async fn test_token_collector_captures_tokens() {
let collector = TokenCollector::new();
let run_id = Uuid::new_v4();
collector
.on_llm_new_token("Hello", run_id, None)
.await
.unwrap();
collector
.on_llm_new_token(" world", run_id, None)
.await
.unwrap();
let tokens = collector.get_tokens();
assert_eq!(tokens, vec!["Hello", " world"]);
}
#[tokio::test]
async fn test_token_collector_get_full_text() {
let collector = TokenCollector::new();
let run_id = Uuid::new_v4();
collector
.on_llm_new_token("Hello", run_id, None)
.await
.unwrap();
collector
.on_llm_new_token(", ", run_id, None)
.await
.unwrap();
collector
.on_llm_new_token("world!", run_id, None)
.await
.unwrap();
assert_eq!(collector.get_full_text(), "Hello, world!");
}
#[tokio::test]
async fn test_token_collector_clear() {
let collector = TokenCollector::new();
let run_id = Uuid::new_v4();
collector
.on_llm_new_token("token1", run_id, None)
.await
.unwrap();
collector
.on_llm_new_token("token2", run_id, None)
.await
.unwrap();
assert_eq!(collector.token_count(), 2);
collector.clear();
assert_eq!(collector.token_count(), 0);
assert!(collector.get_tokens().is_empty());
assert_eq!(collector.get_full_text(), "");
}
#[tokio::test]
async fn test_streaming_callback_adapter_wraps_correctly() {
let recording = Arc::new(RecordingCallback::new());
let adapter = StreamingCallbackAdapter::new(recording.clone());
let run_id = Uuid::new_v4();
adapter
.on_llm_new_token("test_token", run_id, None)
.await
.unwrap();
adapter
.on_chain_start(&serde_json::json!({}), &serde_json::json!({}), run_id, None)
.await
.unwrap();
adapter
.on_chain_end(&serde_json::json!("done"), run_id, None)
.await
.unwrap();
let events = recording.events();
assert!(events.contains(&"token:test_token".to_string()));
assert!(events.contains(&"start:chain".to_string()));
assert!(events.contains(&"end:done".to_string()));
}
#[test]
fn test_console_streaming_callback_creation() {
let _cb = ConsoleStreamingCallback::new();
let _cb_default = ConsoleStreamingCallback::default();
}
#[test]
fn test_streaming_chain_result_construction() {
let mut metadata = HashMap::new();
metadata.insert("key".to_string(), Value::String("value".to_string()));
let result = StreamingChainResult::new(
"full output".to_string(),
3,
vec!["a".to_string(), "b".to_string(), "c".to_string()],
metadata.clone(),
);
assert_eq!(result.output, "full output");
assert_eq!(result.token_count, 3);
assert_eq!(result.chunks.len(), 3);
assert_eq!(result.chunks[0], "a");
assert_eq!(result.chunks[1], "b");
assert_eq!(result.chunks[2], "c");
assert_eq!(
result.metadata.get("key"),
Some(&Value::String("value".to_string()))
);
}
#[tokio::test]
async fn test_multiple_callbacks_receive_same_tokens() {
let collector1 = Arc::new(TokenCollector::new());
let collector2 = Arc::new(TokenCollector::new());
let run_id = Uuid::new_v4();
let tokens = vec!["Hello", " ", "world"];
for token in &tokens {
collector1
.on_llm_new_token(token, run_id, None)
.await
.unwrap();
collector2
.on_llm_new_token(token, run_id, None)
.await
.unwrap();
}
assert_eq!(collector1.get_tokens(), vec!["Hello", " ", "world"]);
assert_eq!(collector2.get_tokens(), vec!["Hello", " ", "world"]);
assert_eq!(collector1.get_full_text(), collector2.get_full_text());
}
#[tokio::test]
async fn test_token_count_tracking() {
let collector = TokenCollector::new();
let run_id = Uuid::new_v4();
assert_eq!(collector.token_count(), 0);
for i in 0..5 {
collector
.on_llm_new_token(&format!("t{}", i), run_id, None)
.await
.unwrap();
}
assert_eq!(collector.token_count(), 5);
}
#[tokio::test]
async fn test_empty_stream_handling() {
let collector = TokenCollector::new();
assert_eq!(collector.token_count(), 0);
assert!(collector.get_tokens().is_empty());
assert_eq!(collector.get_full_text(), "");
let result = StreamingChainResult::new(String::new(), 0, Vec::new(), HashMap::new());
assert_eq!(result.output, "");
assert_eq!(result.token_count, 0);
assert!(result.chunks.is_empty());
}
#[tokio::test]
async fn test_chain_metadata_in_result() {
let chain = make_chain(vec!["response text"]);
let executor = StreamingChainExecutor::new(chain);
let result = executor
.execute_streaming(serde_json::json!({"input": "hello"}), None)
.await
.unwrap();
assert!(result.metadata.contains_key("chain_name"));
assert_eq!(
result.metadata.get("chain_name"),
Some(&Value::String("LLMChain".to_string()))
);
assert!(result.metadata.contains_key("run_id"));
}
#[tokio::test]
async fn test_streaming_callback_trait_impl() {
let recording = Arc::new(RecordingCallback::new());
recording.on_token("hello").await;
recording.on_chain_start("test_chain").await;
recording.on_chain_end("done").await;
recording.on_error("something failed").await;
let events = recording.events();
assert_eq!(events.len(), 4);
assert_eq!(events[0], "token:hello");
assert_eq!(events[1], "start:test_chain");
assert_eq!(events[2], "end:done");
assert_eq!(events[3], "error:something failed");
}
#[tokio::test]
async fn test_callback_ordering() {
let recording = Arc::new(RecordingCallback::new());
let adapter: Arc<dyn CallbackHandler> =
Arc::new(StreamingCallbackAdapter::new(recording.clone()));
let chain = make_chain(vec!["response"]);
let executor = StreamingChainExecutor::new(chain).with_callback(adapter);
let _result = executor
.execute_streaming(serde_json::json!({"input": "test"}), None)
.await
.unwrap();
let events = recording.events();
assert!(
events.first().map_or(false, |e| e.starts_with("start:")),
"First event should be chain start, got: {:?}",
events
);
assert!(
events.last().map_or(false, |e| e.starts_with("end:")),
"Last event should be chain end, got: {:?}",
events
);
}
#[tokio::test]
async fn test_error_callback_on_failure() {
let chain: Arc<dyn Runnable> = Arc::new(
crate::chains::LLMChain::builder()
.model(fake_model(vec!["response"]))
.prompt("{missing_var}")
.build(),
);
let recording = Arc::new(RecordingCallback::new());
let adapter: Arc<dyn CallbackHandler> =
Arc::new(StreamingCallbackAdapter::new(recording.clone()));
let executor = StreamingChainExecutor::new(chain).with_callback(adapter);
let result = executor
.execute_streaming(serde_json::json!({"input": "test"}), None)
.await;
assert!(result.is_err());
let events = recording.events();
assert!(
events.iter().any(|e| e.starts_with("start:")),
"Should have a start event, got: {:?}",
events
);
assert!(
events.iter().any(|e| e.starts_with("error:")),
"Should have an error event, got: {:?}",
events
);
}
}