use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use tracing::error;
use swink_agent::AgentTool;
use swink_agent::ContentBlock;
use swink_agent::{
AssistantMessage as HarnessAssistantMessage, AssistantMessageEvent, Cost, StopReason,
StreamErrorKind, ToolResultMessage, Usage, UserMessage,
};
use crate::convert::{MessageConverter, extract_tool_schemas};
use crate::sse::{SseAction, SseLine, sse_data_lines_with_callback};
#[derive(Debug, Serialize)]
pub struct OaiMessage {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<OaiToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct OaiToolCallRequest {
pub id: String,
pub r#type: String,
pub function: OaiFunctionCallRequest,
}
#[derive(Debug, Serialize)]
pub struct OaiFunctionCallRequest {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Serialize)]
pub struct OaiTool {
pub r#type: String,
pub function: OaiToolDef,
}
#[derive(Debug, Serialize)]
pub struct OaiToolDef {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Serialize)]
pub struct OaiStreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Serialize)]
pub struct OaiChatRequest {
pub model: String,
pub messages: Vec<OaiMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<OaiStreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<OaiTool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
}
#[derive(Deserialize)]
pub struct OaiChunk {
#[serde(default)]
pub choices: Vec<OaiChoice>,
#[serde(default)]
pub usage: Option<OaiUsage>,
}
#[derive(Deserialize)]
pub struct OaiChoice {
#[serde(default)]
pub delta: OaiDelta,
#[serde(default)]
pub finish_reason: Option<String>,
}
#[derive(Default, Deserialize)]
pub struct OaiDelta {
#[serde(default)]
pub content: Option<String>,
#[serde(default)]
pub tool_calls: Option<Vec<OaiToolCallDelta>>,
#[serde(default)]
pub reasoning_content: Option<String>,
}
#[derive(Deserialize)]
pub struct OaiToolCallDelta {
pub index: usize,
#[serde(default)]
pub id: Option<String>,
#[serde(default)]
pub function: Option<OaiFunctionDelta>,
}
#[derive(Deserialize)]
pub struct OaiFunctionDelta {
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub arguments: Option<String>,
}
#[derive(Deserialize)]
pub struct OaiUsage {
#[serde(default)]
pub prompt_tokens: u64,
#[serde(default)]
pub completion_tokens: u64,
#[serde(default)]
pub total_tokens: Option<u64>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
impl OaiUsage {
fn to_usage(&self) -> Usage {
let mut extra = HashMap::new();
for (key, value) in &self.extra {
collect_numeric_usage_fields(key.clone(), value, &mut extra);
}
Usage {
input: self.prompt_tokens,
output: self.completion_tokens,
cache_read: 0,
cache_write: 0,
total: self
.total_tokens
.unwrap_or(self.prompt_tokens + self.completion_tokens),
extra,
}
}
}
fn collect_numeric_usage_fields(key: String, value: &Value, extra: &mut HashMap<String, u64>) {
match value {
Value::Number(number) => {
if let Some(value) = number.as_u64() {
extra.insert(key, value);
}
}
Value::Object(fields) => {
for (child_key, child_value) in fields {
collect_numeric_usage_fields(format!("{key}.{child_key}"), child_value, extra);
}
}
_ => {}
}
}
pub struct OaiToolCallEntry {
pub id: String,
pub name: Option<String>,
pub arguments: String,
pub content_index: Option<usize>,
}
pub struct OaiConverter;
impl MessageConverter for OaiConverter {
type Message = OaiMessage;
fn system_message(system_prompt: &str) -> Option<OaiMessage> {
Some(OaiMessage {
role: "system".to_string(),
content: Some(system_prompt.to_string()),
tool_calls: None,
tool_call_id: None,
})
}
fn user_message(user: &UserMessage) -> OaiMessage {
let content = ContentBlock::extract_text(&user.content);
OaiMessage {
role: "user".to_string(),
content: Some(content),
tool_calls: None,
tool_call_id: None,
}
}
fn assistant_message(assistant: &HarnessAssistantMessage) -> OaiMessage {
let mut content = String::new();
let mut tool_calls = Vec::new();
for block in &assistant.content {
match block {
ContentBlock::Text { text } => {
content.push_str(text);
}
ContentBlock::ToolCall {
id,
name,
arguments,
..
} => {
tool_calls.push(OaiToolCallRequest {
id: id.clone(),
r#type: "function".to_string(),
function: OaiFunctionCallRequest {
name: name.clone(),
arguments: arguments.to_string(),
},
});
}
_ => {}
}
}
OaiMessage {
role: "assistant".to_string(),
content: if content.is_empty() {
None
} else {
Some(content)
},
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
}
}
fn tool_result_message(result: &ToolResultMessage) -> OaiMessage {
let content = ContentBlock::extract_text(&result.content);
OaiMessage {
role: "tool".to_string(),
content: Some(content),
tool_calls: None,
tool_call_id: Some(result.tool_call_id.clone()),
}
}
}
pub fn build_oai_tools(tools: &[Arc<dyn AgentTool>]) -> (Vec<OaiTool>, Option<String>) {
let oai_tools: Vec<OaiTool> = extract_tool_schemas(tools)
.into_iter()
.map(|s| OaiTool {
r#type: "function".to_string(),
function: OaiToolDef {
name: s.name,
description: s.description,
parameters: s.parameters,
},
})
.collect();
let tool_choice = if oai_tools.is_empty() {
None
} else {
Some("auto".to_string())
};
(oai_tools, tool_choice)
}
#[derive(Default)]
pub struct OaiSseStreamState {
pub blocks: crate::block_accumulator::BlockAccumulator,
pub tool_calls: HashMap<usize, OaiToolCallEntry>,
pub usage: Option<Usage>,
pub stop_reason: Option<StopReason>,
pub terminal_error: Option<AssistantMessageEvent>,
}
impl crate::finalize::StreamFinalize for OaiSseStreamState {
fn drain_open_blocks(&mut self) -> Vec<crate::finalize::OpenBlock> {
self.tool_calls.clear();
crate::finalize::StreamFinalize::drain_open_blocks(&mut self.blocks)
}
}
pub fn process_oai_chunk(
chunk: &OaiChunk,
state: &mut OaiSseStreamState,
events: &mut Vec<AssistantMessageEvent>,
provider: &str,
) {
if let Some(u) = &chunk.usage {
state.usage = Some(u.to_usage());
}
for choice in &chunk.choices {
if let Some(reasoning) = &choice.delta.reasoning_content
&& !reasoning.is_empty()
{
if let Some(ev) = state.blocks.ensure_thinking_open() {
events.push(ev);
}
if let Some(ev) = state.blocks.thinking_delta(reasoning.clone()) {
events.push(ev);
}
}
if let Some(content) = &choice.delta.content
&& !content.is_empty()
{
if let Some(ev) = state.blocks.close_thinking(None) {
events.push(ev);
}
if let Some(ev) = state.blocks.ensure_text_open() {
events.push(ev);
}
if let Some(ev) = state.blocks.text_delta(content.clone()) {
events.push(ev);
}
}
if let Some(tool_calls) = &choice.delta.tool_calls {
if let Some(ev) = state.blocks.close_thinking(None) {
events.push(ev);
}
if let Some(ev) = state.blocks.close_text() {
events.push(ev);
}
for tc_delta in tool_calls {
process_oai_tool_call_delta(tc_delta, state, events, provider);
}
}
if let Some(reason) = &choice.finish_reason {
if reason == "content_filter" {
flush_pending_oai_tool_calls(state, events);
events.extend(crate::finalize::finalize_blocks(state));
state.terminal_error = Some(AssistantMessageEvent::error_content_filtered(
format!("{provider} response stopped by content filter"),
));
return;
}
if provider == "Mistral" && reason == "error" {
flush_pending_oai_tool_calls(state, events);
events.extend(crate::finalize::finalize_blocks(state));
state.terminal_error = Some(AssistantMessageEvent::Error {
stop_reason: StopReason::Error,
error_message: "Mistral reported finish_reason=error".to_string(),
usage: state.usage.clone(),
error_kind: Some(StreamErrorKind::Network),
});
return;
}
let stop_reason = match reason.as_str() {
"tool_calls" => StopReason::ToolUse,
"length" | "model_length" => StopReason::Length,
_ => StopReason::Stop,
};
flush_pending_oai_tool_calls(state, events);
events.extend(crate::finalize::finalize_blocks(state));
state.stop_reason = Some(stop_reason);
}
}
}
fn process_oai_tool_call_delta(
tc_delta: &OaiToolCallDelta,
state: &mut OaiSseStreamState,
events: &mut Vec<AssistantMessageEvent>,
provider: &str,
) {
let tc_index = tc_delta.index;
let mut emit_delta = None;
let mut open_tool_call = None;
{
let tc_entry = state
.tool_calls
.entry(tc_index)
.or_insert_with(|| OaiToolCallEntry {
id: tc_delta
.id
.clone()
.unwrap_or_else(|| format!("{provider}-tool-{tc_index}")),
name: None,
arguments: String::new(),
content_index: None,
});
if tc_entry.content_index.is_none()
&& let Some(id) = &tc_delta.id
{
tc_entry.id.clone_from(id);
}
if let Some(name) = tc_delta
.function
.as_ref()
.and_then(|f| f.name.as_ref())
.filter(|name| !name.is_empty())
{
tc_entry.name = Some(name.clone());
}
if let Some(args) = tc_delta
.function
.as_ref()
.and_then(|f| f.arguments.as_ref())
&& !args.is_empty()
{
tc_entry.arguments.push_str(args);
if let Some(content_index) = tc_entry.content_index {
emit_delta = Some((content_index, args.clone()));
}
}
if tc_entry.content_index.is_none()
&& let Some(name) = tc_entry.name.clone()
{
open_tool_call = Some((tc_entry.id.clone(), name, tc_entry.arguments.clone()));
}
}
if let Some((id, name, buffered_arguments)) = open_tool_call {
let (content_index, start_ev) = state.blocks.open_tool_call(id, name);
events.push(start_ev);
if !buffered_arguments.is_empty() {
events.push(crate::block_accumulator::BlockAccumulator::tool_call_delta(
content_index,
buffered_arguments,
));
}
let tc_entry = state
.tool_calls
.get_mut(&tc_index)
.expect("entry exists after opening");
tc_entry.content_index = Some(content_index);
return;
}
if let Some((content_index, args)) = emit_delta {
events.push(crate::block_accumulator::BlockAccumulator::tool_call_delta(
content_index,
args,
));
}
}
fn flush_pending_oai_tool_calls(
state: &mut OaiSseStreamState,
events: &mut Vec<AssistantMessageEvent>,
) {
let mut pending_indices: Vec<_> = state
.tool_calls
.iter()
.filter_map(|(tc_index, entry)| entry.content_index.is_none().then_some(*tc_index))
.collect();
pending_indices.sort_unstable();
for tc_index in pending_indices {
let (id, name, arguments) = {
let entry = state
.tool_calls
.get(&tc_index)
.expect("pending entry should exist");
(
entry.id.clone(),
entry.name.clone().unwrap_or_default(),
entry.arguments.clone(),
)
};
let (content_index, start_ev) = state.blocks.open_tool_call(id, name);
events.push(start_ev);
if !arguments.is_empty() {
events.push(crate::block_accumulator::BlockAccumulator::tool_call_delta(
content_index,
arguments,
));
}
let entry = state
.tool_calls
.get_mut(&tc_index)
.expect("pending entry should still exist");
entry.content_index = Some(content_index);
}
}
#[allow(clippy::too_many_lines)]
pub fn parse_oai_sse_stream(
response: reqwest::Response,
cancellation_token: CancellationToken,
provider: &'static str,
on_raw_payload: Option<swink_agent::OnRawPayload>,
) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send>> {
let line_stream = sse_data_lines_with_callback(response.bytes_stream(), on_raw_payload);
crate::sse::sse_adapter_stream(
line_stream,
cancellation_token,
OaiSseStreamState::default(),
"operation cancelled",
move |item, state| match item {
None => {
let mut events = Vec::new();
flush_pending_oai_tool_calls(state, &mut events);
events.extend(crate::finalize::finalize_blocks(state));
if let Some(error) = state.terminal_error.take() {
events.push(error);
} else if let Some(stop_reason) = state.stop_reason.take() {
let usage = state.usage.take();
events.push(AssistantMessageEvent::Done {
stop_reason,
usage: usage.unwrap_or_default(),
cost: Cost::default(),
});
} else {
events.push(AssistantMessageEvent::error_network(format!(
"{provider} stream ended unexpectedly",
)));
}
SseAction::Done(events)
}
Some(SseLine::Done) => {
let mut events = Vec::new();
flush_pending_oai_tool_calls(state, &mut events);
events.extend(crate::finalize::finalize_blocks(state));
if let Some(error) = state.terminal_error.take() {
events.push(error);
} else {
let stop_reason = state.stop_reason.take().unwrap_or(StopReason::Stop);
let usage = state.usage.take();
events.push(AssistantMessageEvent::Done {
stop_reason,
usage: usage.unwrap_or_default(),
cost: Cost::default(),
});
}
SseAction::Done(events)
}
Some(SseLine::Data(data)) => {
let chunk: OaiChunk = match serde_json::from_str(&data) {
Ok(c) => c,
Err(e) => {
error!(error = %e, "{provider} JSON parse error");
let mut events = Vec::new();
flush_pending_oai_tool_calls(state, &mut events);
events.extend(crate::finalize::finalize_blocks(state));
events.push(AssistantMessageEvent::error_network(format!(
"{provider} JSON parse error: {e}",
)));
return SseAction::Done(events);
}
};
let mut events = Vec::new();
process_oai_chunk(&chunk, state, &mut events, provider);
if let Some(error) = state.terminal_error.take() {
events.push(error);
SseAction::Done(events)
} else {
SseAction::Continue(events)
}
}
Some(SseLine::TransportError(message)) => {
let mut events = Vec::new();
flush_pending_oai_tool_calls(state, &mut events);
events.extend(crate::finalize::finalize_blocks(state));
events.push(AssistantMessageEvent::error_network(format!(
"{provider} {message}",
)));
SseAction::Done(events)
}
Some(_) => SseAction::Skip,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
fn chunk_with_delta(delta: OaiDelta, finish_reason: Option<&str>) -> OaiChunk {
OaiChunk {
choices: vec![OaiChoice {
delta,
finish_reason: finish_reason.map(String::from),
}],
usage: None,
}
}
#[test]
fn reasoning_content_emits_thinking_events() {
let mut state = OaiSseStreamState::default();
let mut events = Vec::new();
let chunk = chunk_with_delta(
OaiDelta {
reasoning_content: Some("Let me think".to_string()),
..Default::default()
},
None,
);
process_oai_chunk(&chunk, &mut state, &mut events, "test");
assert_eq!(events.len(), 2);
assert!(matches!(
&events[0],
AssistantMessageEvent::ThinkingStart { content_index: 0 }
));
assert!(
matches!(&events[1], AssistantMessageEvent::ThinkingDelta { content_index: 0, delta } if delta == "Let me think")
);
events.clear();
let chunk = chunk_with_delta(
OaiDelta {
reasoning_content: Some(" more".to_string()),
..Default::default()
},
None,
);
process_oai_chunk(&chunk, &mut state, &mut events, "test");
assert_eq!(events.len(), 1);
assert!(
matches!(&events[0], AssistantMessageEvent::ThinkingDelta { content_index: 0, delta } if delta == " more")
);
}
#[test]
fn reasoning_to_content_transition_closes_thinking() {
let mut state = OaiSseStreamState::default();
let mut events = Vec::new();
let chunk = chunk_with_delta(
OaiDelta {
reasoning_content: Some("thinking...".to_string()),
..Default::default()
},
None,
);
process_oai_chunk(&chunk, &mut state, &mut events, "test");
assert_eq!(events.len(), 2);
events.clear();
let chunk = chunk_with_delta(
OaiDelta {
content: Some("Hello".to_string()),
..Default::default()
},
None,
);
process_oai_chunk(&chunk, &mut state, &mut events, "test");
assert_eq!(events.len(), 3);
assert!(matches!(
&events[0],
AssistantMessageEvent::ThinkingEnd {
content_index: 0,
..
}
));
assert!(matches!(
&events[1],
AssistantMessageEvent::TextStart { content_index: 1 }
));
assert!(matches!(
&events[2],
AssistantMessageEvent::TextDelta { content_index: 1, delta } if delta == "Hello"
));
}
#[test]
fn reasoning_to_tool_call_closes_thinking() {
let mut state = OaiSseStreamState::default();
let mut events = Vec::new();
let chunk = chunk_with_delta(
OaiDelta {
reasoning_content: Some("planning...".to_string()),
..Default::default()
},
None,
);
process_oai_chunk(&chunk, &mut state, &mut events, "test");
events.clear();
let chunk = chunk_with_delta(
OaiDelta {
tool_calls: Some(vec![OaiToolCallDelta {
index: 0,
id: Some("call_1".to_string()),
function: Some(OaiFunctionDelta {
name: Some("my_tool".to_string()),
arguments: Some(r#"{"a":1}"#.to_string()),
}),
}]),
..Default::default()
},
None,
);
process_oai_chunk(&chunk, &mut state, &mut events, "test");
assert!(matches!(
&events[0],
AssistantMessageEvent::ThinkingEnd {
content_index: 0,
..
}
));
assert!(matches!(
&events[1],
AssistantMessageEvent::ToolCallStart {
content_index: 1,
..
}
));
}
#[test]
fn chunks_without_reasoning_work_normally() {
let mut state = OaiSseStreamState::default();
let mut events = Vec::new();
let chunk = chunk_with_delta(
OaiDelta {
content: Some("Hello world".to_string()),
..Default::default()
},
None,
);
process_oai_chunk(&chunk, &mut state, &mut events, "test");
assert_eq!(events.len(), 2); assert!(matches!(
&events[0],
AssistantMessageEvent::TextStart { content_index: 0 }
));
assert!(matches!(
&events[1],
AssistantMessageEvent::TextDelta { content_index: 0, delta } if delta == "Hello world"
));
}
#[test]
fn empty_reasoning_content_ignored() {
let mut state = OaiSseStreamState::default();
let mut events = Vec::new();
let chunk = chunk_with_delta(
OaiDelta {
reasoning_content: Some(String::new()),
..Default::default()
},
None,
);
process_oai_chunk(&chunk, &mut state, &mut events, "test");
assert!(events.is_empty());
}
#[test]
fn null_reasoning_content_ignored() {
let mut state = OaiSseStreamState::default();
let mut events = Vec::new();
let chunk = chunk_with_delta(
OaiDelta {
reasoning_content: None,
content: Some("text".to_string()),
..Default::default()
},
None,
);
process_oai_chunk(&chunk, &mut state, &mut events, "test");
assert_eq!(events.len(), 2);
assert!(matches!(
&events[0],
AssistantMessageEvent::TextStart { content_index: 0 }
));
}
#[test]
fn reasoning_content_deserialized_from_json() {
let json = r#"{
"choices": [{
"delta": {
"reasoning_content": "step by step"
},
"finish_reason": null
}]
}"#;
let chunk: OaiChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.choices.len(), 1);
assert_eq!(
chunk.choices[0].delta.reasoning_content.as_deref(),
Some("step by step")
);
}
}