leann_core/embedding/
ollama.rs1use anyhow::{Context, Result};
2use ndarray::Array2;
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5use tracing::{info, warn};
6
7use super::EmbeddingProvider;
8use crate::settings::resolve_ollama_host;
9
10const MAX_CONCURRENT: usize = 3;
13
14#[derive(Clone)]
16struct OllamaHandle {
17 model: String,
18 host: String,
19 client: reqwest::Client, }
21
22pub struct OllamaEmbedding {
24 handle: OllamaHandle,
25 dimensions: usize,
26}
27
28#[derive(Serialize)]
29struct OllamaEmbeddingRequest {
30 model: String,
31 input: Vec<String>,
32}
33
34#[derive(Deserialize)]
35struct OllamaEmbeddingResponse {
36 embeddings: Vec<Vec<f32>>,
37}
38
39impl OllamaHandle {
40 async fn embed_batch(&self, batch: &[String]) -> Result<Vec<Vec<f32>>, OllamaBatchError> {
42 let request = OllamaEmbeddingRequest {
43 model: self.model.clone(),
44 input: batch.to_vec(),
45 };
46
47 let response = self
48 .client
49 .post(format!("{}/api/embed", self.host))
50 .json(&request)
51 .send()
52 .await
53 .map_err(|e| {
54 OllamaBatchError::Other(
55 anyhow::anyhow!(e).context("sending embedding request to Ollama"),
56 )
57 })?;
58
59 let status = response.status();
60 if !status.is_success() {
61 let body = response.text().await.unwrap_or_default();
62 if status.as_u16() == 400 && body.contains("context length") {
63 return Err(OllamaBatchError::ContextLength);
64 }
65 return Err(OllamaBatchError::Other(anyhow::anyhow!(
66 "Ollama API error ({}): {}",
67 status,
68 body
69 )));
70 }
71
72 let resp: OllamaEmbeddingResponse = response.json().await.map_err(|e| {
73 OllamaBatchError::Other(anyhow::anyhow!(e).context("parsing Ollama embedding response"))
74 })?;
75
76 Ok(resp.embeddings)
77 }
78
79 #[allow(clippy::type_complexity)]
81 fn embed_with_backoff<'a>(
82 &'a self,
83 chunks: &'a [String],
84 batch_size: usize,
85 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<Vec<f32>>>> + Send + 'a>>
86 {
87 Box::pin(async move {
88 match self.embed_batch(chunks).await {
89 Ok(embeddings) => Ok(embeddings),
90 Err(OllamaBatchError::ContextLength) => {
91 if chunks.len() == 1 {
92 return self.truncate_single_chunk(&chunks[0]).await;
93 }
94 let smaller = batch_size / 2;
95 warn!(
96 "Batch of {} chunks exceeded context length, retrying with batch size {}",
97 chunks.len(),
98 smaller
99 );
100 let mut results = Vec::with_capacity(chunks.len());
101 for sub_batch in chunks.chunks(smaller.max(1)) {
102 results.extend(self.embed_with_backoff(sub_batch, smaller).await?);
103 }
104 Ok(results)
105 }
106 Err(OllamaBatchError::Other(e)) => Err(e),
107 }
108 })
109 }
110
111 async fn truncate_single_chunk(&self, chunk: &str) -> Result<Vec<Vec<f32>>> {
113 let original_len = chunk.len();
114 let mut lo = 0usize;
115 let mut hi = original_len;
116 let mut last_good = None;
117
118 while lo < hi {
119 let mid = (lo + hi) / 2;
120 let truncated = vec![chunk[..mid].to_string()];
121 match self.embed_batch(&truncated).await {
122 Ok(emb) => {
123 last_good = Some(emb);
124 lo = mid + 1;
125 }
126 Err(OllamaBatchError::ContextLength) => {
127 hi = mid;
128 }
129 Err(OllamaBatchError::Other(e)) => return Err(e),
130 }
131 }
132
133 if let Some(embeddings) = last_good {
134 warn!(
135 "Truncated oversized chunk from {} to ~{} chars to fit Ollama context",
136 original_len,
137 lo.saturating_sub(1)
138 );
139 return Ok(embeddings);
140 }
141
142 anyhow::bail!(
143 "Single chunk exceeds Ollama context length ({} chars) \
144 and could not be truncated to fit.",
145 original_len
146 );
147 }
148}
149
150impl OllamaEmbedding {
151 pub fn new(model: &str, host: Option<&str>) -> Self {
152 Self {
153 handle: OllamaHandle {
154 model: model.to_string(),
155 host: resolve_ollama_host(host),
156 client: reqwest::Client::new(),
157 },
158 dimensions: 768, }
160 }
161
162 async fn compute_embeddings_async(&self, chunks: &[String]) -> Result<Array2<f32>> {
165 if chunks.is_empty() {
166 return Ok(Array2::zeros((0, self.dimensions)));
167 }
168
169 let batch_size: usize = 128;
170 let num_batches = chunks.len().div_ceil(batch_size);
171
172 info!(
173 "Ollama embedding: {} chunks in {} batches (concurrency={})",
174 chunks.len(),
175 num_batches,
176 MAX_CONCURRENT
177 );
178
179 let semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT));
180 let mut handles = Vec::with_capacity(num_batches);
181
182 for (i, batch) in chunks.chunks(batch_size).enumerate() {
183 let sem = semaphore.clone();
184 let handle = self.handle.clone();
185 let batch_owned: Vec<String> = batch.to_vec();
186 let batch_idx = i;
187 let total = num_batches;
188
189 handles.push(tokio::spawn(async move {
190 let _permit = sem.acquire().await.expect("semaphore closed");
191 info!(
192 "Ollama embedding batch {}/{} ({} chunks)",
193 batch_idx + 1,
194 total,
195 batch_owned.len()
196 );
197 let result = handle.embed_with_backoff(&batch_owned, batch_size).await;
198 (batch_idx, result)
199 }));
200 }
201
202 let mut indexed_results: Vec<(usize, Vec<Vec<f32>>)> = Vec::with_capacity(num_batches);
204 for h in handles {
205 let (idx, result) = h.await.context("embedding task panicked")?;
206 indexed_results.push((idx, result?));
207 }
208 indexed_results.sort_by_key(|(idx, _)| *idx);
209
210 let all_embeddings: Vec<Vec<f32>> = indexed_results
211 .into_iter()
212 .flat_map(|(_, embs)| embs)
213 .collect();
214
215 if all_embeddings.is_empty() {
216 return Ok(Array2::zeros((0, self.dimensions)));
217 }
218
219 let n = all_embeddings.len();
220 let d = all_embeddings[0].len();
221 let flat: Vec<f32> = all_embeddings.into_iter().flatten().collect();
222
223 Array2::from_shape_vec((n, d), flat).context("reshaping Ollama embeddings")
224 }
225}
226
227enum OllamaBatchError {
228 ContextLength,
229 Other(anyhow::Error),
230}
231
232impl EmbeddingProvider for OllamaEmbedding {
233 fn compute_embeddings(
234 &self,
235 chunks: &[String],
236 _progress: Option<&dyn crate::hnsw::IndexProgress>,
237 ) -> Result<Array2<f32>> {
238 match tokio::runtime::Handle::try_current() {
241 Ok(handle) => {
242 std::thread::scope(|s| {
245 s.spawn(|| handle.block_on(self.compute_embeddings_async(chunks)))
246 .join()
247 .expect("embedding thread panicked")
248 })
249 }
250 Err(_) => {
251 let rt = tokio::runtime::Runtime::new()?;
252 rt.block_on(self.compute_embeddings_async(chunks))
253 }
254 }
255 }
256
257 fn dimensions(&self) -> usize {
258 self.dimensions
259 }
260
261 fn name(&self) -> &str {
262 "ollama"
263 }
264}