Skip to main content

leann_core/embedding/
ollama.rs

1use anyhow::{Context, Result};
2use ndarray::Array2;
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5use tracing::{info, warn};
6
7use super::EmbeddingProvider;
8use crate::settings::resolve_ollama_host;
9
10/// Maximum number of concurrent in-flight requests to Ollama.
11/// Keeps the GPU pipeline fed without overwhelming the server.
12const MAX_CONCURRENT: usize = 3;
13
14/// Shared state cloned into each spawned batch task.
15#[derive(Clone)]
16struct OllamaHandle {
17    model: String,
18    host: String,
19    client: reqwest::Client, // internally Arc'd, cheap to clone
20}
21
22/// Ollama embedding API client with pipelined batch dispatch.
23pub struct OllamaEmbedding {
24    handle: OllamaHandle,
25    dimensions: usize,
26}
27
28#[derive(Serialize)]
29struct OllamaEmbeddingRequest {
30    model: String,
31    input: Vec<String>,
32}
33
34#[derive(Deserialize)]
35struct OllamaEmbeddingResponse {
36    embeddings: Vec<Vec<f32>>,
37}
38
39impl OllamaHandle {
40    /// Send a single batch to the Ollama /api/embed endpoint.
41    async fn embed_batch(&self, batch: &[String]) -> Result<Vec<Vec<f32>>, OllamaBatchError> {
42        let request = OllamaEmbeddingRequest {
43            model: self.model.clone(),
44            input: batch.to_vec(),
45        };
46
47        let response = self
48            .client
49            .post(format!("{}/api/embed", self.host))
50            .json(&request)
51            .send()
52            .await
53            .map_err(|e| {
54                OllamaBatchError::Other(
55                    anyhow::anyhow!(e).context("sending embedding request to Ollama"),
56                )
57            })?;
58
59        let status = response.status();
60        if !status.is_success() {
61            let body = response.text().await.unwrap_or_default();
62            if status.as_u16() == 400 && body.contains("context length") {
63                return Err(OllamaBatchError::ContextLength);
64            }
65            return Err(OllamaBatchError::Other(anyhow::anyhow!(
66                "Ollama API error ({}): {}",
67                status,
68                body
69            )));
70        }
71
72        let resp: OllamaEmbeddingResponse = response.json().await.map_err(|e| {
73            OllamaBatchError::Other(anyhow::anyhow!(e).context("parsing Ollama embedding response"))
74        })?;
75
76        Ok(resp.embeddings)
77    }
78
79    /// Embed a slice of chunks, halving the sub-batch size on context-length errors.
80    #[allow(clippy::type_complexity)]
81    fn embed_with_backoff<'a>(
82        &'a self,
83        chunks: &'a [String],
84        batch_size: usize,
85    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<Vec<f32>>>> + Send + 'a>>
86    {
87        Box::pin(async move {
88            match self.embed_batch(chunks).await {
89                Ok(embeddings) => Ok(embeddings),
90                Err(OllamaBatchError::ContextLength) => {
91                    if chunks.len() == 1 {
92                        return self.truncate_single_chunk(&chunks[0]).await;
93                    }
94                    let smaller = batch_size / 2;
95                    warn!(
96                        "Batch of {} chunks exceeded context length, retrying with batch size {}",
97                        chunks.len(),
98                        smaller
99                    );
100                    let mut results = Vec::with_capacity(chunks.len());
101                    for sub_batch in chunks.chunks(smaller.max(1)) {
102                        results.extend(self.embed_with_backoff(sub_batch, smaller).await?);
103                    }
104                    Ok(results)
105                }
106                Err(OllamaBatchError::Other(e)) => Err(e),
107            }
108        })
109    }
110
111    /// Binary-search for the largest prefix of an oversized chunk that fits.
112    async fn truncate_single_chunk(&self, chunk: &str) -> Result<Vec<Vec<f32>>> {
113        let original_len = chunk.len();
114        let mut lo = 0usize;
115        let mut hi = original_len;
116        let mut last_good = None;
117
118        while lo < hi {
119            let mid = (lo + hi) / 2;
120            let truncated = vec![chunk[..mid].to_string()];
121            match self.embed_batch(&truncated).await {
122                Ok(emb) => {
123                    last_good = Some(emb);
124                    lo = mid + 1;
125                }
126                Err(OllamaBatchError::ContextLength) => {
127                    hi = mid;
128                }
129                Err(OllamaBatchError::Other(e)) => return Err(e),
130            }
131        }
132
133        if let Some(embeddings) = last_good {
134            warn!(
135                "Truncated oversized chunk from {} to ~{} chars to fit Ollama context",
136                original_len,
137                lo.saturating_sub(1)
138            );
139            return Ok(embeddings);
140        }
141
142        anyhow::bail!(
143            "Single chunk exceeds Ollama context length ({} chars) \
144             and could not be truncated to fit.",
145            original_len
146        );
147    }
148}
149
150impl OllamaEmbedding {
151    pub fn new(model: &str, host: Option<&str>) -> Self {
152        Self {
153            handle: OllamaHandle {
154                model: model.to_string(),
155                host: resolve_ollama_host(host),
156                client: reqwest::Client::new(),
157            },
158            dimensions: 768, // Default, will be updated on first compute
159        }
160    }
161
162    /// Dispatch all batches concurrently (up to MAX_CONCURRENT in-flight) so the
163    /// GPU never idles waiting for the next request.
164    async fn compute_embeddings_async(&self, chunks: &[String]) -> Result<Array2<f32>> {
165        if chunks.is_empty() {
166            return Ok(Array2::zeros((0, self.dimensions)));
167        }
168
169        let batch_size: usize = 128;
170        let num_batches = chunks.len().div_ceil(batch_size);
171
172        info!(
173            "Ollama embedding: {} chunks in {} batches (concurrency={})",
174            chunks.len(),
175            num_batches,
176            MAX_CONCURRENT
177        );
178
179        let semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT));
180        let mut handles = Vec::with_capacity(num_batches);
181
182        for (i, batch) in chunks.chunks(batch_size).enumerate() {
183            let sem = semaphore.clone();
184            let handle = self.handle.clone();
185            let batch_owned: Vec<String> = batch.to_vec();
186            let batch_idx = i;
187            let total = num_batches;
188
189            handles.push(tokio::spawn(async move {
190                let _permit = sem.acquire().await.expect("semaphore closed");
191                info!(
192                    "Ollama embedding batch {}/{} ({} chunks)",
193                    batch_idx + 1,
194                    total,
195                    batch_owned.len()
196                );
197                let result = handle.embed_with_backoff(&batch_owned, batch_size).await;
198                (batch_idx, result)
199            }));
200        }
201
202        // Collect results in order.
203        let mut indexed_results: Vec<(usize, Vec<Vec<f32>>)> = Vec::with_capacity(num_batches);
204        for h in handles {
205            let (idx, result) = h.await.context("embedding task panicked")?;
206            indexed_results.push((idx, result?));
207        }
208        indexed_results.sort_by_key(|(idx, _)| *idx);
209
210        let all_embeddings: Vec<Vec<f32>> = indexed_results
211            .into_iter()
212            .flat_map(|(_, embs)| embs)
213            .collect();
214
215        if all_embeddings.is_empty() {
216            return Ok(Array2::zeros((0, self.dimensions)));
217        }
218
219        let n = all_embeddings.len();
220        let d = all_embeddings[0].len();
221        let flat: Vec<f32> = all_embeddings.into_iter().flatten().collect();
222
223        Array2::from_shape_vec((n, d), flat).context("reshaping Ollama embeddings")
224    }
225}
226
227enum OllamaBatchError {
228    ContextLength,
229    Other(anyhow::Error),
230}
231
232impl EmbeddingProvider for OllamaEmbedding {
233    fn compute_embeddings(
234        &self,
235        chunks: &[String],
236        _progress: Option<&dyn crate::hnsw::IndexProgress>,
237    ) -> Result<Array2<f32>> {
238        // Enter async context - reuse existing runtime if available,
239        // otherwise create one.
240        match tokio::runtime::Handle::try_current() {
241            Ok(handle) => {
242                // We're inside a tokio runtime already but in a sync call.
243                // Spawn a thread to avoid blocking the runtime with block_on.
244                std::thread::scope(|s| {
245                    s.spawn(|| handle.block_on(self.compute_embeddings_async(chunks)))
246                        .join()
247                        .expect("embedding thread panicked")
248                })
249            }
250            Err(_) => {
251                let rt = tokio::runtime::Runtime::new()?;
252                rt.block_on(self.compute_embeddings_async(chunks))
253            }
254        }
255    }
256
257    fn dimensions(&self) -> usize {
258        self.dimensions
259    }
260
261    fn name(&self) -> &str {
262        "ollama"
263    }
264}