1use anyhow::{Context, Result, anyhow};
37use serde_json::{Value, json};
38use std::sync::Mutex;
39use std::time::{Duration, Instant};
40
41pub(crate) const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
42
43fn block_on_local<F, Fut, T>(make_fut: F) -> T
99where
100 F: FnOnce() -> Fut + Send,
101 Fut: std::future::Future<Output = T>,
102 T: Send,
103{
104 if let Ok(handle) = tokio::runtime::Handle::try_current() {
105 match handle.runtime_flavor() {
106 tokio::runtime::RuntimeFlavor::MultiThread => {
107 tokio::task::block_in_place(|| handle.block_on(make_fut()))
114 }
115 _ => {
116 std::thread::scope(|s| {
148 s.spawn(move || {
149 tokio::runtime::Builder::new_current_thread()
150 .enable_all()
151 .build()
152 .expect("ephemeral runtime builds")
153 .block_on(make_fut())
154 })
155 .join()
156 .expect(
157 "block_on_local current-thread bridge thread panicked; \
158 underlying future panicked",
159 )
160 })
161 }
162 }
163 } else {
164 tokio::runtime::Builder::new_current_thread()
168 .enable_all()
169 .build()
170 .expect("ephemeral runtime builds")
171 .block_on(make_fut())
172 }
173}
174
175pub const BACKEND_OLLAMA: &str = "ollama";
191
192pub const OPENAI_COMPAT_EMBEDDINGS_PATH: &str = "/embeddings";
197
198pub(crate) fn default_base_url_for_alias(alias: &str) -> Option<&'static str> {
207 match alias {
208 "openai" => Some("https://api.openai.com/v1"),
209 "xai" => Some("https://api.x.ai/v1"),
210 "anthropic" => Some("https://api.anthropic.com/v1"),
211 "gemini" => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
212 "deepseek" => Some("https://api.deepseek.com/v1"),
213 "kimi" | "moonshot" => Some("https://api.moonshot.cn/v1"),
214 "qwen" | "dashscope" => Some("https://dashscope.aliyuncs.com/compatible-mode/v1"),
215 "mistral" => Some("https://api.mistral.ai/v1"),
216 "groq" => Some("https://api.groq.com/openai/v1"),
217 "together" => Some("https://api.together.xyz/v1"),
218 "cerebras" => Some("https://api.cerebras.ai/v1"),
219 "openrouter" => Some("https://openrouter.ai/api/v1"),
220 "fireworks" => Some("https://api.fireworks.ai/inference/v1"),
221 "lmstudio" => Some("http://localhost:1234/v1"),
222 _ => None,
223 }
224}
225
226pub(crate) fn ollama_tags_url(base_url: &str) -> String {
231 format!("{base_url}/api/tags")
232}
233
234fn alias_api_key_env_vars(alias: &str) -> &'static [&'static str] {
239 match alias {
240 "openai" => &["OPENAI_API_KEY"],
241 "xai" => &["XAI_API_KEY"],
242 "anthropic" => &["ANTHROPIC_API_KEY"],
243 "gemini" => &["GEMINI_API_KEY", "GOOGLE_API_KEY"],
244 "deepseek" => &["DEEPSEEK_API_KEY"],
245 "kimi" | "moonshot" => &["MOONSHOT_API_KEY", "KIMI_API_KEY"],
246 "qwen" | "dashscope" => &["DASHSCOPE_API_KEY", "QWEN_API_KEY"],
247 "mistral" => &["MISTRAL_API_KEY"],
248 "groq" => &["GROQ_API_KEY"],
249 "together" => &["TOGETHER_API_KEY"],
250 "cerebras" => &["CEREBRAS_API_KEY"],
251 "openrouter" => &["OPENROUTER_API_KEY"],
252 "fireworks" => &["FIREWORKS_API_KEY"],
253 _ => &[],
254 }
255}
256
257#[derive(Clone)]
268pub enum LlmProvider {
269 Ollama,
273 OpenAiCompatible { api_key: String },
280}
281
282impl std::fmt::Debug for LlmProvider {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287 match self {
288 LlmProvider::Ollama => f.debug_struct("Ollama").finish(),
289 LlmProvider::OpenAiCompatible { .. } => f
290 .debug_struct("OpenAiCompatible")
291 .field("api_key", &"<redacted>")
292 .finish(),
293 }
294 }
295}
296
297impl LlmProvider {
298 pub fn zeroize_secrets(&mut self) {
307 if let LlmProvider::OpenAiCompatible { api_key } = self {
308 use zeroize::Zeroize;
309 api_key.zeroize();
310 }
311 }
312}
313
314impl Drop for LlmProvider {
315 fn drop(&mut self) {
320 self.zeroize_secrets();
321 }
322}
323
324const GENERATE_TIMEOUT: Duration = Duration::from_secs(30);
325const PULL_TIMEOUT: Duration = Duration::from_secs(120);
326const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
330const HEALTH_TIMEOUT: Duration = Duration::from_secs(5);
332const CIRCUIT_BREAKER_COOLDOWN: Duration = Duration::from_secs(30);
336const CIRCUIT_BREAKER_THRESHOLD: u32 = 3;
339
340const EMBED_BATCH_MAX_INPUTS: usize = 100;
346const EMBED_BATCH_MAX_BYTES: usize = 256 * 1024;
351
352const MAX_LLM_RESPONSE_BYTES: usize = 16 * 1024 * 1024;
362
363async fn read_capped_bytes(resp: reqwest::Response) -> Result<Vec<u8>> {
368 read_capped_bytes_inner(resp, MAX_LLM_RESPONSE_BYTES).await
369}
370
371async fn read_capped_bytes_inner(mut resp: reqwest::Response, cap: usize) -> Result<Vec<u8>> {
375 if let Some(len) = resp.content_length() {
376 if len > cap as u64 {
377 return Err(anyhow!(
378 "LLM response too large: Content-Length {len} exceeds cap of {cap} bytes"
379 ));
380 }
381 }
382 let mut buf: Vec<u8> = Vec::new();
383 while let Some(chunk) = resp
384 .chunk()
385 .await
386 .context("Failed to read LLM response chunk")?
387 {
388 if buf.len().saturating_add(chunk.len()) > cap {
389 return Err(anyhow!(
390 "LLM response exceeded cap of {cap} bytes while streaming"
391 ));
392 }
393 buf.extend_from_slice(&chunk);
394 }
395 Ok(buf)
396}
397
398async fn read_capped_json(resp: reqwest::Response) -> Result<Value> {
402 let bytes = read_capped_bytes(resp).await?;
403 serde_json::from_slice(&bytes).context("Failed to parse LLM response body as JSON")
404}
405
406async fn read_capped_text(resp: reqwest::Response) -> String {
411 match read_capped_bytes(resp).await {
412 Ok(bytes) => String::from_utf8_lossy(&bytes).into_owned(),
413 Err(e) => format!("<error body unavailable: {e}>"),
414 }
415}
416
417fn parse_openai_embeddings_batch(body: &Value, expected_len: usize) -> Result<Vec<Vec<f32>>> {
427 let data = body["data"]
428 .as_array()
429 .ok_or_else(|| anyhow!("Missing 'data' array in OpenAI-compatible embed response"))?;
430 if data.len() != expected_len {
431 return Err(anyhow!(
432 "Embed response carried {} vector(s) for {expected_len} input(s)",
433 data.len()
434 ));
435 }
436 let mut out: Vec<Option<Vec<f32>>> = vec![None; expected_len];
437 for (pos, item) in data.iter().enumerate() {
438 let idx = match item["index"].as_u64() {
439 Some(i) => usize::try_from(i)
440 .map_err(|_| anyhow!("Embed response 'index' {i} does not fit usize"))?,
441 None => pos,
442 };
443 if idx >= expected_len {
444 return Err(anyhow!(
445 "Embed response 'index' {idx} out of range for {expected_len} input(s)"
446 ));
447 }
448 if out[idx].is_some() {
449 return Err(anyhow!("Embed response carried duplicate 'index' {idx}"));
450 }
451 let arr = item["embedding"].as_array().ok_or_else(|| {
452 anyhow!("Missing 'data[{pos}].embedding' in OpenAI-compatible embed response")
453 })?;
454 #[allow(clippy::cast_possible_truncation)]
455 let floats: Vec<f32> = arr
456 .iter()
457 .filter_map(|v| v.as_f64().map(|f| f as f32))
458 .collect();
459 if floats.is_empty() {
460 return Err(anyhow!("Empty embedding at index {idx} in embed response"));
461 }
462 out[idx] = Some(floats);
463 }
464 Ok(out.into_iter().flatten().collect())
467}
468
469const QUERY_EXPANSION_PROMPT: &str = r"You are a search query expander. Given a search query, generate 5-8 additional search terms that are semantically related. Return ONLY the terms, one per line, no numbering or explanation.
470
471Query: {query}";
472
473const SUMMARIZE_PROMPT: &str = r"Summarize the following memories into a single concise paragraph. Preserve all key facts, decisions, and technical details.
474
475{memories}";
476
477const AUTO_TAG_PROMPT: &str = r"Generate 3-5 short tags for categorizing this memory. Return ONLY the tags, one per line, lowercase, no symbols.
478
479Title: {title}
480Content: {content}";
481
482const CONTRADICTION_PROMPT: &str = r#"Do these two statements contradict each other? Answer ONLY "yes" or "no".
483
484Statement A: {a}
485Statement B: {b}"#;
486
487#[derive(Debug)]
495struct BreakerState {
496 consecutive_failures: u32,
497 last_failure_at: Option<Instant>,
498}
499
500impl BreakerState {
501 const fn new() -> Self {
502 Self {
503 consecutive_failures: 0,
504 last_failure_at: None,
505 }
506 }
507
508 fn is_open(&self) -> bool {
510 if self.consecutive_failures < CIRCUIT_BREAKER_THRESHOLD {
511 return false;
512 }
513 match self.last_failure_at {
514 Some(t) => t.elapsed() < CIRCUIT_BREAKER_COOLDOWN,
515 None => false,
516 }
517 }
518
519 fn record_failure(&mut self) {
520 self.consecutive_failures = self.consecutive_failures.saturating_add(1);
521 self.last_failure_at = Some(Instant::now());
522 }
523
524 fn record_success(&mut self) {
525 self.consecutive_failures = 0;
526 self.last_failure_at = None;
527 }
528}
529
530pub struct OllamaClient {
531 provider: LlmProvider,
540 base_url: String,
541 model: String,
542 client: reqwest::Client,
552 breaker: Mutex<BreakerState>,
556 embed_dimensions: Option<u32>,
564}
565
566impl OllamaClient {
567 #[must_use]
577 pub fn model_name(&self) -> &str {
578 &self.model
579 }
580
581 #[allow(dead_code)]
584 pub fn new(model: &str) -> Result<Self> {
585 Self::new_with_url(DEFAULT_OLLAMA_URL, model)
586 }
587
588 #[cfg(test)]
594 pub fn new_for_testing(model: &str) -> Self {
595 Self {
596 provider: LlmProvider::Ollama,
597 base_url: DEFAULT_OLLAMA_URL.trim_end_matches('/').to_string(),
598 model: model.to_string(),
599 client: reqwest::Client::builder()
600 .timeout(GENERATE_TIMEOUT)
601 .connect_timeout(CONNECT_TIMEOUT)
602 .build()
603 .expect("test reqwest client builds"),
604 breaker: Mutex::new(BreakerState::new()),
605 embed_dimensions: None,
606 }
607 }
608
609 #[allow(clippy::too_many_lines)]
640 pub fn from_env() -> Result<Option<Self>> {
641 let backend = std::env::var("AI_MEMORY_LLM_BACKEND")
642 .ok()
643 .map(|s| s.trim().to_ascii_lowercase())
644 .unwrap_or_else(|| BACKEND_OLLAMA.to_string());
645
646 let model = std::env::var("AI_MEMORY_LLM_MODEL")
647 .ok()
648 .filter(|s| !s.trim().is_empty())
649 .unwrap_or_else(|| match backend.as_str() {
650 "xai" => "grok-4.3".to_string(),
651 "openai" => "gpt-5".to_string(),
652 "anthropic" => "claude-opus-4.7".to_string(),
653 "gemini" => "gemini-2.0-flash".to_string(),
654 "deepseek" => "deepseek-chat".to_string(),
655 "kimi" | "moonshot" => "moonshot-v1-8k".to_string(),
656 "qwen" | "dashscope" => "qwen-max".to_string(),
657 "mistral" => "mistral-large-latest".to_string(),
658 "groq" => "llama-3.3-70b-versatile".to_string(),
659 "together" => "meta-llama/Llama-3.3-70B-Instruct-Turbo".to_string(),
660 "cerebras" => "llama-3.3-70b".to_string(),
661 "openrouter" => "openai/gpt-5".to_string(),
662 "fireworks" => "accounts/fireworks/models/llama-v3p3-70b-instruct".to_string(),
663 "lmstudio" => "local-model".to_string(),
664 _ => "gemma3:4b".to_string(),
665 });
666
667 match backend.as_str() {
668 BACKEND_OLLAMA => {
669 let base_url = std::env::var("AI_MEMORY_LLM_BASE_URL")
670 .ok()
671 .or_else(|| std::env::var("OLLAMA_BASE_URL").ok())
672 .filter(|s| !s.trim().is_empty())
673 .unwrap_or_else(|| DEFAULT_OLLAMA_URL.to_string());
674 Self::new_with_url(&base_url, &model).map(Some)
675 }
676 "openai-compatible" => {
677 let base_url = std::env::var("AI_MEMORY_LLM_BASE_URL")
678 .ok()
679 .filter(|s| !s.trim().is_empty())
680 .ok_or_else(|| {
681 anyhow!(
682 "AI_MEMORY_LLM_BACKEND=openai-compatible requires \
683 AI_MEMORY_LLM_BASE_URL to be set (no default URL \
684 — operator must supply the vendor's endpoint)"
685 )
686 })?;
687 let api_key = std::env::var("AI_MEMORY_LLM_API_KEY")
688 .ok()
689 .filter(|s| !s.trim().is_empty())
690 .ok_or_else(|| {
691 anyhow!(
692 "AI_MEMORY_LLM_BACKEND=openai-compatible requires \
693 AI_MEMORY_LLM_API_KEY to be set"
694 )
695 })?;
696 Self::new_openai_compatible(&base_url, &model, &api_key).map(Some)
697 }
698 alias => {
699 let Some(default_url) = default_base_url_for_alias(alias) else {
700 return Err(anyhow!(
701 "AI_MEMORY_LLM_BACKEND={alias} is not a recognized \
702 backend alias. Valid values: ollama, openai-compatible, \
703 openai, xai, anthropic, gemini, deepseek, kimi, qwen, \
704 mistral, groq, together, cerebras, openrouter, \
705 fireworks, lmstudio"
706 ));
707 };
708 let base_url = std::env::var("AI_MEMORY_LLM_BASE_URL")
709 .ok()
710 .filter(|s| !s.trim().is_empty())
711 .unwrap_or_else(|| default_url.to_string());
712 let api_key = std::env::var("AI_MEMORY_LLM_API_KEY")
713 .ok()
714 .filter(|s| !s.trim().is_empty())
715 .or_else(|| {
716 alias_api_key_env_vars(alias).iter().find_map(|name| {
717 std::env::var(name).ok().filter(|s| !s.trim().is_empty())
718 })
719 })
720 .ok_or_else(|| {
721 anyhow!(
722 "AI_MEMORY_LLM_BACKEND={alias} requires an API key \
723 — set AI_MEMORY_LLM_API_KEY or one of the \
724 per-vendor env vars: {:?}",
725 alias_api_key_env_vars(alias)
726 )
727 })?;
728 Self::new_openai_compatible(&base_url, &model, &api_key).map(Some)
729 }
730 }
731 }
732
733 pub fn build_for_init(legacy_url: &str, legacy_model: &str) -> Result<Option<Self>> {
758 let backend_env = std::env::var("AI_MEMORY_LLM_BACKEND")
759 .ok()
760 .map(|s| s.trim().to_string())
761 .filter(|s| !s.is_empty());
762 if backend_env.is_some() {
763 return Self::from_env();
764 }
765 Self::new_with_url(legacy_url, legacy_model).map(Some)
766 }
767
768 pub fn build_from_resolved(resolved: &crate::config::ResolvedLlm) -> Result<Option<Self>> {
794 tracing::debug!(
797 "LLM client construction via #1146 resolver — backend={}, model={}, base_url={}, key_source={}, source={}",
798 resolved.backend,
799 resolved.model,
800 resolved.base_url,
801 resolved.api_key_source.as_str(),
802 resolved.source.as_str(),
803 );
804
805 if resolved.backend == BACKEND_OLLAMA {
806 return Self::new_with_url(&resolved.base_url, &resolved.model).map(Some);
807 }
808
809 let Some(api_key) = resolved.api_key() else {
813 return Err(anyhow!(
814 "LLM backend `{}` requires an API key but the resolver \
815 produced none. KeySource = {}. Configure either \
816 AI_MEMORY_LLM_API_KEY, a per-vendor env var (e.g. \
817 XAI_API_KEY), [llm].api_key_env, or [llm].api_key_file \
818 in config.toml. See \
819 https://github.com/alphaonedev/ai-memory-mcp/issues/1146",
820 resolved.backend,
821 resolved.api_key_source.as_str(),
822 ));
823 };
824
825 Self::new_openai_compatible(&resolved.base_url, &resolved.model, api_key).map(Some)
826 }
827
828 pub async fn build_from_resolved_async(
852 resolved: &crate::config::ResolvedLlm,
853 ) -> Result<Option<Self>> {
854 tracing::debug!(
855 "LLM client construction via #1146 resolver (async, FX-D1) — backend={}, model={}, base_url={}, key_source={}, source={}",
856 resolved.backend,
857 resolved.model,
858 resolved.base_url,
859 resolved.api_key_source.as_str(),
860 resolved.source.as_str(),
861 );
862
863 if resolved.backend == BACKEND_OLLAMA {
864 return Self::new_with_url_async(&resolved.base_url, &resolved.model)
865 .await
866 .map(Some);
867 }
868
869 let Some(api_key) = resolved.api_key() else {
870 return Err(anyhow!(
871 "LLM backend `{}` requires an API key but the resolver \
872 produced none. KeySource = {}. Configure either \
873 AI_MEMORY_LLM_API_KEY, a per-vendor env var (e.g. \
874 XAI_API_KEY), [llm].api_key_env, or [llm].api_key_file \
875 in config.toml. See \
876 https://github.com/alphaonedev/ai-memory-mcp/issues/1146",
877 resolved.backend,
878 resolved.api_key_source.as_str(),
879 ));
880 };
881
882 Self::new_openai_compatible(&resolved.base_url, &resolved.model, api_key).map(Some)
883 }
884
885 #[must_use]
893 pub fn is_ollama_native(&self) -> bool {
894 matches!(self.provider, LlmProvider::Ollama)
895 }
896
897 pub fn new_openai_compatible(base_url: &str, model: &str, api_key: &str) -> Result<Self> {
907 let client = reqwest::Client::builder()
908 .timeout(GENERATE_TIMEOUT)
909 .connect_timeout(CONNECT_TIMEOUT)
910 .build()
911 .context("Failed to build HTTP client")?;
912 Ok(Self {
913 provider: LlmProvider::OpenAiCompatible {
914 api_key: api_key.to_string(),
915 },
916 base_url: base_url.trim_end_matches('/').to_string(),
917 model: model.to_string(),
918 client,
919 breaker: Mutex::new(BreakerState::new()),
920 embed_dimensions: None,
921 })
922 }
923
924 #[must_use]
928 pub fn with_embed_dimensions(mut self, dims: Option<u32>) -> Self {
929 self.embed_dimensions = dims;
930 self
931 }
932
933 pub fn new_with_url(base_url: &str, model: &str) -> Result<Self> {
954 block_on_local(|| Self::new_with_url_async(base_url, model))
955 }
956
957 pub async fn new_with_url_async(base_url: &str, model: &str) -> Result<Self> {
963 let instance = Self::new_with_url_no_health_check(base_url, model)?;
964
965 if !instance.is_available_async().await {
966 return Err(anyhow!(
967 "Ollama is not running or not reachable at {}. \
968 Start it with: ollama serve",
969 instance.base_url
970 ));
971 }
972
973 Ok(instance)
974 }
975
976 pub fn new_with_url_no_health_check(base_url: &str, model: &str) -> Result<Self> {
993 let client = reqwest::Client::builder()
994 .timeout(GENERATE_TIMEOUT)
995 .connect_timeout(CONNECT_TIMEOUT)
996 .build()
997 .context("Failed to build HTTP client")?;
998
999 Ok(Self {
1000 provider: LlmProvider::Ollama,
1001 base_url: base_url.trim_end_matches('/').to_string(),
1002 model: model.to_string(),
1003 client,
1004 breaker: Mutex::new(BreakerState::new()),
1005 embed_dimensions: None,
1006 })
1007 }
1008
1009 fn breaker_is_open(&self) -> bool {
1013 self.breaker.lock().map(|b| b.is_open()).unwrap_or(false)
1014 }
1015
1016 fn note_failure(&self) {
1017 if let Ok(mut b) = self.breaker.lock() {
1018 b.record_failure();
1019 }
1020 }
1021
1022 fn note_success(&self) {
1023 if let Ok(mut b) = self.breaker.lock() {
1024 b.record_success();
1025 }
1026 }
1027
1028 #[doc(hidden)]
1030 pub fn circuit_breaker_open(&self) -> bool {
1031 self.breaker_is_open()
1032 }
1033
1034 pub fn is_available(&self) -> bool {
1051 block_on_local(|| self.is_available_async())
1052 }
1053
1054 pub async fn is_available_async(&self) -> bool {
1057 let (url, bearer) = match &self.provider {
1058 LlmProvider::Ollama => (ollama_tags_url(&self.base_url), None),
1059 LlmProvider::OpenAiCompatible { api_key } => {
1060 (format!("{}/models", self.base_url), Some(api_key.as_str()))
1061 }
1062 };
1063 let mut req = self.client.get(&url).timeout(HEALTH_TIMEOUT);
1064 if let Some(key) = bearer {
1065 req = req.bearer_auth(key);
1066 }
1067 match req.send().await {
1068 Ok(r) => r.status().is_success(),
1069 Err(_) => false,
1070 }
1071 }
1072
1073 pub fn ensure_model(&self) -> Result<()> {
1083 block_on_local(|| self.ensure_model_async())
1084 }
1085
1086 pub async fn ensure_model_async(&self) -> Result<()> {
1094 if matches!(self.provider, LlmProvider::OpenAiCompatible { .. }) {
1095 return Ok(());
1096 }
1097 let url = ollama_tags_url(&self.base_url);
1098 let resp = self
1099 .client
1100 .get(&url)
1101 .timeout(Duration::from_secs(10))
1102 .send()
1103 .await
1104 .context("Failed to list Ollama models")?;
1105
1106 let body: Value = read_capped_json(resp)
1107 .await
1108 .context("Failed to parse /api/tags response")?;
1109
1110 let model_exists = body["models"].as_array().is_some_and(|models| {
1111 models.iter().any(|m| {
1112 let name = m["name"].as_str().unwrap_or("");
1113 let our_base = self.model.split(':').next().unwrap_or(&self.model);
1114 name == self.model
1115 || name.starts_with(&format!("{}:", self.model))
1116 || self.model == name.split(':').next().unwrap_or("")
1117 || name == our_base
1118 })
1119 });
1120
1121 if model_exists {
1122 return Ok(());
1123 }
1124
1125 tracing::info!(
1126 "Pulling Ollama model '{}' (this may take a while)...",
1127 self.model
1128 );
1129
1130 let pull_url = format!("{}/api/pull", self.base_url);
1131 let pull_client = reqwest::Client::builder()
1132 .timeout(PULL_TIMEOUT)
1133 .build()
1134 .context("Failed to build pull client")?;
1135
1136 let resp = pull_client
1137 .post(&pull_url)
1138 .json(&json!({ "name": self.model }))
1139 .send()
1140 .await
1141 .context("Failed to pull model from Ollama")?;
1142
1143 if !resp.status().is_success() {
1144 let status = resp.status();
1145 let text = read_capped_text(resp).await;
1146 return Err(anyhow!("Ollama pull failed ({status}): {text}"));
1147 }
1148
1149 tracing::info!("Model '{}' pulled successfully", self.model);
1150 Ok(())
1151 }
1152
1153 pub fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String> {
1169 block_on_local(|| self.generate_async(prompt, system))
1170 }
1171
1172 pub async fn generate_async(&self, prompt: &str, system: Option<&str>) -> Result<String> {
1186 if self.breaker_is_open() {
1187 return Err(anyhow!(
1188 "Failed to send chat request: circuit breaker open \
1189 (last failure within {}s); LLM at {} is not responding",
1190 CIRCUIT_BREAKER_COOLDOWN.as_secs(),
1191 self.base_url,
1192 ));
1193 }
1194 self.check_outbound()?;
1196
1197 let (url, payload, bearer): (String, Value, Option<&str>) = match &self.provider {
1198 LlmProvider::Ollama => {
1199 let mut messages = Vec::new();
1200 if let Some(sys) = system {
1201 messages.push(json!({"role": "system", "content": sys}));
1202 }
1203 messages.push(json!({"role": "user", "content": prompt}));
1204 (
1205 format!("{}/api/chat", self.base_url),
1206 json!({
1207 "model": self.model,
1208 "messages": messages,
1209 "stream": false,
1210 }),
1211 None,
1212 )
1213 }
1214 LlmProvider::OpenAiCompatible { api_key } => {
1215 let mut messages = Vec::new();
1216 if let Some(sys) = system {
1217 messages.push(json!({"role": "system", "content": sys}));
1218 }
1219 messages.push(json!({"role": "user", "content": prompt}));
1220 (
1221 format!("{}/chat/completions", self.base_url),
1222 json!({
1223 "model": self.model,
1224 "messages": messages,
1225 "stream": false,
1226 }),
1227 Some(api_key.as_str()),
1228 )
1229 }
1230 };
1231
1232 let mut req = self
1233 .client
1234 .post(&url)
1235 .timeout(GENERATE_TIMEOUT)
1236 .json(&payload);
1237 if let Some(key) = bearer {
1238 req = req.bearer_auth(key);
1239 }
1240
1241 let resp = match req.send().await {
1242 Ok(r) => r,
1243 Err(e) => {
1244 self.note_failure();
1245 return Err(anyhow::Error::new(e).context("Failed to send chat request"));
1246 }
1247 };
1248
1249 if !resp.status().is_success() {
1250 let status = resp.status();
1251 if status.is_server_error() {
1252 self.note_failure();
1253 }
1254 let text = read_capped_text(resp).await;
1255 return Err(anyhow!("Chat generate failed ({status}): {text}"));
1256 }
1257
1258 let body: Value = match read_capped_json(resp).await {
1259 Ok(b) => b,
1260 Err(e) => {
1261 self.note_failure();
1262 return Err(e.context("Failed to parse chat response"));
1263 }
1264 };
1265
1266 let response_text = match &self.provider {
1267 LlmProvider::Ollama => body["message"]["content"]
1268 .as_str()
1269 .ok_or_else(|| anyhow!("Missing 'message.content' field in chat output"))?
1270 .to_string(),
1271 LlmProvider::OpenAiCompatible { .. } => body["choices"][0]["message"]["content"]
1272 .as_str()
1273 .ok_or_else(|| {
1274 anyhow!(
1275 "Missing 'choices[0].message.content' field in OpenAI-compatible \
1276 chat response; got: {body}"
1277 )
1278 })?
1279 .to_string(),
1280 };
1281
1282 self.note_success();
1283 Ok(response_text)
1284 }
1285
1286 pub fn expand_query(&self, query: &str) -> Result<Vec<String>> {
1288 block_on_local(|| self.expand_query_async(query))
1289 }
1290
1291 pub async fn expand_query_async(&self, query: &str) -> Result<Vec<String>> {
1299 let prompt = QUERY_EXPANSION_PROMPT.replace("{query}", query);
1300 let response = self.generate_async(&prompt, None).await?;
1301
1302 let terms: Vec<String> = response
1303 .lines()
1304 .map(|line| line.trim().to_string())
1305 .filter(|line| !line.is_empty())
1306 .collect();
1307
1308 Ok(terms)
1309 }
1310
1311 pub fn summarize_memories(&self, memories: &[(String, String)]) -> Result<String> {
1313 block_on_local(|| self.summarize_memories_async(memories))
1314 }
1315
1316 pub async fn summarize_memories_async(&self, memories: &[(String, String)]) -> Result<String> {
1323 let formatted = memories
1324 .iter()
1325 .enumerate()
1326 .map(|(i, (title, content))| {
1327 format!("--- Memory {} ---\nTitle: {}\n{}", i + 1, title, content)
1328 })
1329 .collect::<Vec<_>>()
1330 .join("\n\n");
1331
1332 let prompt = SUMMARIZE_PROMPT.replace("{memories}", &formatted);
1333 let response = self.generate_async(&prompt, None).await?;
1334
1335 Ok(response.trim().to_string())
1336 }
1337
1338 pub fn auto_tag(
1348 &self,
1349 title: &str,
1350 content: &str,
1351 model_override: Option<&str>,
1352 ) -> Result<Vec<String>> {
1353 block_on_local(|| self.auto_tag_async(title, content, model_override))
1354 }
1355
1356 pub async fn auto_tag_async(
1363 &self,
1364 title: &str,
1365 content: &str,
1366 model_override: Option<&str>,
1367 ) -> Result<Vec<String>> {
1368 let prompt = AUTO_TAG_PROMPT
1369 .replace("{title}", title)
1370 .replace("{content}", content);
1371 let response = self
1372 .generate_with_model_override_async(&prompt, None, model_override)
1373 .await?;
1374 let tags: Vec<String> = response
1375 .lines()
1376 .map(|line| line.trim().to_lowercase())
1377 .filter(|line| !line.is_empty() && line.len() <= 64)
1378 .take(8)
1379 .collect();
1380 Ok(tags)
1381 }
1382
1383 #[allow(dead_code)]
1392 fn generate_with_model_override(
1393 &self,
1394 prompt: &str,
1395 system: Option<&str>,
1396 model_override: Option<&str>,
1397 ) -> Result<String> {
1398 block_on_local(|| self.generate_with_model_override_async(prompt, system, model_override))
1399 }
1400
1401 #[allow(clippy::too_many_lines)]
1409 pub async fn generate_with_model_override_async(
1410 &self,
1411 prompt: &str,
1412 system: Option<&str>,
1413 model_override: Option<&str>,
1414 ) -> Result<String> {
1415 if self.breaker_is_open() {
1416 return Err(anyhow!(
1417 "Failed to send chat request: circuit breaker open \
1418 (last failure within {}s); LLM at {} is not responding",
1419 CIRCUIT_BREAKER_COOLDOWN.as_secs(),
1420 self.base_url,
1421 ));
1422 }
1423 self.check_outbound()?;
1424 let model = model_override.unwrap_or(&self.model);
1425
1426 let (url, payload, bearer): (String, Value, Option<&str>) = match &self.provider {
1427 LlmProvider::Ollama => {
1428 let mut messages = Vec::new();
1429 if let Some(sys) = system {
1430 messages.push(json!({"role": "system", "content": sys}));
1431 }
1432 messages.push(json!({"role": "user", "content": prompt}));
1433 (
1434 format!("{}/api/chat", self.base_url),
1435 json!({"model": model, "messages": messages, "stream": false}),
1436 None,
1437 )
1438 }
1439 LlmProvider::OpenAiCompatible { api_key } => {
1440 let mut messages = Vec::new();
1441 if let Some(sys) = system {
1442 messages.push(json!({"role": "system", "content": sys}));
1443 }
1444 messages.push(json!({"role": "user", "content": prompt}));
1445 (
1446 format!("{}/chat/completions", self.base_url),
1447 json!({"model": model, "messages": messages, "stream": false}),
1448 Some(api_key.as_str()),
1449 )
1450 }
1451 };
1452
1453 let mut req = self
1454 .client
1455 .post(&url)
1456 .timeout(GENERATE_TIMEOUT)
1457 .json(&payload);
1458 if let Some(key) = bearer {
1459 req = req.bearer_auth(key);
1460 }
1461 let resp = match req.send().await {
1462 Ok(r) => r,
1463 Err(e) => {
1464 self.note_failure();
1465 return Err(anyhow::Error::new(e).context("Failed to send chat request"));
1466 }
1467 };
1468
1469 if !resp.status().is_success() {
1470 let status = resp.status();
1471 if status.is_server_error() {
1472 self.note_failure();
1473 }
1474 let text = read_capped_text(resp).await;
1475 return Err(anyhow!("Generate failed ({status}): {text}"));
1476 }
1477
1478 let body: Value = match read_capped_json(resp).await {
1479 Ok(b) => b,
1480 Err(e) => {
1481 self.note_failure();
1482 return Err(e.context("Failed to parse chat response"));
1483 }
1484 };
1485
1486 let response_text = match &self.provider {
1487 LlmProvider::Ollama => body["message"]["content"]
1488 .as_str()
1489 .ok_or_else(|| anyhow!("Missing 'message.content' in chat response"))?
1490 .to_string(),
1491 LlmProvider::OpenAiCompatible { .. } => body["choices"][0]["message"]["content"]
1492 .as_str()
1493 .ok_or_else(|| {
1494 anyhow!(
1495 "Missing 'choices[0].message.content' in OpenAI-compatible \
1496 chat response; got: {body}"
1497 )
1498 })?
1499 .to_string(),
1500 };
1501
1502 self.note_success();
1503 Ok(response_text)
1504 }
1505
1506 fn check_outbound(&self) -> Result<()> {
1524 let url = reqwest::Url::parse(&self.base_url).ok();
1525 let host = url
1526 .as_ref()
1527 .and_then(|u| u.host_str().map(str::to_string))
1528 .unwrap_or_else(|| self.base_url.clone());
1529 let scheme = url
1530 .as_ref()
1531 .map(|u| u.scheme().to_string())
1532 .unwrap_or_default();
1533 let action = crate::governance::agent_action::AgentAction::NetworkRequest {
1534 host: host.clone(),
1535 scheme,
1536 };
1537 crate::governance::wire_check::check_anyhow(&action)
1538 .with_context(|| format!("governance refused outbound to ollama at {host}"))
1539 }
1540
1541 #[allow(dead_code)]
1554 fn generate_with_body(&self, body: &Value) -> Result<String> {
1555 block_on_local(|| self.generate_with_body_async(body))
1556 }
1557
1558 #[allow(dead_code)]
1567 async fn generate_with_body_async(&self, body: &Value) -> Result<String> {
1568 if self.breaker_is_open() {
1569 return Err(anyhow!(
1570 "Failed to send generate request: circuit breaker open \
1571 (last failure within {}s); ollama at {} is not responding",
1572 CIRCUIT_BREAKER_COOLDOWN.as_secs(),
1573 self.base_url,
1574 ));
1575 }
1576 self.check_outbound()?;
1577 let url = format!("{}/api/generate", self.base_url);
1578 let resp = match self
1579 .client
1580 .post(&url)
1581 .timeout(GENERATE_TIMEOUT)
1582 .json(body)
1583 .send()
1584 .await
1585 {
1586 Ok(r) => r,
1587 Err(e) => {
1588 self.note_failure();
1589 return Err(anyhow::Error::new(e).context("Failed to send generate request"));
1590 }
1591 };
1592
1593 if !resp.status().is_success() {
1594 let status = resp.status();
1595 if status.is_server_error() {
1596 self.note_failure();
1597 }
1598 let text = read_capped_text(resp).await;
1599 return Err(anyhow!("Generate failed ({status}): {text}"));
1600 }
1601
1602 let parsed: Value = match read_capped_json(resp).await {
1603 Ok(v) => v,
1604 Err(e) => {
1605 self.note_failure();
1606 return Err(e.context("Failed to parse generate response"));
1607 }
1608 };
1609
1610 let response_text = parsed["response"]
1611 .as_str()
1612 .ok_or_else(|| anyhow!("Missing 'response' field in generate output"))?
1613 .to_string();
1614
1615 self.note_success();
1616 Ok(response_text)
1617 }
1618
1619 pub fn embed_text(&self, text: &str, embed_model: &str) -> Result<Vec<f32>> {
1627 block_on_local(|| self.embed_text_async(text, embed_model))
1628 }
1629
1630 pub async fn embed_text_async(&self, text: &str, embed_model: &str) -> Result<Vec<f32>> {
1643 if self.breaker_is_open() {
1644 return Err(anyhow!(
1645 "Failed to send embed request: circuit breaker open \
1646 (last failure within {}s); LLM at {} is not responding",
1647 CIRCUIT_BREAKER_COOLDOWN.as_secs(),
1648 self.base_url,
1649 ));
1650 }
1651 self.check_outbound()?;
1652
1653 let (url, payload, bearer): (String, Value, Option<&str>) = match &self.provider {
1654 LlmProvider::Ollama => (
1663 format!("{}/api/embed", self.base_url),
1664 json!({"model": embed_model, "input": text, "truncate": true}),
1665 None,
1666 ),
1667 LlmProvider::OpenAiCompatible { api_key } => (
1674 format!("{}{}", self.base_url, OPENAI_COMPAT_EMBEDDINGS_PATH),
1675 match self.embed_dimensions {
1676 Some(dims) => {
1677 json!({"model": embed_model, "input": text, "dimensions": dims})
1678 }
1679 None => json!({"model": embed_model, "input": text}),
1680 },
1681 Some(api_key.as_str()),
1682 ),
1683 };
1684
1685 let mut req = self
1686 .client
1687 .post(&url)
1688 .timeout(GENERATE_TIMEOUT)
1689 .json(&payload);
1690 if let Some(key) = bearer {
1691 req = req.bearer_auth(key);
1692 }
1693
1694 let resp = match req.send().await {
1695 Ok(r) => r,
1696 Err(e) => {
1697 self.note_failure();
1698 return Err(anyhow::Error::new(e).context("Failed to send embed request"));
1699 }
1700 };
1701
1702 if !resp.status().is_success() {
1703 let status = resp.status();
1704 if status.is_server_error() {
1705 self.note_failure();
1706 }
1707 let text = read_capped_text(resp).await;
1708 return Err(anyhow!("Embed failed ({status}): {text}"));
1709 }
1710
1711 let body: Value = match read_capped_json(resp).await {
1712 Ok(b) => b,
1713 Err(e) => {
1714 self.note_failure();
1715 return Err(e.context("Failed to parse embed response"));
1716 }
1717 };
1718
1719 let embedding_array = match &self.provider {
1720 LlmProvider::Ollama => body["embeddings"]
1721 .as_array()
1722 .and_then(|arr| arr.first())
1723 .and_then(|v| v.as_array())
1724 .ok_or_else(|| anyhow!("Missing 'embeddings[0]' in Ollama embed response"))?,
1725 LlmProvider::OpenAiCompatible { .. } => {
1726 body["data"][0]["embedding"].as_array().ok_or_else(|| {
1727 anyhow!(
1728 "Missing 'data[0].embedding' in OpenAI-compatible embed response; \
1729 got: {body}"
1730 )
1731 })?
1732 }
1733 };
1734
1735 #[allow(clippy::cast_possible_truncation)]
1736 let floats: Vec<f32> = embedding_array
1737 .iter()
1738 .filter_map(|v| v.as_f64().map(|f| f as f32))
1739 .collect();
1740
1741 if floats.is_empty() {
1742 return Err(anyhow!("Empty embedding returned from LLM"));
1743 }
1744
1745 self.note_success();
1746 Ok(floats)
1747 }
1748
1749 pub fn embed_texts(&self, texts: &[&str], embed_model: &str) -> Result<Vec<Vec<f32>>> {
1759 block_on_local(|| self.embed_texts_async(texts, embed_model))
1760 }
1761
1762 pub async fn embed_texts_async(
1791 &self,
1792 texts: &[&str],
1793 embed_model: &str,
1794 ) -> Result<Vec<Vec<f32>>> {
1795 if texts.is_empty() {
1796 return Ok(Vec::new());
1797 }
1798 if matches!(self.provider, LlmProvider::Ollama) {
1799 let mut out = Vec::with_capacity(texts.len());
1800 for t in texts {
1801 out.push(self.embed_text_async(t, embed_model).await?);
1802 }
1803 return Ok(out);
1804 }
1805
1806 let mut out: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
1807 let mut start = 0usize;
1808 while start < texts.len() {
1809 let mut end = start;
1813 let mut bytes = 0usize;
1814 while end < texts.len()
1815 && (end - start) < EMBED_BATCH_MAX_INPUTS
1816 && (end == start || bytes + texts[end].len() <= EMBED_BATCH_MAX_BYTES)
1817 {
1818 bytes += texts[end].len();
1819 end += 1;
1820 }
1821 let chunk = &texts[start..end];
1822 match self.embed_texts_one_request(chunk, embed_model).await {
1823 Ok(vecs) => out.extend(vecs),
1824 Err(batch_err) => {
1825 tracing::warn!(
1829 "batched embed of {} text(s) failed ({batch_err}); \
1830 falling back to per-text requests",
1831 chunk.len()
1832 );
1833 for t in chunk {
1834 out.push(self.embed_text_async(t, embed_model).await?);
1835 }
1836 }
1837 }
1838 start = end;
1839 }
1840 Ok(out)
1841 }
1842
1843 async fn embed_texts_one_request(
1848 &self,
1849 chunk: &[&str],
1850 embed_model: &str,
1851 ) -> Result<Vec<Vec<f32>>> {
1852 if self.breaker_is_open() {
1853 return Err(anyhow!(
1854 "Failed to send embed request: circuit breaker open \
1855 (last failure within {}s); LLM at {} is not responding",
1856 CIRCUIT_BREAKER_COOLDOWN.as_secs(),
1857 self.base_url,
1858 ));
1859 }
1860 self.check_outbound()?;
1861
1862 let LlmProvider::OpenAiCompatible { api_key } = &self.provider else {
1863 return Err(anyhow!(
1864 "embed_texts_one_request requires an OpenAI-compatible provider"
1865 ));
1866 };
1867
1868 let payload = match self.embed_dimensions {
1872 Some(dims) => {
1873 json!({"model": embed_model, "input": chunk, "dimensions": dims})
1874 }
1875 None => json!({"model": embed_model, "input": chunk}),
1876 };
1877
1878 let resp = match self
1879 .client
1880 .post(format!(
1881 "{}{}",
1882 self.base_url, OPENAI_COMPAT_EMBEDDINGS_PATH
1883 ))
1884 .timeout(GENERATE_TIMEOUT)
1885 .json(&payload)
1886 .bearer_auth(api_key)
1887 .send()
1888 .await
1889 {
1890 Ok(r) => r,
1891 Err(e) => {
1892 self.note_failure();
1893 return Err(anyhow::Error::new(e).context("Failed to send embed request"));
1894 }
1895 };
1896
1897 if !resp.status().is_success() {
1898 let status = resp.status();
1899 if status.is_server_error() {
1900 self.note_failure();
1901 }
1902 let text = read_capped_text(resp).await;
1903 return Err(anyhow!("Embed failed ({status}): {text}"));
1904 }
1905
1906 let body: Value = match read_capped_json(resp).await {
1907 Ok(b) => b,
1908 Err(e) => {
1909 self.note_failure();
1910 return Err(e.context("Failed to parse embed response"));
1911 }
1912 };
1913
1914 let parsed = parse_openai_embeddings_batch(&body, chunk.len())?;
1915 self.note_success();
1916 Ok(parsed)
1917 }
1918
1919 pub fn ensure_embed_model(&self, model: &str) -> Result<()> {
1925 block_on_local(|| self.ensure_embed_model_async(model))
1926 }
1927
1928 pub async fn ensure_embed_model_async(&self, model: &str) -> Result<()> {
1936 if matches!(self.provider, LlmProvider::OpenAiCompatible { .. }) {
1937 return Ok(());
1938 }
1939 let url = ollama_tags_url(&self.base_url);
1940 let resp = self
1941 .client
1942 .get(&url)
1943 .timeout(std::time::Duration::from_secs(10))
1944 .send()
1945 .await
1946 .context("Failed to list Ollama models")?;
1947
1948 let body: Value = read_capped_json(resp)
1949 .await
1950 .context("Failed to parse /api/tags response")?;
1951 let model_exists = body["models"].as_array().is_some_and(|models| {
1952 models.iter().any(|m| {
1953 let name = m["name"].as_str().unwrap_or("");
1954 name == model
1955 || name.starts_with(&format!("{model}:"))
1956 || model == name.split(':').next().unwrap_or("")
1957 })
1958 });
1959
1960 if model_exists {
1961 return Ok(());
1962 }
1963
1964 tracing::info!("Pulling Ollama embedding model '{}'...", model);
1965 let pull_url = format!("{}/api/pull", self.base_url);
1966 let pull_client = reqwest::Client::builder()
1967 .timeout(PULL_TIMEOUT)
1968 .build()
1969 .context("Failed to build pull client")?;
1970 let resp = pull_client
1971 .post(&pull_url)
1972 .json(&json!({ "name": model }))
1973 .send()
1974 .await
1975 .context("Failed to pull embedding model from Ollama")?;
1976
1977 if !resp.status().is_success() {
1978 let status = resp.status();
1979 let text = read_capped_text(resp).await;
1980 return Err(anyhow!("Ollama embed model pull failed ({status}): {text}"));
1981 }
1982
1983 tracing::info!("Embedding model '{}' pulled successfully", model);
1984 Ok(())
1985 }
1986
1987 pub fn detect_contradiction(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
1989 block_on_local(|| self.detect_contradiction_async(mem_a, mem_b))
1990 }
1991
1992 pub async fn detect_contradiction_async(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
2000 let prompt = CONTRADICTION_PROMPT
2001 .replace("{a}", mem_a)
2002 .replace("{b}", mem_b);
2003
2004 let response = self.generate_async(&prompt, None).await?;
2005 let answer = response.trim().to_lowercase();
2006
2007 Ok(answer.starts_with("yes"))
2008 }
2009}
2010
2011#[cfg(test)]
2012mod tests {
2013 use super::*;
2014
2015 #[test]
2016 fn test_prompt_templates_have_placeholders() {
2017 assert!(QUERY_EXPANSION_PROMPT.contains("{query}"));
2018 assert!(SUMMARIZE_PROMPT.contains("{memories}"));
2019 assert!(AUTO_TAG_PROMPT.contains("{title}"));
2020 assert!(AUTO_TAG_PROMPT.contains("{content}"));
2021 assert!(CONTRADICTION_PROMPT.contains("{a}"));
2022 assert!(CONTRADICTION_PROMPT.contains("{b}"));
2023 }
2024
2025 #[test]
2026 fn test_default_url() {
2027 assert_eq!(DEFAULT_OLLAMA_URL, "http://localhost:11434");
2028 }
2029
2030 #[test]
2033 fn parse_openai_embeddings_batch_orders_by_index_1603() {
2034 let body = serde_json::json!({"data": [
2035 {"index": 1, "embedding": [2.0, 2.0]},
2036 {"index": 0, "embedding": [1.0, 1.0]},
2037 ]});
2038 let out = parse_openai_embeddings_batch(&body, 2).expect("parse");
2039 assert_eq!(out, vec![vec![1.0, 1.0], vec![2.0, 2.0]]);
2040
2041 let no_index = serde_json::json!({"data": [
2042 {"embedding": [1.0]},
2043 {"embedding": [2.0]},
2044 ]});
2045 let out = parse_openai_embeddings_batch(&no_index, 2).expect("positional parse");
2046 assert_eq!(out, vec![vec![1.0], vec![2.0]]);
2047 }
2048
2049 #[test]
2053 fn parse_openai_embeddings_batch_rejects_malformed_1603() {
2054 let short = serde_json::json!({"data": [{"index": 0, "embedding": [1.0]}]});
2055 assert!(
2056 parse_openai_embeddings_batch(&short, 2).is_err(),
2057 "count mismatch"
2058 );
2059
2060 let dup = serde_json::json!({"data": [
2061 {"index": 0, "embedding": [1.0]},
2062 {"index": 0, "embedding": [2.0]},
2063 ]});
2064 assert!(
2065 parse_openai_embeddings_batch(&dup, 2).is_err(),
2066 "duplicate index"
2067 );
2068
2069 let oob = serde_json::json!({"data": [
2070 {"index": 0, "embedding": [1.0]},
2071 {"index": 9, "embedding": [2.0]},
2072 ]});
2073 assert!(
2074 parse_openai_embeddings_batch(&oob, 2).is_err(),
2075 "out-of-range index"
2076 );
2077
2078 let missing = serde_json::json!({"data": [{"index": 0}]});
2079 assert!(
2080 parse_openai_embeddings_batch(&missing, 1).is_err(),
2081 "missing embedding"
2082 );
2083
2084 let empty = serde_json::json!({"data": [{"index": 0, "embedding": []}]});
2085 assert!(
2086 parse_openai_embeddings_batch(&empty, 1).is_err(),
2087 "empty vector"
2088 );
2089
2090 let no_data = serde_json::json!({"object": "list"});
2091 assert!(
2092 parse_openai_embeddings_batch(&no_data, 1).is_err(),
2093 "missing data"
2094 );
2095 }
2096
2097 #[test]
2101 fn default_base_url_for_alias_covers_all_15_aliases_1067() {
2102 let cases: &[(&str, Option<&str>)] = &[
2103 ("openai", Some("https://api.openai.com/v1")),
2104 ("xai", Some("https://api.x.ai/v1")),
2105 ("anthropic", Some("https://api.anthropic.com/v1")),
2106 (
2107 "gemini",
2108 Some("https://generativelanguage.googleapis.com/v1beta/openai"),
2109 ),
2110 ("deepseek", Some("https://api.deepseek.com/v1")),
2111 ("kimi", Some("https://api.moonshot.cn/v1")),
2112 ("moonshot", Some("https://api.moonshot.cn/v1")),
2113 (
2114 "qwen",
2115 Some("https://dashscope.aliyuncs.com/compatible-mode/v1"),
2116 ),
2117 (
2118 "dashscope",
2119 Some("https://dashscope.aliyuncs.com/compatible-mode/v1"),
2120 ),
2121 ("mistral", Some("https://api.mistral.ai/v1")),
2122 ("groq", Some("https://api.groq.com/openai/v1")),
2123 ("together", Some("https://api.together.xyz/v1")),
2124 ("cerebras", Some("https://api.cerebras.ai/v1")),
2125 ("openrouter", Some("https://openrouter.ai/api/v1")),
2126 ("fireworks", Some("https://api.fireworks.ai/inference/v1")),
2127 ("lmstudio", Some("http://localhost:1234/v1")),
2128 ("openai-compatible", None),
2129 ("totally-unknown-vendor", None),
2130 ];
2131 for (alias, expected) in cases {
2132 let got = default_base_url_for_alias(alias);
2133 assert_eq!(
2134 got, *expected,
2135 "#1067: alias `{alias}` must resolve to {expected:?}; got {got:?}"
2136 );
2137 }
2138 }
2139
2140 #[test]
2142 fn alias_api_key_env_vars_per_alias_pins_1067() {
2143 let cases: &[(&str, &[&str])] = &[
2144 ("openai", &["OPENAI_API_KEY"]),
2145 ("xai", &["XAI_API_KEY"]),
2146 ("anthropic", &["ANTHROPIC_API_KEY"]),
2147 ("gemini", &["GEMINI_API_KEY", "GOOGLE_API_KEY"]),
2148 ("deepseek", &["DEEPSEEK_API_KEY"]),
2149 ("kimi", &["MOONSHOT_API_KEY", "KIMI_API_KEY"]),
2150 ("moonshot", &["MOONSHOT_API_KEY", "KIMI_API_KEY"]),
2151 ("qwen", &["DASHSCOPE_API_KEY", "QWEN_API_KEY"]),
2152 ("dashscope", &["DASHSCOPE_API_KEY", "QWEN_API_KEY"]),
2153 ("mistral", &["MISTRAL_API_KEY"]),
2154 ("groq", &["GROQ_API_KEY"]),
2155 ("together", &["TOGETHER_API_KEY"]),
2156 ("cerebras", &["CEREBRAS_API_KEY"]),
2157 ("openrouter", &["OPENROUTER_API_KEY"]),
2158 ("fireworks", &["FIREWORKS_API_KEY"]),
2159 (BACKEND_OLLAMA, &[]),
2160 ("lmstudio", &[]),
2161 ("openai-compatible", &[]),
2162 ("totally-unknown-vendor", &[]),
2163 ];
2164 for (alias, expected) in cases {
2165 let got = alias_api_key_env_vars(alias);
2166 assert_eq!(
2167 got, *expected,
2168 "#1067: alias `{alias}` env-var preference list must be {expected:?}; got {got:?}"
2169 );
2170 }
2171 }
2172}
2173
2174#[cfg(test)]
2175#[allow(
2176 clippy::unused_self,
2177 clippy::unnecessary_wraps,
2178 clippy::needless_pass_by_value,
2179 clippy::wildcard_imports,
2180 clippy::doc_markdown
2181)]
2182pub mod test_support {
2183 use super::*;
2184
2185 pub enum MockFailure {
2188 ModelNotFound,
2189 Timeout,
2190 MalformedResponse,
2191 ApiError(String),
2192 EmptyResponse,
2193 NetworkError,
2194 }
2195
2196 pub struct MockOllamaClient {
2197 pub base_url: String,
2198 pub model: String,
2199 pub fail_with: Option<MockFailure>,
2200 }
2201
2202 impl MockOllamaClient {
2203 pub fn new_with_url(base_url: &str, model: &str) -> Result<Self> {
2205 Ok(Self {
2206 base_url: base_url.trim_end_matches('/').to_string(),
2207 model: model.to_string(),
2208 fail_with: None,
2209 })
2210 }
2211
2212 pub fn with_failure(base_url: &str, model: &str, failure: MockFailure) -> Result<Self> {
2214 Ok(Self {
2215 base_url: base_url.trim_end_matches('/').to_string(),
2216 model: model.to_string(),
2217 fail_with: Some(failure),
2218 })
2219 }
2220
2221 fn should_fail(&self) -> Option<&MockFailure> {
2223 self.fail_with.as_ref()
2224 }
2225
2226 pub fn is_available(&self) -> bool {
2228 !matches!(self.should_fail(), Some(MockFailure::NetworkError))
2229 }
2230
2231 pub fn ensure_model(&self) -> Result<()> {
2233 match self.should_fail() {
2234 Some(MockFailure::ModelNotFound) => Err(anyhow!(
2235 "Model 'unknown-model' not found in Ollama registry"
2236 )),
2237 Some(MockFailure::Timeout) => {
2238 Err(anyhow!("Failed to list Ollama models: operation timed out"))
2239 }
2240 Some(MockFailure::ApiError(msg)) => {
2241 Err(anyhow!("Ollama pull failed (404): {}", msg))
2242 }
2243 Some(MockFailure::NetworkError) => Err(anyhow!(
2244 "Failed to pull model from Ollama: connection refused"
2245 )),
2246 _ => Ok(()),
2247 }
2248 }
2249
2250 pub fn ensure_embed_model(&self, _model: &str) -> Result<()> {
2252 match self.should_fail() {
2253 Some(MockFailure::ModelNotFound) => Err(anyhow!("Embedding model not found")),
2254 Some(MockFailure::Timeout) => {
2255 Err(anyhow!("Failed to list Ollama models: operation timed out"))
2256 }
2257 Some(MockFailure::ApiError(msg)) => {
2258 Err(anyhow!("Ollama embed model pull failed (404): {}", msg))
2259 }
2260 Some(MockFailure::NetworkError) => Err(anyhow!(
2261 "Failed to pull embedding model from Ollama: connection refused"
2262 )),
2263 _ => Ok(()),
2264 }
2265 }
2266
2267 pub fn generate(&self, prompt: &str, _system: Option<&str>) -> Result<String> {
2269 match self.should_fail() {
2270 Some(MockFailure::Timeout) => {
2271 return Err(anyhow!("Failed to send chat request: operation timed out"));
2272 }
2273 Some(MockFailure::MalformedResponse) => {
2274 return Err(anyhow!("Failed to parse chat response: invalid JSON"));
2275 }
2276 Some(MockFailure::EmptyResponse) => {
2277 return Err(anyhow!("Missing 'message.content' field in chat output"));
2278 }
2279 Some(MockFailure::ApiError(msg)) => {
2280 return Err(anyhow!("Chat generate failed (500): {}", msg));
2281 }
2282 Some(MockFailure::NetworkError) => {
2283 return Err(anyhow!("Failed to send chat request: connection refused"));
2284 }
2285 _ => {}
2286 }
2287
2288 if prompt.contains("expand") || prompt.contains("search") {
2290 Ok("semantic search\nquery terms\nvector retrieval\ninformation retrieval\nsimilarity matching"
2291 .to_string())
2292 } else if prompt.contains("Summarize") {
2293 Ok("This is a consolidated summary of multiple memories covering key facts and decisions."
2294 .to_string())
2295 } else if prompt.contains("tags") {
2296 Ok("important\nkey-fact\nstatus-update\ntechnical".to_string())
2297 } else if prompt.contains("contradict") {
2298 if prompt.contains("yes") || prompt.contains("true") {
2299 Ok("yes".to_string())
2300 } else {
2301 Ok("no".to_string())
2302 }
2303 } else {
2304 Ok("Mock response for: ".to_string() + &prompt[..prompt.len().min(50)])
2305 }
2306 }
2307
2308 pub fn expand_query(&self, query: &str) -> Result<Vec<String>> {
2310 if let Some(failure) = self.should_fail() {
2311 return Err(match failure {
2312 MockFailure::Timeout => {
2313 anyhow!("Failed to send chat request: operation timed out")
2314 }
2315 MockFailure::MalformedResponse => {
2316 anyhow!("Failed to parse chat response: invalid JSON")
2317 }
2318 MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
2319 _ => anyhow!("Generate failed"),
2320 });
2321 }
2322 let terms: Vec<String> = vec![
2323 format!("{}-related", query),
2324 format!("{}-expanded", query),
2325 "semantic-search".to_string(),
2326 "vector-expansion".to_string(),
2327 "query-variants".to_string(),
2328 ];
2329 Ok(terms.to_vec())
2330 }
2331
2332 pub fn summarize_memories(&self, memories: &[(String, String)]) -> Result<String> {
2334 if memories.is_empty() {
2335 return Err(anyhow!("Cannot summarize empty memories list"));
2336 }
2337 if let Some(failure) = self.should_fail() {
2338 return Err(match failure {
2339 MockFailure::Timeout => {
2340 anyhow!("Failed to send chat request: operation timed out")
2341 }
2342 MockFailure::MalformedResponse => {
2343 anyhow!("Failed to parse chat response: invalid JSON")
2344 }
2345 MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
2346 _ => anyhow!("Generate failed"),
2347 });
2348 }
2349 let count = memories.len();
2350 Ok(format!(
2351 "Summary of {count} memories: consolidated facts and key decisions preserved"
2352 ))
2353 }
2354
2355 pub fn auto_tag(
2361 &self,
2362 title: &str,
2363 _content: &str,
2364 _model_override: Option<&str>,
2365 ) -> Result<Vec<String>> {
2366 if let Some(failure) = self.should_fail() {
2367 return Err(match failure {
2368 MockFailure::Timeout => {
2369 anyhow!("Failed to send chat request: operation timed out")
2370 }
2371 MockFailure::MalformedResponse => {
2372 anyhow!("Failed to parse chat response: invalid JSON")
2373 }
2374 MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
2375 _ => anyhow!("Generate failed"),
2376 });
2377 }
2378 let tags: Vec<String> = vec![
2379 "important".to_string(),
2380 format!("{}-tag", title.split_whitespace().next().unwrap_or("data")),
2381 "memory".to_string(),
2382 ];
2383 Ok(tags)
2384 }
2385
2386 pub fn embed_text(&self, text: &str, _embed_model: &str) -> Result<Vec<f32>> {
2388 match self.should_fail() {
2389 Some(MockFailure::Timeout) => {
2390 return Err(anyhow!(
2391 "Failed to send embed request to Ollama: operation timed out"
2392 ));
2393 }
2394 Some(MockFailure::MalformedResponse) => {
2395 return Err(anyhow!(
2396 "Failed to parse Ollama embed response: invalid JSON"
2397 ));
2398 }
2399 Some(MockFailure::EmptyResponse) => {
2400 return Err(anyhow!("Missing embeddings in Ollama response"));
2401 }
2402 Some(MockFailure::ApiError(msg)) => {
2403 return Err(anyhow!("Ollama embed failed (500): {}", msg));
2404 }
2405 Some(MockFailure::NetworkError) => {
2406 return Err(anyhow!(
2407 "Failed to send embed request to Ollama: connection refused"
2408 ));
2409 }
2410 Some(MockFailure::ModelNotFound) => {
2411 return Err(anyhow!("Ollama embed failed (404): model not found"));
2412 }
2413 _ => {}
2414 }
2415 let base_val = (text.len() % 10) as f32 / 100.0;
2416 let embedding: Vec<f32> = (0..768).map(|i| base_val + (i as f32) * 0.0001).collect();
2417 Ok(embedding)
2418 }
2419
2420 pub fn detect_contradiction(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
2422 if let Some(failure) = self.should_fail() {
2423 return Err(match failure {
2424 MockFailure::Timeout => {
2425 anyhow!("Failed to send chat request: operation timed out")
2426 }
2427 MockFailure::MalformedResponse => {
2428 anyhow!("Failed to parse chat response: invalid JSON")
2429 }
2430 MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
2431 _ => anyhow!("Generate failed"),
2432 });
2433 }
2434 let combined = format!("{mem_a} {mem_b}").to_lowercase();
2435 let contradictory_keywords = &["not", "never", "always", "contradiction", "opposite"];
2436 let count = contradictory_keywords
2437 .iter()
2438 .filter(|&&kw| combined.contains(kw))
2439 .count();
2440 Ok(count > 1)
2441 }
2442 }
2443}
2444
2445#[cfg(test)]
2446mod mock_tests {
2447 use super::test_support::MockOllamaClient;
2448 use super::{AUTO_TAG_PROMPT, CONTRADICTION_PROMPT, QUERY_EXPANSION_PROMPT, SUMMARIZE_PROMPT};
2449
2450 #[test]
2451 fn test_mock_new_with_url() {
2452 let client = MockOllamaClient::new_with_url("http://localhost:11434", "test-model");
2453 assert!(client.is_ok());
2454 let client = client.unwrap();
2455 assert_eq!(client.base_url, "http://localhost:11434");
2456 assert_eq!(client.model, "test-model");
2457 }
2458
2459 #[test]
2460 fn test_mock_new_with_url_trailing_slash() {
2461 let client = MockOllamaClient::new_with_url("http://localhost:11434/", "test-model");
2462 assert!(client.is_ok());
2463 let client = client.unwrap();
2464 assert_eq!(client.base_url, "http://localhost:11434");
2465 }
2466
2467 #[test]
2468 fn test_mock_is_available() {
2469 let client =
2470 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2471 assert!(client.is_available());
2472 }
2473
2474 #[test]
2475 fn test_mock_ensure_model() {
2476 let client =
2477 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2478 assert!(client.ensure_model().is_ok());
2479 }
2480
2481 #[test]
2482 fn test_mock_ensure_embed_model() {
2483 let client =
2484 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2485 assert!(client.ensure_embed_model("nomic-embed-text").is_ok());
2486 }
2487
2488 #[test]
2489 fn test_mock_generate_query_expansion() {
2490 let client =
2491 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2492 let prompt = QUERY_EXPANSION_PROMPT.replace("{query}", "search test");
2493 let result = client.generate(&prompt, None);
2494 assert!(result.is_ok());
2495 let response = result.unwrap();
2496 assert!(!response.is_empty());
2497 }
2498
2499 #[test]
2500 fn test_mock_expand_query() {
2501 let client =
2502 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2503 let result = client.expand_query("test query");
2504 assert!(result.is_ok());
2505 let terms = result.unwrap();
2506 assert!(!terms.is_empty());
2507 assert!(terms.len() >= 3);
2508 }
2509
2510 #[test]
2511 fn test_mock_summarize_memories() {
2512 let client =
2513 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2514 let memories = vec![
2515 ("Title 1".to_string(), "Content 1".to_string()),
2516 ("Title 2".to_string(), "Content 2".to_string()),
2517 ];
2518 let result = client.summarize_memories(&memories);
2519 assert!(result.is_ok());
2520 let summary = result.unwrap();
2521 assert!(summary.contains('2'));
2522 }
2523
2524 #[test]
2525 fn test_mock_auto_tag() {
2526 let client =
2527 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2528 let result = client.auto_tag("Test Title", "test content", None);
2529 assert!(result.is_ok());
2530 let tags = result.unwrap();
2531 assert!(!tags.is_empty());
2532 assert!(tags.len() >= 2);
2533 }
2534
2535 #[test]
2536 fn test_mock_embed_text() {
2537 let client =
2538 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2539 let result = client.embed_text("test text", "nomic-embed-text");
2540 assert!(result.is_ok());
2541 let embedding = result.unwrap();
2542 assert_eq!(embedding.len(), 768);
2543 assert!(embedding.iter().all(|&x| x >= 0.0));
2544 }
2545
2546 #[test]
2547 fn test_mock_embed_text_deterministic() {
2548 let client =
2549 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2550 let result1 = client.embed_text("same text", "nomic-embed-text");
2551 let result2 = client.embed_text("same text", "nomic-embed-text");
2552 assert!(result1.is_ok());
2553 assert!(result2.is_ok());
2554 assert_eq!(result1.unwrap(), result2.unwrap());
2555 }
2556
2557 #[test]
2558 fn test_mock_detect_contradiction_true() {
2559 let client =
2560 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2561 let result = client.detect_contradiction(
2562 "The system always works",
2563 "The system never works correctly",
2564 );
2565 assert!(result.is_ok());
2566 let is_contradiction = result.unwrap();
2567 assert!(is_contradiction);
2568 }
2569
2570 #[test]
2571 fn test_mock_detect_contradiction_false() {
2572 let client =
2573 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2574 let result = client.detect_contradiction(
2575 "The memory is about search",
2576 "Additional details about the same search",
2577 );
2578 assert!(result.is_ok());
2579 }
2580
2581 #[test]
2582 fn test_mock_generate_summarize_prompt() {
2583 let client =
2584 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2585 let prompt = SUMMARIZE_PROMPT.replace(
2586 "{memories}",
2587 "--- Memory 1 ---\nTitle: Test\nThis is a test",
2588 );
2589 let result = client.generate(&prompt, None);
2590 assert!(result.is_ok());
2591 let response = result.unwrap();
2592 assert!(response.contains("summary") || response.contains("Summary"));
2593 }
2594
2595 #[test]
2596 fn test_mock_generate_auto_tag_prompt() {
2597 let client =
2598 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2599 let prompt = AUTO_TAG_PROMPT
2600 .replace("{title}", "Important Update")
2601 .replace("{content}", "Some content");
2602 let result = client.generate(&prompt, None);
2603 assert!(result.is_ok());
2604 let response = result.unwrap();
2605 assert!(!response.is_empty());
2606 }
2607
2608 #[test]
2609 fn test_mock_generate_contradiction_prompt() {
2610 let client =
2611 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2612 let prompt = CONTRADICTION_PROMPT
2613 .replace("{a}", "Statement A")
2614 .replace("{b}", "Statement B");
2615 let result = client.generate(&prompt, None);
2616 assert!(result.is_ok());
2617 let response = result.unwrap();
2618 assert!(!response.is_empty());
2619 }
2620
2621 #[test]
2624 fn test_mock_ensure_model_returns_not_found_error() {
2625 let client = MockOllamaClient::with_failure(
2626 "http://localhost:11434",
2627 "unknown-model",
2628 super::test_support::MockFailure::ModelNotFound,
2629 )
2630 .unwrap();
2631 let result = client.ensure_model();
2632 assert!(result.is_err());
2633 let err_msg = result.unwrap_err().to_string();
2634 assert!(err_msg.contains("not found"));
2635 }
2636
2637 #[test]
2638 fn test_mock_ensure_model_returns_timeout_error() {
2639 let client = MockOllamaClient::with_failure(
2640 "http://localhost:11434",
2641 "test-model",
2642 super::test_support::MockFailure::Timeout,
2643 )
2644 .unwrap();
2645 let result = client.ensure_model();
2646 assert!(result.is_err());
2647 let err_msg = result.unwrap_err().to_string();
2648 assert!(err_msg.contains("timed out"));
2649 }
2650
2651 #[test]
2652 fn test_mock_ensure_model_returns_network_error() {
2653 let client = MockOllamaClient::with_failure(
2654 "http://localhost:11434",
2655 "test-model",
2656 super::test_support::MockFailure::NetworkError,
2657 )
2658 .unwrap();
2659 let result = client.ensure_model();
2660 assert!(result.is_err());
2661 let err_msg = result.unwrap_err().to_string();
2662 assert!(err_msg.contains("connection"));
2663 }
2664
2665 #[test]
2666 fn test_mock_ensure_embed_model_returns_not_found_error() {
2667 let client = MockOllamaClient::with_failure(
2668 "http://localhost:11434",
2669 "test-model",
2670 super::test_support::MockFailure::ModelNotFound,
2671 )
2672 .unwrap();
2673 let result = client.ensure_embed_model("unknown-embed-model");
2674 assert!(result.is_err());
2675 }
2676
2677 #[test]
2678 fn test_mock_generate_returns_timeout_error() {
2679 let client = MockOllamaClient::with_failure(
2680 "http://localhost:11434",
2681 "test-model",
2682 super::test_support::MockFailure::Timeout,
2683 )
2684 .unwrap();
2685 let result = client.generate("test prompt", None);
2686 assert!(result.is_err());
2687 let err_msg = result.unwrap_err().to_string();
2688 assert!(err_msg.contains("timed out"));
2689 }
2690
2691 #[test]
2692 fn test_mock_generate_handles_malformed_json() {
2693 let client = MockOllamaClient::with_failure(
2694 "http://localhost:11434",
2695 "test-model",
2696 super::test_support::MockFailure::MalformedResponse,
2697 )
2698 .unwrap();
2699 let result = client.generate("test prompt", None);
2700 assert!(result.is_err());
2701 }
2702
2703 #[test]
2704 fn test_mock_generate_handles_empty_response() {
2705 let client = MockOllamaClient::with_failure(
2706 "http://localhost:11434",
2707 "test-model",
2708 super::test_support::MockFailure::EmptyResponse,
2709 )
2710 .unwrap();
2711 let result = client.generate("test prompt", None);
2712 assert!(result.is_err());
2713 }
2714
2715 #[test]
2716 fn test_mock_generate_handles_api_error() {
2717 let client = MockOllamaClient::with_failure(
2718 "http://localhost:11434",
2719 "test-model",
2720 super::test_support::MockFailure::ApiError("Internal Error".to_string()),
2721 )
2722 .unwrap();
2723 let result = client.generate("test prompt", None);
2724 assert!(result.is_err());
2725 }
2726
2727 #[test]
2728 fn test_mock_expand_query_passes_through_generate_error() {
2729 let client = MockOllamaClient::with_failure(
2730 "http://localhost:11434",
2731 "test-model",
2732 super::test_support::MockFailure::Timeout,
2733 )
2734 .unwrap();
2735 let result = client.expand_query("test query");
2736 assert!(result.is_err());
2737 }
2738
2739 #[test]
2740 fn test_mock_summarize_memories_handles_empty_input() {
2741 let client =
2742 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2743 let empty_memories: Vec<(String, String)> = vec![];
2744 let result = client.summarize_memories(&empty_memories);
2745 assert!(result.is_err());
2746 }
2747
2748 #[test]
2749 fn test_mock_summarize_memories_handles_timeout() {
2750 let client = MockOllamaClient::with_failure(
2751 "http://localhost:11434",
2752 "test-model",
2753 super::test_support::MockFailure::Timeout,
2754 )
2755 .unwrap();
2756 let memories = vec![("Title".to_string(), "Content".to_string())];
2757 let result = client.summarize_memories(&memories);
2758 assert!(result.is_err());
2759 }
2760
2761 #[test]
2762 fn test_mock_auto_tag_handles_special_characters() {
2763 let client =
2764 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2765 let result = client.auto_tag("Title @#$%", "content", None);
2766 assert!(result.is_ok());
2767 }
2768
2769 #[test]
2770 fn test_mock_auto_tag_timeout() {
2771 let client = MockOllamaClient::with_failure(
2772 "http://localhost:11434",
2773 "test-model",
2774 super::test_support::MockFailure::Timeout,
2775 )
2776 .unwrap();
2777 let result = client.auto_tag("Test", "content", None);
2778 assert!(result.is_err());
2779 }
2780
2781 #[test]
2782 fn test_mock_embed_text_returns_768_dim() {
2783 let client =
2784 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2785 let result = client.embed_text("test", "nomic-embed-text-v1.5");
2786 assert!(result.is_ok());
2787 assert_eq!(result.unwrap().len(), 768);
2788 }
2789
2790 #[test]
2791 fn test_mock_embed_text_timeout() {
2792 let client = MockOllamaClient::with_failure(
2793 "http://localhost:11434",
2794 "test-model",
2795 super::test_support::MockFailure::Timeout,
2796 )
2797 .unwrap();
2798 let result = client.embed_text("test", "nomic-embed-text");
2799 assert!(result.is_err());
2800 }
2801
2802 #[test]
2803 fn test_mock_embed_text_malformed() {
2804 let client = MockOllamaClient::with_failure(
2805 "http://localhost:11434",
2806 "test-model",
2807 super::test_support::MockFailure::MalformedResponse,
2808 )
2809 .unwrap();
2810 let result = client.embed_text("test", "nomic-embed-text");
2811 assert!(result.is_err());
2812 }
2813
2814 #[test]
2815 fn test_mock_embed_text_empty_response() {
2816 let client = MockOllamaClient::with_failure(
2817 "http://localhost:11434",
2818 "test-model",
2819 super::test_support::MockFailure::EmptyResponse,
2820 )
2821 .unwrap();
2822 let result = client.embed_text("test", "nomic-embed-text");
2823 assert!(result.is_err());
2824 }
2825
2826 #[test]
2827 fn test_mock_embed_text_model_not_found() {
2828 let client = MockOllamaClient::with_failure(
2829 "http://localhost:11434",
2830 "test-model",
2831 super::test_support::MockFailure::ModelNotFound,
2832 )
2833 .unwrap();
2834 let result = client.embed_text("test", "unknown");
2835 assert!(result.is_err());
2836 }
2837
2838 #[test]
2839 fn test_mock_embed_text_network_error() {
2840 let client = MockOllamaClient::with_failure(
2841 "http://localhost:11434",
2842 "test-model",
2843 super::test_support::MockFailure::NetworkError,
2844 )
2845 .unwrap();
2846 let result = client.embed_text("test", "nomic-embed-text");
2847 assert!(result.is_err());
2848 }
2849
2850 #[test]
2851 fn test_mock_detect_contradiction_yes_case() {
2852 let client =
2853 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2854 let result =
2855 client.detect_contradiction("The system always works", "The system never works");
2856 assert!(result.is_ok());
2857 assert!(result.unwrap());
2858 }
2859
2860 #[test]
2861 fn test_mock_detect_contradiction_no_case() {
2862 let client =
2863 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
2864 let result =
2865 client.detect_contradiction("Consistent statement A", "Consistent statement B");
2866 assert!(result.is_ok());
2867 }
2868
2869 #[test]
2870 fn test_mock_detect_contradiction_timeout() {
2871 let client = MockOllamaClient::with_failure(
2872 "http://localhost:11434",
2873 "test-model",
2874 super::test_support::MockFailure::Timeout,
2875 )
2876 .unwrap();
2877 let result = client.detect_contradiction("A", "B");
2878 assert!(result.is_err());
2879 }
2880
2881 #[test]
2882 fn test_mock_is_available_network_error() {
2883 let client = MockOllamaClient::with_failure(
2884 "http://localhost:11434",
2885 "test-model",
2886 super::test_support::MockFailure::NetworkError,
2887 )
2888 .unwrap();
2889 assert!(!client.is_available());
2890 }
2891
2892 #[test]
2893 fn test_mock_with_failure_creates_client_that_fails() {
2894 let client = MockOllamaClient::with_failure(
2895 "http://localhost:11434",
2896 "test-model",
2897 super::test_support::MockFailure::Timeout,
2898 )
2899 .unwrap();
2900 let result = client.generate("any", None);
2901 assert!(result.is_err());
2902 }
2903
2904 #[test]
2905 fn test_mock_api_error_variant() {
2906 let client = MockOllamaClient::with_failure(
2907 "http://localhost:11434",
2908 "test-model",
2909 super::test_support::MockFailure::ApiError("Custom msg".to_string()),
2910 )
2911 .unwrap();
2912 let result = client.generate("test", None);
2913 assert!(result.is_err());
2914 assert!(result.unwrap_err().to_string().contains("Custom msg"));
2915 }
2916}
2917
2918#[cfg(test)]
2944#[allow(clippy::too_many_lines, clippy::similar_names)]
2945mod wiremock_tests {
2946 use super::OllamaClient;
2947 use serde_json::json;
2948 use std::net::TcpListener;
2949 use wiremock::matchers::{body_partial_json, method, path};
2950 use wiremock::{Mock, MockServer, ResponseTemplate};
2951
2952 async fn mount_tags_ok(server: &MockServer, models: serde_json::Value) {
2955 Mock::given(method("GET"))
2956 .and(path("/api/tags"))
2957 .respond_with(ResponseTemplate::new(200).set_body_json(models))
2958 .mount(server)
2959 .await;
2960 }
2961
2962 #[tokio::test(flavor = "multi_thread")]
2967 async fn read_capped_bytes_rejects_oversize_1459() {
2968 use super::read_capped_bytes_inner;
2969 let server = MockServer::start().await;
2970 Mock::given(method("GET"))
2971 .and(path("/big"))
2972 .respond_with(ResponseTemplate::new(200).set_body_string("x".repeat(4096)))
2973 .mount(&server)
2974 .await;
2975 let url = format!("{}/big", server.uri());
2976 let resp = reqwest::Client::new().get(&url).send().await.unwrap();
2977 let err = read_capped_bytes_inner(resp, 64)
2978 .await
2979 .expect_err("oversize body MUST be rejected by the cap");
2980 let msg = err.to_string();
2981 assert!(
2982 msg.contains("exceeds cap") || msg.contains("exceeded cap"),
2983 "rejection must name the cap: {msg}"
2984 );
2985 }
2986
2987 #[tokio::test(flavor = "multi_thread")]
2989 async fn read_capped_json_parses_small_body_1459() {
2990 use super::read_capped_json;
2991 let server = MockServer::start().await;
2992 Mock::given(method("GET"))
2993 .and(path("/ok"))
2994 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"hello": "world"})))
2995 .mount(&server)
2996 .await;
2997 let url = format!("{}/ok", server.uri());
2998 let resp = reqwest::Client::new().get(&url).send().await.unwrap();
2999 let v = read_capped_json(resp).await.unwrap();
3000 assert_eq!(v["hello"], "world");
3001 }
3002
3003 #[tokio::test(flavor = "multi_thread")]
3006 async fn perf_12_new_with_url_no_health_check_skips_probe() {
3007 let url = tokio::task::spawn_blocking(|| {
3019 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
3020 let port = listener.local_addr().unwrap().port();
3021 drop(listener);
3022 format!("http://127.0.0.1:{port}")
3023 })
3024 .await
3025 .unwrap();
3026
3027 let (constructed_ok, is_available_after) = tokio::task::spawn_blocking(move || {
3028 let client = OllamaClient::new_with_url_no_health_check(&url, "test-model")
3031 .expect("PERF-12: new_with_url_no_health_check must not probe");
3032 let avail = client.is_available();
3035 (true, avail)
3036 })
3037 .await
3038 .unwrap();
3039
3040 assert!(constructed_ok);
3041 assert!(
3042 !is_available_after,
3043 "PERF-12: lazy is_available() must return false for an unreachable endpoint",
3044 );
3045 }
3046
3047 #[tokio::test(flavor = "multi_thread")]
3050 async fn test_is_available_returns_false_on_connection_refused() {
3051 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
3055 let port = listener.local_addr().unwrap().port();
3056 drop(listener);
3057 let url = format!("http://127.0.0.1:{port}");
3058
3059 let result = tokio::task::spawn_blocking(move || {
3064 let client = reqwest::blocking::Client::builder()
3067 .timeout(std::time::Duration::from_secs(5))
3068 .build()
3069 .unwrap();
3070 let probe = format!("{url}/api/tags");
3071 client
3072 .get(&probe)
3073 .send()
3074 .is_ok_and(|r| r.status().is_success())
3075 })
3076 .await
3077 .unwrap();
3078
3079 assert!(
3080 !result,
3081 "is_available should return false when nothing is listening"
3082 );
3083 }
3084
3085 #[tokio::test(flavor = "multi_thread")]
3086 async fn test_is_available_returns_false_on_500_response() {
3087 let server = MockServer::start().await;
3088 Mock::given(method("GET"))
3089 .and(path("/api/tags"))
3090 .respond_with(ResponseTemplate::new(500))
3091 .mount(&server)
3092 .await;
3093
3094 let uri = server.uri();
3095 let result = tokio::task::spawn_blocking(move || {
3096 OllamaClient::new_with_url(&uri, "test-model")
3099 })
3100 .await
3101 .unwrap();
3102
3103 let err = match result {
3106 Ok(_) => panic!("client construction should fail on 500"),
3107 Err(e) => e.to_string(),
3108 };
3109 assert!(
3110 err.contains("not running") || err.contains("not reachable"),
3111 "expected unreachable-style error, got: {err}"
3112 );
3113 }
3114
3115 #[tokio::test(flavor = "multi_thread")]
3116 async fn test_is_available_returns_true_on_200_with_json_body() {
3117 let server = MockServer::start().await;
3118 mount_tags_ok(&server, json!({"models": []})).await;
3119
3120 let uri = server.uri();
3121 let available = tokio::task::spawn_blocking(move || {
3122 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3123 client.is_available()
3124 })
3125 .await
3126 .unwrap();
3127 assert!(available);
3128 }
3129
3130 #[tokio::test(flavor = "multi_thread")]
3133 async fn test_pull_if_missing_skips_pull_if_model_already_in_tags() {
3134 let server = MockServer::start().await;
3135 Mock::given(method("GET"))
3137 .and(path("/api/tags"))
3138 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3139 "models": [
3140 {"name": "test-model:latest"},
3141 ]
3142 })))
3143 .mount(&server)
3144 .await;
3145
3146 Mock::given(method("POST"))
3150 .and(path("/api/pull"))
3151 .respond_with(ResponseTemplate::new(200))
3152 .expect(0)
3153 .mount(&server)
3154 .await;
3155
3156 let uri = server.uri();
3157 let result = tokio::task::spawn_blocking(move || {
3158 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3159 client.ensure_model()
3160 })
3161 .await
3162 .unwrap();
3163 assert!(
3164 result.is_ok(),
3165 "ensure_model should succeed; got {result:?}"
3166 );
3167 }
3168
3169 #[tokio::test(flavor = "multi_thread")]
3170 async fn test_pull_if_missing_initiates_pull_if_not() {
3171 let server = MockServer::start().await;
3172 Mock::given(method("GET"))
3174 .and(path("/api/tags"))
3175 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
3176 .mount(&server)
3177 .await;
3178 Mock::given(method("POST"))
3180 .and(path("/api/pull"))
3181 .and(body_partial_json(json!({"name": "test-model"})))
3182 .respond_with(ResponseTemplate::new(200).set_body_string(""))
3183 .expect(1)
3184 .mount(&server)
3185 .await;
3186
3187 let uri = server.uri();
3188 let result = tokio::task::spawn_blocking(move || {
3189 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3190 client.ensure_model()
3191 })
3192 .await
3193 .unwrap();
3194 assert!(
3195 result.is_ok(),
3196 "ensure_model should succeed; got {result:?}"
3197 );
3198 }
3200
3201 #[tokio::test(flavor = "multi_thread")]
3204 async fn test_generate_parses_success_response() {
3205 let server = MockServer::start().await;
3206 mount_tags_ok(&server, json!({"models": []})).await;
3207 Mock::given(method("POST"))
3210 .and(path("/api/chat"))
3211 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3212 "message": {"role": "assistant", "content": "hello"},
3213 "done": true,
3214 })))
3215 .mount(&server)
3216 .await;
3217
3218 let uri = server.uri();
3219 let result = tokio::task::spawn_blocking(move || {
3220 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3221 client.generate("ping", None)
3222 })
3223 .await
3224 .unwrap();
3225
3226 assert_eq!(result.unwrap(), "hello");
3227 }
3228
3229 #[tokio::test(flavor = "multi_thread")]
3230 async fn test_generate_returns_error_on_malformed_json() {
3231 let server = MockServer::start().await;
3232 mount_tags_ok(&server, json!({"models": []})).await;
3233 Mock::given(method("POST"))
3234 .and(path("/api/chat"))
3235 .respond_with(
3236 ResponseTemplate::new(200)
3237 .set_body_string("{not valid json")
3238 .insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
3239 )
3240 .mount(&server)
3241 .await;
3242
3243 let uri = server.uri();
3244 let result = tokio::task::spawn_blocking(move || {
3245 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3246 client.generate("ping", None)
3247 })
3248 .await
3249 .unwrap();
3250
3251 assert!(result.is_err(), "malformed JSON should surface an error");
3252 let err = result.unwrap_err().to_string();
3253 assert!(
3254 err.contains("parse") || err.to_lowercase().contains("json"),
3255 "expected a parse error, got: {err}"
3256 );
3257 }
3258
3259 #[tokio::test(flavor = "multi_thread")]
3260 async fn test_generate_returns_error_on_500() {
3261 let server = MockServer::start().await;
3262 mount_tags_ok(&server, json!({"models": []})).await;
3263 Mock::given(method("POST"))
3264 .and(path("/api/chat"))
3265 .respond_with(ResponseTemplate::new(500).set_body_string("internal boom"))
3266 .mount(&server)
3267 .await;
3268
3269 let uri = server.uri();
3270 let result = tokio::task::spawn_blocking(move || {
3271 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3272 client.generate("ping", None)
3273 })
3274 .await
3275 .unwrap();
3276
3277 assert!(result.is_err());
3278 let err = result.unwrap_err().to_string();
3279 assert!(err.contains("500") || err.contains("Chat generate failed"));
3280 }
3281
3282 #[tokio::test(flavor = "multi_thread")]
3283 async fn test_generate_passes_system_prompt_when_provided() {
3284 let server = MockServer::start().await;
3288 mount_tags_ok(&server, json!({"models": []})).await;
3289 Mock::given(method("POST"))
3290 .and(path("/api/chat"))
3291 .and(body_partial_json(json!({
3292 "messages": [
3293 {"role": "system", "content": "be terse"},
3294 {"role": "user", "content": "hi"},
3295 ],
3296 "stream": false,
3297 })))
3298 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3299 "message": {"role": "assistant", "content": "ok"},
3300 })))
3301 .mount(&server)
3302 .await;
3303
3304 let uri = server.uri();
3305 let out = tokio::task::spawn_blocking(move || {
3306 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3307 client.generate("hi", Some("be terse"))
3308 })
3309 .await
3310 .unwrap();
3311 assert_eq!(out.unwrap(), "ok");
3312 }
3313
3314 #[tokio::test(flavor = "multi_thread")]
3317 async fn test_embed_parses_embedding_array() {
3318 let server = MockServer::start().await;
3319 mount_tags_ok(&server, json!({"models": []})).await;
3320 Mock::given(method("POST"))
3322 .and(path("/api/embed"))
3323 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3324 "embeddings": [[0.1_f32, 0.2_f32, 0.3_f32]],
3325 })))
3326 .mount(&server)
3327 .await;
3328
3329 let uri = server.uri();
3330 let vec = tokio::task::spawn_blocking(move || {
3331 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3332 client.embed_text("hello", "nomic-embed-text-v1.5")
3333 })
3334 .await
3335 .unwrap();
3336
3337 let v = vec.unwrap();
3338 assert_eq!(v.len(), 3);
3339 assert!((v[0] - 0.1_f32).abs() < 1e-5);
3340 assert!((v[1] - 0.2_f32).abs() < 1e-5);
3341 assert!((v[2] - 0.3_f32).abs() < 1e-5);
3342 }
3343
3344 #[tokio::test(flavor = "multi_thread")]
3351 async fn ollama_embed_payload_sets_truncate_1595() {
3352 let server = MockServer::start().await;
3353 mount_tags_ok(&server, json!({"models": []})).await;
3354 Mock::given(method("POST"))
3355 .and(path("/api/embed"))
3356 .and(body_partial_json(json!({
3357 "model": "nomic-embed-text",
3358 "input": "hello",
3359 "truncate": true,
3360 })))
3361 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3362 "embeddings": [[0.5_f32, 0.25_f32]],
3363 })))
3364 .mount(&server)
3365 .await;
3366
3367 let uri = server.uri();
3368 let vec = tokio::task::spawn_blocking(move || {
3369 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3370 client.embed_text("hello", "nomic-embed-text")
3371 })
3372 .await
3373 .unwrap();
3374 assert_eq!(vec.unwrap().len(), 2);
3375 }
3376
3377 #[tokio::test(flavor = "multi_thread")]
3381 async fn openai_embed_payload_omits_truncate_1595() {
3382 let server = MockServer::start().await;
3383 Mock::given(method("POST"))
3384 .and(path("/embeddings"))
3385 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3386 "data": [{"embedding": [0.5_f32, 0.25_f32]}],
3387 })))
3388 .mount(&server)
3389 .await;
3390
3391 let uri = server.uri();
3392 let vec = tokio::task::spawn_blocking(move || {
3393 let client =
3394 OllamaClient::new_openai_compatible(&uri, "test-model", "fake-key").unwrap();
3395 client.embed_text("hello", "test-model")
3396 })
3397 .await
3398 .unwrap();
3399 assert_eq!(vec.unwrap().len(), 2);
3400
3401 let requests = server
3402 .received_requests()
3403 .await
3404 .expect("request recording enabled");
3405 let embed_req = requests
3406 .iter()
3407 .find(|r| r.url.path() == "/embeddings")
3408 .expect("embed request recorded");
3409 let body: serde_json::Value = serde_json::from_slice(&embed_req.body).expect("json body");
3410 assert!(
3411 body.get("truncate").is_none(),
3412 "OpenAI-compatible embed payload must not carry the \
3413 Ollama-native truncate key, got: {body}"
3414 );
3415 }
3416
3417 #[tokio::test(flavor = "multi_thread")]
3418 async fn test_embed_returns_error_on_wrong_shape() {
3419 let server = MockServer::start().await;
3420 mount_tags_ok(&server, json!({"models": []})).await;
3421 Mock::given(method("POST"))
3424 .and(path("/api/embed"))
3425 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3426 "embedding": 0.5,
3427 })))
3428 .mount(&server)
3429 .await;
3430
3431 let uri = server.uri();
3432 let result = tokio::task::spawn_blocking(move || {
3433 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3434 client.embed_text("hi", "nomic-embed-text")
3435 })
3436 .await
3437 .unwrap();
3438 assert!(result.is_err());
3439 let err = result.unwrap_err().to_string();
3440 assert!(
3441 err.contains("Missing embeddings") || err.to_lowercase().contains("embed"),
3442 "expected missing-embeddings error, got: {err}"
3443 );
3444 }
3445
3446 #[tokio::test(flavor = "multi_thread")]
3447 async fn test_embed_returns_error_on_500() {
3448 let server = MockServer::start().await;
3449 mount_tags_ok(&server, json!({"models": []})).await;
3450 Mock::given(method("POST"))
3451 .and(path("/api/embed"))
3452 .respond_with(ResponseTemplate::new(500).set_body_string("nope"))
3453 .mount(&server)
3454 .await;
3455
3456 let uri = server.uri();
3457 let result = tokio::task::spawn_blocking(move || {
3458 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3459 client.embed_text("hi", "nomic-embed-text")
3460 })
3461 .await
3462 .unwrap();
3463 assert!(result.is_err());
3464 assert!(result.unwrap_err().to_string().contains("500"));
3465 }
3466
3467 #[tokio::test(flavor = "multi_thread")]
3470 async fn test_expand_query_returns_parsed_terms_one_per_line() {
3471 let server = MockServer::start().await;
3472 mount_tags_ok(&server, json!({"models": []})).await;
3473 Mock::given(method("POST"))
3474 .and(path("/api/chat"))
3475 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3476 "message": {"content": "term1\nterm2\nterm3\n\n"},
3478 })))
3479 .mount(&server)
3480 .await;
3481
3482 let uri = server.uri();
3483 let terms = tokio::task::spawn_blocking(move || {
3484 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3485 client.expand_query("anything")
3486 })
3487 .await
3488 .unwrap();
3489 assert_eq!(
3490 terms.unwrap(),
3491 vec![
3492 "term1".to_string(),
3493 "term2".to_string(),
3494 "term3".to_string()
3495 ]
3496 );
3497 }
3498
3499 #[tokio::test(flavor = "multi_thread")]
3500 async fn test_auto_tag_returns_parsed_tags() {
3501 let server = MockServer::start().await;
3502 mount_tags_ok(&server, json!({"models": []})).await;
3503 Mock::given(method("POST"))
3511 .and(path("/api/chat"))
3512 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3513 "message": {"content": "Tag1\nTAG2\ntag3"},
3514 })))
3515 .mount(&server)
3516 .await;
3517
3518 let uri = server.uri();
3519 let tags = tokio::task::spawn_blocking(move || {
3520 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3521 client.auto_tag("Title", "content", None)
3522 })
3523 .await
3524 .unwrap();
3525 assert_eq!(
3526 tags.unwrap(),
3527 vec!["tag1".to_string(), "tag2".to_string(), "tag3".to_string()]
3528 );
3529 }
3530
3531 #[tokio::test(flavor = "multi_thread")]
3532 async fn test_detect_contradiction_parses_yes_no() {
3533 let server = MockServer::start().await;
3537 mount_tags_ok(&server, json!({"models": []})).await;
3538 Mock::given(method("POST"))
3539 .and(path("/api/chat"))
3540 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3541 "message": {"content": "yes\n"},
3542 })))
3543 .mount(&server)
3544 .await;
3545
3546 let uri_yes = server.uri();
3547 let yes = tokio::task::spawn_blocking(move || {
3548 let client = OllamaClient::new_with_url(&uri_yes, "test-model").unwrap();
3549 client.detect_contradiction("a", "b")
3550 })
3551 .await
3552 .unwrap();
3553 assert!(yes.unwrap(), "'yes' should be detected as contradiction");
3554
3555 let server_no = MockServer::start().await;
3558 mount_tags_ok(&server_no, json!({"models": []})).await;
3559 Mock::given(method("POST"))
3560 .and(path("/api/chat"))
3561 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3562 "message": {"content": "no"},
3563 })))
3564 .mount(&server_no)
3565 .await;
3566 let uri_no = server_no.uri();
3567 let no = tokio::task::spawn_blocking(move || {
3568 let client = OllamaClient::new_with_url(&uri_no, "test-model").unwrap();
3569 client.detect_contradiction("a", "b")
3570 })
3571 .await
3572 .unwrap();
3573 assert!(!no.unwrap(), "'no' should NOT be detected as contradiction");
3574
3575 let server_garbage = MockServer::start().await;
3577 mount_tags_ok(&server_garbage, json!({"models": []})).await;
3578 Mock::given(method("POST"))
3579 .and(path("/api/chat"))
3580 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3581 "message": {"content": "definitely-not-yes-or-no"},
3582 })))
3583 .mount(&server_garbage)
3584 .await;
3585 let uri_g = server_garbage.uri();
3586 let garbage = tokio::task::spawn_blocking(move || {
3587 let client = OllamaClient::new_with_url(&uri_g, "test-model").unwrap();
3588 client.detect_contradiction("a", "b")
3589 })
3590 .await
3591 .unwrap();
3592 assert!(
3593 !garbage.unwrap(),
3594 "garbage answer should default to non-contradiction"
3595 );
3596 }
3597
3598 #[tokio::test(flavor = "multi_thread")]
3601 async fn test_ensure_embed_model_skips_pull_if_present() {
3602 let server = MockServer::start().await;
3603 Mock::given(method("GET"))
3604 .and(path("/api/tags"))
3605 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3606 "models": [{"name": "nomic-embed-text:latest"}]
3607 })))
3608 .mount(&server)
3609 .await;
3610 Mock::given(method("POST"))
3611 .and(path("/api/pull"))
3612 .respond_with(ResponseTemplate::new(200))
3613 .expect(0)
3614 .mount(&server)
3615 .await;
3616
3617 let uri = server.uri();
3618 let r = tokio::task::spawn_blocking(move || {
3619 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3620 client.ensure_embed_model("nomic-embed-text")
3621 })
3622 .await
3623 .unwrap();
3624 assert!(r.is_ok());
3625 }
3626
3627 #[tokio::test(flavor = "multi_thread")]
3636 async fn auto_tag_model_override_takes_precedence_l15() {
3637 let server = MockServer::start().await;
3638 mount_tags_ok(&server, json!({"models": []})).await;
3639 Mock::given(method("POST"))
3644 .and(path("/api/chat"))
3645 .and(body_partial_json(json!({"model": "gemma3:4b"})))
3646 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3647 "message": {"content": "alpha\nbeta\ngamma"},
3648 })))
3649 .expect(1)
3650 .mount(&server)
3651 .await;
3652
3653 let uri = server.uri();
3654 let tags = tokio::task::spawn_blocking(move || {
3655 let client = OllamaClient::new_with_url(&uri, "gemma4:e2b").unwrap();
3658 client.auto_tag("Title", "content", Some("gemma3:4b"))
3659 })
3660 .await
3661 .unwrap();
3662 let tags = tags.expect("auto_tag with override should succeed");
3663 assert_eq!(
3664 tags,
3665 vec!["alpha".to_string(), "beta".to_string(), "gamma".to_string()]
3666 );
3667 }
3668
3669 #[tokio::test(flavor = "multi_thread")]
3677 async fn auto_tag_chat_shape_post_1067() {
3678 let server = MockServer::start().await;
3679 mount_tags_ok(&server, json!({"models": []})).await;
3680 Mock::given(method("POST"))
3681 .and(path("/api/chat"))
3682 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
3683 "message": {"content": "one\ntwo"},
3684 })))
3685 .expect(1)
3686 .mount(&server)
3687 .await;
3688
3689 let uri = server.uri();
3690 let tags = tokio::task::spawn_blocking(move || {
3691 let client = OllamaClient::new_with_url(&uri, "any-model").unwrap();
3692 client.auto_tag("Title", "content", None)
3693 })
3694 .await
3695 .unwrap();
3696 let tags = tags.expect("auto_tag should succeed");
3697 assert_eq!(tags, vec!["one".to_string(), "two".to_string()]);
3698 }
3699
3700 pub(super) static ENV_GUARD_1143: std::sync::Mutex<()> = std::sync::Mutex::new(());
3715
3716 pub(super) fn lock_env_1143() -> std::sync::MutexGuard<'static, ()> {
3717 ENV_GUARD_1143
3718 .lock()
3719 .unwrap_or_else(std::sync::PoisonError::into_inner)
3720 }
3721
3722 pub(super) fn clear_llm_env_1143() {
3726 for k in [
3727 "AI_MEMORY_LLM_BACKEND",
3728 "AI_MEMORY_LLM_MODEL",
3729 "AI_MEMORY_LLM_BASE_URL",
3730 "AI_MEMORY_LLM_API_KEY",
3731 "OLLAMA_BASE_URL",
3732 "XAI_API_KEY",
3733 "OPENAI_API_KEY",
3734 "ANTHROPIC_API_KEY",
3735 "GEMINI_API_KEY",
3736 "GOOGLE_API_KEY",
3737 ] {
3738 unsafe { std::env::remove_var(k) };
3739 }
3740 }
3741
3742 #[test]
3743 fn is_ollama_native_true_for_ollama_client_1143() {
3744 let client = OllamaClient::new_for_testing("gemma4:e4b");
3747 assert!(
3748 client.is_ollama_native(),
3749 "#1143: Ollama-provider client must report is_ollama_native()=true"
3750 );
3751 }
3752
3753 #[test]
3754 fn is_ollama_native_false_for_openai_compatible_1143() {
3755 let client =
3760 OllamaClient::new_openai_compatible("https://api.x.ai/v1", "grok-4.3", "fake-key")
3761 .expect("openai-compatible client builds");
3762 assert!(
3763 !client.is_ollama_native(),
3764 "#1143: OpenAI-compatible client must report is_ollama_native()=false"
3765 );
3766 }
3767
3768 #[tokio::test(flavor = "multi_thread")]
3769 async fn build_for_init_legacy_arm_when_env_unset_1143() {
3770 let _g = lock_env_1143();
3771 clear_llm_env_1143();
3772
3773 let server = MockServer::start().await;
3774 mount_tags_ok(&server, json!({"models": []})).await;
3775 let uri = server.uri();
3776
3777 let result =
3780 tokio::task::spawn_blocking(move || OllamaClient::build_for_init(&uri, "gemma4:e4b"))
3781 .await
3782 .unwrap();
3783
3784 let client = match result {
3785 Ok(Some(c)) => c,
3786 Ok(None) => panic!("#1143: legacy arm must yield Ok(Some(client)); got Ok(None)"),
3787 Err(e) => panic!("#1143: legacy arm must yield Ok(Some(client)); got Err({e})"),
3788 };
3789 assert!(
3790 client.is_ollama_native(),
3791 "#1143: legacy arm constructs an Ollama-provider client"
3792 );
3793 assert_eq!(client.model, "gemma4:e4b");
3794 }
3795
3796 #[tokio::test(flavor = "multi_thread")]
3797 async fn build_for_init_env_arm_routes_to_from_env_1143() {
3798 let _g = lock_env_1143();
3799 clear_llm_env_1143();
3800
3801 unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "xai") };
3805 unsafe { std::env::set_var("AI_MEMORY_LLM_API_KEY", "fake-xai-key") };
3806 unsafe { std::env::set_var("AI_MEMORY_LLM_MODEL", "grok-4.3") };
3807
3808 let result = tokio::task::spawn_blocking(|| {
3812 OllamaClient::build_for_init("http://127.0.0.1:1", "ignored-legacy-model")
3813 })
3814 .await
3815 .unwrap();
3816
3817 clear_llm_env_1143();
3818
3819 let client = match result {
3820 Ok(Some(c)) => c,
3821 Ok(None) => panic!(
3822 "#1143: env arm with AI_MEMORY_LLM_BACKEND=xai must yield \
3823 Ok(Some(client)); got Ok(None)"
3824 ),
3825 Err(e) => panic!(
3826 "#1143: env arm with AI_MEMORY_LLM_BACKEND=xai must yield \
3827 Ok(Some(client)); got Err({e})"
3828 ),
3829 };
3830 assert!(
3831 !client.is_ollama_native(),
3832 "#1143: xai backend yields an OpenAI-compatible (non-Ollama) client"
3833 );
3834 assert_eq!(
3835 client.model, "grok-4.3",
3836 "#1143: AI_MEMORY_LLM_MODEL must override the legacy model arg"
3837 );
3838 assert_eq!(
3839 client.base_url, "https://api.x.ai/v1",
3840 "#1143: xai default base URL must override the legacy URL arg"
3841 );
3842 }
3843
3844 #[tokio::test(flavor = "multi_thread")]
3845 async fn build_for_init_env_arm_unknown_alias_errors_1143() {
3846 let _g = lock_env_1143();
3847 clear_llm_env_1143();
3848 unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "totally-bogus-vendor") };
3849
3850 let result = tokio::task::spawn_blocking(|| {
3851 OllamaClient::build_for_init("http://127.0.0.1:1", "ignored")
3852 })
3853 .await
3854 .unwrap();
3855
3856 clear_llm_env_1143();
3857 assert!(
3858 result.is_err(),
3859 "#1143: unknown backend alias must surface the error \
3860 instead of silently falling through to the legacy arm"
3861 );
3862 }
3863
3864 #[tokio::test(flavor = "multi_thread")]
3865 async fn build_for_init_env_arm_empty_string_falls_back_to_legacy_1143() {
3866 let _g = lock_env_1143();
3867 clear_llm_env_1143();
3868 unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", " ") };
3873
3874 let server = MockServer::start().await;
3875 mount_tags_ok(&server, json!({"models": []})).await;
3876 let uri = server.uri();
3877
3878 let result =
3879 tokio::task::spawn_blocking(move || OllamaClient::build_for_init(&uri, "gemma4:e2b"))
3880 .await
3881 .unwrap();
3882
3883 clear_llm_env_1143();
3884 let client = result
3885 .expect("legacy arm should not error on whitespace env")
3886 .expect("legacy arm yields Some(client)");
3887 assert!(client.is_ollama_native());
3888 assert_eq!(client.model, "gemma4:e2b");
3889 }
3890}
3891
3892#[cfg(test)]
3905#[allow(clippy::too_many_lines)]
3906mod c5_breaker_tests {
3907 use super::OllamaClient;
3908 use serde_json::json;
3909 use wiremock::matchers::{method, path};
3910 use wiremock::{Mock, MockServer, ResponseTemplate};
3911
3912 async fn mount_tags_ok(server: &MockServer) {
3913 Mock::given(method("GET"))
3914 .and(path("/api/tags"))
3915 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
3916 .mount(server)
3917 .await;
3918 }
3919
3920 #[tokio::test(flavor = "multi_thread")]
3923 async fn generate_fast_fails_after_breaker_trips() {
3924 let server = MockServer::start().await;
3925 mount_tags_ok(&server).await;
3926 Mock::given(method("POST"))
3927 .and(path("/api/chat"))
3928 .respond_with(ResponseTemplate::new(500).set_body_string("upstream sick"))
3929 .mount(&server)
3930 .await;
3931
3932 let uri = server.uri();
3933 let outcome = tokio::task::spawn_blocking(move || {
3934 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3935 assert!(
3937 !client.circuit_breaker_open(),
3938 "breaker open before any failure"
3939 );
3940
3941 for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
3943 let _ = client.generate("ping", None); }
3945 assert!(
3946 client.circuit_breaker_open(),
3947 "breaker should be open after {} consecutive 5xx",
3948 super::CIRCUIT_BREAKER_THRESHOLD
3949 );
3950
3951 let err = client
3953 .generate("ping", None)
3954 .expect_err("breaker-open path must Err");
3955 err.to_string()
3956 })
3957 .await
3958 .unwrap();
3959 assert!(
3960 outcome.contains("circuit breaker open"),
3961 "expected breaker-open envelope, got: {outcome}"
3962 );
3963 }
3964
3965 #[tokio::test(flavor = "multi_thread")]
3967 async fn embed_text_fast_fails_after_breaker_trips() {
3968 let server = MockServer::start().await;
3969 mount_tags_ok(&server).await;
3970 Mock::given(method("POST"))
3973 .and(path("/api/chat"))
3974 .respond_with(ResponseTemplate::new(500))
3975 .mount(&server)
3976 .await;
3977
3978 let uri = server.uri();
3979 let outcome = tokio::task::spawn_blocking(move || {
3980 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
3981 for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
3982 let _ = client.generate("ping", None);
3983 }
3984 assert!(client.circuit_breaker_open());
3985 client
3987 .embed_text("hello", "nomic-embed-text")
3988 .expect_err("embed_text must fast-fail when breaker open")
3989 .to_string()
3990 })
3991 .await
3992 .unwrap();
3993 assert!(
3994 outcome.contains("circuit breaker open"),
3995 "expected breaker-open envelope on embed_text, got: {outcome}"
3996 );
3997 }
3998
3999 #[tokio::test(flavor = "multi_thread")]
4002 async fn circuit_breaker_open_starts_closed() {
4003 let server = MockServer::start().await;
4004 mount_tags_ok(&server).await;
4005 let uri = server.uri();
4006 let closed = tokio::task::spawn_blocking(move || {
4007 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
4008 client.circuit_breaker_open()
4009 })
4010 .await
4011 .unwrap();
4012 assert!(
4013 !closed,
4014 "freshly-constructed client must have closed breaker"
4015 );
4016 }
4017
4018 #[tokio::test(flavor = "multi_thread")]
4024 async fn breaker_stays_closed_under_threshold() {
4025 let server = MockServer::start().await;
4026 mount_tags_ok(&server).await;
4027 Mock::given(method("POST"))
4028 .and(path("/api/chat"))
4029 .respond_with(ResponseTemplate::new(500))
4030 .mount(&server)
4031 .await;
4032 let uri = server.uri();
4033 let still_closed = tokio::task::spawn_blocking(move || {
4034 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
4035 for _ in 0..(super::CIRCUIT_BREAKER_THRESHOLD - 1) {
4037 let _ = client.generate("ping", None);
4038 }
4039 client.circuit_breaker_open()
4040 })
4041 .await
4042 .unwrap();
4043 assert!(
4044 !still_closed,
4045 "breaker must stay closed strictly below the threshold"
4046 );
4047 }
4048}
4049
4050#[cfg(test)]
4062#[allow(clippy::too_many_lines, clippy::similar_names)]
4063mod perf9_async_tests {
4064 use super::OllamaClient;
4065 use serde_json::json;
4066 use std::net::TcpListener;
4067 use wiremock::matchers::{body_partial_json, method, path};
4068 use wiremock::{Mock, MockServer, ResponseTemplate};
4069
4070 async fn mount_tags_ok(server: &MockServer) {
4071 Mock::given(method("GET"))
4072 .and(path("/api/tags"))
4073 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
4074 .mount(server)
4075 .await;
4076 }
4077
4078 #[tokio::test(flavor = "multi_thread")]
4081 async fn new_with_url_async_succeeds_against_healthy_endpoint() {
4082 let server = MockServer::start().await;
4083 mount_tags_ok(&server).await;
4084 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4085 .await
4086 .expect("constructor succeeds against healthy /api/tags");
4087 assert!(client.is_ollama_native());
4088 }
4089
4090 #[tokio::test(flavor = "multi_thread")]
4091 async fn new_with_url_async_errors_when_endpoint_500s() {
4092 let server = MockServer::start().await;
4093 Mock::given(method("GET"))
4094 .and(path("/api/tags"))
4095 .respond_with(ResponseTemplate::new(500))
4096 .mount(&server)
4097 .await;
4098 let msg = match OllamaClient::new_with_url_async(&server.uri(), "test-model").await {
4099 Ok(_) => panic!("constructor must fail on 500"),
4100 Err(e) => e.to_string(),
4101 };
4102 assert!(
4103 msg.contains("not running") || msg.contains("not reachable"),
4104 "expected unreachable-style error, got: {msg}"
4105 );
4106 }
4107
4108 #[tokio::test(flavor = "multi_thread")]
4109 async fn new_with_url_async_errors_when_nothing_listening() {
4110 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
4111 let port = listener.local_addr().unwrap().port();
4112 drop(listener);
4113 let url = format!("http://127.0.0.1:{port}");
4114 let msg = match OllamaClient::new_with_url_async(&url, "test-model").await {
4115 Ok(_) => panic!("connect-refused must surface an error"),
4116 Err(e) => e.to_string(),
4117 };
4118 assert!(msg.contains("not running") || msg.contains("not reachable"));
4119 }
4120
4121 #[tokio::test(flavor = "multi_thread")]
4124 async fn is_available_async_true_on_200() {
4125 let server = MockServer::start().await;
4126 mount_tags_ok(&server).await;
4127 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4128 .await
4129 .unwrap();
4130 assert!(client.is_available_async().await);
4131 }
4132
4133 #[tokio::test(flavor = "multi_thread")]
4134 async fn is_available_async_false_on_500_after_construction() {
4135 let server = MockServer::start().await;
4136 mount_tags_ok(&server).await;
4138 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4139 .await
4140 .unwrap();
4141 drop(server);
4146 let server500 = MockServer::start().await;
4147 Mock::given(method("GET"))
4148 .and(path("/api/tags"))
4149 .respond_with(ResponseTemplate::new(500))
4150 .mount(&server500)
4151 .await;
4152 let mut client500 = OllamaClient::new_for_testing("test-model");
4157 client500.base_url = server500.uri().trim_end_matches('/').to_string();
4160 let _ = client; assert!(!client500.is_available_async().await);
4162 }
4163
4164 #[tokio::test(flavor = "multi_thread")]
4165 async fn is_available_async_false_on_network_error() {
4166 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
4167 let port = listener.local_addr().unwrap().port();
4168 drop(listener);
4169 let mut client = OllamaClient::new_for_testing("test-model");
4170 client.base_url = format!("http://127.0.0.1:{port}");
4171 assert!(!client.is_available_async().await);
4172 }
4173
4174 #[tokio::test(flavor = "multi_thread")]
4175 async fn is_available_async_openai_compatible_path_hits_models() {
4176 let server = MockServer::start().await;
4177 Mock::given(method("GET"))
4179 .and(path("/models"))
4180 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"data": []})))
4181 .mount(&server)
4182 .await;
4183 let client = OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key")
4184 .expect("OpenAI-compat client builds");
4185 assert!(client.is_available_async().await);
4186 }
4187
4188 #[tokio::test(flavor = "multi_thread")]
4189 async fn is_available_async_openai_compatible_false_on_401() {
4190 let server = MockServer::start().await;
4191 Mock::given(method("GET"))
4192 .and(path("/models"))
4193 .respond_with(ResponseTemplate::new(401))
4194 .mount(&server)
4195 .await;
4196 let client =
4197 OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
4198 assert!(!client.is_available_async().await);
4200 }
4201
4202 #[tokio::test(flavor = "multi_thread")]
4205 async fn ensure_model_async_noop_on_openai_compatible() {
4206 let server = MockServer::start().await;
4207 drop(server);
4211 let client =
4212 OllamaClient::new_openai_compatible("http://127.0.0.1:1", "any-model", "fake-key")
4213 .unwrap();
4214 client
4215 .ensure_model_async()
4216 .await
4217 .expect("OpenAI-compatible ensure_model_async is a no-op");
4218 }
4219
4220 #[tokio::test(flavor = "multi_thread")]
4221 async fn ensure_model_async_skips_pull_when_model_present() {
4222 let server = MockServer::start().await;
4223 Mock::given(method("GET"))
4224 .and(path("/api/tags"))
4225 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4226 "models": [{"name": "test-model:latest"}]
4227 })))
4228 .mount(&server)
4229 .await;
4230 Mock::given(method("POST"))
4231 .and(path("/api/pull"))
4232 .respond_with(ResponseTemplate::new(200))
4233 .expect(0)
4234 .mount(&server)
4235 .await;
4236
4237 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4238 .await
4239 .unwrap();
4240 client.ensure_model_async().await.expect("no pull needed");
4241 }
4242
4243 #[tokio::test(flavor = "multi_thread")]
4244 async fn ensure_model_async_pulls_when_missing() {
4245 let server = MockServer::start().await;
4246 Mock::given(method("GET"))
4247 .and(path("/api/tags"))
4248 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
4249 .mount(&server)
4250 .await;
4251 Mock::given(method("POST"))
4252 .and(path("/api/pull"))
4253 .and(body_partial_json(json!({"name": "test-model"})))
4254 .respond_with(ResponseTemplate::new(200).set_body_string(""))
4255 .expect(1)
4256 .mount(&server)
4257 .await;
4258
4259 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4260 .await
4261 .unwrap();
4262 client.ensure_model_async().await.expect("pull succeeds");
4263 }
4264
4265 #[tokio::test(flavor = "multi_thread")]
4266 async fn ensure_model_async_surfaces_pull_failure() {
4267 let server = MockServer::start().await;
4268 Mock::given(method("GET"))
4269 .and(path("/api/tags"))
4270 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
4271 .mount(&server)
4272 .await;
4273 Mock::given(method("POST"))
4274 .and(path("/api/pull"))
4275 .respond_with(ResponseTemplate::new(500).set_body_string("upstream sick"))
4276 .mount(&server)
4277 .await;
4278
4279 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4280 .await
4281 .unwrap();
4282 let err = client
4283 .ensure_model_async()
4284 .await
4285 .expect_err("500 on pull must surface");
4286 assert!(err.to_string().contains("Ollama pull failed"));
4287 }
4288
4289 #[tokio::test(flavor = "multi_thread")]
4290 async fn ensure_model_async_errors_on_malformed_tags_response() {
4291 let server = MockServer::start().await;
4292 Mock::given(method("GET"))
4294 .and(path("/api/tags"))
4295 .respond_with(
4296 ResponseTemplate::new(200)
4297 .set_body_string("{not json")
4298 .insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
4299 )
4300 .mount(&server)
4301 .await;
4302 let mut client = OllamaClient::new_for_testing("test-model");
4303 client.base_url = server.uri().trim_end_matches('/').to_string();
4304 let err = client
4305 .ensure_model_async()
4306 .await
4307 .expect_err("malformed tags must surface");
4308 assert!(
4309 err.to_string().contains("parse") || err.to_string().to_lowercase().contains("json")
4310 );
4311 }
4312
4313 #[tokio::test(flavor = "multi_thread")]
4316 async fn generate_async_happy_path() {
4317 let server = MockServer::start().await;
4318 mount_tags_ok(&server).await;
4319 Mock::given(method("POST"))
4320 .and(path("/api/chat"))
4321 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4322 "message": {"role": "assistant", "content": "hello world"},
4323 })))
4324 .mount(&server)
4325 .await;
4326 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4327 .await
4328 .unwrap();
4329 let out = client.generate_async("ping", None).await.unwrap();
4330 assert_eq!(out, "hello world");
4331 }
4332
4333 #[tokio::test(flavor = "multi_thread")]
4334 async fn generate_async_with_system_prompt() {
4335 let server = MockServer::start().await;
4336 mount_tags_ok(&server).await;
4337 Mock::given(method("POST"))
4338 .and(path("/api/chat"))
4339 .and(body_partial_json(json!({
4340 "messages": [
4341 {"role": "system", "content": "be terse"},
4342 {"role": "user", "content": "hi"},
4343 ],
4344 })))
4345 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4346 "message": {"content": "ok"},
4347 })))
4348 .mount(&server)
4349 .await;
4350 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4351 .await
4352 .unwrap();
4353 let out = client.generate_async("hi", Some("be terse")).await.unwrap();
4354 assert_eq!(out, "ok");
4355 }
4356
4357 #[tokio::test(flavor = "multi_thread")]
4358 async fn generate_async_returns_error_on_500() {
4359 let server = MockServer::start().await;
4360 mount_tags_ok(&server).await;
4361 Mock::given(method("POST"))
4362 .and(path("/api/chat"))
4363 .respond_with(ResponseTemplate::new(500).set_body_string("upstream sick"))
4364 .mount(&server)
4365 .await;
4366 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4367 .await
4368 .unwrap();
4369 let err = client.generate_async("ping", None).await.unwrap_err();
4370 assert!(
4371 err.to_string().contains("500") || err.to_string().contains("Chat generate failed")
4372 );
4373 }
4374
4375 #[tokio::test(flavor = "multi_thread")]
4376 async fn generate_async_returns_error_on_400() {
4377 let server = MockServer::start().await;
4380 mount_tags_ok(&server).await;
4381 Mock::given(method("POST"))
4382 .and(path("/api/chat"))
4383 .respond_with(ResponseTemplate::new(400).set_body_string("bad request"))
4384 .mount(&server)
4385 .await;
4386 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4387 .await
4388 .unwrap();
4389 for _ in 0..(super::CIRCUIT_BREAKER_THRESHOLD + 1) {
4392 let _ = client.generate_async("ping", None).await;
4393 }
4394 assert!(
4395 !client.circuit_breaker_open(),
4396 "4xx must not trip the circuit breaker"
4397 );
4398 }
4399
4400 #[tokio::test(flavor = "multi_thread")]
4401 async fn generate_async_returns_error_on_malformed_json() {
4402 let server = MockServer::start().await;
4403 mount_tags_ok(&server).await;
4404 Mock::given(method("POST"))
4405 .and(path("/api/chat"))
4406 .respond_with(
4407 ResponseTemplate::new(200)
4408 .set_body_string("{not valid json")
4409 .insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
4410 )
4411 .mount(&server)
4412 .await;
4413 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4414 .await
4415 .unwrap();
4416 let err = client.generate_async("ping", None).await.unwrap_err();
4417 assert!(
4418 err.to_string().contains("parse") || err.to_string().to_lowercase().contains("json")
4419 );
4420 }
4421
4422 #[tokio::test(flavor = "multi_thread")]
4423 async fn generate_async_errors_when_message_content_missing() {
4424 let server = MockServer::start().await;
4425 mount_tags_ok(&server).await;
4426 Mock::given(method("POST"))
4429 .and(path("/api/chat"))
4430 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"done": true})))
4431 .mount(&server)
4432 .await;
4433 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4434 .await
4435 .unwrap();
4436 let err = client.generate_async("ping", None).await.unwrap_err();
4437 assert!(err.to_string().contains("Missing 'message.content'"));
4438 }
4439
4440 #[tokio::test(flavor = "multi_thread")]
4441 async fn generate_async_breaker_open_short_circuits() {
4442 let server = MockServer::start().await;
4443 mount_tags_ok(&server).await;
4444 Mock::given(method("POST"))
4445 .and(path("/api/chat"))
4446 .respond_with(ResponseTemplate::new(500))
4447 .mount(&server)
4448 .await;
4449 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4450 .await
4451 .unwrap();
4452 for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
4453 let _ = client.generate_async("x", None).await;
4454 }
4455 assert!(client.circuit_breaker_open(), "breaker should be tripped");
4456 let err = client
4457 .generate_async("y", None)
4458 .await
4459 .expect_err("breaker-open path Errs");
4460 assert!(err.to_string().contains("circuit breaker open"));
4461 }
4462
4463 #[tokio::test(flavor = "multi_thread")]
4466 async fn generate_async_openai_compatible_happy_path() {
4467 let server = MockServer::start().await;
4468 Mock::given(method("POST"))
4469 .and(path("/chat/completions"))
4470 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4471 "choices": [{"message": {"role": "assistant", "content": "hi from openai"}}]
4472 })))
4473 .mount(&server)
4474 .await;
4475 let client =
4476 OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
4477 let out = client.generate_async("ping", None).await.unwrap();
4478 assert_eq!(out, "hi from openai");
4479 }
4480
4481 #[tokio::test(flavor = "multi_thread")]
4482 async fn generate_async_openai_compatible_missing_choices() {
4483 let server = MockServer::start().await;
4484 Mock::given(method("POST"))
4485 .and(path("/chat/completions"))
4486 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"data": "wrong shape"})))
4487 .mount(&server)
4488 .await;
4489 let client =
4490 OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
4491 let err = client.generate_async("ping", None).await.unwrap_err();
4492 assert!(
4493 err.to_string()
4494 .contains("Missing 'choices[0].message.content'")
4495 );
4496 }
4497
4498 #[tokio::test(flavor = "multi_thread")]
4501 async fn embed_text_async_happy_path() {
4502 let server = MockServer::start().await;
4503 mount_tags_ok(&server).await;
4504 Mock::given(method("POST"))
4505 .and(path("/api/embed"))
4506 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4507 "embeddings": [[0.1_f32, 0.2_f32, 0.3_f32]],
4508 })))
4509 .mount(&server)
4510 .await;
4511 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4512 .await
4513 .unwrap();
4514 let v = client
4515 .embed_text_async("hello", "nomic-embed-text")
4516 .await
4517 .unwrap();
4518 assert_eq!(v.len(), 3);
4519 }
4520
4521 #[tokio::test(flavor = "multi_thread")]
4522 async fn embed_text_async_500_trips_breaker_after_threshold() {
4523 let server = MockServer::start().await;
4524 mount_tags_ok(&server).await;
4525 Mock::given(method("POST"))
4526 .and(path("/api/embed"))
4527 .respond_with(ResponseTemplate::new(500))
4528 .mount(&server)
4529 .await;
4530 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4531 .await
4532 .unwrap();
4533 for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
4534 let _ = client.embed_text_async("hello", "m").await;
4535 }
4536 assert!(
4537 client.circuit_breaker_open(),
4538 "3× 5xx must trip the breaker on embed_text_async"
4539 );
4540 }
4541
4542 #[tokio::test(flavor = "multi_thread")]
4543 async fn embed_text_async_400_does_not_trip_breaker() {
4544 let server = MockServer::start().await;
4545 mount_tags_ok(&server).await;
4546 Mock::given(method("POST"))
4547 .and(path("/api/embed"))
4548 .respond_with(ResponseTemplate::new(400))
4549 .mount(&server)
4550 .await;
4551 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4552 .await
4553 .unwrap();
4554 for _ in 0..(super::CIRCUIT_BREAKER_THRESHOLD + 1) {
4555 let _ = client.embed_text_async("hello", "m").await;
4556 }
4557 assert!(!client.circuit_breaker_open());
4558 }
4559
4560 #[tokio::test(flavor = "multi_thread")]
4561 async fn embed_text_async_empty_vec_errors() {
4562 let server = MockServer::start().await;
4563 mount_tags_ok(&server).await;
4564 Mock::given(method("POST"))
4565 .and(path("/api/embed"))
4566 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"embeddings": [[]]})))
4567 .mount(&server)
4568 .await;
4569 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4570 .await
4571 .unwrap();
4572 let err = client
4573 .embed_text_async("hello", "m")
4574 .await
4575 .expect_err("empty vector must error");
4576 assert!(err.to_string().contains("Empty embedding"));
4577 }
4578
4579 #[tokio::test(flavor = "multi_thread")]
4580 async fn embed_text_async_malformed_json_errors() {
4581 let server = MockServer::start().await;
4582 mount_tags_ok(&server).await;
4583 Mock::given(method("POST"))
4584 .and(path("/api/embed"))
4585 .respond_with(
4586 ResponseTemplate::new(200)
4587 .set_body_string("{bad json")
4588 .insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
4589 )
4590 .mount(&server)
4591 .await;
4592 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4593 .await
4594 .unwrap();
4595 let err = client.embed_text_async("hi", "m").await.unwrap_err();
4596 assert!(
4597 err.to_string().contains("parse") || err.to_string().to_lowercase().contains("json")
4598 );
4599 }
4600
4601 #[tokio::test(flavor = "multi_thread")]
4602 async fn embed_text_async_openai_compatible_happy_path() {
4603 let server = MockServer::start().await;
4604 Mock::given(method("POST"))
4605 .and(path("/embeddings"))
4606 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4607 "data": [{"embedding": [0.5_f32, 0.6_f32]}]
4608 })))
4609 .mount(&server)
4610 .await;
4611 let client =
4612 OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
4613 let v = client
4614 .embed_text_async("hello", "nomic-embed-text")
4615 .await
4616 .unwrap();
4617 assert_eq!(v.len(), 2);
4618 }
4619
4620 #[tokio::test(flavor = "multi_thread")]
4621 async fn embed_text_async_openai_compatible_missing_data_errors() {
4622 let server = MockServer::start().await;
4623 Mock::given(method("POST"))
4624 .and(path("/embeddings"))
4625 .respond_with(ResponseTemplate::new(200).set_body_json(json!({})))
4626 .mount(&server)
4627 .await;
4628 let client =
4629 OllamaClient::new_openai_compatible(&server.uri(), "test-model", "fake-key").unwrap();
4630 let err = client.embed_text_async("hi", "m").await.unwrap_err();
4631 assert!(err.to_string().contains("Missing 'data[0].embedding'"));
4632 }
4633
4634 #[tokio::test(flavor = "multi_thread")]
4635 async fn embed_text_async_breaker_open_short_circuits() {
4636 let server = MockServer::start().await;
4637 mount_tags_ok(&server).await;
4638 Mock::given(method("POST"))
4639 .and(path("/api/embed"))
4640 .respond_with(ResponseTemplate::new(500))
4641 .mount(&server)
4642 .await;
4643 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4644 .await
4645 .unwrap();
4646 for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
4647 let _ = client.embed_text_async("x", "m").await;
4648 }
4649 let err = client.embed_text_async("y", "m").await.unwrap_err();
4650 assert!(err.to_string().contains("circuit breaker open"));
4651 }
4652
4653 #[tokio::test(flavor = "multi_thread")]
4656 async fn ensure_embed_model_async_noop_on_openai_compatible() {
4657 let client =
4658 OllamaClient::new_openai_compatible("http://127.0.0.1:1", "any-model", "fake-key")
4659 .unwrap();
4660 client.ensure_embed_model_async("any").await.expect("no-op");
4661 }
4662
4663 #[tokio::test(flavor = "multi_thread")]
4664 async fn ensure_embed_model_async_skips_when_present() {
4665 let server = MockServer::start().await;
4666 Mock::given(method("GET"))
4667 .and(path("/api/tags"))
4668 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4669 "models": [{"name": "nomic-embed-text:latest"}]
4670 })))
4671 .mount(&server)
4672 .await;
4673 Mock::given(method("POST"))
4674 .and(path("/api/pull"))
4675 .respond_with(ResponseTemplate::new(200))
4676 .expect(0)
4677 .mount(&server)
4678 .await;
4679 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4680 .await
4681 .unwrap();
4682 client
4683 .ensure_embed_model_async("nomic-embed-text")
4684 .await
4685 .unwrap();
4686 }
4687
4688 #[tokio::test(flavor = "multi_thread")]
4689 async fn ensure_embed_model_async_pulls_when_missing() {
4690 let server = MockServer::start().await;
4691 Mock::given(method("GET"))
4692 .and(path("/api/tags"))
4693 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
4694 .mount(&server)
4695 .await;
4696 Mock::given(method("POST"))
4697 .and(path("/api/pull"))
4698 .and(body_partial_json(json!({"name": "nomic-embed-text"})))
4699 .respond_with(ResponseTemplate::new(200))
4700 .expect(1)
4701 .mount(&server)
4702 .await;
4703 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4704 .await
4705 .unwrap();
4706 client
4707 .ensure_embed_model_async("nomic-embed-text")
4708 .await
4709 .unwrap();
4710 }
4711
4712 #[tokio::test(flavor = "multi_thread")]
4713 async fn ensure_embed_model_async_pull_failure_surfaces() {
4714 let server = MockServer::start().await;
4715 Mock::given(method("GET"))
4716 .and(path("/api/tags"))
4717 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
4718 .mount(&server)
4719 .await;
4720 Mock::given(method("POST"))
4721 .and(path("/api/pull"))
4722 .respond_with(ResponseTemplate::new(500))
4723 .mount(&server)
4724 .await;
4725 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4726 .await
4727 .unwrap();
4728 let err = client
4729 .ensure_embed_model_async("nomic-embed-text")
4730 .await
4731 .unwrap_err();
4732 assert!(err.to_string().contains("Ollama embed model pull failed"));
4733 }
4734
4735 #[tokio::test(flavor = "multi_thread")]
4738 async fn expand_query_async_parses_lines() {
4739 let server = MockServer::start().await;
4740 mount_tags_ok(&server).await;
4741 Mock::given(method("POST"))
4742 .and(path("/api/chat"))
4743 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4744 "message": {"content": "one\ntwo\n\nthree"},
4745 })))
4746 .mount(&server)
4747 .await;
4748 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4749 .await
4750 .unwrap();
4751 let terms = client.expand_query_async("anything").await.unwrap();
4752 assert_eq!(
4753 terms,
4754 vec!["one".to_string(), "two".to_string(), "three".to_string()]
4755 );
4756 }
4757
4758 #[tokio::test(flavor = "multi_thread")]
4759 async fn summarize_memories_async_renders_prompt_and_returns_summary() {
4760 let server = MockServer::start().await;
4761 mount_tags_ok(&server).await;
4762 Mock::given(method("POST"))
4763 .and(path("/api/chat"))
4764 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4765 "message": {"content": "summarized"},
4766 })))
4767 .mount(&server)
4768 .await;
4769 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4770 .await
4771 .unwrap();
4772 let s = client
4773 .summarize_memories_async(&[
4774 ("t1".to_string(), "c1".to_string()),
4775 ("t2".to_string(), "c2".to_string()),
4776 ])
4777 .await
4778 .unwrap();
4779 assert_eq!(s, "summarized");
4780 }
4781
4782 #[tokio::test(flavor = "multi_thread")]
4783 async fn auto_tag_async_normalises_lines_and_caps_at_8() {
4784 let server = MockServer::start().await;
4785 mount_tags_ok(&server).await;
4786 Mock::given(method("POST"))
4788 .and(path("/api/chat"))
4789 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4790 "message": {"content": "A\nB\nC\nD\nE\nF\nG\nH\nI\nJ"},
4791 })))
4792 .mount(&server)
4793 .await;
4794 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4795 .await
4796 .unwrap();
4797 let tags = client
4798 .auto_tag_async("title", "content", None)
4799 .await
4800 .unwrap();
4801 assert_eq!(tags.len(), 8);
4802 for t in &tags {
4803 assert_eq!(t.to_lowercase(), *t);
4804 }
4805 }
4806
4807 #[tokio::test(flavor = "multi_thread")]
4808 async fn auto_tag_async_model_override_stamps_body() {
4809 let server = MockServer::start().await;
4810 mount_tags_ok(&server).await;
4811 Mock::given(method("POST"))
4812 .and(path("/api/chat"))
4813 .and(body_partial_json(json!({"model": "fast-model"})))
4814 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4815 "message": {"content": "a\nb\nc"},
4816 })))
4817 .expect(1)
4818 .mount(&server)
4819 .await;
4820 let client = OllamaClient::new_with_url_async(&server.uri(), "primary-model")
4821 .await
4822 .unwrap();
4823 let tags = client
4824 .auto_tag_async("t", "c", Some("fast-model"))
4825 .await
4826 .unwrap();
4827 assert_eq!(
4828 tags,
4829 vec!["a".to_string(), "b".to_string(), "c".to_string()]
4830 );
4831 }
4832
4833 #[tokio::test(flavor = "multi_thread")]
4834 async fn detect_contradiction_async_parses_yes() {
4835 let server = MockServer::start().await;
4836 mount_tags_ok(&server).await;
4837 Mock::given(method("POST"))
4838 .and(path("/api/chat"))
4839 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4840 "message": {"content": "Yes."},
4841 })))
4842 .mount(&server)
4843 .await;
4844 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4845 .await
4846 .unwrap();
4847 assert!(client.detect_contradiction_async("a", "b").await.unwrap());
4848 }
4849
4850 #[tokio::test(flavor = "multi_thread")]
4851 async fn detect_contradiction_async_parses_no() {
4852 let server = MockServer::start().await;
4853 mount_tags_ok(&server).await;
4854 Mock::given(method("POST"))
4855 .and(path("/api/chat"))
4856 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4857 "message": {"content": "no, they don't"},
4858 })))
4859 .mount(&server)
4860 .await;
4861 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4862 .await
4863 .unwrap();
4864 assert!(!client.detect_contradiction_async("a", "b").await.unwrap());
4865 }
4866
4867 #[tokio::test(flavor = "multi_thread")]
4868 async fn detect_contradiction_async_propagates_generate_error() {
4869 let server = MockServer::start().await;
4870 mount_tags_ok(&server).await;
4871 Mock::given(method("POST"))
4872 .and(path("/api/chat"))
4873 .respond_with(ResponseTemplate::new(500))
4874 .mount(&server)
4875 .await;
4876 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4877 .await
4878 .unwrap();
4879 assert!(client.detect_contradiction_async("a", "b").await.is_err());
4880 }
4881
4882 #[tokio::test(flavor = "multi_thread")]
4885 async fn generate_with_model_override_async_breaker_open_short_circuits() {
4886 let server = MockServer::start().await;
4887 mount_tags_ok(&server).await;
4888 Mock::given(method("POST"))
4889 .and(path("/api/chat"))
4890 .respond_with(ResponseTemplate::new(500))
4891 .mount(&server)
4892 .await;
4893 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4894 .await
4895 .unwrap();
4896 for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
4897 let _ = client
4898 .generate_with_model_override_async("p", None, Some("m"))
4899 .await;
4900 }
4901 let err = client
4902 .generate_with_model_override_async("p", None, Some("m"))
4903 .await
4904 .unwrap_err();
4905 assert!(err.to_string().contains("circuit breaker open"));
4906 }
4907
4908 #[tokio::test(flavor = "multi_thread")]
4911 async fn sync_wrapper_runs_under_block_in_place_path() {
4912 let server = MockServer::start().await;
4918 mount_tags_ok(&server).await;
4919 Mock::given(method("POST"))
4920 .and(path("/api/chat"))
4921 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
4922 "message": {"content": "bridge ok"},
4923 })))
4924 .mount(&server)
4925 .await;
4926 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
4927 .await
4928 .unwrap();
4929 let out = client.generate("p", None).expect("sync wrapper ok");
4930 assert_eq!(out, "bridge ok");
4931 }
4932
4933 #[test]
4936 fn llm_provider_debug_redacts_api_key() {
4937 let p_ollama = super::LlmProvider::Ollama;
4940 let p_oai = super::LlmProvider::OpenAiCompatible {
4941 api_key: "secret-token-do-not-leak".to_string(),
4942 };
4943 let s_ollama = format!("{p_ollama:?}");
4944 let s_oai = format!("{p_oai:?}");
4945 assert!(s_ollama.contains("Ollama"));
4946 assert!(s_oai.contains("OpenAiCompatible"));
4947 assert!(s_oai.contains("<redacted>"));
4948 assert!(
4949 !s_oai.contains("secret-token-do-not-leak"),
4950 "Debug impl must not leak the api_key"
4951 );
4952 }
4953
4954 #[test]
4955 fn model_name_returns_resolved_model() {
4956 let client = OllamaClient::new_for_testing("gemma-test-model");
4957 assert_eq!(client.model_name(), "gemma-test-model");
4958 }
4959
4960 #[test]
4961 fn llm_provider_zeroize_secrets_is_idempotent() {
4962 let mut p = super::LlmProvider::OpenAiCompatible {
4963 api_key: "abcdef".to_string(),
4964 };
4965 p.zeroize_secrets();
4966 let super::LlmProvider::OpenAiCompatible { api_key } = &p else {
4967 unreachable!()
4968 };
4969 assert!(api_key.is_empty() || api_key.bytes().all(|b| b == 0));
4970 p.zeroize_secrets();
4971 }
4972
4973 #[test]
4974 fn llm_provider_zeroize_secrets_noop_on_ollama() {
4975 let mut p = super::LlmProvider::Ollama;
4976 p.zeroize_secrets();
4977 assert!(matches!(p, super::LlmProvider::Ollama));
4978 }
4979
4980 #[test]
4981 fn breaker_state_is_open_returns_false_when_last_failure_none() {
4982 let s = super::BreakerState::new();
4983 assert!(!s.is_open(), "fresh breaker must be closed");
4984 }
4985
4986 #[tokio::test(flavor = "multi_thread")]
4987 async fn new_convenience_constructor_routes_to_default_url() {
4988 let res = tokio::task::spawn_blocking(|| OllamaClient::new("test-model"))
4992 .await
4993 .unwrap();
4994 match res {
4995 Ok(_) => { }
4996 Err(e) => {
4997 let msg = e.to_string();
4998 assert!(
4999 msg.contains("not running") || msg.contains("not reachable"),
5000 "expected an unreachable-style error, got: {msg}"
5001 );
5002 }
5003 }
5004 }
5005
5006 #[tokio::test(flavor = "multi_thread")]
5009 async fn sync_wrapper_path_is_available() {
5010 let server = MockServer::start().await;
5011 mount_tags_ok(&server).await;
5012 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5013 .await
5014 .unwrap();
5015 assert!(client.is_available());
5016 }
5017
5018 #[tokio::test(flavor = "multi_thread")]
5019 async fn sync_wrapper_path_embed_text() {
5020 let server = MockServer::start().await;
5021 mount_tags_ok(&server).await;
5022 Mock::given(method("POST"))
5023 .and(path("/api/embed"))
5024 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
5025 "embeddings": [[0.42_f32]],
5026 })))
5027 .mount(&server)
5028 .await;
5029 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5030 .await
5031 .unwrap();
5032 let v = client.embed_text("hi", "m").unwrap();
5033 assert_eq!(v.len(), 1);
5034 }
5035
5036 #[tokio::test(flavor = "multi_thread")]
5037 async fn sync_wrapper_path_expand_query() {
5038 let server = MockServer::start().await;
5039 mount_tags_ok(&server).await;
5040 Mock::given(method("POST"))
5041 .and(path("/api/chat"))
5042 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
5043 "message": {"content": "a\nb"},
5044 })))
5045 .mount(&server)
5046 .await;
5047 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5048 .await
5049 .unwrap();
5050 let terms = client.expand_query("q").unwrap();
5051 assert_eq!(terms, vec!["a".to_string(), "b".to_string()]);
5052 }
5053
5054 #[tokio::test(flavor = "multi_thread")]
5055 async fn sync_wrapper_path_summarize_memories() {
5056 let server = MockServer::start().await;
5057 mount_tags_ok(&server).await;
5058 Mock::given(method("POST"))
5059 .and(path("/api/chat"))
5060 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
5061 "message": {"content": "compacted"},
5062 })))
5063 .mount(&server)
5064 .await;
5065 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5066 .await
5067 .unwrap();
5068 let s = client
5069 .summarize_memories(&[("t".to_string(), "c".to_string())])
5070 .unwrap();
5071 assert_eq!(s, "compacted");
5072 }
5073
5074 #[tokio::test(flavor = "multi_thread")]
5075 async fn sync_wrapper_path_auto_tag() {
5076 let server = MockServer::start().await;
5077 mount_tags_ok(&server).await;
5078 Mock::given(method("POST"))
5079 .and(path("/api/chat"))
5080 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
5081 "message": {"content": "x\ny\nz"},
5082 })))
5083 .mount(&server)
5084 .await;
5085 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5086 .await
5087 .unwrap();
5088 let tags = client.auto_tag("t", "c", None).unwrap();
5089 assert_eq!(
5090 tags,
5091 vec!["x".to_string(), "y".to_string(), "z".to_string()]
5092 );
5093 }
5094
5095 #[tokio::test(flavor = "multi_thread")]
5096 async fn sync_wrapper_path_detect_contradiction() {
5097 let server = MockServer::start().await;
5098 mount_tags_ok(&server).await;
5099 Mock::given(method("POST"))
5100 .and(path("/api/chat"))
5101 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
5102 "message": {"content": "yes"},
5103 })))
5104 .mount(&server)
5105 .await;
5106 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5107 .await
5108 .unwrap();
5109 assert!(client.detect_contradiction("a", "b").unwrap());
5110 }
5111
5112 #[tokio::test(flavor = "multi_thread")]
5113 async fn sync_wrapper_path_ensure_model() {
5114 let server = MockServer::start().await;
5115 Mock::given(method("GET"))
5116 .and(path("/api/tags"))
5117 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
5118 "models": [{"name": "test-model:latest"}]
5119 })))
5120 .mount(&server)
5121 .await;
5122 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5123 .await
5124 .unwrap();
5125 client.ensure_model().unwrap();
5126 }
5127
5128 #[tokio::test(flavor = "multi_thread")]
5129 async fn sync_wrapper_path_ensure_embed_model() {
5130 let server = MockServer::start().await;
5131 Mock::given(method("GET"))
5132 .and(path("/api/tags"))
5133 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
5134 "models": [{"name": "nomic-embed-text:latest"}]
5135 })))
5136 .mount(&server)
5137 .await;
5138 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5139 .await
5140 .unwrap();
5141 client.ensure_embed_model("nomic-embed-text").unwrap();
5142 }
5143
5144 #[tokio::test(flavor = "multi_thread")]
5147 async fn generate_with_body_async_happy_path() {
5148 let server = MockServer::start().await;
5149 mount_tags_ok(&server).await;
5150 Mock::given(method("POST"))
5151 .and(path("/api/generate"))
5152 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
5153 "response": "legacy text",
5154 })))
5155 .mount(&server)
5156 .await;
5157 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5158 .await
5159 .unwrap();
5160 let body = json!({"model": "test-model", "prompt": "p", "stream": false});
5161 let out = client.generate_with_body_async(&body).await.unwrap();
5162 assert_eq!(out, "legacy text");
5163 }
5164
5165 #[tokio::test(flavor = "multi_thread")]
5166 async fn generate_with_body_async_returns_error_on_500() {
5167 let server = MockServer::start().await;
5168 mount_tags_ok(&server).await;
5169 Mock::given(method("POST"))
5170 .and(path("/api/generate"))
5171 .respond_with(ResponseTemplate::new(500).set_body_string("bad"))
5172 .mount(&server)
5173 .await;
5174 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5175 .await
5176 .unwrap();
5177 let body = json!({"model": "test-model"});
5178 let err = client.generate_with_body_async(&body).await.unwrap_err();
5179 assert!(err.to_string().contains("500") || err.to_string().contains("Generate failed"));
5180 }
5181
5182 #[tokio::test(flavor = "multi_thread")]
5183 async fn generate_with_body_async_returns_error_on_malformed_json() {
5184 let server = MockServer::start().await;
5185 mount_tags_ok(&server).await;
5186 Mock::given(method("POST"))
5187 .and(path("/api/generate"))
5188 .respond_with(
5189 ResponseTemplate::new(200)
5190 .set_body_string("{bad json")
5191 .insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
5192 )
5193 .mount(&server)
5194 .await;
5195 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5196 .await
5197 .unwrap();
5198 let body = json!({"model": "test-model"});
5199 let err = client.generate_with_body_async(&body).await.unwrap_err();
5200 assert!(
5201 err.to_string().contains("parse") || err.to_string().to_lowercase().contains("json")
5202 );
5203 }
5204
5205 #[tokio::test(flavor = "multi_thread")]
5206 async fn generate_with_body_async_breaker_open_short_circuits() {
5207 let server = MockServer::start().await;
5208 mount_tags_ok(&server).await;
5209 Mock::given(method("POST"))
5210 .and(path("/api/generate"))
5211 .respond_with(ResponseTemplate::new(500))
5212 .mount(&server)
5213 .await;
5214 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5215 .await
5216 .unwrap();
5217 let body = json!({"model": "test-model"});
5218 for _ in 0..super::CIRCUIT_BREAKER_THRESHOLD {
5219 let _ = client.generate_with_body_async(&body).await;
5220 }
5221 let err = client.generate_with_body_async(&body).await.unwrap_err();
5222 assert!(err.to_string().contains("circuit breaker open"));
5223 }
5224
5225 #[tokio::test(flavor = "multi_thread")]
5226 async fn generate_with_body_async_missing_response_field_errors() {
5227 let server = MockServer::start().await;
5228 mount_tags_ok(&server).await;
5229 Mock::given(method("POST"))
5230 .and(path("/api/generate"))
5231 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"done": true})))
5232 .mount(&server)
5233 .await;
5234 let client = OllamaClient::new_with_url_async(&server.uri(), "test-model")
5235 .await
5236 .unwrap();
5237 let body = json!({});
5238 let err = client.generate_with_body_async(&body).await.unwrap_err();
5239 assert!(err.to_string().contains("Missing 'response'"));
5240 }
5241
5242 #[tokio::test(flavor = "multi_thread")]
5245 async fn from_env_openai_compatible_requires_base_url() {
5246 let _g = super::wiremock_tests::lock_env_1143();
5247 super::wiremock_tests::clear_llm_env_1143();
5248 unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "openai-compatible") };
5249 unsafe { std::env::set_var("AI_MEMORY_LLM_API_KEY", "k") };
5250 let res = OllamaClient::from_env();
5251 super::wiremock_tests::clear_llm_env_1143();
5252 let err = match res {
5253 Ok(_) => panic!("openai-compatible without base_url must error"),
5254 Err(e) => e.to_string(),
5255 };
5256 assert!(err.contains("AI_MEMORY_LLM_BASE_URL"));
5257 }
5258
5259 #[tokio::test(flavor = "multi_thread")]
5260 async fn from_env_openai_compatible_requires_api_key() {
5261 let _g = super::wiremock_tests::lock_env_1143();
5262 super::wiremock_tests::clear_llm_env_1143();
5263 unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "openai-compatible") };
5264 unsafe { std::env::set_var("AI_MEMORY_LLM_BASE_URL", "https://example.test/v1") };
5265 let res = OllamaClient::from_env();
5266 super::wiremock_tests::clear_llm_env_1143();
5267 let err = match res {
5268 Ok(_) => panic!("openai-compatible without key must error"),
5269 Err(e) => e.to_string(),
5270 };
5271 assert!(err.contains("AI_MEMORY_LLM_API_KEY"));
5272 }
5273
5274 #[tokio::test(flavor = "multi_thread")]
5275 async fn from_env_alias_requires_api_key_when_none_resolvable() {
5276 let _g = super::wiremock_tests::lock_env_1143();
5277 super::wiremock_tests::clear_llm_env_1143();
5278 unsafe { std::env::set_var("AI_MEMORY_LLM_BACKEND", "xai") };
5279 let res = OllamaClient::from_env();
5280 super::wiremock_tests::clear_llm_env_1143();
5281 let err = match res {
5282 Ok(_) => panic!("xai without key must error"),
5283 Err(e) => e.to_string(),
5284 };
5285 assert!(err.contains("API key"));
5286 }
5287
5288 #[test]
5291 fn sync_wrapper_outside_runtime_constructs_ephemeral() {
5292 let rt = tokio::runtime::Builder::new_multi_thread()
5298 .enable_all()
5299 .build()
5300 .unwrap();
5301 let server = rt.block_on(async {
5302 let s = MockServer::start().await;
5303 mount_tags_ok(&s).await;
5304 Mock::given(method("POST"))
5305 .and(path("/api/chat"))
5306 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
5307 "message": {"content": "no-rt bridge ok"},
5308 })))
5309 .mount(&s)
5310 .await;
5311 s
5312 });
5313 std::thread::scope(|sc| {
5316 sc.spawn(|| {
5317 let client = OllamaClient::new_with_url(&server.uri(), "test-model")
5318 .expect("sync new_with_url ok");
5319 let out = client.generate("ping", None).expect("sync generate ok");
5320 assert_eq!(out, "no-rt bridge ok");
5321 })
5322 .join()
5323 .unwrap();
5324 });
5325 }
5326}