use std::io::{BufRead, BufReader};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::{
ApiClient, ApiRequest, AssistantEvent, ContentBlock, MessageRole, RuntimeError, TokenUsage,
};
use serde_json::{json, Value};
use crate::tool_groups::ToolRegistry;
pub type TextCallback = Box<dyn Fn(&str) + Send + Sync>;
#[must_use]
pub fn stdout_text_callback() -> TextCallback {
Box::new(|delta: &str| {
use std::io::Write;
let stdout = std::io::stdout();
let mut out = stdout.lock();
let _ = out.write_all(delta.as_bytes());
let _ = out.flush();
})
}
#[must_use]
pub fn tui_text_callback(
tx: std::sync::mpsc::SyncSender<crate::tui_events::TuiEvent>,
) -> TextCallback {
Box::new(move |delta: &str| {
let _ = tx.send(crate::tui_events::TuiEvent::Token(delta.to_string()));
})
}
static TELEGRAM_STREAM_BUFFER: std::sync::OnceLock<Mutex<String>> = std::sync::OnceLock::new();
#[must_use]
pub fn telegram_stream_buffer() -> &'static Mutex<String> {
TELEGRAM_STREAM_BUFFER.get_or_init(|| Mutex::new(String::new()))
}
pub fn telegram_stream_reset() {
if let Ok(mut buf) = telegram_stream_buffer().lock() {
buf.clear();
}
}
#[must_use]
pub fn telegram_text_callback() -> TextCallback {
Box::new(|delta: &str| {
use std::io::Write;
if let Ok(mut buf) = telegram_stream_buffer().lock() {
buf.push_str(delta);
}
let stdout = std::io::stdout();
let mut out = stdout.lock();
let _ = out.write_all(delta.as_bytes());
let _ = out.flush();
})
}
const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
pub const DEFAULT_NUM_CTX: u32 = 16384;
pub const DEFAULT_NUM_PREDICT: u32 = 6144;
#[must_use]
pub fn current_num_ctx() -> u32 {
crate::model_config::active().brain.num_ctx
}
#[must_use]
pub fn current_num_predict() -> u32 {
crate::model_config::active().brain.num_predict
}
const REQUEST_TIMEOUT_SECS: u64 = 300;
const CHARS_PER_TOKEN: usize = 4;
const SAFETY_CHARS: usize = 1024;
pub enum ToolsProvider {
Fixed(Value),
Dynamic(Arc<Mutex<ToolRegistry>>),
}
impl ToolsProvider {
#[must_use]
pub fn current(&self) -> Value {
match self {
Self::Fixed(v) => v.clone(),
Self::Dynamic(reg) => match reg.lock() {
Ok(g) => g.current_tools(),
Err(poisoned) => poisoned.into_inner().current_tools(),
},
}
}
}
pub struct OllamaApiClient {
http: reqwest::blocking::Client,
base_url: String,
model: String,
tools: ToolsProvider,
num_ctx: u32,
num_predict: u32,
text_callback: Option<TextCallback>,
openai_compat: bool,
}
impl OllamaApiClient {
#[must_use]
pub fn new(model: impl Into<String>, tools: Value) -> Self {
Self::build(model.into(), ToolsProvider::Fixed(tools))
}
#[must_use]
pub fn with_registry(model: impl Into<String>, registry: Arc<Mutex<ToolRegistry>>) -> Self {
Self::build(model.into(), ToolsProvider::Dynamic(registry))
}
fn build(model: String, tools: ToolsProvider) -> Self {
let base_url = resolve_ollama_url();
let http = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
.build()
.expect("failed to build reqwest blocking client");
Self {
http,
base_url,
model,
tools,
num_ctx: current_num_ctx(),
num_predict: current_num_predict(),
text_callback: None,
openai_compat: resolve_openai_compat(),
}
}
#[must_use]
pub fn with_openai_compat(mut self, on: bool) -> Self {
self.openai_compat = on;
self
}
#[must_use]
pub fn with_context(mut self, num_ctx: u32) -> Self {
self.num_ctx = num_ctx;
self
}
#[must_use]
pub fn with_max_predict(mut self, num_predict: u32) -> Self {
self.num_predict = num_predict;
self
}
#[must_use]
pub fn with_text_callback(mut self, callback: TextCallback) -> Self {
self.text_callback = Some(callback);
self
}
#[must_use]
pub fn model(&self) -> &str {
&self.model
}
}
#[must_use]
pub fn resolve_ollama_url() -> String {
match std::env::var("OLLAMA_HOST") {
Ok(host) if !host.is_empty() => {
let host = host.trim_end_matches('/');
if host.starts_with("http://") || host.starts_with("https://") {
host.to_string()
} else {
format!("http://{host}")
}
}
_ => DEFAULT_OLLAMA_URL.to_string(),
}
}
#[must_use]
pub fn resolve_max_tools() -> Option<usize> {
std::env::var("CLAUDETTE_MAX_TOOLS")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.filter(|&n| n > 0)
}
fn cap_tools(tools: Value, cap: usize) -> Value {
let Value::Array(mut arr) = tools else {
return tools;
};
if arr.len() <= cap {
return Value::Array(arr);
}
let enable_pos = arr
.iter()
.position(|t| t.pointer("/function/name").and_then(Value::as_str) == Some("enable_tools"));
if let Some(pos) = enable_pos {
let enable = arr.remove(pos);
arr.truncate(cap.saturating_sub(1));
arr.insert(0, enable);
} else {
arr.truncate(cap);
}
Value::Array(arr)
}
#[must_use]
pub fn resolve_openai_compat() -> bool {
std::env::var("CLAUDETTE_OPENAI_COMPAT")
.ok()
.is_some_and(|v| !v.is_empty() && v != "0")
}
#[must_use]
pub fn is_local_ollama_url(url: &str) -> bool {
let rest = if url.len() >= 8 && url[..8].eq_ignore_ascii_case("https://") {
&url[8..]
} else if url.len() >= 7 && url[..7].eq_ignore_ascii_case("http://") {
&url[7..]
} else {
url
};
let rest = rest.split('/').next().unwrap_or(rest);
let host_and_port = match rest.rfind('@') {
Some(idx) => &rest[idx + 1..],
None => rest,
};
let host = if let Some(inside) = host_and_port
.strip_prefix('[')
.and_then(|s| s.split(']').next())
{
inside
} else {
host_and_port.split(':').next().unwrap_or(host_and_port)
};
let host_lower = host.to_ascii_lowercase();
if host_lower == "localhost" || host_lower == "::1" {
return true;
}
if let Some(rest) = host_lower.strip_prefix("127.") {
return rest.split('.').count() == 3
&& rest
.split('.')
.all(|s| !s.is_empty() && s.parse::<u8>().is_ok());
}
false
}
pub fn probe_ollama() -> Result<String, String> {
let url = resolve_ollama_url();
if !is_local_ollama_url(&url)
&& std::env::var("CLAUDETTE_ALLOW_REMOTE_OLLAMA")
.ok()
.map_or(true, |v| v.is_empty() || v == "0")
{
eprintln!(
"âš OLLAMA_HOST points at a non-loopback address: {url}\n\
Every prompt, tool call, and piece of memory/email/calendar\n\
data will be sent to that host. Claudette's default posture\n\
is local-only; a remote endpoint turns it into a cloud client.\n\
If this is intentional, set CLAUDETTE_ALLOW_REMOTE_OLLAMA=1\n\
to silence this warning."
);
}
if std::env::var("CLAUDETTE_SKIP_OLLAMA_PROBE")
.ok()
.is_some_and(|v| !v.is_empty() && v != "0")
{
return Ok(url);
}
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(3))
.build()
.map_err(|e| format!("could not build probe client: {e}"))?;
let (probe_url, mode_label) = if resolve_openai_compat() {
(format!("{url}/v1/models"), "OpenAI-compat brain")
} else {
(url.clone(), "Ollama")
};
match client.get(&probe_url).send() {
Ok(resp) if resp.status().is_success() || resp.status().is_redirection() => Ok(url),
Ok(resp) => Err(format!(
"{mode_label} at {probe_url} returned HTTP {} — is a different service bound to that port?",
resp.status()
)),
Err(e) => Err(format!(
"{mode_label} not reachable at {probe_url} ({e}). Start the server \
(or set OLLAMA_HOST), then retry. Set CLAUDETTE_SKIP_OLLAMA_PROBE=1 to bypass."
)),
}
}
impl ApiClient for OllamaApiClient {
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
let body = self.build_chat_body(&request);
let path = if self.openai_compat {
"/v1/chat/completions"
} else {
"/api/chat"
};
let url = format!("{}{}", self.base_url, path);
let resp = self
.http
.post(&url)
.json(&body)
.send()
.map_err(|e| RuntimeError::new(format!("Brain request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().unwrap_or_default();
return Err(RuntimeError::new(format!(
"Brain HTTP {status}: {}",
text.chars().take(400).collect::<String>()
)));
}
if self.openai_compat {
let body: Value = resp.json().map_err(|e| {
RuntimeError::new(format!("OpenAI-compat response parse failed: {e}"))
})?;
self.parse_openai_response(&body)
} else {
self.consume_stream_lines(BufReader::new(resp))
}
}
}
impl OllamaApiClient {
fn build_chat_body(&self, request: &ApiRequest) -> Value {
let tools = self.tools.current();
let tools = if let Some(cap) = resolve_max_tools() {
cap_tools(tools, cap)
} else {
tools
};
let history_budget = self.history_budget_chars_for_tools(request, &tools);
if self.openai_compat {
let messages = build_messages_openai_compat(request, history_budget);
return json!({
"model": self.model,
"messages": messages,
"tools": tools,
"stream": false,
"temperature": 0.0,
"max_tokens": self.num_predict,
});
}
let messages = build_messages(request, history_budget);
json!({
"model": self.model,
"messages": messages,
"tools": tools,
"stream": true,
"think": false,
"options": {
"temperature": 0.0,
"num_ctx": self.num_ctx,
"num_predict": self.num_predict
}
})
}
fn parse_openai_response(&self, body: &Value) -> Result<Vec<AssistantEvent>, RuntimeError> {
if let Some(err) = body.pointer("/error/message").and_then(Value::as_str) {
return Err(RuntimeError::new(format!("OpenAI-compat error: {err}")));
}
let message = body
.pointer("/choices/0/message")
.ok_or_else(|| RuntimeError::new("OpenAI response missing choices[0].message"))?;
let mut events = Vec::new();
let content = message.get("content").and_then(Value::as_str).unwrap_or("");
if !content.is_empty() {
if let Some(cb) = &self.text_callback {
cb(content);
cb("\n");
}
events.push(AssistantEvent::TextDelta(content.to_string()));
}
if let Some(arr) = message.get("tool_calls").and_then(Value::as_array) {
for (idx, tc) in arr.iter().enumerate() {
let name = tc
.pointer("/function/name")
.and_then(Value::as_str)
.unwrap_or("unknown")
.to_string();
let arguments_str = tc
.pointer("/function/arguments")
.and_then(Value::as_str)
.map_or_else(|| "{}".to_string(), str::to_string);
let id = tc
.get("id")
.and_then(Value::as_str)
.map_or_else(|| format!("call_{idx}"), String::from);
events.push(AssistantEvent::ToolUse {
id,
name,
input: arguments_str,
});
}
}
let usage = body.get("usage");
let input_tokens = usage
.and_then(|u| u.get("prompt_tokens"))
.and_then(Value::as_u64)
.unwrap_or(0) as u32;
let output_tokens = usage
.and_then(|u| u.get("completion_tokens"))
.and_then(Value::as_u64)
.unwrap_or(0) as u32;
events.push(AssistantEvent::Usage(TokenUsage {
input_tokens,
output_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}));
events.push(AssistantEvent::MessageStop);
Ok(events)
}
fn consume_stream_lines<R: BufRead>(
&self,
reader: R,
) -> Result<Vec<AssistantEvent>, RuntimeError> {
let mut accumulated_text = String::new();
let mut tool_calls: Vec<Value> = Vec::new();
let mut input_tokens: u32 = 0;
let mut output_tokens: u32 = 0;
for line in reader.lines() {
let line =
line.map_err(|e| RuntimeError::new(format!("Ollama stream read failed: {e}")))?;
if line.trim().is_empty() {
continue;
}
let chunk: Value = serde_json::from_str(&line)
.map_err(|e| RuntimeError::new(format!("Ollama stream parse failed: {e}")))?;
if let Some(err) = chunk.get("error").and_then(Value::as_str) {
return Err(RuntimeError::new(format!("Ollama error: {err}")));
}
if let Some(content) = chunk.pointer("/message/content").and_then(Value::as_str) {
if !content.is_empty() {
accumulated_text.push_str(content);
if let Some(cb) = &self.text_callback {
cb(content);
}
}
}
if let Some(arr) = chunk
.pointer("/message/tool_calls")
.and_then(Value::as_array)
{
tool_calls.clone_from(arr);
}
if chunk.get("done").and_then(Value::as_bool) == Some(true) {
input_tokens = chunk
.get("prompt_eval_count")
.and_then(Value::as_u64)
.unwrap_or(0) as u32;
output_tokens = chunk.get("eval_count").and_then(Value::as_u64).unwrap_or(0) as u32;
}
}
if !accumulated_text.is_empty() {
if let Some(cb) = &self.text_callback {
cb("\n");
}
}
let mut events = Vec::new();
if !accumulated_text.is_empty() {
events.push(AssistantEvent::TextDelta(accumulated_text));
}
for (idx, tc) in tool_calls.iter().enumerate() {
let name = tc
.pointer("/function/name")
.and_then(Value::as_str)
.unwrap_or("unknown")
.to_string();
let arguments = tc
.pointer("/function/arguments")
.cloned()
.unwrap_or(json!({}));
let input = serde_json::to_string(&arguments).unwrap_or_else(|_| "{}".to_string());
let id = tc
.get("id")
.and_then(Value::as_str)
.map_or_else(|| format!("call_{idx}"), String::from);
events.push(AssistantEvent::ToolUse { id, name, input });
}
events.push(AssistantEvent::Usage(TokenUsage {
input_tokens,
output_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}));
events.push(AssistantEvent::MessageStop);
Ok(events)
}
#[cfg(test)]
fn history_budget_chars(&self, request: &ApiRequest) -> usize {
self.history_budget_chars_for_tools(request, &self.tools.current())
}
fn history_budget_chars_for_tools(&self, request: &ApiRequest, tools: &Value) -> usize {
let total = self.num_ctx as usize * CHARS_PER_TOKEN;
let output = self.num_predict as usize * CHARS_PER_TOKEN;
let system: usize = request
.system_prompt
.iter()
.map(|s| s.len() + 2) .sum();
let tools_chars = tools.to_string().len();
total
.saturating_sub(output)
.saturating_sub(system)
.saturating_sub(tools_chars)
.saturating_sub(SAFETY_CHARS)
}
}
fn build_messages(request: &ApiRequest, history_budget_chars: usize) -> Vec<Value> {
let history = build_history_messages(&request.messages);
let history = truncate_to_budget(history, history_budget_chars);
let mut messages = Vec::with_capacity(history.len() + 1);
let system_prompt = request.system_prompt.join("\n\n");
if !system_prompt.is_empty() {
messages.push(json!({
"role": "system",
"content": system_prompt,
}));
}
messages.extend(history);
messages
}
fn build_history_messages(msgs: &[crate::ConversationMessage]) -> Vec<Value> {
let mut messages = Vec::with_capacity(msgs.len());
for msg in msgs {
let role = role_str(msg.role);
let mut content_parts: Vec<String> = Vec::new();
let mut tool_calls: Vec<Value> = Vec::new();
let mut images: Vec<String> = Vec::new();
for block in &msg.blocks {
match block {
ContentBlock::Text { text } => {
content_parts.push(text.clone());
}
ContentBlock::Image { data_b64, .. } => {
images.push(data_b64.clone());
}
ContentBlock::ToolUse { id, name, input } => {
let arguments: Value =
serde_json::from_str(input).unwrap_or_else(|_| json!({}));
tool_calls.push(json!({
"id": id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}));
}
ContentBlock::ToolResult { output, .. } => {
content_parts.push(output.clone());
}
}
}
let content = content_parts.join("\n");
let mut obj = json!({
"role": role,
"content": content,
});
if !tool_calls.is_empty() {
obj["tool_calls"] = json!(tool_calls);
}
if !images.is_empty() {
obj["images"] = json!(images);
}
messages.push(obj);
}
messages
}
fn build_messages_openai_compat(request: &ApiRequest, history_budget_chars: usize) -> Vec<Value> {
let history = build_history_messages_openai_compat(&request.messages);
let history = truncate_to_budget(history, history_budget_chars);
let mut messages = Vec::with_capacity(history.len() + 1);
let system_prompt = request.system_prompt.join("\n\n");
if !system_prompt.is_empty() {
messages.push(json!({
"role": "system",
"content": system_prompt,
}));
}
messages.extend(history);
messages
}
fn build_history_messages_openai_compat(msgs: &[crate::ConversationMessage]) -> Vec<Value> {
let mut messages = Vec::with_capacity(msgs.len());
for msg in msgs {
if matches!(msg.role, MessageRole::Tool) {
for block in &msg.blocks {
if let ContentBlock::ToolResult {
tool_use_id,
output,
..
} = block
{
messages.push(json!({
"role": "tool",
"tool_call_id": tool_use_id,
"content": output,
}));
}
}
continue;
}
let role = role_str(msg.role);
let mut content_parts: Vec<String> = Vec::new();
let mut tool_calls: Vec<Value> = Vec::new();
let mut image_parts: Vec<Value> = Vec::new();
for block in &msg.blocks {
match block {
ContentBlock::Text { text } => {
content_parts.push(text.clone());
}
ContentBlock::Image {
media_type,
data_b64,
} => {
image_parts.push(json!({
"type": "image_url",
"image_url": {
"url": format!("data:{media_type};base64,{data_b64}")
}
}));
}
ContentBlock::ToolUse { id, name, input } => {
tool_calls.push(json!({
"id": id,
"type": "function",
"function": {
"name": name,
"arguments": input,
}
}));
}
ContentBlock::ToolResult { .. } => {
}
}
}
let content = content_parts.join("\n");
let mut obj = json!({ "role": role });
if tool_calls.is_empty() {
if image_parts.is_empty() {
obj["content"] = json!(content);
} else {
let mut parts: Vec<Value> =
Vec::with_capacity(image_parts.len() + usize::from(!content.is_empty()));
if !content.is_empty() {
parts.push(json!({ "type": "text", "text": content }));
}
parts.extend(image_parts);
obj["content"] = Value::Array(parts);
}
} else {
obj["content"] = if content.is_empty() {
Value::Null
} else {
Value::String(content)
};
obj["tool_calls"] = json!(tool_calls);
}
messages.push(obj);
}
messages
}
fn truncate_to_budget(messages: Vec<Value>, budget_chars: usize) -> Vec<Value> {
let total = messages.len();
if total == 0 {
return Vec::new();
}
let mut must_keep = vec![false; total];
must_keep[total - 1] = true;
if let Some(idx) = messages
.iter()
.rposition(|m| m.get("role").and_then(Value::as_str) == Some("user"))
{
must_keep[idx] = true;
}
let mut kept: Vec<Value> = Vec::with_capacity(total);
let mut used = 0usize;
for (idx_from_end, msg) in messages.into_iter().rev().enumerate() {
let idx = total - 1 - idx_from_end;
let cost = estimate_message_chars(&msg);
let force = must_keep[idx];
if !force && used.saturating_add(cost) > budget_chars {
continue;
}
let role = msg.get("role").and_then(Value::as_str).unwrap_or("");
if role == "tool" && idx > 0 {
must_keep[idx - 1] = true;
}
used = used.saturating_add(cost);
kept.push(msg);
}
kept.reverse();
let kept_len = kept.len();
if kept_len >= 2 {
let drop_mask: Vec<bool> = kept
.iter()
.enumerate()
.map(|(i, msg)| {
if i == kept_len - 1 {
return false;
}
let has_tool_calls = msg
.get("tool_calls")
.and_then(|v| v.as_array())
.is_some_and(|a| !a.is_empty());
has_tool_calls
&& kept
.get(i + 1)
.and_then(|n| n.get("role"))
.and_then(Value::as_str)
!= Some("tool")
})
.collect();
if drop_mask.iter().any(|d| *d) {
kept = kept
.into_iter()
.zip(drop_mask)
.filter(|(_, drop)| !*drop)
.map(|(m, _)| m)
.collect();
}
}
kept
}
fn estimate_message_chars(msg: &Value) -> usize {
let content = match msg.get("content") {
Some(Value::String(s)) => s.len(),
Some(Value::Array(parts)) => parts.iter().map(|p| p.to_string().len()).sum(),
_ => 0,
};
let tools = msg.get("tool_calls").map_or(0, |v| v.to_string().len());
let images = msg.get("images").map_or(0, |v| v.to_string().len());
content + tools + images
}
fn role_str(role: MessageRole) -> &'static str {
match role {
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::System => "system",
MessageRole::Tool => "tool",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ConversationMessage, MessageRole};
#[test]
fn is_local_ollama_url_recognises_loopback() {
for url in [
"http://localhost:11434",
"https://localhost:11434",
"http://LOCALHOST:11434",
"HTTP://localhost:11434",
"HTTPS://localhost:11434",
"http://user:pass@localhost:11434",
"http://127.0.0.1:11434",
"http://127.0.0.2:11434",
"http://127.255.255.255:11434",
"http://[::1]:11434",
"localhost:11434",
"127.0.0.1",
] {
assert!(
is_local_ollama_url(url),
"expected local, but {url} was flagged remote"
);
}
}
#[test]
fn is_local_ollama_url_rejects_remote() {
for url in [
"http://ollama.example.com:11434",
"https://attacker.evil:11434",
"http://192.168.1.10:11434", "http://10.0.0.1:11434", "http://1.2.3.4:11434",
"http://[2001:db8::1]:11434",
"http://0.0.0.0:11434",
"http://[::]:11434",
"http://localhost:fakepass@evil.com:11434",
"http://localhost@evil.com:11434",
] {
assert!(
!is_local_ollama_url(url),
"expected remote, but {url} was flagged local"
);
}
}
#[test]
fn probe_ollama_skip_env_short_circuits() {
let prev_host = std::env::var("OLLAMA_HOST").ok();
let prev_skip = std::env::var("CLAUDETTE_SKIP_OLLAMA_PROBE").ok();
std::env::set_var(
"OLLAMA_HOST",
"http://definitely-not-a-real-host.invalid:11434",
);
std::env::set_var("CLAUDETTE_SKIP_OLLAMA_PROBE", "1");
let result = probe_ollama();
assert!(
result.is_ok(),
"skip env should bypass the probe; got {result:?}"
);
match prev_host {
Some(v) => std::env::set_var("OLLAMA_HOST", v),
None => std::env::remove_var("OLLAMA_HOST"),
}
match prev_skip {
Some(v) => std::env::set_var("CLAUDETTE_SKIP_OLLAMA_PROBE", v),
None => std::env::remove_var("CLAUDETTE_SKIP_OLLAMA_PROBE"),
}
}
fn text_msg(role: &str, content: &str) -> Value {
json!({ "role": role, "content": content })
}
fn user_text(text: &str) -> ConversationMessage {
ConversationMessage {
role: MessageRole::User,
blocks: vec![ContentBlock::Text {
text: text.to_string(),
}],
usage: None,
}
}
#[test]
fn truncate_keeps_everything_when_under_budget() {
let messages = vec![
text_msg("user", "hello"),
text_msg("assistant", "hi"),
text_msg("user", "how are you"),
];
let kept = truncate_to_budget(messages, 1000);
assert_eq!(kept.len(), 3);
assert_eq!(kept[0]["content"], "hello");
assert_eq!(kept[2]["content"], "how are you");
}
#[test]
fn truncate_drops_oldest_first() {
let messages = vec![
text_msg("user", "first-old0"), text_msg("assistant", "second-mi"), text_msg("user", "third-new0"), ];
let kept = truncate_to_budget(messages, 25);
assert_eq!(kept.len(), 2, "expected 2 kept, got {kept:?}");
assert_eq!(kept[0]["content"], "second-mi");
assert_eq!(kept[1]["content"], "third-new0");
}
#[test]
fn truncate_zero_budget_still_keeps_newest() {
let messages = vec![text_msg("user", "anything")];
let kept = truncate_to_budget(messages, 0);
assert_eq!(kept.len(), 1);
assert_eq!(kept[0]["content"], "anything");
}
#[test]
fn truncate_empty_input_returns_empty() {
let kept = truncate_to_budget(Vec::new(), 1000);
assert!(kept.is_empty());
}
#[test]
fn truncate_keeps_oversized_newest_alone() {
let messages = vec![text_msg("user", "way too long for the budget")];
let kept = truncate_to_budget(messages, 5);
assert_eq!(kept.len(), 1, "newest must always survive");
assert_eq!(kept[0]["content"], "way too long for the budget");
}
#[test]
fn truncate_skips_oversized_older_keeps_smaller_oldest() {
let messages = vec![
text_msg("user", "tiny old"), text_msg("assistant", &"X".repeat(500)), text_msg("user", "newest"), ];
let kept = truncate_to_budget(messages, 30);
assert_eq!(kept.len(), 2, "kept: {kept:?}");
assert_eq!(kept[0]["content"], "tiny old");
assert_eq!(kept[1]["content"], "newest");
}
fn assistant_with_tool_call(call_id: &str, fn_name: &str) -> Value {
json!({
"role": "assistant",
"content": "",
"tool_calls": [{
"id": call_id,
"type": "function",
"function": { "name": fn_name, "arguments": "{}" }
}]
})
}
fn openai_tool_msg(call_id: &str, content: &str) -> Value {
json!({
"role": "tool",
"tool_call_id": call_id,
"content": content,
})
}
#[test]
fn truncate_pins_user_query_under_giant_tool_result_ollama_shape() {
let messages = vec![
text_msg("user", "read the big file"),
assistant_with_tool_call("call_1", "read_file"),
text_msg("tool", &"X".repeat(50_000)),
];
let kept = truncate_to_budget(messages, 100);
let last_user = kept
.iter()
.rev()
.find(|m| m.get("role").and_then(Value::as_str) == Some("user"));
assert!(
last_user.is_some(),
"user query must survive even under giant tool results: {kept:?}"
);
assert_eq!(last_user.unwrap()["content"], "read the big file");
}
#[test]
fn truncate_pins_user_query_under_giant_tool_result_openai_shape() {
let messages = vec![
text_msg("user", "read the big file"),
assistant_with_tool_call("call_1", "read_file"),
openai_tool_msg("call_1", &"X".repeat(50_000)),
];
let kept = truncate_to_budget(messages, 100);
let last_user = kept
.iter()
.rev()
.find(|m| m.get("role").and_then(Value::as_str) == Some("user"));
assert!(
last_user.is_some(),
"user query must survive even under giant tool results: {kept:?}"
);
assert_eq!(last_user.unwrap()["content"], "read the big file");
}
#[test]
fn truncate_pins_assistant_tool_calls_when_keeping_tool_ollama_shape() {
let messages = vec![
text_msg("user", "first turn"),
assistant_with_tool_call("call_1", "list_dir"),
text_msg("tool", "(small result)"),
];
let kept = truncate_to_budget(messages, 200);
let last = kept.last().unwrap();
assert_eq!(last["role"], "tool");
let assistant_idx = kept.len() - 2;
assert_eq!(kept[assistant_idx]["role"], "assistant");
assert!(
kept[assistant_idx]["tool_calls"]
.as_array()
.is_some_and(|a| !a.is_empty()),
"expected paired tool_calls before tool, got {kept:?}"
);
}
#[test]
fn truncate_pins_assistant_tool_calls_when_keeping_tool_openai_shape() {
let messages = vec![
text_msg("user", "first turn"),
assistant_with_tool_call("call_1", "list_dir"),
openai_tool_msg("call_1", "(small result)"),
];
let kept = truncate_to_budget(messages, 200);
let last = kept.last().unwrap();
assert_eq!(last["role"], "tool");
assert_eq!(last["tool_call_id"], "call_1");
let assistant_idx = kept.len() - 2;
assert_eq!(kept[assistant_idx]["role"], "assistant");
}
#[test]
fn truncate_drops_orphan_assistant_when_tool_skipped_ollama_shape() {
let messages = vec![
text_msg("user", "what's in src?"),
assistant_with_tool_call("call_1", "list_dir"),
text_msg("tool", &"X".repeat(50_000)),
text_msg("user", "and what about tests?"),
];
let kept = truncate_to_budget(messages, 300);
for (i, msg) in kept.iter().enumerate() {
if i == kept.len() - 1 {
continue;
}
let has_tc = msg
.get("tool_calls")
.and_then(|v| v.as_array())
.is_some_and(|a| !a.is_empty());
if has_tc {
let next = kept.get(i + 1);
let next_is_tool =
next.and_then(|n| n.get("role")).and_then(Value::as_str) == Some("tool");
assert!(
next_is_tool,
"assistant.tool_calls at idx {i} is orphaned (next: {next:?}); full kept: {kept:?}"
);
}
}
}
#[test]
fn truncate_drops_orphan_assistant_when_tool_skipped_openai_shape() {
let messages = vec![
text_msg("user", "what's in src?"),
assistant_with_tool_call("call_1", "list_dir"),
openai_tool_msg("call_1", &"X".repeat(50_000)),
text_msg("user", "and what about tests?"),
];
let kept = truncate_to_budget(messages, 300);
for (i, msg) in kept.iter().enumerate() {
if i == kept.len() - 1 {
continue;
}
let has_tc = msg
.get("tool_calls")
.and_then(|v| v.as_array())
.is_some_and(|a| !a.is_empty());
if has_tc {
let next = kept.get(i + 1);
let next_is_tool =
next.and_then(|n| n.get("role")).and_then(Value::as_str) == Some("tool");
assert!(
next_is_tool,
"assistant.tool_calls at idx {i} is orphaned (next: {next:?}); full kept: {kept:?}"
);
}
}
}
#[test]
fn estimate_message_chars_counts_content_and_tool_calls() {
let plain = text_msg("user", "hello"); assert_eq!(estimate_message_chars(&plain), 5);
let with_tools = json!({
"role": "assistant",
"content": "ok",
"tool_calls": [{ "id": "x", "type": "function", "function": { "name": "f", "arguments": {} }}],
});
let chars = estimate_message_chars(&with_tools);
assert!(chars > 2, "expected >2, got {chars}");
}
#[test]
fn build_messages_always_keeps_system_prompt_and_newest() {
let request = ApiRequest {
messages: vec![user_text("this is the only thing the user said")],
system_prompt: vec!["you are an assistant".to_string()],
};
let result = build_messages(&request, 0);
assert_eq!(result.len(), 2, "expected system + newest, got {result:?}");
assert_eq!(result[0]["role"], "system");
assert_eq!(result[0]["content"], "you are an assistant");
assert_eq!(result[1]["content"], "this is the only thing the user said");
}
#[test]
fn build_messages_truncates_history_under_budget() {
let request = ApiRequest {
messages: vec![
user_text("ancient turn that should fall off"),
user_text("middle turn that should also fall off"),
user_text("newest"),
],
system_prompt: vec!["sys".to_string()],
};
let result = build_messages(&request, 20);
assert_eq!(
result.len(),
2,
"expected system + 1 history, got {result:?}"
);
assert_eq!(result[0]["role"], "system");
assert_eq!(result[1]["content"], "newest");
}
#[test]
fn history_budget_shrinks_with_larger_system_prompt() {
let mut client = OllamaApiClient::new("test", json!([]));
client.num_ctx = 1000; client.num_predict = 100;
let small_sys = ApiRequest {
messages: Vec::new(),
system_prompt: vec!["short".to_string()],
};
let big_sys = ApiRequest {
messages: Vec::new(),
system_prompt: vec!["x".repeat(500)],
};
let small_budget = client.history_budget_chars(&small_sys);
let big_budget = client.history_budget_chars(&big_sys);
assert!(
small_budget > big_budget,
"smaller system prompt should leave more room for history"
);
assert!(big_budget + 500 <= small_budget + 10);
}
use std::io::Cursor;
fn fake_stream(lines: &[&str]) -> Cursor<Vec<u8>> {
Cursor::new(lines.join("\n").into_bytes())
}
#[test]
fn stream_text_only_single_chunk() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"Hello"},"done":false}"#,
r#"{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":10,"eval_count":3}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
assert_eq!(events.len(), 3);
match &events[0] {
AssistantEvent::TextDelta(t) => assert_eq!(t, "Hello"),
other => panic!("expected TextDelta, got {other:?}"),
}
match &events[1] {
AssistantEvent::Usage(u) => {
assert_eq!(u.input_tokens, 10);
assert_eq!(u.output_tokens, 3);
}
other => panic!("expected Usage, got {other:?}"),
}
assert!(matches!(events[2], AssistantEvent::MessageStop));
}
#[test]
fn stream_text_accumulates_multiple_deltas() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"Hel"},"done":false}"#,
r#"{"message":{"role":"assistant","content":"lo, "},"done":false}"#,
r#"{"message":{"role":"assistant","content":"world"},"done":false}"#,
r#"{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":5,"eval_count":7}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
match &events[0] {
AssistantEvent::TextDelta(t) => assert_eq!(t, "Hello, world"),
other => panic!("expected TextDelta, got {other:?}"),
}
}
#[test]
fn stream_tool_call_on_done_chunk() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_time","arguments":{}}}]},"done":true,"prompt_eval_count":20,"eval_count":2}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
assert_eq!(events.len(), 3);
match &events[0] {
AssistantEvent::ToolUse { name, id, .. } => {
assert_eq!(name, "get_time");
assert_eq!(id, "call_1");
}
other => panic!("expected ToolUse, got {other:?}"),
}
}
#[test]
fn stream_text_then_tool_call() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"Let me check"},"done":false}"#,
r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"x","type":"function","function":{"name":"get_time","arguments":{}}}]},"done":true,"prompt_eval_count":15,"eval_count":4}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
assert_eq!(events.len(), 4);
assert!(matches!(&events[0], AssistantEvent::TextDelta(t) if t == "Let me check"));
assert!(matches!(&events[1], AssistantEvent::ToolUse { name, .. } if name == "get_time"));
}
#[test]
fn stream_error_chunk_returns_error() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[r#"{"error":"model not found"}"#]);
let result = client.consume_stream_lines(stream);
assert!(result.is_err());
let err = format!("{:?}", result.unwrap_err());
assert!(err.contains("model not found"), "got: {err}");
}
#[test]
fn stream_missing_id_synthesises_one() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"","tool_calls":[{"type":"function","function":{"name":"a","arguments":{}}}]},"done":true,"prompt_eval_count":0,"eval_count":0}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
match &events[0] {
AssistantEvent::ToolUse { id, .. } => {
assert!(id.starts_with("call_"), "expected synthesised id, got {id}");
}
other => panic!("expected ToolUse, got {other:?}"),
}
}
#[test]
fn stream_empty_returns_only_usage_and_stop() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[]);
let events = client.consume_stream_lines(stream).unwrap();
assert_eq!(events.len(), 2);
match &events[0] {
AssistantEvent::Usage(u) => {
assert_eq!(u.input_tokens, 0);
assert_eq!(u.output_tokens, 0);
}
other => panic!("expected Usage, got {other:?}"),
}
assert!(matches!(events[1], AssistantEvent::MessageStop));
}
#[test]
fn stream_callback_fires_per_delta_and_trailing_newline() {
use std::sync::{Arc, Mutex};
let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let log_clone = log.clone();
let cb: TextCallback = Box::new(move |s: &str| {
log_clone.lock().unwrap().push(s.to_string());
});
let client = OllamaApiClient::new("test", json!([])).with_text_callback(cb);
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"foo"},"done":false}"#,
r#"{"message":{"role":"assistant","content":"bar"},"done":true,"prompt_eval_count":1,"eval_count":1}"#,
]);
let _ = client.consume_stream_lines(stream).unwrap();
let entries = log.lock().unwrap();
assert_eq!(
*entries,
vec!["foo".to_string(), "bar".to_string(), "\n".to_string()],
"callback should fire foo, bar, then trailing \\n"
);
}
#[test]
fn stream_callback_no_trailing_newline_when_only_tool_call() {
use std::sync::{Arc, Mutex};
let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let log_clone = log.clone();
let cb: TextCallback = Box::new(move |s: &str| {
log_clone.lock().unwrap().push(s.to_string());
});
let client = OllamaApiClient::new("test", json!([])).with_text_callback(cb);
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"x","type":"function","function":{"name":"a","arguments":{}}}]},"done":true,"prompt_eval_count":0,"eval_count":0}"#,
]);
let _ = client.consume_stream_lines(stream).unwrap();
let entries = log.lock().unwrap();
assert!(
entries.is_empty(),
"no callbacks expected when content is empty (only a tool call), got {entries:?}"
);
}
#[test]
fn stream_skips_blank_lines() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
"",
r#"{"message":{"role":"assistant","content":"hi"},"done":false}"#,
"",
r#"{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":1,"eval_count":1}"#,
"",
]);
let events = client.consume_stream_lines(stream).unwrap();
assert!(matches!(&events[0], AssistantEvent::TextDelta(t) if t == "hi"));
}
#[test]
fn history_budget_subtracts_tools_schema() {
let request = ApiRequest {
messages: Vec::new(),
system_prompt: vec!["sys".to_string()],
};
let mut empty_tools = OllamaApiClient::new("test", json!([]));
empty_tools.num_ctx = 16384;
empty_tools.num_predict = 1024;
let mut full_tools = OllamaApiClient::new("test", crate::secretary_tools_json());
full_tools.num_ctx = 16384;
full_tools.num_predict = 1024;
let empty_budget = empty_tools.history_budget_chars(&request);
let full_budget = full_tools.history_budget_chars(&request);
let tools_chars = crate::secretary_tools_json().to_string().len();
assert!(
full_budget < empty_budget,
"tool registry must shrink the history budget"
);
let delta = empty_budget - full_budget;
assert!(
delta + 4 >= tools_chars && delta <= tools_chars + 4,
"delta {delta} should approximately equal tools_chars {tools_chars}"
);
}
#[test]
fn resolve_openai_compat_unset_returns_false() {
let prev = std::env::var("CLAUDETTE_OPENAI_COMPAT").ok();
std::env::remove_var("CLAUDETTE_OPENAI_COMPAT");
assert!(!resolve_openai_compat());
if let Some(v) = prev {
std::env::set_var("CLAUDETTE_OPENAI_COMPAT", v);
}
}
#[test]
fn resolve_openai_compat_set_to_one_returns_true() {
let prev = std::env::var("CLAUDETTE_OPENAI_COMPAT").ok();
std::env::set_var("CLAUDETTE_OPENAI_COMPAT", "1");
assert!(resolve_openai_compat());
match prev {
Some(v) => std::env::set_var("CLAUDETTE_OPENAI_COMPAT", v),
None => std::env::remove_var("CLAUDETTE_OPENAI_COMPAT"),
}
}
#[test]
fn resolve_openai_compat_set_to_zero_returns_false() {
let prev = std::env::var("CLAUDETTE_OPENAI_COMPAT").ok();
std::env::set_var("CLAUDETTE_OPENAI_COMPAT", "0");
assert!(!resolve_openai_compat());
match prev {
Some(v) => std::env::set_var("CLAUDETTE_OPENAI_COMPAT", v),
None => std::env::remove_var("CLAUDETTE_OPENAI_COMPAT"),
}
}
#[test]
fn build_chat_body_compat_uses_openai_shape() {
let client = OllamaApiClient::new("openai/gpt-oss-20b", json!([])).with_openai_compat(true);
let req = ApiRequest {
messages: vec![user_text("hi")],
system_prompt: vec!["sys".to_string()],
};
let body = client.build_chat_body(&req);
assert_eq!(body["stream"], json!(false));
assert_eq!(body["temperature"], json!(0.0));
assert!(body.get("max_tokens").is_some(), "max_tokens missing");
assert!(
body.get("think").is_none(),
"think field must NOT be sent in compat mode"
);
assert!(
body.get("options").is_none(),
"options.* must NOT be sent in compat mode"
);
}
#[test]
fn build_chat_body_default_stays_ollama_shape() {
let client = OllamaApiClient::new("qwen3.5:4b", json!([]));
let req = ApiRequest {
messages: vec![user_text("hi")],
system_prompt: vec!["sys".to_string()],
};
let body = client.build_chat_body(&req);
assert_eq!(body["stream"], json!(true));
assert_eq!(body["think"], json!(false));
assert!(
body.get("options").is_some(),
"options.* required for ollama"
);
assert!(
body.get("max_tokens").is_none(),
"max_tokens is openai-only"
);
}
#[test]
fn parse_openai_response_text_only() {
let client = OllamaApiClient::new("test", json!([])).with_openai_compat(true);
let body = json!({
"id": "chatcmpl-x",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": "Hello world"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 10, "completion_tokens": 3, "total_tokens": 13}
});
let events = client.parse_openai_response(&body).unwrap();
assert_eq!(events.len(), 3);
match &events[0] {
AssistantEvent::TextDelta(t) => assert_eq!(t, "Hello world"),
other => panic!("expected TextDelta, got {other:?}"),
}
match &events[1] {
AssistantEvent::Usage(u) => {
assert_eq!(u.input_tokens, 10);
assert_eq!(u.output_tokens, 3);
}
other => panic!("expected Usage, got {other:?}"),
}
assert!(matches!(events[2], AssistantEvent::MessageStop));
}
#[test]
fn parse_openai_response_with_tool_calls() {
let client = OllamaApiClient::new("test", json!([])).with_openai_compat(true);
let body = json!({
"choices": [{
"message": {
"role": "assistant",
"content": null,
"tool_calls": [{
"id": "call_abc",
"type": "function",
"function": {
"name": "get_time",
"arguments": "{\"tz\":\"UTC\"}"
}
}]
},
"finish_reason": "tool_calls"
}],
"usage": {"prompt_tokens": 50, "completion_tokens": 12}
});
let events = client.parse_openai_response(&body).unwrap();
assert_eq!(events.len(), 3);
match &events[0] {
AssistantEvent::ToolUse { id, name, input } => {
assert_eq!(id, "call_abc");
assert_eq!(name, "get_time");
assert_eq!(input, "{\"tz\":\"UTC\"}");
}
other => panic!("expected ToolUse, got {other:?}"),
}
}
#[test]
fn parse_openai_response_text_then_tool_call() {
let client = OllamaApiClient::new("test", json!([])).with_openai_compat(true);
let body = json!({
"choices": [{
"message": {
"role": "assistant",
"content": "Let me check the time.",
"tool_calls": [{
"id": "x",
"type": "function",
"function": {"name": "get_time", "arguments": "{}"}
}]
},
"finish_reason": "tool_calls"
}]
});
let events = client.parse_openai_response(&body).unwrap();
assert_eq!(events.len(), 4); assert!(
matches!(&events[0], AssistantEvent::TextDelta(t) if t == "Let me check the time.")
);
assert!(matches!(&events[1], AssistantEvent::ToolUse { name, .. } if name == "get_time"));
}
#[test]
fn parse_openai_response_error_field_returns_err() {
let client = OllamaApiClient::new("test", json!([])).with_openai_compat(true);
let body =
json!({"error": {"message": "model not found", "type": "invalid_request_error"}});
let result = client.parse_openai_response(&body);
assert!(result.is_err());
let err = format!("{:?}", result.unwrap_err());
assert!(err.contains("model not found"), "got: {err}");
}
#[test]
fn parse_openai_response_missing_choices_is_err() {
let client = OllamaApiClient::new("test", json!([])).with_openai_compat(true);
let body = json!({"id": "x", "object": "chat.completion"});
let result = client.parse_openai_response(&body);
assert!(result.is_err());
}
#[test]
fn parse_openai_response_missing_id_synthesises_one() {
let client = OllamaApiClient::new("test", json!([])).with_openai_compat(true);
let body = json!({
"choices": [{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [{
"type": "function",
"function": {"name": "a", "arguments": "{}"}
}]
}
}]
});
let events = client.parse_openai_response(&body).unwrap();
match &events[0] {
AssistantEvent::ToolUse { id, .. } => {
assert!(id.starts_with("call_"), "expected synthesised id, got {id}");
}
other => panic!("expected ToolUse, got {other:?}"),
}
}
fn fake_tool(name: &str) -> Value {
json!({
"type": "function",
"function": {
"name": name,
"description": format!("desc for {name}"),
"parameters": {"type": "object", "properties": {}, "required": []}
}
})
}
#[test]
fn resolve_max_tools_unset_returns_none() {
let prev = std::env::var("CLAUDETTE_MAX_TOOLS").ok();
std::env::remove_var("CLAUDETTE_MAX_TOOLS");
assert_eq!(resolve_max_tools(), None);
if let Some(v) = prev {
std::env::set_var("CLAUDETTE_MAX_TOOLS", v);
}
}
#[test]
fn resolve_max_tools_zero_is_treated_as_no_cap() {
let prev = std::env::var("CLAUDETTE_MAX_TOOLS").ok();
std::env::set_var("CLAUDETTE_MAX_TOOLS", "0");
assert_eq!(resolve_max_tools(), None);
match prev {
Some(v) => std::env::set_var("CLAUDETTE_MAX_TOOLS", v),
None => std::env::remove_var("CLAUDETTE_MAX_TOOLS"),
}
}
#[test]
fn resolve_max_tools_garbage_returns_none() {
let prev = std::env::var("CLAUDETTE_MAX_TOOLS").ok();
std::env::set_var("CLAUDETTE_MAX_TOOLS", "not-a-number");
assert_eq!(resolve_max_tools(), None);
match prev {
Some(v) => std::env::set_var("CLAUDETTE_MAX_TOOLS", v),
None => std::env::remove_var("CLAUDETTE_MAX_TOOLS"),
}
}
#[test]
fn cap_tools_passthrough_when_under_cap() {
let tools = Value::Array(vec![fake_tool("a"), fake_tool("b")]);
let capped = cap_tools(tools.clone(), 5);
assert_eq!(capped, tools);
}
#[test]
fn cap_tools_truncates_when_over_cap() {
let tools = Value::Array((0..10).map(|i| fake_tool(&format!("t{i}"))).collect());
let capped = cap_tools(tools, 3);
let arr = capped.as_array().unwrap();
assert_eq!(arr.len(), 3);
assert_eq!(arr[0]["function"]["name"], "t0");
assert_eq!(arr[2]["function"]["name"], "t2");
}
#[test]
fn cap_tools_moves_enable_tools_to_front_when_present() {
let tools = Value::Array(vec![
fake_tool("a"),
fake_tool("b"),
fake_tool("enable_tools"),
fake_tool("c"),
fake_tool("d"),
]);
let capped = cap_tools(tools, 3);
let arr = capped.as_array().unwrap();
assert_eq!(arr.len(), 3);
assert_eq!(arr[0]["function"]["name"], "enable_tools");
assert_eq!(arr[1]["function"]["name"], "a");
assert_eq!(arr[2]["function"]["name"], "b");
}
#[test]
fn cap_tools_keeps_enable_tools_at_front_when_already_first() {
let tools = Value::Array(vec![
fake_tool("enable_tools"),
fake_tool("a"),
fake_tool("b"),
fake_tool("c"),
]);
let capped = cap_tools(tools, 2);
let arr = capped.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["function"]["name"], "enable_tools");
assert_eq!(arr[1]["function"]["name"], "a");
}
#[test]
fn cap_tools_preserves_enable_tools_even_when_cap_is_one() {
let tools = Value::Array(vec![
fake_tool("a"),
fake_tool("b"),
fake_tool("enable_tools"),
]);
let capped = cap_tools(tools, 1);
let arr = capped.as_array().unwrap();
assert_eq!(arr.len(), 1);
assert_eq!(arr[0]["function"]["name"], "enable_tools");
}
#[test]
fn cap_tools_passes_through_non_array() {
let v = json!({"not": "an array"});
assert_eq!(cap_tools(v.clone(), 5), v);
}
#[test]
fn cap_tools_no_enable_tools_just_takes_first_n() {
let tools = Value::Array(vec![fake_tool("a"), fake_tool("b"), fake_tool("c")]);
let capped = cap_tools(tools, 2);
let arr = capped.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["function"]["name"], "a");
assert_eq!(arr[1]["function"]["name"], "b");
}
#[test]
fn openai_history_emits_separate_tool_messages_with_tool_call_id() {
let msgs = vec![
ConversationMessage {
role: MessageRole::Assistant,
blocks: vec![ContentBlock::ToolUse {
id: "call_a".into(),
name: "note_list".into(),
input: "{}".into(),
}],
usage: None,
},
ConversationMessage {
role: MessageRole::Tool,
blocks: vec![ContentBlock::ToolResult {
tool_use_id: "call_a".into(),
tool_name: "note_list".into(),
output: "no notes yet".into(),
is_error: false,
}],
usage: None,
},
];
let out = build_history_messages_openai_compat(&msgs);
assert_eq!(out.len(), 2);
assert_eq!(out[0]["role"], "assistant");
assert!(
out[0]["content"].is_null(),
"assistant content must be JSON null when tool_calls present, got {:?}",
out[0]["content"]
);
assert_eq!(out[0]["tool_calls"][0]["id"], "call_a");
assert_eq!(out[0]["tool_calls"][0]["function"]["arguments"], "{}");
assert_eq!(out[1]["role"], "tool");
assert_eq!(out[1]["tool_call_id"], "call_a");
assert_eq!(out[1]["content"], "no notes yet");
}
#[test]
fn openai_history_assistant_with_text_and_tool_calls_keeps_both() {
let msgs = vec![ConversationMessage {
role: MessageRole::Assistant,
blocks: vec![
ContentBlock::Text {
text: "Looking up your notes.".into(),
},
ContentBlock::ToolUse {
id: "x".into(),
name: "note_list".into(),
input: "{\"limit\":5}".into(),
},
],
usage: None,
}];
let out = build_history_messages_openai_compat(&msgs);
assert_eq!(out.len(), 1);
assert_eq!(out[0]["content"], "Looking up your notes.");
assert_eq!(out[0]["tool_calls"][0]["function"]["name"], "note_list");
assert_eq!(
out[0]["tool_calls"][0]["function"]["arguments"],
"{\"limit\":5}"
);
}
#[test]
fn openai_history_plain_text_message_unchanged() {
let msgs = vec![ConversationMessage {
role: MessageRole::User,
blocks: vec![ContentBlock::Text { text: "hey".into() }],
usage: None,
}];
let out = build_history_messages_openai_compat(&msgs);
assert_eq!(out.len(), 1);
assert_eq!(out[0]["role"], "user");
assert_eq!(out[0]["content"], "hey");
assert!(out[0].get("tool_calls").is_none());
}
#[test]
fn openai_history_multiple_tool_results_in_one_message_become_separate() {
let msgs = vec![ConversationMessage {
role: MessageRole::Tool,
blocks: vec![
ContentBlock::ToolResult {
tool_use_id: "id1".into(),
tool_name: "a".into(),
output: "result one".into(),
is_error: false,
},
ContentBlock::ToolResult {
tool_use_id: "id2".into(),
tool_name: "b".into(),
output: "result two".into(),
is_error: false,
},
],
usage: None,
}];
let out = build_history_messages_openai_compat(&msgs);
assert_eq!(out.len(), 2);
assert_eq!(out[0]["tool_call_id"], "id1");
assert_eq!(out[0]["content"], "result one");
assert_eq!(out[1]["tool_call_id"], "id2");
assert_eq!(out[1]["content"], "result two");
}
#[test]
fn build_chat_body_compat_uses_openai_history_shape() {
let client = OllamaApiClient::new("openai/gpt-oss-20b", json!([])).with_openai_compat(true);
let req = ApiRequest {
messages: vec![
user_text("show notes"),
ConversationMessage {
role: MessageRole::Assistant,
blocks: vec![ContentBlock::ToolUse {
id: "c1".into(),
name: "note_list".into(),
input: "{}".into(),
}],
usage: None,
},
ConversationMessage {
role: MessageRole::Tool,
blocks: vec![ContentBlock::ToolResult {
tool_use_id: "c1".into(),
tool_name: "note_list".into(),
output: "[]".into(),
is_error: false,
}],
usage: None,
},
],
system_prompt: vec!["sys".to_string()],
};
let body = client.build_chat_body(&req);
let msgs = body["messages"].as_array().expect("messages array");
assert_eq!(msgs.len(), 4);
assert_eq!(msgs[0]["role"], "system");
assert_eq!(msgs[1]["role"], "user");
assert_eq!(msgs[2]["role"], "assistant");
assert!(msgs[2]["content"].is_null());
assert_eq!(msgs[3]["role"], "tool");
assert_eq!(msgs[3]["tool_call_id"], "c1");
}
#[test]
fn parse_openai_response_callback_fires_with_full_text() {
use std::sync::{Arc, Mutex};
let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let log_clone = log.clone();
let cb: TextCallback = Box::new(move |s: &str| {
log_clone.lock().unwrap().push(s.to_string());
});
let client = OllamaApiClient::new("test", json!([]))
.with_openai_compat(true)
.with_text_callback(cb);
let body = json!({
"choices": [{
"message": {"role": "assistant", "content": "foo bar"}
}]
});
let _ = client.parse_openai_response(&body).unwrap();
let entries = log.lock().unwrap();
assert_eq!(
*entries,
vec!["foo bar".to_string(), "\n".to_string()],
"callback should fire full text + trailing newline (no per-token streaming yet)"
);
}
#[test]
fn dynamic_registry_budget_shrinks_when_group_is_enabled() {
use crate::tool_groups::{ToolGroup, ToolRegistry};
let registry = Arc::new(Mutex::new(ToolRegistry::new()));
let mut client = OllamaApiClient::with_registry("test", registry.clone());
client.num_ctx = 16384;
client.num_predict = 1024;
let request = ApiRequest {
messages: Vec::new(),
system_prompt: vec!["sys".to_string()],
};
let before = client.history_budget_chars(&request);
registry.lock().unwrap().enable(ToolGroup::Git);
let after = client.history_budget_chars(&request);
assert!(
after < before,
"enabling a tool group must shrink the history budget (before={before}, after={after})"
);
}
}