kardo_core/embeddings/
mod.rs1pub mod storage;
7
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, thiserror::Error)]
12pub enum EmbeddingError {
13 #[error("HTTP request failed: {0}")]
14 Http(#[from] reqwest::Error),
15 #[error("Model unavailable: {0}")]
16 ModelUnavailable(String),
17 #[error("Parse error: {0}")]
18 Parse(String),
19}
20
21impl From<EmbeddingError> for String {
22 fn from(e: EmbeddingError) -> Self {
23 e.to_string()
24 }
25}
26
27const OLLAMA_BASE_URL: &str = "http://localhost:11434";
28const EMBEDDING_MODEL: &str = "nomic-embed-text";
29pub const EMBEDDING_DIM: usize = 768;
31
32#[derive(Serialize)]
34struct OllamaEmbedRequest {
35 model: String,
36 input: serde_json::Value,
37}
38
39#[derive(Deserialize)]
41struct OllamaEmbedResponse {
42 embeddings: Vec<Vec<f32>>,
43}
44
45#[derive(Deserialize)]
47struct OllamaTagsResponse {
48 models: Vec<OllamaModelInfo>,
49}
50
51#[derive(Deserialize)]
52struct OllamaModelInfo {
53 name: String,
54}
55
56pub struct EmbeddingClient {
58 base_url: String,
59 model: String,
60 client: reqwest::Client,
61}
62
63impl EmbeddingClient {
64 pub fn new() -> Self {
66 Self {
67 base_url: OLLAMA_BASE_URL.to_string(),
68 model: EMBEDDING_MODEL.to_string(),
69 client: reqwest::Client::builder()
70 .timeout(std::time::Duration::from_secs(60))
71 .build()
72 .unwrap_or_default(),
73 }
74 }
75
76 #[cfg(test)]
78 pub fn with_base_url(mut self, url: &str) -> Self {
79 self.base_url = url.to_string();
80 self
81 }
82
83 pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
85 let url = format!("{}/api/embed", self.base_url);
86
87 let request = OllamaEmbedRequest {
88 model: self.model.clone(),
89 input: serde_json::Value::String(text.to_string()),
90 };
91
92 let resp = self
93 .client
94 .post(&url)
95 .json(&request)
96 .send()
97 .await?;
98
99 if !resp.status().is_success() {
100 return Err(EmbeddingError::ModelUnavailable(format!(
101 "Ollama returned status {}",
102 resp.status()
103 )));
104 }
105
106 let embed_resp: OllamaEmbedResponse = resp
107 .json()
108 .await
109 .map_err(|e| EmbeddingError::Parse(format!("Failed to parse embedding response: {}", e)))?;
110
111 embed_resp
112 .embeddings
113 .into_iter()
114 .next()
115 .ok_or_else(|| EmbeddingError::Parse("No embedding returned".to_string()))
116 }
117
118 pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
120 if texts.is_empty() {
121 return Ok(vec![]);
122 }
123
124 let url = format!("{}/api/embed", self.base_url);
125
126 let input_array: Vec<serde_json::Value> = texts
127 .iter()
128 .map(|t| serde_json::Value::String(t.to_string()))
129 .collect();
130
131 let request = OllamaEmbedRequest {
132 model: self.model.clone(),
133 input: serde_json::Value::Array(input_array),
134 };
135
136 let resp = self
137 .client
138 .post(&url)
139 .json(&request)
140 .send()
141 .await?;
142
143 if !resp.status().is_success() {
144 return Err(EmbeddingError::ModelUnavailable(format!(
145 "Ollama returned status {}",
146 resp.status()
147 )));
148 }
149
150 let embed_resp: OllamaEmbedResponse = resp
151 .json()
152 .await
153 .map_err(|e| EmbeddingError::Parse(format!(
154 "Failed to parse batch embedding response: {}",
155 e
156 )))?;
157
158 Ok(embed_resp.embeddings)
159 }
160
161 pub async fn check_model_available(&self) -> bool {
163 let url = format!("{}/api/tags", self.base_url);
164
165 match self.client.get(&url).send().await {
166 Ok(resp) => {
167 if let Ok(tags) = resp.json::<OllamaTagsResponse>().await {
168 tags.models
169 .iter()
170 .any(|m| m.name.starts_with(&self.model) || m.name == self.model)
171 } else {
172 false
173 }
174 }
175 Err(_) => false,
176 }
177 }
178}
179
180impl Default for EmbeddingClient {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186pub fn chunk_markdown(content: &str) -> Vec<String> {
191 const MAX_CHUNK_CHARS: usize = 2048; let mut chunks: Vec<String> = Vec::new();
194 let mut current_chunk = String::new();
195
196 for line in content.lines() {
197 let is_heading = line.starts_with("## ") || line.starts_with("### ");
198
199 if is_heading && !current_chunk.trim().is_empty() {
200 chunks.push(current_chunk.trim().to_string());
202 current_chunk = String::new();
203 }
204
205 if !current_chunk.is_empty()
207 && current_chunk.len() + line.len() + 1 >= MAX_CHUNK_CHARS
208 {
209 chunks.push(current_chunk.trim().to_string());
210 current_chunk = String::new();
211 }
212
213 if line.len() >= MAX_CHUNK_CHARS {
215 for word in line.split_whitespace() {
216 if !current_chunk.is_empty()
217 && current_chunk.len() + word.len() + 1 >= MAX_CHUNK_CHARS
218 {
219 chunks.push(current_chunk.trim().to_string());
220 current_chunk = String::new();
221 }
222 if !current_chunk.is_empty() {
223 current_chunk.push(' ');
224 }
225 current_chunk.push_str(word);
226 }
227 current_chunk.push('\n');
228 } else {
229 current_chunk.push_str(line);
230 current_chunk.push('\n');
231 }
232 }
233
234 if !current_chunk.trim().is_empty() {
236 chunks.push(current_chunk.trim().to_string());
237 }
238
239 chunks.retain(|c| c.len() >= 20);
241
242 chunks
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_chunk_markdown_by_headings() {
251 let content = r#"# Main Title
252
253Some intro text that is long enough to keep.
254
255## Section One
256
257Content for section one with enough text to be meaningful.
258
259## Section Two
260
261Content for section two with enough text to be meaningful.
262
263### Subsection 2.1
264
265Detailed content for subsection two point one.
266"#;
267
268 let chunks = chunk_markdown(content);
269 assert!(chunks.len() >= 3, "Expected at least 3 chunks, got {}", chunks.len());
270 assert!(chunks[0].contains("Main Title"));
271 assert!(chunks[1].contains("Section One"));
272 }
273
274 #[test]
275 fn test_chunk_markdown_empty() {
276 let chunks = chunk_markdown("");
277 assert!(chunks.is_empty());
278 }
279
280 #[test]
281 fn test_chunk_markdown_no_headings() {
282 let content = "This is a simple paragraph with enough content to be considered a valid chunk by the chunker.";
283 let chunks = chunk_markdown(content);
284 assert_eq!(chunks.len(), 1);
285 assert!(chunks[0].contains("simple paragraph"));
286 }
287
288 #[test]
289 fn test_chunk_markdown_long_section() {
290 let long_line = "a ".repeat(1100); let content = format!("## Long Section\n\n{}", long_line);
293 let chunks = chunk_markdown(&content);
294 assert!(chunks.len() >= 2, "Long content should be split into multiple chunks");
295 }
296
297 #[test]
298 fn test_chunk_markdown_filters_short() {
299 let content = "## A\n\nok\n\n## B\n\nThis section has enough content to pass the minimum length filter easily.\n";
300 let chunks = chunk_markdown(content);
301 for chunk in &chunks {
304 assert!(chunk.len() >= 20, "Short chunks should be filtered: '{}'", chunk);
305 }
306 }
307
308 #[test]
309 fn test_embedding_client_creation() {
310 let client = EmbeddingClient::new();
311 assert_eq!(client.base_url, "http://localhost:11434");
312 assert_eq!(client.model, "nomic-embed-text");
313 }
314
315 #[test]
316 fn test_embed_request_format_single() {
317 let req = OllamaEmbedRequest {
319 model: "nomic-embed-text".to_string(),
320 input: serde_json::Value::String("hello world".to_string()),
321 };
322 let json = serde_json::to_value(&req).unwrap();
323 assert_eq!(json["model"], "nomic-embed-text");
324 assert_eq!(json["input"], "hello world");
325 }
326
327 #[test]
328 fn test_embed_request_format_batch() {
329 let texts = ["hello", "world"];
330 let input_array: Vec<serde_json::Value> = texts
331 .iter()
332 .map(|t| serde_json::Value::String(t.to_string()))
333 .collect();
334 let req = OllamaEmbedRequest {
335 model: "nomic-embed-text".to_string(),
336 input: serde_json::Value::Array(input_array),
337 };
338 let json = serde_json::to_value(&req).unwrap();
339 assert_eq!(json["model"], "nomic-embed-text");
340 assert!(json["input"].is_array());
341 assert_eq!(json["input"].as_array().unwrap().len(), 2);
342 }
343
344 #[test]
345 fn test_embed_response_parsing() {
346 let json_str = r#"{"embeddings":[[0.1, 0.2, 0.3],[0.4, 0.5, 0.6]]}"#;
347 let resp: OllamaEmbedResponse = serde_json::from_str(json_str).unwrap();
348 assert_eq!(resp.embeddings.len(), 2);
349 assert_eq!(resp.embeddings[0], vec![0.1, 0.2, 0.3]);
350 assert_eq!(resp.embeddings[1], vec![0.4, 0.5, 0.6]);
351 }
352
353 #[test]
354 fn test_tags_response_parsing() {
355 let json_str = r#"{"models":[{"name":"nomic-embed-text:latest"},{"name":"qwen3:0.6b"}]}"#;
356 let resp: OllamaTagsResponse = serde_json::from_str(json_str).unwrap();
357 assert_eq!(resp.models.len(), 2);
358 assert!(resp.models.iter().any(|m| m.name.starts_with("nomic-embed-text")));
359 }
360}