use serde::{Deserialize, Serialize};
use crate::usage::Usage;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum StreamChunk {
Text(String),
ReasoningContent(String),
Audio {
data: String,
#[serde(skip_serializing_if = "Option::is_none")]
transcript: Option<String>,
},
ToolUseStart {
index: usize,
id: String,
name: String,
},
ToolUseDelta {
index: usize,
partial_json: String,
},
ToolUseComplete {
index: usize,
},
Usage(Usage),
Done {
stop_reason: Option<StopReason>,
},
Error {
message: String,
},
}
impl StreamChunk {
#[inline]
#[must_use]
pub fn text(content: impl Into<String>) -> Self {
Self::Text(content.into())
}
#[must_use]
pub fn tool_use_start(index: usize, id: impl Into<String>, name: impl Into<String>) -> Self {
Self::ToolUseStart {
index,
id: id.into(),
name: name.into(),
}
}
#[must_use]
pub fn tool_use_delta(index: usize, partial_json: impl Into<String>) -> Self {
Self::ToolUseDelta {
index,
partial_json: partial_json.into(),
}
}
#[must_use]
pub const fn done(stop_reason: Option<StopReason>) -> Self {
Self::Done { stop_reason }
}
#[must_use]
pub fn error(message: impl Into<String>) -> Self {
Self::Error {
message: message.into(),
}
}
#[must_use]
pub fn as_text(&self) -> Option<&str> {
match self {
Self::Text(text) => Some(text),
_ => None,
}
}
#[must_use]
pub const fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
#[must_use]
pub const fn is_done(&self) -> bool {
matches!(self, Self::Done { .. })
}
#[must_use]
pub const fn is_error(&self) -> bool {
matches!(self, Self::Error { .. })
}
#[inline]
#[must_use]
pub fn reasoning(content: impl Into<String>) -> Self {
Self::ReasoningContent(content.into())
}
#[must_use]
pub fn audio(data: impl Into<String>, transcript: Option<String>) -> Self {
Self::Audio {
data: data.into(),
transcript,
}
}
#[must_use]
pub fn as_reasoning(&self) -> Option<&str> {
match self {
Self::ReasoningContent(content) => Some(content),
_ => None,
}
}
#[must_use]
pub const fn is_reasoning(&self) -> bool {
matches!(self, Self::ReasoningContent(_))
}
#[must_use]
pub const fn is_audio(&self) -> bool {
matches!(self, Self::Audio { .. })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum StopReason {
#[default]
Stop,
Length,
#[serde(alias = "function_call")]
ToolCalls,
ContentFilter,
Null,
}
impl StopReason {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Stop => "stop",
Self::Length => "length",
Self::ToolCalls => "tool_calls",
Self::ContentFilter => "content_filter",
Self::Null => "null",
}
}
#[must_use]
pub fn parse(s: &str) -> Self {
match s.to_lowercase().as_str() {
"length" | "max_tokens" => Self::Length,
"tool_calls" | "tool_use" | "function_call" => Self::ToolCalls,
"content_filter" => Self::ContentFilter,
"null" => Self::Null,
_ => Self::Stop,
}
}
#[must_use]
pub const fn is_complete(&self) -> bool {
matches!(self, Self::Stop | Self::ToolCalls)
}
#[must_use]
pub const fn is_truncated(&self) -> bool {
matches!(self, Self::Length)
}
#[must_use]
pub const fn is_filtered(&self) -> bool {
matches!(self, Self::ContentFilter)
}
#[must_use]
pub const fn is_tool_call(&self) -> bool {
matches!(self, Self::ToolCalls)
}
}
impl std::fmt::Display for StopReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Default)]
pub struct StreamAggregator {
text: String,
reasoning_content: String,
tool_calls: std::collections::BTreeMap<usize, ToolCallBuilder>,
usage: Option<Usage>,
stop_reason: Option<StopReason>,
}
#[derive(Debug, Clone, Default)]
struct ToolCallBuilder {
id: String,
name: String,
arguments: String,
}
impl StreamAggregator {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn apply(&mut self, chunk: &StreamChunk) {
match chunk {
StreamChunk::Text(text) => {
self.text.push_str(text);
}
StreamChunk::ToolUseStart { index, id, name } => {
self.tool_calls.insert(
*index,
ToolCallBuilder {
id: id.clone(),
name: name.clone(),
arguments: String::new(),
},
);
}
StreamChunk::ToolUseDelta {
index,
partial_json,
} => {
if let Some(tc) = self.tool_calls.get_mut(index) {
tc.arguments.push_str(partial_json);
}
}
StreamChunk::ReasoningContent(content) => {
self.reasoning_content.push_str(content);
}
StreamChunk::Audio { .. }
| StreamChunk::ToolUseComplete { .. }
| StreamChunk::Error { .. } => {}
StreamChunk::Usage(usage) => {
self.usage = Some(*usage);
}
StreamChunk::Done { stop_reason } => {
self.stop_reason = *stop_reason;
}
}
}
#[must_use]
pub fn text(&self) -> &str {
&self.text
}
#[must_use]
pub fn reasoning_content(&self) -> &str {
&self.reasoning_content
}
#[must_use]
pub const fn has_reasoning_content(&self) -> bool {
!self.reasoning_content.is_empty()
}
#[must_use]
pub const fn usage(&self) -> Option<Usage> {
self.usage
}
#[must_use]
pub const fn stop_reason(&self) -> Option<StopReason> {
self.stop_reason
}
#[must_use]
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
#[must_use]
pub fn build_tool_calls(&self) -> Vec<crate::message::ToolCall> {
self.tool_calls
.values()
.map(|tc| crate::message::ToolCall::function(&tc.id, &tc.name, &tc.arguments))
.collect()
}
#[must_use]
pub fn into_chat_response(self) -> crate::chat::ChatResponse {
use crate::chat::ChatResponse;
use crate::message::{Content, Message, Role};
let tool_calls = self.build_tool_calls();
let has_text = !self.text.is_empty();
let has_tools = !tool_calls.is_empty();
let mut message = match (has_text, has_tools) {
(_, true) => {
let mut msg = Message::assistant_tool_calls(tool_calls);
if has_text {
msg.content = Some(Content::text(self.text));
}
msg
}
_ => Message::new(Role::Assistant, Content::text(self.text)),
};
if !self.reasoning_content.is_empty() {
message.reasoning_content = Some(self.reasoning_content);
}
let mut response = ChatResponse::new(message);
if let Some(reason) = self.stop_reason {
response = response.with_stop_reason(reason);
}
if let Some(usage) = self.usage {
response = response.with_usage(usage);
}
response
}
}