use std::collections::HashMap;
use anyhow::{Context, Result};
use futures::{StreamExt, TryStreamExt, future::BoxFuture};
use serde_json::{Value, json};
use super::{
ContentPart, LlmProvider, LlmRequest, LlmStream, Message, MessageContent, Role, StreamEvent,
TokenUsage,
};
pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
#[allow(dead_code)]
const DEFAULT_MAX_TOKENS: u32 = 65536;
const TOOL_NAME_PREFIX: &str = "rc_";
pub(crate) fn sanitize_tool_name(name: &str) -> String {
let needs_encoding = name.starts_with(TOOL_NAME_PREFIX)
|| !name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-');
if !needs_encoding {
return name.to_owned();
}
let mut out = String::with_capacity(name.len() + TOOL_NAME_PREFIX.len() + 2);
out.push_str(TOOL_NAME_PREFIX);
for c in name.chars() {
match c {
'_' => out.push_str("_u_"),
'.' => out.push_str("_d_"),
'a'..='z' | 'A'..='Z' | '0'..='9' | '-' => out.push(c),
_ => {
for b in c.to_string().as_bytes() {
out.push_str(&format!("_x{b:02X}_"));
}
}
}
}
out
}
pub fn restore_tool_name(name: &str) -> String {
let Some(encoded) = name.strip_prefix(TOOL_NAME_PREFIX) else {
return name.to_owned();
};
let bytes = encoded.as_bytes();
let mut out = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'_'
&& let Some(end_rel) = bytes[i + 1..].iter().position(|b| *b == b'_')
{
let end = i + 1 + end_rel;
let code = &encoded[i + 1..end];
match code {
"u" => {
out.push(b'_');
i = end + 1;
continue;
}
"d" => {
out.push(b'.');
i = end + 1;
continue;
}
_ if code.len() == 3 && code.starts_with('x') => {
if let Ok(value) = u8::from_str_radix(&code[1..], 16) {
out.push(value);
i = end + 1;
continue;
}
}
_ => {}
}
}
out.push(bytes[i]);
i += 1;
}
String::from_utf8(out).unwrap_or_else(|_| name.to_owned())
}
#[cfg(test)]
#[test]
fn tool_name_roundtrip() {
assert_eq!(sanitize_tool_name("read_file"), "read_file");
assert_eq!(sanitize_tool_name("shell"), "shell");
assert_eq!(sanitize_tool_name("video_gen"), "video_gen");
assert_eq!(sanitize_tool_name("douyin__publish"), "douyin__publish");
assert_eq!(restore_tool_name("read_file"), "read_file");
assert_eq!(
sanitize_tool_name("wechat.send_text"),
"rc_wechat_d_send_u_text"
);
assert_eq!(
restore_tool_name("rc_wechat_d_send_u_text"),
"wechat.send_text"
);
assert_eq!(sanitize_tool_name("ns:tool"), "rc_ns_x3A_tool");
assert_eq!(restore_tool_name(&sanitize_tool_name("rc_weird")), "rc_weird");
}
#[cfg(test)]
#[test]
fn tool_name_roundtrip_preserves_double_underscore() {
let original = "plugin__alpha.send__text";
let wire = sanitize_tool_name(original);
assert_eq!(restore_tool_name(&wire), original);
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum OpenAiMode {
Chat, Responses, }
pub struct OpenAiProvider {
client: reqwest::Client,
api_key: Option<String>,
base_url: String,
is_ollama: bool,
mode: OpenAiMode,
}
impl OpenAiProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: super::http_client(),
api_key: Some(api_key.into()),
base_url: OPENAI_API_BASE.to_owned(),
is_ollama: false,
mode: OpenAiMode::Chat,
}
}
pub fn with_base_url(base_url: impl Into<String>, api_key: Option<String>) -> Self {
Self {
client: super::http_client(),
api_key,
base_url: base_url.into(),
is_ollama: false,
mode: OpenAiMode::Chat,
}
}
pub fn chat(base_url: impl Into<String>, api_key: Option<String>) -> Self {
Self {
client: super::http_client(),
api_key,
base_url: base_url.into(),
is_ollama: false,
mode: OpenAiMode::Chat,
}
}
pub fn responses(base_url: impl Into<String>, api_key: Option<String>) -> Self {
Self {
client: super::http_client(),
api_key,
base_url: base_url.into(),
is_ollama: false,
mode: OpenAiMode::Responses,
}
}
pub fn ollama(base_url: impl Into<String>, api_key: Option<String>) -> Self {
Self {
client: super::http_client(),
api_key,
base_url: base_url.into(),
is_ollama: true,
mode: OpenAiMode::Chat,
}
}
pub fn with_user_agent(
base_url: impl Into<String>,
api_key: Option<String>,
user_agent: Option<String>,
) -> Self {
Self {
client: super::http_client_with_ua(user_agent.as_deref()),
api_key,
base_url: base_url.into(),
is_ollama: false,
mode: OpenAiMode::Chat,
}
}
pub fn responses_with_ua(
base_url: impl Into<String>,
api_key: Option<String>,
user_agent: Option<String>,
) -> Self {
Self {
client: super::http_client_with_ua(user_agent.as_deref()),
api_key,
base_url: base_url.into(),
is_ollama: false,
mode: OpenAiMode::Responses,
}
}
pub fn ollama_with_ua(
base_url: impl Into<String>,
api_key: Option<String>,
user_agent: Option<String>,
) -> Self {
Self {
client: super::http_client_with_ua(user_agent.as_deref()),
api_key,
base_url: base_url.into(),
is_ollama: true,
mode: OpenAiMode::Chat,
}
}
}
impl LlmProvider for OpenAiProvider {
fn name(&self) -> &str {
match self.mode {
OpenAiMode::Chat => "openai",
OpenAiMode::Responses => "openai-responses",
}
}
fn stream(&self, req: LlmRequest) -> BoxFuture<'_, Result<LlmStream>> {
Box::pin(async move {
tracing::info!(
model = %req.model,
max_tokens = ?req.max_tokens,
thinking_budget = ?req.thinking_budget,
"openai: preparing LLM request"
);
super::warn_unsupported_kv_cache_mode_2(self.name(), &req);
if self.is_ollama && !self.base_url.contains("/v1") {
return self.stream_ollama_native(&req).await;
}
if self.mode == OpenAiMode::Responses {
return self.stream_responses(&req).await;
}
let body = build_request_body(&req)?;
let body_str = serde_json::to_string(&body).unwrap_or_default();
tracing::debug!(
model = %req.model,
tools_count = req.tools.len(),
has_tools_in_body = body.get("tools").is_some(),
body_len = body_str.len(),
"openai: request prepared"
);
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
let mut builder = self
.client
.post(url)
.header("content-type", "application/json");
if let Some(key) = &self.api_key {
builder = builder.header("authorization", format!("Bearer {key}"));
}
let resp = super::send_with_transport_retry(builder.json(&body))
.await
.context("OpenAI request failed")?;
let status = resp.status();
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_owned();
tracing::info!(
%status,
content_type = %content_type,
"openai: response received"
);
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("OpenAI API error {status}: {body}");
}
if !content_type.contains("text/event-stream") && content_type.contains("json") {
let body: serde_json::Value =
resp.json().await.context("OpenAI: parse JSON response")?;
tracing::info!(body = %body.to_string().chars().take(500).collect::<String>(), "openai: non-streaming JSON response");
let text = body
.pointer("/choices/0/message/content")
.and_then(|v| v.as_str())
.unwrap_or("");
if !text.is_empty() {
let stream = futures::stream::iter(vec![
Ok(StreamEvent::TextDelta(text.to_owned())),
Ok(StreamEvent::Done { usage: None }),
]);
let llm_stream: LlmStream = Box::pin(stream);
return Ok(llm_stream);
}
anyhow::bail!(
"OpenAI: empty non-streaming response: {}",
body.to_string().chars().take(500).collect::<String>()
);
}
let byte_stream = resp.bytes_stream();
let line_buffer = std::sync::Arc::new(tokio::sync::Mutex::new(String::new()));
let utf8_remainder = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::<u8>::new()));
let event_stream = byte_stream
.map_err(|e| anyhow::anyhow!("stream read error: {e}"))
.then(move |chunk| {
let line_buffer = line_buffer.clone();
let utf8_remainder = utf8_remainder.clone();
async move {
parse_sse_chunk_with_buffer(chunk, &line_buffer, &utf8_remainder).await
}
})
.flat_map(|events| futures::stream::iter(events));
let stream: LlmStream = Box::pin(event_stream);
Ok(stream)
})
}
}
impl OpenAiProvider {
async fn stream_ollama_native(&self, req: &LlmRequest) -> Result<LlmStream> {
let base = self.base_url.trim_end_matches('/');
let url = format!("{base}/api/chat");
let mut messages: Vec<Value> = Vec::new();
if let Some(ref sys) = req.system {
messages.push(json!({"role": "system", "content": sys}));
}
let thinking_enabled = matches!(req.thinking_budget, Some(b) if b > 0);
for msg in &req.messages {
let mut m = serialize_message(msg, thinking_enabled);
if let Some(tcs) = m.get_mut("tool_calls").and_then(|v| v.as_array_mut()) {
for tc in tcs {
if let Some(args) = tc.pointer_mut("/function/arguments") {
if let Some(s) = args.as_str() {
if let Ok(parsed) = serde_json::from_str::<Value>(s) {
*args = parsed;
}
}
}
}
}
if let Some(content) = m.get("content") {
if content.is_array() {
let empty = Vec::new();
let parts = content.as_array().unwrap_or(&empty);
let mut texts = Vec::new();
let mut images = Vec::new();
for p in parts {
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
texts.push(t.to_owned());
} else if let Some(url) =
p.pointer("/image_url/url").and_then(|v| v.as_str())
{
let b64 = url.split(",").last().unwrap_or(url);
images.push(json!(b64));
}
}
m["content"] = json!(texts.join("\n"));
if !images.is_empty() {
m["images"] = json!(images);
}
}
}
messages.push(m);
}
let mut body = json!({
"model": req.model,
"messages": messages,
"stream": true,
"think": false,
});
let mut options = serde_json::Map::new();
if let Some(t) = req.temperature {
options.insert("temperature".into(), super::json_f32(t));
}
if let Some(max) = req.max_tokens {
if max > 0 {
options.insert("num_predict".into(), json!(max));
}
}
if !options.is_empty() {
body["options"] = Value::Object(options);
}
if !req.tools.is_empty() {
let tools: Vec<Value> = req
.tools
.iter()
.map(|t| {
json!({
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters,
}
})
})
.collect();
body["tools"] = json!(tools);
}
if let Some(msgs) = body["messages"].as_array_mut() {
normalize_messages_for_cache(msgs);
}
tracing::debug!(
tools_count = req.tools.len(),
think = req.tools.is_empty(),
"ollama native: calling {url}"
);
let resp = super::send_with_transport_retry(
self.client
.post(&url)
.header("content-type", "application/json")
.json(&body),
)
.await
.context("ollama native request failed")?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("ollama native API error: {body}");
}
let byte_stream = resp.bytes_stream();
let in_thinking = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let event_stream = byte_stream
.map_err(|e| anyhow::anyhow!("stream read error: {e}"))
.flat_map(move |chunk| {
let in_thinking = std::sync::Arc::clone(&in_thinking);
let events: Vec<Result<StreamEvent>> = match chunk {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
text.lines()
.filter_map(|line| {
let line = line.trim();
if line.is_empty() {
return None;
}
let v: Value = serde_json::from_str(line).ok()?;
if let Some(tc) = v
.get("message")
.and_then(|m| m.get("tool_calls"))
.and_then(|tc| tc.as_array())
.and_then(|a| a.first())
{
let func = &tc["function"];
let name = func["name"].as_str().unwrap_or("").to_owned();
let input = if func["arguments"].is_object() {
func["arguments"].clone()
} else {
let args_str = func["arguments"].as_str().unwrap_or("{}");
serde_json::from_str(args_str).unwrap_or(json!({}))
};
return Some(Ok(StreamEvent::ToolCall {
id: format!("call_{}", name),
name,
input,
}));
}
let thinking = v
.pointer("/message/thinking")
.and_then(|c| c.as_str())
.unwrap_or("");
let content = v
.pointer("/message/content")
.and_then(|c| c.as_str())
.unwrap_or("");
let done = v["done"].as_bool().unwrap_or(false);
if done {
if in_thinking.swap(false, std::sync::atomic::Ordering::Relaxed)
{
return Some(Ok(StreamEvent::TextDelta(
"</think>".to_owned(),
)));
}
return Some(Ok(StreamEvent::Done { usage: None }));
}
if !thinking.is_empty() {
let was_thinking = in_thinking
.swap(true, std::sync::atomic::Ordering::Relaxed);
if !was_thinking {
return Some(Ok(StreamEvent::TextDelta(format!(
"<think>{thinking}"
))));
}
return Some(Ok(StreamEvent::TextDelta(thinking.to_owned())));
}
if !content.is_empty() {
let was_thinking = in_thinking
.swap(false, std::sync::atomic::Ordering::Relaxed);
if was_thinking {
return Some(Ok(StreamEvent::TextDelta(format!(
"</think>{content}"
))));
}
Some(Ok(StreamEvent::TextDelta(content.to_owned())))
} else {
None
}
})
.collect()
}
Err(e) => vec![Err(e)],
};
futures::stream::iter(events)
});
Ok(Box::pin(event_stream))
}
async fn upload_image_to_files(&self, data_uri: &str) -> Result<String> {
use base64::Engine;
let rest = data_uri
.strip_prefix("data:")
.ok_or_else(|| anyhow::anyhow!("invalid data URI: missing data: prefix"))?;
let (meta, b64_data) = rest
.split_once(',')
.ok_or_else(|| anyhow::anyhow!("invalid data URI: missing comma"))?;
let mime_type = meta.split(';').next().unwrap_or("image/png");
let ext = match mime_type {
"image/jpeg" | "image/jpg" => "jpg",
"image/png" => "png",
"image/gif" => "gif",
"image/webp" => "webp",
"video/mp4" => "mp4",
"video/quicktime" => "mov",
"video/webm" => "webm",
"video/x-msvideo" => "avi",
_ => "bin",
};
let image_bytes = base64::engine::general_purpose::STANDARD
.decode(b64_data)
.context("failed to decode base64 image data")?;
let filename = format!("upload.{ext}");
let file_part = reqwest::multipart::Part::bytes(image_bytes)
.file_name(filename)
.mime_str(mime_type)
.context("invalid mime type")?;
let form = reqwest::multipart::Form::new()
.text("purpose", "user_data")
.part("file", file_part);
let url = format!("{}/files", self.base_url.trim_end_matches('/'));
let mut builder = self.client.post(&url);
if let Some(key) = &self.api_key {
builder = builder.header("authorization", format!("Bearer {key}"));
}
let resp = builder
.multipart(form)
.timeout(std::time::Duration::from_secs(500))
.send()
.await
.context("Files API upload request failed")?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("Files API upload error: {body}");
}
let body: Value = resp.json().await.context("Files API: parse response")?;
let file_id = body["id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Files API response missing id field: {body}"))?
.to_owned();
tracing::debug!(file_id = %file_id, "uploaded image to Files API");
Ok(file_id)
}
async fn upload_images_for_messages(&self, messages: &[Message]) -> HashMap<String, String> {
let mut file_id_map = HashMap::new();
let mut data_uris: Vec<String> = Vec::new();
for msg in messages {
if let MessageContent::Parts(parts) = &msg.content {
for part in parts {
if let ContentPart::Image { url } = part {
if url.starts_with("data:") && !data_uris.contains(url) {
data_uris.push(url.clone());
}
}
}
}
}
tracing::info!(
count = data_uris.len(),
"upload_images: found data URIs to upload"
);
for uri in data_uris {
match self.upload_image_to_files(&uri).await {
Ok(file_id) => {
file_id_map.insert(uri, file_id);
}
Err(e) => {
tracing::warn!(error = %e, "failed to upload image to Files API, falling back to base64 inline");
}
}
}
file_id_map
}
async fn stream_responses(&self, req: &LlmRequest) -> Result<LlmStream> {
let file_id_map = self.upload_images_for_messages(&req.messages).await;
let body = build_responses_body(req, &file_id_map)?;
let body_str = serde_json::to_string(&body).unwrap_or_default();
tracing::debug!(
model = %req.model,
tools_count = req.tools.len(),
body_len = body_str.len(),
"openai-responses: request prepared"
);
let url = format!("{}/responses", self.base_url.trim_end_matches('/'));
let mut builder = self
.client
.post(&url)
.header("content-type", "application/json");
if let Some(key) = &self.api_key {
builder = builder.header("authorization", format!("Bearer {key}"));
}
let resp = super::send_with_transport_retry(builder.json(&body))
.await
.context("OpenAI Responses request failed")?;
let status = resp.status();
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_owned();
tracing::debug!(
%status,
content_type = %content_type,
"openai-responses: response received"
);
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("OpenAI Responses API error {status}: {body}");
}
if !content_type.contains("text/event-stream") && content_type.contains("json") {
let body: Value = resp
.json()
.await
.context("OpenAI Responses: parse JSON response")?;
tracing::debug!(
body = %body.to_string().chars().take(300).collect::<String>(),
"openai-responses: non-streaming JSON response"
);
let text = body
.pointer("/output/0/content/0/text")
.and_then(|v| v.as_str())
.unwrap_or("");
if !text.is_empty() {
let stream = futures::stream::iter(vec![
Ok(StreamEvent::TextDelta(text.to_owned())),
Ok(StreamEvent::Done { usage: None }),
]);
return Ok(Box::pin(stream) as LlmStream);
}
anyhow::bail!(
"OpenAI Responses: empty non-streaming response: {}",
body.to_string().chars().take(500).collect::<String>()
);
}
let byte_stream = resp.bytes_stream();
let line_buffer = std::sync::Arc::new(tokio::sync::Mutex::new(String::new()));
let utf8_remainder = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::<u8>::new()));
let event_stream = byte_stream
.map_err(|e| anyhow::anyhow!("stream read error: {e}"))
.then(move |chunk| {
let line_buffer = line_buffer.clone();
let utf8_remainder = utf8_remainder.clone();
async move {
parse_responses_sse_chunk_buffered(chunk, &line_buffer, &utf8_remainder).await
}
})
.flat_map(|events| futures::stream::iter(events));
let stream: LlmStream = Box::pin(event_stream);
Ok(stream)
}
}
fn build_request_body(req: &LlmRequest) -> Result<Value> {
let thinking_enabled = matches!(req.thinking_budget, Some(b) if b > 0);
let messages: Vec<Value> = req
.messages
.iter()
.map(|m| serialize_message(m, thinking_enabled))
.collect();
let bare_model = req
.model
.rsplit_once('/')
.map(|(_, m)| m)
.unwrap_or(&req.model);
let mut body = json!({
"model": bare_model,
"stream": true,
"messages": messages,
});
if let Some(max_tokens) = req.max_tokens {
if max_tokens > 0 {
body["max_tokens"] = json!(max_tokens);
}
}
let model_lower = req.model.to_lowercase();
let is_minimax = model_lower.contains("minimax");
let is_vision = req.endpoint == super::AgentEndpoint::Vision;
if is_minimax {
body["reasoning_split"] = json!(true);
} else if !is_vision {
match req.thinking_budget {
Some(budget) if budget > 0 => {
body["enable_thinking"] = json!(true);
body["thinking"] = json!({"type": "disabled"});
body["thinking_budget"] = json!(budget);
body["chat_template_kwargs"] = json!({"enable_thinking": true});
}
_ => {
body["enable_thinking"] = json!(false);
body["thinking"] = json!({"type": "disabled"});
body["chat_template_kwargs"] = json!({"enable_thinking": false});
}
}
}
if let Some(sys) = &req.system {
let mut msgs = vec![json!({"role": "system", "content": sys})];
msgs.extend(body["messages"].as_array().cloned().unwrap_or_default());
let mut system_parts: Vec<String> = Vec::new();
let mut non_system: Vec<Value> = Vec::new();
for msg in &msgs {
if msg["role"].as_str() == Some("system") {
if let Some(c) = msg["content"].as_str() {
if !c.is_empty() {
system_parts.push(c.to_owned());
}
}
} else {
non_system.push(msg.clone());
}
}
let mut merged = vec![json!({"role": "system", "content": system_parts.join("\n")})];
merged.extend(non_system);
body["messages"] = json!(merged);
}
if let Some(t) = req.temperature {
body["temperature"] = super::json_f32(t);
}
if let Some(fp) = req.frequency_penalty {
if fp > 0.0 {
body["frequency_penalty"] = super::json_f32(fp);
}
}
if !req.tools.is_empty() {
let tools: Vec<Value> = req
.tools
.iter()
.map(|t| {
json!({
"type": "function",
"function": {
"name": sanitize_tool_name(&t.name),
"description": t.description,
"parameters": t.parameters,
}
})
})
.collect();
body["tools"] = json!(tools);
let needs_explicit_tool_choice = {
let m = req.model.to_lowercase();
m.contains("kimi") || m.contains("moonshot") || m.contains("k1.5") || m.contains("k2")
};
if needs_explicit_tool_choice {
body["tool_choice"] = json!("auto");
}
}
if let Some(msgs) = body["messages"].as_array_mut() {
normalize_messages_for_cache(msgs);
fix_tool_call_pairing(msgs);
reorder_tool_messages(msgs);
}
Ok(body)
}
fn normalize_messages_for_cache(messages: &mut [Value]) {
for msg in messages.iter_mut() {
if let Some(content) = msg
.get_mut("content")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_owned())
{
msg["content"] = json!(content);
}
if let Some(tcs) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) {
for tc in tcs.iter_mut() {
if let Some(args_str) = tc.pointer("/function/arguments").and_then(|v| v.as_str()) {
if let Ok(parsed) = serde_json::from_str::<Value>(args_str) {
if let Ok(canonical) = serde_json::from_str::<
std::collections::BTreeMap<String, Value>,
>(&parsed.to_string())
{
if let Ok(sorted) = serde_json::to_string(&canonical) {
tc["function"]["arguments"] = json!(sorted);
}
}
}
}
}
}
}
}
fn fix_tool_call_pairing(messages: &mut Vec<Value>) {
let mut valid_call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
for msg in messages.iter() {
if msg.get("role").and_then(|r| r.as_str()) == Some("assistant") {
if let Some(tcs) = msg.get("tool_calls").and_then(|v| v.as_array()) {
for tc in tcs {
if let Some(id) = tc.get("id").and_then(|v| v.as_str()) {
valid_call_ids.insert(id.to_owned());
}
}
}
}
}
let mut result_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
for msg in messages.iter() {
if msg.get("role").and_then(|r| r.as_str()) == Some("tool") {
if let Some(id) = msg.get("tool_call_id").and_then(|v| v.as_str()) {
result_ids.insert(id.to_owned());
}
}
}
messages.retain(|msg| {
if msg.get("role").and_then(|r| r.as_str()) == Some("tool") {
if let Some(id) = msg.get("tool_call_id").and_then(|v| v.as_str()) {
return valid_call_ids.contains(id);
}
}
true
});
for msg in messages.iter_mut() {
if msg.get("role").and_then(|r| r.as_str()) == Some("assistant") {
if let Some(tcs) = msg.get("tool_calls").and_then(|v| v.as_array()).cloned() {
let filtered: Vec<Value> = tcs
.into_iter()
.filter(|tc| {
tc.get("id")
.and_then(|v| v.as_str())
.map(|id| result_ids.contains(id))
.unwrap_or(true)
})
.collect();
if filtered.is_empty() {
msg.as_object_mut().map(|m| m.remove("tool_calls"));
} else if filtered.len()
!= msg["tool_calls"].as_array().map(|a| a.len()).unwrap_or(0)
{
msg["tool_calls"] = json!(filtered);
}
}
}
}
}
fn reorder_tool_messages(messages: &mut Vec<Value>) {
let mut tool_results: std::collections::HashMap<String, Vec<Value>> =
std::collections::HashMap::new();
let mut non_tool: Vec<Value> = Vec::new();
for msg in messages.drain(..) {
if msg.get("role").and_then(|r| r.as_str()) == Some("tool") {
if let Some(id) = msg.get("tool_call_id").and_then(|v| v.as_str()) {
tool_results.entry(id.to_owned()).or_default().push(msg);
}
} else {
non_tool.push(msg);
}
}
for msg in non_tool {
if msg.get("role").and_then(|r| r.as_str()) == Some("assistant") {
if let Some(tcs) = msg.get("tool_calls").and_then(|v| v.as_array()) {
let call_ids: Vec<String> = tcs
.iter()
.filter_map(|tc| tc.get("id").and_then(|v| v.as_str()).map(String::from))
.collect();
messages.push(msg);
for cid in &call_ids {
if let Some(results) = tool_results.remove(cid) {
messages.extend(results);
}
}
continue;
}
}
messages.push(msg);
}
if !tool_results.is_empty() {
tracing::debug!(
orphaned = tool_results.len(),
"reorder_tool_messages: dropped orphaned tool results"
);
}
}
fn serialize_message(msg: &Message, thinking_enabled: bool) -> Value {
let role_str = match msg.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
};
if msg.role == Role::Tool {
if let MessageContent::Parts(parts) = &msg.content {
for part in parts {
if let ContentPart::ToolResult {
tool_use_id,
content,
..
} = part
{
return json!({
"role": "tool",
"tool_call_id": tool_use_id,
"content": content,
});
}
}
}
let text = match &msg.content {
MessageContent::Text(t) => t.clone(),
MessageContent::Parts(_) => String::new(),
};
return json!({ "role": "tool", "content": text });
}
if msg.role == Role::Assistant {
if let MessageContent::Parts(parts) = &msg.content {
let mut text_parts = Vec::new();
let mut reasoning_parts = Vec::new();
let mut tool_calls = Vec::new();
for part in parts {
match part {
ContentPart::ToolUse { id, name, input } => {
tool_calls.push(json!({
"id": id,
"type": "function",
"function": {
"name": sanitize_tool_name(name),
"arguments": input.to_string()
}
}));
}
ContentPart::Text { text } => text_parts.push(text.clone()),
ContentPart::Reasoning { text } => reasoning_parts.push(text.clone()),
_ => {}
}
}
let text = text_parts.join("");
let reasoning = reasoning_parts.join("");
let needs_obj = !tool_calls.is_empty() || !reasoning.is_empty() || thinking_enabled;
if needs_obj {
let mut obj = json!({
"role": "assistant",
"content": text,
});
if let Some(obj_map) = obj.as_object_mut() {
if !tool_calls.is_empty() {
obj_map.insert("tool_calls".to_owned(), json!(tool_calls));
}
if thinking_enabled || !reasoning.is_empty() {
obj_map.insert("reasoning_content".to_owned(), json!(reasoning));
}
}
return obj;
}
}
if thinking_enabled && let MessageContent::Text(t) = &msg.content {
return json!({
"role": "assistant",
"content": t,
"reasoning_content": "",
});
}
}
let content = match &msg.content {
MessageContent::Text(t) => json!(t),
MessageContent::Parts(parts) => {
let serialized: Vec<Value> = parts.iter().map(serialize_part).collect();
json!(serialized)
}
};
json!({ "role": role_str, "content": content })
}
fn serialize_part(part: &ContentPart) -> Value {
match part {
ContentPart::Text { text } => json!({ "type": "text", "text": text }),
ContentPart::Image { url } => json!({
"type": "image_url",
"image_url": { "url": url }
}),
ContentPart::ToolUse { id, name, input } => json!({
"type": "function",
"id": id,
"function": { "name": sanitize_tool_name(name), "arguments": input.to_string() }
}),
ContentPart::ToolResult {
tool_use_id,
content,
..
} => json!({
"role": "tool",
"tool_call_id": tool_use_id,
"content": content,
}),
ContentPart::Reasoning { text } => json!({
"type": "reasoning",
"reasoning": text,
}),
}
}
async fn parse_sse_chunk_with_buffer(
chunk: Result<bytes::Bytes>,
line_buffer: &tokio::sync::Mutex<String>,
utf8_remainder: &tokio::sync::Mutex<Vec<u8>>,
) -> Vec<Result<StreamEvent>> {
let bytes = match chunk {
Ok(b) => b,
Err(e) => return vec![Err(e)],
};
let mut remainder = utf8_remainder.lock().await;
let full_bytes = if remainder.is_empty() {
bytes.to_vec()
} else {
let mut combined = std::mem::take(&mut *remainder);
combined.extend_from_slice(&bytes);
combined
};
let text = match std::str::from_utf8(&full_bytes) {
Ok(t) => {
drop(remainder);
std::borrow::Cow::Owned(t.to_owned())
}
Err(e) => {
let valid_up_to = e.valid_up_to();
*remainder = full_bytes[valid_up_to..].to_vec();
drop(remainder);
if valid_up_to == 0 {
return vec![];
}
std::borrow::Cow::Owned(
std::str::from_utf8(&full_bytes[..valid_up_to])
.expect("valid_up_to guarantees valid UTF-8")
.to_owned(),
)
}
};
let mut buffer = line_buffer.lock().await;
buffer.push_str(&text);
let mut events = Vec::new();
if let Some(last_newline_pos) = buffer.rfind('\n') {
let complete_portion = buffer[..last_newline_pos].to_owned();
let incomplete_portion = buffer[last_newline_pos + 1..].to_owned();
buffer.clear();
buffer.push_str(&incomplete_portion);
for line in complete_portion.lines() {
if let Some(data) = line.strip_prefix("data:").map(|s| s.trim_start()) {
if data == "[DONE]" {
events.push(Ok(StreamEvent::Done { usage: None }));
continue;
}
let parsed = parse_event(data);
if parsed.is_empty() {
tracing::debug!(data, "openai: unparsed SSE data");
} else {
for ev in parsed {
events.push(Ok(ev));
}
}
}
}
}
events
}
#[allow(dead_code)]
fn parse_sse_chunk(chunk: Result<bytes::Bytes>) -> Vec<Result<StreamEvent>> {
let bytes = match chunk {
Ok(b) => b,
Err(e) => return vec![Err(e)],
};
let text = String::from_utf8_lossy(&bytes);
let mut events = Vec::new();
let mut has_data_line = false;
for line in text.lines() {
if let Some(data) = line.strip_prefix("data:").map(|s| s.trim_start()) {
has_data_line = true;
if data == "[DONE]" {
events.push(Ok(StreamEvent::Done { usage: None }));
continue;
}
let parsed = parse_event(data);
if parsed.is_empty() {
tracing::debug!(data, "openai: unparsed SSE data");
} else {
for ev in parsed {
events.push(Ok(ev));
}
}
}
}
if !has_data_line && !text.trim().is_empty() {
tracing::warn!(
raw = rsclaw_util::truncate_str(&text, 500),
"openai: non-SSE chunk received"
);
}
events
}
pub fn strip_think_tags_pub(text: &str) -> String {
strip_think_tags(text)
}
fn strip_think_tags(text: &str) -> String {
let mut result = text.to_owned();
while let Some(start) = result.find("<think>") {
if let Some(rel_end) = result[start + 7..].find("</think>") {
let end = start + 7 + rel_end;
result = format!("{}{}", &result[..start], &result[end + 8..]);
} else {
result = result[..start].to_owned();
break;
}
}
result = result.replace("</think>", "");
result
}
fn parse_event(data: &str) -> Vec<StreamEvent> {
let v: Value = match serde_json::from_str(data) {
Ok(v) => v,
Err(e) => {
tracing::warn!(data, error = %e, "openai: failed to parse SSE JSON");
return Vec::new();
}
};
if let Some(err) = v.get("error") {
let msg = err["message"].as_str().unwrap_or("unknown API error");
return vec![StreamEvent::Error(msg.to_owned())];
}
let choices = match v["choices"].as_array() {
Some(c) => c,
None => {
tracing::warn!(data, "openai: SSE response missing choices array");
return Vec::new();
}
};
let choice = match choices.first() {
Some(c) => c,
None => {
tracing::warn!(data, "openai: SSE response has empty choices array");
return Vec::new();
}
};
let delta = &choice["delta"];
let mut events = Vec::new();
if let Some(text) = delta["reasoning_content"]
.as_str()
.filter(|s| !s.is_empty())
{
events.push(StreamEvent::ReasoningDelta(text.to_owned()));
}
if let Some(text) = delta["content"].as_str().filter(|s| !s.is_empty()) {
events.push(StreamEvent::TextDelta(text.to_owned()));
}
if let Some(tool_calls) = delta["tool_calls"].as_array()
&& let Some(tc) = tool_calls.first()
{
let func = &tc["function"];
let id = tc["id"].as_str().unwrap_or("").to_owned();
let name = func["name"].as_str().unwrap_or("").to_owned();
let args_str = func["arguments"].as_str().unwrap_or("");
let input = if args_str.is_empty() {
Value::Object(Default::default())
} else {
Value::String(args_str.to_owned())
};
tracing::debug!(id = %id, name = %name, args_len = args_str.len(), "openai: tool call chunk");
events.push(StreamEvent::ToolCall { id, name, input });
}
if choice["finish_reason"].is_string() {
let usage = v["usage"].as_object().map(|u| TokenUsage {
input: u.get("prompt_tokens").and_then(Value::as_u64).unwrap_or(0),
output: u
.get("completion_tokens")
.and_then(Value::as_u64)
.unwrap_or(0),
cache_creation: 0,
cache_read: u
.get("prompt_tokens_details")
.and_then(|d| d.get("cached_tokens"))
.and_then(Value::as_u64)
.unwrap_or(0),
..Default::default()
});
events.push(StreamEvent::Done { usage });
}
events
}
fn build_responses_body(req: &LlmRequest, file_id_map: &HashMap<String, String>) -> Result<Value> {
let mut input: Vec<Value> = req
.messages
.iter()
.filter(|m| m.role != Role::System)
.flat_map(|m| serialize_input_items(m, file_id_map))
.collect();
normalize_responses_items(&mut input);
let mut body = json!({
"model": req.model,
"stream": true,
"input": input,
});
let mut system_parts: Vec<String> = Vec::new();
if let Some(ref sys) = req.system {
system_parts.push(sys.clone());
}
for msg in &req.messages {
if msg.role == Role::System {
if let MessageContent::Text(t) = &msg.content {
system_parts.push(t.clone());
}
}
}
if !system_parts.is_empty() {
body["instructions"] = json!(system_parts.join("\n\n"));
}
if let Some(max_tokens) = req.max_tokens {
if max_tokens > 0 {
body["max_output_tokens"] = json!(max_tokens);
}
}
if let Some(t) = req.temperature {
body["temperature"] = super::json_f32(t);
}
if !req.tools.is_empty() {
let tools: Vec<Value> = req
.tools
.iter()
.map(|t| {
json!({
"type": "function",
"name": sanitize_tool_name(&t.name),
"description": t.description,
"parameters": t.parameters,
})
})
.collect();
body["tools"] = json!(tools);
}
Ok(body)
}
fn normalize_responses_items(items: &mut Vec<Value>) {
use std::collections::{HashMap, HashSet};
let item_type = |v: &Value| v.get("type").and_then(|t| t.as_str()).map(str::to_owned);
let call_id = |v: &Value| v.get("call_id").and_then(|c| c.as_str()).map(str::to_owned);
let mut have_call: HashSet<String> = HashSet::new();
let mut have_output: HashSet<String> = HashSet::new();
for it in items.iter() {
match item_type(it).as_deref() {
Some("function_call") => {
if let Some(id) = call_id(it) {
have_call.insert(id);
}
}
Some("function_call_output") => {
if let Some(id) = call_id(it) {
have_output.insert(id);
}
}
_ => {}
}
}
let mut outputs: HashMap<String, Vec<Value>> = HashMap::new();
let mut rest: Vec<Value> = Vec::new();
for it in items.drain(..) {
match item_type(&it).as_deref() {
Some("function_call_output") => match call_id(&it) {
Some(id) if have_call.contains(&id) => {
outputs.entry(id).or_default().push(it);
}
_ => {} },
Some("function_call") => match call_id(&it) {
Some(id) if have_output.contains(&id) => rest.push(it),
_ => {} },
_ => rest.push(it),
}
}
for it in rest {
if item_type(&it).as_deref() == Some("function_call") {
let id = call_id(&it);
items.push(it);
if let Some(outs) = id.and_then(|id| outputs.remove(&id)) {
items.extend(outs);
}
} else {
items.push(it);
}
}
}
fn serialize_input_items(msg: &Message, file_id_map: &HashMap<String, String>) -> Vec<Value> {
let role_str = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
Role::System => "user",
};
if msg.role == Role::Tool {
if let MessageContent::Parts(parts) = &msg.content {
let items: Vec<Value> = parts
.iter()
.filter_map(|p| {
if let ContentPart::ToolResult { tool_use_id, content, .. } = p {
if tool_use_id.is_empty() { return None; }
Some(json!({ "type": "function_call_output", "call_id": tool_use_id, "output": content }))
} else {
None
}
})
.collect();
if !items.is_empty() {
return items;
}
}
return vec![];
}
if msg.role == Role::Assistant {
let mut result: Vec<Value> = Vec::new();
let mut text_parts = Vec::new();
match &msg.content {
MessageContent::Text(t) => text_parts.push(t.clone()),
MessageContent::Parts(parts) => {
for part in parts {
match part {
ContentPart::Text { text } => text_parts.push(text.clone()),
ContentPart::ToolUse { id, name, input } => {
if !id.is_empty() {
result.push(json!({
"type": "function_call",
"call_id": id,
"name": sanitize_tool_name(name),
"arguments": input.to_string(),
"status": "completed",
}));
}
}
_ => {}
}
}
}
}
let text = text_parts.join("");
if !text.is_empty() {
result.insert(
0,
json!({
"type": "message",
"role": "assistant",
"status": "completed",
"content": [{ "type": "output_text", "text": text }],
}),
);
}
if !result.is_empty() {
return result;
}
}
let content = match &msg.content {
MessageContent::Text(t) => json!([{ "type": "input_text", "text": t }]),
MessageContent::Parts(parts) => {
let serialized: Vec<Value> = parts
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => {
Some(json!({ "type": "input_text", "text": text }))
}
ContentPart::Image { url } => {
Some(serialize_media_for_responses(url, file_id_map))
}
_ => None,
})
.collect();
if serialized.is_empty() {
json!([{ "type": "input_text", "text": "" }])
} else {
json!(serialized)
}
}
};
vec![json!({ "role": role_str, "content": content })]
}
fn serialize_media_for_responses(url: &str, file_id_map: &HashMap<String, String>) -> Value {
if url.starts_with("data:") {
if let Some(file_id) = file_id_map.get(url) {
if url.starts_with("data:video/") {
return json!({ "type": "input_video", "file_id": file_id });
}
return json!({ "type": "input_image", "file_id": file_id });
}
if url.starts_with("data:video/") {
tracing::warn!("video upload failed, skipping input_video (fallback to transcription)");
return json!({ "type": "input_text", "text": "[video attached — audio transcription fallback]" });
}
return json!({ "type": "input_image", "image_url": url });
}
let lower = url.to_lowercase();
let path = lower.split('?').next().unwrap_or(&lower);
if path.ends_with(".mp4")
|| path.ends_with(".mov")
|| path.ends_with(".avi")
|| path.ends_with(".webm")
|| path.ends_with(".mkv")
{
json!({ "type": "input_video", "video_url": url })
} else {
json!({ "type": "input_image", "image_url": url })
}
}
async fn parse_responses_sse_chunk_buffered(
chunk: Result<bytes::Bytes>,
line_buffer: &tokio::sync::Mutex<String>,
utf8_remainder: &tokio::sync::Mutex<Vec<u8>>,
) -> Vec<Result<StreamEvent>> {
let bytes = match chunk {
Ok(b) => b,
Err(e) => return vec![Err(e)],
};
let mut remainder = utf8_remainder.lock().await;
let full_bytes = if remainder.is_empty() {
bytes.to_vec()
} else {
let mut combined = std::mem::take(&mut *remainder);
combined.extend_from_slice(&bytes);
combined
};
let text = match std::str::from_utf8(&full_bytes) {
Ok(t) => {
drop(remainder);
std::borrow::Cow::Owned(t.to_owned())
}
Err(e) => {
let valid_up_to = e.valid_up_to();
*remainder = full_bytes[valid_up_to..].to_vec();
drop(remainder);
if valid_up_to == 0 {
return vec![];
}
std::borrow::Cow::Owned(
std::str::from_utf8(&full_bytes[..valid_up_to])
.expect("valid_up_to guarantees valid UTF-8")
.to_owned(),
)
}
};
let mut buffer = line_buffer.lock().await;
buffer.push_str(&text);
let mut events = Vec::new();
let last_newline_pos = match buffer.rfind('\n') {
Some(pos) => pos,
None => return events, };
let complete_portion = buffer[..last_newline_pos].to_owned();
let incomplete_portion = buffer[last_newline_pos + 1..].to_owned();
buffer.clear();
buffer.push_str(&incomplete_portion);
let mut current_event_type: Option<String> = None;
for line in complete_portion.lines() {
let line = line.trim();
if line.is_empty() {
current_event_type = None;
continue;
}
if let Some(event_type) = line.strip_prefix("event: ") {
current_event_type = Some(event_type.trim().to_owned());
continue;
}
if let Some(data) = line.strip_prefix("data:").map(|s| s.trim_start()) {
if data == "[DONE]" {
events.push(Ok(StreamEvent::Done { usage: None }));
continue;
}
if let Some(event) = parse_responses_event(data, current_event_type.as_deref()) {
events.push(Ok(event));
} else {
tracing::debug!(data, "openai-responses: unparsed SSE data");
}
}
}
events
}
fn parse_responses_event(data: &str, event_type: Option<&str>) -> Option<StreamEvent> {
let v: Value = serde_json::from_str(data).ok()?;
if let Some(err) = v.get("error") {
let msg = err["message"]
.as_str()
.unwrap_or("unknown API error")
.to_owned();
return Some(StreamEvent::Error(msg));
}
let evt_type = event_type.or_else(|| v["type"].as_str()).unwrap_or("");
match evt_type {
"response.output_text.delta" => {
let delta = match &v["delta"] {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
_ => String::new(),
};
if delta.is_empty() {
None
} else {
Some(StreamEvent::TextDelta(delta))
}
}
"response.output_item.done" => {
let item = &v["item"];
if item["type"].as_str() == Some("function_call") {
let id = item["call_id"]
.as_str()
.or_else(|| item["id"].as_str())
.unwrap_or("")
.to_owned();
let name = restore_tool_name(item["name"].as_str().unwrap_or(""));
let args_str = item["arguments"].as_str().unwrap_or("{}");
let input = serde_json::from_str(args_str)
.unwrap_or_else(|_| Value::String(args_str.to_owned()));
Some(StreamEvent::ToolCall { id, name, input })
} else {
None
}
}
"response.completed" | "response.done" => {
let usage = v
.pointer("/response/usage")
.or_else(|| v.get("usage"))
.and_then(|u| u.as_object())
.map(|u| TokenUsage {
input: u.get("input_tokens").and_then(Value::as_u64).unwrap_or(0),
output: u.get("output_tokens").and_then(Value::as_u64).unwrap_or(0),
cache_creation: 0,
cache_read: u
.get("input_tokens_details")
.and_then(|d| d.get("cached_tokens"))
.and_then(Value::as_u64)
.unwrap_or(0),
..Default::default()
});
Some(StreamEvent::Done { usage })
}
_ if evt_type.starts_with("response.") => None,
_ => parse_completions_fallback(&v),
}
}
fn parse_completions_fallback(v: &Value) -> Option<StreamEvent> {
let choices = v["choices"].as_array()?;
let choice = choices.first()?;
let delta = &choice["delta"];
if let Some(tool_calls) = delta["tool_calls"].as_array()
&& let Some(tc) = tool_calls.first()
{
let func = &tc["function"];
let id = tc["id"].as_str().unwrap_or("").to_owned();
let name = func["name"].as_str().unwrap_or("").to_owned();
let args_str = func["arguments"].as_str().unwrap_or("");
let input = if args_str.is_empty() {
Value::Object(Default::default())
} else {
Value::String(args_str.to_owned())
};
return Some(StreamEvent::ToolCall { id, name, input });
}
if let Some(text) = delta["content"].as_str()
&& !text.is_empty()
{
return Some(StreamEvent::TextDelta(text.to_owned()));
}
if choice["finish_reason"].is_string() {
let usage = v["usage"].as_object().map(|u| TokenUsage {
input: u.get("prompt_tokens").and_then(Value::as_u64).unwrap_or(0),
output: u
.get("completion_tokens")
.and_then(Value::as_u64)
.unwrap_or(0),
cache_creation: 0,
cache_read: u
.get("prompt_tokens_details")
.and_then(|d| d.get("cached_tokens"))
.and_then(Value::as_u64)
.unwrap_or(0),
..Default::default()
});
return Some(StreamEvent::Done { usage });
}
None
}
#[cfg(test)]
mod tests {
use super::{
super::{LlmRequest, Message, MessageContent, Role},
*,
};
fn make_request() -> LlmRequest {
LlmRequest {
fallback_models: Vec::new(),
model: "gpt-4o".to_owned(),
..Default::default()
}
}
#[test]
fn request_serializes_model() {
let req = make_request();
let body = build_request_body(&req).unwrap();
assert_eq!(body["model"].as_str().unwrap(), "gpt-4o");
}
#[test]
fn responses_items_pair_reorder_and_drop_orphans() {
let mut items = vec![
serde_json::json!({"type":"function_call","call_id":"A","name":"shell"}),
serde_json::json!({"type":"function_call_output","call_id":"Z","output":"orphan"}),
serde_json::json!({"type":"function_call_output","call_id":"A","output":"ok"}),
serde_json::json!({"type":"function_call","call_id":"B","name":"shell"}),
serde_json::json!({"role":"user","content":[{"type":"input_text","text":"hi"}]}),
];
normalize_responses_items(&mut items);
let types: Vec<String> = items
.iter()
.map(|i| {
i.get("type")
.and_then(|t| t.as_str())
.or_else(|| i.get("role").and_then(|r| r.as_str()))
.unwrap_or("")
.to_owned()
})
.collect();
assert_eq!(types, vec!["function_call", "function_call_output", "user"]);
assert_eq!(items[0]["call_id"].as_str(), Some("A"));
assert_eq!(items[1]["call_id"].as_str(), Some("A"));
}
#[test]
fn message_role_user() {
let req = LlmRequest {
fallback_models: Vec::new(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("hello".to_owned()),
rsclaw_hidden: None,
}],
..make_request()
};
let body = build_request_body(&req).unwrap();
let msgs = body["messages"].as_array().unwrap();
assert_eq!(msgs[0]["role"].as_str().unwrap(), "user");
}
#[test]
fn assistant_with_tool_calls_includes_reasoning_content_when_thinking_enabled() {
let msg = Message {
role: Role::Assistant,
content: MessageContent::Parts(vec![ContentPart::ToolUse {
id: "call_1".into(),
name: "web_search".into(),
input: serde_json::json!({"q": "rust"}),
}]),
rsclaw_hidden: None,
};
let out = serialize_message(&msg, true);
assert_eq!(out["role"], "assistant");
assert!(out["tool_calls"].is_array());
assert!(
out.get("reasoning_content").is_some(),
"thinking-enabled assistant MUST emit reasoning_content field; got {out}"
);
assert_eq!(out["reasoning_content"], "");
}
#[test]
fn assistant_with_text_only_includes_reasoning_content_when_thinking_enabled() {
let msg = Message {
role: Role::Assistant,
content: MessageContent::Text("hello".into()),
rsclaw_hidden: None,
};
let out = serialize_message(&msg, true);
assert_eq!(out["role"], "assistant");
assert_eq!(out["content"], "hello");
assert_eq!(out["reasoning_content"], "");
}
#[test]
fn assistant_preserves_captured_reasoning_when_thinking_enabled() {
let msg = Message {
role: Role::Assistant,
content: MessageContent::Parts(vec![
ContentPart::Reasoning {
text: "Let me think...".into(),
},
ContentPart::ToolUse {
id: "call_2".into(),
name: "web_search".into(),
input: serde_json::json!({}),
},
]),
rsclaw_hidden: None,
};
let out = serialize_message(&msg, true);
assert_eq!(out["reasoning_content"], "Let me think...");
}
#[test]
fn parse_event_emits_both_reasoning_and_tool_call_in_one_chunk() {
let data = r#"{"choices":[{"delta":{"reasoning_content":"Let me think","tool_calls":[{"id":"call_1","type":"function","function":{"name":"web_search","arguments":"{\"q\":\"rust\"}"}}]}}]}"#;
let events = parse_event(data);
assert_eq!(
events.len(),
2,
"expected reasoning + tool_call; got {events:?}"
);
match &events[0] {
StreamEvent::ReasoningDelta(t) => assert_eq!(t, "Let me think"),
other => panic!("expected ReasoningDelta first, got {other:?}"),
}
match &events[1] {
StreamEvent::ToolCall { id, name, .. } => {
assert_eq!(id, "call_1");
assert_eq!(name, "web_search");
}
other => panic!("expected ToolCall second, got {other:?}"),
}
}
#[test]
fn parse_event_emits_reasoning_then_text_in_one_chunk() {
let data =
r#"{"choices":[{"delta":{"reasoning_content":"...so the answer is","content":"42"}}]}"#;
let events = parse_event(data);
assert_eq!(events.len(), 2);
assert!(matches!(&events[0], StreamEvent::ReasoningDelta(t) if t == "...so the answer is"));
assert!(matches!(&events[1], StreamEvent::TextDelta(t) if t == "42"));
}
#[test]
fn parse_event_emits_done_with_usage_after_deltas() {
let data = r#"{"choices":[{"delta":{"content":"!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":3}}"#;
let events = parse_event(data);
assert_eq!(events.len(), 2);
assert!(matches!(&events[0], StreamEvent::TextDelta(t) if t == "!"));
match &events[1] {
StreamEvent::Done { usage } => {
let u = usage.as_ref().expect("usage populated");
assert_eq!(u.input, 10);
assert_eq!(u.output, 3);
}
other => panic!("expected Done, got {other:?}"),
}
}
#[test]
fn parse_event_drops_empty_reasoning_and_empty_content() {
let data = r#"{"choices":[{"delta":{"reasoning_content":"","content":""}}]}"#;
let events = parse_event(data);
assert!(
events.is_empty(),
"empty fields should produce no events, got {events:?}"
);
}
#[test]
fn assistant_omits_reasoning_content_when_thinking_disabled() {
let msg = Message {
role: Role::Assistant,
content: MessageContent::Parts(vec![ContentPart::ToolUse {
id: "call_3".into(),
name: "x".into(),
input: serde_json::json!({}),
}]),
rsclaw_hidden: None,
};
let out = serialize_message(&msg, false);
assert!(
out.get("reasoning_content").is_none(),
"non-thinking assistant must NOT emit reasoning_content; got {out}"
);
}
mod responses_tests {
use super::*;
fn make_responses_request() -> LlmRequest {
LlmRequest {
fallback_models: Vec::new(),
model: "doubao-seed-2-0-pro-260215".to_owned(),
..Default::default()
}
}
#[test]
fn request_uses_input_not_messages() {
let req = LlmRequest {
fallback_models: Vec::new(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("hello".to_owned()),
rsclaw_hidden: None,
}],
..make_responses_request()
};
let body = build_responses_body(&req, &HashMap::new()).unwrap();
assert!(body.get("input").is_some(), "should have 'input' field");
assert!(
body.get("messages").is_none(),
"should NOT have 'messages' field"
);
}
#[test]
fn system_goes_to_instructions() {
let req = LlmRequest {
fallback_models: Vec::new(),
system: Some("be helpful".to_owned()),
..make_responses_request()
};
let body = build_responses_body(&req, &HashMap::new()).unwrap();
assert_eq!(body["instructions"].as_str().unwrap(), "be helpful");
}
#[test]
fn content_parts_use_input_text() {
let req = LlmRequest {
fallback_models: Vec::new(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("hello".to_owned()),
rsclaw_hidden: None,
}],
..make_responses_request()
};
let body = build_responses_body(&req, &HashMap::new()).unwrap();
let input = body["input"].as_array().unwrap();
let part_type = input[0]["content"][0]["type"].as_str().unwrap();
assert_eq!(part_type, "input_text");
}
#[test]
fn image_uses_input_image() {
let req = LlmRequest {
fallback_models: Vec::new(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Parts(vec![ContentPart::Image {
url: "https://example.com/img.png".to_owned(),
}]),
rsclaw_hidden: None,
}],
..make_responses_request()
};
let body = build_responses_body(&req, &HashMap::new()).unwrap();
let input = body["input"].as_array().unwrap();
let part = &input[0]["content"][0];
assert_eq!(part["type"].as_str().unwrap(), "input_image");
assert_eq!(
part["image_url"].as_str().unwrap(),
"https://example.com/img.png"
);
}
#[test]
fn parse_text_delta_event() {
let data = r#"{"type":"response.output_text.delta","delta":"hello"}"#;
let event = parse_responses_event(data, Some("response.output_text.delta"));
assert!(matches!(event, Some(StreamEvent::TextDelta(ref t)) if t == "hello"));
}
#[test]
fn parse_done_event() {
let data = r#"{"type":"response.completed","response":{"usage":{"input_tokens":10,"output_tokens":20}}}"#;
let event = parse_responses_event(data, Some("response.completed"));
match event {
Some(StreamEvent::Done { usage: Some(u) }) => {
assert_eq!(u.input, 10);
assert_eq!(u.output, 20);
}
other => panic!("expected Done with usage, got {other:?}"),
}
}
#[test]
fn parse_tool_call_event() {
let data = r#"{"type":"response.output_item.done","item":{"type":"function_call","id":"call_123","name":"read_file","arguments":"{\"path\":\"/tmp/x\"}"}}"#;
let event = parse_responses_event(data, Some("response.output_item.done"));
match event {
Some(StreamEvent::ToolCall { id, name, input }) => {
assert_eq!(id, "call_123");
assert_eq!(name, "read_file");
assert_eq!(input["path"].as_str().unwrap(), "/tmp/x");
}
other => panic!("expected ToolCall, got {other:?}"),
}
}
#[test]
fn parse_completions_fallback_test() {
let data = r#"{"choices":[{"delta":{"content":"world"},"finish_reason":null}]}"#;
let event = parse_responses_event(data, None);
assert!(matches!(event, Some(StreamEvent::TextDelta(ref t)) if t == "world"));
}
#[test]
fn strip_think_normal() {
let text = "<think>reasoning</think>answer text";
assert_eq!(strip_think_tags(text), "answer text");
}
#[test]
fn strip_think_lone_close_before_open() {
let text = "</think>\nThe IP is 127.0.0.1 port 5432\n<think>extra</think>";
let result = strip_think_tags(text);
assert!(
result.contains("127.0.0.1"),
"IP should not be eaten: {result:?}"
);
assert!(
result.contains("5432"),
"port should not be eaten: {result:?}"
);
}
#[test]
fn strip_think_no_tags() {
let text = "The answer is 127.0.0.1 and port 5432";
assert_eq!(strip_think_tags(text), text);
}
#[test]
fn strip_think_unclosed() {
let text = "prefix <think>partial reasoning";
assert_eq!(strip_think_tags(text), "prefix ");
}
#[test]
fn strip_think_multiple_blocks() {
let text = "<think>a</think>answer<think>b</think> rest";
assert_eq!(strip_think_tags(text), "answer rest");
}
#[tokio::test]
async fn parse_responses_sse_chunk_with_event_lines() {
let raw = b"event: response.output_text.delta\ndata: {\"delta\":\"hi\"}\n\nevent: response.completed\ndata: {\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n";
let buffer = tokio::sync::Mutex::new(String::new());
let events = parse_responses_sse_chunk_buffered(
Ok(bytes::Bytes::from_static(raw)),
&buffer,
&tokio::sync::Mutex::new(Vec::new()),
)
.await;
assert_eq!(events.len(), 2);
assert!(matches!(&events[0], Ok(StreamEvent::TextDelta(t)) if t == "hi"));
assert!(matches!(&events[1], Ok(StreamEvent::Done { .. })));
}
#[tokio::test]
async fn parse_responses_numeric_delta_not_lost() {
let raw = b"event: response.output_text.delta\ndata: {\"delta\":42}\n\n";
let buffer = tokio::sync::Mutex::new(String::new());
let events = parse_responses_sse_chunk_buffered(
Ok(bytes::Bytes::from_static(raw)),
&buffer,
&tokio::sync::Mutex::new(Vec::new()),
)
.await;
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], Ok(StreamEvent::TextDelta(t)) if t == "42"));
}
}
#[tokio::test]
async fn sse_line_buffer_handles_split_lines() {
let buffer = tokio::sync::Mutex::new(String::new());
let utf8_rem = tokio::sync::Mutex::new(Vec::new());
let chunk1 = Ok(bytes::Bytes::from(
r#"data: {"choices":[{"delta":{"content":"he"#,
));
let events1 = parse_sse_chunk_with_buffer(chunk1, &buffer, &utf8_rem).await;
assert!(
events1.is_empty(),
"Expected no events from incomplete chunk, got {:?}",
events1
);
{
let buf = buffer.lock().await;
assert!(
buf.contains("he"),
"Buffer should contain 'he', got: {}",
*buf
);
}
let chunk2 = Ok(bytes::Bytes::from("l\"}}]}\n"));
let events2 = parse_sse_chunk_with_buffer(chunk2, &buffer, &utf8_rem).await;
assert_eq!(events2.len(), 1, "Expected 1 event, got {:?}", events2);
match &events2[0] {
Ok(StreamEvent::TextDelta(text)) => assert_eq!(text, "hel"),
other => panic!("Expected TextDelta, got {:?}", other),
}
}
#[tokio::test]
async fn sse_line_buffer_handles_multiple_lines() {
let buffer = tokio::sync::Mutex::new(String::new());
let utf8_rem = tokio::sync::Mutex::new(Vec::new());
let chunk = Ok(bytes::Bytes::from(
"data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\
data: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n",
));
let events = parse_sse_chunk_with_buffer(chunk, &buffer, &utf8_rem).await;
assert_eq!(events.len(), 2);
match &events[0] {
Ok(StreamEvent::TextDelta(text)) => assert_eq!(text, "hello"),
other => panic!("Expected TextDelta, got {:?}", other),
}
match &events[1] {
Ok(StreamEvent::TextDelta(text)) => assert_eq!(text, " world"),
other => panic!("Expected TextDelta, got {:?}", other),
}
}
#[tokio::test]
async fn sse_line_buffer_handles_trailing_incomplete_line() {
let buffer = tokio::sync::Mutex::new(String::new());
let utf8_rem = tokio::sync::Mutex::new(Vec::new());
let chunk1 = Ok(bytes::Bytes::from(
"data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\
data: {\"choices\":[{\"delta\":{\"content\":\"incom",
));
let events1 = parse_sse_chunk_with_buffer(chunk1, &buffer, &utf8_rem).await;
assert_eq!(events1.len(), 1);
let chunk2 = Ok(bytes::Bytes::from("plete\"}}]}\n"));
let events2 = parse_sse_chunk_with_buffer(chunk2, &buffer, &utf8_rem).await;
assert_eq!(events2.len(), 1);
match &events2[0] {
Ok(StreamEvent::TextDelta(text)) => assert_eq!(text, "incomplete"),
other => panic!("Expected TextDelta, got {:?}", other),
}
}
}