use std::io::{BufRead, BufReader};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::{
ApiClient, ApiRequest, AssistantEvent, ContentBlock, MessageRole, RuntimeError, TokenUsage,
};
use serde_json::{json, Value};
use crate::tool_groups::ToolRegistry;
pub type TextCallback = Box<dyn Fn(&str) + Send + Sync>;
#[must_use]
pub fn stdout_text_callback() -> TextCallback {
Box::new(|delta: &str| {
use std::io::Write;
let stdout = std::io::stdout();
let mut out = stdout.lock();
let _ = out.write_all(delta.as_bytes());
let _ = out.flush();
})
}
#[must_use]
pub fn tui_text_callback(
tx: std::sync::mpsc::SyncSender<crate::tui_events::TuiEvent>,
) -> TextCallback {
Box::new(move |delta: &str| {
let _ = tx.send(crate::tui_events::TuiEvent::Token(delta.to_string()));
})
}
static TELEGRAM_STREAM_BUFFER: std::sync::OnceLock<Mutex<String>> = std::sync::OnceLock::new();
#[must_use]
pub fn telegram_stream_buffer() -> &'static Mutex<String> {
TELEGRAM_STREAM_BUFFER.get_or_init(|| Mutex::new(String::new()))
}
pub fn telegram_stream_reset() {
if let Ok(mut buf) = telegram_stream_buffer().lock() {
buf.clear();
}
}
#[must_use]
pub fn telegram_text_callback() -> TextCallback {
Box::new(|delta: &str| {
use std::io::Write;
if let Ok(mut buf) = telegram_stream_buffer().lock() {
buf.push_str(delta);
}
let stdout = std::io::stdout();
let mut out = stdout.lock();
let _ = out.write_all(delta.as_bytes());
let _ = out.flush();
})
}
const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
pub const DEFAULT_NUM_CTX: u32 = 16384;
pub const DEFAULT_NUM_PREDICT: u32 = 6144;
#[must_use]
pub fn current_num_ctx() -> u32 {
crate::model_config::active().brain.num_ctx
}
#[must_use]
pub fn current_num_predict() -> u32 {
crate::model_config::active().brain.num_predict
}
const REQUEST_TIMEOUT_SECS: u64 = 300;
const CHARS_PER_TOKEN: usize = 4;
const SAFETY_CHARS: usize = 1024;
pub enum ToolsProvider {
Fixed(Value),
Dynamic(Arc<Mutex<ToolRegistry>>),
}
impl ToolsProvider {
#[must_use]
pub fn current(&self) -> Value {
match self {
Self::Fixed(v) => v.clone(),
Self::Dynamic(reg) => match reg.lock() {
Ok(g) => g.current_tools(),
Err(poisoned) => poisoned.into_inner().current_tools(),
},
}
}
}
pub struct OllamaApiClient {
http: reqwest::blocking::Client,
base_url: String,
model: String,
tools: ToolsProvider,
num_ctx: u32,
num_predict: u32,
text_callback: Option<TextCallback>,
}
impl OllamaApiClient {
#[must_use]
pub fn new(model: impl Into<String>, tools: Value) -> Self {
Self::build(model.into(), ToolsProvider::Fixed(tools))
}
#[must_use]
pub fn with_registry(model: impl Into<String>, registry: Arc<Mutex<ToolRegistry>>) -> Self {
Self::build(model.into(), ToolsProvider::Dynamic(registry))
}
fn build(model: String, tools: ToolsProvider) -> Self {
let base_url = resolve_ollama_url();
let http = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
.build()
.expect("failed to build reqwest blocking client");
Self {
http,
base_url,
model,
tools,
num_ctx: current_num_ctx(),
num_predict: current_num_predict(),
text_callback: None,
}
}
#[must_use]
pub fn with_context(mut self, num_ctx: u32) -> Self {
self.num_ctx = num_ctx;
self
}
#[must_use]
pub fn with_max_predict(mut self, num_predict: u32) -> Self {
self.num_predict = num_predict;
self
}
#[must_use]
pub fn with_text_callback(mut self, callback: TextCallback) -> Self {
self.text_callback = Some(callback);
self
}
#[must_use]
pub fn model(&self) -> &str {
&self.model
}
}
#[must_use]
pub fn resolve_ollama_url() -> String {
match std::env::var("OLLAMA_HOST") {
Ok(host) if !host.is_empty() => {
let host = host.trim_end_matches('/');
if host.starts_with("http://") || host.starts_with("https://") {
host.to_string()
} else {
format!("http://{host}")
}
}
_ => DEFAULT_OLLAMA_URL.to_string(),
}
}
#[must_use]
pub fn is_local_ollama_url(url: &str) -> bool {
let rest = if url.len() >= 8 && url[..8].eq_ignore_ascii_case("https://") {
&url[8..]
} else if url.len() >= 7 && url[..7].eq_ignore_ascii_case("http://") {
&url[7..]
} else {
url
};
let rest = rest.split('/').next().unwrap_or(rest);
let host_and_port = match rest.rfind('@') {
Some(idx) => &rest[idx + 1..],
None => rest,
};
let host = if let Some(inside) = host_and_port
.strip_prefix('[')
.and_then(|s| s.split(']').next())
{
inside
} else {
host_and_port.split(':').next().unwrap_or(host_and_port)
};
let host_lower = host.to_ascii_lowercase();
if host_lower == "localhost" || host_lower == "::1" {
return true;
}
if let Some(rest) = host_lower.strip_prefix("127.") {
return rest.split('.').count() == 3
&& rest
.split('.')
.all(|s| !s.is_empty() && s.parse::<u8>().is_ok());
}
false
}
pub fn probe_ollama() -> Result<String, String> {
let url = resolve_ollama_url();
if !is_local_ollama_url(&url)
&& std::env::var("CLAUDETTE_ALLOW_REMOTE_OLLAMA")
.ok()
.map_or(true, |v| v.is_empty() || v == "0")
{
eprintln!(
"âš OLLAMA_HOST points at a non-loopback address: {url}\n\
Every prompt, tool call, and piece of memory/email/calendar\n\
data will be sent to that host. Claudette's default posture\n\
is local-only; a remote endpoint turns it into a cloud client.\n\
If this is intentional, set CLAUDETTE_ALLOW_REMOTE_OLLAMA=1\n\
to silence this warning."
);
}
if std::env::var("CLAUDETTE_SKIP_OLLAMA_PROBE")
.ok()
.is_some_and(|v| !v.is_empty() && v != "0")
{
return Ok(url);
}
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(3))
.build()
.map_err(|e| format!("could not build probe client: {e}"))?;
match client.get(&url).send() {
Ok(resp) if resp.status().is_success() || resp.status().is_redirection() => Ok(url),
Ok(resp) => Err(format!(
"Ollama at {url} returned HTTP {} — is a different service bound to that port?",
resp.status()
)),
Err(e) => Err(format!(
"Ollama not reachable at {url} ({e}). Start it with `ollama serve` \
(or set OLLAMA_HOST), then retry. Set CLAUDETTE_SKIP_OLLAMA_PROBE=1 to bypass."
)),
}
}
impl ApiClient for OllamaApiClient {
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
let body = self.build_chat_body(&request);
let url = format!("{}/api/chat", self.base_url);
let resp = self
.http
.post(&url)
.json(&body)
.send()
.map_err(|e| RuntimeError::new(format!("Ollama request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().unwrap_or_default();
return Err(RuntimeError::new(format!(
"Ollama HTTP {status}: {}",
text.chars().take(400).collect::<String>()
)));
}
self.consume_stream_lines(BufReader::new(resp))
}
}
impl OllamaApiClient {
fn build_chat_body(&self, request: &ApiRequest) -> Value {
let tools = self.tools.current();
let history_budget = self.history_budget_chars_for_tools(request, &tools);
let messages = build_messages(request, history_budget);
json!({
"model": self.model,
"messages": messages,
"tools": tools,
"stream": true,
"think": false,
"options": {
"temperature": 0.0,
"num_ctx": self.num_ctx,
"num_predict": self.num_predict
}
})
}
fn consume_stream_lines<R: BufRead>(
&self,
reader: R,
) -> Result<Vec<AssistantEvent>, RuntimeError> {
let mut accumulated_text = String::new();
let mut tool_calls: Vec<Value> = Vec::new();
let mut input_tokens: u32 = 0;
let mut output_tokens: u32 = 0;
for line in reader.lines() {
let line =
line.map_err(|e| RuntimeError::new(format!("Ollama stream read failed: {e}")))?;
if line.trim().is_empty() {
continue;
}
let chunk: Value = serde_json::from_str(&line)
.map_err(|e| RuntimeError::new(format!("Ollama stream parse failed: {e}")))?;
if let Some(err) = chunk.get("error").and_then(Value::as_str) {
return Err(RuntimeError::new(format!("Ollama error: {err}")));
}
if let Some(content) = chunk.pointer("/message/content").and_then(Value::as_str) {
if !content.is_empty() {
accumulated_text.push_str(content);
if let Some(cb) = &self.text_callback {
cb(content);
}
}
}
if let Some(arr) = chunk
.pointer("/message/tool_calls")
.and_then(Value::as_array)
{
tool_calls.clone_from(arr);
}
if chunk.get("done").and_then(Value::as_bool) == Some(true) {
input_tokens = chunk
.get("prompt_eval_count")
.and_then(Value::as_u64)
.unwrap_or(0) as u32;
output_tokens = chunk.get("eval_count").and_then(Value::as_u64).unwrap_or(0) as u32;
}
}
if !accumulated_text.is_empty() {
if let Some(cb) = &self.text_callback {
cb("\n");
}
}
let mut events = Vec::new();
if !accumulated_text.is_empty() {
events.push(AssistantEvent::TextDelta(accumulated_text));
}
for (idx, tc) in tool_calls.iter().enumerate() {
let name = tc
.pointer("/function/name")
.and_then(Value::as_str)
.unwrap_or("unknown")
.to_string();
let arguments = tc
.pointer("/function/arguments")
.cloned()
.unwrap_or(json!({}));
let input = serde_json::to_string(&arguments).unwrap_or_else(|_| "{}".to_string());
let id = tc
.get("id")
.and_then(Value::as_str)
.map_or_else(|| format!("call_{idx}"), String::from);
events.push(AssistantEvent::ToolUse { id, name, input });
}
events.push(AssistantEvent::Usage(TokenUsage {
input_tokens,
output_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}));
events.push(AssistantEvent::MessageStop);
Ok(events)
}
#[cfg(test)]
fn history_budget_chars(&self, request: &ApiRequest) -> usize {
self.history_budget_chars_for_tools(request, &self.tools.current())
}
fn history_budget_chars_for_tools(&self, request: &ApiRequest, tools: &Value) -> usize {
let total = self.num_ctx as usize * CHARS_PER_TOKEN;
let output = self.num_predict as usize * CHARS_PER_TOKEN;
let system: usize = request
.system_prompt
.iter()
.map(|s| s.len() + 2) .sum();
let tools_chars = tools.to_string().len();
total
.saturating_sub(output)
.saturating_sub(system)
.saturating_sub(tools_chars)
.saturating_sub(SAFETY_CHARS)
}
}
fn build_messages(request: &ApiRequest, history_budget_chars: usize) -> Vec<Value> {
let history = build_history_messages(&request.messages);
let history = truncate_to_budget(history, history_budget_chars);
let mut messages = Vec::with_capacity(history.len() + 1);
let system_prompt = request.system_prompt.join("\n\n");
if !system_prompt.is_empty() {
messages.push(json!({
"role": "system",
"content": system_prompt,
}));
}
messages.extend(history);
messages
}
fn build_history_messages(msgs: &[crate::ConversationMessage]) -> Vec<Value> {
let mut messages = Vec::with_capacity(msgs.len());
for msg in msgs {
let role = role_str(msg.role);
let mut content_parts: Vec<String> = Vec::new();
let mut tool_calls: Vec<Value> = Vec::new();
for block in &msg.blocks {
match block {
ContentBlock::Text { text } => {
content_parts.push(text.clone());
}
ContentBlock::ToolUse { id, name, input } => {
let arguments: Value =
serde_json::from_str(input).unwrap_or_else(|_| json!({}));
tool_calls.push(json!({
"id": id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}));
}
ContentBlock::ToolResult { output, .. } => {
content_parts.push(output.clone());
}
}
}
let content = content_parts.join("\n");
let mut obj = json!({
"role": role,
"content": content,
});
if !tool_calls.is_empty() {
obj["tool_calls"] = json!(tool_calls);
}
messages.push(obj);
}
messages
}
fn truncate_to_budget(messages: Vec<Value>, budget_chars: usize) -> Vec<Value> {
let mut kept: Vec<Value> = Vec::with_capacity(messages.len());
let mut used = 0usize;
let total = messages.len();
for (idx_from_end, msg) in messages.into_iter().rev().enumerate() {
let cost = estimate_message_chars(&msg);
let is_newest = idx_from_end == 0 && total > 0;
if !is_newest && used.saturating_add(cost) > budget_chars {
continue;
}
used = used.saturating_add(cost);
kept.push(msg);
}
kept.reverse();
kept
}
fn estimate_message_chars(msg: &Value) -> usize {
let content = msg
.get("content")
.and_then(Value::as_str)
.map_or(0, str::len);
let tools = msg.get("tool_calls").map_or(0, |v| v.to_string().len());
content + tools
}
fn role_str(role: MessageRole) -> &'static str {
match role {
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::System => "system",
MessageRole::Tool => "tool",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ConversationMessage, MessageRole};
#[test]
fn is_local_ollama_url_recognises_loopback() {
for url in [
"http://localhost:11434",
"https://localhost:11434",
"http://LOCALHOST:11434",
"HTTP://localhost:11434",
"HTTPS://localhost:11434",
"http://user:pass@localhost:11434",
"http://127.0.0.1:11434",
"http://127.0.0.2:11434",
"http://127.255.255.255:11434",
"http://[::1]:11434",
"localhost:11434",
"127.0.0.1",
] {
assert!(
is_local_ollama_url(url),
"expected local, but {url} was flagged remote"
);
}
}
#[test]
fn is_local_ollama_url_rejects_remote() {
for url in [
"http://ollama.example.com:11434",
"https://attacker.evil:11434",
"http://192.168.1.10:11434", "http://10.0.0.1:11434", "http://1.2.3.4:11434",
"http://[2001:db8::1]:11434",
"http://0.0.0.0:11434",
"http://[::]:11434",
"http://localhost:fakepass@evil.com:11434",
"http://localhost@evil.com:11434",
] {
assert!(
!is_local_ollama_url(url),
"expected remote, but {url} was flagged local"
);
}
}
#[test]
fn probe_ollama_skip_env_short_circuits() {
let prev_host = std::env::var("OLLAMA_HOST").ok();
let prev_skip = std::env::var("CLAUDETTE_SKIP_OLLAMA_PROBE").ok();
std::env::set_var(
"OLLAMA_HOST",
"http://definitely-not-a-real-host.invalid:11434",
);
std::env::set_var("CLAUDETTE_SKIP_OLLAMA_PROBE", "1");
let result = probe_ollama();
assert!(
result.is_ok(),
"skip env should bypass the probe; got {result:?}"
);
match prev_host {
Some(v) => std::env::set_var("OLLAMA_HOST", v),
None => std::env::remove_var("OLLAMA_HOST"),
}
match prev_skip {
Some(v) => std::env::set_var("CLAUDETTE_SKIP_OLLAMA_PROBE", v),
None => std::env::remove_var("CLAUDETTE_SKIP_OLLAMA_PROBE"),
}
}
fn text_msg(role: &str, content: &str) -> Value {
json!({ "role": role, "content": content })
}
fn user_text(text: &str) -> ConversationMessage {
ConversationMessage {
role: MessageRole::User,
blocks: vec![ContentBlock::Text {
text: text.to_string(),
}],
usage: None,
}
}
#[test]
fn truncate_keeps_everything_when_under_budget() {
let messages = vec![
text_msg("user", "hello"),
text_msg("assistant", "hi"),
text_msg("user", "how are you"),
];
let kept = truncate_to_budget(messages, 1000);
assert_eq!(kept.len(), 3);
assert_eq!(kept[0]["content"], "hello");
assert_eq!(kept[2]["content"], "how are you");
}
#[test]
fn truncate_drops_oldest_first() {
let messages = vec![
text_msg("user", "first-old0"), text_msg("assistant", "second-mi"), text_msg("user", "third-new0"), ];
let kept = truncate_to_budget(messages, 25);
assert_eq!(kept.len(), 2, "expected 2 kept, got {kept:?}");
assert_eq!(kept[0]["content"], "second-mi");
assert_eq!(kept[1]["content"], "third-new0");
}
#[test]
fn truncate_zero_budget_still_keeps_newest() {
let messages = vec![text_msg("user", "anything")];
let kept = truncate_to_budget(messages, 0);
assert_eq!(kept.len(), 1);
assert_eq!(kept[0]["content"], "anything");
}
#[test]
fn truncate_empty_input_returns_empty() {
let kept = truncate_to_budget(Vec::new(), 1000);
assert!(kept.is_empty());
}
#[test]
fn truncate_keeps_oversized_newest_alone() {
let messages = vec![text_msg("user", "way too long for the budget")];
let kept = truncate_to_budget(messages, 5);
assert_eq!(kept.len(), 1, "newest must always survive");
assert_eq!(kept[0]["content"], "way too long for the budget");
}
#[test]
fn truncate_skips_oversized_older_keeps_smaller_oldest() {
let messages = vec![
text_msg("user", "tiny old"), text_msg("assistant", &"X".repeat(500)), text_msg("user", "newest"), ];
let kept = truncate_to_budget(messages, 30);
assert_eq!(kept.len(), 2, "kept: {kept:?}");
assert_eq!(kept[0]["content"], "tiny old");
assert_eq!(kept[1]["content"], "newest");
}
#[test]
fn estimate_message_chars_counts_content_and_tool_calls() {
let plain = text_msg("user", "hello"); assert_eq!(estimate_message_chars(&plain), 5);
let with_tools = json!({
"role": "assistant",
"content": "ok",
"tool_calls": [{ "id": "x", "type": "function", "function": { "name": "f", "arguments": {} }}],
});
let chars = estimate_message_chars(&with_tools);
assert!(chars > 2, "expected >2, got {chars}");
}
#[test]
fn build_messages_always_keeps_system_prompt_and_newest() {
let request = ApiRequest {
messages: vec![user_text("this is the only thing the user said")],
system_prompt: vec!["you are an assistant".to_string()],
};
let result = build_messages(&request, 0);
assert_eq!(result.len(), 2, "expected system + newest, got {result:?}");
assert_eq!(result[0]["role"], "system");
assert_eq!(result[0]["content"], "you are an assistant");
assert_eq!(result[1]["content"], "this is the only thing the user said");
}
#[test]
fn build_messages_truncates_history_under_budget() {
let request = ApiRequest {
messages: vec![
user_text("ancient turn that should fall off"),
user_text("middle turn that should also fall off"),
user_text("newest"),
],
system_prompt: vec!["sys".to_string()],
};
let result = build_messages(&request, 20);
assert_eq!(
result.len(),
2,
"expected system + 1 history, got {result:?}"
);
assert_eq!(result[0]["role"], "system");
assert_eq!(result[1]["content"], "newest");
}
#[test]
fn history_budget_shrinks_with_larger_system_prompt() {
let mut client = OllamaApiClient::new("test", json!([]));
client.num_ctx = 1000; client.num_predict = 100;
let small_sys = ApiRequest {
messages: Vec::new(),
system_prompt: vec!["short".to_string()],
};
let big_sys = ApiRequest {
messages: Vec::new(),
system_prompt: vec!["x".repeat(500)],
};
let small_budget = client.history_budget_chars(&small_sys);
let big_budget = client.history_budget_chars(&big_sys);
assert!(
small_budget > big_budget,
"smaller system prompt should leave more room for history"
);
assert!(big_budget + 500 <= small_budget + 10);
}
use std::io::Cursor;
fn fake_stream(lines: &[&str]) -> Cursor<Vec<u8>> {
Cursor::new(lines.join("\n").into_bytes())
}
#[test]
fn stream_text_only_single_chunk() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"Hello"},"done":false}"#,
r#"{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":10,"eval_count":3}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
assert_eq!(events.len(), 3);
match &events[0] {
AssistantEvent::TextDelta(t) => assert_eq!(t, "Hello"),
other => panic!("expected TextDelta, got {other:?}"),
}
match &events[1] {
AssistantEvent::Usage(u) => {
assert_eq!(u.input_tokens, 10);
assert_eq!(u.output_tokens, 3);
}
other => panic!("expected Usage, got {other:?}"),
}
assert!(matches!(events[2], AssistantEvent::MessageStop));
}
#[test]
fn stream_text_accumulates_multiple_deltas() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"Hel"},"done":false}"#,
r#"{"message":{"role":"assistant","content":"lo, "},"done":false}"#,
r#"{"message":{"role":"assistant","content":"world"},"done":false}"#,
r#"{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":5,"eval_count":7}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
match &events[0] {
AssistantEvent::TextDelta(t) => assert_eq!(t, "Hello, world"),
other => panic!("expected TextDelta, got {other:?}"),
}
}
#[test]
fn stream_tool_call_on_done_chunk() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_time","arguments":{}}}]},"done":true,"prompt_eval_count":20,"eval_count":2}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
assert_eq!(events.len(), 3);
match &events[0] {
AssistantEvent::ToolUse { name, id, .. } => {
assert_eq!(name, "get_time");
assert_eq!(id, "call_1");
}
other => panic!("expected ToolUse, got {other:?}"),
}
}
#[test]
fn stream_text_then_tool_call() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"Let me check"},"done":false}"#,
r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"x","type":"function","function":{"name":"get_time","arguments":{}}}]},"done":true,"prompt_eval_count":15,"eval_count":4}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
assert_eq!(events.len(), 4);
assert!(matches!(&events[0], AssistantEvent::TextDelta(t) if t == "Let me check"));
assert!(matches!(&events[1], AssistantEvent::ToolUse { name, .. } if name == "get_time"));
}
#[test]
fn stream_error_chunk_returns_error() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[r#"{"error":"model not found"}"#]);
let result = client.consume_stream_lines(stream);
assert!(result.is_err());
let err = format!("{:?}", result.unwrap_err());
assert!(err.contains("model not found"), "got: {err}");
}
#[test]
fn stream_missing_id_synthesises_one() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"","tool_calls":[{"type":"function","function":{"name":"a","arguments":{}}}]},"done":true,"prompt_eval_count":0,"eval_count":0}"#,
]);
let events = client.consume_stream_lines(stream).unwrap();
match &events[0] {
AssistantEvent::ToolUse { id, .. } => {
assert!(id.starts_with("call_"), "expected synthesised id, got {id}");
}
other => panic!("expected ToolUse, got {other:?}"),
}
}
#[test]
fn stream_empty_returns_only_usage_and_stop() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[]);
let events = client.consume_stream_lines(stream).unwrap();
assert_eq!(events.len(), 2);
match &events[0] {
AssistantEvent::Usage(u) => {
assert_eq!(u.input_tokens, 0);
assert_eq!(u.output_tokens, 0);
}
other => panic!("expected Usage, got {other:?}"),
}
assert!(matches!(events[1], AssistantEvent::MessageStop));
}
#[test]
fn stream_callback_fires_per_delta_and_trailing_newline() {
use std::sync::{Arc, Mutex};
let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let log_clone = log.clone();
let cb: TextCallback = Box::new(move |s: &str| {
log_clone.lock().unwrap().push(s.to_string());
});
let client = OllamaApiClient::new("test", json!([])).with_text_callback(cb);
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"foo"},"done":false}"#,
r#"{"message":{"role":"assistant","content":"bar"},"done":true,"prompt_eval_count":1,"eval_count":1}"#,
]);
let _ = client.consume_stream_lines(stream).unwrap();
let entries = log.lock().unwrap();
assert_eq!(
*entries,
vec!["foo".to_string(), "bar".to_string(), "\n".to_string()],
"callback should fire foo, bar, then trailing \\n"
);
}
#[test]
fn stream_callback_no_trailing_newline_when_only_tool_call() {
use std::sync::{Arc, Mutex};
let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let log_clone = log.clone();
let cb: TextCallback = Box::new(move |s: &str| {
log_clone.lock().unwrap().push(s.to_string());
});
let client = OllamaApiClient::new("test", json!([])).with_text_callback(cb);
let stream = fake_stream(&[
r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"x","type":"function","function":{"name":"a","arguments":{}}}]},"done":true,"prompt_eval_count":0,"eval_count":0}"#,
]);
let _ = client.consume_stream_lines(stream).unwrap();
let entries = log.lock().unwrap();
assert!(
entries.is_empty(),
"no callbacks expected when content is empty (only a tool call), got {entries:?}"
);
}
#[test]
fn stream_skips_blank_lines() {
let client = OllamaApiClient::new("test", json!([]));
let stream = fake_stream(&[
"",
r#"{"message":{"role":"assistant","content":"hi"},"done":false}"#,
"",
r#"{"message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":1,"eval_count":1}"#,
"",
]);
let events = client.consume_stream_lines(stream).unwrap();
assert!(matches!(&events[0], AssistantEvent::TextDelta(t) if t == "hi"));
}
#[test]
fn history_budget_subtracts_tools_schema() {
let request = ApiRequest {
messages: Vec::new(),
system_prompt: vec!["sys".to_string()],
};
let mut empty_tools = OllamaApiClient::new("test", json!([]));
empty_tools.num_ctx = 16384;
empty_tools.num_predict = 1024;
let mut full_tools = OllamaApiClient::new("test", crate::secretary_tools_json());
full_tools.num_ctx = 16384;
full_tools.num_predict = 1024;
let empty_budget = empty_tools.history_budget_chars(&request);
let full_budget = full_tools.history_budget_chars(&request);
let tools_chars = crate::secretary_tools_json().to_string().len();
assert!(
full_budget < empty_budget,
"tool registry must shrink the history budget"
);
let delta = empty_budget - full_budget;
assert!(
delta + 4 >= tools_chars && delta <= tools_chars + 4,
"delta {delta} should approximately equal tools_chars {tools_chars}"
);
}
#[test]
fn dynamic_registry_budget_shrinks_when_group_is_enabled() {
use crate::tool_groups::{ToolGroup, ToolRegistry};
let registry = Arc::new(Mutex::new(ToolRegistry::new()));
let mut client = OllamaApiClient::with_registry("test", registry.clone());
client.num_ctx = 16384;
client.num_predict = 1024;
let request = ApiRequest {
messages: Vec::new(),
system_prompt: vec!["sys".to_string()],
};
let before = client.history_budget_chars(&request);
registry.lock().unwrap().enable(ToolGroup::Git);
let after = client.history_budget_chars(&request);
assert!(
after < before,
"enabling a tool group must shrink the history budget (before={before}, after={after})"
);
}
}