memvid_cli/
nvidia_embeddings.rs

1//! NVIDIA Embeddings Provider
2//!
3//! Implements high-quality embeddings via NVIDIA Integrate API (`/v1/embeddings`).
4//!
5//! ## Environment Variables
6//! - `NVIDIA_API_KEY`: Required
7//! - `NVIDIA_BASE_URL`: Optional (default: https://integrate.api.nvidia.com)
8//! - `NVIDIA_EMBEDDING_MODEL`: Optional (default: nvidia/nv-embed-v1)
9//! - `NVIDIA_EMBEDDING_BATCH_SIZE`: Optional (default: 64)
10//!
11//! The NVIDIA API supports different `input_type` values. Memvid uses:
12//! - `passage` for document/chunk embeddings
13//! - `query` for query embeddings
14
15use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use reqwest::StatusCode;
18use serde::{Deserialize, Serialize};
19use std::time::Duration;
20use tracing::warn;
21
22const DEFAULT_NVIDIA_BASE_URL: &str = "https://integrate.api.nvidia.com";
23const DEFAULT_NVIDIA_EMBEDDING_MODEL: &str = "nvidia/nv-embed-v1";
24
25const DEFAULT_BATCH_SIZE: usize = 64;
26const MAX_BATCH_SIZE: usize = 256;
27
28const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
29
30fn truncate_to_chars(text: &str, max_chars: usize) -> String {
31    if text.len() <= max_chars {
32        return text.to_string();
33    }
34
35    let truncated = &text[..max_chars];
36    let end = truncated
37        .char_indices()
38        .rev()
39        .next()
40        .map(|(i, c)| i + c.len_utf8())
41        .unwrap_or(max_chars);
42    text[..end].to_string()
43}
44
45fn extract_error_message(body: &str) -> Option<String> {
46    let value: serde_json::Value = serde_json::from_str(body).ok()?;
47    if let Some(message) = value.get("error").and_then(|v| v.as_str()) {
48        return Some(message.to_string());
49    }
50    if let Some(message) = value.get("message").and_then(|v| v.as_str()) {
51        return Some(message.to_string());
52    }
53    None
54}
55
56fn parse_token_limit_error(message: &str) -> Option<(usize, usize)> {
57    if !message
58        .to_ascii_lowercase()
59        .contains("exceeds maximum allowed token size")
60    {
61        return None;
62    }
63
64    let mut numbers = Vec::new();
65    let mut current = String::new();
66    for ch in message.chars() {
67        if ch.is_ascii_digit() {
68            current.push(ch);
69        } else if !current.is_empty() {
70            if let Ok(value) = current.parse::<usize>() {
71                numbers.push(value);
72            }
73            current.clear();
74        }
75    }
76    if !current.is_empty() {
77        if let Ok(value) = current.parse::<usize>() {
78            numbers.push(value);
79        }
80    }
81
82    if numbers.len() >= 2 {
83        Some((numbers[0], numbers[1]))
84    } else {
85        None
86    }
87}
88
89#[derive(Debug, Serialize)]
90struct NvidiaEmbeddingRequest<'a> {
91    input: Vec<&'a str>,
92    model: &'a str,
93    #[serde(rename = "input_type")]
94    input_type: &'a str,
95    #[serde(rename = "encoding_format")]
96    encoding_format: &'a str,
97    truncate: &'a str,
98}
99
100#[derive(Debug, Deserialize)]
101struct NvidiaEmbeddingResponse {
102    data: Vec<NvidiaEmbeddingData>,
103}
104
105#[derive(Debug, Deserialize)]
106struct NvidiaEmbeddingData {
107    embedding: Vec<f32>,
108    index: usize,
109}
110
111#[derive(Clone, Debug)]
112pub struct NvidiaEmbeddingProvider {
113    api_key: String,
114    base_url: String,
115    model: String,
116    batch_size: usize,
117    document_input_type: String,
118    query_input_type: String,
119    encoding_format: String,
120    truncate: String,
121    client: Client,
122}
123
124impl NvidiaEmbeddingProvider {
125    pub fn from_env(explicit_model_override: Option<&str>) -> Result<Self> {
126        let api_key = std::env::var("NVIDIA_API_KEY").map_err(|_| {
127            anyhow!("NVIDIA_API_KEY environment variable is required for NVIDIA embeddings")
128        })?;
129        if api_key.trim().is_empty() {
130            bail!("NVIDIA_API_KEY cannot be empty");
131        }
132
133        let base_url = std::env::var("NVIDIA_BASE_URL")
134            .unwrap_or_else(|_| DEFAULT_NVIDIA_BASE_URL.to_string());
135        let base_url = base_url.trim().trim_end_matches('/').to_string();
136        if base_url.is_empty() {
137            bail!("NVIDIA_BASE_URL cannot be empty");
138        }
139
140        let model = explicit_model_override
141            .and_then(|value| {
142                let trimmed = value.trim();
143                (!trimmed.is_empty()).then_some(trimmed.to_string())
144            })
145            .or_else(|| {
146                std::env::var("NVIDIA_EMBEDDING_MODEL")
147                    .ok()
148                    .map(|s| s.trim().to_string())
149            })
150            .filter(|value| !value.is_empty())
151            .unwrap_or_else(|| DEFAULT_NVIDIA_EMBEDDING_MODEL.to_string());
152
153        let batch_size = std::env::var("NVIDIA_EMBEDDING_BATCH_SIZE")
154            .ok()
155            .and_then(|value| value.trim().parse::<usize>().ok())
156            .unwrap_or(DEFAULT_BATCH_SIZE)
157            .clamp(1, MAX_BATCH_SIZE);
158
159        let client = crate::http::blocking_client(REQUEST_TIMEOUT)
160            .map_err(|err| anyhow!("Failed to create HTTP client: {err}"))?;
161
162        let truncate = std::env::var("NVIDIA_EMBEDDING_TRUNCATE")
163            .ok()
164            .map(|value| value.trim().to_string())
165            .filter(|value| !value.is_empty())
166            .unwrap_or_else(|| "NONE".to_string());
167
168        Ok(Self {
169            api_key,
170            base_url,
171            model,
172            batch_size,
173            document_input_type: "passage".to_string(),
174            query_input_type: "query".to_string(),
175            encoding_format: "float".to_string(),
176            truncate,
177            client,
178        })
179    }
180
181    pub fn kind(&self) -> &'static str {
182        "nvidia"
183    }
184
185    pub fn model(&self) -> &str {
186        &self.model
187    }
188
189    pub fn embed_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
190        self.embed_batch_with_retry(&self.document_input_type, texts, 3)
191    }
192
193    pub fn embed_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
194        self.embed_batch_with_retry(&self.query_input_type, texts, 3)
195    }
196
197    pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
198        let mut out = self.embed_passages(&[text])?;
199        out.pop()
200            .ok_or_else(|| anyhow!("NVIDIA embeddings API returned no embedding output"))
201    }
202
203    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
204        let mut out = self.embed_queries(&[text])?;
205        out.pop()
206            .ok_or_else(|| anyhow!("NVIDIA embeddings API returned no embedding output"))
207    }
208
209    fn embeddings_url(&self) -> String {
210        format!("{}/v1/embeddings", self.base_url)
211    }
212
213    fn embed_batch_with_retry(
214        &self,
215        input_type: &str,
216        texts: &[&str],
217        max_retries: usize,
218    ) -> Result<Vec<Vec<f32>>> {
219        if texts.is_empty() {
220            return Ok(Vec::new());
221        }
222
223        let mut all_embeddings = Vec::with_capacity(texts.len());
224        for chunk in texts.chunks(self.batch_size) {
225            let embeddings = self.call_nvidia_with_retry(input_type, chunk, max_retries)?;
226            all_embeddings.extend(embeddings);
227        }
228
229        Ok(all_embeddings)
230    }
231
232    fn call_nvidia_with_retry(
233        &self,
234        input_type: &str,
235        texts: &[&str],
236        max_retries: usize,
237    ) -> Result<Vec<Vec<f32>>> {
238        let url = self.embeddings_url();
239        let request = NvidiaEmbeddingRequest {
240            input: texts.to_vec(),
241            model: &self.model,
242            input_type,
243            encoding_format: &self.encoding_format,
244            truncate: &self.truncate,
245        };
246
247        let mut attempt = 0usize;
248        let mut backoff = Duration::from_millis(500);
249        let max_backoff = Duration::from_secs(8);
250
251        loop {
252            attempt += 1;
253            let response = self
254                .client
255                .post(&url)
256                .bearer_auth(&self.api_key)
257                .json(&request)
258                .send();
259
260            match response {
261                Ok(resp) => {
262                    let status = resp.status();
263                    let body = resp.text().unwrap_or_default();
264
265                    if status.is_success() {
266                        let mut decoded: NvidiaEmbeddingResponse = serde_json::from_str(&body)
267                            .map_err(|err| {
268                                anyhow!("failed to decode NVIDIA embeddings response: {err}")
269                            })?;
270                        decoded.data.sort_by_key(|item| item.index);
271
272                        if decoded.data.len() != texts.len() {
273                            bail!(
274                                "NVIDIA embeddings API returned {} embeddings for {} inputs",
275                                decoded.data.len(),
276                                texts.len()
277                            );
278                        }
279
280                        let embeddings: Vec<Vec<f32>> = decoded
281                            .data
282                            .into_iter()
283                            .map(|item| item.embedding)
284                            .collect();
285
286                        if embeddings.iter().any(|emb| emb.is_empty()) {
287                            bail!("NVIDIA embeddings API returned an empty embedding vector");
288                        }
289
290                        return Ok(embeddings);
291                    }
292
293                    let retryable =
294                        status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error();
295                    if retryable && attempt <= max_retries {
296                        warn!(
297                            "NVIDIA embeddings API returned {status} (attempt {attempt}/{max_attempts}); retrying in {backoff:?}: {body}",
298                            max_attempts = max_retries + 1
299                        );
300                        std::thread::sleep(backoff);
301                        backoff = (backoff * 2).min(max_backoff);
302                        continue;
303                    }
304
305                    if status == StatusCode::BAD_REQUEST {
306                        if let Some(message) = extract_error_message(&body) {
307                            if let Some((actual, max)) = parse_token_limit_error(&message) {
308                                let mut factor =
309                                    (max as f64 / actual.max(1) as f64).clamp(0.05, 0.95) * 0.95;
310                                warn!(
311                                    "NVIDIA embeddings input exceeds token limit ({actual} > {max}); retrying with automatic truncation"
312                                );
313
314                                for _ in 0..3 {
315                                    let owned: Vec<String> = texts
316                                        .iter()
317                                        .map(|text| {
318                                            let target =
319                                                ((text.len() as f64) * factor).floor() as usize;
320                                            truncate_to_chars(text, target.max(256))
321                                        })
322                                        .collect();
323                                    let refs: Vec<&str> =
324                                        owned.iter().map(|text| text.as_str()).collect();
325                                    let request = NvidiaEmbeddingRequest {
326                                        input: refs,
327                                        model: &self.model,
328                                        input_type,
329                                        encoding_format: &self.encoding_format,
330                                        truncate: &self.truncate,
331                                    };
332
333                                    let resp = self
334                                        .client
335                                        .post(&url)
336                                        .bearer_auth(&self.api_key)
337                                        .json(&request)
338                                        .send()
339                                        .map_err(|err| {
340                                            anyhow!("NVIDIA embeddings request failed: {err}")
341                                        })?;
342
343                                    let status = resp.status();
344                                    let body = resp.text().unwrap_or_default();
345                                    if status.is_success() {
346                                        let mut decoded: NvidiaEmbeddingResponse =
347                                            serde_json::from_str(&body).map_err(|err| {
348                                                anyhow!(
349                                                    "failed to decode NVIDIA embeddings response: {err}"
350                                                )
351                                            })?;
352                                        decoded.data.sort_by_key(|item| item.index);
353
354                                        if decoded.data.len() != texts.len() {
355                                            bail!(
356                                                "NVIDIA embeddings API returned {} embeddings for {} inputs",
357                                                decoded.data.len(),
358                                                texts.len()
359                                            );
360                                        }
361
362                                        let embeddings: Vec<Vec<f32>> = decoded
363                                            .data
364                                            .into_iter()
365                                            .map(|item| item.embedding)
366                                            .collect();
367
368                                        if embeddings.iter().any(|emb| emb.is_empty()) {
369                                            bail!(
370                                                "NVIDIA embeddings API returned an empty embedding vector"
371                                            );
372                                        }
373
374                                        return Ok(embeddings);
375                                    }
376
377                                    if status == StatusCode::BAD_REQUEST {
378                                        if let Some(message) = extract_error_message(&body) {
379                                            if parse_token_limit_error(&message).is_some() {
380                                                factor = (factor * 0.85).clamp(0.02, 0.8);
381                                                continue;
382                                            }
383                                        }
384                                    }
385
386                                    bail!(
387                                        "NVIDIA embeddings API returned error status {status}: {body}"
388                                    );
389                                }
390
391                                bail!(
392                                    "NVIDIA embeddings input exceeds token limit and could not be truncated automatically.\n\
393                                     Try enabling smaller chunks (or disable contextual prefixes) and retry. You can also set NVIDIA_EMBEDDING_TRUNCATE=END if your model supports server-side truncation."
394                                );
395                            }
396                        }
397                    }
398
399                    bail!("NVIDIA embeddings API returned error status {status}: {body}");
400                }
401                Err(err) => {
402                    let retryable = err.is_timeout() || err.is_connect();
403                    if retryable && attempt <= max_retries {
404                        warn!(
405                            "NVIDIA embeddings request failed (attempt {attempt}/{max_attempts}); retrying in {backoff:?}: {err}",
406                            max_attempts = max_retries + 1
407                        );
408                        std::thread::sleep(backoff);
409                        backoff = (backoff * 2).min(max_backoff);
410                        continue;
411                    }
412
413                    bail!("NVIDIA embeddings request failed: {err}");
414                }
415            }
416        }
417    }
418}