Skip to main content

kardo_core/embeddings/
mod.rs

1//! Embedding generation and semantic search via Ollama.
2//!
3//! Uses the `nomic-embed-text` model (384-dimensional embeddings) through
4//! the Ollama `/api/embed` HTTP endpoint.
5
6pub mod storage;
7
8use serde::{Deserialize, Serialize};
9
10/// Typed errors for embedding operations.
11#[derive(Debug, thiserror::Error)]
12pub enum EmbeddingError {
13    #[error("HTTP request failed: {0}")]
14    Http(#[from] reqwest::Error),
15    #[error("Model unavailable: {0}")]
16    ModelUnavailable(String),
17    #[error("Parse error: {0}")]
18    Parse(String),
19}
20
21impl From<EmbeddingError> for String {
22    fn from(e: EmbeddingError) -> Self {
23        e.to_string()
24    }
25}
26
27const OLLAMA_BASE_URL: &str = "http://localhost:11434";
28const EMBEDDING_MODEL: &str = "nomic-embed-text";
29/// Default embedding dimension for nomic-embed-text.
30pub const EMBEDDING_DIM: usize = 768;
31
32/// Request body for Ollama /api/embed endpoint.
33#[derive(Serialize)]
34struct OllamaEmbedRequest {
35    model: String,
36    input: serde_json::Value,
37}
38
39/// Response from Ollama /api/embed endpoint.
40#[derive(Deserialize)]
41struct OllamaEmbedResponse {
42    embeddings: Vec<Vec<f32>>,
43}
44
45/// Tags response for checking available models.
46#[derive(Deserialize)]
47struct OllamaTagsResponse {
48    models: Vec<OllamaModelInfo>,
49}
50
51#[derive(Deserialize)]
52struct OllamaModelInfo {
53    name: String,
54}
55
56/// Client for generating embeddings via Ollama.
57pub struct EmbeddingClient {
58    base_url: String,
59    model: String,
60    client: reqwest::Client,
61}
62
63impl EmbeddingClient {
64    /// Create a new embedding client with default settings.
65    pub fn new() -> Self {
66        Self {
67            base_url: OLLAMA_BASE_URL.to_string(),
68            model: EMBEDDING_MODEL.to_string(),
69            client: reqwest::Client::builder()
70                .timeout(std::time::Duration::from_secs(60))
71                .build()
72                .unwrap_or_default(),
73        }
74    }
75
76    /// Create an embedding client with a custom base URL (for testing).
77    #[cfg(test)]
78    pub fn with_base_url(mut self, url: &str) -> Self {
79        self.base_url = url.to_string();
80        self
81    }
82
83    /// Generate an embedding for a single text.
84    pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
85        let url = format!("{}/api/embed", self.base_url);
86
87        let request = OllamaEmbedRequest {
88            model: self.model.clone(),
89            input: serde_json::Value::String(text.to_string()),
90        };
91
92        let resp = self
93            .client
94            .post(&url)
95            .json(&request)
96            .send()
97            .await?;
98
99        if !resp.status().is_success() {
100            return Err(EmbeddingError::ModelUnavailable(format!(
101                "Ollama returned status {}",
102                resp.status()
103            )));
104        }
105
106        let embed_resp: OllamaEmbedResponse = resp
107            .json()
108            .await
109            .map_err(|e| EmbeddingError::Parse(format!("Failed to parse embedding response: {}", e)))?;
110
111        embed_resp
112            .embeddings
113            .into_iter()
114            .next()
115            .ok_or_else(|| EmbeddingError::Parse("No embedding returned".to_string()))
116    }
117
118    /// Generate embeddings for multiple texts in a single request.
119    pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
120        if texts.is_empty() {
121            return Ok(vec![]);
122        }
123
124        let url = format!("{}/api/embed", self.base_url);
125
126        let input_array: Vec<serde_json::Value> = texts
127            .iter()
128            .map(|t| serde_json::Value::String(t.to_string()))
129            .collect();
130
131        let request = OllamaEmbedRequest {
132            model: self.model.clone(),
133            input: serde_json::Value::Array(input_array),
134        };
135
136        let resp = self
137            .client
138            .post(&url)
139            .json(&request)
140            .send()
141            .await?;
142
143        if !resp.status().is_success() {
144            return Err(EmbeddingError::ModelUnavailable(format!(
145                "Ollama returned status {}",
146                resp.status()
147            )));
148        }
149
150        let embed_resp: OllamaEmbedResponse = resp
151            .json()
152            .await
153            .map_err(|e| EmbeddingError::Parse(format!(
154                "Failed to parse batch embedding response: {}",
155                e
156            )))?;
157
158        Ok(embed_resp.embeddings)
159    }
160
161    /// Check if the embedding model is available in Ollama.
162    pub async fn check_model_available(&self) -> bool {
163        let url = format!("{}/api/tags", self.base_url);
164
165        match self.client.get(&url).send().await {
166            Ok(resp) => {
167                if let Ok(tags) = resp.json::<OllamaTagsResponse>().await {
168                    tags.models
169                        .iter()
170                        .any(|m| m.name.starts_with(&self.model) || m.name == self.model)
171                } else {
172                    false
173                }
174            }
175            Err(_) => false,
176        }
177    }
178}
179
180impl Default for EmbeddingClient {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186/// Split markdown content into chunks at heading boundaries.
187///
188/// Chunks at `##` or `###` headings, with a maximum of ~512 tokens per chunk
189/// (approximated as ~4 chars per token = 2048 chars).
190pub fn chunk_markdown(content: &str) -> Vec<String> {
191    const MAX_CHUNK_CHARS: usize = 2048; // ~512 tokens at ~4 chars/token
192
193    let mut chunks: Vec<String> = Vec::new();
194    let mut current_chunk = String::new();
195
196    for line in content.lines() {
197        let is_heading = line.starts_with("## ") || line.starts_with("### ");
198
199        if is_heading && !current_chunk.trim().is_empty() {
200            // Start a new chunk at heading boundary
201            chunks.push(current_chunk.trim().to_string());
202            current_chunk = String::new();
203        }
204
205        // If adding this line would exceed the limit, flush first
206        if !current_chunk.is_empty()
207            && current_chunk.len() + line.len() + 1 >= MAX_CHUNK_CHARS
208        {
209            chunks.push(current_chunk.trim().to_string());
210            current_chunk = String::new();
211        }
212
213        // If a single line is longer than the max, split it by words
214        if line.len() >= MAX_CHUNK_CHARS {
215            for word in line.split_whitespace() {
216                if !current_chunk.is_empty()
217                    && current_chunk.len() + word.len() + 1 >= MAX_CHUNK_CHARS
218                {
219                    chunks.push(current_chunk.trim().to_string());
220                    current_chunk = String::new();
221                }
222                if !current_chunk.is_empty() {
223                    current_chunk.push(' ');
224                }
225                current_chunk.push_str(word);
226            }
227            current_chunk.push('\n');
228        } else {
229            current_chunk.push_str(line);
230            current_chunk.push('\n');
231        }
232    }
233
234    // Don't forget the last chunk
235    if !current_chunk.trim().is_empty() {
236        chunks.push(current_chunk.trim().to_string());
237    }
238
239    // Filter out very short chunks (less than 20 chars)
240    chunks.retain(|c| c.len() >= 20);
241
242    chunks
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_chunk_markdown_by_headings() {
251        let content = r#"# Main Title
252
253Some intro text that is long enough to keep.
254
255## Section One
256
257Content for section one with enough text to be meaningful.
258
259## Section Two
260
261Content for section two with enough text to be meaningful.
262
263### Subsection 2.1
264
265Detailed content for subsection two point one.
266"#;
267
268        let chunks = chunk_markdown(content);
269        assert!(chunks.len() >= 3, "Expected at least 3 chunks, got {}", chunks.len());
270        assert!(chunks[0].contains("Main Title"));
271        assert!(chunks[1].contains("Section One"));
272    }
273
274    #[test]
275    fn test_chunk_markdown_empty() {
276        let chunks = chunk_markdown("");
277        assert!(chunks.is_empty());
278    }
279
280    #[test]
281    fn test_chunk_markdown_no_headings() {
282        let content = "This is a simple paragraph with enough content to be considered a valid chunk by the chunker.";
283        let chunks = chunk_markdown(content);
284        assert_eq!(chunks.len(), 1);
285        assert!(chunks[0].contains("simple paragraph"));
286    }
287
288    #[test]
289    fn test_chunk_markdown_long_section() {
290        // Create content that exceeds MAX_CHUNK_CHARS
291        let long_line = "a ".repeat(1100); // ~2200 chars
292        let content = format!("## Long Section\n\n{}", long_line);
293        let chunks = chunk_markdown(&content);
294        assert!(chunks.len() >= 2, "Long content should be split into multiple chunks");
295    }
296
297    #[test]
298    fn test_chunk_markdown_filters_short() {
299        let content = "## A\n\nok\n\n## B\n\nThis section has enough content to pass the minimum length filter easily.\n";
300        let chunks = chunk_markdown(content);
301        // The "## A\n\nok" chunk is < 20 chars, should be filtered
302        // "## B\n\n..." should remain
303        for chunk in &chunks {
304            assert!(chunk.len() >= 20, "Short chunks should be filtered: '{}'", chunk);
305        }
306    }
307
308    #[test]
309    fn test_embedding_client_creation() {
310        let client = EmbeddingClient::new();
311        assert_eq!(client.base_url, "http://localhost:11434");
312        assert_eq!(client.model, "nomic-embed-text");
313    }
314
315    #[test]
316    fn test_embed_request_format_single() {
317        // Verify that the request serialization is correct
318        let req = OllamaEmbedRequest {
319            model: "nomic-embed-text".to_string(),
320            input: serde_json::Value::String("hello world".to_string()),
321        };
322        let json = serde_json::to_value(&req).unwrap();
323        assert_eq!(json["model"], "nomic-embed-text");
324        assert_eq!(json["input"], "hello world");
325    }
326
327    #[test]
328    fn test_embed_request_format_batch() {
329        let texts = ["hello", "world"];
330        let input_array: Vec<serde_json::Value> = texts
331            .iter()
332            .map(|t| serde_json::Value::String(t.to_string()))
333            .collect();
334        let req = OllamaEmbedRequest {
335            model: "nomic-embed-text".to_string(),
336            input: serde_json::Value::Array(input_array),
337        };
338        let json = serde_json::to_value(&req).unwrap();
339        assert_eq!(json["model"], "nomic-embed-text");
340        assert!(json["input"].is_array());
341        assert_eq!(json["input"].as_array().unwrap().len(), 2);
342    }
343
344    #[test]
345    fn test_embed_response_parsing() {
346        let json_str = r#"{"embeddings":[[0.1, 0.2, 0.3],[0.4, 0.5, 0.6]]}"#;
347        let resp: OllamaEmbedResponse = serde_json::from_str(json_str).unwrap();
348        assert_eq!(resp.embeddings.len(), 2);
349        assert_eq!(resp.embeddings[0], vec![0.1, 0.2, 0.3]);
350        assert_eq!(resp.embeddings[1], vec![0.4, 0.5, 0.6]);
351    }
352
353    #[test]
354    fn test_tags_response_parsing() {
355        let json_str = r#"{"models":[{"name":"nomic-embed-text:latest"},{"name":"qwen3:0.6b"}]}"#;
356        let resp: OllamaTagsResponse = serde_json::from_str(json_str).unwrap();
357        assert_eq!(resp.models.len(), 2);
358        assert!(resp.models.iter().any(|m| m.name.starts_with("nomic-embed-text")));
359    }
360}