use std::pin::Pin;
use std::task::{Context, Poll};
use futures_core::Stream;
use futures_util::StreamExt;
use crate::error::OpenAIError;
use crate::streaming::SseStream;
use crate::types::chat::{
ChatCompletionChunk, ChatCompletionMessage, ChatCompletionResponse, ChunkChoice,
};
use crate::types::common::{FinishReason, Role, Usage};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ChatStreamEvent {
Chunk(Box<ChatCompletionChunk>),
ContentDelta {
delta: String,
snapshot: String,
},
ContentDone {
content: String,
},
RefusalDelta { delta: String, snapshot: String },
RefusalDone { refusal: String },
ToolCallDelta {
index: i32,
name: String,
arguments_delta: String,
arguments_snapshot: String,
},
ToolCallDone {
index: i32,
call_id: String,
name: String,
arguments: String,
},
Done { finish_reason: Option<FinishReason> },
}
#[derive(Debug, Clone, Default)]
struct ToolCallState {
id: String,
name: String,
arguments: String,
}
#[derive(Debug)]
struct StreamState {
id: String,
model: String,
created: i64,
content: String,
refusal: String,
tool_calls: Vec<ToolCallState>,
finish_reason: Option<FinishReason>,
usage: Option<Usage>,
system_fingerprint: Option<String>,
service_tier: Option<crate::types::common::ServiceTier>,
content_done: bool,
refusal_done: bool,
tools_done: Vec<bool>,
}
impl StreamState {
fn new() -> Self {
Self {
id: String::new(),
model: String::new(),
created: 0,
content: String::new(),
refusal: String::new(),
tool_calls: Vec::new(),
finish_reason: None,
usage: None,
system_fingerprint: None,
service_tier: None,
content_done: false,
refusal_done: false,
tools_done: Vec::new(),
}
}
fn handle_chunk(&mut self, chunk: &ChatCompletionChunk) -> Vec<ChatStreamEvent> {
let mut events = Vec::new();
if self.id.is_empty() {
self.id.clone_from(&chunk.id);
self.model.clone_from(&chunk.model);
self.created = chunk.created;
}
if chunk.system_fingerprint.is_some() {
self.system_fingerprint
.clone_from(&chunk.system_fingerprint);
}
if chunk.service_tier.is_some() {
self.service_tier = chunk.service_tier.clone();
}
if chunk.usage.is_some() {
self.usage.clone_from(&chunk.usage);
}
for choice in &chunk.choices {
self.handle_choice(choice, &mut events);
}
events
}
fn handle_choice(&mut self, choice: &ChunkChoice, events: &mut Vec<ChatStreamEvent>) {
let delta = &choice.delta;
if let Some(ref text) = delta.content
&& !text.is_empty()
{
self.content.push_str(text);
events.push(ChatStreamEvent::ContentDelta {
delta: text.clone(),
snapshot: self.content.clone(),
});
}
if let Some(ref refusal) = delta.refusal
&& !refusal.is_empty()
{
self.refusal.push_str(refusal);
events.push(ChatStreamEvent::RefusalDelta {
delta: refusal.clone(),
snapshot: self.refusal.clone(),
});
}
if let Some(ref tool_calls) = delta.tool_calls {
for tc in tool_calls {
let idx = tc.index as usize;
while self.tool_calls.len() <= idx {
self.tool_calls.push(ToolCallState::default());
self.tools_done.push(false);
}
let state = &mut self.tool_calls[idx];
if let Some(ref id) = tc.id {
state.id = id.clone();
}
if let Some(ref func) = tc.function {
if let Some(ref name) = func.name {
state.name = name.clone();
}
if let Some(ref args) = func.arguments {
state.arguments.push_str(args);
events.push(ChatStreamEvent::ToolCallDelta {
index: tc.index,
name: state.name.clone(),
arguments_delta: args.clone(),
arguments_snapshot: state.arguments.clone(),
});
}
}
}
}
if let Some(ref fr) = choice.finish_reason {
self.finish_reason = Some(fr.clone());
if !self.content.is_empty() && !self.content_done {
self.content_done = true;
events.push(ChatStreamEvent::ContentDone {
content: self.content.clone(),
});
}
if !self.refusal.is_empty() && !self.refusal_done {
self.refusal_done = true;
events.push(ChatStreamEvent::RefusalDone {
refusal: self.refusal.clone(),
});
}
for (i, tc) in self.tool_calls.iter().enumerate() {
if !self.tools_done[i] {
self.tools_done[i] = true;
events.push(ChatStreamEvent::ToolCallDone {
index: i as i32,
call_id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
});
}
}
events.push(ChatStreamEvent::Done {
finish_reason: Some(fr.clone()),
});
}
}
fn into_completion(self) -> ChatCompletionResponse {
let tool_calls = if self.tool_calls.is_empty() {
None
} else {
Some(
self.tool_calls
.into_iter()
.map(|tc| crate::types::chat::ToolCall {
id: tc.id,
type_: "function".into(),
function: crate::types::chat::FunctionCall {
name: tc.name,
arguments: tc.arguments,
},
})
.collect(),
)
};
ChatCompletionResponse {
id: self.id,
choices: vec![crate::types::chat::ChatCompletionChoice {
index: 0,
finish_reason: self.finish_reason.unwrap_or(FinishReason::Stop),
message: ChatCompletionMessage {
role: Role::Assistant,
content: if self.content.is_empty() {
None
} else {
Some(self.content)
},
refusal: if self.refusal.is_empty() {
None
} else {
Some(self.refusal)
},
tool_calls,
annotations: None,
},
logprobs: None,
}],
created: self.created,
model: self.model,
object: "chat.completion".into(),
service_tier: self.service_tier,
system_fingerprint: self.system_fingerprint,
usage: self.usage,
session_id: None,
}
}
}
pub struct ChatCompletionStream {
inner: SseStream<ChatCompletionChunk>,
state: StreamState,
event_buffer: Vec<ChatStreamEvent>,
done: bool,
}
impl ChatCompletionStream {
pub(crate) fn new(inner: SseStream<ChatCompletionChunk>) -> Self {
Self {
inner,
state: StreamState::new(),
event_buffer: Vec::new(),
done: false,
}
}
pub async fn get_final_completion(mut self) -> Result<ChatCompletionResponse, OpenAIError> {
while let Some(result) = self.inner.next().await {
let chunk = result?;
self.state.handle_chunk(&chunk);
}
Ok(self.state.into_completion())
}
pub fn current_content(&self) -> &str {
&self.state.content
}
}
impl Stream for ChatCompletionStream {
type Item = Result<ChatStreamEvent, OpenAIError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if !this.event_buffer.is_empty() {
return Poll::Ready(Some(Ok(this.event_buffer.remove(0))));
}
if this.done {
return Poll::Ready(None);
}
match this.inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let mut events = this.state.handle_chunk(&chunk);
let first = ChatStreamEvent::Chunk(Box::new(chunk));
if events.is_empty() {
Poll::Ready(Some(Ok(first)))
} else {
this.event_buffer.append(&mut events);
Poll::Ready(Some(Ok(first)))
}
}
Poll::Ready(Some(Err(e))) => {
this.done = true;
Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => {
this.done = true;
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
impl Unpin for ChatCompletionStream {}
#[cfg(test)]
mod tests {
use super::*;
fn make_chunk(
content: Option<&str>,
finish_reason: Option<FinishReason>,
) -> ChatCompletionChunk {
ChatCompletionChunk {
id: "c1".into(),
choices: vec![ChunkChoice {
delta: crate::types::chat::ChoiceDelta {
content: content.map(String::from),
role: None,
refusal: None,
tool_calls: None,
},
finish_reason,
index: 0,
logprobs: None,
}],
created: 1,
model: "gpt-4o".into(),
object: "chat.completion.chunk".into(),
service_tier: None,
system_fingerprint: None,
usage: None,
session_id: None,
}
}
#[test]
fn test_accumulate_content() {
let mut state = StreamState::new();
let events = state.handle_chunk(&make_chunk(Some("Hello"), None));
assert_eq!(events.len(), 1);
match &events[0] {
ChatStreamEvent::ContentDelta { delta, snapshot } => {
assert_eq!(delta, "Hello");
assert_eq!(snapshot, "Hello");
}
other => panic!("expected ContentDelta, got: {other:?}"),
}
let events = state.handle_chunk(&make_chunk(Some(" world"), None));
match &events[0] {
ChatStreamEvent::ContentDelta { delta, snapshot } => {
assert_eq!(delta, " world");
assert_eq!(snapshot, "Hello world");
}
other => panic!("expected ContentDelta, got: {other:?}"),
}
}
#[test]
fn test_content_done_on_finish() {
let mut state = StreamState::new();
state.handle_chunk(&make_chunk(Some("Hi"), None));
let events = state.handle_chunk(&make_chunk(None, Some(FinishReason::Stop)));
assert!(
events
.iter()
.any(|e| matches!(e, ChatStreamEvent::ContentDone { content } if content == "Hi"))
);
assert!(events.iter().any(|e| matches!(
e,
ChatStreamEvent::Done {
finish_reason: Some(FinishReason::Stop)
}
)));
}
#[test]
fn test_tool_call_accumulation() {
let mut state = StreamState::new();
let chunk1 = ChatCompletionChunk {
id: "c1".into(),
choices: vec![ChunkChoice {
delta: crate::types::chat::ChoiceDelta {
content: None,
role: Some(Role::Assistant),
refusal: None,
tool_calls: Some(vec![crate::types::chat::DeltaToolCall {
index: 0,
id: Some("call_1".into()),
function: Some(crate::types::chat::DeltaFunctionCall {
name: Some("get_weather".into()),
arguments: Some("{\"loc".into()),
}),
type_: Some("function".into()),
}]),
},
finish_reason: None,
index: 0,
logprobs: None,
}],
created: 1,
model: "gpt-4o".into(),
object: "chat.completion.chunk".into(),
service_tier: None,
system_fingerprint: None,
usage: None,
session_id: None,
};
let events = state.handle_chunk(&chunk1);
assert!(events.iter().any(
|e| matches!(e, ChatStreamEvent::ToolCallDelta { name, .. } if name == "get_weather")
));
let chunk2 = ChatCompletionChunk {
id: "c1".into(),
choices: vec![ChunkChoice {
delta: crate::types::chat::ChoiceDelta {
content: None,
role: None,
refusal: None,
tool_calls: Some(vec![crate::types::chat::DeltaToolCall {
index: 0,
id: None,
function: Some(crate::types::chat::DeltaFunctionCall {
name: None,
arguments: Some("ation\": \"SF\"}".into()),
}),
type_: None,
}]),
},
finish_reason: None,
index: 0,
logprobs: None,
}],
created: 1,
model: "gpt-4o".into(),
object: "chat.completion.chunk".into(),
service_tier: None,
system_fingerprint: None,
usage: None,
session_id: None,
};
state.handle_chunk(&chunk2);
let events = state.handle_chunk(&make_chunk(None, Some(FinishReason::ToolCalls)));
assert!(events.iter().any(|e| matches!(
e,
ChatStreamEvent::ToolCallDone { name, arguments, call_id, .. }
if name == "get_weather" && arguments == "{\"location\": \"SF\"}" && call_id == "call_1"
)));
}
#[test]
fn test_into_completion() {
let mut state = StreamState::new();
state.handle_chunk(&make_chunk(Some("Hello"), None));
state.handle_chunk(&make_chunk(Some(" world"), None));
state.handle_chunk(&make_chunk(None, Some(FinishReason::Stop)));
let completion = state.into_completion();
assert_eq!(
completion.choices[0].message.content.as_deref(),
Some("Hello world")
);
assert_eq!(completion.choices[0].finish_reason, FinishReason::Stop);
}
}