1use anyhow::{Context, Result, anyhow};
5use serde_json::{Value, json};
6use std::time::Duration;
7
8const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
9
10const GENERATE_TIMEOUT: Duration = Duration::from_secs(30);
11const PULL_TIMEOUT: Duration = Duration::from_secs(120);
12
13const QUERY_EXPANSION_PROMPT: &str = r"You are a search query expander. Given a search query, generate 5-8 additional search terms that are semantically related. Return ONLY the terms, one per line, no numbering or explanation.
14
15Query: {query}";
16
17const SUMMARIZE_PROMPT: &str = r"Summarize the following memories into a single concise paragraph. Preserve all key facts, decisions, and technical details.
18
19{memories}";
20
21const AUTO_TAG_PROMPT: &str = r"Generate 3-5 short tags for categorizing this memory. Return ONLY the tags, one per line, lowercase, no symbols.
22
23Title: {title}
24Content: {content}";
25
26const CONTRADICTION_PROMPT: &str = r#"Do these two statements contradict each other? Answer ONLY "yes" or "no".
27
28Statement A: {a}
29Statement B: {b}"#;
30
31pub struct OllamaClient {
32 base_url: String,
33 model: String,
34 client: reqwest::blocking::Client,
35}
36
37impl OllamaClient {
38 #[allow(dead_code)]
41 pub fn new(model: &str) -> Result<Self> {
42 Self::new_with_url(DEFAULT_OLLAMA_URL, model)
43 }
44
45 pub fn new_with_url(base_url: &str, model: &str) -> Result<Self> {
48 let client = reqwest::blocking::Client::builder()
49 .timeout(GENERATE_TIMEOUT)
50 .build()
51 .context("Failed to build HTTP client")?;
52
53 let instance = Self {
54 base_url: base_url.trim_end_matches('/').to_string(),
55 model: model.to_string(),
56 client,
57 };
58
59 if !instance.is_available() {
60 return Err(anyhow!(
61 "Ollama is not running or not reachable at {}. \
62 Start it with: ollama serve",
63 instance.base_url
64 ));
65 }
66
67 Ok(instance)
68 }
69
70 pub fn is_available(&self) -> bool {
72 let url = format!("{}/api/tags", self.base_url);
73 self.client
74 .get(&url)
75 .timeout(Duration::from_secs(5))
76 .send()
77 .is_ok_and(|r| r.status().is_success())
78 }
79
80 pub fn ensure_model(&self) -> Result<()> {
82 let url = format!("{}/api/tags", self.base_url);
84 let resp = self
85 .client
86 .get(&url)
87 .timeout(Duration::from_secs(10))
88 .send()
89 .context("Failed to list Ollama models")?;
90
91 let body: Value = resp.json().context("Failed to parse /api/tags response")?;
92
93 let model_exists = body["models"].as_array().is_some_and(|models| {
94 models.iter().any(|m| {
95 let name = m["name"].as_str().unwrap_or("");
96 let our_base = self.model.split(':').next().unwrap_or(&self.model);
99 name == self.model
100 || name.starts_with(&format!("{}:", self.model))
101 || self.model == name.split(':').next().unwrap_or("")
102 || name == our_base
103 })
104 });
105
106 if model_exists {
107 return Ok(());
108 }
109
110 tracing::info!(
112 "Pulling Ollama model '{}' (this may take a while)...",
113 self.model
114 );
115
116 let pull_url = format!("{}/api/pull", self.base_url);
117 let pull_client = reqwest::blocking::Client::builder()
118 .timeout(PULL_TIMEOUT)
119 .build()
120 .context("Failed to build pull client")?;
121
122 let resp = pull_client
123 .post(&pull_url)
124 .json(&json!({ "name": self.model }))
125 .send()
126 .context("Failed to pull model from Ollama")?;
127
128 if !resp.status().is_success() {
129 let status = resp.status();
130 let text = resp.text().unwrap_or_default();
131 return Err(anyhow!("Ollama pull failed ({status}): {text}"));
132 }
133
134 tracing::info!("Model '{}' pulled successfully", self.model);
135 Ok(())
136 }
137
138 pub fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String> {
142 let url = format!("{}/api/chat", self.base_url);
143
144 let mut messages = Vec::new();
145 if let Some(sys) = system {
146 messages.push(json!({"role": "system", "content": sys}));
147 }
148 messages.push(json!({"role": "user", "content": prompt}));
149
150 let payload = json!({
151 "model": self.model,
152 "messages": messages,
153 "stream": false,
154 });
155
156 let resp = self
157 .client
158 .post(&url)
159 .timeout(GENERATE_TIMEOUT)
160 .json(&payload)
161 .send()
162 .context("Failed to send chat request")?;
163
164 if !resp.status().is_success() {
165 let status = resp.status();
166 let text = resp.text().unwrap_or_default();
167 return Err(anyhow!("Chat generate failed ({status}): {text}"));
168 }
169
170 let body: Value = resp.json().context("Failed to parse chat response")?;
171
172 let response_text = body["message"]["content"]
174 .as_str()
175 .ok_or_else(|| anyhow!("Missing 'message.content' field in chat output"))?
176 .to_string();
177
178 Ok(response_text)
179 }
180
181 pub fn expand_query(&self, query: &str) -> Result<Vec<String>> {
183 let prompt = QUERY_EXPANSION_PROMPT.replace("{query}", query);
184 let response = self.generate(&prompt, None)?;
185
186 let terms: Vec<String> = response
187 .lines()
188 .map(|line| line.trim().to_string())
189 .filter(|line| !line.is_empty())
190 .collect();
191
192 Ok(terms)
193 }
194
195 pub fn summarize_memories(&self, memories: &[(String, String)]) -> Result<String> {
197 let formatted = memories
198 .iter()
199 .enumerate()
200 .map(|(i, (title, content))| {
201 format!("--- Memory {} ---\nTitle: {}\n{}", i + 1, title, content)
202 })
203 .collect::<Vec<_>>()
204 .join("\n\n");
205
206 let prompt = SUMMARIZE_PROMPT.replace("{memories}", &formatted);
207 let response = self.generate(&prompt, None)?;
208
209 Ok(response.trim().to_string())
210 }
211
212 pub fn auto_tag(&self, title: &str, content: &str) -> Result<Vec<String>> {
214 let prompt = AUTO_TAG_PROMPT
215 .replace("{title}", title)
216 .replace("{content}", content);
217
218 let response = self.generate(&prompt, None)?;
219
220 let tags: Vec<String> = response
221 .lines()
222 .map(|line| line.trim().to_lowercase())
223 .filter(|line| !line.is_empty())
224 .collect();
225
226 Ok(tags)
227 }
228
229 pub fn embed_text(&self, text: &str, embed_model: &str) -> Result<Vec<f32>> {
233 let url = format!("{}/api/embed", self.base_url);
234 let payload = json!({
235 "model": embed_model,
236 "input": text,
237 });
238
239 let resp = self
240 .client
241 .post(&url)
242 .timeout(GENERATE_TIMEOUT)
243 .json(&payload)
244 .send()
245 .context("Failed to send embed request to Ollama")?;
246
247 if !resp.status().is_success() {
248 let status = resp.status();
249 let text = resp.text().unwrap_or_default();
250 return Err(anyhow!("Ollama embed failed ({status}): {text}"));
251 }
252
253 let body: Value = resp
254 .json()
255 .context("Failed to parse Ollama embed response")?;
256
257 let embedding = body["embeddings"]
259 .as_array()
260 .and_then(|arr| arr.first())
261 .and_then(|v| v.as_array())
262 .ok_or_else(|| anyhow!("Missing embeddings in Ollama response"))?;
263
264 #[allow(clippy::cast_possible_truncation)]
265 let floats: Vec<f32> = embedding
266 .iter()
267 .filter_map(|v| v.as_f64().map(|f| f as f32))
268 .collect();
269
270 if floats.is_empty() {
271 return Err(anyhow!("Empty embedding returned from Ollama"));
272 }
273
274 Ok(floats)
275 }
276
277 pub fn ensure_embed_model(&self, model: &str) -> Result<()> {
279 let url = format!("{}/api/tags", self.base_url);
280 let resp = self
281 .client
282 .get(&url)
283 .timeout(std::time::Duration::from_secs(10))
284 .send()
285 .context("Failed to list Ollama models")?;
286
287 let body: Value = resp.json().context("Failed to parse /api/tags response")?;
288 let model_exists = body["models"].as_array().is_some_and(|models| {
289 models.iter().any(|m| {
290 let name = m["name"].as_str().unwrap_or("");
291 name == model
292 || name.starts_with(&format!("{model}:"))
293 || model == name.split(':').next().unwrap_or("")
294 })
295 });
296
297 if model_exists {
298 return Ok(());
299 }
300
301 tracing::info!("Pulling Ollama embedding model '{}'...", model);
302 let pull_url = format!("{}/api/pull", self.base_url);
303 let pull_client = reqwest::blocking::Client::builder()
304 .timeout(PULL_TIMEOUT)
305 .build()
306 .context("Failed to build pull client")?;
307 let resp = pull_client
308 .post(&pull_url)
309 .json(&json!({ "name": model }))
310 .send()
311 .context("Failed to pull embedding model from Ollama")?;
312
313 if !resp.status().is_success() {
314 let status = resp.status();
315 let text = resp.text().unwrap_or_default();
316 return Err(anyhow!("Ollama embed model pull failed ({status}): {text}"));
317 }
318
319 tracing::info!("Embedding model '{}' pulled successfully", model);
320 Ok(())
321 }
322
323 pub fn detect_contradiction(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
325 let prompt = CONTRADICTION_PROMPT
326 .replace("{a}", mem_a)
327 .replace("{b}", mem_b);
328
329 let response = self.generate(&prompt, None)?;
330 let answer = response.trim().to_lowercase();
331
332 Ok(answer.starts_with("yes"))
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_prompt_templates_have_placeholders() {
342 assert!(QUERY_EXPANSION_PROMPT.contains("{query}"));
343 assert!(SUMMARIZE_PROMPT.contains("{memories}"));
344 assert!(AUTO_TAG_PROMPT.contains("{title}"));
345 assert!(AUTO_TAG_PROMPT.contains("{content}"));
346 assert!(CONTRADICTION_PROMPT.contains("{a}"));
347 assert!(CONTRADICTION_PROMPT.contains("{b}"));
348 }
349
350 #[test]
351 fn test_default_url() {
352 assert_eq!(DEFAULT_OLLAMA_URL, "http://localhost:11434");
353 }
354}
355
356#[cfg(test)]
357#[allow(
358 clippy::unused_self,
359 clippy::unnecessary_wraps,
360 clippy::needless_pass_by_value,
361 clippy::wildcard_imports,
362 clippy::doc_markdown
363)]
364pub mod test_support {
365 use super::*;
366
367 pub enum MockFailure {
370 ModelNotFound,
371 Timeout,
372 MalformedResponse,
373 ApiError(String),
374 EmptyResponse,
375 NetworkError,
376 }
377
378 pub struct MockOllamaClient {
379 pub base_url: String,
380 pub model: String,
381 pub fail_with: Option<MockFailure>,
382 }
383
384 impl MockOllamaClient {
385 pub fn new_with_url(base_url: &str, model: &str) -> Result<Self> {
387 Ok(Self {
388 base_url: base_url.trim_end_matches('/').to_string(),
389 model: model.to_string(),
390 fail_with: None,
391 })
392 }
393
394 pub fn with_failure(base_url: &str, model: &str, failure: MockFailure) -> Result<Self> {
396 Ok(Self {
397 base_url: base_url.trim_end_matches('/').to_string(),
398 model: model.to_string(),
399 fail_with: Some(failure),
400 })
401 }
402
403 fn should_fail(&self) -> Option<&MockFailure> {
405 self.fail_with.as_ref()
406 }
407
408 pub fn is_available(&self) -> bool {
410 !matches!(self.should_fail(), Some(MockFailure::NetworkError))
411 }
412
413 pub fn ensure_model(&self) -> Result<()> {
415 match self.should_fail() {
416 Some(MockFailure::ModelNotFound) => Err(anyhow!(
417 "Model 'unknown-model' not found in Ollama registry"
418 )),
419 Some(MockFailure::Timeout) => {
420 Err(anyhow!("Failed to list Ollama models: operation timed out"))
421 }
422 Some(MockFailure::ApiError(msg)) => {
423 Err(anyhow!("Ollama pull failed (404): {}", msg))
424 }
425 Some(MockFailure::NetworkError) => Err(anyhow!(
426 "Failed to pull model from Ollama: connection refused"
427 )),
428 _ => Ok(()),
429 }
430 }
431
432 pub fn ensure_embed_model(&self, _model: &str) -> Result<()> {
434 match self.should_fail() {
435 Some(MockFailure::ModelNotFound) => Err(anyhow!("Embedding model not found")),
436 Some(MockFailure::Timeout) => {
437 Err(anyhow!("Failed to list Ollama models: operation timed out"))
438 }
439 Some(MockFailure::ApiError(msg)) => {
440 Err(anyhow!("Ollama embed model pull failed (404): {}", msg))
441 }
442 Some(MockFailure::NetworkError) => Err(anyhow!(
443 "Failed to pull embedding model from Ollama: connection refused"
444 )),
445 _ => Ok(()),
446 }
447 }
448
449 pub fn generate(&self, prompt: &str, _system: Option<&str>) -> Result<String> {
451 match self.should_fail() {
452 Some(MockFailure::Timeout) => {
453 return Err(anyhow!("Failed to send chat request: operation timed out"));
454 }
455 Some(MockFailure::MalformedResponse) => {
456 return Err(anyhow!("Failed to parse chat response: invalid JSON"));
457 }
458 Some(MockFailure::EmptyResponse) => {
459 return Err(anyhow!("Missing 'message.content' field in chat output"));
460 }
461 Some(MockFailure::ApiError(msg)) => {
462 return Err(anyhow!("Chat generate failed (500): {}", msg));
463 }
464 Some(MockFailure::NetworkError) => {
465 return Err(anyhow!("Failed to send chat request: connection refused"));
466 }
467 _ => {}
468 }
469
470 if prompt.contains("expand") || prompt.contains("search") {
472 Ok("semantic search\nquery terms\nvector retrieval\ninformation retrieval\nsimilarity matching"
473 .to_string())
474 } else if prompt.contains("Summarize") {
475 Ok("This is a consolidated summary of multiple memories covering key facts and decisions."
476 .to_string())
477 } else if prompt.contains("tags") {
478 Ok("important\nkey-fact\nstatus-update\ntechnical".to_string())
479 } else if prompt.contains("contradict") {
480 if prompt.contains("yes") || prompt.contains("true") {
481 Ok("yes".to_string())
482 } else {
483 Ok("no".to_string())
484 }
485 } else {
486 Ok("Mock response for: ".to_string() + &prompt[..prompt.len().min(50)])
487 }
488 }
489
490 pub fn expand_query(&self, query: &str) -> Result<Vec<String>> {
492 if let Some(failure) = self.should_fail() {
493 return Err(match failure {
494 MockFailure::Timeout => {
495 anyhow!("Failed to send chat request: operation timed out")
496 }
497 MockFailure::MalformedResponse => {
498 anyhow!("Failed to parse chat response: invalid JSON")
499 }
500 MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
501 _ => anyhow!("Generate failed"),
502 });
503 }
504 let terms: Vec<String> = vec![
505 format!("{}-related", query),
506 format!("{}-expanded", query),
507 "semantic-search".to_string(),
508 "vector-expansion".to_string(),
509 "query-variants".to_string(),
510 ];
511 Ok(terms.to_vec())
512 }
513
514 pub fn summarize_memories(&self, memories: &[(String, String)]) -> Result<String> {
516 if memories.is_empty() {
517 return Err(anyhow!("Cannot summarize empty memories list"));
518 }
519 if let Some(failure) = self.should_fail() {
520 return Err(match failure {
521 MockFailure::Timeout => {
522 anyhow!("Failed to send chat request: operation timed out")
523 }
524 MockFailure::MalformedResponse => {
525 anyhow!("Failed to parse chat response: invalid JSON")
526 }
527 MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
528 _ => anyhow!("Generate failed"),
529 });
530 }
531 let count = memories.len();
532 Ok(format!(
533 "Summary of {count} memories: consolidated facts and key decisions preserved"
534 ))
535 }
536
537 pub fn auto_tag(&self, title: &str, _content: &str) -> Result<Vec<String>> {
539 if let Some(failure) = self.should_fail() {
540 return Err(match failure {
541 MockFailure::Timeout => {
542 anyhow!("Failed to send chat request: operation timed out")
543 }
544 MockFailure::MalformedResponse => {
545 anyhow!("Failed to parse chat response: invalid JSON")
546 }
547 MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
548 _ => anyhow!("Generate failed"),
549 });
550 }
551 let tags: Vec<String> = vec![
552 "important".to_string(),
553 format!("{}-tag", title.split_whitespace().next().unwrap_or("data")),
554 "memory".to_string(),
555 ];
556 Ok(tags)
557 }
558
559 pub fn embed_text(&self, text: &str, _embed_model: &str) -> Result<Vec<f32>> {
561 match self.should_fail() {
562 Some(MockFailure::Timeout) => {
563 return Err(anyhow!(
564 "Failed to send embed request to Ollama: operation timed out"
565 ));
566 }
567 Some(MockFailure::MalformedResponse) => {
568 return Err(anyhow!(
569 "Failed to parse Ollama embed response: invalid JSON"
570 ));
571 }
572 Some(MockFailure::EmptyResponse) => {
573 return Err(anyhow!("Missing embeddings in Ollama response"));
574 }
575 Some(MockFailure::ApiError(msg)) => {
576 return Err(anyhow!("Ollama embed failed (500): {}", msg));
577 }
578 Some(MockFailure::NetworkError) => {
579 return Err(anyhow!(
580 "Failed to send embed request to Ollama: connection refused"
581 ));
582 }
583 Some(MockFailure::ModelNotFound) => {
584 return Err(anyhow!("Ollama embed failed (404): model not found"));
585 }
586 _ => {}
587 }
588 let base_val = (text.len() % 10) as f32 / 100.0;
589 let embedding: Vec<f32> = (0..768).map(|i| base_val + (i as f32) * 0.0001).collect();
590 Ok(embedding)
591 }
592
593 pub fn detect_contradiction(&self, mem_a: &str, mem_b: &str) -> Result<bool> {
595 if let Some(failure) = self.should_fail() {
596 return Err(match failure {
597 MockFailure::Timeout => {
598 anyhow!("Failed to send chat request: operation timed out")
599 }
600 MockFailure::MalformedResponse => {
601 anyhow!("Failed to parse chat response: invalid JSON")
602 }
603 MockFailure::ApiError(msg) => anyhow!("Chat generate failed (500): {}", msg),
604 _ => anyhow!("Generate failed"),
605 });
606 }
607 let combined = format!("{mem_a} {mem_b}").to_lowercase();
608 let contradictory_keywords = &["not", "never", "always", "contradiction", "opposite"];
609 let count = contradictory_keywords
610 .iter()
611 .filter(|&&kw| combined.contains(kw))
612 .count();
613 Ok(count > 1)
614 }
615 }
616}
617
618#[cfg(test)]
619mod mock_tests {
620 use super::test_support::MockOllamaClient;
621 use super::{AUTO_TAG_PROMPT, CONTRADICTION_PROMPT, QUERY_EXPANSION_PROMPT, SUMMARIZE_PROMPT};
622
623 #[test]
624 fn test_mock_new_with_url() {
625 let client = MockOllamaClient::new_with_url("http://localhost:11434", "test-model");
626 assert!(client.is_ok());
627 let client = client.unwrap();
628 assert_eq!(client.base_url, "http://localhost:11434");
629 assert_eq!(client.model, "test-model");
630 }
631
632 #[test]
633 fn test_mock_new_with_url_trailing_slash() {
634 let client = MockOllamaClient::new_with_url("http://localhost:11434/", "test-model");
635 assert!(client.is_ok());
636 let client = client.unwrap();
637 assert_eq!(client.base_url, "http://localhost:11434");
638 }
639
640 #[test]
641 fn test_mock_is_available() {
642 let client =
643 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
644 assert!(client.is_available());
645 }
646
647 #[test]
648 fn test_mock_ensure_model() {
649 let client =
650 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
651 assert!(client.ensure_model().is_ok());
652 }
653
654 #[test]
655 fn test_mock_ensure_embed_model() {
656 let client =
657 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
658 assert!(client.ensure_embed_model("nomic-embed-text").is_ok());
659 }
660
661 #[test]
662 fn test_mock_generate_query_expansion() {
663 let client =
664 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
665 let prompt = QUERY_EXPANSION_PROMPT.replace("{query}", "search test");
666 let result = client.generate(&prompt, None);
667 assert!(result.is_ok());
668 let response = result.unwrap();
669 assert!(!response.is_empty());
670 }
671
672 #[test]
673 fn test_mock_expand_query() {
674 let client =
675 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
676 let result = client.expand_query("test query");
677 assert!(result.is_ok());
678 let terms = result.unwrap();
679 assert!(!terms.is_empty());
680 assert!(terms.len() >= 3);
681 }
682
683 #[test]
684 fn test_mock_summarize_memories() {
685 let client =
686 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
687 let memories = vec![
688 ("Title 1".to_string(), "Content 1".to_string()),
689 ("Title 2".to_string(), "Content 2".to_string()),
690 ];
691 let result = client.summarize_memories(&memories);
692 assert!(result.is_ok());
693 let summary = result.unwrap();
694 assert!(summary.contains('2'));
695 }
696
697 #[test]
698 fn test_mock_auto_tag() {
699 let client =
700 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
701 let result = client.auto_tag("Test Title", "test content");
702 assert!(result.is_ok());
703 let tags = result.unwrap();
704 assert!(!tags.is_empty());
705 assert!(tags.len() >= 2);
706 }
707
708 #[test]
709 fn test_mock_embed_text() {
710 let client =
711 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
712 let result = client.embed_text("test text", "nomic-embed-text");
713 assert!(result.is_ok());
714 let embedding = result.unwrap();
715 assert_eq!(embedding.len(), 768);
716 assert!(embedding.iter().all(|&x| x >= 0.0));
717 }
718
719 #[test]
720 fn test_mock_embed_text_deterministic() {
721 let client =
722 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
723 let result1 = client.embed_text("same text", "nomic-embed-text");
724 let result2 = client.embed_text("same text", "nomic-embed-text");
725 assert!(result1.is_ok());
726 assert!(result2.is_ok());
727 assert_eq!(result1.unwrap(), result2.unwrap());
728 }
729
730 #[test]
731 fn test_mock_detect_contradiction_true() {
732 let client =
733 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
734 let result = client.detect_contradiction(
735 "The system always works",
736 "The system never works correctly",
737 );
738 assert!(result.is_ok());
739 let is_contradiction = result.unwrap();
740 assert!(is_contradiction);
741 }
742
743 #[test]
744 fn test_mock_detect_contradiction_false() {
745 let client =
746 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
747 let result = client.detect_contradiction(
748 "The memory is about search",
749 "Additional details about the same search",
750 );
751 assert!(result.is_ok());
752 }
753
754 #[test]
755 fn test_mock_generate_summarize_prompt() {
756 let client =
757 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
758 let prompt = SUMMARIZE_PROMPT.replace(
759 "{memories}",
760 "--- Memory 1 ---\nTitle: Test\nThis is a test",
761 );
762 let result = client.generate(&prompt, None);
763 assert!(result.is_ok());
764 let response = result.unwrap();
765 assert!(response.contains("summary") || response.contains("Summary"));
766 }
767
768 #[test]
769 fn test_mock_generate_auto_tag_prompt() {
770 let client =
771 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
772 let prompt = AUTO_TAG_PROMPT
773 .replace("{title}", "Important Update")
774 .replace("{content}", "Some content");
775 let result = client.generate(&prompt, None);
776 assert!(result.is_ok());
777 let response = result.unwrap();
778 assert!(!response.is_empty());
779 }
780
781 #[test]
782 fn test_mock_generate_contradiction_prompt() {
783 let client =
784 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
785 let prompt = CONTRADICTION_PROMPT
786 .replace("{a}", "Statement A")
787 .replace("{b}", "Statement B");
788 let result = client.generate(&prompt, None);
789 assert!(result.is_ok());
790 let response = result.unwrap();
791 assert!(!response.is_empty());
792 }
793
794 #[test]
797 fn test_mock_ensure_model_returns_not_found_error() {
798 let client = MockOllamaClient::with_failure(
799 "http://localhost:11434",
800 "unknown-model",
801 super::test_support::MockFailure::ModelNotFound,
802 )
803 .unwrap();
804 let result = client.ensure_model();
805 assert!(result.is_err());
806 let err_msg = result.unwrap_err().to_string();
807 assert!(err_msg.contains("not found"));
808 }
809
810 #[test]
811 fn test_mock_ensure_model_returns_timeout_error() {
812 let client = MockOllamaClient::with_failure(
813 "http://localhost:11434",
814 "test-model",
815 super::test_support::MockFailure::Timeout,
816 )
817 .unwrap();
818 let result = client.ensure_model();
819 assert!(result.is_err());
820 let err_msg = result.unwrap_err().to_string();
821 assert!(err_msg.contains("timed out"));
822 }
823
824 #[test]
825 fn test_mock_ensure_model_returns_network_error() {
826 let client = MockOllamaClient::with_failure(
827 "http://localhost:11434",
828 "test-model",
829 super::test_support::MockFailure::NetworkError,
830 )
831 .unwrap();
832 let result = client.ensure_model();
833 assert!(result.is_err());
834 let err_msg = result.unwrap_err().to_string();
835 assert!(err_msg.contains("connection"));
836 }
837
838 #[test]
839 fn test_mock_ensure_embed_model_returns_not_found_error() {
840 let client = MockOllamaClient::with_failure(
841 "http://localhost:11434",
842 "test-model",
843 super::test_support::MockFailure::ModelNotFound,
844 )
845 .unwrap();
846 let result = client.ensure_embed_model("unknown-embed-model");
847 assert!(result.is_err());
848 }
849
850 #[test]
851 fn test_mock_generate_returns_timeout_error() {
852 let client = MockOllamaClient::with_failure(
853 "http://localhost:11434",
854 "test-model",
855 super::test_support::MockFailure::Timeout,
856 )
857 .unwrap();
858 let result = client.generate("test prompt", None);
859 assert!(result.is_err());
860 let err_msg = result.unwrap_err().to_string();
861 assert!(err_msg.contains("timed out"));
862 }
863
864 #[test]
865 fn test_mock_generate_handles_malformed_json() {
866 let client = MockOllamaClient::with_failure(
867 "http://localhost:11434",
868 "test-model",
869 super::test_support::MockFailure::MalformedResponse,
870 )
871 .unwrap();
872 let result = client.generate("test prompt", None);
873 assert!(result.is_err());
874 }
875
876 #[test]
877 fn test_mock_generate_handles_empty_response() {
878 let client = MockOllamaClient::with_failure(
879 "http://localhost:11434",
880 "test-model",
881 super::test_support::MockFailure::EmptyResponse,
882 )
883 .unwrap();
884 let result = client.generate("test prompt", None);
885 assert!(result.is_err());
886 }
887
888 #[test]
889 fn test_mock_generate_handles_api_error() {
890 let client = MockOllamaClient::with_failure(
891 "http://localhost:11434",
892 "test-model",
893 super::test_support::MockFailure::ApiError("Internal Error".to_string()),
894 )
895 .unwrap();
896 let result = client.generate("test prompt", None);
897 assert!(result.is_err());
898 }
899
900 #[test]
901 fn test_mock_expand_query_passes_through_generate_error() {
902 let client = MockOllamaClient::with_failure(
903 "http://localhost:11434",
904 "test-model",
905 super::test_support::MockFailure::Timeout,
906 )
907 .unwrap();
908 let result = client.expand_query("test query");
909 assert!(result.is_err());
910 }
911
912 #[test]
913 fn test_mock_summarize_memories_handles_empty_input() {
914 let client =
915 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
916 let empty_memories: Vec<(String, String)> = vec![];
917 let result = client.summarize_memories(&empty_memories);
918 assert!(result.is_err());
919 }
920
921 #[test]
922 fn test_mock_summarize_memories_handles_timeout() {
923 let client = MockOllamaClient::with_failure(
924 "http://localhost:11434",
925 "test-model",
926 super::test_support::MockFailure::Timeout,
927 )
928 .unwrap();
929 let memories = vec![("Title".to_string(), "Content".to_string())];
930 let result = client.summarize_memories(&memories);
931 assert!(result.is_err());
932 }
933
934 #[test]
935 fn test_mock_auto_tag_handles_special_characters() {
936 let client =
937 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
938 let result = client.auto_tag("Title @#$%", "content");
939 assert!(result.is_ok());
940 }
941
942 #[test]
943 fn test_mock_auto_tag_timeout() {
944 let client = MockOllamaClient::with_failure(
945 "http://localhost:11434",
946 "test-model",
947 super::test_support::MockFailure::Timeout,
948 )
949 .unwrap();
950 let result = client.auto_tag("Test", "content");
951 assert!(result.is_err());
952 }
953
954 #[test]
955 fn test_mock_embed_text_returns_768_dim() {
956 let client =
957 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
958 let result = client.embed_text("test", "nomic-embed-text-v1.5");
959 assert!(result.is_ok());
960 assert_eq!(result.unwrap().len(), 768);
961 }
962
963 #[test]
964 fn test_mock_embed_text_timeout() {
965 let client = MockOllamaClient::with_failure(
966 "http://localhost:11434",
967 "test-model",
968 super::test_support::MockFailure::Timeout,
969 )
970 .unwrap();
971 let result = client.embed_text("test", "nomic-embed-text");
972 assert!(result.is_err());
973 }
974
975 #[test]
976 fn test_mock_embed_text_malformed() {
977 let client = MockOllamaClient::with_failure(
978 "http://localhost:11434",
979 "test-model",
980 super::test_support::MockFailure::MalformedResponse,
981 )
982 .unwrap();
983 let result = client.embed_text("test", "nomic-embed-text");
984 assert!(result.is_err());
985 }
986
987 #[test]
988 fn test_mock_embed_text_empty_response() {
989 let client = MockOllamaClient::with_failure(
990 "http://localhost:11434",
991 "test-model",
992 super::test_support::MockFailure::EmptyResponse,
993 )
994 .unwrap();
995 let result = client.embed_text("test", "nomic-embed-text");
996 assert!(result.is_err());
997 }
998
999 #[test]
1000 fn test_mock_embed_text_model_not_found() {
1001 let client = MockOllamaClient::with_failure(
1002 "http://localhost:11434",
1003 "test-model",
1004 super::test_support::MockFailure::ModelNotFound,
1005 )
1006 .unwrap();
1007 let result = client.embed_text("test", "unknown");
1008 assert!(result.is_err());
1009 }
1010
1011 #[test]
1012 fn test_mock_embed_text_network_error() {
1013 let client = MockOllamaClient::with_failure(
1014 "http://localhost:11434",
1015 "test-model",
1016 super::test_support::MockFailure::NetworkError,
1017 )
1018 .unwrap();
1019 let result = client.embed_text("test", "nomic-embed-text");
1020 assert!(result.is_err());
1021 }
1022
1023 #[test]
1024 fn test_mock_detect_contradiction_yes_case() {
1025 let client =
1026 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
1027 let result =
1028 client.detect_contradiction("The system always works", "The system never works");
1029 assert!(result.is_ok());
1030 assert!(result.unwrap());
1031 }
1032
1033 #[test]
1034 fn test_mock_detect_contradiction_no_case() {
1035 let client =
1036 MockOllamaClient::new_with_url("http://localhost:11434", "test-model").unwrap();
1037 let result =
1038 client.detect_contradiction("Consistent statement A", "Consistent statement B");
1039 assert!(result.is_ok());
1040 }
1041
1042 #[test]
1043 fn test_mock_detect_contradiction_timeout() {
1044 let client = MockOllamaClient::with_failure(
1045 "http://localhost:11434",
1046 "test-model",
1047 super::test_support::MockFailure::Timeout,
1048 )
1049 .unwrap();
1050 let result = client.detect_contradiction("A", "B");
1051 assert!(result.is_err());
1052 }
1053
1054 #[test]
1055 fn test_mock_is_available_network_error() {
1056 let client = MockOllamaClient::with_failure(
1057 "http://localhost:11434",
1058 "test-model",
1059 super::test_support::MockFailure::NetworkError,
1060 )
1061 .unwrap();
1062 assert!(!client.is_available());
1063 }
1064
1065 #[test]
1066 fn test_mock_with_failure_creates_client_that_fails() {
1067 let client = MockOllamaClient::with_failure(
1068 "http://localhost:11434",
1069 "test-model",
1070 super::test_support::MockFailure::Timeout,
1071 )
1072 .unwrap();
1073 let result = client.generate("any", None);
1074 assert!(result.is_err());
1075 }
1076
1077 #[test]
1078 fn test_mock_api_error_variant() {
1079 let client = MockOllamaClient::with_failure(
1080 "http://localhost:11434",
1081 "test-model",
1082 super::test_support::MockFailure::ApiError("Custom msg".to_string()),
1083 )
1084 .unwrap();
1085 let result = client.generate("test", None);
1086 assert!(result.is_err());
1087 assert!(result.unwrap_err().to_string().contains("Custom msg"));
1088 }
1089}
1090
1091#[cfg(test)]
1117#[allow(clippy::too_many_lines, clippy::similar_names)]
1118mod wiremock_tests {
1119 use super::OllamaClient;
1120 use serde_json::json;
1121 use std::net::TcpListener;
1122 use wiremock::matchers::{body_partial_json, method, path};
1123 use wiremock::{Mock, MockServer, ResponseTemplate};
1124
1125 async fn mount_tags_ok(server: &MockServer, models: serde_json::Value) {
1128 Mock::given(method("GET"))
1129 .and(path("/api/tags"))
1130 .respond_with(ResponseTemplate::new(200).set_body_json(models))
1131 .mount(server)
1132 .await;
1133 }
1134
1135 async fn build_client(uri: String, model: &'static str) -> OllamaClient {
1139 tokio::task::spawn_blocking(move || OllamaClient::new_with_url(&uri, model).unwrap())
1140 .await
1141 .unwrap()
1142 }
1143
1144 #[tokio::test(flavor = "multi_thread")]
1147 async fn test_is_available_returns_false_on_connection_refused() {
1148 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1152 let port = listener.local_addr().unwrap().port();
1153 drop(listener);
1154 let url = format!("http://127.0.0.1:{port}");
1155
1156 let result = tokio::task::spawn_blocking(move || {
1161 let client = reqwest::blocking::Client::builder()
1164 .timeout(std::time::Duration::from_secs(5))
1165 .build()
1166 .unwrap();
1167 let probe = format!("{url}/api/tags");
1168 client
1169 .get(&probe)
1170 .send()
1171 .is_ok_and(|r| r.status().is_success())
1172 })
1173 .await
1174 .unwrap();
1175
1176 assert!(
1177 !result,
1178 "is_available should return false when nothing is listening"
1179 );
1180 }
1181
1182 #[tokio::test(flavor = "multi_thread")]
1183 async fn test_is_available_returns_false_on_500_response() {
1184 let server = MockServer::start().await;
1185 Mock::given(method("GET"))
1186 .and(path("/api/tags"))
1187 .respond_with(ResponseTemplate::new(500))
1188 .mount(&server)
1189 .await;
1190
1191 let uri = server.uri();
1192 let result = tokio::task::spawn_blocking(move || {
1193 OllamaClient::new_with_url(&uri, "test-model")
1196 })
1197 .await
1198 .unwrap();
1199
1200 let err = match result {
1203 Ok(_) => panic!("client construction should fail on 500"),
1204 Err(e) => e.to_string(),
1205 };
1206 assert!(
1207 err.contains("not running") || err.contains("not reachable"),
1208 "expected unreachable-style error, got: {err}"
1209 );
1210 }
1211
1212 #[tokio::test(flavor = "multi_thread")]
1213 async fn test_is_available_returns_true_on_200_with_json_body() {
1214 let server = MockServer::start().await;
1215 mount_tags_ok(&server, json!({"models": []})).await;
1216
1217 let uri = server.uri();
1218 let available = tokio::task::spawn_blocking(move || {
1219 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1220 client.is_available()
1221 })
1222 .await
1223 .unwrap();
1224 assert!(available);
1225 }
1226
1227 #[tokio::test(flavor = "multi_thread")]
1230 async fn test_pull_if_missing_skips_pull_if_model_already_in_tags() {
1231 let server = MockServer::start().await;
1232 Mock::given(method("GET"))
1234 .and(path("/api/tags"))
1235 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1236 "models": [
1237 {"name": "test-model:latest"},
1238 ]
1239 })))
1240 .mount(&server)
1241 .await;
1242
1243 Mock::given(method("POST"))
1247 .and(path("/api/pull"))
1248 .respond_with(ResponseTemplate::new(200))
1249 .expect(0)
1250 .mount(&server)
1251 .await;
1252
1253 let uri = server.uri();
1254 let result = tokio::task::spawn_blocking(move || {
1255 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1256 client.ensure_model()
1257 })
1258 .await
1259 .unwrap();
1260 assert!(
1261 result.is_ok(),
1262 "ensure_model should succeed; got {result:?}"
1263 );
1264 }
1265
1266 #[tokio::test(flavor = "multi_thread")]
1267 async fn test_pull_if_missing_initiates_pull_if_not() {
1268 let server = MockServer::start().await;
1269 Mock::given(method("GET"))
1271 .and(path("/api/tags"))
1272 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
1273 .mount(&server)
1274 .await;
1275 Mock::given(method("POST"))
1277 .and(path("/api/pull"))
1278 .and(body_partial_json(json!({"name": "test-model"})))
1279 .respond_with(ResponseTemplate::new(200).set_body_string(""))
1280 .expect(1)
1281 .mount(&server)
1282 .await;
1283
1284 let uri = server.uri();
1285 let result = tokio::task::spawn_blocking(move || {
1286 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1287 client.ensure_model()
1288 })
1289 .await
1290 .unwrap();
1291 assert!(
1292 result.is_ok(),
1293 "ensure_model should succeed; got {result:?}"
1294 );
1295 }
1297
1298 #[tokio::test(flavor = "multi_thread")]
1301 async fn test_generate_parses_success_response() {
1302 let server = MockServer::start().await;
1303 mount_tags_ok(&server, json!({"models": []})).await;
1304 Mock::given(method("POST"))
1307 .and(path("/api/chat"))
1308 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1309 "message": {"role": "assistant", "content": "hello"},
1310 "done": true,
1311 })))
1312 .mount(&server)
1313 .await;
1314
1315 let uri = server.uri();
1316 let result = tokio::task::spawn_blocking(move || {
1317 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1318 client.generate("ping", None)
1319 })
1320 .await
1321 .unwrap();
1322
1323 assert_eq!(result.unwrap(), "hello");
1324 }
1325
1326 #[tokio::test(flavor = "multi_thread")]
1327 async fn test_generate_returns_error_on_malformed_json() {
1328 let server = MockServer::start().await;
1329 mount_tags_ok(&server, json!({"models": []})).await;
1330 Mock::given(method("POST"))
1331 .and(path("/api/chat"))
1332 .respond_with(
1333 ResponseTemplate::new(200)
1334 .set_body_string("{not valid json")
1335 .insert_header("content-type", "application/json"),
1336 )
1337 .mount(&server)
1338 .await;
1339
1340 let uri = server.uri();
1341 let result = tokio::task::spawn_blocking(move || {
1342 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1343 client.generate("ping", None)
1344 })
1345 .await
1346 .unwrap();
1347
1348 assert!(result.is_err(), "malformed JSON should surface an error");
1349 let err = result.unwrap_err().to_string();
1350 assert!(
1351 err.contains("parse") || err.to_lowercase().contains("json"),
1352 "expected a parse error, got: {err}"
1353 );
1354 }
1355
1356 #[tokio::test(flavor = "multi_thread")]
1357 async fn test_generate_returns_error_on_500() {
1358 let server = MockServer::start().await;
1359 mount_tags_ok(&server, json!({"models": []})).await;
1360 Mock::given(method("POST"))
1361 .and(path("/api/chat"))
1362 .respond_with(ResponseTemplate::new(500).set_body_string("internal boom"))
1363 .mount(&server)
1364 .await;
1365
1366 let uri = server.uri();
1367 let result = tokio::task::spawn_blocking(move || {
1368 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1369 client.generate("ping", None)
1370 })
1371 .await
1372 .unwrap();
1373
1374 assert!(result.is_err());
1375 let err = result.unwrap_err().to_string();
1376 assert!(err.contains("500") || err.contains("Chat generate failed"));
1377 }
1378
1379 #[tokio::test(flavor = "multi_thread")]
1380 async fn test_generate_passes_system_prompt_when_provided() {
1381 let server = MockServer::start().await;
1385 mount_tags_ok(&server, json!({"models": []})).await;
1386 Mock::given(method("POST"))
1387 .and(path("/api/chat"))
1388 .and(body_partial_json(json!({
1389 "messages": [
1390 {"role": "system", "content": "be terse"},
1391 {"role": "user", "content": "hi"},
1392 ],
1393 "stream": false,
1394 })))
1395 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1396 "message": {"role": "assistant", "content": "ok"},
1397 })))
1398 .mount(&server)
1399 .await;
1400
1401 let uri = server.uri();
1402 let out = tokio::task::spawn_blocking(move || {
1403 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1404 client.generate("hi", Some("be terse"))
1405 })
1406 .await
1407 .unwrap();
1408 assert_eq!(out.unwrap(), "ok");
1409 }
1410
1411 #[tokio::test(flavor = "multi_thread")]
1414 async fn test_embed_parses_embedding_array() {
1415 let server = MockServer::start().await;
1416 mount_tags_ok(&server, json!({"models": []})).await;
1417 Mock::given(method("POST"))
1419 .and(path("/api/embed"))
1420 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1421 "embeddings": [[0.1_f32, 0.2_f32, 0.3_f32]],
1422 })))
1423 .mount(&server)
1424 .await;
1425
1426 let uri = server.uri();
1427 let vec = tokio::task::spawn_blocking(move || {
1428 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1429 client.embed_text("hello", "nomic-embed-text-v1.5")
1430 })
1431 .await
1432 .unwrap();
1433
1434 let v = vec.unwrap();
1435 assert_eq!(v.len(), 3);
1436 assert!((v[0] - 0.1_f32).abs() < 1e-5);
1437 assert!((v[1] - 0.2_f32).abs() < 1e-5);
1438 assert!((v[2] - 0.3_f32).abs() < 1e-5);
1439 }
1440
1441 #[tokio::test(flavor = "multi_thread")]
1442 async fn test_embed_returns_error_on_wrong_shape() {
1443 let server = MockServer::start().await;
1444 mount_tags_ok(&server, json!({"models": []})).await;
1445 Mock::given(method("POST"))
1448 .and(path("/api/embed"))
1449 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1450 "embedding": 0.5,
1451 })))
1452 .mount(&server)
1453 .await;
1454
1455 let uri = server.uri();
1456 let result = tokio::task::spawn_blocking(move || {
1457 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1458 client.embed_text("hi", "nomic-embed-text")
1459 })
1460 .await
1461 .unwrap();
1462 assert!(result.is_err());
1463 let err = result.unwrap_err().to_string();
1464 assert!(
1465 err.contains("Missing embeddings") || err.to_lowercase().contains("embed"),
1466 "expected missing-embeddings error, got: {err}"
1467 );
1468 }
1469
1470 #[tokio::test(flavor = "multi_thread")]
1471 async fn test_embed_returns_error_on_500() {
1472 let server = MockServer::start().await;
1473 mount_tags_ok(&server, json!({"models": []})).await;
1474 Mock::given(method("POST"))
1475 .and(path("/api/embed"))
1476 .respond_with(ResponseTemplate::new(500).set_body_string("nope"))
1477 .mount(&server)
1478 .await;
1479
1480 let uri = server.uri();
1481 let result = tokio::task::spawn_blocking(move || {
1482 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1483 client.embed_text("hi", "nomic-embed-text")
1484 })
1485 .await
1486 .unwrap();
1487 assert!(result.is_err());
1488 assert!(result.unwrap_err().to_string().contains("500"));
1489 }
1490
1491 #[tokio::test(flavor = "multi_thread")]
1494 async fn test_expand_query_returns_parsed_terms_one_per_line() {
1495 let server = MockServer::start().await;
1496 mount_tags_ok(&server, json!({"models": []})).await;
1497 Mock::given(method("POST"))
1498 .and(path("/api/chat"))
1499 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1500 "message": {"content": "term1\nterm2\nterm3\n\n"},
1502 })))
1503 .mount(&server)
1504 .await;
1505
1506 let uri = server.uri();
1507 let terms = tokio::task::spawn_blocking(move || {
1508 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1509 client.expand_query("anything")
1510 })
1511 .await
1512 .unwrap();
1513 assert_eq!(
1514 terms.unwrap(),
1515 vec![
1516 "term1".to_string(),
1517 "term2".to_string(),
1518 "term3".to_string()
1519 ]
1520 );
1521 }
1522
1523 #[tokio::test(flavor = "multi_thread")]
1524 async fn test_auto_tag_returns_parsed_tags() {
1525 let server = MockServer::start().await;
1526 mount_tags_ok(&server, json!({"models": []})).await;
1527 Mock::given(method("POST"))
1531 .and(path("/api/chat"))
1532 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1533 "message": {"content": "Tag1\nTAG2\ntag3"},
1534 })))
1535 .mount(&server)
1536 .await;
1537
1538 let uri = server.uri();
1539 let tags = tokio::task::spawn_blocking(move || {
1540 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1541 client.auto_tag("Title", "content")
1542 })
1543 .await
1544 .unwrap();
1545 assert_eq!(
1546 tags.unwrap(),
1547 vec!["tag1".to_string(), "tag2".to_string(), "tag3".to_string()]
1548 );
1549 }
1550
1551 #[tokio::test(flavor = "multi_thread")]
1552 async fn test_detect_contradiction_parses_yes_no() {
1553 let server = MockServer::start().await;
1557 mount_tags_ok(&server, json!({"models": []})).await;
1558 Mock::given(method("POST"))
1559 .and(path("/api/chat"))
1560 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1561 "message": {"content": "yes\n"},
1562 })))
1563 .mount(&server)
1564 .await;
1565
1566 let uri_yes = server.uri();
1567 let yes = tokio::task::spawn_blocking(move || {
1568 let client = OllamaClient::new_with_url(&uri_yes, "test-model").unwrap();
1569 client.detect_contradiction("a", "b")
1570 })
1571 .await
1572 .unwrap();
1573 assert!(yes.unwrap(), "'yes' should be detected as contradiction");
1574
1575 let server_no = MockServer::start().await;
1578 mount_tags_ok(&server_no, json!({"models": []})).await;
1579 Mock::given(method("POST"))
1580 .and(path("/api/chat"))
1581 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1582 "message": {"content": "no"},
1583 })))
1584 .mount(&server_no)
1585 .await;
1586 let uri_no = server_no.uri();
1587 let no = tokio::task::spawn_blocking(move || {
1588 let client = OllamaClient::new_with_url(&uri_no, "test-model").unwrap();
1589 client.detect_contradiction("a", "b")
1590 })
1591 .await
1592 .unwrap();
1593 assert!(!no.unwrap(), "'no' should NOT be detected as contradiction");
1594
1595 let server_garbage = MockServer::start().await;
1597 mount_tags_ok(&server_garbage, json!({"models": []})).await;
1598 Mock::given(method("POST"))
1599 .and(path("/api/chat"))
1600 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1601 "message": {"content": "definitely-not-yes-or-no"},
1602 })))
1603 .mount(&server_garbage)
1604 .await;
1605 let uri_g = server_garbage.uri();
1606 let garbage = tokio::task::spawn_blocking(move || {
1607 let client = OllamaClient::new_with_url(&uri_g, "test-model").unwrap();
1608 client.detect_contradiction("a", "b")
1609 })
1610 .await
1611 .unwrap();
1612 assert!(
1613 !garbage.unwrap(),
1614 "garbage answer should default to non-contradiction"
1615 );
1616 }
1617
1618 #[tokio::test(flavor = "multi_thread")]
1621 async fn test_ensure_embed_model_skips_pull_if_present() {
1622 let server = MockServer::start().await;
1623 Mock::given(method("GET"))
1624 .and(path("/api/tags"))
1625 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1626 "models": [{"name": "nomic-embed-text:latest"}]
1627 })))
1628 .mount(&server)
1629 .await;
1630 Mock::given(method("POST"))
1631 .and(path("/api/pull"))
1632 .respond_with(ResponseTemplate::new(200))
1633 .expect(0)
1634 .mount(&server)
1635 .await;
1636
1637 let uri = server.uri();
1638 let r = tokio::task::spawn_blocking(move || {
1639 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
1640 client.ensure_embed_model("nomic-embed-text")
1641 })
1642 .await
1643 .unwrap();
1644 assert!(r.is_ok());
1645 }
1646}