use crate::error::LlmConnectorError;
use futures_util::{Stream, StreamExt};
use std::pin::Pin;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamFormat {
Sse,
NdJson,
Auto,
}
pub fn create_text_stream(
response: reqwest::Response,
format: StreamFormat,
) -> Pin<Box<dyn Stream<Item = Result<String, LlmConnectorError>> + Send>> {
let stream = response.bytes_stream();
struct ScanState {
buffer: String,
detected_format: Option<StreamFormat>,
}
let events_stream = stream
.scan(
ScanState {
buffer: String::new(),
detected_format: if format == StreamFormat::Auto {
None
} else {
Some(format)
},
},
move |state, chunk_result| {
let mut out: Vec<Result<String, LlmConnectorError>> = Vec::new();
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk).replace("\r\n", "\n");
state.buffer.push_str(&chunk_str);
if state.detected_format.is_none() {
if state.buffer.contains("data:") {
state.detected_format = Some(StreamFormat::Sse);
} else if state.buffer.contains('\n')
&& state.buffer.trim().starts_with('{')
{
state.detected_format = Some(StreamFormat::NdJson);
}
}
match state.detected_format {
Some(StreamFormat::Sse) => {
while let Some(boundary_idx) = state.buffer.find("\n\n") {
let event_str: String =
state.buffer.drain(..boundary_idx + 2).collect();
let mut data_lines = Vec::new();
for line in event_str.split('\n') {
let line = line.trim();
if let Some(payload) = line.strip_prefix("data:") {
let payload = payload.trim();
if !payload.is_empty() && payload != "[DONE]" {
data_lines.push(payload.to_string());
}
}
}
if !data_lines.is_empty() {
out.push(Ok(data_lines.join("\n")));
}
}
}
Some(StreamFormat::NdJson) => {
while let Some(boundary_idx) = state.buffer.find('\n') {
let line: String =
state.buffer.drain(..boundary_idx + 1).collect();
let trimmed = line.trim();
let payload = if let Some(p) = trimmed.strip_prefix("data:") {
p.trim()
} else {
trimmed
};
if !payload.is_empty() && payload != "[DONE]" {
out.push(Ok(payload.to_string()));
}
}
}
None => {
}
_ => {
}
}
}
Err(e) => {
out.push(Err(LlmConnectorError::NetworkError(e.to_string())));
}
}
std::future::ready(Some(out))
},
)
.flat_map(futures_util::stream::iter);
Box::pin(events_stream)
}
#[inline]
pub fn sse_events(
response: reqwest::Response,
) -> Pin<Box<dyn Stream<Item = Result<String, LlmConnectorError>> + Send>> {
create_text_stream(response, StreamFormat::Sse)
}
#[inline]
pub fn json_lines_events(
response: reqwest::Response,
) -> Pin<Box<dyn Stream<Item = Result<String, LlmConnectorError>> + Send>> {
create_text_stream(response, StreamFormat::NdJson)
}
pub fn parse_sse_line(line: &str) -> Result<Option<serde_json::Value>, LlmConnectorError> {
let line = line.trim();
if line.is_empty() || line.starts_with(':') {
return Ok(None);
}
if let Some(payload) = line.strip_prefix("data:") {
let payload = payload.trim();
if payload.is_empty() || payload == "[DONE]" {
return Ok(None);
}
let value: serde_json::Value = serde_json::from_str(payload).map_err(|e| {
LlmConnectorError::ParseError(format!("Failed to parse SSE JSON: {}", e))
})?;
Ok(Some(value))
} else {
Ok(None)
}
}
#[cfg(feature = "streaming")]
pub fn sse_to_streaming_response(response: reqwest::Response) -> crate::types::ChatStream {
use crate::types::{StreamingResponse, ToolCall};
use std::collections::HashMap;
let string_stream = create_text_stream(response, StreamFormat::Auto);
let response_stream = string_stream.scan(
HashMap::<usize, ToolCall>::new(),
|accumulated_tool_calls, result| {
let processed = result.and_then(|json_str| {
let mut streaming_response = serde_json::from_str::<StreamingResponse>(&json_str)
.map_err(|e| {
crate::error::LlmConnectorError::ParseError(format!(
"Failed to parse streaming response: {}. Content: {}",
e, json_str
))
})?;
populate_convenience_fields(&mut streaming_response);
accumulate_tool_calls(&mut streaming_response, accumulated_tool_calls);
Ok(streaming_response)
});
std::future::ready(Some(processed))
},
);
Box::pin(response_stream)
}
#[cfg(feature = "streaming")]
fn populate_convenience_fields(response: &mut crate::types::StreamingResponse) {
if response.content.is_empty()
&& let Some(choice) = response.choices.first()
{
let content_to_use = choice
.delta
.content
.as_ref()
.filter(|s| !s.is_empty())
.or(choice.delta.reasoning_content.as_ref())
.or(choice.delta.reasoning.as_ref())
.or(choice.delta.thought.as_ref())
.or(choice.delta.thinking.as_ref());
if let Some(content) = content_to_use {
response.content = content.clone();
}
}
}
#[cfg(feature = "streaming")]
fn accumulate_tool_calls(
response: &mut crate::types::StreamingResponse,
accumulated: &mut std::collections::HashMap<usize, crate::types::ToolCall>,
) {
if let Some(choice) = response.choices.first_mut()
&& let Some(delta_tool_calls) = &choice.delta.tool_calls
{
for delta_call in delta_tool_calls {
let index = delta_call.index.unwrap_or(0);
accumulated
.entry(index)
.and_modify(|existing| existing.merge_delta(delta_call))
.or_insert_with(|| delta_call.clone());
}
let complete_calls: Vec<crate::types::ToolCall> = accumulated
.values()
.filter(|call| call.is_complete())
.cloned()
.collect();
if !complete_calls.is_empty() {
choice.delta.tool_calls = Some(complete_calls);
} else {
choice.delta.tool_calls = None;
}
}
}
#[cfg(test)]
mod tests {
#[tokio::test]
async fn test_sse_detection() {
let _mock_response = "data: {\"test\":1}\n\ndata: {\"test\":2}\n\n";
}
}