use anyhow::{Context, Result, anyhow};
use serde_json::{Value, json};
use std::sync::Mutex;
use std::time::{Duration, Instant};
pub(crate) const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
fn block_on_local<F, Fut, T>(make_fut: F) -> T
where
F: FnOnce() -> Fut + Send,
Fut: std::future::Future<Output = T>,
T: Send,
{
if let Ok(handle) = tokio::runtime::Handle::try_current() {
match handle.runtime_flavor() {
tokio::runtime::RuntimeFlavor::MultiThread => {
tokio::task::block_in_place(|| handle.block_on(make_fut()))
}
_ => {
std::thread::scope(|s| {
s.spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("ephemeral runtime builds")
.block_on(make_fut())
})
.join()
.expect(
"block_on_local current-thread bridge thread panicked; \
underlying future panicked",
)
})
}
}
} else {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("ephemeral runtime builds")
.block_on(make_fut())
}
}
pub const BACKEND_OLLAMA: &str = "ollama";
pub const OPENAI_COMPAT_EMBEDDINGS_PATH: &str = "/embeddings";
pub(crate) fn default_base_url_for_alias(alias: &str) -> Option<&'static str> {
match alias {
"openai" => Some("https://api.openai.com/v1"),
"xai" => Some("https://api.x.ai/v1"),
"anthropic" => Some("https://api.anthropic.com/v1"),
"gemini" => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
"deepseek" => Some("https://api.deepseek.com/v1"),
"kimi" | "moonshot" => Some("https://api.moonshot.cn/v1"),
"qwen" | "dashscope" => Some("https://dashscope.aliyuncs.com/compatible-mode/v1"),
"mistral" => Some("https://api.mistral.ai/v1"),
"groq" => Some("https://api.groq.com/openai/v1"),
"together" => Some("https://api.together.xyz/v1"),
"cerebras" => Some("https://api.cerebras.ai/v1"),
"openrouter" => Some("https://openrouter.ai/api/v1"),
"fireworks" => Some("https://api.fireworks.ai/inference/v1"),
"lmstudio" => Some("http://localhost:1234/v1"),
_ => None,
}
}
pub(crate) fn ollama_tags_url(base_url: &str) -> String {
format!("{base_url}/api/tags")
}
fn alias_api_key_env_vars(alias: &str) -> &'static [&'static str] {
match alias {
"openai" => &["OPENAI_API_KEY"],
"xai" => &["XAI_API_KEY"],
"anthropic" => &["ANTHROPIC_API_KEY"],
"gemini" => &["GEMINI_API_KEY", "GOOGLE_API_KEY"],
"deepseek" => &["DEEPSEEK_API_KEY"],
"kimi" | "moonshot" => &["MOONSHOT_API_KEY", "KIMI_API_KEY"],
"qwen" | "dashscope" => &["DASHSCOPE_API_KEY", "QWEN_API_KEY"],
"mistral" => &["MISTRAL_API_KEY"],
"groq" => &["GROQ_API_KEY"],
"together" => &["TOGETHER_API_KEY"],
"cerebras" => &["CEREBRAS_API_KEY"],
"openrouter" => &["OPENROUTER_API_KEY"],
"fireworks" => &["FIREWORKS_API_KEY"],
_ => &[],
}
}
#[derive(Clone)]
pub enum LlmProvider {
Ollama,
OpenAiCompatible { api_key: String },
}
impl std::fmt::Debug for LlmProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LlmProvider::Ollama => f.debug_struct("Ollama").finish(),
LlmProvider::OpenAiCompatible { .. } => f
.debug_struct("OpenAiCompatible")
.field("api_key", &"<redacted>")
.finish(),
}
}
}
impl LlmProvider {
pub fn zeroize_secrets(&mut self) {
if let LlmProvider::OpenAiCompatible { api_key } = self {
use zeroize::Zeroize;
api_key.zeroize();
}
}
}
impl Drop for LlmProvider {
fn drop(&mut self) {
self.zeroize_secrets();
}
}
const GENERATE_TIMEOUT: Duration = Duration::from_secs(30);
const PULL_TIMEOUT: Duration = Duration::from_secs(120);
const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const HEALTH_TIMEOUT: Duration = Duration::from_secs(5);
const CIRCUIT_BREAKER_COOLDOWN: Duration = Duration::from_secs(30);
const CIRCUIT_BREAKER_THRESHOLD: u32 = 3;
const EMBED_BATCH_MAX_INPUTS: usize = 100;
const EMBED_BATCH_MAX_BYTES: usize = 256 * 1024;
const MAX_LLM_RESPONSE_BYTES: usize = 16 * 1024 * 1024;
async fn read_capped_bytes(resp: reqwest::Response) -> Result<Vec<u8>> {
read_capped_bytes_inner(resp, MAX_LLM_RESPONSE_BYTES).await
}
async fn read_capped_bytes_inner(mut resp: reqwest::Response, cap: usize) -> Result<Vec<u8>> {
if let Some(len) = resp.content_length() {
if len > cap as u64 {
return Err(anyhow!(
"LLM response too large: Content-Length {len} exceeds cap of {cap} bytes"
));
}
}
let mut buf: Vec<u8> = Vec::new();
while let Some(chunk) = resp
.chunk()
.await
.context("Failed to read LLM response chunk")?
{
if buf.len().saturating_add(chunk.len()) > cap {
return Err(anyhow!(
"LLM response exceeded cap of {cap} bytes while streaming"
));
}
buf.extend_from_slice(&chunk);
}
Ok(buf)
}
async fn read_capped_json(resp: reqwest::Response) -> Result<Value> {
let bytes = read_capped_bytes(resp).await?;
serde_json::from_slice(&bytes).context("Failed to parse LLM response body as JSON")
}
async fn read_capped_text(resp: reqwest::Response) -> String {
match read_capped_bytes(resp).await {
Ok(bytes) => String::from_utf8_lossy(&bytes).into_owned(),
Err(e) => format!("<error body unavailable: {e}>"),
}
}
fn parse_openai_embeddings_batch(body: &Value, expected_len: usize) -> Result<Vec<Vec<f32>>> {
let data = body["data"]
.as_array()
.ok_or_else(|| anyhow!("Missing 'data' array in OpenAI-compatible embed response"))?;
if data.len() != expected_len {
return Err(anyhow!(
"Embed response carried {} vector(s) for {expected_len} input(s)",
data.len()
));
}
let mut out: Vec<Option<Vec<f32>>> = vec![None; expected_len];
for (pos, item) in data.iter().enumerate() {
let idx = match item["index"].as_u64() {
Some(i) => usize::try_from(i)
.map_err(|_| anyhow!("Embed response 'index' {i} does not fit usize"))?,
None => pos,
};
if idx >= expected_len {
return Err(anyhow!(
"Embed response 'index' {idx} out of range for {expected_len} input(s)"
));
}
if out[idx].is_some() {
return Err(anyhow!("Embed response carried duplicate 'index' {idx}"));
}
let arr = item["embedding"].as_array().ok_or_else(|| {
anyhow!("Missing 'data[{pos}].embedding' in OpenAI-compatible embed response")
})?;
#[allow(clippy::cast_possible_truncation)]
let floats: Vec<f32> = arr
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
if floats.is_empty() {
return Err(anyhow!("Empty embedding at index {idx} in embed response"));
}
out[idx] = Some(floats);
}
Ok(out.into_iter().flatten().collect())
}
const QUERY_EXPANSION_PROMPT: &str = r"You are a search query expander. Given a search query, generate 5-8 additional search terms that are semantically related. Return ONLY the terms, one per line, no numbering or explanation.
Query: {query}";
const SUMMARIZE_PROMPT: &str = r"Summarize the following memories into a single concise paragraph. Preserve all key facts, decisions, and technical details.
{memories}";
const AUTO_TAG_PROMPT: &str = r"Generate 3-5 short tags for categorizing this memory. Return ONLY the tags, one per line, lowercase, no symbols.
Title: {title}
Content: {content}";
const CONTRADICTION_PROMPT: &str = r#"Do these two statements contradict each other? Answer ONLY "yes" or "no".
Statement A: {a}
Statement B: {b}"#;
#[derive(Debug)]
struct BreakerState {
consecutive_failures: u32,
last_failure_at: Option<Instant>,
}
impl BreakerState {
const fn new() -> Self {
Self {
consecutive_failures: 0,
last_failure_at: None,
}
}
fn is_open(&self) -> bool {
if self.consecutive_failures < CIRCUIT_BREAKER_THRESHOLD {
return false;
}
match self.last_failure_at {
Some(t) => t.elapsed() < CIRCUIT_BREAKER_COOLDOWN,
None => false,
}
}
fn record_failure(&mut self) {
self.consecutive_failures = self.consecutive_failures.saturating_add(1);
self.last_failure_at = Some(Instant::now());
}
fn record_success(&mut self) {
self.consecutive_failures = 0;
self.last_failure_at = None;
}
}
pub struct OllamaClient {
provider: LlmProvider,
base_url: String,
model: String,
client: reqwest::Client,
breaker: Mutex<BreakerState>,
embed_dimensions: Option<u32>,
}
impl OllamaClient {
#[must_use]
pub fn model_name(&self) -> &str {
&self.model
}
#[allow(dead_code)]
pub fn new(model: &str) -> Result<Self> {
Self::new_with_url(DEFAULT_OLLAMA_URL, model)
}
#[cfg(test)]
pub fn new_for_testing(model: &str) -> Self {
Self {
provider: LlmProvider::Ollama,
base_url: DEFAULT_OLLAMA_URL.trim_end_matches('/').to_string(),
model: model.to_string(),
client: reqwest::Client::builder()
.timeout(GENERATE_TIMEOUT)
.connect_timeout(CONNECT_TIMEOUT)
.build()
.expect("test reqwest client builds"),
breaker: Mutex::new(BreakerState::new()),
embed_dimensions: None,
}
}
#[allow(clippy::too_many_lines)]
pub fn from_env() -> Result<Option<Self>> {
let backend = std::env::var("AI_MEMORY_LLM_BACKEND")
.ok()
.map(|s| s.trim().to_ascii_lowercase())
.unwrap_or_else(|| BACKEND_OLLAMA.to_string());
let model = std::env::var("AI_MEMORY_LLM_MODEL")
.ok()
.filter(|s| !s.trim().is_empty())
.unwrap_or_else(|| match backend.as_str() {
"xai" => "grok-4.3".to_string(),
"openai" => "gpt-5".to_string(),
"anthropic" => "claude-opus-4.7".to_string(),
"gemini" => "gemini-2.0-flash".to_string(),
"deepseek" => "deepseek-chat".to_string(),
"kimi" | "moonshot" => "moonshot-v1-8k".to_string(),
"qwen" | "dashscope" => "qwen-max".to_string(),
"mistral" => "mistral-large-latest".to_string(),
"groq" => "llama-3.3-70b-versatile".to_string(),
"together" => "meta-llama/Llama-3.3-70B-Instruct-Turbo".to_string(),
"cerebras" => "llama-3.3-70b".to_string(),
"openrouter" => "openai/gpt-5".to_string(),
"fireworks" => "accounts/fireworks/models/llama-v3p3-70b-instruct".to_string(),
"lmstudio" => "local-model".to_string(),
_ => "gemma3:4b".to_string(),
});
match backend.as_str() {
BACKEND_OLLAMA => {
let base_url = std::env::var("AI_MEMORY_LLM_BASE_URL")
.ok()
.or_else(|| std::env::var("OLLAMA_BASE_URL").ok())
.filter(|s| !s.trim().is_empty())
.unwrap_or_else(|| DEFAULT_OLLAMA_URL.to_string());
Self::new_with_url(&base_url, &model).map(Some)
}
"openai-compatible" => {
let base_url = std::env::var("AI_MEMORY_LLM_BASE_URL")
.ok()
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| {
anyhow!(
"AI_MEMORY_LLM_BACKEND=openai-compatible requires \
AI_MEMORY_LLM_BASE_URL to be set (no default URL \
— operator must supply the vendor's endpoint)"
)
})?;
let api_key = std::env::var("AI_MEMORY_LLM_API_KEY")
.ok()
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| {
anyhow!(
"AI_MEMORY_LLM_BACKEND=openai-compatible requires \
AI_MEMORY_LLM_API_KEY to be set"
)
})?;
Self::new_openai_compatible(&base_url, &model, &api_key).map(Some)
}
alias => {
let Some(default_url) = default_base_url_for_alias(alias) else {
return Err(anyhow!(
"AI_MEMORY_LLM_BACKEND={alias} is not a recognized \
backend alias. Valid values: ollama, openai-compatible, \
openai, xai, anthropic, gemini, deepseek, kimi, qwen, \
mistral, groq, together, cerebras, openrouter, \
fireworks, lmstudio"
));
};
let base_url = std::env::var("AI_MEMORY_LLM_BASE_URL")
.ok()
.filter(|s| !s.trim().is_empty())
.unwrap_or_else(|| default_url.to_string());
let api_key = std::env::var("AI_MEMORY_LLM_API_KEY")
.ok()
.filter(|s| !s.trim().is_empty())
.or_else(|| {
alias_api_key_env_vars(alias).iter().find_map(|name| {
std::env::var(name).ok().filter(|s| !s.trim().is_empty())
})
})
.ok_or_else(|| {
anyhow!(
"AI_MEMORY_LLM_BACKEND={alias} requires an API key \
— set AI_MEMORY_LLM_API_KEY or one of the \
per-vendor env vars: {:?}",
alias_api_key_env_vars(alias)
)
})?;
Self::new_openai_compatible(&base_url, &model, &api_key).map(Some)
}
}
}
pub fn build_for_init(legacy_url: &str, legacy_model: &str) -> Result<Option<Self>> {
let backend_env = std::env::var("AI_MEMORY_LLM_BACKEND")
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
if backend_env.is_some() {
return Self::from_env();
}
Self::new_with_url(legacy_url, legacy_model).map(Some)
}
pub fn build_from_resolved(resolved: &crate::config::ResolvedLlm) -> Result<Option<Self>> {
tracing::debug!(
"LLM client construction via #1146 resolver — backend={}, model={}, base_url={}, key_source={}, source={}",
resolved.backend,
resolved.model,
resolved.base_url,
resolved.api_key_source.as_str(),
resolved.source.as_str(),
);
if resolved.backend == BACKEND_OLLAMA {
return Self::new_with_url(&resolved.base_url, &resolved.model).map(Some);
}
let Some(api_key) = resolved.api_key() else {
return Err(anyhow!(
"LLM backend `{}` requires an API key but the resolver \
produced none. KeySource = {}. Configure either \
AI_MEMORY_LLM_API_KEY, a per-vendor env var (e.g. \
XAI_API_KEY), [llm].api_key_env, or [llm].api_key_file \
in config.toml. See \
https://github.com/alphaonedev/ai-memory-mcp/issues/1146",
resolved.backend,
resolved.api_key_source.as_str(),
));
};
Self::new_openai_compatible(&resolved.base_url, &resolved.model, api_key).map(Some)
}
pub async fn build_from_resolved_async(
resolved: &crate::config::ResolvedLlm,
) -> Result<Option<Self>> {
tracing::debug!(
"LLM client construction via #1146 resolver (async, FX-D1) — backend={}, model={}, base_url={}, key_source={}, source={}",
resolved.backend,
resolved.model,
resolved.base_url,
resolved.api_key_source.as_str(),
resolved.source.as_str(),
);
if resolved.backend == BACKEND_OLLAMA {
return Self::new_with_url_async(&resolved.base_url, &resolved.model)
.await
.map(Some);
}
let Some(api_key) = resolved.api_key() else {
return Err(anyhow!(
"LLM backend `{}` requires an API key but the resolver \
produced none. KeySource = {}. Configure either \
AI_MEMORY_LLM_API_KEY, a per-vendor env var (e.g. \
XAI_API_KEY), [llm].api_key_env, or [llm].api_key_file \
in config.toml. See \
https://github.com/alphaonedev/ai-memory-mcp/issues/1146",
resolved.backend,
resolved.api_key_source.as_str(),
));
};
Self::new_openai_compatible(&resolved.base_url, &resolved.model, api_key).map(Some)
}
#[must_use]
pub fn is_ollama_native(&self) -> bool {
matches!(self.provider, LlmProvider::Ollama)
}
pub fn new_openai_compatible(base_url: &str, model: &str, api_key: &str) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(GENERATE_TIMEOUT)
.connect_timeout(CONNECT_TIMEOUT)
.build()
.context("Failed to build HTTP client")?;
Ok(Self {
provider: LlmProvider::OpenAiCompatible {
api_key: api_key.to_string(),
},
base_url: base_url.trim_end_matches('/').to_string(),
model: model.to_string(),
client,
breaker: Mutex::new(BreakerState::new()),
embed_dimensions: None,
})
}
#[must_use]
pub fn with_embed_dimensions(mut self, dims: Option<u32>) -> Self {
self.embed_dimensions = dims;
self
}
pub fn new_with_url(base_url: &str, model: &str) -> Result<Self> {
block_on_local(|| Self::new_with_url_async(base_url, model))
}
pub async fn new_with_url_async(base_url: &str, model: &str) -> Result<Self> {
let instance = Self::new_with_url_no_health_check(base_url, model)?;
if !instance.is_available_async().await {
return Err(anyhow!(
"Ollama is not running or not reachable at {}. \
Start it with: ollama serve",
instance.base_url
));
}
Ok(instance)
}
pub fn new_with_url_no_health_check(base_url: &str, model: &str) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(GENERATE_TIMEOUT)
.connect_timeout(CONNECT_TIMEOUT)
.build()
.context("Failed to build HTTP client")?;
Ok(Self {
provider: LlmProvider::Ollama,
base_url: base_url.trim_end_matches('/').to_string(),
model: model.to_string(),
client,
breaker: Mutex::new(BreakerState::new()),
embed_dimensions: None,
})
}
fn breaker_is_open(&self) -> bool {
self.breaker.lock().map(|b| b.is_open()).unwrap_or(false)
}
fn note_failure(&self) {
if let Ok(mut b) = self.breaker.lock() {
b.record_failure();
}
}
fn note_success(&self) {
if let Ok(mut b) = self.breaker.lock() {
b.record_success();
}
}
#[doc(hidden)]
pub fn circuit_breaker_open(&self) -> bool {
self.breaker_is_open()
}
pub fn is_available(&self) -> bool {
block_on_local(|| self.is_available_async())
}
pub async fn is_available_async(&self) -> bool {
let (url, bearer) = match &self.provider {
LlmProvider::Ollama => (ollama_tags_url(&self.base_url), None),
LlmProvider::OpenAiCompatible { api_key } => {
(format!("{}/models", self.base_url), Some(api_key.as_str()))
}
};
let mut req = self.client.get(&url).timeout(HEALTH_TIMEOUT);
if let Some(key) = bearer {
req = req.bearer_auth(key);
}
match req.send().await {
Ok(r) => r.status().is_success(),
Err(_) => false,
}
}
pub fn ensure_model(&self) -> Result<()> {
block_on_local(|| self.ensure_model_async())
}
pub async fn ensure_model_async(&self) -> Result<()> {
if matches!(self.provider, LlmProvider::OpenAiCompatible { .. }) {
return Ok(());
}
let url = ollama_tags_url(&self.base_url);
let resp = self
.client
.get(&url)
.timeout(Duration::from_secs(10))
.send()
.await
.context("Failed to list Ollama models")?;
let body: Value = read_capped_json(resp)
.await
.context("Failed to parse /api/tags response")?;
let model_exists = body["models"].as_array().is_some_and(|models| {
models.iter().any(|m| {
let name = m["name"].as_str().unwrap_or("");
let our_base = self.model.split(':').next().unwrap_or(&self.model);
name == self.model
|| name.starts_with(&format!("{}:", self.model))
|| self.model == name.split(':').next().unwrap_or("")
|| name == our_base
})
});
if model_exists {
return Ok(());
}
tracing::info!(
"Pulling Ollama model '{}' (this may take a while)...",
self.model
);
let pull_url = format!("{}/api/pull", self.base_url);
let pull_client = reqwest::Client::builder()
.timeout(PULL_TIMEOUT)
.build()
.context("Failed to build pull client")?;
let resp = pull_client
.post(&pull_url)
.json(&json!({ "name": self.model }))
.send()
.await
.context("Failed to pull model from Ollama")?;
if !resp.status().is_success() {
let status = resp.status();
let text = read_capped_text(resp).await;
return Err(anyhow!("Ollama pull failed ({status}): {text}"));
}
tracing::info!("Model '{}' pulled successfully", self.model);
Ok(())
}
pub fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String> {
block_on_local(|| self.generate_async(prompt, system))
}
pub async fn generate_async(&self, prompt: &str, system: Option<&str>) -> Result<String> {
if self.breaker_is_open() {
return Err(anyhow!(
"Failed to send chat request: circuit breaker open \
(last failure within {}s); LLM at {} is not responding",
CIRCUIT_BREAKER_COOLDOWN.as_secs(),
self.base_url,
));
}
self.check_outbound()?;
let (url, payload, bearer): (String, Value, Option<&str>) = match &self.provider {
LlmProvider::Ollama => {
let mut messages = Vec::new();
if let Some(sys) = system {
messages.push(json!({"role": "system", "content": sys}));
}
messages.push(json!({"role": "user", "content": prompt}));
(
format!("{}/api/chat", self.base_url),
json!({
"model": self.model,
"messages": messages,
"stream": false,
}),
None,
)
}
LlmProvider::OpenAiCompatible { api_key } => {
let mut messages = Vec::new();
if let Some(sys) = system {
messages.push(json!({"role": "system", "content": sys}));
}
messages.push(json!({"role": "user", "content": prompt}));
(
format!("{}/chat/completions", self.base_url),
json!({
"model": self.model,
"messages": messages,
"stream": false,
}),
Some(api_key.as_str()),
)
}
};
let mut req = self
.client
.post(&url)
.timeout(GENERATE_TIMEOUT)
.json(&payload);
if let Some(key) = bearer {
req = req.bearer_auth(key);
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
self.note_failure();
return Err(anyhow::Error::new(e).context("Failed to send chat request"));
}
};
if !resp.status().is_success() {
let status = resp.status();
if status.is_server_error() {
self.note_failure();
}
let text = read_capped_text(resp).await;
return Err(anyhow!("Chat generate failed ({status}): {text}"));
}
let body: Value = match read_capped_json(resp).await {
Ok(b) => b,
Err(e) => {
self.note_failure();
return Err(e.context("Failed to parse chat response"));
}
};
let response_text = match &self.provider {
LlmProvider::Ollama => body["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Missing 'message.content' field in chat output"))?
.to_string(),
LlmProvider::OpenAiCompatible { .. } => body["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| {
anyhow!(
"Missing 'choices[0].message.content' field in OpenAI-compatible \
chat response; got: {body}"
)
})?
.to_string(),
};
self.note_success();
Ok(response_text)
}
pub fn expand_query(&self, query: &str) -> Result<Vec<String>> {
block_on_local(|| self.expand_query_async(query))
}
pub async fn expand_query_async(&self, query: &str) -> Result<Vec<String>> {
let prompt = QUERY_EXPANSION_PROMPT.replace("{query}", query);
let response = self.generate_async(&prompt, None).await?;
let terms: Vec<String> = response
.lines()
.map(|line| line.trim().to_string())
.filter(|line| !line.is_empty())
.collect();
Ok(terms)
}
pub fn summarize_memories(&self, memories: &[(String, String)]) -> Result<String> {
block_on_local(|| self.summarize_memories_async(memories))
}
pub async fn summarize_memories_async(&self, memories: &[(String, String)]) -> Result<String> {
let formatted = memories
.iter()
.enumerate()
.map(|(i, (title, content))| {
format!("--- Memory {} ---\nTitle: {}\n{}", i + 1, title, content)
})
.collect::<Vec<_>>()
.join("\n\n");
let prompt = SUMMARIZE_PROMPT.replace("{memories}", &formatted);
let response = self.generate_async(&prompt, None).await?;
Ok(response.trim().to_string())
}
pub fn auto_tag(
&self,
title: &str,
content: &str,
model_override: Option<&str>,
) -> Result<Vec<String>> {
block_on_local(|| self.auto_tag_async(title, content, model_override))
}
pub async fn auto_tag_async(
&self,
title: &str,
content: &str,
model_override: Option<&str>,
) -> Result<Vec<String>> {
let prompt = AUTO_TAG_PROMPT
.replace("{title}", title)
.replace("{content}", content);
let response = self
.generate_with_model_override_async(&prompt, None, model_override)
.await?;
let tags: Vec<String> = response
.lines()
.map(|line| line.trim().to_lowercase())
.filter(|line| !line.is_empty() && line.len() <= 64)
.take(8)
.collect();
Ok(tags)
}
#[allow(dead_code)]
fn generate_with_model_override(
&self,
prompt: &str,
system: Option<&str>,
model_override: Option<&str>,
) -> Result<String> {
block_on_local(|| self.generate_with_model_override_async(prompt, system, model_override))
}
#[allow(clippy::too_many_lines)]
pub async fn generate_with_model_override_async(
&self,
prompt: &str,
system: Option<&str>,
model_override: Option<&str>,
) -> Result<String> {
if self.breaker_is_open() {
return Err(anyhow!(
"Failed to send chat request: circuit breaker open \
(last failure within {}s); LLM at {} is not responding",
CIRCUIT_BREAKER_COOLDOWN.as_secs(),
self.base_url,
));
}
self.check_outbound()?;
let model = model_override.unwrap_or(&self.model);
let (url, payload, bearer): (String, Value, Option<&str>) = match &self.provider {
LlmProvider::Ollama => {
let mut messages = Vec::new();
if let Some(sys) = system {
messages.push(json!({"role": "system", "content": sys}));
}
messages.push(json!({"role": "user", "content": prompt}));
(
format!("{}/api/chat", self.base_url),
json!({"model": model, "messages": messages, "stream": false}),
None,
)
}
LlmProvider::OpenAiCompatible { api_key } => {
let mut messages = Vec::new();
if let Some(sys) = system {
messages.push(json!({"role": "system", "content": sys}));
}
messages.push(json!({"role": "user", "content": prompt}));
(
format!("{}/chat/completions", self.base_url),
json!({"model": model, "messages": messages, "stream": false}),
Some(api_key.as_str()),
)
}
};
let mut req = self
.client
.post(&url)
.timeout(GENERATE_TIMEOUT)
.json(&payload);
if let Some(key) = bearer {
req = req.bearer_auth(key);
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
self.note_failure();
return Err(anyhow::Error::new(e).context("Failed to send chat request"));
}
};
if !resp.status().is_success() {
let status = resp.status();
if status.is_server_error() {
self.note_failure();
}
let text = read_capped_text(resp).await;
return Err(anyhow!("Generate failed ({status}): {text}"));
}
let body: Value = match read_capped_json(resp).await {
Ok(b) => b,
Err(e) => {
self.note_failure();
return Err(e.context("Failed to parse chat response"));
}
};
let response_text = match &self.provider {
LlmProvider::Ollama => body["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Missing 'message.content' in chat response"))?
.to_string(),
LlmProvider::OpenAiCompatible { .. } => body["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| {
anyhow!(
"Missing 'choices[0].message.content' in OpenAI-compatible \
chat response; got: {body}"
)
})?
.to_string(),
};
self.note_success();
Ok(response_text)
}
fn check_outbound(&self) -> Result<()> {
let url = reqwest::Url::parse(&self.base_url).ok();
let host = url
.as_ref()
.and_then(|u| u.host_str().map(str::to_string))
.unwrap_or_else(|| self.base_url.clone());
let scheme = url
.as_ref()
.map(|u| u.scheme().to_string())
.unwrap_or_default();
let action = crate::governance::agent_action::AgentAction::NetworkRequest {
host: host.clone(),
scheme,
};
crate::governance::wire_check::check_anyhow(&action)
.with_context(|| format!("governance refused outbound to ollama at {host}"))
}
#[allow(dead_code)]
fn generate_with_body(&self, body: &Value) -> Result<String> {
block_on_local(|| self.generate_with_body_async(body))
}
#[allow(dead_code)]
async fn generate_with_body_async(&self, body: &Value) -> Result<String> {
if self.breaker_is_open() {
return Err(anyhow!(
"Failed to send generate request: circuit breaker open \
(last failure within {}s); ollama at {} is not responding",
CIRCUIT_BREAKER_COOLDOWN.as_secs(),
self.base_url,
));
}
self.check_outbound()?;
let url = format!("{}/api/generate", self.base_url);
let resp = match self
.client
.post(&url)
.timeout(GENERATE_TIMEOUT)
.json(body)
.send()
.await
{
Ok(r) => r,
Err(e) => {
self.note_failure();
return Err(anyhow::Error::new(e).context("Failed to send generate request"));
}
};
if !resp.status().is_success() {
let status = resp.status();
if status.is_server_error() {
self.note_failure();
}
let text = read_capped_text(resp).await;
return Err(anyhow!("Generate failed ({status}): {text}"));
}
let parsed: Value = match read_capped_json(resp).await {
Ok(v) => v,
Err(e) => {
self.note_failure();
return Err(e.context("Failed to parse generate response"));
}
};
let response_text = parsed["response"]
.as_str()
.ok_or_else(|| anyhow!("Missing 'response' field in generate output"))?
.to_string();
self.note_success();
Ok(response_text)
}
pub fn embed_text(&self, text: &str, embed_model: &str) -> Result<Vec<f32>> {
block_on_local(|| self.embed_text_async(text, embed_model))
}
pub async fn embed_text_async(&self, text: &str, embed_model: &str) -> Result<Vec<f32>> {
if self.breaker_is_open() {
return Err(anyhow!(
"Failed to send embed request: circuit breaker open \
(last failure within {}s); LLM at {} is not responding",
CIRCUIT_BREAKER_COOLDOWN.as_secs(),
self.base_url,
));
}
self.check_outbound()?;
let (url, payload, bearer): (String, Value, Option<&str>) = match &self.provider {
LlmProvider::Ollama => (
format!("{}/api/embed", self.base_url),
json!({"model": embed_model, "input": text, "truncate": true}),
None,
),
LlmProvider::OpenAiCompatible { api_key } => (
format!("{}{}", self.base_url, OPENAI_COMPAT_EMBEDDINGS_PATH),
match self.embed_dimensions {
Some(dims) => {
json!({"model": embed_model, "input": text, "dimensions": dims})
}
None => json!({"model": embed_model, "input": text}),
},
Some(api_key.as_str()),
),
};
let mut req = self
.client
.post(&url)
.timeout(GENERATE_TIMEOUT)
.json(&payload);
if let Some(key) = bearer {
req = req.bearer_auth(key);
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
self.note_failure();
return Err(anyhow::Error::new(e).context("Failed to send embed request"));
}
};
if !resp.status().is_success() {
let status = resp.status();
if status.is_server_error() {
self.note_failure();
}
let text = read_capped_text(resp).await;
return Err(anyhow!("Embed failed ({status}): {text}"));
}
let body: Value = match read_capped_json(resp).await {
Ok(b) => b,
Err(e) => {
self.note_failure();
return Err(e.context("Failed to parse embed response"));
}
};
let embedding_array = match &self.provider {
LlmProvider::Ollama => body["embeddings"]
.as_array()
.and_then(|arr| arr.first())
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow!("Missing 'embeddings[0]' in Ollama embed response"))?,
LlmProvider::OpenAiCompatible { .. } => {
body["data"][0]["embedding"].as_array().ok_or_else(|| {
anyhow!(
"Missing 'data[0].embedding' in OpenAI-compatible embed response; \
got: {body}"
)
})?
}
};
#[allow(clippy::cast_possible_truncation)]
let floats: Vec<f32> = embedding_array
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
if floats.is_empty() {
return Err(anyhow!("Empty embedding returned from LLM"));
}
self.note_success();
Ok(floats)
}
pub fn embed_texts(&self, texts: &[&str], embed_model: &str) -> Result<Vec<Vec<f32>>> {
block_on_local(|| self.embed_texts_async(texts, embed_model))
}
pub async fn embed_texts_async(
&self,
texts: &[&str],
embed_model: &str,
) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if matches!(self.provider, LlmProvider::Ollama) {
let mut out = Vec::with_capacity(texts.len());
for t in texts {
out.push(self.embed_text_async(t, embed_model).await?);
}
return Ok(out);
}
let mut out: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
let mut start = 0usize;
while start < texts.len() {
let mut end = start;
let mut bytes = 0usize;
while end < texts.len()
&& (end - start) < EMBED_BATCH_MAX_INPUTS
&& (end == start || bytes + texts[end].len() <= EMBED_BATCH_MAX_BYTES)
{
bytes += texts[end].len();
end += 1;
}
let chunk = &texts[start..end];
match self.embed_texts_one_request(chunk, embed_model).await {
Ok(vecs) => out.extend(vecs),
Err(batch_err) => {
tracing::warn!(
"batched embed of {} text(s) failed ({batch_err}); \
falling back to per-text requests",
chunk.len()
);
for t in chunk {
out.push(self.embed_text_async(t, embed_model).await?);
}
}
}
start = end;
}
Ok(out)
}
async fn embed_texts_one_request(
&self,
chunk: &[&str],
embed_model: &str,
) -> Result<Vec<Vec<f32>>> {
if self.breaker_is_open() {
return Err(anyhow!(
"Failed to send embed request: circuit breaker open \
(last failure within {}s); LLM at {} is not responding",
CIRCUIT_BREAKER_COOLDOWN.as_secs(),
self.base_url,
));
}
self.check_outbound()?;
let LlmProvider::OpenAiCompatible { api_key } = &self.provider else {
return Err(anyhow!(
"embed_texts_one_request requires an OpenAI-compatible provider"
));
};
let payload = match self.embed_dimensions {
Some(dims) => {
json!({"model": embed_model, "input": chunk, "dimensions": dims})
}
None => json!({"model": embed_model, "input": chunk}),
};
let resp = match self
.client
.post(format!(
"{}{}",
self.base_url, OPENAI_COMPAT_EMBEDDINGS_PATH
))
.timeout(GENERATE_TIMEOUT)
.json(&payload)
.bearer_auth(api_key)
.send()
.await
{
Ok(r) => r,
Err(e) => {
self.note_failure();
return Err(anyhow::Error::new(e).context("Failed to send embed request"));
}
};
if !resp.status().is_success() {
let status = resp.status();
if status.is_server_error() {
self.note_failure();
}
let text = read_capped_text(resp).await;
return Err(anyhow!("Embed failed ({status}): {text}"));
}
let body: Value = match read_capped_json(resp).await {
Ok(b) => b,
Err(e) => {
self.note_failure();
return Err(e.context("Failed to parse embed response"));
}
};
let parsed = parse_openai_embeddings_batch(&body, chunk.len())?;
self.note_success();
Ok(parsed)
}
pub fn ensure_embed_model(&self, model: &str) -> Result<()> {
block_on_local(|| self.ensure_embed_model_async(model))
}
pub async fn ensure_embed_model_async(&self, model: &str) -> Result<()> {
if matches!(self.provider, LlmProvider::OpenAiCompatible { .. }) {
return Ok(());
}
let url = ollama_tags_url(&self.base_url);
let resp = self
.client
.get(&url)
.timeout(std::time::Duration::from_secs(10))
.send()
.await
.context("Failed to list Ollama models")?;
let body: Value = read_capped_json(resp)
.await
.context("Failed to parse /api/tags response")?;
let model_exists = body["models"].as_array().is_some_and(|models| {
models.iter().any(|m| {
let name = m["name"].as_str().unwrap_or("");
name == model
|| name.starts_with(&format!("{model}:"))
|| model == name.split(':').next().unwrap_or("")
})
});
if model_exists {
return Ok(());
}
tracing::info!("Pulling Ollama embedding model '{}'...", model);
let pull_url = format!("{}/api/pull", self.base_url);
let pull_client = reqwest::Client::builder()
.timeout(PULL_TIMEOUT)
.build()
.context("Failed to build pull client")?;
let resp = pull_client
.post(&pull_url)
.json(&json!({ "name": model }))
.send()
.await
.context("Failed to pull embedding model from Ollama")?;
if !resp.status().is_success() {
let status = resp.status();
let text = read_capped_text(resp).await;
return Err(anyhow!("Ollama embed model pull failed ({status}): {text}"));
}
tracing::info!("Embedding model '{}' pulled successfully", model);
Ok(())
}
pub fn detect_contradiction(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
block_on_local(|| self.detect_contradiction_async(mem_a, mem_b))
}
pub async fn detect_contradiction_async(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
let prompt = CONTRADICTION_PROMPT
.replace("{a}", mem_a)
.replace("{b}", mem_b);
let response = self.generate_async(&prompt, None).await?;
let answer = response.trim().to_lowercase();
Ok(answer.starts_with("yes"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prompt_templates_have_placeholders() {
assert!(QUERY_EXPANSION_PROMPT.contains("{query}"));
assert!(SUMMARIZE_PROMPT.contains("{memories}"));
assert!(AUTO_TAG_PROMPT.contains("{title}"));
assert!(AUTO_TAG_PROMPT.contains("{content}"));
assert!(CONTRADICTION_PROMPT.contains("{a}"));
assert!(CONTRADICTION_PROMPT.contains("{b}"));
}
#[test]
fn test_default_url() {
assert_eq!(DEFAULT_OLLAMA_URL, "http://localhost:11434");
}
#[test]
fn parse_openai_embeddings_batch_orders_by_index_1603() {
let body = serde_json::json!({"data": [
{"index": 1, "embedding": [2.0, 2.0]},
{"index": 0, "embedding": [1.0, 1.0]},
]});
let out = parse_openai_embeddings_batch(&body, 2).expect("parse");
assert_eq!(out, vec![vec![1.0, 1.0], vec![2.0, 2.0]]);
let no_index = serde_json::json!({"data": [
{"embedding": [1.0]},
{"embedding": [2.0]},
]});
let out = parse_openai_embeddings_batch(&no_index, 2).expect("positional parse");
assert_eq!(out, vec![vec![1.0], vec![2.0]]);
}
#[test]
fn parse_openai_embeddings_batch_rejects_malformed_1603() {
let short = serde_json::json!({"data": [{"index": 0, "embedding": [1.0]}]});
assert!(
parse_openai_embeddings_batch(&short, 2).is_err(),
"count mismatch"
);
let dup = serde_json::json!({"data": [
{"index": 0, "embedding": [1.0]},
{"index": 0, "embedding": [2.0]},
]});
assert!(
parse_openai_embeddings_batch(&dup, 2).is_err(),
"duplicate index"
);
let oob = serde_json::json!({"data": [
{"index": 0, "embedding": [1.0]},
{"index": 9, "embedding": [2.0]},
]});
assert!(
parse_openai_embeddings_batch(&oob, 2).is_err(),
"out-of-range index"
);
let missing = serde_json::json!({"data": [{"index": 0}]});
assert!(
parse_openai_embeddings_batch(&missing, 1).is_err(),
"missing embedding"
);
let empty = serde_json::json!({"data": [{"index": 0, "embedding": []}]});
assert!(
parse_openai_embeddings_batch(&empty, 1).is_err(),
"empty vector"
);
let no_data = serde_json::json!({"object": "list"});
assert!(
parse_openai_embeddings_batch(&no_data, 1).is_err(),
"missing data"
);
}
#[test]
fn default_base_url_for_alias_covers_all_15_aliases_1067() {
let cases: &[(&str, Option<&str>)] = &[
("openai", Some("https://api.openai.com/v1")),
("xai", Some("https://api.x.ai/v1")),
("anthropic", Some("https://api.anthropic.com/v1")),
(
"gemini",
Some("https://generativelanguage.googleapis.com/v1beta/openai"),
),
("deepseek", Some("https://api.deepseek.com/v1")),
("kimi", Some("https://api.moonshot.cn/v1")),
("moonshot", Some("https://api.moonshot.cn/v1")),
(
"qwen",
Some("https://dashscope.aliyuncs.com/compatible-mode/v1"),
),
(
"dashscope",
Some("https://dashscope.aliyuncs.com/compatible-mode/v1"),
),
("mistral", Some("https://api.mistral.ai/v1")),
("groq", Some("https://api.groq.com/openai/v1")),
("together", Some("https://api.together.xyz/v1")),
("cerebras", Some("https://api.cerebras.ai/v1")),
("openrouter", Some("https://openrouter.ai/api/v1")),
("fireworks", Some("https://api.fireworks.ai/inference/v1")),
("lmstudio", Some("http://localhost:1234/v1")),
("openai-compatible", None),
("totally-unknown-vendor", None),
];
for (alias, expected) in cases {
let got = default_base_url_for_alias(alias);
assert_eq!(
got, *expected,
"#1067: alias `{alias}` must resolve to {expected:?}; got {got:?}"
);
}
}
#[test]
fn alias_api_key_env_vars_per_alias_pins_1067() {
let cases: &[(&str, &[&str])] = &[
("openai", &["OPENAI_API_KEY"]),
("xai", &["XAI_API_KEY"]),
("anthropic", &["ANTHROPIC_API_KEY"]),
("gemini", &["GEMINI_API_KEY", "GOOGLE_API_KEY"]),
("deepseek", &["DEEPSEEK_API_KEY"]),
("kimi", &["MOONSHOT_API_KEY", "KIMI_API_KEY"]),
("moonshot", &["MOONSHOT_API_KEY", "KIMI_API_KEY"]),
("qwen", &["DASHSCOPE_API_KEY", "QWEN_API_KEY"]),
("dashscope", &["DASHSCOPE_API_KEY", "QWEN_API_KEY"]),
("mistral", &["MISTRAL_API_KEY"]),
("groq", &["GROQ_API_KEY"]),
("together", &["TOGETHER_API_KEY"]),
("cerebras", &["CEREBRAS_API_KEY"]),
("openrouter", &["OPENROUTER_API_KEY"]),
("fireworks", &["FIREWORKS_API_KEY"]),
(BACKEND_OLLAMA, &[]),
("lmstudio", &[]),
("openai-compatible", &[]),
("totally-unknown-vendor", &[]),
];
for (alias, expected) in cases {
let got = alias_api_key_env_vars(alias);
assert_eq!(
got, *expected,
"#1067: alias `{alias}` env-var preference list must be {expected:?}; got {got:?}"
);
}
}
}
#[cfg(test)]
#[allow(
clippy::unused_self,
clippy::unnecessary_wraps,
clippy::needless_pass_by_value,
clippy::wildcard_imports,
clippy::doc_markdown
)]
pub mod test_support {
use super::*;
pub enum MockFailure {
ModelNotFound,
Timeout,
MalformedResponse,
ApiError(String),
EmptyResponse,
NetworkError,
}
pub struct MockOllamaClient {
pub base_url: String,
pub model: String,
pub fail_with: Option<MockFailure>,
}
impl MockOllamaClient {
pub fn new_with_url(base_url: &str, model: &str) -> Result<Self> {
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
model: model.to_string(),
fail_with: None,
})
}
pub fn with_failure(base_url: &str, model: &str, failure: MockFailure) -> Result<Self> {
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
model: model.to_string(),
fail_with: Some(failure),
})
}
fn should_fail(&self) -> Option<&MockFailure> {
self.fail_with.as_ref()
}
pub fn is_available(&self) -> bool {
!matches!(self.should_fail(), Some(MockFailure::NetworkError))
}
pub fn ensure_model(&self) -> Result<()> {
match self.should_fail() {
Some(MockFailure::ModelNotFound) => Err(anyhow!(
"Model 'unknown-model' not found in Ollama registry"
)),
Some(MockFailure::Timeout) => {
Err(anyhow!("Failed to list Ollama models: operation timed out"))
}
Some(MockFailure::ApiError(msg)) => {
Err(anyhow!("Ollama pull failed (404): {}", msg))
}
Some(MockFailure::NetworkError) => Err(anyhow!(
"Failed to pull model from Ollama: connection refused"
)),
_ => Ok(()),
}
}
pub fn ensure_embed_model(&self, _model: &str) -> Result<()> {
match self.should_fail() {
Some(MockFailure::ModelNotFound) => Err(anyhow!("Embedding model not found")),
Some(MockFailure::Timeout) => {
Err(anyhow!("Failed to list Ollama models: operation timed out"))
}
Some(MockFailure::ApiError(msg)) => {
Err(anyhow!("Ollama embed model pull failed (404): {}", msg))
}
Some(MockFailure::NetworkError) => Err(anyhow!(
"Failed to pull embedding model from Ollama: connection refused"
)),
_ => Ok(()),
}
}
pub fn generate(&self, prompt: &str, _system: Option<&str>) -> Result<String> {
match self.should_fail() {
Some(MockFailure::Timeout) => {
return Err(anyhow!("Failed to send chat request: operation timed out"));
}
Some(MockFailure::MalformedResponse) => {
return Err(anyhow!("Failed to parse chat response: invalid JSON"));
}
Some(MockFailure::EmptyResponse) => {
return Err(anyhow!("Missing 'message.content' field in chat output"));
}
Some(MockFailure::ApiError(msg)) => {
return Err(anyhow!("Chat generate failed (500): {}", msg));
}
Some(MockFailure::NetworkError) => {
return Err(anyhow!("Failed to send chat request: connection refused"));
}
_ => {}
}
if prompt.contains("expand") || prompt.contains("search") {
Ok("semantic search\nquery terms\nvector retrieval\ninformation retrieval\nsimilarity matching"
.to_string())
} else if prompt.contains("Summarize") {
Ok("This is a consolidated summary of multiple memories covering key facts and decisions."
.to_string())
} else if prompt.contains("tags") {
Ok("important\nkey-fact\nstatus-update\ntechnical".to_string())
} else if prompt.contains("contradict") {
if prompt.contains("yes") || prompt.contains("true") {
Ok("yes".to_string())
} else {
Ok("no".to_string())
}
} else {
Ok("Mock response for: ".to_string() + &prompt[..prompt.len().min(50)])
}
}
pub fn expand_query(&self, query: &str) -> Result<Vec<String>> {
if let Some(failure) = self.should_fail() {
return Err(match failure {
MockFailure::Timeout => {
anyhow!("Failed to send chat request: operation timed out")
}
MockFailure::MalformedResponse => {
anyhow!("Failed to parse chat response: invalid JSON")
}
MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
_ => anyhow!("Generate failed"),
});
}
let terms: Vec<String> = vec![
format!("{}-related", query),
format!("{}-expanded", query),
"semantic-search".to_string(),
"vector-expansion".to_string(),
"query-variants".to_string(),
];
Ok(terms.to_vec())
}
pub fn summarize_memories(&self, memories: &[(String, String)]) -> Result<String> {
if memories.is_empty() {
return Err(anyhow!("Cannot summarize empty memories list"));
}
if let Some(failure) = self.should_fail() {
return Err(match failure {
MockFailure::Timeout => {
anyhow!("Failed to send chat request: operation timed out")
}
MockFailure::MalformedResponse => {
anyhow!("Failed to parse chat response: invalid JSON")
}
MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
_ => anyhow!("Generate failed"),
});
}
let count = memories.len();
Ok(format!(
"Summary of {count} memories: consolidated facts and key decisions preserved"
))
}
pub fn auto_tag(
&self,
title: &str,
_content: &str,
_model_override: Option<&str>,
) -> Result<Vec<String>> {
if let Some(failure) = self.should_fail() {
return Err(match failure {
MockFailure::Timeout => {
anyhow!("Failed to send chat request: operation timed out")
}
MockFailure::MalformedResponse => {
anyhow!("Failed to parse chat response: invalid JSON")
}
MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
_ => anyhow!("Generate failed"),
});
}
let tags: Vec<String> = vec![
"important".to_string(),
format!("{}-tag", title.split_whitespace().next().unwrap_or("data")),
"memory".to_string(),
];
Ok(tags)
}
pub fn embed_text(&self, text: &str, _embed_model: &str) -> Result<Vec<f32>> {
match self.should_fail() {
Some(MockFailure::Timeout) => {
return Err(anyhow!(
"Failed to send embed request to Ollama: operation timed out"
));
}
Some(MockFailure::MalformedResponse) => {
return Err(anyhow!(
"Failed to parse Ollama embed response: invalid JSON"
));
}
Some(MockFailure::EmptyResponse) => {
return Err(anyhow!("Missing embeddings in Ollama response"));
}
Some(MockFailure::ApiError(msg)) => {
return Err(anyhow!("Ollama embed failed (500): {}", msg));
}
Some(MockFailure::NetworkError) => {
return Err(anyhow!(
"Failed to send embed request to Ollama: connection refused"
));
}
Some(MockFailure::ModelNotFound) => {
return Err(anyhow!("Ollama embed failed (404): model not found"));
}
_ => {}
}
let base_val = (text.len() % 10) as f32 / 100.0;
let embedding: Vec<f32> = (0..768).map(|i| base_val + (i as f32) * 0.0001).collect();
Ok(embedding)
}
pub fn detect_contradiction(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
if let Some(failure) = self.should_fail() {
return Err(match failure {
MockFailure::Timeout => {
anyhow!("Failed to send chat request: operation timed out")
}
MockFailure::MalformedResponse => {
anyhow!("Failed to parse chat response: invalid JSON")
}
MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
_ => anyhow!("Generate failed"),
});
}
let combined = format!("{mem_a} {mem_b}").to_lowercase();
let contradictory_keywords = &["not", "never", "always", "contradiction", "opposite"];
let count = contradictory_keywords
.iter()
.filter(|&&kw| combined.contains(kw))
.count();
Ok(count > 1)
}
}
}
#[cfg(test)]
mod mock_tests {
use super::test_support::MockOllamaClient;
use super::{AUTO_TAG_PROMPT, CONTRADICTION_PROMPT, QUERY_EXPANSION_PROMPT, SUMMARIZE_PROMPT};
#[test]
fn test_mock_new_with_url() {
let client = MockOllamaClient::new_with_url("http://localhost:11434", "test-model");
assert!(client.is_ok());
let client = client.unwrap();
assert_eq!(client.base_url, "http://localhost:11434");
assert_eq!(client.model, "test-model");
}
#[test]
fn test_mock_new_with_url_trailing_slash() {
let client = MockOllamaClient::new_with_url("http://localhost:11434/", "test-model");
assert!(client.is_ok());
let client = client.unwrap();
assert_eq!(client.base_url, "http://localhost:11434");
}
#[test]
fn test_mock_is_available() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
assert!(client.is_available());
}
#[test]
fn test_mock_ensure_model() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
assert!(client.ensure_model().is_ok());
}
#[test]
fn test_mock_ensure_embed_model() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
assert!(client.ensure_embed_model("nomic-embed-text").is_ok());
}
#[test]
fn test_mock_generate_query_expansion() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let prompt = QUERY_EXPANSION_PROMPT.replace("{query}", "search test");
let result = client.generate(&prompt, None);
assert!(result.is_ok());
let response = result.unwrap();
assert!(!response.is_empty());
}
#[test]
fn test_mock_expand_query() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result = client.expand_query("test query");
assert!(result.is_ok());
let terms = result.unwrap();
assert!(!terms.is_empty());
assert!(terms.len() >= 3);
}
#[test]
fn test_mock_summarize_memories() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let memories = vec![
("Title 1".to_string(), "Content 1".to_string()),
("Title 2".to_string(), "Content 2".to_string()),
];
let result = client.summarize_memories(&memories);
assert!(result.is_ok());
let summary = result.unwrap();
assert!(summary.contains('2'));
}
#[test]
fn test_mock_auto_tag() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result = client.auto_tag("Test Title", "test content", None);
assert!(result.is_ok());
let tags = result.unwrap();
assert!(!tags.is_empty());
assert!(tags.len() >= 2);
}
#[test]
fn test_mock_embed_text() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result = client.embed_text("test text", "nomic-embed-text");
assert!(result.is_ok());
let embedding = result.unwrap();
assert_eq!(embedding.len(), 768);
assert!(embedding.iter().all(|&x| x >= 0.0));
}
#[test]
fn test_mock_embed_text_deterministic() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result1 = client.embed_text("same text", "nomic-embed-text");
let result2 = client.embed_text("same text", "nomic-embed-text");
assert!(result1.is_ok());
assert!(result2.is_ok());
assert_eq!(result1.unwrap(), result2.unwrap());
}
#[test]
fn test_mock_detect_contradiction_true() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result = client.detect_contradiction(
"The system always works",
"The system never works correctly",
);
assert!(result.is_ok());
let is_contradiction = result.unwrap();
assert!(is_contradiction);
}
#[test]
fn test_mock_detect_contradiction_false() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result = client.detect_contradiction(
"The memory is about search",
"Additional details about the same search",
);
assert!(result.is_ok());
}
#[test]
fn test_mock_generate_summarize_prompt() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let prompt = SUMMARIZE_PROMPT.replace(
"{memories}",
"--- Memory 1 ---\nTitle: Test\nThis is a test",
);
let result = client.generate(&prompt, None);
assert!(result.is_ok());
let response = result.unwrap();
assert!(response.contains("summary") || response.contains("Summary"));
}
#[test]
fn test_mock_generate_auto_tag_prompt() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let prompt = AUTO_TAG_PROMPT
.replace("{title}", "Important Update")
.replace("{content}", "Some content");
let result = client.generate(&prompt, None);
assert!(result.is_ok());
let response = result.unwrap();
assert!(!response.is_empty());
}
#[test]
fn test_mock_generate_contradiction_prompt() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let prompt = CONTRADICTION_PROMPT
.replace("{a}", "Statement A")
.replace("{b}", "Statement B");
let result = client.generate(&prompt, None);
assert!(result.is_ok());
let response = result.unwrap();
assert!(!response.is_empty());
}
#[test]
fn test_mock_ensure_model_returns_not_found_error() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"unknown-model",
super::test_support::MockFailure::ModelNotFound,
)
.unwrap();
let result = client.ensure_model();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("not found"));
}
#[test]
fn test_mock_ensure_model_returns_timeout_error() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::Timeout,
)
.unwrap();
let result = client.ensure_model();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("timed out"));
}
#[test]
fn test_mock_ensure_model_returns_network_error() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::NetworkError,
)
.unwrap();
let result = client.ensure_model();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("connection"));
}
#[test]
fn test_mock_ensure_embed_model_returns_not_found_error() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::ModelNotFound,
)
.unwrap();
let result = client.ensure_embed_model("unknown-embed-model");
assert!(result.is_err());
}
#[test]
fn test_mock_generate_returns_timeout_error() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::Timeout,
)
.unwrap();
let result = client.generate("test prompt", None);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("timed out"));
}
#[test]
fn test_mock_generate_handles_malformed_json() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::MalformedResponse,
)
.unwrap();
let result = client.generate("test prompt", None);
assert!(result.is_err());
}
#[test]
fn test_mock_generate_handles_empty_response() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::EmptyResponse,
)
.unwrap();
let result = client.generate("test prompt", None);
assert!(result.is_err());
}
#[test]
fn test_mock_generate_handles_api_error() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::ApiError("Internal Error".to_string()),
)
.unwrap();
let result = client.generate("test prompt", None);
assert!(result.is_err());
}
#[test]
fn test_mock_expand_query_passes_through_generate_error() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::Timeout,
)
.unwrap();
let result = client.expand_query("test query");
assert!(result.is_err());
}
#[test]
fn test_mock_summarize_memories_handles_empty_input() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let empty_memories: Vec<(String, String)> = vec![];
let result = client.summarize_memories(&empty_memories);
assert!(result.is_err());
}
#[test]
fn test_mock_summarize_memories_handles_timeout() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::Timeout,
)
.unwrap();
let memories = vec![("Title".to_string(), "Content".to_string())];
let result = client.summarize_memories(&memories);
assert!(result.is_err());
}
#[test]
fn test_mock_auto_tag_handles_special_characters() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result = client.auto_tag("Title @#$%", "content", None);
assert!(result.is_ok());
}
#[test]
fn test_mock_auto_tag_timeout() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::Timeout,
)
.unwrap();
let result = client.auto_tag("Test", "content", None);
assert!(result.is_err());
}
#[test]
fn test_mock_embed_text_returns_768_dim() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result = client.embed_text("test", "nomic-embed-text-v1.5");
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 768);
}
#[test]
fn test_mock_embed_text_timeout() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::Timeout,
)
.unwrap();
let result = client.embed_text("test", "nomic-embed-text");
assert!(result.is_err());
}
#[test]
fn test_mock_embed_text_malformed() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::MalformedResponse,
)
.unwrap();
let result = client.embed_text("test", "nomic-embed-text");
assert!(result.is_err());
}
#[test]
fn test_mock_embed_text_empty_response() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::EmptyResponse,
)
.unwrap();
let result = client.embed_text("test", "nomic-embed-text");
assert!(result.is_err());
}
#[test]
fn test_mock_embed_text_model_not_found() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::ModelNotFound,
)
.unwrap();
let result = client.embed_text("test", "unknown");
assert!(result.is_err());
}
#[test]
fn test_mock_embed_text_network_error() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::NetworkError,
)
.unwrap();
let result = client.embed_text("test", "nomic-embed-text");
assert!(result.is_err());
}
#[test]
fn test_mock_detect_contradiction_yes_case() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result =
client.detect_contradiction("The system always works", "The system never works");
assert!(result.is_ok());
assert!(result.unwrap());
}
#[test]
fn test_mock_detect_contradiction_no_case() {
let client =
MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
let result =
client.detect_contradiction("Consistent statement A", "Consistent statement B");
assert!(result.is_ok());
}
#[test]
fn test_mock_detect_contradiction_timeout() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::Timeout,
)
.unwrap();
let result = client.detect_contradiction("A", "B");
assert!(result.is_err());
}
#[test]
fn test_mock_is_available_network_error() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::NetworkError,
)
.unwrap();
assert!(!client.is_available());
}
#[test]
fn test_mock_with_failure_creates_client_that_fails() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::Timeout,
)
.unwrap();
let result = client.generate("any", None);
assert!(result.is_err());
}
#[test]
fn test_mock_api_error_variant() {
let client = MockOllamaClient::with_failure(
"http://localhost:11434",
"test-model",
super::test_support::MockFailure::ApiError("Custom msg".to_string()),
)
.unwrap();
let result = client.generate("test", None);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Custom msg"));
}
}
#[cfg(test)]
#[allow(clippy::too_many_lines, clippy::similar_names)]
mod wiremock_tests {
use super::OllamaClient;
use serde_json::json;
use std::net::TcpListener;
use wiremock::matchers::{body_partial_json, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
async fn mount_tags_ok(server: &MockServer, models: serde_json::Value) {
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(models))
.mount(server)
.await;
}
#[tokio::test(flavor = "multi_thread")]
async fn read_capped_bytes_rejects_oversize_1459() {
use super::read_capped_bytes_inner;
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/big"))
.respond_with(ResponseTemplate::new(200).set_body_string("x".repeat(4096)))
.mount(&server)
.await;
let url = format!("{}/big", server.uri());
let resp = reqwest::Client::new().get(&url).send().await.unwrap();
let err = read_capped_bytes_inner(resp, 64)
.await
.expect_err("oversize body MUST be rejected by the cap");
let msg = err.to_string();
assert!(
msg.contains("exceeds cap") || msg.contains("exceeded cap"),
"rejection must name the cap: {msg}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn read_capped_json_parses_small_body_1459() {
use super::read_capped_json;
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/ok"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"hello": "world"})))
.mount(&server)
.await;
let url = format!("{}/ok", server.uri());
let resp = reqwest::Client::new().get(&url).send().await.unwrap();
let v = read_capped_json(resp).await.unwrap();
assert_eq!(v["hello"], "world");
}
#[tokio::test(flavor = "multi_thread")]
async fn perf_12_new_with_url_no_health_check_skips_probe() {
let url = tokio::task::spawn_blocking(|| {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
format!("http://127.0.0.1:{port}")
})
.await
.unwrap();
let (constructed_ok, is_available_after) = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url_no_health_check(&url, "test-model")
.expect("PERF-12: new_with_url_no_health_check must not probe");
let avail = client.is_available();
(true, avail)
})
.await
.unwrap();
assert!(constructed_ok);
assert!(
!is_available_after,
"PERF-12: lazy is_available() must return false for an unreachable endpoint",
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_is_available_returns_false_on_connection_refused() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let url = format!("http://127.0.0.1:{port}");
let result = tokio::task::spawn_blocking(move || {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.unwrap();
let probe = format!("{url}/api/tags");
client
.get(&probe)
.send()
.is_ok_and(|r| r.status().is_success())
})
.await
.unwrap();
assert!(
!result,
"is_available should return false when nothing is listening"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_is_available_returns_false_on_500_response() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let uri = server.uri();
let result = tokio::task::spawn_blocking(move || {
OllamaClient::new_with_url(&uri, "test-model")
})
.await
.unwrap();
let err = match result {
Ok(_) => panic!("client construction should fail on 500"),
Err(e) => e.to_string(),
};
assert!(
err.contains("not running") || err.contains("not reachable"),
"expected unreachable-style error, got: {err}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_is_available_returns_true_on_200_with_json_body() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
let uri = server.uri();
let available = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.is_available()
})
.await
.unwrap();
assert!(available);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_pull_if_missing_skips_pull_if_model_already_in_tags() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"models": [
{"name": "test-model:latest"},
]
})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.respond_with(ResponseTemplate::new(200))
.expect(0)
.mount(&server)
.await;
let uri = server.uri();
let result = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.ensure_model()
})
.await
.unwrap();
assert!(
result.is_ok(),
"ensure_model should succeed; got {result:?}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_pull_if_missing_initiates_pull_if_not() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.and(body_partial_json(json!({"name": "test-model"})))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.expect(1)
.mount(&server)
.await;
let uri = server.uri();
let result = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.ensure_model()
})
.await
.unwrap();
assert!(
result.is_ok(),
"ensure_model should succeed; got {result:?}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_generate_parses_success_response() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"role": "assistant", "content": "hello"},
"done": true,
})))
.mount(&server)
.await;
let uri = server.uri();
let result = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.generate("ping", None)
})
.await
.unwrap();
assert_eq!(result.unwrap(), "hello");
}
#[tokio::test(flavor = "multi_thread")]
async fn test_generate_returns_error_on_malformed_json() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("{not valid json")
.insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
)
.mount(&server)
.await;
let uri = server.uri();
let result = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.generate("ping", None)
})
.await
.unwrap();
assert!(result.is_err(), "malformed JSON should surface an error");
let err = result.unwrap_err().to_string();
assert!(
err.contains("parse") || err.to_lowercase().contains("json"),
"expected a parse error, got: {err}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_generate_returns_error_on_500() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(500).set_body_string("internal boom"))
.mount(&server)
.await;
let uri = server.uri();
let result = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.generate("ping", None)
})
.await
.unwrap();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("500") || err.contains("Chat generate failed"));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_generate_passes_system_prompt_when_provided() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.and(body_partial_json(json!({
"messages": [
{"role": "system", "content": "be terse"},
{"role": "user", "content": "hi"},
],
"stream": false,
})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"role": "assistant", "content": "ok"},
})))
.mount(&server)
.await;
let uri = server.uri();
let out = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.generate("hi", Some("be terse"))
})
.await
.unwrap();
assert_eq!(out.unwrap(), "ok");
}
#[tokio::test(flavor = "multi_thread")]
async fn test_embed_parses_embedding_array() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"embeddings": [[0.1_f32, 0.2_f32, 0.3_f32]],
})))
.mount(&server)
.await;
let uri = server.uri();
let vec = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.embed_text("hello", "nomic-embed-text-v1.5")
})
.await
.unwrap();
let v = vec.unwrap();
assert_eq!(v.len(), 3);
assert!((v[0] - 0.1_f32).abs() < 1e-5);
assert!((v[1] - 0.2_f32).abs() < 1e-5);
assert!((v[2] - 0.3_f32).abs() < 1e-5);
}
#[tokio::test(flavor = "multi_thread")]
async fn ollama_embed_payload_sets_truncate_1595() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.and(body_partial_json(json!({
"model": "nomic-embed-text",
"input": "hello",
"truncate": true,
})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"embeddings": [[0.5_f32, 0.25_f32]],
})))
.mount(&server)
.await;
let uri = server.uri();
let vec = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.embed_text("hello", "nomic-embed-text")
})
.await
.unwrap();
assert_eq!(vec.unwrap().len(), 2);
}
#[tokio::test(flavor = "multi_thread")]
async fn openai_embed_payload_omits_truncate_1595() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/embeddings"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"data": [{"embedding": [0.5_f32, 0.25_f32]}],
})))
.mount(&server)
.await;
let uri = server.uri();
let vec = tokio::task::spawn_blocking(move || {
let client =
OllamaClient::new_openai_compatible(&uri, "test-model", "fake-key").unwrap();
client.embed_text("hello", "test-model")
})
.await
.unwrap();
assert_eq!(vec.unwrap().len(), 2);
let requests = server
.received_requests()
.await
.expect("request recording enabled");
let embed_req = requests
.iter()
.find(|r| r.url.path() == "/embeddings")
.expect("embed request recorded");
let body: serde_json::Value = serde_json::from_slice(&embed_req.body).expect("json body");
assert!(
body.get("truncate").is_none(),
"OpenAI-compatible embed payload must not carry the \
Ollama-native truncate key, got: {body}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_embed_returns_error_on_wrong_shape() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"embedding": 0.5,
})))
.mount(&server)
.await;
let uri = server.uri();
let result = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.embed_text("hi", "nomic-embed-text")
})
.await
.unwrap();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Missing embeddings") || err.to_lowercase().contains("embed"),
"expected missing-embeddings error, got: {err}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_embed_returns_error_on_500() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(500).set_body_string("nope"))
.mount(&server)
.await;
let uri = server.uri();
let result = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.embed_text("hi", "nomic-embed-text")
})
.await
.unwrap();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("500"));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_expand_query_returns_parsed_terms_one_per_line() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "term1\nterm2\nterm3\n\n"},
})))
.mount(&server)
.await;
let uri = server.uri();
let terms = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.expand_query("anything")
})
.await
.unwrap();
assert_eq!(
terms.unwrap(),
vec![
"term1".to_string(),
"term2".to_string(),
"term3".to_string()
]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_auto_tag_returns_parsed_tags() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "Tag1\nTAG2\ntag3"},
})))
.mount(&server)
.await;
let uri = server.uri();
let tags = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.auto_tag("Title", "content", None)
})
.await
.unwrap();
assert_eq!(
tags.unwrap(),
vec!["tag1".to_string(), "tag2".to_string(), "tag3".to_string()]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_detect_contradiction_parses_yes_no() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "yes\n"},
})))
.mount(&server)
.await;
let uri_yes = server.uri();
let yes = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri_yes, "test-model").unwrap();
client.detect_contradiction("a", "b")
})
.await
.unwrap();
assert!(yes.unwrap(), "'yes' should be detected as contradiction");
let server_no = MockServer::start().await;
mount_tags_ok(&server_no, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "no"},
})))
.mount(&server_no)
.await;
let uri_no = server_no.uri();
let no = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri_no, "test-model").unwrap();
client.detect_contradiction("a", "b")
})
.await
.unwrap();
assert!(!no.unwrap(), "'no' should NOT be detected as contradiction");
let server_garbage = MockServer::start().await;
mount_tags_ok(&server_garbage, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "definitely-not-yes-or-no"},
})))
.mount(&server_garbage)
.await;
let uri_g = server_garbage.uri();
let garbage = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri_g, "test-model").unwrap();
client.detect_contradiction("a", "b")
})
.await
.unwrap();
assert!(
!garbage.unwrap(),
"garbage answer should default to non-contradiction"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_ensure_embed_model_skips_pull_if_present() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"models": [{"name": "nomic-embed-text:latest"}]
})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.respond_with(ResponseTemplate::new(200))
.expect(0)
.mount(&server)
.await;
let uri = server.uri();
let r = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.ensure_embed_model("nomic-embed-text")
})
.await
.unwrap();
assert!(r.is_ok());
}
#[tokio::test(flavor = "multi_thread")]
async fn auto_tag_model_override_takes_precedence_l15() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.and(body_partial_json(json!({"model": "gemma3:4b"})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "alpha\nbeta\ngamma"},
})))
.expect(1)
.mount(&server)
.await;
let uri = server.uri();
let tags = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "gemma4:e2b").unwrap();
client.auto_tag("Title", "content", Some("gemma3:4b"))
})
.await
.unwrap();
let tags = tags.expect("auto_tag with override should succeed");
assert_eq!(
tags,
vec!["alpha".to_string(), "beta".to_string(), "gamma".to_string()]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn auto_tag_chat_shape_post_1067() {
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "one\ntwo"},
})))
.expect(1)
.mount(&server)
.await;
let uri = server.uri();
let tags = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "any-model").unwrap();
client.auto_tag("Title", "content", None)
})
.await
.unwrap();
let tags = tags.expect("auto_tag should succeed");
assert_eq!(tags, vec!["one".to_string(), "two".to_string()]);
}
pub(super) static ENV_GUARD_1143: std::sync::Mutex<()> = std::sync::Mutex::new(());
pub(super) fn lock_env_1143() -> std::sync::MutexGuard<'static, ()> {
ENV_GUARD_1143
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
pub(super) fn clear_llm_env_1143() {
for k in [
"AI_MEMORY_LLM_BACKEND",
"AI_MEMORY_LLM_MODEL",
"AI_MEMORY_LLM_BASE_URL",
"AI_MEMORY_LLM_API_KEY",
"OLLAMA_BASE_URL",
"XAI_API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"GEMINI_API_KEY",
"GOOGLE_API_KEY",
] {
unsafe { std::env::remove_var(k) };
}
}
#[test]
fn is_ollama_native_true_for_ollama_client_1143() {
let client = OllamaClient::new_for_testing("gemma4:e4b");
assert!(
client.is_ollama_native(),
"#1143: Ollama-provider client must report is_ollama_native()=true"
);
}
#[test]
fn is_ollama_native_false_for_openai_compatible_1143() {
let client =
OllamaClient::new_openai_compatible("https://api.x.ai/v1", "grok-4.3", "fake-key")
.expect("openai-compatible client builds");
assert!(
!client.is_ollama_native(),
"#1143: OpenAI-compatible client must report is_ollama_native()=false"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn build_for_init_legacy_arm_when_env_unset_1143() {
let _g = lock_env_1143();
clear_llm_env_1143();
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
let uri = server.uri();
let result =
tokio::task::spawn_blocking(move || OllamaClient::build_for_init(&uri, "gemma4:e4b"))
.await
.unwrap();
let client = match result {
Ok(Some(c)) => c,
Ok(None) => panic!("#1143: legacy arm must yield Ok(Some(client)); got Ok(None)"),
Err(e) => panic!("#1143: legacy arm must yield Ok(Some(client)); got Err({e})"),
};
assert!(
client.is_ollama_native(),
"#1143: legacy arm constructs an Ollama-provider client"
);
assert_eq!(client.model, "gemma4:e4b");
}
#[tokio::test(flavor = "multi_thread")]
async fn build_for_init_env_arm_routes_to_from_env_1143() {
let _g = lock_env_1143();
clear_llm_env_1143();
unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "xai") };
unsafe { std::env::set_var("AI_MEMORY_LLM_API_KEY", "fake-xai-key") };
unsafe { std::env::set_var("AI_MEMORY_LLM_MODEL", "grok-4.3") };
let result = tokio::task::spawn_blocking(|| {
OllamaClient::build_for_init("http://127.0.0.1:1", "ignored-legacy-model")
})
.await
.unwrap();
clear_llm_env_1143();
let client = match result {
Ok(Some(c)) => c,
Ok(None) => panic!(
"#1143: env arm with AI_MEMORY_LLM_BACKEND=xai must yield \
Ok(Some(client)); got Ok(None)"
),
Err(e) => panic!(
"#1143: env arm with AI_MEMORY_LLM_BACKEND=xai must yield \
Ok(Some(client)); got Err({e})"
),
};
assert!(
!client.is_ollama_native(),
"#1143: xai backend yields an OpenAI-compatible (non-Ollama) client"
);
assert_eq!(
client.model, "grok-4.3",
"#1143: AI_MEMORY_LLM_MODEL must override the legacy model arg"
);
assert_eq!(
client.base_url, "https://api.x.ai/v1",
"#1143: xai default base URL must override the legacy URL arg"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn build_for_init_env_arm_unknown_alias_errors_1143() {
let _g = lock_env_1143();
clear_llm_env_1143();
unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "totally-bogus-vendor") };
let result = tokio::task::spawn_blocking(|| {
OllamaClient::build_for_init("http://127.0.0.1:1", "ignored")
})
.await
.unwrap();
clear_llm_env_1143();
assert!(
result.is_err(),
"#1143: unknown backend alias must surface the error \
instead of silently falling through to the legacy arm"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn build_for_init_env_arm_empty_string_falls_back_to_legacy_1143() {
let _g = lock_env_1143();
clear_llm_env_1143();
unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", " ") };
let server = MockServer::start().await;
mount_tags_ok(&server, json!({"models": []})).await;
let uri = server.uri();
let result =
tokio::task::spawn_blocking(move || OllamaClient::build_for_init(&uri, "gemma4:e2b"))
.await
.unwrap();
clear_llm_env_1143();
let client = result
.expect("legacy arm should not error on whitespace env")
.expect("legacy arm yields Some(client)");
assert!(client.is_ollama_native());
assert_eq!(client.model, "gemma4:e2b");
}
}
#[cfg(test)]
#[allow(clippy::too_many_lines)]
mod c5_breaker_tests {
use super::OllamaClient;
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
async fn mount_tags_ok(server: &MockServer) {
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(server)
.await;
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_fast_fails_after_breaker_trips() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(500).set_body_string("upstream sick"))
.mount(&server)
.await;
let uri = server.uri();
let outcome = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
assert!(
!client.circuit_breaker_open(),
"breaker open before any failure"
);
for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
let _ = client.generate("ping", None); }
assert!(
client.circuit_breaker_open(),
"breaker should be open after {} consecutive 5xx",
super::CIRCUIT_BREAKER_THRESHOLD
);
let err = client
.generate("ping", None)
.expect_err("breaker-open path must Err");
err.to_string()
})
.await
.unwrap();
assert!(
outcome.contains("circuit breaker open"),
"expected breaker-open envelope, got: {outcome}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_text_fast_fails_after_breaker_trips() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let uri = server.uri();
let outcome = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
let _ = client.generate("ping", None);
}
assert!(client.circuit_breaker_open());
client
.embed_text("hello", "nomic-embed-text")
.expect_err("embed_text must fast-fail when breaker open")
.to_string()
})
.await
.unwrap();
assert!(
outcome.contains("circuit breaker open"),
"expected breaker-open envelope on embed_text, got: {outcome}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn circuit_breaker_open_starts_closed() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
let uri = server.uri();
let closed = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
client.circuit_breaker_open()
})
.await
.unwrap();
assert!(
!closed,
"freshly-constructed client must have closed breaker"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn breaker_stays_closed_under_threshold() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let uri = server.uri();
let still_closed = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
for _ in 0..(super::CIRCUIT_BREAKER_THRESHOLD - 1) {
let _ = client.generate("ping", None);
}
client.circuit_breaker_open()
})
.await
.unwrap();
assert!(
!still_closed,
"breaker must stay closed strictly below the threshold"
);
}
}
#[cfg(test)]
#[allow(clippy::too_many_lines, clippy::similar_names)]
mod perf9_async_tests {
use super::OllamaClient;
use serde_json::json;
use std::net::TcpListener;
use wiremock::matchers::{body_partial_json, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
async fn mount_tags_ok(server: &MockServer) {
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(server)
.await;
}
#[tokio::test(flavor = "multi_thread")]
async fn new_with_url_async_succeeds_against_healthy_endpoint() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.expect("constructor succeeds against healthy /api/tags");
assert!(client.is_ollama_native());
}
#[tokio::test(flavor = "multi_thread")]
async fn new_with_url_async_errors_when_endpoint_500s() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let msg = match OllamaClient::new_with_url_async(&server.uri(), "test-model").await {
Ok(_) => panic!("constructor must fail on 500"),
Err(e) => e.to_string(),
};
assert!(
msg.contains("not running") || msg.contains("not reachable"),
"expected unreachable-style error, got: {msg}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn new_with_url_async_errors_when_nothing_listening() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let url = format!("http://127.0.0.1:{port}");
let msg = match OllamaClient::new_with_url_async(&url, "test-model").await {
Ok(_) => panic!("connect-refused must surface an error"),
Err(e) => e.to_string(),
};
assert!(msg.contains("not running") || msg.contains("not reachable"));
}
#[tokio::test(flavor = "multi_thread")]
async fn is_available_async_true_on_200() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
assert!(client.is_available_async().await);
}
#[tokio::test(flavor = "multi_thread")]
async fn is_available_async_false_on_500_after_construction() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
drop(server);
let server500 = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(500))
.mount(&server500)
.await;
let mut client500 = OllamaClient::new_for_testing("test-model");
client500.base_url = server500.uri().trim_end_matches('/').to_string();
let _ = client; assert!(!client500.is_available_async().await);
}
#[tokio::test(flavor = "multi_thread")]
async fn is_available_async_false_on_network_error() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let mut client = OllamaClient::new_for_testing("test-model");
client.base_url = format!("http://127.0.0.1:{port}");
assert!(!client.is_available_async().await);
}
#[tokio::test(flavor = "multi_thread")]
async fn is_available_async_openai_compatible_path_hits_models() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/models"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"data": []})))
.mount(&server)
.await;
let client = OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key")
.expect("OpenAI-compat client builds");
assert!(client.is_available_async().await);
}
#[tokio::test(flavor = "multi_thread")]
async fn is_available_async_openai_compatible_false_on_401() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/models"))
.respond_with(ResponseTemplate::new(401))
.mount(&server)
.await;
let client =
OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
assert!(!client.is_available_async().await);
}
#[tokio::test(flavor = "multi_thread")]
async fn ensure_model_async_noop_on_openai_compatible() {
let server = MockServer::start().await;
drop(server);
let client =
OllamaClient::new_openai_compatible("http://127.0.0.1:1", "any-model", "fake-key")
.unwrap();
client
.ensure_model_async()
.await
.expect("OpenAI-compatible ensure_model_async is a no-op");
}
#[tokio::test(flavor = "multi_thread")]
async fn ensure_model_async_skips_pull_when_model_present() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"models": [{"name": "test-model:latest"}]
})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.respond_with(ResponseTemplate::new(200))
.expect(0)
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
client.ensure_model_async().await.expect("no pull needed");
}
#[tokio::test(flavor = "multi_thread")]
async fn ensure_model_async_pulls_when_missing() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.and(body_partial_json(json!({"name": "test-model"})))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.expect(1)
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
client.ensure_model_async().await.expect("pull succeeds");
}
#[tokio::test(flavor = "multi_thread")]
async fn ensure_model_async_surfaces_pull_failure() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.respond_with(ResponseTemplate::new(500).set_body_string("upstream sick"))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let err = client
.ensure_model_async()
.await
.expect_err("500 on pull must surface");
assert!(err.to_string().contains("Ollama pull failed"));
}
#[tokio::test(flavor = "multi_thread")]
async fn ensure_model_async_errors_on_malformed_tags_response() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("{not json")
.insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
)
.mount(&server)
.await;
let mut client = OllamaClient::new_for_testing("test-model");
client.base_url = server.uri().trim_end_matches('/').to_string();
let err = client
.ensure_model_async()
.await
.expect_err("malformed tags must surface");
assert!(
err.to_string().contains("parse") || err.to_string().to_lowercase().contains("json")
);
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_async_happy_path() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"role": "assistant", "content": "hello world"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let out = client.generate_async("ping", None).await.unwrap();
assert_eq!(out, "hello world");
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_async_with_system_prompt() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.and(body_partial_json(json!({
"messages": [
{"role": "system", "content": "be terse"},
{"role": "user", "content": "hi"},
],
})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "ok"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let out = client.generate_async("hi", Some("be terse")).await.unwrap();
assert_eq!(out, "ok");
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_async_returns_error_on_500() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(500).set_body_string("upstream sick"))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let err = client.generate_async("ping", None).await.unwrap_err();
assert!(
err.to_string().contains("500") || err.to_string().contains("Chat generate failed")
);
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_async_returns_error_on_400() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(400).set_body_string("bad request"))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
for _ in 0..(super::CIRCUIT_BREAKER_THRESHOLD + 1) {
let _ = client.generate_async("ping", None).await;
}
assert!(
!client.circuit_breaker_open(),
"4xx must not trip the circuit breaker"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_async_returns_error_on_malformed_json() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("{not valid json")
.insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
)
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let err = client.generate_async("ping", None).await.unwrap_err();
assert!(
err.to_string().contains("parse") || err.to_string().to_lowercase().contains("json")
);
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_async_errors_when_message_content_missing() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"done": true})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let err = client.generate_async("ping", None).await.unwrap_err();
assert!(err.to_string().contains("Missing 'message.content'"));
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_async_breaker_open_short_circuits() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
let _ = client.generate_async("x", None).await;
}
assert!(client.circuit_breaker_open(), "breaker should be tripped");
let err = client
.generate_async("y", None)
.await
.expect_err("breaker-open path Errs");
assert!(err.to_string().contains("circuit breaker open"));
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_async_openai_compatible_happy_path() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"choices": [{"message": {"role": "assistant", "content": "hi from openai"}}]
})))
.mount(&server)
.await;
let client =
OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
let out = client.generate_async("ping", None).await.unwrap();
assert_eq!(out, "hi from openai");
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_async_openai_compatible_missing_choices() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"data": "wrong shape"})))
.mount(&server)
.await;
let client =
OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
let err = client.generate_async("ping", None).await.unwrap_err();
assert!(
err.to_string()
.contains("Missing 'choices[0].message.content'")
);
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_text_async_happy_path() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"embeddings": [[0.1_f32, 0.2_f32, 0.3_f32]],
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let v = client
.embed_text_async("hello", "nomic-embed-text")
.await
.unwrap();
assert_eq!(v.len(), 3);
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_text_async_500_trips_breaker_after_threshold() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
let _ = client.embed_text_async("hello", "m").await;
}
assert!(
client.circuit_breaker_open(),
"3× 5xx must trip the breaker on embed_text_async"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_text_async_400_does_not_trip_breaker() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(400))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
for _ in 0..(super::CIRCUIT_BREAKER_THRESHOLD + 1) {
let _ = client.embed_text_async("hello", "m").await;
}
assert!(!client.circuit_breaker_open());
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_text_async_empty_vec_errors() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"embeddings": [[]]})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let err = client
.embed_text_async("hello", "m")
.await
.expect_err("empty vector must error");
assert!(err.to_string().contains("Empty embedding"));
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_text_async_malformed_json_errors() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("{bad json")
.insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
)
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let err = client.embed_text_async("hi", "m").await.unwrap_err();
assert!(
err.to_string().contains("parse") || err.to_string().to_lowercase().contains("json")
);
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_text_async_openai_compatible_happy_path() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/embeddings"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"data": [{"embedding": [0.5_f32, 0.6_f32]}]
})))
.mount(&server)
.await;
let client =
OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
let v = client
.embed_text_async("hello", "nomic-embed-text")
.await
.unwrap();
assert_eq!(v.len(), 2);
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_text_async_openai_compatible_missing_data_errors() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/embeddings"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({})))
.mount(&server)
.await;
let client =
OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
let err = client.embed_text_async("hi", "m").await.unwrap_err();
assert!(err.to_string().contains("Missing 'data[0].embedding'"));
}
#[tokio::test(flavor = "multi_thread")]
async fn embed_text_async_breaker_open_short_circuits() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
let _ = client.embed_text_async("x", "m").await;
}
let err = client.embed_text_async("y", "m").await.unwrap_err();
assert!(err.to_string().contains("circuit breaker open"));
}
#[tokio::test(flavor = "multi_thread")]
async fn ensure_embed_model_async_noop_on_openai_compatible() {
let client =
OllamaClient::new_openai_compatible("http://127.0.0.1:1", "any-model", "fake-key")
.unwrap();
client.ensure_embed_model_async("any").await.expect("no-op");
}
#[tokio::test(flavor = "multi_thread")]
async fn ensure_embed_model_async_skips_when_present() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"models": [{"name": "nomic-embed-text:latest"}]
})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.respond_with(ResponseTemplate::new(200))
.expect(0)
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
client
.ensure_embed_model_async("nomic-embed-text")
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn ensure_embed_model_async_pulls_when_missing() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.and(body_partial_json(json!({"name": "nomic-embed-text"})))
.respond_with(ResponseTemplate::new(200))
.expect(1)
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
client
.ensure_embed_model_async("nomic-embed-text")
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn ensure_embed_model_async_pull_failure_surfaces() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/pull"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let err = client
.ensure_embed_model_async("nomic-embed-text")
.await
.unwrap_err();
assert!(err.to_string().contains("Ollama embed model pull failed"));
}
#[tokio::test(flavor = "multi_thread")]
async fn expand_query_async_parses_lines() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "one\ntwo\n\nthree"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let terms = client.expand_query_async("anything").await.unwrap();
assert_eq!(
terms,
vec!["one".to_string(), "two".to_string(), "three".to_string()]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn summarize_memories_async_renders_prompt_and_returns_summary() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "summarized"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let s = client
.summarize_memories_async(&[
("t1".to_string(), "c1".to_string()),
("t2".to_string(), "c2".to_string()),
])
.await
.unwrap();
assert_eq!(s, "summarized");
}
#[tokio::test(flavor = "multi_thread")]
async fn auto_tag_async_normalises_lines_and_caps_at_8() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "A\nB\nC\nD\nE\nF\nG\nH\nI\nJ"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let tags = client
.auto_tag_async("title", "content", None)
.await
.unwrap();
assert_eq!(tags.len(), 8);
for t in &tags {
assert_eq!(t.to_lowercase(), *t);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn auto_tag_async_model_override_stamps_body() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.and(body_partial_json(json!({"model": "fast-model"})))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "a\nb\nc"},
})))
.expect(1)
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "primary-model")
.await
.unwrap();
let tags = client
.auto_tag_async("t", "c", Some("fast-model"))
.await
.unwrap();
assert_eq!(
tags,
vec!["a".to_string(), "b".to_string(), "c".to_string()]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn detect_contradiction_async_parses_yes() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "Yes."},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
assert!(client.detect_contradiction_async("a", "b").await.unwrap());
}
#[tokio::test(flavor = "multi_thread")]
async fn detect_contradiction_async_parses_no() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "no, they don't"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
assert!(!client.detect_contradiction_async("a", "b").await.unwrap());
}
#[tokio::test(flavor = "multi_thread")]
async fn detect_contradiction_async_propagates_generate_error() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
assert!(client.detect_contradiction_async("a", "b").await.is_err());
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_with_model_override_async_breaker_open_short_circuits() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
let _ = client
.generate_with_model_override_async("p", None, Some("m"))
.await;
}
let err = client
.generate_with_model_override_async("p", None, Some("m"))
.await
.unwrap_err();
assert!(err.to_string().contains("circuit breaker open"));
}
#[tokio::test(flavor = "multi_thread")]
async fn sync_wrapper_runs_under_block_in_place_path() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "bridge ok"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let out = client.generate("p", None).expect("sync wrapper ok");
assert_eq!(out, "bridge ok");
}
#[test]
fn llm_provider_debug_redacts_api_key() {
let p_ollama = super::LlmProvider::Ollama;
let p_oai = super::LlmProvider::OpenAiCompatible {
api_key: "secret-token-do-not-leak".to_string(),
};
let s_ollama = format!("{p_ollama:?}");
let s_oai = format!("{p_oai:?}");
assert!(s_ollama.contains("Ollama"));
assert!(s_oai.contains("OpenAiCompatible"));
assert!(s_oai.contains("<redacted>"));
assert!(
!s_oai.contains("secret-token-do-not-leak"),
"Debug impl must not leak the api_key"
);
}
#[test]
fn model_name_returns_resolved_model() {
let client = OllamaClient::new_for_testing("gemma-test-model");
assert_eq!(client.model_name(), "gemma-test-model");
}
#[test]
fn llm_provider_zeroize_secrets_is_idempotent() {
let mut p = super::LlmProvider::OpenAiCompatible {
api_key: "abcdef".to_string(),
};
p.zeroize_secrets();
let super::LlmProvider::OpenAiCompatible { api_key } = &p else {
unreachable!()
};
assert!(api_key.is_empty() || api_key.bytes().all(|b| b == 0));
p.zeroize_secrets();
}
#[test]
fn llm_provider_zeroize_secrets_noop_on_ollama() {
let mut p = super::LlmProvider::Ollama;
p.zeroize_secrets();
assert!(matches!(p, super::LlmProvider::Ollama));
}
#[test]
fn breaker_state_is_open_returns_false_when_last_failure_none() {
let s = super::BreakerState::new();
assert!(!s.is_open(), "fresh breaker must be closed");
}
#[tokio::test(flavor = "multi_thread")]
async fn new_convenience_constructor_routes_to_default_url() {
let res = tokio::task::spawn_blocking(|| OllamaClient::new("test-model"))
.await
.unwrap();
match res {
Ok(_) => { }
Err(e) => {
let msg = e.to_string();
assert!(
msg.contains("not running") || msg.contains("not reachable"),
"expected an unreachable-style error, got: {msg}"
);
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn sync_wrapper_path_is_available() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
assert!(client.is_available());
}
#[tokio::test(flavor = "multi_thread")]
async fn sync_wrapper_path_embed_text() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"embeddings": [[0.42_f32]],
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let v = client.embed_text("hi", "m").unwrap();
assert_eq!(v.len(), 1);
}
#[tokio::test(flavor = "multi_thread")]
async fn sync_wrapper_path_expand_query() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "a\nb"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let terms = client.expand_query("q").unwrap();
assert_eq!(terms, vec!["a".to_string(), "b".to_string()]);
}
#[tokio::test(flavor = "multi_thread")]
async fn sync_wrapper_path_summarize_memories() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "compacted"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let s = client
.summarize_memories(&[("t".to_string(), "c".to_string())])
.unwrap();
assert_eq!(s, "compacted");
}
#[tokio::test(flavor = "multi_thread")]
async fn sync_wrapper_path_auto_tag() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "x\ny\nz"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let tags = client.auto_tag("t", "c", None).unwrap();
assert_eq!(
tags,
vec!["x".to_string(), "y".to_string(), "z".to_string()]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn sync_wrapper_path_detect_contradiction() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "yes"},
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
assert!(client.detect_contradiction("a", "b").unwrap());
}
#[tokio::test(flavor = "multi_thread")]
async fn sync_wrapper_path_ensure_model() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"models": [{"name": "test-model:latest"}]
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
client.ensure_model().unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn sync_wrapper_path_ensure_embed_model() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"models": [{"name": "nomic-embed-text:latest"}]
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
client.ensure_embed_model("nomic-embed-text").unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_with_body_async_happy_path() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"response": "legacy text",
})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let body = json!({"model": "test-model", "prompt": "p", "stream": false});
let out = client.generate_with_body_async(&body).await.unwrap();
assert_eq!(out, "legacy text");
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_with_body_async_returns_error_on_500() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(ResponseTemplate::new(500).set_body_string("bad"))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let body = json!({"model": "test-model"});
let err = client.generate_with_body_async(&body).await.unwrap_err();
assert!(err.to_string().contains("500") || err.to_string().contains("Generate failed"));
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_with_body_async_returns_error_on_malformed_json() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("{bad json")
.insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
)
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let body = json!({"model": "test-model"});
let err = client.generate_with_body_async(&body).await.unwrap_err();
assert!(
err.to_string().contains("parse") || err.to_string().to_lowercase().contains("json")
);
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_with_body_async_breaker_open_short_circuits() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let body = json!({"model": "test-model"});
for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
let _ = client.generate_with_body_async(&body).await;
}
let err = client.generate_with_body_async(&body).await.unwrap_err();
assert!(err.to_string().contains("circuit breaker open"));
}
#[tokio::test(flavor = "multi_thread")]
async fn generate_with_body_async_missing_response_field_errors() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"done": true})))
.mount(&server)
.await;
let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
.await
.unwrap();
let body = json!({});
let err = client.generate_with_body_async(&body).await.unwrap_err();
assert!(err.to_string().contains("Missing 'response'"));
}
#[tokio::test(flavor = "multi_thread")]
async fn from_env_openai_compatible_requires_base_url() {
let _g = super::wiremock_tests::lock_env_1143();
super::wiremock_tests::clear_llm_env_1143();
unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "openai-compatible") };
unsafe { std::env::set_var("AI_MEMORY_LLM_API_KEY", "k") };
let res = OllamaClient::from_env();
super::wiremock_tests::clear_llm_env_1143();
let err = match res {
Ok(_) => panic!("openai-compatible without base_url must error"),
Err(e) => e.to_string(),
};
assert!(err.contains("AI_MEMORY_LLM_BASE_URL"));
}
#[tokio::test(flavor = "multi_thread")]
async fn from_env_openai_compatible_requires_api_key() {
let _g = super::wiremock_tests::lock_env_1143();
super::wiremock_tests::clear_llm_env_1143();
unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "openai-compatible") };
unsafe { std::env::set_var("AI_MEMORY_LLM_BASE_URL", "https://example.test/v1") };
let res = OllamaClient::from_env();
super::wiremock_tests::clear_llm_env_1143();
let err = match res {
Ok(_) => panic!("openai-compatible without key must error"),
Err(e) => e.to_string(),
};
assert!(err.contains("AI_MEMORY_LLM_API_KEY"));
}
#[tokio::test(flavor = "multi_thread")]
async fn from_env_alias_requires_api_key_when_none_resolvable() {
let _g = super::wiremock_tests::lock_env_1143();
super::wiremock_tests::clear_llm_env_1143();
unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "xai") };
let res = OllamaClient::from_env();
super::wiremock_tests::clear_llm_env_1143();
let err = match res {
Ok(_) => panic!("xai without key must error"),
Err(e) => e.to_string(),
};
assert!(err.contains("API key"));
}
#[test]
fn sync_wrapper_outside_runtime_constructs_ephemeral() {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let server = rt.block_on(async {
let s = MockServer::start().await;
mount_tags_ok(&s).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "no-rt bridge ok"},
})))
.mount(&s)
.await;
s
});
std::thread::scope(|sc| {
sc.spawn(|| {
let client = OllamaClient::new_with_url(&server.uri(), "test-model")
.expect("sync new_with_url ok");
let out = client.generate("ping", None).expect("sync generate ok");
assert_eq!(out, "no-rt bridge ok");
})
.join()
.unwrap();
});
}
}