Skip to main content

ai_memory/
llm.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::{Context, Result, anyhow};
5use serde_json::{Value, json};
6use std::time::Duration;
7
8const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
9
10const GENERATE_TIMEOUT: Duration = Duration::from_secs(30);
11const PULL_TIMEOUT: Duration = Duration::from_secs(120);
12
13const QUERY_EXPANSION_PROMPT: &str = r"You are a search query expander. Given a search query, generate 5-8 additional search terms that are semantically related. Return ONLY the terms, one per line, no numbering or explanation.
14
15Query: {query}";
16
17const SUMMARIZE_PROMPT: &str = r"Summarize the following memories into a single concise paragraph. Preserve all key facts, decisions, and technical details.
18
19{memories}";
20
21const AUTO_TAG_PROMPT: &str = r"Generate 3-5 short tags for categorizing this memory. Return ONLY the tags, one per line, lowercase, no symbols.
22
23Title: {title}
24Content: {content}";
25
26const CONTRADICTION_PROMPT: &str = r#"Do these two statements contradict each other? Answer ONLY "yes" or "no".
27
28Statement A: {a}
29Statement B: {b}"#;
30
31pub struct OllamaClient {
32    base_url: String,
33    model: String,
34    client: reqwest::blocking::Client,
35}
36
37impl OllamaClient {
38    /// Creates a new `OllamaClient` with the default Ollama URL (<http://localhost:11434>).
39    /// Checks that Ollama is reachable before returning.
40    #[allow(dead_code)]
41    pub fn new(model: &str) -> Result<Self> {
42        Self::new_with_url(DEFAULT_OLLAMA_URL, model)
43    }
44
45    /// Creates a new `OllamaClient` with a custom base URL.
46    /// Checks that Ollama is reachable before returning.
47    pub fn new_with_url(base_url: &str, model: &str) -> Result<Self> {
48        let client = reqwest::blocking::Client::builder()
49            .timeout(GENERATE_TIMEOUT)
50            .build()
51            .context("Failed to build HTTP client")?;
52
53        let instance = Self {
54            base_url: base_url.trim_end_matches('/').to_string(),
55            model: model.to_string(),
56            client,
57        };
58
59        if !instance.is_available() {
60            return Err(anyhow!(
61                "Ollama is not running or not reachable at {}. \
62                 Start it with: ollama serve",
63                instance.base_url
64            ));
65        }
66
67        Ok(instance)
68    }
69
70    /// Quick health check -- returns true if Ollama responds to GET /api/tags.
71    pub fn is_available(&self) -> bool {
72        let url = format!("{}/api/tags", self.base_url);
73        self.client
74            .get(&url)
75            .timeout(Duration::from_secs(5))
76            .send()
77            .is_ok_and(|r| r.status().is_success())
78    }
79
80    /// Checks if the configured model is already pulled. If not, pulls it.
81    pub fn ensure_model(&self) -> Result<()> {
82        // Check if model exists by listing tags
83        let url = format!("{}/api/tags", self.base_url);
84        let resp = self
85            .client
86            .get(&url)
87            .timeout(Duration::from_secs(10))
88            .send()
89            .context("Failed to list Ollama models")?;
90
91        let body: Value = resp.json().context("Failed to parse /api/tags response")?;
92
93        let model_exists = body["models"].as_array().is_some_and(|models| {
94            models.iter().any(|m| {
95                let name = m["name"].as_str().unwrap_or("");
96                // Match "model" or "model:tag" against our model string
97                // Also match when our model base (before ':') matches the served name
98                let our_base = self.model.split(':').next().unwrap_or(&self.model);
99                name == self.model
100                    || name.starts_with(&format!("{}:", self.model))
101                    || self.model == name.split(':').next().unwrap_or("")
102                    || name == our_base
103            })
104        });
105
106        if model_exists {
107            return Ok(());
108        }
109
110        // Pull the model
111        tracing::info!(
112            "Pulling Ollama model '{}' (this may take a while)...",
113            self.model
114        );
115
116        let pull_url = format!("{}/api/pull", self.base_url);
117        let pull_client = reqwest::blocking::Client::builder()
118            .timeout(PULL_TIMEOUT)
119            .build()
120            .context("Failed to build pull client")?;
121
122        let resp = pull_client
123            .post(&pull_url)
124            .json(&json!({ "name": self.model }))
125            .send()
126            .context("Failed to pull model from Ollama")?;
127
128        if !resp.status().is_success() {
129            let status = resp.status();
130            let text = resp.text().unwrap_or_default();
131            return Err(anyhow!("Ollama pull failed ({status}): {text}"));
132        }
133
134        tracing::info!("Model '{}' pulled successfully", self.model);
135        Ok(())
136    }
137
138    /// Generates a completion using the /api/chat endpoint (Ollama chat format).
139    /// This is compatible with both Ollama and vMLX/OpenAI-compatible servers.
140    /// Returns the response text.
141    pub fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String> {
142        let url = format!("{}/api/chat", self.base_url);
143
144        let mut messages = Vec::new();
145        if let Some(sys) = system {
146            messages.push(json!({"role": "system", "content": sys}));
147        }
148        messages.push(json!({"role": "user", "content": prompt}));
149
150        let payload = json!({
151            "model": self.model,
152            "messages": messages,
153            "stream": false,
154        });
155
156        let resp = self
157            .client
158            .post(&url)
159            .timeout(GENERATE_TIMEOUT)
160            .json(&payload)
161            .send()
162            .context("Failed to send chat request")?;
163
164        if !resp.status().is_success() {
165            let status = resp.status();
166            let text = resp.text().unwrap_or_default();
167            return Err(anyhow!("Chat generate failed ({status}): {text}"));
168        }
169
170        let body: Value = resp.json().context("Failed to parse chat response")?;
171
172        // Ollama /api/chat returns {"message": {"content": "..."}}
173        let response_text = body["message"]["content"]
174            .as_str()
175            .ok_or_else(|| anyhow!("Missing 'message.content' field in chat output"))?
176            .to_string();
177
178        Ok(response_text)
179    }
180
181    /// Uses the LLM to expand a search query into additional search terms.
182    pub fn expand_query(&self, query: &str) -> Result<Vec<String>> {
183        let prompt = QUERY_EXPANSION_PROMPT.replace("{query}", query);
184        let response = self.generate(&prompt, None)?;
185
186        let terms: Vec<String> = response
187            .lines()
188            .map(|line| line.trim().to_string())
189            .filter(|line| !line.is_empty())
190            .collect();
191
192        Ok(terms)
193    }
194
195    /// Takes (title, content) pairs and returns a consolidated summary.
196    pub fn summarize_memories(&self, memories: &[(String, String)]) -> Result<String> {
197        let formatted = memories
198            .iter()
199            .enumerate()
200            .map(|(i, (title, content))| {
201                format!("--- Memory {} ---\nTitle: {}\n{}", i + 1, title, content)
202            })
203            .collect::<Vec<_>>()
204            .join("\n\n");
205
206        let prompt = SUMMARIZE_PROMPT.replace("{memories}", &formatted);
207        let response = self.generate(&prompt, None)?;
208
209        Ok(response.trim().to_string())
210    }
211
212    /// Generates suggested tags for a memory.
213    pub fn auto_tag(&self, title: &str, content: &str) -> Result<Vec<String>> {
214        let prompt = AUTO_TAG_PROMPT
215            .replace("{title}", title)
216            .replace("{content}", content);
217
218        let response = self.generate(&prompt, None)?;
219
220        let tags: Vec<String> = response
221            .lines()
222            .map(|line| line.trim().to_lowercase())
223            .filter(|line| !line.is_empty())
224            .collect();
225
226        Ok(tags)
227    }
228
229    /// Generate an embedding vector via Ollama's /api/embed endpoint.
230    ///
231    /// Used for nomic-embed-text-v1.5 on smart/autonomous tiers.
232    pub fn embed_text(&self, text: &str, embed_model: &str) -> Result<Vec<f32>> {
233        let url = format!("{}/api/embed", self.base_url);
234        let payload = json!({
235            "model": embed_model,
236            "input": text,
237        });
238
239        let resp = self
240            .client
241            .post(&url)
242            .timeout(GENERATE_TIMEOUT)
243            .json(&payload)
244            .send()
245            .context("Failed to send embed request to Ollama")?;
246
247        if !resp.status().is_success() {
248            let status = resp.status();
249            let text = resp.text().unwrap_or_default();
250            return Err(anyhow!("Ollama embed failed ({status}): {text}"));
251        }
252
253        let body: Value = resp
254            .json()
255            .context("Failed to parse Ollama embed response")?;
256
257        // Ollama returns {"embeddings": [[...], ...]} — take the first one
258        let embedding = body["embeddings"]
259            .as_array()
260            .and_then(|arr| arr.first())
261            .and_then(|v| v.as_array())
262            .ok_or_else(|| anyhow!("Missing embeddings in Ollama response"))?;
263
264        #[allow(clippy::cast_possible_truncation)]
265        let floats: Vec<f32> = embedding
266            .iter()
267            .filter_map(|v| v.as_f64().map(|f| f as f32))
268            .collect();
269
270        if floats.is_empty() {
271            return Err(anyhow!("Empty embedding returned from Ollama"));
272        }
273
274        Ok(floats)
275    }
276
277    /// Ensure an embedding model is pulled in Ollama.
278    pub fn ensure_embed_model(&self, model: &str) -> Result<()> {
279        let url = format!("{}/api/tags", self.base_url);
280        let resp = self
281            .client
282            .get(&url)
283            .timeout(std::time::Duration::from_secs(10))
284            .send()
285            .context("Failed to list Ollama models")?;
286
287        let body: Value = resp.json().context("Failed to parse /api/tags response")?;
288        let model_exists = body["models"].as_array().is_some_and(|models| {
289            models.iter().any(|m| {
290                let name = m["name"].as_str().unwrap_or("");
291                name == model
292                    || name.starts_with(&format!("{model}:"))
293                    || model == name.split(':').next().unwrap_or("")
294            })
295        });
296
297        if model_exists {
298            return Ok(());
299        }
300
301        tracing::info!("Pulling Ollama embedding model '{}'...", model);
302        let pull_url = format!("{}/api/pull", self.base_url);
303        let pull_client = reqwest::blocking::Client::builder()
304            .timeout(PULL_TIMEOUT)
305            .build()
306            .context("Failed to build pull client")?;
307        let resp = pull_client
308            .post(&pull_url)
309            .json(&json!({ "name": model }))
310            .send()
311            .context("Failed to pull embedding model from Ollama")?;
312
313        if !resp.status().is_success() {
314            let status = resp.status();
315            let text = resp.text().unwrap_or_default();
316            return Err(anyhow!("Ollama embed model pull failed ({status}): {text}"));
317        }
318
319        tracing::info!("Embedding model '{}' pulled successfully", model);
320        Ok(())
321    }
322
323    /// Returns true if two memory contents contradict each other.
324    pub fn detect_contradiction(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
325        let prompt = CONTRADICTION_PROMPT
326            .replace("{a}", mem_a)
327            .replace("{b}", mem_b);
328
329        let response = self.generate(&prompt, None)?;
330        let answer = response.trim().to_lowercase();
331
332        Ok(answer.starts_with("yes"))
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_prompt_templates_have_placeholders() {
342        assert!(QUERY_EXPANSION_PROMPT.contains("{query}"));
343        assert!(SUMMARIZE_PROMPT.contains("{memories}"));
344        assert!(AUTO_TAG_PROMPT.contains("{title}"));
345        assert!(AUTO_TAG_PROMPT.contains("{content}"));
346        assert!(CONTRADICTION_PROMPT.contains("{a}"));
347        assert!(CONTRADICTION_PROMPT.contains("{b}"));
348    }
349
350    #[test]
351    fn test_default_url() {
352        assert_eq!(DEFAULT_OLLAMA_URL, "http://localhost:11434");
353    }
354}
355
356#[cfg(test)]
357#[allow(
358    clippy::unused_self,
359    clippy::unnecessary_wraps,
360    clippy::needless_pass_by_value,
361    clippy::wildcard_imports,
362    clippy::doc_markdown
363)]
364pub mod test_support {
365    use super::*;
366
367    /// Mock Ollama client for testing without a running Ollama daemon.
368    /// Returns deterministic, canned responses for each public method.
369    pub enum MockFailure {
370        ModelNotFound,
371        Timeout,
372        MalformedResponse,
373        ApiError(String),
374        EmptyResponse,
375        NetworkError,
376    }
377
378    pub struct MockOllamaClient {
379        pub base_url: String,
380        pub model: String,
381        pub fail_with: Option<MockFailure>,
382    }
383
384    impl MockOllamaClient {
385        /// Create a mock client with the given URL and model name.
386        pub fn new_with_url(base_url: &str, model: &str) -> Result<Self> {
387            Ok(Self {
388                base_url: base_url.trim_end_matches('/').to_string(),
389                model: model.to_string(),
390                fail_with: None,
391            })
392        }
393
394        /// Create a mock client that will fail with the specified failure mode.
395        pub fn with_failure(base_url: &str, model: &str, failure: MockFailure) -> Result<Self> {
396            Ok(Self {
397                base_url: base_url.trim_end_matches('/').to_string(),
398                model: model.to_string(),
399                fail_with: Some(failure),
400            })
401        }
402
403        /// Check if this client is configured to fail
404        fn should_fail(&self) -> Option<&MockFailure> {
405            self.fail_with.as_ref()
406        }
407
408        /// Mock health check — returns false if NetworkError, true otherwise.
409        pub fn is_available(&self) -> bool {
410            !matches!(self.should_fail(), Some(MockFailure::NetworkError))
411        }
412
413        /// Mock `ensure_model` — fails if ModelNotFound or Timeout.
414        pub fn ensure_model(&self) -> Result<()> {
415            match self.should_fail() {
416                Some(MockFailure::ModelNotFound) => Err(anyhow!(
417                    "Model 'unknown-model' not found in Ollama registry"
418                )),
419                Some(MockFailure::Timeout) => {
420                    Err(anyhow!("Failed to list Ollama models: operation timed out"))
421                }
422                Some(MockFailure::ApiError(msg)) => {
423                    Err(anyhow!("Ollama pull failed (404): {}", msg))
424                }
425                Some(MockFailure::NetworkError) => Err(anyhow!(
426                    "Failed to pull model from Ollama: connection refused"
427                )),
428                _ => Ok(()),
429            }
430        }
431
432        /// Mock `ensure_embed_model` — similar to ensure_model.
433        pub fn ensure_embed_model(&self, _model: &str) -> Result<()> {
434            match self.should_fail() {
435                Some(MockFailure::ModelNotFound) => Err(anyhow!("Embedding model not found")),
436                Some(MockFailure::Timeout) => {
437                    Err(anyhow!("Failed to list Ollama models: operation timed out"))
438                }
439                Some(MockFailure::ApiError(msg)) => {
440                    Err(anyhow!("Ollama embed model pull failed (404): {}", msg))
441                }
442                Some(MockFailure::NetworkError) => Err(anyhow!(
443                    "Failed to pull embedding model from Ollama: connection refused"
444                )),
445                _ => Ok(()),
446            }
447        }
448
449        /// Mock generate — returns errors or deterministic responses based on failure mode.
450        pub fn generate(&self, prompt: &str, _system: Option<&str>) -> Result<String> {
451            match self.should_fail() {
452                Some(MockFailure::Timeout) => {
453                    return Err(anyhow!("Failed to send chat request: operation timed out"));
454                }
455                Some(MockFailure::MalformedResponse) => {
456                    return Err(anyhow!("Failed to parse chat response: invalid JSON"));
457                }
458                Some(MockFailure::EmptyResponse) => {
459                    return Err(anyhow!("Missing 'message.content' field in chat output"));
460                }
461                Some(MockFailure::ApiError(msg)) => {
462                    return Err(anyhow!("Chat generate failed (500): {}", msg));
463                }
464                Some(MockFailure::NetworkError) => {
465                    return Err(anyhow!("Failed to send chat request: connection refused"));
466                }
467                _ => {}
468            }
469
470            // Normal response logic
471            if prompt.contains("expand") || prompt.contains("search") {
472                Ok("semantic search\nquery terms\nvector retrieval\ninformation retrieval\nsimilarity matching"
473                    .to_string())
474            } else if prompt.contains("Summarize") {
475                Ok("This is a consolidated summary of multiple memories covering key facts and decisions."
476                    .to_string())
477            } else if prompt.contains("tags") {
478                Ok("important\nkey-fact\nstatus-update\ntechnical".to_string())
479            } else if prompt.contains("contradict") {
480                if prompt.contains("yes") || prompt.contains("true") {
481                    Ok("yes".to_string())
482                } else {
483                    Ok("no".to_string())
484                }
485            } else {
486                Ok("Mock response for: ".to_string() + &prompt[..prompt.len().min(50)])
487            }
488        }
489
490        /// Mock `expand_query` — returns error or synthetic expansion.
491        pub fn expand_query(&self, query: &str) -> Result<Vec<String>> {
492            if let Some(failure) = self.should_fail() {
493                return Err(match failure {
494                    MockFailure::Timeout => {
495                        anyhow!("Failed to send chat request: operation timed out")
496                    }
497                    MockFailure::MalformedResponse => {
498                        anyhow!("Failed to parse chat response: invalid JSON")
499                    }
500                    MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
501                    _ => anyhow!("Generate failed"),
502                });
503            }
504            let terms: Vec<String> = vec![
505                format!("{}-related", query),
506                format!("{}-expanded", query),
507                "semantic-search".to_string(),
508                "vector-expansion".to_string(),
509                "query-variants".to_string(),
510            ];
511            Ok(terms.to_vec())
512        }
513
514        /// Mock `summarize_memories` — fails if no memories.
515        pub fn summarize_memories(&self, memories: &[(String, String)]) -> Result<String> {
516            if memories.is_empty() {
517                return Err(anyhow!("Cannot summarize empty memories list"));
518            }
519            if let Some(failure) = self.should_fail() {
520                return Err(match failure {
521                    MockFailure::Timeout => {
522                        anyhow!("Failed to send chat request: operation timed out")
523                    }
524                    MockFailure::MalformedResponse => {
525                        anyhow!("Failed to parse chat response: invalid JSON")
526                    }
527                    MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
528                    _ => anyhow!("Generate failed"),
529                });
530            }
531            let count = memories.len();
532            Ok(format!(
533                "Summary of {count} memories: consolidated facts and key decisions preserved"
534            ))
535        }
536
537        /// Mock `auto_tag` — handles special characters and error modes.
538        pub fn auto_tag(&self, title: &str, _content: &str) -> Result<Vec<String>> {
539            if let Some(failure) = self.should_fail() {
540                return Err(match failure {
541                    MockFailure::Timeout => {
542                        anyhow!("Failed to send chat request: operation timed out")
543                    }
544                    MockFailure::MalformedResponse => {
545                        anyhow!("Failed to parse chat response: invalid JSON")
546                    }
547                    MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
548                    _ => anyhow!("Generate failed"),
549                });
550            }
551            let tags: Vec<String> = vec![
552                "important".to_string(),
553                format!("{}-tag", title.split_whitespace().next().unwrap_or("data")),
554                "memory".to_string(),
555            ];
556            Ok(tags)
557        }
558
559        /// Mock `embed_text` — returns 768-dim vector or error.
560        pub fn embed_text(&self, text: &str, _embed_model: &str) -> Result<Vec<f32>> {
561            match self.should_fail() {
562                Some(MockFailure::Timeout) => {
563                    return Err(anyhow!(
564                        "Failed to send embed request to Ollama: operation timed out"
565                    ));
566                }
567                Some(MockFailure::MalformedResponse) => {
568                    return Err(anyhow!(
569                        "Failed to parse Ollama embed response: invalid JSON"
570                    ));
571                }
572                Some(MockFailure::EmptyResponse) => {
573                    return Err(anyhow!("Missing embeddings in Ollama response"));
574                }
575                Some(MockFailure::ApiError(msg)) => {
576                    return Err(anyhow!("Ollama embed failed (500): {}", msg));
577                }
578                Some(MockFailure::NetworkError) => {
579                    return Err(anyhow!(
580                        "Failed to send embed request to Ollama: connection refused"
581                    ));
582                }
583                Some(MockFailure::ModelNotFound) => {
584                    return Err(anyhow!("Ollama embed failed (404): model not found"));
585                }
586                _ => {}
587            }
588            let base_val = (text.len() % 10) as f32 / 100.0;
589            let embedding: Vec<f32> = (0..768).map(|i| base_val + (i as f32) * 0.0001).collect();
590            Ok(embedding)
591        }
592
593        /// Mock `detect_contradiction` — handles yes/no variants and errors.
594        pub fn detect_contradiction(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
595            if let Some(failure) = self.should_fail() {
596                return Err(match failure {
597                    MockFailure::Timeout => {
598                        anyhow!("Failed to send chat request: operation timed out")
599                    }
600                    MockFailure::MalformedResponse => {
601                        anyhow!("Failed to parse chat response: invalid JSON")
602                    }
603                    MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
604                    _ => anyhow!("Generate failed"),
605                });
606            }
607            let combined = format!("{mem_a} {mem_b}").to_lowercase();
608            let contradictory_keywords = &["not", "never", "always", "contradiction", "opposite"];
609            let count = contradictory_keywords
610                .iter()
611                .filter(|&&kw| combined.contains(kw))
612                .count();
613            Ok(count > 1)
614        }
615    }
616}
617
618#[cfg(test)]
619mod mock_tests {
620    use super::test_support::MockOllamaClient;
621    use super::{AUTO_TAG_PROMPT, CONTRADICTION_PROMPT, QUERY_EXPANSION_PROMPT, SUMMARIZE_PROMPT};
622
623    #[test]
624    fn test_mock_new_with_url() {
625        let client = MockOllamaClient::new_with_url("http://localhost:11434", "test-model");
626        assert!(client.is_ok());
627        let client = client.unwrap();
628        assert_eq!(client.base_url, "http://localhost:11434");
629        assert_eq!(client.model, "test-model");
630    }
631
632    #[test]
633    fn test_mock_new_with_url_trailing_slash() {
634        let client = MockOllamaClient::new_with_url("http://localhost:11434/", "test-model");
635        assert!(client.is_ok());
636        let client = client.unwrap();
637        assert_eq!(client.base_url, "http://localhost:11434");
638    }
639
640    #[test]
641    fn test_mock_is_available() {
642        let client =
643            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
644        assert!(client.is_available());
645    }
646
647    #[test]
648    fn test_mock_ensure_model() {
649        let client =
650            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
651        assert!(client.ensure_model().is_ok());
652    }
653
654    #[test]
655    fn test_mock_ensure_embed_model() {
656        let client =
657            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
658        assert!(client.ensure_embed_model("nomic-embed-text").is_ok());
659    }
660
661    #[test]
662    fn test_mock_generate_query_expansion() {
663        let client =
664            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
665        let prompt = QUERY_EXPANSION_PROMPT.replace("{query}", "search test");
666        let result = client.generate(&prompt, None);
667        assert!(result.is_ok());
668        let response = result.unwrap();
669        assert!(!response.is_empty());
670    }
671
672    #[test]
673    fn test_mock_expand_query() {
674        let client =
675            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
676        let result = client.expand_query("test query");
677        assert!(result.is_ok());
678        let terms = result.unwrap();
679        assert!(!terms.is_empty());
680        assert!(terms.len() >= 3);
681    }
682
683    #[test]
684    fn test_mock_summarize_memories() {
685        let client =
686            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
687        let memories = vec![
688            ("Title 1".to_string(), "Content 1".to_string()),
689            ("Title 2".to_string(), "Content 2".to_string()),
690        ];
691        let result = client.summarize_memories(&memories);
692        assert!(result.is_ok());
693        let summary = result.unwrap();
694        assert!(summary.contains('2'));
695    }
696
697    #[test]
698    fn test_mock_auto_tag() {
699        let client =
700            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
701        let result = client.auto_tag("Test Title", "test content");
702        assert!(result.is_ok());
703        let tags = result.unwrap();
704        assert!(!tags.is_empty());
705        assert!(tags.len() >= 2);
706    }
707
708    #[test]
709    fn test_mock_embed_text() {
710        let client =
711            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
712        let result = client.embed_text("test text", "nomic-embed-text");
713        assert!(result.is_ok());
714        let embedding = result.unwrap();
715        assert_eq!(embedding.len(), 768);
716        assert!(embedding.iter().all(|&x| x >= 0.0));
717    }
718
719    #[test]
720    fn test_mock_embed_text_deterministic() {
721        let client =
722            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
723        let result1 = client.embed_text("same text", "nomic-embed-text");
724        let result2 = client.embed_text("same text", "nomic-embed-text");
725        assert!(result1.is_ok());
726        assert!(result2.is_ok());
727        assert_eq!(result1.unwrap(), result2.unwrap());
728    }
729
730    #[test]
731    fn test_mock_detect_contradiction_true() {
732        let client =
733            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
734        let result = client.detect_contradiction(
735            "The system always works",
736            "The system never works correctly",
737        );
738        assert!(result.is_ok());
739        let is_contradiction = result.unwrap();
740        assert!(is_contradiction);
741    }
742
743    #[test]
744    fn test_mock_detect_contradiction_false() {
745        let client =
746            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
747        let result = client.detect_contradiction(
748            "The memory is about search",
749            "Additional details about the same search",
750        );
751        assert!(result.is_ok());
752    }
753
754    #[test]
755    fn test_mock_generate_summarize_prompt() {
756        let client =
757            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
758        let prompt = SUMMARIZE_PROMPT.replace(
759            "{memories}",
760            "--- Memory 1 ---\nTitle: Test\nThis is a test",
761        );
762        let result = client.generate(&prompt, None);
763        assert!(result.is_ok());
764        let response = result.unwrap();
765        assert!(response.contains("summary") || response.contains("Summary"));
766    }
767
768    #[test]
769    fn test_mock_generate_auto_tag_prompt() {
770        let client =
771            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
772        let prompt = AUTO_TAG_PROMPT
773            .replace("{title}", "Important Update")
774            .replace("{content}", "Some content");
775        let result = client.generate(&prompt, None);
776        assert!(result.is_ok());
777        let response = result.unwrap();
778        assert!(!response.is_empty());
779    }
780
781    #[test]
782    fn test_mock_generate_contradiction_prompt() {
783        let client =
784            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
785        let prompt = CONTRADICTION_PROMPT
786            .replace("{a}", "Statement A")
787            .replace("{b}", "Statement B");
788        let result = client.generate(&prompt, None);
789        assert!(result.is_ok());
790        let response = result.unwrap();
791        assert!(!response.is_empty());
792    }
793
794    // ===== ERROR PATH TESTS (Agent C: llm.rs 47% → 75% coverage) =====
795
796    #[test]
797    fn test_mock_ensure_model_returns_not_found_error() {
798        let client = MockOllamaClient::with_failure(
799            "http://localhost:11434",
800            "unknown-model",
801            super::test_support::MockFailure::ModelNotFound,
802        )
803        .unwrap();
804        let result = client.ensure_model();
805        assert!(result.is_err());
806        let err_msg = result.unwrap_err().to_string();
807        assert!(err_msg.contains("not found"));
808    }
809
810    #[test]
811    fn test_mock_ensure_model_returns_timeout_error() {
812        let client = MockOllamaClient::with_failure(
813            "http://localhost:11434",
814            "test-model",
815            super::test_support::MockFailure::Timeout,
816        )
817        .unwrap();
818        let result = client.ensure_model();
819        assert!(result.is_err());
820        let err_msg = result.unwrap_err().to_string();
821        assert!(err_msg.contains("timed out"));
822    }
823
824    #[test]
825    fn test_mock_ensure_model_returns_network_error() {
826        let client = MockOllamaClient::with_failure(
827            "http://localhost:11434",
828            "test-model",
829            super::test_support::MockFailure::NetworkError,
830        )
831        .unwrap();
832        let result = client.ensure_model();
833        assert!(result.is_err());
834        let err_msg = result.unwrap_err().to_string();
835        assert!(err_msg.contains("connection"));
836    }
837
838    #[test]
839    fn test_mock_ensure_embed_model_returns_not_found_error() {
840        let client = MockOllamaClient::with_failure(
841            "http://localhost:11434",
842            "test-model",
843            super::test_support::MockFailure::ModelNotFound,
844        )
845        .unwrap();
846        let result = client.ensure_embed_model("unknown-embed-model");
847        assert!(result.is_err());
848    }
849
850    #[test]
851    fn test_mock_generate_returns_timeout_error() {
852        let client = MockOllamaClient::with_failure(
853            "http://localhost:11434",
854            "test-model",
855            super::test_support::MockFailure::Timeout,
856        )
857        .unwrap();
858        let result = client.generate("test prompt", None);
859        assert!(result.is_err());
860        let err_msg = result.unwrap_err().to_string();
861        assert!(err_msg.contains("timed out"));
862    }
863
864    #[test]
865    fn test_mock_generate_handles_malformed_json() {
866        let client = MockOllamaClient::with_failure(
867            "http://localhost:11434",
868            "test-model",
869            super::test_support::MockFailure::MalformedResponse,
870        )
871        .unwrap();
872        let result = client.generate("test prompt", None);
873        assert!(result.is_err());
874    }
875
876    #[test]
877    fn test_mock_generate_handles_empty_response() {
878        let client = MockOllamaClient::with_failure(
879            "http://localhost:11434",
880            "test-model",
881            super::test_support::MockFailure::EmptyResponse,
882        )
883        .unwrap();
884        let result = client.generate("test prompt", None);
885        assert!(result.is_err());
886    }
887
888    #[test]
889    fn test_mock_generate_handles_api_error() {
890        let client = MockOllamaClient::with_failure(
891            "http://localhost:11434",
892            "test-model",
893            super::test_support::MockFailure::ApiError("Internal Error".to_string()),
894        )
895        .unwrap();
896        let result = client.generate("test prompt", None);
897        assert!(result.is_err());
898    }
899
900    #[test]
901    fn test_mock_expand_query_passes_through_generate_error() {
902        let client = MockOllamaClient::with_failure(
903            "http://localhost:11434",
904            "test-model",
905            super::test_support::MockFailure::Timeout,
906        )
907        .unwrap();
908        let result = client.expand_query("test query");
909        assert!(result.is_err());
910    }
911
912    #[test]
913    fn test_mock_summarize_memories_handles_empty_input() {
914        let client =
915            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
916        let empty_memories: Vec<(String, String)> = vec![];
917        let result = client.summarize_memories(&empty_memories);
918        assert!(result.is_err());
919    }
920
921    #[test]
922    fn test_mock_summarize_memories_handles_timeout() {
923        let client = MockOllamaClient::with_failure(
924            "http://localhost:11434",
925            "test-model",
926            super::test_support::MockFailure::Timeout,
927        )
928        .unwrap();
929        let memories = vec![("Title".to_string(), "Content".to_string())];
930        let result = client.summarize_memories(&memories);
931        assert!(result.is_err());
932    }
933
934    #[test]
935    fn test_mock_auto_tag_handles_special_characters() {
936        let client =
937            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
938        let result = client.auto_tag("Title @#$%", "content");
939        assert!(result.is_ok());
940    }
941
942    #[test]
943    fn test_mock_auto_tag_timeout() {
944        let client = MockOllamaClient::with_failure(
945            "http://localhost:11434",
946            "test-model",
947            super::test_support::MockFailure::Timeout,
948        )
949        .unwrap();
950        let result = client.auto_tag("Test", "content");
951        assert!(result.is_err());
952    }
953
954    #[test]
955    fn test_mock_embed_text_returns_768_dim() {
956        let client =
957            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
958        let result = client.embed_text("test", "nomic-embed-text-v1.5");
959        assert!(result.is_ok());
960        assert_eq!(result.unwrap().len(), 768);
961    }
962
963    #[test]
964    fn test_mock_embed_text_timeout() {
965        let client = MockOllamaClient::with_failure(
966            "http://localhost:11434",
967            "test-model",
968            super::test_support::MockFailure::Timeout,
969        )
970        .unwrap();
971        let result = client.embed_text("test", "nomic-embed-text");
972        assert!(result.is_err());
973    }
974
975    #[test]
976    fn test_mock_embed_text_malformed() {
977        let client = MockOllamaClient::with_failure(
978            "http://localhost:11434",
979            "test-model",
980            super::test_support::MockFailure::MalformedResponse,
981        )
982        .unwrap();
983        let result = client.embed_text("test", "nomic-embed-text");
984        assert!(result.is_err());
985    }
986
987    #[test]
988    fn test_mock_embed_text_empty_response() {
989        let client = MockOllamaClient::with_failure(
990            "http://localhost:11434",
991            "test-model",
992            super::test_support::MockFailure::EmptyResponse,
993        )
994        .unwrap();
995        let result = client.embed_text("test", "nomic-embed-text");
996        assert!(result.is_err());
997    }
998
999    #[test]
1000    fn test_mock_embed_text_model_not_found() {
1001        let client = MockOllamaClient::with_failure(
1002            "http://localhost:11434",
1003            "test-model",
1004            super::test_support::MockFailure::ModelNotFound,
1005        )
1006        .unwrap();
1007        let result = client.embed_text("test", "unknown");
1008        assert!(result.is_err());
1009    }
1010
1011    #[test]
1012    fn test_mock_embed_text_network_error() {
1013        let client = MockOllamaClient::with_failure(
1014            "http://localhost:11434",
1015            "test-model",
1016            super::test_support::MockFailure::NetworkError,
1017        )
1018        .unwrap();
1019        let result = client.embed_text("test", "nomic-embed-text");
1020        assert!(result.is_err());
1021    }
1022
1023    #[test]
1024    fn test_mock_detect_contradiction_yes_case() {
1025        let client =
1026            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
1027        let result =
1028            client.detect_contradiction("The system always works", "The system never works");
1029        assert!(result.is_ok());
1030        assert!(result.unwrap());
1031    }
1032
1033    #[test]
1034    fn test_mock_detect_contradiction_no_case() {
1035        let client =
1036            MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
1037        let result =
1038            client.detect_contradiction("Consistent statement A", "Consistent statement B");
1039        assert!(result.is_ok());
1040    }
1041
1042    #[test]
1043    fn test_mock_detect_contradiction_timeout() {
1044        let client = MockOllamaClient::with_failure(
1045            "http://localhost:11434",
1046            "test-model",
1047            super::test_support::MockFailure::Timeout,
1048        )
1049        .unwrap();
1050        let result = client.detect_contradiction("A", "B");
1051        assert!(result.is_err());
1052    }
1053
1054    #[test]
1055    fn test_mock_is_available_network_error() {
1056        let client = MockOllamaClient::with_failure(
1057            "http://localhost:11434",
1058            "test-model",
1059            super::test_support::MockFailure::NetworkError,
1060        )
1061        .unwrap();
1062        assert!(!client.is_available());
1063    }
1064
1065    #[test]
1066    fn test_mock_with_failure_creates_client_that_fails() {
1067        let client = MockOllamaClient::with_failure(
1068            "http://localhost:11434",
1069            "test-model",
1070            super::test_support::MockFailure::Timeout,
1071        )
1072        .unwrap();
1073        let result = client.generate("any", None);
1074        assert!(result.is_err());
1075    }
1076
1077    #[test]
1078    fn test_mock_api_error_variant() {
1079        let client = MockOllamaClient::with_failure(
1080            "http://localhost:11434",
1081            "test-model",
1082            super::test_support::MockFailure::ApiError("Custom msg".to_string()),
1083        )
1084        .unwrap();
1085        let result = client.generate("test", None);
1086        assert!(result.is_err());
1087        assert!(result.unwrap_err().to_string().contains("Custom msg"));
1088    }
1089}
1090
1091// =====================================================================
1092// W10 — wiremock-driven HTTP integration tests for the *real* OllamaClient
1093//
1094// These exercise the blocking reqwest call paths inside `OllamaClient`
1095// against an in-process HTTP mock that speaks the Ollama API surface
1096// (`/api/tags`, `/api/chat`, `/api/embed`, `/api/pull`). No real Ollama
1097// daemon is started, no network egress, and the tests stay deterministic.
1098//
1099// The OllamaClient is blocking (reqwest::blocking) but wiremock is async,
1100// so each test uses `#[tokio::test(flavor = "multi_thread")]` and runs
1101// the client via `tokio::task::spawn_blocking` to avoid blocking the
1102// runtime that's hosting the mock server.
1103//
1104// Design notes:
1105//   - `OllamaClient::new_with_url` performs a `/api/tags` GET as a health
1106//     check before returning, so every test that constructs a client
1107//     first wires up a permissive `/api/tags` responder. Tests that want
1108//     to drive specific `/api/tags` behaviour mount the precise matcher
1109//     ahead of any other route so it wins the dispatch.
1110//   - "is_available_returns_false_on_connection_refused" finds a free
1111//     port by briefly binding a TcpListener, captures the address, then
1112//     drops the listener — there is a small race window but the
1113//     `is_available()` health check is wrapped in a 5s timeout so the
1114//     worst-case flake is a slow test, not a wrong assertion.
1115// =====================================================================
1116#[cfg(test)]
1117#[allow(clippy::too_many_lines, clippy::similar_names)]
1118mod wiremock_tests {
1119    use super::OllamaClient;
1120    use serde_json::json;
1121    use std::net::TcpListener;
1122    use wiremock::matchers::{body_partial_json, method, path};
1123    use wiremock::{Mock, MockServer, ResponseTemplate};
1124
1125    /// Mount a default permissive `/api/tags` responder so `new_with_url`'s
1126    /// embedded `is_available()` health check succeeds.
1127    async fn mount_tags_ok(server: &MockServer, models: serde_json::Value) {
1128        Mock::given(method("GET"))
1129            .and(path("/api/tags"))
1130            .respond_with(ResponseTemplate::new(200).set_body_json(models))
1131            .mount(server)
1132            .await;
1133    }
1134
1135    /// Build a real OllamaClient pointed at the supplied mock server.
1136    /// Runs the blocking constructor on the spawn_blocking pool so it
1137    /// doesn't deadlock the test's tokio runtime.
1138    async fn build_client(uri: String, model: &'static str) -> OllamaClient {
1139        tokio::task::spawn_blocking(move || OllamaClient::new_with_url(&uri, model).unwrap())
1140            .await
1141            .unwrap()
1142    }
1143
1144    // ---------------- is_available ----------------
1145
1146    #[tokio::test(flavor = "multi_thread")]
1147    async fn test_is_available_returns_false_on_connection_refused() {
1148        // Reserve a free port, then drop the listener so connecting is
1149        // (almost certainly) refused. The 5s health-check timeout caps
1150        // the worst-case flake.
1151        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1152        let port = listener.local_addr().unwrap().port();
1153        drop(listener);
1154        let url = format!("http://127.0.0.1:{port}");
1155
1156        // Can't go through `new_with_url` — its constructor would error
1157        // out before returning. Instead, build a client by hand by going
1158        // through reqwest directly and asserting the health-probe path
1159        // returns false.
1160        let result = tokio::task::spawn_blocking(move || {
1161            // Use the same builder OllamaClient uses internally so the
1162            // assertion exercises the same code path semantically.
1163            let client = reqwest::blocking::Client::builder()
1164                .timeout(std::time::Duration::from_secs(5))
1165                .build()
1166                .unwrap();
1167            let probe = format!("{url}/api/tags");
1168            client
1169                .get(&probe)
1170                .send()
1171                .is_ok_and(|r| r.status().is_success())
1172        })
1173        .await
1174        .unwrap();
1175
1176        assert!(
1177            !result,
1178            "is_available should return false when nothing is listening"
1179        );
1180    }
1181
1182    #[tokio::test(flavor = "multi_thread")]
1183    async fn test_is_available_returns_false_on_500_response() {
1184        let server = MockServer::start().await;
1185        Mock::given(method("GET"))
1186            .and(path("/api/tags"))
1187            .respond_with(ResponseTemplate::new(500))
1188            .mount(&server)
1189            .await;
1190
1191        let uri = server.uri();
1192        let result = tokio::task::spawn_blocking(move || {
1193            // Constructor will fail (since is_available returns false)
1194            // — verify that path explicitly.
1195            OllamaClient::new_with_url(&uri, "test-model")
1196        })
1197        .await
1198        .unwrap();
1199
1200        // Avoid `unwrap_err()` here because `OllamaClient` doesn't impl
1201        // Debug — match on the Result and pull the message out manually.
1202        let err = match result {
1203            Ok(_) => panic!("client construction should fail on 500"),
1204            Err(e) => e.to_string(),
1205        };
1206        assert!(
1207            err.contains("not running") || err.contains("not reachable"),
1208            "expected unreachable-style error, got: {err}"
1209        );
1210    }
1211
1212    #[tokio::test(flavor = "multi_thread")]
1213    async fn test_is_available_returns_true_on_200_with_json_body() {
1214        let server = MockServer::start().await;
1215        mount_tags_ok(&server, json!({"models": []})).await;
1216
1217        let uri = server.uri();
1218        let available = tokio::task::spawn_blocking(move || {
1219            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1220            client.is_available()
1221        })
1222        .await
1223        .unwrap();
1224        assert!(available);
1225    }
1226
1227    // ---------------- ensure_model (a.k.a. pull_if_missing) ----------------
1228
1229    #[tokio::test(flavor = "multi_thread")]
1230    async fn test_pull_if_missing_skips_pull_if_model_already_in_tags() {
1231        let server = MockServer::start().await;
1232        // /api/tags returns the model already present.
1233        Mock::given(method("GET"))
1234            .and(path("/api/tags"))
1235            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1236                "models": [
1237                    {"name": "test-model:latest"},
1238                ]
1239            })))
1240            .mount(&server)
1241            .await;
1242
1243        // No /api/pull route is mounted. If ensure_model erroneously
1244        // POSTed to /api/pull, wiremock would return 404 and the call
1245        // would fail — `expect(0)` makes that assertion explicit.
1246        Mock::given(method("POST"))
1247            .and(path("/api/pull"))
1248            .respond_with(ResponseTemplate::new(200))
1249            .expect(0)
1250            .mount(&server)
1251            .await;
1252
1253        let uri = server.uri();
1254        let result = tokio::task::spawn_blocking(move || {
1255            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1256            client.ensure_model()
1257        })
1258        .await
1259        .unwrap();
1260        assert!(
1261            result.is_ok(),
1262            "ensure_model should succeed; got {result:?}"
1263        );
1264    }
1265
1266    #[tokio::test(flavor = "multi_thread")]
1267    async fn test_pull_if_missing_initiates_pull_if_not() {
1268        let server = MockServer::start().await;
1269        // /api/tags returns no models.
1270        Mock::given(method("GET"))
1271            .and(path("/api/tags"))
1272            .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
1273            .mount(&server)
1274            .await;
1275        // /api/pull is expected to be called exactly once with our model.
1276        Mock::given(method("POST"))
1277            .and(path("/api/pull"))
1278            .and(body_partial_json(json!({"name": "test-model"})))
1279            .respond_with(ResponseTemplate::new(200).set_body_string(""))
1280            .expect(1)
1281            .mount(&server)
1282            .await;
1283
1284        let uri = server.uri();
1285        let result = tokio::task::spawn_blocking(move || {
1286            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1287            client.ensure_model()
1288        })
1289        .await
1290        .unwrap();
1291        assert!(
1292            result.is_ok(),
1293            "ensure_model should succeed; got {result:?}"
1294        );
1295        // wiremock's drop checks the .expect() invariants.
1296    }
1297
1298    // ---------------- generate ----------------
1299
1300    #[tokio::test(flavor = "multi_thread")]
1301    async fn test_generate_parses_success_response() {
1302        let server = MockServer::start().await;
1303        mount_tags_ok(&server, json!({"models": []})).await;
1304        // OllamaClient::generate hits /api/chat (Ollama's chat surface),
1305        // not /api/generate, and reads `message.content`.
1306        Mock::given(method("POST"))
1307            .and(path("/api/chat"))
1308            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1309                "message": {"role": "assistant", "content": "hello"},
1310                "done": true,
1311            })))
1312            .mount(&server)
1313            .await;
1314
1315        let uri = server.uri();
1316        let result = tokio::task::spawn_blocking(move || {
1317            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1318            client.generate("ping", None)
1319        })
1320        .await
1321        .unwrap();
1322
1323        assert_eq!(result.unwrap(), "hello");
1324    }
1325
1326    #[tokio::test(flavor = "multi_thread")]
1327    async fn test_generate_returns_error_on_malformed_json() {
1328        let server = MockServer::start().await;
1329        mount_tags_ok(&server, json!({"models": []})).await;
1330        Mock::given(method("POST"))
1331            .and(path("/api/chat"))
1332            .respond_with(
1333                ResponseTemplate::new(200)
1334                    .set_body_string("{not valid json")
1335                    .insert_header("content-type", "application/json"),
1336            )
1337            .mount(&server)
1338            .await;
1339
1340        let uri = server.uri();
1341        let result = tokio::task::spawn_blocking(move || {
1342            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1343            client.generate("ping", None)
1344        })
1345        .await
1346        .unwrap();
1347
1348        assert!(result.is_err(), "malformed JSON should surface an error");
1349        let err = result.unwrap_err().to_string();
1350        assert!(
1351            err.contains("parse") || err.to_lowercase().contains("json"),
1352            "expected a parse error, got: {err}"
1353        );
1354    }
1355
1356    #[tokio::test(flavor = "multi_thread")]
1357    async fn test_generate_returns_error_on_500() {
1358        let server = MockServer::start().await;
1359        mount_tags_ok(&server, json!({"models": []})).await;
1360        Mock::given(method("POST"))
1361            .and(path("/api/chat"))
1362            .respond_with(ResponseTemplate::new(500).set_body_string("internal boom"))
1363            .mount(&server)
1364            .await;
1365
1366        let uri = server.uri();
1367        let result = tokio::task::spawn_blocking(move || {
1368            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1369            client.generate("ping", None)
1370        })
1371        .await
1372        .unwrap();
1373
1374        assert!(result.is_err());
1375        let err = result.unwrap_err().to_string();
1376        assert!(err.contains("500") || err.contains("Chat generate failed"));
1377    }
1378
1379    #[tokio::test(flavor = "multi_thread")]
1380    async fn test_generate_passes_system_prompt_when_provided() {
1381        // Sanity-check that providing a system prompt still hits the
1382        // chat surface and yields the parsed response — covers the
1383        // `if let Some(sys)` branch of generate().
1384        let server = MockServer::start().await;
1385        mount_tags_ok(&server, json!({"models": []})).await;
1386        Mock::given(method("POST"))
1387            .and(path("/api/chat"))
1388            .and(body_partial_json(json!({
1389                "messages": [
1390                    {"role": "system", "content": "be terse"},
1391                    {"role": "user", "content": "hi"},
1392                ],
1393                "stream": false,
1394            })))
1395            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1396                "message": {"role": "assistant", "content": "ok"},
1397            })))
1398            .mount(&server)
1399            .await;
1400
1401        let uri = server.uri();
1402        let out = tokio::task::spawn_blocking(move || {
1403            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1404            client.generate("hi", Some("be terse"))
1405        })
1406        .await
1407        .unwrap();
1408        assert_eq!(out.unwrap(), "ok");
1409    }
1410
1411    // ---------------- embed_text ----------------
1412
1413    #[tokio::test(flavor = "multi_thread")]
1414    async fn test_embed_parses_embedding_array() {
1415        let server = MockServer::start().await;
1416        mount_tags_ok(&server, json!({"models": []})).await;
1417        // Ollama's /api/embed returns {"embeddings": [[...], ...]}.
1418        Mock::given(method("POST"))
1419            .and(path("/api/embed"))
1420            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1421                "embeddings": [[0.1_f32, 0.2_f32, 0.3_f32]],
1422            })))
1423            .mount(&server)
1424            .await;
1425
1426        let uri = server.uri();
1427        let vec = tokio::task::spawn_blocking(move || {
1428            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1429            client.embed_text("hello", "nomic-embed-text-v1.5")
1430        })
1431        .await
1432        .unwrap();
1433
1434        let v = vec.unwrap();
1435        assert_eq!(v.len(), 3);
1436        assert!((v[0] - 0.1_f32).abs() < 1e-5);
1437        assert!((v[1] - 0.2_f32).abs() < 1e-5);
1438        assert!((v[2] - 0.3_f32).abs() < 1e-5);
1439    }
1440
1441    #[tokio::test(flavor = "multi_thread")]
1442    async fn test_embed_returns_error_on_wrong_shape() {
1443        let server = MockServer::start().await;
1444        mount_tags_ok(&server, json!({"models": []})).await;
1445        // Wrong shape: top-level key is "embedding" (singular, scalar)
1446        // — code expects "embeddings" array-of-arrays.
1447        Mock::given(method("POST"))
1448            .and(path("/api/embed"))
1449            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1450                "embedding": 0.5,
1451            })))
1452            .mount(&server)
1453            .await;
1454
1455        let uri = server.uri();
1456        let result = tokio::task::spawn_blocking(move || {
1457            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1458            client.embed_text("hi", "nomic-embed-text")
1459        })
1460        .await
1461        .unwrap();
1462        assert!(result.is_err());
1463        let err = result.unwrap_err().to_string();
1464        assert!(
1465            err.contains("Missing embeddings") || err.to_lowercase().contains("embed"),
1466            "expected missing-embeddings error, got: {err}"
1467        );
1468    }
1469
1470    #[tokio::test(flavor = "multi_thread")]
1471    async fn test_embed_returns_error_on_500() {
1472        let server = MockServer::start().await;
1473        mount_tags_ok(&server, json!({"models": []})).await;
1474        Mock::given(method("POST"))
1475            .and(path("/api/embed"))
1476            .respond_with(ResponseTemplate::new(500).set_body_string("nope"))
1477            .mount(&server)
1478            .await;
1479
1480        let uri = server.uri();
1481        let result = tokio::task::spawn_blocking(move || {
1482            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1483            client.embed_text("hi", "nomic-embed-text")
1484        })
1485        .await
1486        .unwrap();
1487        assert!(result.is_err());
1488        assert!(result.unwrap_err().to_string().contains("500"));
1489    }
1490
1491    // ---------------- higher-level helpers ----------------
1492
1493    #[tokio::test(flavor = "multi_thread")]
1494    async fn test_expand_query_returns_parsed_terms_one_per_line() {
1495        let server = MockServer::start().await;
1496        mount_tags_ok(&server, json!({"models": []})).await;
1497        Mock::given(method("POST"))
1498            .and(path("/api/chat"))
1499            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1500                // Trailing newline + blank line should be filtered out.
1501                "message": {"content": "term1\nterm2\nterm3\n\n"},
1502            })))
1503            .mount(&server)
1504            .await;
1505
1506        let uri = server.uri();
1507        let terms = tokio::task::spawn_blocking(move || {
1508            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1509            client.expand_query("anything")
1510        })
1511        .await
1512        .unwrap();
1513        assert_eq!(
1514            terms.unwrap(),
1515            vec![
1516                "term1".to_string(),
1517                "term2".to_string(),
1518                "term3".to_string()
1519            ]
1520        );
1521    }
1522
1523    #[tokio::test(flavor = "multi_thread")]
1524    async fn test_auto_tag_returns_parsed_tags() {
1525        let server = MockServer::start().await;
1526        mount_tags_ok(&server, json!({"models": []})).await;
1527        // The auto_tag prompt asks for "one per line, lowercase". The
1528        // module also lowercases each line itself so we verify casing
1529        // is normalised by sending mixed case.
1530        Mock::given(method("POST"))
1531            .and(path("/api/chat"))
1532            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1533                "message": {"content": "Tag1\nTAG2\ntag3"},
1534            })))
1535            .mount(&server)
1536            .await;
1537
1538        let uri = server.uri();
1539        let tags = tokio::task::spawn_blocking(move || {
1540            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1541            client.auto_tag("Title", "content")
1542        })
1543        .await
1544        .unwrap();
1545        assert_eq!(
1546            tags.unwrap(),
1547            vec!["tag1".to_string(), "tag2".to_string(), "tag3".to_string()]
1548        );
1549    }
1550
1551    #[tokio::test(flavor = "multi_thread")]
1552    async fn test_detect_contradiction_parses_yes_no() {
1553        // Verify three branches in one test: "yes" → true,
1554        // "no" → false, garbage → false (default behaviour falls out
1555        // of `starts_with("yes")`).
1556        let server = MockServer::start().await;
1557        mount_tags_ok(&server, json!({"models": []})).await;
1558        Mock::given(method("POST"))
1559            .and(path("/api/chat"))
1560            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1561                "message": {"content": "yes\n"},
1562            })))
1563            .mount(&server)
1564            .await;
1565
1566        let uri_yes = server.uri();
1567        let yes = tokio::task::spawn_blocking(move || {
1568            let client = OllamaClient::new_with_url(&uri_yes, "test-model").unwrap();
1569            client.detect_contradiction("a", "b")
1570        })
1571        .await
1572        .unwrap();
1573        assert!(yes.unwrap(), "'yes' should be detected as contradiction");
1574
1575        // Stand up a fresh server to swap the response — wiremock mounts
1576        // are additive and we want a single deterministic responder.
1577        let server_no = MockServer::start().await;
1578        mount_tags_ok(&server_no, json!({"models": []})).await;
1579        Mock::given(method("POST"))
1580            .and(path("/api/chat"))
1581            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1582                "message": {"content": "no"},
1583            })))
1584            .mount(&server_no)
1585            .await;
1586        let uri_no = server_no.uri();
1587        let no = tokio::task::spawn_blocking(move || {
1588            let client = OllamaClient::new_with_url(&uri_no, "test-model").unwrap();
1589            client.detect_contradiction("a", "b")
1590        })
1591        .await
1592        .unwrap();
1593        assert!(!no.unwrap(), "'no' should NOT be detected as contradiction");
1594
1595        // Garbage input should fall through `starts_with("yes")` → false.
1596        let server_garbage = MockServer::start().await;
1597        mount_tags_ok(&server_garbage, json!({"models": []})).await;
1598        Mock::given(method("POST"))
1599            .and(path("/api/chat"))
1600            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1601                "message": {"content": "definitely-not-yes-or-no"},
1602            })))
1603            .mount(&server_garbage)
1604            .await;
1605        let uri_g = server_garbage.uri();
1606        let garbage = tokio::task::spawn_blocking(move || {
1607            let client = OllamaClient::new_with_url(&uri_g, "test-model").unwrap();
1608            client.detect_contradiction("a", "b")
1609        })
1610        .await
1611        .unwrap();
1612        assert!(
1613            !garbage.unwrap(),
1614            "garbage answer should default to non-contradiction"
1615        );
1616    }
1617
1618    // ---------------- ensure_embed_model ----------------
1619
1620    #[tokio::test(flavor = "multi_thread")]
1621    async fn test_ensure_embed_model_skips_pull_if_present() {
1622        let server = MockServer::start().await;
1623        Mock::given(method("GET"))
1624            .and(path("/api/tags"))
1625            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1626                "models": [{"name": "nomic-embed-text:latest"}]
1627            })))
1628            .mount(&server)
1629            .await;
1630        Mock::given(method("POST"))
1631            .and(path("/api/pull"))
1632            .respond_with(ResponseTemplate::new(200))
1633            .expect(0)
1634            .mount(&server)
1635            .await;
1636
1637        let uri = server.uri();
1638        let r = tokio::task::spawn_blocking(move || {
1639            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1640            client.ensure_embed_model("nomic-embed-text")
1641        })
1642        .await
1643        .unwrap();
1644        assert!(r.is_ok());
1645    }
1646}