use bytes::Bytes;
use futures::{Stream, StreamExt};
use reqwest::Client;
use serde::Deserialize;
use std::collections::BTreeMap;
use std::pin::Pin;
use tt_shared::{
filter_extra_headers,
messages::{ChunkChoice, ChunkDelta, ToolCall, ToolCallFunction},
ChatCompletionChunk, ChatCompletionRequest, ProviderError, RequestContext, Usage,
};
use crate::errors::{map_reqwest_error, map_response_error};
use crate::translate;
#[derive(Debug, serde::Serialize)]
struct StreamOptions {
include_usage: bool,
}
#[derive(Debug, serde::Serialize)]
struct OpenAiStreamBody {
#[serde(flatten)]
inner: translate::OpenAiRequestBody,
stream_options: StreamOptions,
}
pub type ChunkStream =
Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk, ProviderError>> + Send>>;
pub async fn stream_chat_completion(
client: Client,
base_url: &str,
req: ChatCompletionRequest,
ctx: &RequestContext,
) -> Result<ChunkStream, ProviderError> {
let url = format!("{base_url}/chat/completions");
let api_key = ctx.credentials.api_key.expose().to_string();
let extra_headers: Vec<(String, String)> = filter_extra_headers(&ctx.credentials.extra_headers);
let mut translated = translate::translate_request(req)?;
translated.stream = true;
let body = OpenAiStreamBody {
inner: translated,
stream_options: StreamOptions {
include_usage: true,
},
};
let body_bytes = serde_json::to_vec(&body)
.map_err(|e| ProviderError::Internal(format!("failed to serialize stream body: {e}")))?;
let mut request_builder = client
.post(&url)
.header("Authorization", format!("Bearer {api_key}"))
.header("Content-Type", "application/json")
.body(body_bytes);
for (name, value) in &extra_headers {
request_builder = request_builder.header(name, value);
}
let response = request_builder.send().await.map_err(map_reqwest_error)?;
let status = response.status().as_u16();
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if status >= 400 {
let body_text = response.text().await.map_err(map_reqwest_error)?;
return Err(map_response_error(
status,
&body_text,
retry_after.as_deref(),
));
}
let bytes_stream = response.bytes_stream();
let stream = build_sse_stream(bytes_stream);
Ok(Box::pin(stream))
}
fn build_sse_stream<S>(
bytes_stream: S,
) -> impl Stream<Item = Result<ChatCompletionChunk, ProviderError>> + Send
where
S: Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static,
{
async_stream::stream! {
let mut buffer: Vec<u8> = Vec::new();
let mut acc = ToolAccum::default();
futures::pin_mut!(bytes_stream);
loop {
let next_item: Option<Result<Bytes, reqwest::Error>> = bytes_stream.next().await;
match next_item {
Some(Ok(chunk)) => {
buffer.extend_from_slice(&chunk);
while let Some((event_end, sep_len)) = find_event_boundary(&buffer) {
let event_bytes = buffer.drain(..event_end + sep_len).collect::<Vec<_>>();
let mut done = false;
for event in parse_sse_event(&event_bytes) {
match event {
SseEvent::Done => {
if let Some(c) = acc.drain(Some("tool_calls".to_string()), None) {
yield Ok(c);
}
done = true;
break;
}
SseEvent::Chunk(raw) => {
for c in handle_raw_chunk(&mut acc, raw) {
yield Ok(c);
}
}
SseEvent::Err(e) => {
yield Err(e);
}
SseEvent::Skip => {}
}
}
if done {
return;
}
}
}
Some(Err(e)) => {
yield Err(map_reqwest_error(e));
return;
}
None => {
if !buffer.is_empty() {
for event in parse_sse_event(&buffer) {
match event {
SseEvent::Chunk(raw) => {
for c in handle_raw_chunk(&mut acc, raw) {
yield Ok(c);
}
}
SseEvent::Err(e) => yield Err(e),
SseEvent::Done | SseEvent::Skip => {}
}
}
}
if let Some(c) = acc.drain(Some("tool_calls".to_string()), None) {
yield Ok(c);
}
return;
}
}
}
}
}
#[derive(Debug, Deserialize)]
struct RawChunk {
id: String,
object: String,
created: i64,
model: String,
#[serde(default)]
choices: Vec<RawChoice>,
#[serde(default)]
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
struct RawChoice {
#[serde(default)]
index: u32,
#[serde(default)]
delta: RawDelta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Default, Deserialize)]
struct RawDelta {
#[serde(default)]
role: Option<String>,
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Vec<RawToolCallDelta>,
}
#[derive(Debug, Default, Deserialize)]
struct RawToolCallDelta {
#[serde(default)]
index: u32,
#[serde(default)]
id: Option<String>,
#[serde(default, rename = "type")]
r#type: Option<String>,
#[serde(default)]
function: Option<RawFnDelta>,
}
#[derive(Debug, Default, Deserialize)]
struct RawFnDelta {
#[serde(default)]
name: Option<String>,
#[serde(default)]
arguments: Option<String>,
}
impl RawChunk {
fn into_canonical(self) -> ChatCompletionChunk {
ChatCompletionChunk {
id: self.id,
object: self.object,
created: self.created,
model: self.model,
choices: self
.choices
.into_iter()
.map(|c| ChunkChoice {
index: c.index,
delta: ChunkDelta {
role: c.delta.role,
content: c.delta.content,
tool_calls: Vec::new(),
},
finish_reason: c.finish_reason,
})
.collect(),
usage: self.usage,
}
}
}
#[derive(Default)]
struct PartialToolCall {
id: String,
r#type: String,
name: String,
arguments: String,
}
impl PartialToolCall {
fn into_tool_call(self) -> ToolCall {
ToolCall {
id: self.id,
r#type: if self.r#type.is_empty() {
"function".to_string()
} else {
self.r#type
},
function: ToolCallFunction {
name: self.name,
arguments: self.arguments,
},
}
}
}
#[derive(Clone)]
struct ChunkMeta {
id: String,
object: String,
created: i64,
model: String,
}
#[derive(Default)]
struct ToolAccum {
calls: BTreeMap<(u32, u32), PartialToolCall>,
meta: Option<ChunkMeta>,
}
impl ToolAccum {
fn is_empty(&self) -> bool {
self.calls.is_empty()
}
fn merge(&mut self, raw: &RawChunk) {
let mut saw_fragment = false;
for choice in &raw.choices {
for tc in &choice.delta.tool_calls {
saw_fragment = true;
let e = self.calls.entry((choice.index, tc.index)).or_default();
if let Some(id) = &tc.id {
if !id.is_empty() {
e.id = id.clone();
}
}
if let Some(t) = &tc.r#type {
if !t.is_empty() {
e.r#type = t.clone();
}
}
if let Some(f) = &tc.function {
if let Some(n) = &f.name {
if !n.is_empty() {
e.name = n.clone();
}
}
if let Some(a) = &f.arguments {
e.arguments.push_str(a);
}
}
}
}
if saw_fragment {
self.meta = Some(ChunkMeta {
id: raw.id.clone(),
object: raw.object.clone(),
created: raw.created,
model: raw.model.clone(),
});
}
}
fn drain(
&mut self,
finish_reason: Option<String>,
usage: Option<Usage>,
) -> Option<ChatCompletionChunk> {
if self.calls.is_empty() {
return None;
}
let meta = self.meta.take()?;
let mut by_choice: BTreeMap<u32, Vec<ToolCall>> = BTreeMap::new();
for ((choice_index, _tool_index), partial) in std::mem::take(&mut self.calls) {
by_choice
.entry(choice_index)
.or_default()
.push(partial.into_tool_call());
}
let choices = by_choice
.into_iter()
.map(|(index, tool_calls)| ChunkChoice {
index,
delta: ChunkDelta {
role: None,
content: None,
tool_calls,
},
finish_reason: finish_reason.clone(),
})
.collect();
Some(ChatCompletionChunk {
id: meta.id,
object: meta.object,
created: meta.created,
model: meta.model,
choices,
usage,
})
}
}
fn content_chunk(raw: &RawChunk) -> Option<ChatCompletionChunk> {
if !raw
.choices
.iter()
.any(|c| c.delta.role.is_some() || c.delta.content.is_some())
{
return None;
}
Some(ChatCompletionChunk {
id: raw.id.clone(),
object: raw.object.clone(),
created: raw.created,
model: raw.model.clone(),
choices: raw
.choices
.iter()
.map(|c| ChunkChoice {
index: c.index,
delta: ChunkDelta {
role: c.delta.role.clone(),
content: c.delta.content.clone(),
tool_calls: Vec::new(),
},
finish_reason: None,
})
.collect(),
usage: None,
})
}
fn handle_raw_chunk(acc: &mut ToolAccum, raw: RawChunk) -> Vec<ChatCompletionChunk> {
let has_tool_frag = raw.choices.iter().any(|c| !c.delta.tool_calls.is_empty());
let finish_reason = raw.choices.iter().find_map(|c| c.finish_reason.clone());
if has_tool_frag {
let mut out = Vec::new();
if let Some(c) = content_chunk(&raw) {
out.push(c);
}
acc.merge(&raw);
if finish_reason.is_some() {
out.extend(acc.drain(finish_reason, raw.usage));
}
return out;
}
if finish_reason.is_some() && !acc.is_empty() {
return acc.drain(finish_reason, raw.usage).into_iter().collect();
}
vec![raw.into_canonical()]
}
#[derive(Debug)]
enum SseEvent {
Chunk(RawChunk),
Done,
Err(ProviderError),
Skip,
}
fn parse_sse_event(event_bytes: &[u8]) -> Vec<SseEvent> {
let text = match std::str::from_utf8(event_bytes) {
Ok(t) => t,
Err(_) => {
return vec![SseEvent::Err(ProviderError::Deserialize(
"SSE event contained invalid UTF-8".to_string(),
))]
}
};
let mut results = Vec::new();
for line in text.lines() {
let line = line.trim_end_matches('\r');
if line.is_empty() {
continue;
}
if line.starts_with(':') {
continue;
}
if let Some(data) = line
.strip_prefix("data:")
.map(|s| s.strip_prefix(' ').unwrap_or(s))
{
if data == "[DONE]" {
results.push(SseEvent::Done);
break;
}
match serde_json::from_str::<RawChunk>(data) {
Ok(chunk) => results.push(SseEvent::Chunk(chunk)),
Err(e) => results.push(SseEvent::Err(ProviderError::Deserialize(format!(
"failed to parse SSE chunk: {e}"
)))),
}
}
}
if results.is_empty() {
results.push(SseEvent::Skip);
}
results
}
fn find_event_boundary(buf: &[u8]) -> Option<(usize, usize)> {
if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
return Some((pos, 4));
}
if let Some(pos) = buf.windows(2).position(|w| w == b"\n\n") {
return Some((pos, 2));
}
None
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAiSseError {
error: OpenAiSseErrorInner,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct OpenAiSseErrorInner {
message: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn find_event_boundary_lf() {
let buf = b"data: hello\n\ndata: world\n\n";
assert_eq!(find_event_boundary(buf), Some((11, 2)));
}
#[test]
fn find_event_boundary_crlf() {
let buf = b"data: hello\r\n\r\ndata: world\r\n\r\n";
assert_eq!(find_event_boundary(buf), Some((11, 4)));
}
#[test]
fn find_event_boundary_none() {
let buf = b"data: hello\n";
assert_eq!(find_event_boundary(buf), None);
}
#[test]
fn parse_sse_event_data_line() {
let chunk_json = r#"{"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}"#;
let event = format!("data: {chunk_json}\n\n");
let results = parse_sse_event(event.as_bytes());
assert_eq!(results.len(), 1);
assert!(matches!(&results[0], SseEvent::Chunk(c) if c.id == "c1"));
}
#[test]
fn parse_sse_event_done() {
let results = parse_sse_event(b"data: [DONE]\n\n");
assert_eq!(results.len(), 1);
assert!(matches!(results[0], SseEvent::Done));
}
#[test]
fn parse_sse_event_tool_call_fragment_no_id() {
let data = r#"{"id":"c","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"x\":"}}]},"finish_reason":null}]}"#;
let event = format!("data: {data}\n\n");
let results = parse_sse_event(event.as_bytes());
assert_eq!(results.len(), 1);
assert!(matches!(&results[0], SseEvent::Chunk(c) if c.id == "c"));
}
fn frag(index: u32, id: Option<&str>, name: Option<&str>, args: &str) -> RawToolCallDelta {
RawToolCallDelta {
index,
id: id.map(String::from),
r#type: id.map(|_| "function".to_string()),
function: Some(RawFnDelta {
name: name.map(String::from),
arguments: Some(args.to_string()),
}),
}
}
fn raw_chunk(tool_calls: Vec<RawToolCallDelta>, finish_reason: Option<&str>) -> RawChunk {
RawChunk {
id: "c".into(),
object: "chat.completion.chunk".into(),
created: 1,
model: "gpt-4o".into(),
choices: vec![RawChoice {
index: 0,
delta: RawDelta {
role: None,
content: None,
tool_calls,
},
finish_reason: finish_reason.map(String::from),
}],
usage: None,
}
}
#[test]
fn handle_reassembles_single_tool_call() {
let mut acc = ToolAccum::default();
assert!(handle_raw_chunk(
&mut acc,
raw_chunk(vec![frag(0, Some("call_1"), Some("f"), "")], None)
)
.is_empty());
assert!(handle_raw_chunk(
&mut acc,
raw_chunk(vec![frag(0, None, None, "{\"a\":")], None)
)
.is_empty());
let out = handle_raw_chunk(
&mut acc,
raw_chunk(vec![frag(0, None, None, "1}")], Some("tool_calls")),
);
assert_eq!(out.len(), 1);
let tc = &out[0].choices[0].delta.tool_calls;
assert_eq!(tc.len(), 1);
assert_eq!(tc[0].id, "call_1");
assert_eq!(tc[0].r#type, "function");
assert_eq!(tc[0].function.name, "f");
assert_eq!(tc[0].function.arguments, "{\"a\":1}");
assert_eq!(
out[0].choices[0].finish_reason.as_deref(),
Some("tool_calls")
);
assert!(acc.is_empty());
}
#[test]
fn handle_reassembles_two_tool_calls_by_index() {
let mut acc = ToolAccum::default();
handle_raw_chunk(
&mut acc,
raw_chunk(vec![frag(0, Some("a"), Some("fa"), "{}")], None),
);
handle_raw_chunk(
&mut acc,
raw_chunk(vec![frag(1, Some("b"), Some("fb"), "{}")], None),
);
let out = handle_raw_chunk(&mut acc, raw_chunk(vec![], Some("tool_calls")));
assert_eq!(out.len(), 1);
let tc = &out[0].choices[0].delta.tool_calls;
assert_eq!(tc.len(), 2);
assert_eq!(tc[0].id, "a"); assert_eq!(tc[1].id, "b");
}
#[test]
fn handle_forwards_content_chunk() {
let mut acc = ToolAccum::default();
let raw = RawChunk {
id: "c".into(),
object: "chat.completion.chunk".into(),
created: 1,
model: "gpt-4o".into(),
choices: vec![RawChoice {
index: 0,
delta: RawDelta {
role: None,
content: Some("Hi".into()),
tool_calls: vec![],
},
finish_reason: None,
}],
usage: None,
};
let out = handle_raw_chunk(&mut acc, raw);
assert_eq!(out.len(), 1);
assert_eq!(out[0].choices[0].delta.content.as_deref(), Some("Hi"));
assert!(out[0].choices[0].delta.tool_calls.is_empty());
}
#[test]
fn handle_preserves_role_riding_with_tool_fragment() {
let mut acc = ToolAccum::default();
let first = RawChunk {
id: "c".into(),
object: "chat.completion.chunk".into(),
created: 1,
model: "gpt-4o".into(),
choices: vec![RawChoice {
index: 0,
delta: RawDelta {
role: Some("assistant".into()),
content: None,
tool_calls: vec![frag(0, Some("call_1"), Some("f"), "{}")],
},
finish_reason: None,
}],
usage: None,
};
let out = handle_raw_chunk(&mut acc, first);
assert_eq!(out.len(), 1);
assert_eq!(out[0].choices[0].delta.role.as_deref(), Some("assistant"));
assert!(out[0].choices[0].delta.tool_calls.is_empty());
let done = handle_raw_chunk(&mut acc, raw_chunk(vec![], Some("tool_calls")));
assert_eq!(done.len(), 1);
assert_eq!(done[0].choices[0].delta.tool_calls.len(), 1);
assert_eq!(done[0].choices[0].delta.tool_calls[0].id, "call_1");
}
#[test]
fn handle_keys_tool_calls_by_choice_index() {
let mut acc = ToolAccum::default();
let raw = RawChunk {
id: "c".into(),
object: "chat.completion.chunk".into(),
created: 1,
model: "gpt-4o".into(),
choices: vec![
RawChoice {
index: 0,
delta: RawDelta {
role: None,
content: None,
tool_calls: vec![frag(0, Some("call_0"), Some("f0"), "{}")],
},
finish_reason: Some("tool_calls".into()),
},
RawChoice {
index: 1,
delta: RawDelta {
role: None,
content: None,
tool_calls: vec![frag(0, Some("call_1"), Some("f1"), "{}")],
},
finish_reason: Some("tool_calls".into()),
},
],
usage: None,
};
let out = handle_raw_chunk(&mut acc, raw);
assert_eq!(out.len(), 1);
assert_eq!(out[0].choices.len(), 2, "one ChunkChoice per choice index");
assert_eq!(out[0].choices[0].index, 0);
assert_eq!(out[0].choices[0].delta.tool_calls[0].id, "call_0");
assert_eq!(out[0].choices[1].index, 1);
assert_eq!(out[0].choices[1].delta.tool_calls[0].id, "call_1");
}
#[test]
fn parse_sse_event_comment_skipped() {
let results = parse_sse_event(b":keep-alive\n\n");
assert_eq!(results.len(), 1);
assert!(matches!(results[0], SseEvent::Skip));
}
#[test]
fn parse_sse_event_malformed_json() {
let results = parse_sse_event(b"data: {not valid json}\n\n");
assert_eq!(results.len(), 1);
assert!(matches!(results[0], SseEvent::Err(_)));
}
#[test]
fn parse_sse_event_crlf_data_line() {
let chunk_json = r#"{"id":"c2","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}"#;
let event = format!("data: {chunk_json}\r\n\r\n");
let results = parse_sse_event(event.as_bytes());
assert_eq!(results.len(), 1, "CRLF event should parse to one chunk");
assert!(
matches!(&results[0], SseEvent::Chunk(c) if c.id == "c2"),
"CRLF-delimited event should parse identical to LF form"
);
}
#[test]
fn parse_sse_event_no_space_data_prefix() {
let chunk_json = r#"{"id":"c3","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}"#;
let event = format!("data:{chunk_json}\n\n");
let results = parse_sse_event(event.as_bytes());
assert_eq!(
results.len(),
1,
"no-space data: event should parse to one chunk"
);
assert!(
matches!(&results[0], SseEvent::Chunk(c) if c.id == "c3"),
"data:{{...}} (no space) should parse the same as `data: {{...}}`"
);
}
#[test]
fn find_event_boundary_prefers_crlf_over_lf_in_same_buf() {
let buf = b"data: a\r\n\r\ndata: b\n\n";
let (pos, sep) = find_event_boundary(buf).expect("boundary found");
assert_eq!(sep, 4, "should detect CRLF boundary");
assert_eq!(pos, 7);
}
}