use super::error::{ProviderError, Result};
use super::rate_limiter::RateLimiter;
use super::r#trait::{Provider, ProviderStream};
use super::types::*;
use crate::brain::tokenizer::{count_message_tokens, count_tokens};
use async_trait::async_trait;
use futures::stream::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
const DEFAULT_OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const DEFAULT_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(90);
const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(15);
const STRIP_OPEN_TAGS: &[&str] = &["<think>", "<!-- reasoning -->", "<!--"];
const STRIP_CLOSE_TAGS: &[&[&str]] = &[
&["</think>"],
&["<!-- /reasoning -->", "</think>", "-->"], &["-->"],
];
const THINK_BLOCK_MAX_BYTES: usize = 200_000;
const MAX_OPEN_TAG_CARRY: usize = 17;
fn retry_reason(err: &super::error::ProviderError) -> String {
use super::error::ProviderError;
match err {
ProviderError::HttpError(e) => super::error::describe_reqwest_error(e),
ProviderError::Timeout(_) => "timed out".to_string(),
ProviderError::RateLimitExceeded(_) => "rate limited".to_string(),
ProviderError::ApiError { status, .. } if *status == 429 => "rate limited".to_string(),
ProviderError::ApiError { status, .. } if *status >= 500 => {
format!("server error {status}")
}
ProviderError::ApiError { status, .. } => format!("HTTP {status}"),
_ => "transient error".to_string(),
}
}
pub(crate) fn filter_think_tags(
text: &str,
inside_think: &mut bool,
active_close_tag: &mut usize,
bytes_consumed: &mut usize,
carry: &mut String,
) -> (String, String) {
let mut owned: String;
let input_ref: &str = if carry.is_empty() {
text
} else {
owned = std::mem::take(carry);
owned.push_str(text);
owned.as_str()
};
let mut result = String::new();
let mut reasoning = String::new();
let mut remaining = input_ref;
let is_reasoning_block = |idx: usize| idx < 2;
loop {
if *inside_think {
*bytes_consumed += remaining.len();
if *bytes_consumed > THINK_BLOCK_MAX_BYTES {
tracing::warn!(
"⚠️ Think-tag filter consumed {} bytes without close tag \
(tag_idx={}) — still waiting for close, continuing to suppress",
*bytes_consumed,
*active_close_tag,
);
if is_reasoning_block(*active_close_tag) {
reasoning.push_str(remaining);
}
*bytes_consumed = 0;
break;
}
let close_candidates = STRIP_CLOSE_TAGS[*active_close_tag];
let earliest_close = close_candidates
.iter()
.filter_map(|close| remaining.find(close).map(|pos| (pos, *close)))
.min_by_key(|(pos, _)| *pos);
if let Some((end, close)) = earliest_close {
if is_reasoning_block(*active_close_tag) {
reasoning.push_str(&remaining[..end]);
}
remaining = &remaining[end + close.len()..];
*inside_think = false;
*bytes_consumed = 0;
} else {
if is_reasoning_block(*active_close_tag) {
reasoning.push_str(remaining);
}
break;
}
} else {
let mut earliest: Option<(usize, usize)> = None; for (i, open) in STRIP_OPEN_TAGS.iter().enumerate() {
if let Some(pos) = remaining.find(open)
&& earliest.is_none_or(|(best, _)| pos < best)
{
earliest = Some((pos, i));
}
}
if let Some((pos, tag_idx)) = earliest {
result.push_str(&remaining[..pos]);
remaining = &remaining[pos + STRIP_OPEN_TAGS[tag_idx].len()..];
*inside_think = true;
*active_close_tag = tag_idx;
*bytes_consumed = 0;
} else {
let mut orphan: Option<(usize, &str)> = None;
for close_candidates in STRIP_CLOSE_TAGS.iter() {
for close in close_candidates.iter() {
if *close == "-->" || close.is_empty() {
continue;
}
if let Some(pos) = remaining.find(close)
&& orphan.is_none_or(|(best, _)| pos < best)
{
orphan = Some((pos, close));
}
}
}
if let Some((pos, close)) = orphan {
reasoning.push_str(&remaining[..pos]);
remaining = &remaining[pos + close.len()..];
continue;
}
let tail_keep = open_tag_prefix_len(remaining);
if tail_keep > 0 {
let split_at = remaining.len() - tail_keep;
result.push_str(&remaining[..split_at]);
carry.push_str(&remaining[split_at..]);
} else {
result.push_str(remaining);
}
break;
}
}
}
(result, reasoning)
}
pub(crate) fn tool_marker_prefix_len(s: &str, markers: &[&str]) -> usize {
let max_marker_len = markers.iter().map(|m| m.len()).max().unwrap_or(0);
if max_marker_len <= 1 {
return 0;
}
let start = s.len().saturating_sub(max_marker_len - 1);
for i in start..s.len() {
if !s.is_char_boundary(i) {
continue;
}
let suffix = &s[i..];
if suffix.is_empty() {
continue;
}
if markers
.iter()
.any(|m| m.len() > suffix.len() && m.starts_with(suffix))
{
return suffix.len();
}
}
0
}
fn open_tag_prefix_len(s: &str) -> usize {
let tail_starts = s
.char_indices()
.map(|(i, _)| i)
.filter(|i| s.len() - i <= MAX_OPEN_TAG_CARRY);
for start in tail_starts {
let suffix = &s[start..];
for open in STRIP_OPEN_TAGS {
if open.len() > suffix.len() && open.starts_with(suffix) {
return suffix.len();
}
}
}
0
}
pub(crate) fn strip_think_blocks(text: &str) -> String {
let mut result = text.to_string();
for (open, close_candidates) in STRIP_OPEN_TAGS.iter().zip(STRIP_CLOSE_TAGS.iter()) {
while let Some(start) = result.find(open) {
let earliest_close = close_candidates
.iter()
.filter_map(|close| result[start..].find(close).map(|end| (end, *close)))
.min_by_key(|(end, _)| *end);
if let Some((end, close)) = earliest_close {
result = format!(
"{}{}",
&result[..start],
&result[start + end + close.len()..]
);
} else {
result = result[..start].to_string();
break;
}
}
}
for close_candidates in STRIP_CLOSE_TAGS.iter() {
for close in close_candidates.iter() {
if *close == "-->" || close.is_empty() {
continue;
}
if let Some(pos) = result.find(close) {
result = result[pos + close.len()..].to_string();
break;
}
}
}
result.trim().to_string()
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum BareToolArrayMatch {
Full,
Prefix,
None,
}
pub(crate) fn classify_bare_tool_array(s: &str) -> BareToolArrayMatch {
fn step<'a>(t: &'a str, literal: &str, state: &mut BareToolArrayMatch) -> Option<&'a str> {
let t = t.trim_start();
if t.is_empty() {
*state = BareToolArrayMatch::Prefix;
return None;
}
if t.len() < literal.len() {
*state = if literal.starts_with(t) {
BareToolArrayMatch::Prefix
} else {
BareToolArrayMatch::None
};
return None;
}
if let Some(rest) = t.strip_prefix(literal) {
Some(rest)
} else {
*state = BareToolArrayMatch::None;
None
}
}
let t = s.trim_start();
if t.is_empty() {
return BareToolArrayMatch::Prefix;
}
let mut state = BareToolArrayMatch::None;
let Some(t) = step(t, "[", &mut state) else {
return state;
};
let Some(t) = step(t, "{", &mut state) else {
return state;
};
let Some(t) = step(t, "\"id\"", &mut state) else {
return state;
};
let Some(t) = step(t, ":", &mut state) else {
return state;
};
let Some(_) = step(t, "\"call_", &mut state) else {
return state;
};
BareToolArrayMatch::Full
}
pub(crate) fn extract_text_tool_calls(text: &str) -> (Vec<(String, serde_json::Value)>, String) {
let has_claude_style = KNOWN_TOOL_NAMES
.iter()
.any(|t| text.contains(&format!("<{}>", t)));
let has_bare_array_signal = text.contains("\"id\":\"call_") || text.contains("\"id\": \"call_");
let has_dict_by_id_signal =
text.contains("\"call_") && (text.contains("\"name\"") || text.contains("\"function\""));
let has_bare_command_args = text.contains("{\"command\":")
|| text.contains("{ \"command\":")
|| text.contains("{\"command\" :");
let has_invoke_signal = text.contains("<invoke name=") || text.contains("invoke name=\"");
let has_bare_name_args = super::bare_tool_call_extractor::has_bare_name_args_signal(text);
if !text.contains("<tool_call>")
&& !text.contains("<tool_call_list>")
&& !text.contains("<function=")
&& !text.contains("tool_call:")
&& !text.contains("\"tool_calls\"")
&& !text.contains("\"tool_call\"")
&& !has_claude_style
&& !has_bare_array_signal
&& !has_dict_by_id_signal
&& !has_bare_command_args
&& !has_invoke_signal
&& !has_bare_name_args
{
return (Vec::new(), text.to_string());
}
let mut tool_calls: Vec<(String, serde_json::Value)> = Vec::new();
let mut strip_ranges: Vec<(usize, usize)> = Vec::new();
if has_claude_style {
for (start, end, name, input) in extract_claude_style_tool_calls(text) {
tool_calls.push((name, input));
strip_ranges.push((start, end));
}
}
if has_bare_array_signal {
let anchors = ["\"id\":\"call_", "\"id\": \"call_"];
let mut search_from = 0;
loop {
let next = anchors
.iter()
.filter_map(|a| text[search_from..].find(a).map(|p| (search_from + p, *a)))
.min_by_key(|(p, _)| *p);
let Some((anchor_pos, anchor_lit)) = next else {
break;
};
let window_start = anchor_pos.saturating_sub(64);
let bracket_pos = text[window_start..anchor_pos]
.rfind('[')
.map(|r| window_start + r);
let Some(arr_pos) = bracket_pos else {
search_from = anchor_pos + anchor_lit.len();
continue;
};
if strip_ranges
.iter()
.any(|(s, e)| *s <= arr_pos && arr_pos < *e)
{
search_from = anchor_pos + anchor_lit.len();
continue;
}
match extract_balanced_json(&text[arr_pos..]) {
Some(consumed) => {
let arr_slice = &text[arr_pos..arr_pos + consumed];
if let Ok(v) = serde_json::from_str::<serde_json::Value>(arr_slice)
&& let Some(items) = v.as_array()
{
let mut found_any = false;
for item in items {
if let Some(call) = parse_tool_call_value(item) {
tool_calls.push(call);
found_any = true;
}
}
if found_any {
strip_ranges.push((arr_pos, arr_pos + consumed));
search_from = arr_pos + consumed;
continue;
}
}
search_from = anchor_pos + anchor_lit.len();
}
None => {
search_from = anchor_pos + anchor_lit.len();
}
}
}
}
if text.contains("<tool_call_list>") {
let mut seen_blocks: Vec<(usize, usize)> = Vec::new();
for (start, end, name, input) in extract_tool_call_list_calls(text) {
tool_calls.push((name, input));
if !seen_blocks.contains(&(start, end)) {
seen_blocks.push((start, end));
strip_ranges.push((start, end));
}
}
}
if has_dict_by_id_signal {
let mut search_from = 0;
while let Some(rel) = text[search_from..].find("\"call_") {
let anchor_pos = search_from + rel;
let mut back = anchor_pos;
while back > 0 {
let b = text.as_bytes()[back - 1];
if b.is_ascii_whitespace() || b == b'\n' || b == b'\r' {
back -= 1;
continue;
}
break;
}
if back == 0 || text.as_bytes()[back - 1] != b'{' {
search_from = anchor_pos + "\"call_".len();
continue;
}
let obj_start = back - 1;
if strip_ranges
.iter()
.any(|(s, e)| *s <= obj_start && obj_start < *e)
{
search_from = anchor_pos + "\"call_".len();
continue;
}
let Some(consumed) = extract_balanced_json(&text[obj_start..]) else {
search_from = anchor_pos + "\"call_".len();
continue;
};
let obj_slice = &text[obj_start..obj_start + consumed];
let Ok(v) = serde_json::from_str::<serde_json::Value>(obj_slice) else {
search_from = anchor_pos + "\"call_".len();
continue;
};
let Some(obj) = v.as_object() else {
search_from = anchor_pos + "\"call_".len();
continue;
};
let mut found_any = false;
for (key, val) in obj {
if !key.starts_with("call_") {
continue;
}
if let Some(call) = parse_tool_call_value(val) {
tool_calls.push(call);
found_any = true;
}
}
if found_any {
strip_ranges.push((obj_start, obj_start + consumed));
search_from = obj_start + consumed;
} else {
search_from = anchor_pos + "\"call_".len();
}
}
}
let mut i: usize = 0;
let bytes = text.as_bytes();
while i < bytes.len() {
let tc_at = text[i..].find("<tool_call>").map(|r| i + r);
let fn_at = text[i..].find("<function=").map(|r| i + r);
let bare_at = text[i..].find("tool_call:").map(|r| i + r);
let arr_at = text[i..].find("\"tool_calls\"").map(|r| i + r);
let sing_at = {
let candidate = text[i..].find("\"tool_call\"").map(|r| i + r);
match candidate {
Some(p)
if text.as_bytes().get(p + "\"tool_call\"".len() - 1).copied()
!= Some(b'"') =>
{
None
}
Some(p) if text[p..].starts_with("\"tool_calls\"") => None,
other => other,
}
};
let next = [tc_at, fn_at, bare_at, arr_at, sing_at]
.into_iter()
.flatten()
.min();
let Some(start) = next else { break };
if bare_at == Some(start) {
if start > 0 {
let prev = text.as_bytes()[start - 1];
let is_boundary = prev.is_ascii_whitespace()
|| matches!(
prev,
b',' | b';' | b':' | b'[' | b'(' | b'{' | b'\n' | b'\r'
);
if !is_boundary {
i = start + "tool_call:".len();
continue;
}
}
let body_start = start + "tool_call:".len();
let brace_rel = text[body_start..]
.char_indices()
.find(|(_, c)| !c.is_whitespace())
.map(|(idx, _)| idx);
let brace_abs = match brace_rel {
Some(rel) if text.as_bytes().get(body_start + rel) == Some(&b'{') => {
body_start + rel
}
_ => {
i = body_start;
continue;
}
};
match extract_balanced_json(&text[brace_abs..]) {
Some(consumed) => {
let json_slice = &text[brace_abs..brace_abs + consumed];
if let Some(call) = parse_qwen_tool_json(json_slice) {
tool_calls.push(call);
strip_ranges.push((start, brace_abs + consumed));
i = brace_abs + consumed;
continue;
}
i = body_start;
}
None => {
i = body_start;
}
}
continue;
} else if arr_at == Some(start) {
let wrapper = text[..start].rfind('{');
let wrapper_start = match wrapper {
Some(br) if start - br <= 4 => br,
_ => {
i = start + "\"tool_calls\"".len();
continue;
}
};
match extract_balanced_json(&text[wrapper_start..]) {
Some(consumed) => {
let env_slice = &text[wrapper_start..wrapper_start + consumed];
if let Ok(v) = serde_json::from_str::<serde_json::Value>(env_slice)
&& let Some(arr) = v.get("tool_calls").and_then(|a| a.as_array())
{
let mut found_any = false;
for item in arr {
if let Some(call) = parse_tool_call_value(item) {
tool_calls.push(call);
found_any = true;
}
}
if found_any {
strip_ranges.push((wrapper_start, wrapper_start + consumed));
i = wrapper_start + consumed;
continue;
}
}
i = start + "\"tool_calls\"".len();
}
None => {
i = start + "\"tool_calls\"".len();
}
}
continue;
} else if sing_at == Some(start) {
let wrapper = text[..start].rfind('{');
let wrapper_start = match wrapper {
Some(br) if start - br <= 4 => br,
_ => {
i = start + "\"tool_call\"".len();
continue;
}
};
match extract_balanced_json_tolerant(&text[wrapper_start..]) {
Some(consumed) => {
let env_slice = &text[wrapper_start..wrapper_start + consumed];
let recovered = recover_tool_call_from_malformed_json(env_slice);
if let Some(call) = recovered {
tool_calls.push(call);
strip_ranges.push((wrapper_start, wrapper_start + consumed));
i = wrapper_start + consumed;
continue;
}
i = start + "\"tool_call\"".len();
}
None => {
i = start + "\"tool_call\"".len();
}
}
continue;
} else if tc_at == Some(start) {
let body_start = start + "<tool_call>".len();
let brace_rel = text[body_start..]
.char_indices()
.find(|(_, c)| !c.is_whitespace())
.map(|(idx, _)| idx);
let brace_abs = match brace_rel {
Some(rel) if text.as_bytes().get(body_start + rel) == Some(&b'{') => {
body_start + rel
}
_ => {
i = body_start;
continue;
}
};
match extract_balanced_json(&text[brace_abs..]) {
Some(consumed) => {
let json_slice = &text[brace_abs..brace_abs + consumed];
if let Some(call) = parse_qwen_tool_json(json_slice) {
tool_calls.push(call);
}
let mut end = brace_abs + consumed;
let after = &text[end..];
let ws_len = after.len() - after.trim_start().len();
if after.trim_start().starts_with("</tool_call>") {
end += ws_len + "</tool_call>".len();
}
strip_ranges.push((start, end));
i = end;
}
None => {
i = body_start;
}
}
} else {
let tag_start = start;
let after = &text[tag_start..];
let open_end = match after.find('>') {
Some(r) => tag_start + r + 1,
None => {
i = tag_start + "<function=".len();
continue;
}
};
let name = text[tag_start + "<function=".len()..open_end - 1].trim();
if name.is_empty() || !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
i = open_end;
continue;
}
let tail = &text[open_end..];
let candidates = [
tail.find("</tool_call>").map(|r| (r, "</tool_call>".len())),
tail.find("<function=").map(|r| (r, 0usize)),
tail.find("</function>").map(|r| (r, "</function>".len())),
];
let pick = candidates.iter().filter_map(|o| *o).min_by_key(|(r, _)| *r);
let (body_rel, close_len) = match pick {
Some(p) => p,
None => {
(tail.len(), 0)
}
};
let body = &tail[..body_rel];
let input = parse_function_params(body);
tool_calls.push((name.to_string(), input));
let end = open_end + body_rel + close_len;
strip_ranges.push((start, end));
i = end;
}
}
if has_bare_command_args {
let mut search_from = 0;
while let Some(rel) = text[search_from..].find("\"command\"") {
let anchor_pos = search_from + rel;
let mut back = anchor_pos;
while back > 0 {
let b = text.as_bytes()[back - 1];
if b.is_ascii_whitespace() || b == b'\n' || b == b'\r' {
back -= 1;
continue;
}
break;
}
if back == 0 || text.as_bytes()[back - 1] != b'{' {
search_from = anchor_pos + "\"command\"".len();
continue;
}
let obj_start = back - 1;
if strip_ranges
.iter()
.any(|(s, e)| *s <= obj_start && obj_start < *e)
{
search_from = anchor_pos + "\"command\"".len();
continue;
}
let mut prev = obj_start;
while prev > 0 {
let b = text.as_bytes()[prev - 1];
if b.is_ascii_whitespace() || b == b'\n' || b == b'\r' {
prev -= 1;
continue;
}
break;
}
if prev > 0 && text.as_bytes()[prev - 1] == b':' {
let mut k = prev - 1;
while k > 0 {
let b = text.as_bytes()[k - 1];
if b.is_ascii_whitespace() {
k -= 1;
continue;
}
break;
}
if k > 0 && text.as_bytes()[k - 1] == b'"' {
search_from = anchor_pos + "\"command\"".len();
continue;
}
}
let Some(consumed) = extract_balanced_json(&text[obj_start..]) else {
search_from = anchor_pos + "\"command\"".len();
continue;
};
let obj_slice = &text[obj_start..obj_start + consumed];
let Ok(v) = serde_json::from_str::<serde_json::Value>(obj_slice) else {
search_from = anchor_pos + "\"command\"".len();
continue;
};
let Some(obj) = v.as_object() else {
search_from = anchor_pos + "\"command\"".len();
continue;
};
let known = ["command", "working_dir", "timeout_secs"];
let all_known = obj.keys().all(|k| known.contains(&k.as_str()));
let has_command_string = obj.get("command").and_then(|c| c.as_str()).is_some();
if !all_known || !has_command_string {
search_from = anchor_pos + "\"command\"".len();
continue;
}
tool_calls.push(("bash".to_string(), v));
strip_ranges.push((obj_start, obj_start + consumed));
search_from = obj_start + consumed;
}
}
if has_invoke_signal {
let invoke_calls = extract_invoke_style_tool_calls(text, &strip_ranges);
for (s, e, name, args) in invoke_calls {
tool_calls.push((name, args));
strip_ranges.push((s, e));
}
widen_strip_to_known_wrappers(text, &mut strip_ranges);
}
if has_bare_name_args {
for m in super::bare_tool_call_extractor::extract_bare_name_args_calls(
text,
&strip_ranges,
&tool_calls,
) {
if !m.already_in_existing {
tool_calls.push((m.name, m.args));
}
strip_ranges.push((m.strip_start, m.strip_end));
}
}
if strip_ranges.is_empty() {
return (tool_calls, text.to_string());
}
strip_ranges.sort_by_key(|(s, _)| *s);
let mut out = String::with_capacity(text.len());
let mut cursor = 0;
for (s, e) in strip_ranges {
if s >= cursor {
if s > cursor {
out.push_str(&text[cursor..s]);
}
cursor = e;
} else if e > cursor {
cursor = e;
}
}
if cursor < text.len() {
out.push_str(&text[cursor..]);
}
(tool_calls, out.trim().to_string())
}
pub(crate) const KNOWN_TOOL_NAMES: &[&str] = &[
"bash",
"ls",
"glob",
"grep",
"read_file",
"write_file",
"edit_file",
"patch_file",
"web_search",
"web_fetch",
"web_request",
"http_request",
"plan",
"task_manager",
"cron_manage",
"memory_search",
"session_search",
"lsp",
"agent",
"slack_send",
"telegram_send",
"discord_send",
"trello_send",
];
fn extract_claude_style_tool_calls(text: &str) -> Vec<(usize, usize, String, serde_json::Value)> {
let mut results = Vec::new();
let mut cursor = 0;
while cursor < text.len() {
let mut best: Option<(usize, &'static str)> = None;
for &tool in KNOWN_TOOL_NAMES {
let needle_owned = format!("<{}>", tool);
if let Some(rel) = text[cursor..].find(&needle_owned) {
let abs = cursor + rel;
if best.is_none_or(|(b, _)| abs < b) {
best = Some((abs, tool));
}
}
}
let Some((start, tool_name)) = best else {
break;
};
let open_tag_len = tool_name.len() + 2; let body_start = start + open_tag_len;
let close_tag = format!("</{}>", tool_name);
let Some(close_rel) = text[body_start..].find(&close_tag) else {
cursor = body_start;
continue;
};
let close_abs = body_start + close_rel;
let body = &text[body_start..close_abs];
let params = parse_xml_param_pairs(body);
if params.is_empty() {
cursor = close_abs + close_tag.len();
continue;
}
let mut map = serde_json::Map::new();
for (k, v) in params {
map.insert(k, serde_json::Value::String(v));
}
let end = close_abs + close_tag.len();
results.push((
start,
end,
tool_name.to_string(),
serde_json::Value::Object(map),
));
cursor = end;
}
results
}
fn parse_xml_param_pairs(body: &str) -> Vec<(String, String)> {
let mut pairs = Vec::new();
let mut cursor = 0;
while cursor < body.len() {
let Some(lt_rel) = body[cursor..].find('<') else {
break;
};
let lt_abs = cursor + lt_rel;
let after_lt = &body[lt_abs + 1..];
let name_len = after_lt
.bytes()
.take_while(|&b| b.is_ascii_alphanumeric() || b == b'_')
.count();
if name_len == 0 || after_lt.as_bytes().get(name_len) != Some(&b'>') {
cursor = lt_abs + 1;
continue;
}
let name = &after_lt[..name_len];
let body_start = lt_abs + 1 + name_len + 1; let close = format!("</{}>", name);
let Some(close_rel) = body[body_start..].find(&close) else {
break;
};
let value = body[body_start..body_start + close_rel].trim().to_string();
pairs.push((name.to_string(), value));
cursor = body_start + close_rel + close.len();
}
pairs
}
fn extract_invoke_style_tool_calls(
text: &str,
existing_strip_ranges: &[(usize, usize)],
) -> Vec<(usize, usize, String, serde_json::Value)> {
let mut results: Vec<(usize, usize, String, serde_json::Value)> = Vec::new();
let mut cursor: usize = 0;
while cursor < text.len() {
let invoke_at = text[cursor..]
.find("invoke name=")
.map(|r| (cursor + r, "invoke name=".len()));
let param_at = text[cursor..]
.find("<parameter name=")
.map(|r| (cursor + r, "<parameter name=".len()));
let (anchor, anchor_len) = match (invoke_at, param_at) {
(Some(i), Some(p)) => {
if i.0 <= p.0 {
i
} else {
p
}
}
(Some(i), None) => i,
(None, Some(p)) => p,
(None, None) => break,
};
if existing_strip_ranges
.iter()
.any(|(s, e)| *s <= anchor && anchor < *e)
{
cursor = anchor + anchor_len;
continue;
}
let after_eq = anchor + anchor_len;
let bytes = text.as_bytes();
if after_eq >= bytes.len() {
break;
}
let quote = bytes[after_eq];
if quote != b'"' && quote != b'\'' {
cursor = after_eq;
continue;
}
let name_start = after_eq + 1;
let name_end_rel = match text[name_start..].find(quote as char) {
Some(r) => r,
None => {
cursor = name_start;
continue;
}
};
let name = text[name_start..name_start + name_end_rel].to_string();
if !KNOWN_TOOL_NAMES.iter().any(|&n| n == name) {
cursor = name_start + name_end_rel;
continue;
}
let body_search_start = name_start + name_end_rel;
let close = text[body_search_start..]
.find("</invoke>")
.map(|r| (body_search_start + r, "</invoke>".len()));
let (body_end, close_len) = close.unwrap_or_else(|| {
let qwen_close = text[body_search_start..].find("</qwen:tool_call>");
let func_close = text[body_search_start..].find("</function_calls>");
let stop_rel = [qwen_close, func_close]
.into_iter()
.flatten()
.min()
.unwrap_or(text.len() - body_search_start);
(body_search_start + stop_rel, 0)
});
let body = &text[body_search_start..body_end];
let params = parse_invoke_parameters(body);
let mut obj = serde_json::Map::new();
for (k, v) in params {
obj.insert(k, coerce_xml_param_value(&v));
}
let mut block_start = anchor;
if block_start > 0 && text.as_bytes()[block_start - 1] == b'<' {
block_start -= 1;
}
let block_end = body_end + close_len;
results.push((block_start, block_end, name, serde_json::Value::Object(obj)));
cursor = block_end;
}
results
}
fn parse_invoke_parameters(body: &str) -> Vec<(String, String)> {
let mut out = Vec::new();
let mut cursor = 0;
while cursor < body.len() {
let rel = match body[cursor..].find("<parameter name=") {
Some(r) => r,
None => break,
};
let abs = cursor + rel;
let after = abs + "<parameter name=".len();
let bytes = body.as_bytes();
if after >= bytes.len() {
break;
}
let quote = bytes[after];
if quote != b'"' && quote != b'\'' {
cursor = after;
continue;
}
let name_start = after + 1;
let name_end_rel = match body[name_start..].find(quote as char) {
Some(r) => r,
None => break,
};
let key = body[name_start..name_start + name_end_rel].to_string();
let after_name = name_start + name_end_rel + 1;
let value_start = match body[after_name..].find('>') {
Some(r) => after_name + r + 1,
None => break,
};
let close = "</parameter>";
let close_rel = match body[value_start..].find(close) {
Some(r) => r,
None => break,
};
let value = body[value_start..value_start + close_rel]
.trim()
.to_string();
if !KNOWN_TOOL_NAMES.iter().any(|&n| n == key) {
out.push((key, value));
}
cursor = value_start + close_rel + close.len();
}
out
}
fn coerce_xml_param_value(raw: &str) -> serde_json::Value {
let trimmed = raw.trim();
if let Ok(n) = trimmed.parse::<i64>() {
return serde_json::Value::from(n);
}
if let Ok(n) = trimmed.parse::<f64>() {
if n.is_finite() {
return serde_json::json!(n);
}
}
match trimmed.to_ascii_lowercase().as_str() {
"true" | "yes" => return serde_json::Value::Bool(true),
"false" | "no" => return serde_json::Value::Bool(false),
_ => {}
}
serde_json::Value::String(trimmed.to_string())
}
fn widen_strip_to_known_wrappers(text: &str, strip_ranges: &mut Vec<(usize, usize)>) {
const WRAPPERS: &[(&str, &str)] = &[
("<qwen:tool_call>", "</qwen:tool_call>"),
("<function_calls>", "</function_calls>"),
];
let mut additions: Vec<(usize, usize)> = Vec::new();
for (open, close) in WRAPPERS {
let mut search_from = 0;
while let Some(rel) = text[search_from..].find(open) {
let open_pos = search_from + rel;
let after_open = open_pos + open.len();
let close_rel = match text[after_open..].find(close) {
Some(r) => r,
None => break,
};
let close_end = after_open + close_rel + close.len();
let contains_invoke = strip_ranges
.iter()
.any(|(s, _)| open_pos <= *s && *s < close_end);
if contains_invoke {
additions.push((open_pos, close_end));
}
search_from = close_end;
}
}
strip_ranges.extend(additions);
}
pub(crate) fn extract_balanced_json(s: &str) -> Option<usize> {
let bytes = s.as_bytes();
let (open, close) = match bytes.first()? {
b'{' => (b'{', b'}'),
b'[' => (b'[', b']'),
_ => return None,
};
let mut depth: i32 = 0;
let mut in_string = false;
let mut escape = false;
for (idx, &b) in bytes.iter().enumerate() {
if escape {
escape = false;
continue;
}
if in_string {
match b {
b'\\' => escape = true,
b'"' => in_string = false,
_ => {}
}
continue;
}
if b == b'"' {
in_string = true;
} else if b == open {
depth += 1;
} else if b == close {
depth -= 1;
if depth == 0 {
return Some(idx + 1);
}
}
}
None
}
fn parse_qwen_tool_json(json: &str) -> Option<(String, serde_json::Value)> {
let v: serde_json::Value = serde_json::from_str(json).ok()?;
parse_tool_call_value(&v)
}
fn extract_tool_call_list_calls(text: &str) -> Vec<(usize, usize, String, serde_json::Value)> {
const OPEN: &str = "<tool_call_list>";
const CLOSE: &str = "</tool_call_list>";
let mut out = Vec::new();
let mut i = 0;
while let Some(rel) = text[i..].find(OPEN) {
let block_start = i + rel;
let inner_start = block_start + OPEN.len();
let (inner_end, block_end) = match text[inner_start..].find(CLOSE) {
Some(r) => (inner_start + r, inner_start + r + CLOSE.len()),
None => {
let next = text[inner_start..]
.find(OPEN)
.map(|r| inner_start + r)
.unwrap_or(text.len());
(next, next)
}
};
let inner = &text[inner_start..inner_end];
let mut j = 0;
while j < inner.len() {
let b = inner.as_bytes()[j];
if (b == b'{' || b == b'[')
&& let Some(consumed) = extract_balanced_json(&inner[j..])
{
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&inner[j..j + consumed]) {
match &v {
serde_json::Value::Array(items) => {
for item in items {
if let Some(call) = parse_tool_call_value(item) {
out.push((block_start, block_end, call.0, call.1));
}
}
}
_ => {
if let Some(call) = parse_tool_call_value(&v) {
out.push((block_start, block_end, call.0, call.1));
}
}
}
}
j += consumed.max(1);
continue;
}
j += 1;
}
i = block_end.max(inner_start);
}
out
}
fn parse_tool_call_value(v: &serde_json::Value) -> Option<(String, serde_json::Value)> {
let name = v
.get("name")
.and_then(|n| n.as_str())
.or_else(|| v.get("tool_name").and_then(|n| n.as_str()))
.or_else(|| {
v.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
})
.or_else(|| {
v.get("function")
.and_then(|f| if f.is_string() { f.as_str() } else { None })
})?
.to_string();
if name.is_empty() {
return None;
}
let args_val = v
.get("arguments")
.or_else(|| v.get("args"))
.or_else(|| v.get("input"))
.or_else(|| v.get("tool_input"))
.or_else(|| v.get("parameters"))
.or_else(|| v.get("function").and_then(|f| f.get("arguments")))
.or_else(|| v.get("function").and_then(|f| f.get("parameters")));
let input = match args_val {
Some(serde_json::Value::String(s)) => {
serde_json::from_str(s).unwrap_or(serde_json::json!({}))
}
Some(other) => other.clone(),
None => serde_json::json!({}),
};
Some((name, input))
}
fn extract_balanced_json_tolerant(s: &str) -> Option<usize> {
extract_balanced_json(s)
}
fn recover_tool_call_from_malformed_json(env: &str) -> Option<(String, serde_json::Value)> {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(env) {
let inner = v
.get("tool_call")
.or_else(|| v.get("function"))
.cloned()
.unwrap_or(v);
if let Some(call) = parse_tool_call_value(&inner) {
return Some(call);
}
}
use regex::Regex;
use std::sync::LazyLock;
static NAME_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r#""name"\s*:?\s*"([^"]+)""#).unwrap());
static ARG_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r#""([a-zA-Z_][a-zA-Z0-9_]*)"\s*:?\s*("([^"\\]|\\.)*"|true|false|-?\d+(\.\d+)?)"#,
)
.unwrap()
});
let name_cap = NAME_RE.captures(env)?;
let name = name_cap.get(1)?.as_str().to_string();
if name.is_empty() {
return None;
}
let name_end = name_cap.get(0)?.end();
let args_region = &env[name_end..];
let mut map = serde_json::Map::new();
for cap in ARG_RE.captures_iter(args_region) {
let key = cap.get(1).map(|m| m.as_str().to_string());
let raw = cap.get(2).map(|m| m.as_str().to_string());
if let (Some(k), Some(r)) = (key, raw) {
if matches!(
k.as_str(),
"name" | "tool_call" | "type" | "id" | "function"
) {
continue;
}
let val = if let Some(stripped) = r.strip_prefix('"').and_then(|s| s.strip_suffix('"'))
{
serde_json::Value::String(stripped.replace("\\\"", "\""))
} else if r == "true" {
serde_json::Value::Bool(true)
} else if r == "false" {
serde_json::Value::Bool(false)
} else if let Ok(n) = r.parse::<i64>() {
serde_json::Value::Number(n.into())
} else if let Ok(f) = r.parse::<f64>()
&& let Some(n) = serde_json::Number::from_f64(f)
{
serde_json::Value::Number(n)
} else {
continue;
};
map.insert(k, val);
}
}
Some((name, serde_json::Value::Object(map)))
}
fn parse_function_params(body: &str) -> serde_json::Value {
let mut map = serde_json::Map::new();
let mut i = 0usize;
while let Some(rel) = body[i..].find("<parameter=") {
let tag_start = i + rel;
let after = &body[tag_start..];
let Some(gt) = after.find('>') else { break };
let key = body[tag_start + "<parameter=".len()..tag_start + gt].trim();
if key.is_empty() {
i = tag_start + gt + 1;
continue;
}
let val_start = tag_start + gt + 1;
let tail = &body[val_start..];
let end_at_param = tail.find("</parameter>");
let end_at_next = tail.find("<parameter=");
let end = match (end_at_param, end_at_next) {
(Some(a), Some(b)) => Some(a.min(b)),
(a, b) => a.or(b),
};
let (val, next_i) = match end {
Some(rel) => {
let skip = if end_at_param == Some(rel) {
rel + "</parameter>".len()
} else {
rel
};
(tail[..rel].trim().to_string(), val_start + skip)
}
None => (tail.trim().to_string(), body.len()),
};
map.insert(key.to_string(), serde_json::Value::String(val));
i = next_i;
}
serde_json::Value::Object(map)
}
pub type TokenFn = Arc<dyn Fn() -> String + Send + Sync>;
pub type BodyTransformFn = Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>;
pub type BaseUrlFn = Arc<dyn Fn() -> String + Send + Sync>;
pub type AuthRefreshFn = Arc<
dyn Fn() -> std::pin::Pin<
Box<dyn std::future::Future<Output = std::result::Result<(), String>> + Send>,
> + Send
+ Sync,
>;
pub type AuthInvalidateFn = Arc<dyn Fn() + Send + Sync>;
#[derive(Clone)]
pub struct OpenAIProvider {
api_key: String,
base_url: String,
client: Client,
custom_default_model: Option<String>,
name: String,
vision_model: Option<String>,
pub(crate) extra_headers: Vec<(String, String)>,
configured_context_window: Option<u32>,
configured_models: Vec<String>,
token_fn: Option<TokenFn>,
rate_limiter: Option<Arc<RateLimiter>>,
body_transform: Option<BodyTransformFn>,
base_url_fn: Option<BaseUrlFn>,
auth_refresh_fn: Option<AuthRefreshFn>,
auth_invalidate_fn: Option<AuthInvalidateFn>,
retry_config_override: Option<crate::utils::retry::RetryConfig>,
cache_enabled: bool,
cache_ttl: Option<u32>,
retry_notices: Arc<std::sync::Mutex<Vec<(u32, u32, String)>>>,
}
impl OpenAIProvider {
fn is_openrouter(&self) -> bool {
self.base_url.to_lowercase().contains("openrouter")
}
fn retry_config(&self, model: &str) -> crate::utils::retry::RetryConfig {
if let Some(ref ovr) = self.retry_config_override {
return ovr.clone();
}
if self.name == "qwen" || self.is_openrouter() || model.ends_with(":free") {
crate::utils::retry::RetryConfig::qwen_cli_match()
} else {
crate::utils::retry::RetryConfig::default()
}
}
pub fn new(api_key: String) -> Self {
let client = Client::builder()
.timeout(DEFAULT_TIMEOUT)
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.pool_idle_timeout(DEFAULT_POOL_IDLE_TIMEOUT)
.pool_max_idle_per_host(2)
.tcp_keepalive(DEFAULT_TCP_KEEPALIVE)
.build()
.expect("Failed to create HTTP client");
Self {
api_key,
base_url: DEFAULT_OPENAI_API_URL.to_string(),
client,
custom_default_model: None,
name: "openai".to_string(),
vision_model: None,
extra_headers: vec![],
configured_context_window: None,
configured_models: Vec::new(),
token_fn: None,
rate_limiter: None,
body_transform: None,
base_url_fn: None,
auth_refresh_fn: None,
auth_invalidate_fn: None,
retry_config_override: None,
cache_enabled: false,
cache_ttl: None,
retry_notices: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
pub fn local(base_url: String) -> Self {
let client = Client::builder()
.timeout(DEFAULT_TIMEOUT)
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.pool_idle_timeout(DEFAULT_POOL_IDLE_TIMEOUT)
.pool_max_idle_per_host(2)
.tcp_keepalive(DEFAULT_TCP_KEEPALIVE)
.build()
.expect("Failed to create HTTP client");
Self {
api_key: "not-needed".to_string(),
base_url,
client,
custom_default_model: None,
name: "openai-compatible".to_string(),
vision_model: None,
extra_headers: vec![],
configured_context_window: None,
configured_models: Vec::new(),
token_fn: None,
rate_limiter: None,
body_transform: None,
base_url_fn: None,
auth_refresh_fn: None,
auth_invalidate_fn: None,
retry_config_override: None,
cache_enabled: false,
cache_ttl: None,
retry_notices: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
pub fn with_base_url(api_key: String, base_url: String) -> Self {
let client = Client::builder()
.timeout(DEFAULT_TIMEOUT)
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.pool_idle_timeout(DEFAULT_POOL_IDLE_TIMEOUT)
.pool_max_idle_per_host(2)
.tcp_keepalive(DEFAULT_TCP_KEEPALIVE)
.build()
.expect("Failed to create HTTP client");
Self {
api_key,
base_url,
client,
custom_default_model: None,
name: "openai-compatible".to_string(),
vision_model: None,
extra_headers: vec![],
configured_context_window: None,
configured_models: Vec::new(),
token_fn: None,
rate_limiter: None,
body_transform: None,
base_url_fn: None,
auth_refresh_fn: None,
auth_invalidate_fn: None,
retry_config_override: None,
cache_enabled: false,
cache_ttl: None,
retry_notices: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
pub fn with_extra_headers(mut self, headers: Vec<(String, String)>) -> Self {
self.extra_headers = headers;
self
}
pub fn with_context_window(mut self, size: u32) -> Self {
self.configured_context_window = Some(size);
self
}
pub fn with_models(mut self, models: Vec<String>) -> Self {
self.configured_models = models;
self
}
pub fn with_name(mut self, name: &str) -> Self {
self.name = name.to_string();
self
}
pub fn with_default_model(mut self, model: String) -> Self {
self.custom_default_model = Some(model);
self
}
pub fn with_token_fn(mut self, f: TokenFn) -> Self {
self.token_fn = Some(f);
self
}
pub fn with_vision_model(mut self, model: String) -> Self {
self.vision_model = Some(model);
self
}
pub fn with_rate_limiter(mut self, limiter: Arc<RateLimiter>) -> Self {
self.rate_limiter = Some(limiter);
self
}
pub fn with_body_transform(mut self, f: BodyTransformFn) -> Self {
self.body_transform = Some(f);
self
}
pub fn with_base_url_fn(mut self, f: BaseUrlFn) -> Self {
self.base_url_fn = Some(f);
self
}
pub fn with_auth_refresh_fn(mut self, f: AuthRefreshFn) -> Self {
self.auth_refresh_fn = Some(f);
self
}
pub fn with_auth_invalidate_fn(mut self, f: AuthInvalidateFn) -> Self {
self.auth_invalidate_fn = Some(f);
self
}
pub fn with_retry_config(mut self, config: crate::utils::retry::RetryConfig) -> Self {
self.retry_config_override = Some(config);
self
}
pub fn with_cache_enabled(mut self, enabled: bool) -> Self {
self.cache_enabled = enabled;
self
}
pub fn with_cache_ttl(mut self, ttl: u32) -> Self {
self.cache_ttl = Some(ttl);
self
}
fn encode_body<T: Serialize>(&self, body: &T) -> Result<serde_json::Value> {
let mut v = serde_json::to_value(body)?;
if let Some(ref f) = self.body_transform {
v = f(v);
}
Ok(v)
}
fn send_url(&self) -> String {
if let Some(ref f) = self.base_url_fn {
let u = f();
if !u.is_empty() {
return u;
}
}
self.base_url.clone()
}
fn is_auth_error(err: &ProviderError) -> bool {
matches!(
err,
ProviderError::InvalidApiKey
| ProviderError::ApiError {
status: 401 | 403,
..
}
)
}
pub fn vision_model(&self) -> Option<&str> {
self.vision_model.as_deref()
}
fn headers(&self) -> std::result::Result<reqwest::header::HeaderMap, ProviderError> {
let mut headers = reqwest::header::HeaderMap::new();
let bearer_key = if let Some(ref f) = self.token_fn {
let token = f();
if token.is_empty() { None } else { Some(token) }
} else if self.api_key != "not-needed" {
Some(self.api_key.trim().to_string())
} else {
None
};
if let Some(key) = bearer_key {
let header_value: reqwest::header::HeaderValue =
format!("Bearer {}", key).parse().map_err(|_| {
tracing::error!(
"API key contains invalid characters (length={}). Check keys.toml.",
key.len()
);
ProviderError::InvalidApiKey
})?;
headers.insert(reqwest::header::AUTHORIZATION, header_value);
}
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().expect("valid content-type"),
);
headers.insert(
reqwest::header::ACCEPT,
"application/json".parse().expect("valid accept"),
);
if self.base_url.to_lowercase().contains("openrouter") {
if let (Ok(k1), Ok(v1)) = (
"HTTP-Referer".parse::<reqwest::header::HeaderName>(),
"https://opencrabs.com".parse::<reqwest::header::HeaderValue>(),
) {
headers.insert(k1, v1);
}
if let (Ok(k2), Ok(v2)) = (
"X-Title".parse::<reqwest::header::HeaderName>(),
"OpenCrabs".parse::<reqwest::header::HeaderValue>(),
) {
headers.insert(k2, v2);
}
if let (Ok(k3), Ok(v3)) = (
"X-OpenRouter-Category".parse::<reqwest::header::HeaderName>(),
"cli-agent,personal-agent,programming-app".parse::<reqwest::header::HeaderValue>(),
) {
headers.insert(k3, v3);
}
if self.cache_enabled {
if let (Ok(k), Ok(v)) = (
"X-OpenRouter-Cache".parse::<reqwest::header::HeaderName>(),
"true".parse::<reqwest::header::HeaderValue>(),
) {
headers.insert(k, v);
}
if let Some(ttl) = self.cache_ttl
&& let (Ok(k), Ok(v)) = (
"X-OpenRouter-Cache-TTL".parse::<reqwest::header::HeaderName>(),
ttl.to_string().parse::<reqwest::header::HeaderValue>(),
)
{
headers.insert(k, v);
}
}
tracing::debug!("OpenRouter optimization headers attached");
}
for (key, value) in &self.extra_headers {
if let (Ok(k), Ok(v)) = (
key.parse::<reqwest::header::HeaderName>(),
value.parse::<reqwest::header::HeaderValue>(),
) {
headers.insert(k, v);
}
}
Ok(headers)
}
pub(crate) fn to_openai_request(&self, request: LLMRequest) -> OpenAIRequest {
let mut messages = Vec::new();
if let Some(ref system) = request.system {
tracing::debug!("System brain present: {} chars", system.len());
} else {
tracing::warn!("NO SYSTEM BRAIN in request!");
}
if let Some(system) = request.system {
messages.push(OpenAIMessage {
role: "system".to_string(),
content: Some(serde_json::Value::String(system)),
tool_calls: None,
tool_call_id: None,
reasoning_content: None,
});
}
let needs_reasoning_content = needs_reasoning_content_for(&self.base_url, &request.model);
for msg in request.messages {
let role = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
};
let mut text_parts = Vec::new();
let mut image_parts: Vec<serde_json::Value> = Vec::new();
let mut tool_uses = Vec::new();
let mut tool_results = Vec::new();
let mut thinking_parts: Vec<String> = Vec::new();
for block in msg.content {
match block {
ContentBlock::Text { text } => {
text_parts.push(text);
}
ContentBlock::ToolUse { id, name, input } => {
tool_uses.push((id, name, input));
}
ContentBlock::ToolResult {
tool_use_id,
content,
..
} => {
tool_results.push((tool_use_id, content));
}
ContentBlock::Thinking { thinking, .. } => {
if !thinking.is_empty() {
thinking_parts.push(thinking);
}
}
ContentBlock::Image { source } => {
let url = match source {
ImageSource::Base64 { media_type, data } => {
format!("data:{};base64,{}", media_type, data)
}
ImageSource::Url { url } => url,
};
image_parts.push(serde_json::json!({
"type": "image_url",
"image_url": { "url": url }
}));
}
}
}
let make_content =
|texts: &[String], images: &[serde_json::Value]| -> Option<serde_json::Value> {
if !images.is_empty() {
let mut parts: Vec<serde_json::Value> = Vec::new();
if !texts.is_empty() {
parts.push(serde_json::json!({
"type": "text",
"text": texts.join("\n")
}));
}
parts.extend(images.iter().cloned());
Some(serde_json::Value::Array(parts))
} else if !texts.is_empty() {
Some(serde_json::Value::String(texts.join("\n")))
} else {
None
}
};
if !tool_uses.is_empty() {
let openai_tool_calls = tool_uses
.into_iter()
.map(|(id, name, input)| OpenAIToolCall {
id,
r#type: "function".to_string(),
function: OpenAIFunctionCall {
name,
arguments: serde_json::to_string(&input).unwrap_or_default(),
},
})
.collect();
let content_val = make_content(&text_parts, &image_parts);
let reasoning_content = if !thinking_parts.is_empty() {
Some(thinking_parts.join("\n"))
} else if needs_reasoning_content {
Some(" ".to_string())
} else {
None
};
messages.push(OpenAIMessage {
role: role.to_string(),
content: content_val,
tool_calls: Some(openai_tool_calls),
tool_call_id: None,
reasoning_content,
});
}
else if !tool_results.is_empty() {
for (tool_use_id, content) in tool_results {
messages.push(OpenAIMessage {
role: "tool".to_string(),
content: Some(serde_json::Value::String(content)),
tool_calls: None,
tool_call_id: Some(tool_use_id),
reasoning_content: None,
});
}
}
else {
let content_val = make_content(&text_parts, &image_parts)
.unwrap_or(serde_json::Value::String(String::new()));
let reasoning_content = if role == "assistant" && !thinking_parts.is_empty() {
Some(thinking_parts.join("\n"))
} else {
None
};
messages.push(OpenAIMessage {
role: role.to_string(),
content: Some(content_val),
tool_calls: None,
tool_call_id: None,
reasoning_content,
});
}
}
let tools: Option<Vec<OpenAITool>> = request.tools.map(|tools| {
tools
.iter()
.map(|tool| OpenAITool {
r#type: "function".to_string(),
function: OpenAIFunction {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
})
.collect()
});
let uses_completion_tokens = uses_max_completion_tokens(&request.model);
let (max_tokens, max_completion_tokens) = if uses_completion_tokens {
(None, request.max_tokens)
} else {
(request.max_tokens, None)
};
let tool_choice = tools
.as_ref()
.filter(|t| !t.is_empty())
.map(|_| serde_json::json!("auto"));
let base = self.base_url.to_lowercase();
let include_reasoning = if base.contains("openrouter")
|| base.contains("openrouter.ai")
|| std::env::var("OPENCRABS_ENABLE_REASONING").is_ok()
{
Some(true)
} else {
None
};
OpenAIRequest {
model: request.model,
messages,
temperature: request.temperature,
max_tokens,
max_completion_tokens,
stream: Some(request.stream),
stream_options: None,
tools,
tool_choice,
include_reasoning,
}
}
#[allow(clippy::wrong_self_convention)]
fn from_openai_response(&self, response: OpenAIResponse) -> LLMResponse {
let choice = response
.choices
.into_iter()
.next()
.unwrap_or_else(|| OpenAIChoice {
message: OpenAIMessage {
role: "assistant".to_string(),
content: Some(serde_json::Value::String(String::new())),
tool_calls: None,
tool_call_id: None,
reasoning_content: None,
},
finish_reason: Some("error".to_string()),
});
let mut content_blocks = Vec::new();
if let Some(ref content_val) = choice.message.content {
let text = match content_val {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(parts) => {
parts
.iter()
.filter_map(|p| {
if p.get("type")?.as_str()? == "text" {
p.get("text")?.as_str().map(String::from)
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}
_ => String::new(),
};
if !text.is_empty() {
let clean = strip_think_blocks(&text);
if !clean.is_empty() {
content_blocks.push(ContentBlock::Text { text: clean });
}
}
}
let structured_tool_calls_present = choice
.message
.tool_calls
.as_ref()
.map(|v| !v.is_empty())
.unwrap_or(false);
if !structured_tool_calls_present {
let mut recovered: Vec<ContentBlock> = Vec::new();
for block in content_blocks.iter_mut() {
if let ContentBlock::Text { text } = block {
let (calls, cleaned) = extract_text_tool_calls(text);
if calls.is_empty() {
continue;
}
tracing::info!(
"Recovered {} tool call(s) from text content (local-model fallback)",
calls.len()
);
*text = cleaned;
for (idx, (name, input)) in calls.into_iter().enumerate() {
recovered.push(ContentBlock::ToolUse {
id: format!("call_text_{}", idx),
name,
input,
});
}
}
}
content_blocks.retain(|b| match b {
ContentBlock::Text { text } => !text.trim().is_empty(),
_ => true,
});
content_blocks.extend(recovered);
}
if let Some(tool_calls) = choice.message.tool_calls {
tracing::debug!(
"Converting {} tool calls from OpenAI response",
tool_calls.len()
);
for tool_call in tool_calls {
let input = super::json_repair::parse_or_repair(&tool_call.function.arguments);
tracing::debug!(
"Converted tool call: {} with id {}",
tool_call.function.name,
tool_call.id
);
content_blocks.push(ContentBlock::ToolUse {
id: tool_call.id,
name: tool_call.function.name,
input,
});
}
}
let has_tool_text = content_blocks.iter().any(|b| {
if let ContentBlock::Text { text } = b {
(text.contains("\"function\"") && text.contains("\"arguments\""))
|| (text.contains("tool_call") && text.contains("\"name\""))
|| (text.contains("```json") && text.contains("\"command\""))
} else {
false
}
});
let has_structured_tools = content_blocks
.iter()
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
if has_tool_text && !has_structured_tools {
tracing::warn!(
"Model returned tool call JSON as text — likely does not support function calling"
);
content_blocks.push(ContentBlock::Text {
text: "\n\n⚠️ **This model does not support function calling.** Tool requests were returned as text instead of executable calls. Consider switching to a model that supports tool use (e.g. Claude, GPT-4, Gemini).".to_string(),
});
}
let stop_reason = choice
.finish_reason
.and_then(|reason| match reason.as_str() {
"stop" => Some(StopReason::EndTurn),
"length" => Some(StopReason::MaxTokens),
"tool_calls" | "function_call" => Some(StopReason::ToolUse),
_ => None,
});
LLMResponse {
id: response.id,
model: response.model,
content: content_blocks,
stop_reason,
usage: TokenUsage {
input_tokens: response.usage.prompt_tokens.unwrap_or(0),
output_tokens: response.usage.completion_tokens.unwrap_or(0),
cache_creation_tokens: response.usage.cache_creation_input_tokens.unwrap_or(0),
cache_read_tokens: response.usage.effective_cache_read(),
..Default::default()
},
streaming_active_secs: None,
}
}
async fn handle_error(&self, response: reqwest::Response) -> ProviderError {
let status = response.status().as_u16();
let retry_after = response.headers().get("retry-after").and_then(|v| {
v.to_str().ok().and_then(|s| {
s.parse::<u64>().ok()
})
});
let base_url_snapshot = self.base_url.clone();
let body_text = response.text().await.unwrap_or_default();
tracing::warn!(
"[HANDLE_ERROR] {} → {}: raw body (first 1500 chars): {}",
self.name,
status,
body_text.chars().take(1500).collect::<String>(),
);
if let Ok(error_body) = serde_json::from_str::<OpenAIErrorResponse>(&body_text) {
let (inner_message, inner_type) = unwrap_proxy_error(&error_body.error);
let message = if status == 429 {
if let Some(secs) = retry_after {
format!("{} (retry after {} seconds)", inner_message, secs)
} else {
format!("{} (rate limited, please retry later)", inner_message)
}
} else {
inner_message
};
return if status == 429 {
tracing::warn!("[RATE_LIMIT] {} → {}: {}", self.name, status, message,);
ProviderError::RateLimitExceeded(message)
} else {
ProviderError::ApiError {
status,
message,
error_type: inner_type.or(Some(String::new())),
}
};
}
if status == 422
&& let Ok(pydantic) = serde_json::from_str::<PydanticValidationError>(&body_text)
{
let mut parts: Vec<String> = pydantic
.detail
.iter()
.take(5)
.map(|d| {
let field = d
.loc
.iter()
.map(|v| match v {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
_ => "?".to_string(),
})
.collect::<Vec<_>>()
.join(".");
format!("{}: {}", field, d.msg)
})
.collect();
if pydantic.detail.len() > 5 {
parts.push(format!("… +{} more", pydantic.detail.len() - 5));
}
let fields_summary = parts.join("; ");
let unsloth_hint = if is_unsloth_studio_url(&base_url_snapshot) {
" — Unsloth Studio's API is NOT OpenAI-compat (rejects `tools`, `role:\"tool\"`, etc.). \
Find the internal llama-server port with `lsof -iTCP -sTCP:LISTEN | grep llama-server` \
and point base_url at that directly."
} else {
""
};
return ProviderError::ApiError {
status,
message: format!(
"Schema validation failed ({} field(s)): {}{}",
pydantic.detail.len(),
fields_summary,
unsloth_hint,
),
error_type: Some("validation_error".to_string()),
};
}
let fallback_message = if body_text.is_empty() {
"Unknown error".to_string()
} else {
format!(
"HTTP {}: {}",
status,
body_text.chars().take(400).collect::<String>()
)
};
if status == 429 {
let message = if let Some(secs) = retry_after {
format!("Rate limit exceeded (retry after {} seconds)", secs)
} else {
"Rate limit exceeded, please retry later".to_string()
};
ProviderError::RateLimitExceeded(message)
} else {
ProviderError::ApiError {
status,
message: fallback_message,
error_type: None,
}
}
}
}
#[async_trait]
impl Provider for OpenAIProvider {
async fn complete(&self, request: LLMRequest) -> Result<LLMResponse> {
use crate::utils::retry::retry;
let model = request.model.clone();
let message_count = request.messages.len();
let retry_config = self.retry_config(&model);
let mut openai_request = self.to_openai_request(request);
let tool_count = openai_request.tools.as_ref().map(|t| t.len()).unwrap_or(0);
tracing::info!(
"OpenAI API request: model={}, messages={}, max_tokens={:?}, max_completion_tokens={:?}, tools={}",
model,
message_count,
openai_request.max_tokens,
openai_request.max_completion_tokens,
tool_count
);
if tool_count == 0 {
tracing::warn!(
"OpenAI request has NO tools - LLM won't know about file/bash operations!"
);
}
if let Some(ref limiter) = self.rate_limiter {
let waited = limiter.wait().await;
if !waited.is_zero() {
tracing::debug!(
"Rate limiter: waited {:?} before request to {}",
waited,
self.base_url
);
}
}
let result = retry(
|| async {
tracing::debug!("Sending request to OpenAI API: {}", self.base_url);
let body = self.encode_body(&openai_request)?;
let response = self
.client
.post(self.send_url())
.headers(self.headers()?)
.json(&body)
.send()
.await?;
let status = response.status();
tracing::debug!("OpenAI API response status: {}", status);
if !status.is_success() {
return Err(self.handle_error(response).await);
}
let cache_status = response
.headers()
.get("x-openrouter-cache-status")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let openai_response: OpenAIResponse = response.json().await?;
let llm_response = self.from_openai_response(openai_response);
if let Some(ref status) = cache_status {
if status == "HIT" {
tracing::info!("OpenRouter cache HIT — zero tokens billed");
} else {
tracing::debug!("OpenRouter cache MISS");
}
}
tracing::info!(
"OpenAI API response: input_tokens={}, output_tokens={}, stop_reason={:?}",
llm_response.usage.input_tokens,
llm_response.usage.output_tokens,
llm_response.stop_reason
);
Ok(llm_response)
},
&retry_config,
)
.await;
if let Err(ref e) = result {
if is_token_field_mismatch(&e.to_string()) {
tracing::warn!(
"Token field mismatch for model {}, retrying with swapped fields",
model
);
openai_request.swap_token_fields();
return retry(
|| async {
let body = self.encode_body(&openai_request)?;
let response = self
.client
.post(self.send_url())
.headers(self.headers()?)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
return Err(self.handle_error(response).await);
}
let openai_response: OpenAIResponse = response.json().await?;
Ok(self.from_openai_response(openai_response))
},
&retry_config,
)
.await;
}
if Self::is_auth_error(e)
&& let Some(ref refresh) = self.auth_refresh_fn
{
tracing::warn!("{} auth error — refreshing and retrying", self.name);
match refresh().await {
Ok(()) => {
return retry(
|| async {
let body = self.encode_body(&openai_request)?;
let response = self
.client
.post(self.send_url())
.headers(self.headers()?)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
return Err(self.handle_error(response).await);
}
let openai_response: OpenAIResponse = response.json().await?;
Ok(self.from_openai_response(openai_response))
},
&retry_config,
)
.await;
}
Err(msg) => {
tracing::error!("{} auth refresh failed: {}", self.name, msg);
if msg.contains("HTTP 400")
&& let Some(ref invalidate) = self.auth_invalidate_fn
{
tracing::warn!(
"{} refresh_token permanently dead — invalidating account",
self.name
);
invalidate();
}
}
}
}
tracing::error!("OpenAI API request failed: {}", e);
}
result
}
async fn stream(&self, request: LLMRequest) -> Result<ProviderStream> {
use crate::utils::retry::retry;
let model = request.model.clone();
let message_count = request.messages.len();
if let Some(ref limiter) = self.rate_limiter {
let waited = limiter.wait().await;
if !waited.is_zero() {
tracing::debug!(
"Rate limiter: waited {:?} before streaming request to {}",
waited,
self.base_url
);
}
}
tracing::info!(
"{} streaming request: model={}, messages={}, base_url={}",
self.name(),
model,
message_count,
self.base_url
);
let mut openai_request = self.to_openai_request(request);
openai_request.stream = Some(true);
openai_request.stream_options = Some(StreamOptions {
include_usage: true,
});
let tools_count = openai_request.tools.as_ref().map(|t| t.len()).unwrap_or(0);
let message_tokens: usize = openai_request
.messages
.iter()
.map(|m| {
let content = m
.content
.as_ref()
.map(|v| {
let s = v.as_str().unwrap_or("");
count_message_tokens(s)
})
.unwrap_or(4);
let tool_calls = m
.tool_calls
.as_ref()
.map(|tc| count_tokens(&serde_json::to_string(tc).unwrap_or_default()))
.unwrap_or(0);
content + tool_calls
})
.sum();
let tool_schema_tokens = openai_request
.tools
.as_ref()
.map(|tools| count_tokens(&serde_json::to_string(tools).unwrap_or_default()))
.unwrap_or(0);
let total_input_tokens = message_tokens + tool_schema_tokens;
let context_pct = (total_input_tokens as f32 / 200_000.0 * 100.0).round() as u32;
tracing::debug!(
"OpenAI stream request: ~{} input tokens ({}% of 200k window) — {} messages, {} tool schemas",
total_input_tokens,
context_pct,
openai_request.messages.len(),
tools_count
);
let retry_config = self.retry_config(&model);
let notices = self.retry_notices.clone();
let pname = self.name().to_string();
let mut response = crate::utils::retry::retry_with_notify(
|| async {
let body = self.encode_body(&openai_request)?;
let response = self
.client
.post(self.send_url())
.headers(self.headers()?)
.json(&body)
.send()
.await?;
tracing::debug!("OpenAI response status: {}", response.status());
if !response.status().is_success() {
return Err(self.handle_error(response).await);
}
Ok(response)
},
&retry_config,
|attempt, max, err| {
if let Ok(mut v) = notices.lock() {
v.push((attempt, max, format!("{} — {}", pname, retry_reason(err))));
}
},
)
.await;
if let Err(ref e) = response
&& is_token_field_mismatch(&e.to_string())
{
tracing::warn!(
"Token field mismatch for model {} (stream), retrying with swapped fields",
model
);
openai_request.swap_token_fields();
response = retry(
|| async {
let body = self.encode_body(&openai_request)?;
let r = self
.client
.post(self.send_url())
.headers(self.headers()?)
.json(&body)
.send()
.await?;
if !r.status().is_success() {
return Err(self.handle_error(r).await);
}
Ok(r)
},
&retry_config,
)
.await;
}
if let Err(ref e) = response
&& Self::is_auth_error(e)
&& let Some(ref refresh) = self.auth_refresh_fn
{
tracing::warn!("{} stream auth error — refreshing and retrying", self.name);
match refresh().await {
Ok(()) => {
response = retry(
|| async {
let body = self.encode_body(&openai_request)?;
let r = self
.client
.post(self.send_url())
.headers(self.headers()?)
.json(&body)
.send()
.await?;
if !r.status().is_success() {
return Err(self.handle_error(r).await);
}
Ok(r)
},
&retry_config,
)
.await;
}
Err(msg) => {
tracing::error!("{} stream auth refresh failed: {}", self.name, msg);
if msg.contains("HTTP 400")
&& let Some(ref invalidate) = self.auth_invalidate_fn
{
tracing::warn!(
"{} refresh_token permanently dead — invalidating account",
self.name
);
invalidate();
}
}
}
}
let response = response?;
let cache_status = response
.headers()
.get("x-openrouter-cache-status")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let Some(ref status) = cache_status {
if status == "HIT" {
tracing::info!("OpenRouter cache HIT (stream) — zero tokens billed");
} else {
tracing::debug!("OpenRouter cache MISS (stream)");
}
}
let byte_stream = response.bytes_stream();
let buffer = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
#[derive(Debug, Clone, Default)]
struct ToolCallAccum {
id: String,
name: String,
arguments: String,
}
struct StreamState {
emitted_message_start: bool,
emitted_content_start: bool,
emitted_content_stop: bool,
seen_delta_content: bool,
tool_calls: std::collections::HashMap<usize, ToolCallAccum>,
inside_think: bool,
active_close_tag: usize,
think_bytes_consumed: usize,
think_carry: String,
pending_stop_reason: Option<crate::brain::provider::types::StopReason>,
leak_probe: String,
leak_active: bool,
response_text_accum: String,
tool_capture_buffer: String,
tool_block_active: bool,
marker_carry: String,
}
let state = std::sync::Arc::new(std::sync::Mutex::new(StreamState {
emitted_message_start: false,
emitted_content_start: false,
emitted_content_stop: false,
seen_delta_content: false,
tool_calls: std::collections::HashMap::new(),
inside_think: false,
active_close_tag: 0,
think_bytes_consumed: 0,
think_carry: String::new(),
pending_stop_reason: None,
leak_probe: String::new(),
leak_active: false,
response_text_accum: String::new(),
tool_capture_buffer: String::new(),
tool_block_active: false,
marker_carry: String::new(),
}));
let event_stream = byte_stream
.map(move |chunk_result| -> Vec<std::result::Result<StreamEvent, ProviderError>> {
match chunk_result {
Err(e) => vec![Err(ProviderError::StreamError(e.to_string()))],
Ok(chunk) => {
let raw_text = String::from_utf8_lossy(&chunk);
tracing::trace!("[STREAM_RAW] SSE chunk: {}", raw_text.chars().take(500).collect::<String>());
if raw_text.contains("tool_calls") {
tracing::trace!("[STREAM_RAW] SSE chunk with tool_calls: {}", raw_text.chars().take(500).collect::<String>());
}
let mut buf = buffer.lock().expect("SSE buffer lock poisoned");
buf.push_str(&raw_text);
let mut events = Vec::new();
let mut st = state.lock().expect("SSE state lock");
while let Some(newline_pos) = buf.find('\n') {
let line = buf[..newline_pos].trim().to_string();
buf.drain(..=newline_pos);
if let Some(json_str) = line.strip_prefix("data: ") {
if json_str == "[DONE]" {
if st.emitted_content_start && !st.emitted_content_stop {
events.push(Ok(StreamEvent::ContentBlockStop { index: 0 }));
st.emitted_content_stop = true;
}
for (_idx, accum) in st.tool_calls.drain() {
let input = super::json_repair::parse_or_repair(&accum.arguments);
tracing::info!(
"[TOOL_EMIT] Flushing tool on DONE: id={}, name={}, args={}",
accum.id, accum.name, &accum.arguments.chars().take(200).collect::<String>()
);
let tool_index = _idx + 1; events.push(Ok(StreamEvent::ContentBlockStart {
index: tool_index,
content_block: ContentBlock::ToolUse {
id: accum.id,
name: accum.name,
input,
},
}));
events.push(Ok(StreamEvent::ContentBlockStop { index: tool_index }));
}
if let Some(stop_reason) = st.pending_stop_reason.take() {
tracing::info!("[STREAM_USAGE] Final usage (fallback on DONE): input={}, output=0", total_input_tokens);
events.push(Ok(StreamEvent::MessageDelta {
delta: crate::brain::provider::types::MessageDelta {
stop_reason: Some(stop_reason),
stop_sequence: None,
},
usage: crate::brain::provider::types::TokenUsage {
input_tokens: total_input_tokens as u32,
output_tokens: 0, ..Default::default() },
}));
}
events.push(Ok(StreamEvent::MessageStop));
continue;
}
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(json_str)
&& let Some(status_msg) = raw.pointer("/base_resp/status_msg").and_then(|v| v.as_str())
{
let status_code = raw.pointer("/base_resp/status_code").and_then(|v| v.as_u64()).unwrap_or(0);
if status_code != 0 {
tracing::error!("[STREAM_ERROR] Provider returned inline error: code={}, msg={}", status_code, status_msg);
events.push(Err(ProviderError::ApiError {
status: status_code as u16,
message: status_msg.to_string(),
error_type: Some("provider_error".to_string()),
}));
continue;
}
}
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(json_str)
&& let Some(err_obj) = raw.get("error").and_then(|v| v.as_object())
{
let message = err_obj.get("message").and_then(|v| v.as_str()).unwrap_or("stream error").to_string();
let err_type = err_obj.get("type").and_then(|v| v.as_str()).map(|s| s.to_string());
let read_status = |key: &str| -> u16 {
err_obj
.get(key)
.and_then(|v| {
v.as_u64().map(|n| n as u16).or_else(|| {
v.as_str().and_then(|s| s.parse::<u16>().ok())
})
})
.unwrap_or(0)
};
let code = {
let c = read_status("code");
if c != 0 { c } else { read_status("http_code") }
};
tracing::error!("[STREAM_ERROR] Inline SSE error: type={:?}, code={}, msg={}", err_type, code, message);
let msg_lc = message.to_lowercase();
let is_rate_limit = code == 429
|| err_type.as_deref() == Some("rate_limit_exceeded")
|| msg_lc.contains("rate limit")
|| msg_lc.contains("quota");
let is_overloaded = code == 529
|| code == 503
|| err_type.as_deref() == Some("overloaded_error")
|| msg_lc.contains("overloaded")
|| msg_lc.contains("server cluster")
|| msg_lc.contains("high load");
let pe = if is_rate_limit {
ProviderError::RateLimitExceeded(message)
} else if is_overloaded {
ProviderError::StreamError(format!(
"upstream overloaded ({}): {}",
code, message
))
} else {
ProviderError::ApiError {
status: if code == 0 { 500 } else { code },
message,
error_type: err_type,
}
};
events.push(Err(pe));
continue;
}
match serde_json::from_str::<OpenAIStreamChunk>(json_str) {
Ok(chunk) => {
if !st.emitted_message_start && !chunk.id.is_empty() {
st.emitted_message_start = true;
let model = chunk.model.clone().unwrap_or_default();
events.push(Ok(StreamEvent::MessageStart {
message: crate::brain::provider::types::StreamMessage {
id: chunk.id,
model,
role: Role::Assistant,
usage: crate::brain::provider::types::TokenUsage {
input_tokens: 0,
output_tokens: 0, ..Default::default() },
},
}));
}
let delta_content = chunk.choices.first()
.and_then(|c| c.delta.as_ref())
.and_then(|d| d.content.as_ref())
.cloned();
let content = if delta_content.is_some() {
if delta_content.as_ref().is_some_and(|s| !s.is_empty()) {
st.seen_delta_content = true;
}
delta_content
} else if !st.seen_delta_content {
chunk.choices.first()
.and_then(|c| c.message.as_ref())
.and_then(|d| d.content.as_ref())
.cloned()
} else {
None
};
let tool_calls = chunk.choices.first()
.and_then(|c| c.delta.as_ref().or(c.message.as_ref()))
.and_then(|d| d.tool_calls.as_ref());
if let Some(tc_list) = tool_calls {
for tc_item in tc_list {
let idx = tc_item.index;
let accum = st.tool_calls.entry(idx).or_default();
if let Some(ref id) = tc_item.id
&& !id.is_empty() {
accum.id = id.clone();
}
if let Some(ref func) = tc_item.function {
if let Some(ref name) = func.name
&& !name.is_empty() {
accum.name = name.clone();
}
if let Some(ref args) = func.arguments {
accum.arguments.push_str(args);
}
}
tracing::debug!(
"[TOOL_ACCUM] idx={}, id={}, name={}, args_len={}, args_tail={}",
idx, accum.id, accum.name, accum.arguments.len(),
accum.arguments.chars().rev().take(60).collect::<String>().chars().rev().collect::<String>()
);
}
}
let finish_reason_str = chunk.choices.first()
.and_then(|c| c.finish_reason.as_ref());
if finish_reason_str.is_some() && !st.tool_calls.is_empty() {
if st.emitted_content_start && !st.emitted_content_stop {
events.push(Ok(StreamEvent::ContentBlockStop { index: 0 }));
st.emitted_content_stop = true;
}
for (idx, accum) in st.tool_calls.drain() {
let input = super::json_repair::parse_or_repair(&accum.arguments);
tracing::info!(
"[TOOL_EMIT] Emitting tool call: idx={}, id={}, name={}, args_len={}",
idx, accum.id, accum.name, accum.arguments.len()
);
let tool_index = idx + 1; events.push(Ok(StreamEvent::ContentBlockStart {
index: tool_index,
content_block: ContentBlock::ToolUse {
id: accum.id,
name: accum.name,
input,
},
}));
events.push(Ok(StreamEvent::ContentBlockStop { index: tool_index }));
}
}
if let Some(ref c) = content {
const LEAK_MARKERS: &[&str] = &[
"{\"tool_calls\"",
"{ \"tool_calls\"",
];
let drop_all = st.leak_active;
let mut to_emit: Option<String> = None;
if drop_all {
if !c.is_empty() {
tracing::debug!(
"[STREAM_FILTER] Suppressing {} chars of content during active tool_calls leak",
c.len()
);
}
} else {
st.leak_probe.push_str(c);
let probe_trimmed = st.leak_probe.trim_start();
let full_match = LEAK_MARKERS
.iter()
.any(|m| probe_trimmed.starts_with(m));
let bare_array_state =
classify_bare_tool_array(probe_trimmed);
if full_match
|| bare_array_state == BareToolArrayMatch::Full
{
tracing::warn!(
"[STREAM_FILTER] Dropping hallucinated tool-call JSON across accumulated content ({} chars buffered, bare_array={})",
st.leak_probe.len(),
bare_array_state == BareToolArrayMatch::Full
);
st.leak_active = true;
st.leak_probe.clear();
} else {
let still_prefix = probe_trimmed.is_empty()
|| LEAK_MARKERS.iter().any(|m| {
m.starts_with(probe_trimmed)
})
|| bare_array_state == BareToolArrayMatch::Prefix;
if still_prefix {
if st.leak_probe.len() > 128 {
to_emit = Some(std::mem::take(&mut st.leak_probe));
}
} else {
to_emit = Some(std::mem::take(&mut st.leak_probe));
}
}
}
if let Some(ref flushed) = to_emit {
let (mut inside, mut close_idx, mut consumed) =
(st.inside_think, st.active_close_tag, st.think_bytes_consumed);
let mut carry = std::mem::take(&mut st.think_carry);
let (filtered, reasoning_from_think) = filter_think_tags(
flushed,
&mut inside,
&mut close_idx,
&mut consumed,
&mut carry,
);
st.inside_think = inside;
st.active_close_tag = close_idx;
st.think_bytes_consumed = consumed;
st.think_carry = carry;
if !reasoning_from_think.is_empty() {
if !st.emitted_content_start {
st.emitted_content_start = true;
events.push(Ok(StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlock::Text { text: String::new() },
}));
}
events.push(Ok(StreamEvent::ContentBlockDelta {
index: 0,
delta: ContentDelta::ReasoningDelta {
text: reasoning_from_think,
},
}));
}
const TOOL_MARKERS: &[&str] = &[
"<tool_call>",
"<function=",
"<|tool\u{2581}calls_section_begin|>",
"<|tool\u{2581}call_begin|>",
"<|tool_calls_section_begin|>",
"<|tool_call_begin|>",
];
let mut display_text: String = String::new();
if st.tool_block_active {
st.tool_capture_buffer.push_str(&filtered);
} else {
let mut working = std::mem::take(&mut st.marker_carry);
working.push_str(&filtered);
let first = TOOL_MARKERS
.iter()
.filter_map(|m| working.find(m).map(|p| (p, *m)))
.min_by_key(|(p, _)| *p);
if let Some((pos, marker)) = first {
let before: String = working[..pos].to_string();
let after = &working[pos..];
st.tool_capture_buffer.push_str(after);
st.tool_block_active = true;
display_text = before;
tracing::info!(
"[STREAM_FILTER] Tool-call marker {:?} detected — routing {} bytes to capture buffer",
marker, after.len()
);
} else {
let tail_keep = tool_marker_prefix_len(
&working,
TOOL_MARKERS,
);
if tail_keep >= working.len() {
st.marker_carry = working;
} else if tail_keep > 0 {
let split = working.len() - tail_keep;
display_text = working[..split].to_string();
st.marker_carry =
working[split..].to_string();
} else {
display_text = working;
}
}
}
if !display_text.is_empty() {
if !st.emitted_content_start {
st.emitted_content_start = true;
events.push(Ok(StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlock::Text { text: String::new() },
}));
}
st.response_text_accum.push_str(&display_text);
events.push(Ok(StreamEvent::ContentBlockDelta {
index: 0,
delta: ContentDelta::TextDelta {
text: display_text,
},
}));
} else if !st.emitted_content_start
&& flushed.is_empty()
&& !st.tool_block_active
{
st.emitted_content_start = true;
events.push(Ok(StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlock::Text { text: String::new() },
}));
}
}
}
let reasoning = chunk.choices.first()
.and_then(|c| c.delta.as_ref())
.and_then(|d| d.reasoning_content.as_ref())
.cloned();
if let Some(rc) = reasoning && !rc.is_empty() {
if !st.emitted_content_start {
st.emitted_content_start = true;
events.push(Ok(StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlock::Text { text: String::new() },
}));
}
events.push(Ok(StreamEvent::ContentBlockDelta {
index: 0,
delta: ContentDelta::ReasoningDelta {
text: rc,
},
}));
}
if let Some(reason) = finish_reason_str {
if !st.leak_active && !st.leak_probe.is_empty() {
let flushed = std::mem::take(&mut st.leak_probe);
let (mut inside, mut close_idx, mut consumed) =
(st.inside_think, st.active_close_tag, st.think_bytes_consumed);
let mut carry = std::mem::take(&mut st.think_carry);
let (filtered, reasoning_from_think) = filter_think_tags(
&flushed,
&mut inside,
&mut close_idx,
&mut consumed,
&mut carry,
);
st.inside_think = inside;
st.active_close_tag = close_idx;
st.think_bytes_consumed = consumed;
st.think_carry = carry;
if !reasoning_from_think.is_empty() {
if !st.emitted_content_start {
st.emitted_content_start = true;
events.push(Ok(StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlock::Text { text: String::new() },
}));
}
events.push(Ok(StreamEvent::ContentBlockDelta {
index: 0,
delta: ContentDelta::ReasoningDelta {
text: reasoning_from_think,
},
}));
}
if !filtered.is_empty() {
if !st.emitted_content_start {
st.emitted_content_start = true;
events.push(Ok(StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlock::Text { text: String::new() },
}));
}
st.response_text_accum.push_str(&filtered);
events.push(Ok(StreamEvent::ContentBlockDelta {
index: 0,
delta: ContentDelta::TextDelta { text: filtered },
}));
}
}
let has_markers_in_accum = st.response_text_accum.contains("<tool_call>")
|| st.response_text_accum.contains("<function=")
|| st.response_text_accum.contains("tool_call:")
|| st.response_text_accum.contains("\"tool_calls\"")
|| st.response_text_accum.contains("\"id\":\"call_")
|| st.response_text_accum.contains("\"id\": \"call_");
let has_capture = !st.tool_capture_buffer.is_empty();
if st.tool_calls.is_empty() && (has_markers_in_accum || has_capture)
{
let mut combined = st.tool_capture_buffer.clone();
if has_markers_in_accum {
let mut prefix = st.response_text_accum.clone();
prefix.push('\n');
prefix.push_str(&combined);
combined = prefix;
}
let (recovered, _cleaned) =
extract_text_tool_calls(&combined);
if !recovered.is_empty() {
tracing::info!(
"Recovered {} streaming tool call(s) from text content (local-model fallback; capture_bytes={}, display_markers={})",
recovered.len(),
st.tool_capture_buffer.len(),
has_markers_in_accum,
);
if st.emitted_content_start && !st.emitted_content_stop {
events.push(Ok(StreamEvent::ContentBlockStop { index: 0 }));
st.emitted_content_stop = true;
}
for (tc_idx, (name, input)) in recovered.into_iter().enumerate() {
let tool_index = tc_idx + 1;
events.push(Ok(StreamEvent::ContentBlockStart {
index: tool_index,
content_block: ContentBlock::ToolUse {
id: format!("call_text_{}", tc_idx),
name,
input,
},
}));
events.push(Ok(StreamEvent::ContentBlockStop { index: tool_index }));
}
}
}
if st.emitted_content_start && !st.emitted_content_stop {
events.push(Ok(StreamEvent::ContentBlockStop { index: 0 }));
st.emitted_content_stop = true;
}
let (raw_input, raw_output, raw_cache_read, raw_cache_create) = if let Some(ref usage) = chunk.usage {
(
usage.prompt_tokens.unwrap_or(0),
usage.completion_tokens.unwrap_or(0),
usage.effective_cache_read(),
usage.cache_creation_input_tokens.unwrap_or(0),
)
} else {
(0, 0, 0, 0)
};
let stop_reason = Some(match reason.as_str() {
"stop" => crate::brain::provider::types::StopReason::EndTurn,
"length" => crate::brain::provider::types::StopReason::MaxTokens,
"tool_calls" | "function_call" => crate::brain::provider::types::StopReason::ToolUse,
_ => crate::brain::provider::types::StopReason::EndTurn,
});
if raw_input > 0 || raw_output > 0 {
tracing::info!(
"[STREAM_USAGE] Final usage (inline): input={}, output={}, cache_read={}, cache_create={}",
raw_input, raw_output, raw_cache_read, raw_cache_create
);
events.push(Ok(StreamEvent::MessageDelta {
delta: crate::brain::provider::types::MessageDelta {
stop_reason,
stop_sequence: None,
},
usage: crate::brain::provider::types::TokenUsage {
input_tokens: raw_input,
output_tokens: raw_output,
cache_creation_tokens: raw_cache_create,
cache_read_tokens: raw_cache_read,
..Default::default()
},
}));
events.push(Ok(StreamEvent::MessageStop));
} else {
st.pending_stop_reason = stop_reason;
}
}
if chunk.choices.is_empty()
&& let Some(ref usage) = chunk.usage {
let input = usage.prompt_tokens.unwrap_or(0);
let output = usage.completion_tokens.unwrap_or(0);
let cache_read = usage.effective_cache_read();
let cache_create = usage.cache_creation_input_tokens.unwrap_or(0);
let reasoning = usage.reasoning_tokens();
if input > 0 || output > 0 {
tracing::info!(
"[STREAM_USAGE] Final usage: input={}, output={}, cache_read={}, cache_create={}, reasoning={}",
input, output, cache_read, cache_create, reasoning
);
events.push(Ok(StreamEvent::MessageDelta {
delta: crate::brain::provider::types::MessageDelta {
stop_reason: st.pending_stop_reason.take(),
stop_sequence: None,
},
usage: crate::brain::provider::types::TokenUsage {
input_tokens: input,
output_tokens: output,
cache_creation_tokens: cache_create,
cache_read_tokens: cache_read,
..Default::default()
},
}));
events.push(Ok(StreamEvent::MessageStop));
}
}
}
Err(e) => {
let json_preview = json_str.chars().take(300).collect::<String>();
tracing::warn!(
"[STREAM_PARSE] Failed to parse chunk: {} | Raw: {}",
e, json_preview
);
}
}
}
}
if events.is_empty()
&& !st.emitted_message_start
&& super::nonstream_compat::is_nonstream_response(&buf)
&& let Some(synth) = super::nonstream_compat::synthesize_stream_events(&buf)
{
st.emitted_message_start = true;
st.emitted_content_start = true;
st.emitted_content_stop = true;
buf.clear();
events.extend(synth);
}
if events.is_empty() {
vec![Ok(StreamEvent::Ping)]
} else {
events
}
}
}
})
.flat_map(futures::stream::iter);
Ok(Box::pin(event_stream))
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_tools(&self) -> bool {
true
}
fn supports_vision(&self) -> bool {
self.vision_model.is_some()
}
fn name(&self) -> &str {
&self.name
}
fn take_retry_notices(&self) -> Vec<(u32, u32, String)> {
self.retry_notices
.lock()
.map(|mut v| std::mem::take(&mut *v))
.unwrap_or_default()
}
fn base_url(&self) -> Option<&str> {
Some(&self.base_url)
}
fn default_model(&self) -> &str {
self.custom_default_model.as_deref().unwrap_or_else(|| {
tracing::error!(
"No default_model configured for provider '{}' — check config.toml",
self.name
);
"MISSING_MODEL"
})
}
fn supported_models(&self) -> Vec<String> {
if !self.configured_models.is_empty() {
return self.configured_models.clone();
}
vec![
"gpt-4-turbo-preview".to_string(),
"gpt-4".to_string(),
"gpt-4-32k".to_string(),
"gpt-3.5-turbo".to_string(),
"gpt-3.5-turbo-16k".to_string(),
]
}
async fn fetch_models(&self) -> Vec<String> {
let models_url = self.base_url.replace("/chat/completions", "/models");
#[derive(Deserialize)]
struct ModelEntry {
id: String,
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<ModelEntry>,
}
let headers = match self.headers() {
Ok(h) => h,
Err(_) => return self.supported_models(),
};
match self.client.get(&models_url).headers(headers).send().await {
Ok(resp) if resp.status().is_success() => match resp.json::<ModelsResponse>().await {
Ok(body) => {
let mut models: Vec<String> = body.data.into_iter().map(|m| m.id).collect();
models.sort();
if models.is_empty() {
return self.supported_models();
}
models
}
Err(_) => self.supported_models(),
},
_ => self.supported_models(),
}
}
fn configured_context_window(&self) -> Option<u32> {
self.configured_context_window
}
fn context_window(&self, model: &str) -> Option<u32> {
if let Some(cw) = self.configured_context_window {
return Some(cw);
}
let m = model.to_lowercase();
if m.starts_with("gpt-5") {
return Some(1_047_576); }
if m.starts_with("gpt-4.1") {
return Some(1_047_576); }
if m.starts_with("o4") || m.starts_with("o3") {
return Some(200_000);
}
if m.starts_with("o1") {
return Some(200_000);
}
if m.starts_with("gpt-4o") {
return Some(128_000);
}
match model {
"gpt-4-turbo" | "gpt-4-turbo-preview" => Some(128_000),
"gpt-4" => Some(8_192),
"gpt-4-32k" => Some(32_768),
"gpt-3.5-turbo" => Some(16_384),
"gpt-3.5-turbo-16k" => Some(16_384),
_ => None,
}
}
fn calculate_cost(&self, model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
crate::usage::pricing::PricingConfig::load()
.map(|cfg| cfg.calculate_cost(model, input_tokens, output_tokens))
.unwrap_or(0.0)
}
}
pub(crate) fn uses_max_completion_tokens(model: &str) -> bool {
let m = model.to_lowercase();
m.starts_with("gpt-4.1")
|| m.starts_with("gpt-5")
|| m.starts_with("o1")
|| m.starts_with("o3")
|| m.starts_with("o4")
|| m.contains("thinking")
}
#[derive(Debug, Clone, Serialize)]
pub(crate) struct OpenAIRequest {
pub(crate) model: String,
pub(crate) messages: Vec<OpenAIMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<OpenAITool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
include_reasoning: Option<bool>,
}
impl OpenAIRequest {
fn swap_token_fields(&mut self) {
let old_max = self.max_tokens.take();
let old_completion = self.max_completion_tokens.take();
self.max_tokens = old_completion;
self.max_completion_tokens = old_max;
}
}
pub(crate) fn is_token_field_mismatch(msg: &str) -> bool {
let m = msg.to_lowercase();
(m.contains("max_tokens") || m.contains("max_completion_tokens")) && m.contains("unsupported")
}
#[derive(Debug, Clone, Serialize)]
struct StreamOptions {
include_usage: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct OpenAIMessage {
pub(crate) role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) content: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) tool_calls: Option<Vec<OpenAIToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) reasoning_content: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct OpenAIToolCall {
pub(crate) id: String,
pub(crate) r#type: String,
pub(crate) function: OpenAIFunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct OpenAIFunctionCall {
pub(crate) name: String,
pub(crate) arguments: String,
}
#[derive(Debug, Clone, Serialize)]
struct OpenAITool {
r#type: String,
function: OpenAIFunction,
}
#[derive(Debug, Clone, Serialize)]
struct OpenAIFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAIResponse {
id: String,
model: String,
choices: Vec<OpenAIChoice>,
usage: OpenAIUsage,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAIChoice {
message: OpenAIMessage,
finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAIUsage {
#[serde(rename = "prompt_tokens")]
prompt_tokens: Option<u32>,
#[serde(rename = "completion_tokens")]
completion_tokens: Option<u32>,
#[serde(default)]
cache_creation_input_tokens: Option<u32>,
#[serde(default)]
cache_read_input_tokens: Option<u32>,
#[serde(default)]
prompt_tokens_details: Option<OpenAIPromptTokensDetails>,
#[serde(default)]
completion_tokens_details: Option<OpenAICompletionTokensDetails>,
}
impl OpenAIUsage {
fn effective_cache_read(&self) -> u32 {
self.cache_read_input_tokens
.or_else(|| {
self.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens)
})
.unwrap_or(0)
}
fn reasoning_tokens(&self) -> u32 {
self.completion_tokens_details
.as_ref()
.and_then(|d| d.reasoning_tokens)
.unwrap_or(0)
}
}
#[derive(Debug, Clone, Deserialize, Default)]
struct OpenAIPromptTokensDetails {
#[serde(default)]
cached_tokens: Option<u32>,
}
#[derive(Debug, Clone, Deserialize, Default)]
struct OpenAICompletionTokensDetails {
#[serde(default)]
reasoning_tokens: Option<u32>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAIStreamChunk {
id: String,
model: Option<String>,
choices: Vec<OpenAIStreamChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
usage: Option<OpenAIUsage>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAIStreamChoice {
delta: Option<OpenAIMessageDelta>,
message: Option<OpenAIMessageDelta>,
finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct StreamingToolCall {
index: usize,
#[serde(default)]
id: Option<String>,
#[serde(default)]
function: Option<StreamingFunctionCall>,
}
#[derive(Debug, Clone, Deserialize)]
struct StreamingFunctionCall {
#[serde(default)]
name: Option<String>,
#[serde(default)]
arguments: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAIMessageDelta {
content: Option<String>,
#[serde(default, alias = "reasoning")]
reasoning_content: Option<String>,
tool_calls: Option<Vec<StreamingToolCall>>,
}
#[derive(Debug, Clone, Deserialize)]
pub(crate) struct OpenAIErrorResponse {
pub(crate) error: OpenAIError,
}
#[derive(Debug, Clone, Deserialize)]
pub(crate) struct OpenAIError {
pub(crate) message: String,
#[serde(rename = "type")]
pub(crate) error_type: Option<String>,
#[serde(default)]
pub(crate) metadata: Option<OpenAIErrorMetadata>,
}
#[derive(Debug, Clone, Deserialize)]
pub(crate) struct OpenAIErrorMetadata {
#[serde(default)]
pub(crate) raw: Option<String>,
#[serde(default)]
pub(crate) provider_name: Option<String>,
}
pub(crate) fn needs_reasoning_content_for(base_url: &str, model: &str) -> bool {
let url = base_url.to_ascii_lowercase();
let model = model.to_ascii_lowercase();
url.contains("moonshot") || (url.contains("opencode.ai") && model.contains("kimi"))
}
pub(crate) fn unwrap_proxy_error(outer: &OpenAIError) -> (String, Option<String>) {
let Some(ref metadata) = outer.metadata else {
return (outer.message.clone(), outer.error_type.clone());
};
let Some(ref raw) = metadata.raw else {
return (outer.message.clone(), outer.error_type.clone());
};
let Ok(inner) = serde_json::from_str::<OpenAIErrorResponse>(raw) else {
let prefix = metadata
.provider_name
.as_deref()
.map(|p| format!("[{}] ", p))
.unwrap_or_default();
return (
format!("{}{}: {}", prefix, outer.message, raw),
outer.error_type.clone(),
);
};
let prefix = metadata
.provider_name
.as_deref()
.map(|p| format!("[{}] ", p))
.unwrap_or_default();
(
format!("{}{}", prefix, inner.error.message),
inner
.error
.error_type
.clone()
.or_else(|| outer.error_type.clone()),
)
}
#[derive(Debug, Clone, Deserialize)]
struct PydanticValidationError {
detail: Vec<PydanticDetail>,
}
#[derive(Debug, Clone, Deserialize)]
struct PydanticDetail {
loc: Vec<serde_json::Value>,
msg: String,
#[serde(rename = "type")]
_kind: Option<String>,
}
fn is_unsloth_studio_url(url: &str) -> bool {
let lower = url.to_ascii_lowercase();
lower.contains("localhost") || lower.contains("127.0.0.1")
}