pub mod sse;
pub mod anthropic;
pub mod openai;
use crate::diagnostic::StopReason;
use crate::value::FlexValue;
use sse::{SseEvent, SseParser};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Provider {
Anthropic,
OpenAI,
}
#[derive(Debug, Clone)]
pub struct StreamConfig {
pub provider: Provider,
pub max_buffer_bytes: usize,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
provider: Provider::Anthropic,
max_buffer_bytes: 1_048_576, }
}
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
TextDelta(String),
BlockStart {
index: usize,
id: String,
block_type: String,
name: Option<String>,
},
BlockDelta {
index: usize,
fragment: String,
},
BlockComplete {
index: usize,
id: String,
block_type: String,
name: Option<String>,
content: FlexValue,
},
Metadata(FlexValue),
Stop(StopReason),
Unknown {
event_type: String,
data: FlexValue,
},
ParseError {
event_type: Option<String>,
raw_data: String,
error: String,
},
}
pub trait StreamHandler: Send {
fn process_event(&self, sse: &SseEvent) -> Vec<StreamEvent>;
}
#[derive(Debug, Clone)]
pub struct MessageSnapshot {
pub text: String,
pub tool_calls: Vec<(String, String, FlexValue)>, pub stop_reason: Option<StopReason>,
pub done: bool,
}
impl MessageSnapshot {
fn new() -> Self {
Self {
text: String::new(),
tool_calls: Vec::new(),
stop_reason: None,
done: false,
}
}
fn apply_event(&mut self, event: &StreamEvent) {
match event {
StreamEvent::TextDelta(text) => {
self.text.push_str(text);
}
StreamEvent::BlockComplete {
id,
name,
content,
block_type,
..
} => {
if block_type == "tool_use" || block_type == "function" {
self.tool_calls.push((
id.clone(),
name.clone().unwrap_or_default(),
content.clone(),
));
}
}
StreamEvent::Stop(reason) => {
self.stop_reason = Some(reason.clone());
self.done = true;
}
_ => {}
}
}
}
pub struct FlexStream {
_config: StreamConfig,
sse_parser: SseParser,
handler: Box<dyn StreamHandler>,
block_accumulators: std::collections::HashMap<usize, BlockAccumulator>,
snapshot: MessageSnapshot,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct BlockAccumulator {
id: String,
block_type: String,
name: Option<String>,
content_fragments: Vec<String>,
}
fn handler_for_provider(provider: Provider) -> Box<dyn StreamHandler> {
match provider {
Provider::Anthropic => Box::new(anthropic::AnthropicStreamHandler),
Provider::OpenAI => Box::new(openai::OpenAiStreamHandler),
}
}
impl FlexStream {
pub fn new(config: StreamConfig) -> Self {
let handler = handler_for_provider(config.provider);
Self {
_config: config,
sse_parser: SseParser::new(),
handler,
block_accumulators: std::collections::HashMap::new(),
snapshot: MessageSnapshot::new(),
}
}
pub fn with_handler(handler: Box<dyn StreamHandler>) -> Self {
Self {
_config: StreamConfig::default(),
sse_parser: SseParser::new(),
handler,
block_accumulators: std::collections::HashMap::new(),
snapshot: MessageSnapshot::new(),
}
}
pub fn current_message(&self) -> &MessageSnapshot {
&self.snapshot
}
pub fn feed(&mut self, chunk: &[u8]) -> Vec<StreamEvent> {
let sse_events = self.sse_parser.feed_bytes(chunk);
self.process_sse_events(sse_events)
}
pub fn feed_str(&mut self, chunk: &str) -> Vec<StreamEvent> {
let sse_events = self.sse_parser.feed(chunk);
self.process_sse_events(sse_events)
}
pub fn finish(self) -> Vec<StreamEvent> {
let sse_events = self.sse_parser.finish();
let mut stream_events = Vec::new();
for sse_event in sse_events {
stream_events.extend(self.handler.process_event(&sse_event));
}
stream_events
}
fn process_sse_events(&mut self, sse_events: Vec<SseEvent>) -> Vec<StreamEvent> {
let mut stream_events = Vec::new();
for sse_event in sse_events {
let events = self.handler.process_event(&sse_event);
for event in events {
match event {
StreamEvent::BlockStart {
index,
ref id,
ref block_type,
ref name,
} => {
self.block_accumulators.insert(
index,
BlockAccumulator {
id: id.clone(),
block_type: block_type.clone(),
name: name.clone(),
content_fragments: Vec::new(),
},
);
stream_events.push(event);
}
StreamEvent::BlockDelta {
index,
ref fragment,
} => {
if let Some(acc) = self.block_accumulators.get_mut(&index) {
acc.content_fragments.push(fragment.clone());
}
stream_events.push(event);
}
StreamEvent::BlockComplete { index, .. } => {
if let Some(acc) = self.block_accumulators.remove(&index) {
let assembled_json = acc.content_fragments.join("");
let content = if assembled_json.is_empty() {
FlexValue::new(serde_json::Value::Null)
} else {
match serde_json::from_str::<serde_json::Value>(&assembled_json) {
Ok(v) => FlexValue::new(v),
Err(_) => {
FlexValue::new(serde_json::Value::String(assembled_json))
}
}
};
stream_events.push(StreamEvent::BlockComplete {
index,
id: acc.id,
block_type: acc.block_type,
name: acc.name,
content,
});
} else {
stream_events.push(event);
}
}
_ => {
stream_events.push(event);
}
}
}
}
for event in &stream_events {
self.snapshot.apply_event(event);
}
stream_events
}
}