use crate::messages::{StopReason, ToolCall};
use crate::types::CorrelationId;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct ActiveStream {
pub correlation_id: CorrelationId,
pub content: String,
pub tool_calls: Vec<ToolCall>,
pub started: bool,
pub stop_reason: Option<StopReason>,
}
impl ActiveStream {
#[must_use]
pub fn new(correlation_id: CorrelationId) -> Self {
Self {
correlation_id,
content: String::new(),
tool_calls: Vec::new(),
started: false,
stop_reason: None,
}
}
pub fn mark_started(&mut self) {
self.started = true;
}
pub fn append_token(&mut self, token: &str) {
self.content.push_str(token);
}
pub fn add_tool_call(&mut self, tool_call: ToolCall) {
self.tool_calls.push(tool_call);
}
pub fn mark_ended(&mut self, stop_reason: StopReason) {
self.stop_reason = Some(stop_reason);
}
#[must_use]
pub fn is_ended(&self) -> bool {
self.stop_reason.is_some()
}
#[must_use]
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
#[must_use]
pub fn content_length(&self) -> usize {
self.content.len()
}
}
#[derive(Debug, Clone, Default)]
pub struct StreamAccumulator {
streams: HashMap<String, ActiveStream>,
}
impl StreamAccumulator {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn start_stream(&mut self, correlation_id: &CorrelationId) {
let key = correlation_id.to_string();
let mut stream = ActiveStream::new(correlation_id.clone());
stream.mark_started();
self.streams.insert(key, stream);
}
pub fn get_stream_mut(&mut self, correlation_id: &CorrelationId) -> Option<&mut ActiveStream> {
let key = correlation_id.to_string();
self.streams.get_mut(&key)
}
#[must_use]
pub fn get_stream(&self, correlation_id: &CorrelationId) -> Option<&ActiveStream> {
let key = correlation_id.to_string();
self.streams.get(&key)
}
pub fn append_token(&mut self, correlation_id: &CorrelationId, token: &str) {
if let Some(stream) = self.get_stream_mut(correlation_id) {
stream.append_token(token);
}
}
pub fn add_tool_call(&mut self, correlation_id: &CorrelationId, tool_call: ToolCall) {
if let Some(stream) = self.get_stream_mut(correlation_id) {
stream.add_tool_call(tool_call);
}
}
pub fn end_stream(
&mut self,
correlation_id: &CorrelationId,
stop_reason: StopReason,
) -> Option<ActiveStream> {
let key = correlation_id.to_string();
if let Some(mut stream) = self.streams.remove(&key) {
stream.mark_ended(stop_reason);
Some(stream)
} else {
None
}
}
pub fn remove_stream(&mut self, correlation_id: &CorrelationId) -> Option<ActiveStream> {
let key = correlation_id.to_string();
self.streams.remove(&key)
}
#[must_use]
pub fn active_count(&self) -> usize {
self.streams.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.streams.is_empty()
}
pub fn clear(&mut self) {
self.streams.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn active_stream_new() {
let corr_id = CorrelationId::new();
let stream = ActiveStream::new(corr_id.clone());
assert_eq!(stream.correlation_id, corr_id);
assert!(stream.content.is_empty());
assert!(stream.tool_calls.is_empty());
assert!(!stream.started);
assert!(stream.stop_reason.is_none());
}
#[test]
fn active_stream_append_token() {
let corr_id = CorrelationId::new();
let mut stream = ActiveStream::new(corr_id);
stream.append_token("Hello");
stream.append_token(" ");
stream.append_token("World");
assert_eq!(stream.content, "Hello World");
assert_eq!(stream.content_length(), 11);
}
#[test]
fn active_stream_add_tool_call() {
let corr_id = CorrelationId::new();
let mut stream = ActiveStream::new(corr_id);
let tool_call = ToolCall {
id: "tc_123".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"query": "rust"}),
};
stream.add_tool_call(tool_call);
assert!(stream.has_tool_calls());
assert_eq!(stream.tool_calls.len(), 1);
}
#[test]
fn active_stream_lifecycle() {
let corr_id = CorrelationId::new();
let mut stream = ActiveStream::new(corr_id);
assert!(!stream.started);
assert!(!stream.is_ended());
stream.mark_started();
assert!(stream.started);
assert!(!stream.is_ended());
stream.mark_ended(StopReason::EndTurn);
assert!(stream.is_ended());
assert_eq!(stream.stop_reason, Some(StopReason::EndTurn));
}
#[test]
fn stream_accumulator_start_stream() {
let mut accumulator = StreamAccumulator::new();
let corr_id = CorrelationId::new();
assert!(accumulator.is_empty());
accumulator.start_stream(&corr_id);
assert_eq!(accumulator.active_count(), 1);
assert!(!accumulator.is_empty());
let stream = accumulator.get_stream(&corr_id).unwrap();
assert!(stream.started);
}
#[test]
fn stream_accumulator_append_token() {
let mut accumulator = StreamAccumulator::new();
let corr_id = CorrelationId::new();
accumulator.start_stream(&corr_id);
accumulator.append_token(&corr_id, "Hello");
accumulator.append_token(&corr_id, " World");
let stream = accumulator.get_stream(&corr_id).unwrap();
assert_eq!(stream.content, "Hello World");
}
#[test]
fn stream_accumulator_end_stream() {
let mut accumulator = StreamAccumulator::new();
let corr_id = CorrelationId::new();
accumulator.start_stream(&corr_id);
accumulator.append_token(&corr_id, "Test content");
let stream = accumulator
.end_stream(&corr_id, StopReason::EndTurn)
.unwrap();
assert_eq!(stream.content, "Test content");
assert!(stream.is_ended());
assert!(accumulator.is_empty());
}
#[test]
fn stream_accumulator_remove_stream() {
let mut accumulator = StreamAccumulator::new();
let corr_id = CorrelationId::new();
accumulator.start_stream(&corr_id);
let stream = accumulator.remove_stream(&corr_id);
assert!(stream.is_some());
assert!(!stream.unwrap().is_ended());
assert!(accumulator.is_empty());
}
#[test]
fn stream_accumulator_multiple_streams() {
let mut accumulator = StreamAccumulator::new();
let corr_id1 = CorrelationId::new();
let corr_id2 = CorrelationId::new();
accumulator.start_stream(&corr_id1);
accumulator.start_stream(&corr_id2);
accumulator.append_token(&corr_id1, "Stream 1");
accumulator.append_token(&corr_id2, "Stream 2");
assert_eq!(accumulator.active_count(), 2);
let stream1 = accumulator.get_stream(&corr_id1).unwrap();
let stream2 = accumulator.get_stream(&corr_id2).unwrap();
assert_eq!(stream1.content, "Stream 1");
assert_eq!(stream2.content, "Stream 2");
}
#[test]
fn stream_accumulator_clear() {
let mut accumulator = StreamAccumulator::new();
accumulator.start_stream(&CorrelationId::new());
accumulator.start_stream(&CorrelationId::new());
assert_eq!(accumulator.active_count(), 2);
accumulator.clear();
assert!(accumulator.is_empty());
}
#[test]
fn stream_accumulator_add_tool_call() {
let mut accumulator = StreamAccumulator::new();
let corr_id = CorrelationId::new();
accumulator.start_stream(&corr_id);
let tool_call = ToolCall {
id: "tc_456".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "2+2"}),
};
accumulator.add_tool_call(&corr_id, tool_call);
let stream = accumulator.get_stream(&corr_id).unwrap();
assert!(stream.has_tool_calls());
assert_eq!(stream.tool_calls[0].name, "calculator");
}
use crate::messages::ToolCall;
}