memvid_cli/
nvidia_embeddings.rs

1//! NVIDIA Embeddings Provider
2//!
3//! Implements high-quality embeddings via NVIDIA Integrate API (`/v1/embeddings`).
4//!
5//! ## Environment Variables
6//! - `NVIDIA_API_KEY`: Required
7//! - `NVIDIA_BASE_URL`: Optional (default: https://integrate.api.nvidia.com)
8//! - `NVIDIA_EMBEDDING_MODEL`: Optional (default: nvidia/nv-embed-v1)
9//! - `NVIDIA_EMBEDDING_BATCH_SIZE`: Optional (default: 64)
10//!
11//! The NVIDIA API supports different `input_type` values. Memvid uses:
12//! - `passage` for document/chunk embeddings
13//! - `query` for query embeddings
14
15use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use reqwest::StatusCode;
18use serde::{Deserialize, Serialize};
19use std::time::Duration;
20use tracing::warn;
21
22const DEFAULT_NVIDIA_BASE_URL: &str = "https://integrate.api.nvidia.com";
23const DEFAULT_NVIDIA_EMBEDDING_MODEL: &str = "nvidia/nv-embed-v1";
24
25const DEFAULT_BATCH_SIZE: usize = 64;
26const MAX_BATCH_SIZE: usize = 256;
27
28const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
29
30fn truncate_to_chars(text: &str, max_chars: usize) -> String {
31    if text.len() <= max_chars {
32        return text.to_string();
33    }
34
35    let truncated = &text[..max_chars];
36    let end = truncated
37        .char_indices()
38        .rev()
39        .next()
40        .map(|(i, c)| i + c.len_utf8())
41        .unwrap_or(max_chars);
42    text[..end].to_string()
43}
44
45fn extract_error_message(body: &str) -> Option<String> {
46    let value: serde_json::Value = serde_json::from_str(body).ok()?;
47    if let Some(message) = value.get("error").and_then(|v| v.as_str()) {
48        return Some(message.to_string());
49    }
50    if let Some(message) = value.get("message").and_then(|v| v.as_str()) {
51        return Some(message.to_string());
52    }
53    None
54}
55
56fn parse_token_limit_error(message: &str) -> Option<(usize, usize)> {
57    if !message
58        .to_ascii_lowercase()
59        .contains("exceeds maximum allowed token size")
60    {
61        return None;
62    }
63
64    let mut numbers = Vec::new();
65    let mut current = String::new();
66    for ch in message.chars() {
67        if ch.is_ascii_digit() {
68            current.push(ch);
69        } else if !current.is_empty() {
70            if let Ok(value) = current.parse::<usize>() {
71                numbers.push(value);
72            }
73            current.clear();
74        }
75    }
76    if !current.is_empty() {
77        if let Ok(value) = current.parse::<usize>() {
78            numbers.push(value);
79        }
80    }
81
82    if numbers.len() >= 2 {
83        Some((numbers[0], numbers[1]))
84    } else {
85        None
86    }
87}
88
89#[derive(Debug, Serialize)]
90struct NvidiaEmbeddingRequest<'a> {
91    input: Vec<&'a str>,
92    model: &'a str,
93    #[serde(rename = "input_type")]
94    input_type: &'a str,
95    #[serde(rename = "encoding_format")]
96    encoding_format: &'a str,
97    truncate: &'a str,
98}
99
100#[derive(Debug, Deserialize)]
101struct NvidiaEmbeddingResponse {
102    data: Vec<NvidiaEmbeddingData>,
103}
104
105#[derive(Debug, Deserialize)]
106struct NvidiaEmbeddingData {
107    embedding: Vec<f32>,
108    index: usize,
109}
110
111#[derive(Clone, Debug)]
112pub struct NvidiaEmbeddingProvider {
113    api_key: String,
114    base_url: String,
115    model: String,
116    batch_size: usize,
117    document_input_type: String,
118    query_input_type: String,
119    encoding_format: String,
120    truncate: String,
121    client: Client,
122}
123
124impl NvidiaEmbeddingProvider {
125    pub fn from_env(explicit_model_override: Option<&str>) -> Result<Self> {
126        let api_key = std::env::var("NVIDIA_API_KEY")
127            .map_err(|_| anyhow!("NVIDIA_API_KEY environment variable is required for NVIDIA embeddings"))?;
128        if api_key.trim().is_empty() {
129            bail!("NVIDIA_API_KEY cannot be empty");
130        }
131
132        let base_url = std::env::var("NVIDIA_BASE_URL")
133            .unwrap_or_else(|_| DEFAULT_NVIDIA_BASE_URL.to_string());
134        let base_url = base_url.trim().trim_end_matches('/').to_string();
135        if base_url.is_empty() {
136            bail!("NVIDIA_BASE_URL cannot be empty");
137        }
138
139        let model = explicit_model_override
140            .and_then(|value| {
141                let trimmed = value.trim();
142                (!trimmed.is_empty()).then_some(trimmed.to_string())
143            })
144            .or_else(|| std::env::var("NVIDIA_EMBEDDING_MODEL").ok().map(|s| s.trim().to_string()))
145            .filter(|value| !value.is_empty())
146            .unwrap_or_else(|| DEFAULT_NVIDIA_EMBEDDING_MODEL.to_string());
147
148        let batch_size = std::env::var("NVIDIA_EMBEDDING_BATCH_SIZE")
149            .ok()
150            .and_then(|value| value.trim().parse::<usize>().ok())
151            .unwrap_or(DEFAULT_BATCH_SIZE)
152            .clamp(1, MAX_BATCH_SIZE);
153
154        let client = crate::http::blocking_client(REQUEST_TIMEOUT)
155            .map_err(|err| anyhow!("Failed to create HTTP client: {err}"))?;
156
157        let truncate = std::env::var("NVIDIA_EMBEDDING_TRUNCATE")
158            .ok()
159            .map(|value| value.trim().to_string())
160            .filter(|value| !value.is_empty())
161            .unwrap_or_else(|| "NONE".to_string());
162
163        Ok(Self {
164            api_key,
165            base_url,
166            model,
167            batch_size,
168            document_input_type: "passage".to_string(),
169            query_input_type: "query".to_string(),
170            encoding_format: "float".to_string(),
171            truncate,
172            client,
173        })
174    }
175
176    pub fn kind(&self) -> &'static str {
177        "nvidia"
178    }
179
180    pub fn model(&self) -> &str {
181        &self.model
182    }
183
184    pub fn embed_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
185        self.embed_batch_with_retry(&self.document_input_type, texts, 3)
186    }
187
188    pub fn embed_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
189        self.embed_batch_with_retry(&self.query_input_type, texts, 3)
190    }
191
192    pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
193        let mut out = self.embed_passages(&[text])?;
194        out.pop()
195            .ok_or_else(|| anyhow!("NVIDIA embeddings API returned no embedding output"))
196    }
197
198    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
199        let mut out = self.embed_queries(&[text])?;
200        out.pop()
201            .ok_or_else(|| anyhow!("NVIDIA embeddings API returned no embedding output"))
202    }
203
204    fn embeddings_url(&self) -> String {
205        format!("{}/v1/embeddings", self.base_url)
206    }
207
208    fn embed_batch_with_retry(
209        &self,
210        input_type: &str,
211        texts: &[&str],
212        max_retries: usize,
213    ) -> Result<Vec<Vec<f32>>> {
214        if texts.is_empty() {
215            return Ok(Vec::new());
216        }
217
218        let mut all_embeddings = Vec::with_capacity(texts.len());
219        for chunk in texts.chunks(self.batch_size) {
220            let embeddings = self.call_nvidia_with_retry(input_type, chunk, max_retries)?;
221            all_embeddings.extend(embeddings);
222        }
223
224        Ok(all_embeddings)
225    }
226
227    fn call_nvidia_with_retry(
228        &self,
229        input_type: &str,
230        texts: &[&str],
231        max_retries: usize,
232    ) -> Result<Vec<Vec<f32>>> {
233        let url = self.embeddings_url();
234        let request = NvidiaEmbeddingRequest {
235            input: texts.to_vec(),
236            model: &self.model,
237            input_type,
238            encoding_format: &self.encoding_format,
239            truncate: &self.truncate,
240        };
241
242        let mut attempt = 0usize;
243        let mut backoff = Duration::from_millis(500);
244        let max_backoff = Duration::from_secs(8);
245
246        loop {
247            attempt += 1;
248            let response = self
249                .client
250                .post(&url)
251                .bearer_auth(&self.api_key)
252                .json(&request)
253                .send();
254
255            match response {
256                Ok(resp) => {
257                    let status = resp.status();
258                    let body = resp.text().unwrap_or_default();
259
260                    if status.is_success() {
261                        let mut decoded: NvidiaEmbeddingResponse =
262                            serde_json::from_str(&body).map_err(|err| {
263                                anyhow!("failed to decode NVIDIA embeddings response: {err}")
264                            })?;
265                        decoded.data.sort_by_key(|item| item.index);
266
267                        if decoded.data.len() != texts.len() {
268                            bail!(
269                                "NVIDIA embeddings API returned {} embeddings for {} inputs",
270                                decoded.data.len(),
271                                texts.len()
272                            );
273                        }
274
275                        let embeddings: Vec<Vec<f32>> = decoded
276                            .data
277                            .into_iter()
278                            .map(|item| item.embedding)
279                            .collect();
280
281                        if embeddings.iter().any(|emb| emb.is_empty()) {
282                            bail!("NVIDIA embeddings API returned an empty embedding vector");
283                        }
284
285                        return Ok(embeddings);
286                    }
287
288                    let retryable =
289                        status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error();
290                    if retryable && attempt <= max_retries {
291                        warn!(
292                            "NVIDIA embeddings API returned {status} (attempt {attempt}/{max_attempts}); retrying in {backoff:?}: {body}",
293                            max_attempts = max_retries + 1
294                        );
295                        std::thread::sleep(backoff);
296                        backoff = (backoff * 2).min(max_backoff);
297                        continue;
298                    }
299
300                    if status == StatusCode::BAD_REQUEST {
301                        if let Some(message) = extract_error_message(&body) {
302                            if let Some((actual, max)) = parse_token_limit_error(&message) {
303                                let mut factor =
304                                    (max as f64 / actual.max(1) as f64).clamp(0.05, 0.95) * 0.95;
305                                warn!(
306                                    "NVIDIA embeddings input exceeds token limit ({actual} > {max}); retrying with automatic truncation"
307                                );
308
309                                for _ in 0..3 {
310                                    let owned: Vec<String> = texts
311                                        .iter()
312                                        .map(|text| {
313                                            let target =
314                                                ((text.len() as f64) * factor).floor() as usize;
315                                            truncate_to_chars(text, target.max(256))
316                                        })
317                                        .collect();
318                                    let refs: Vec<&str> =
319                                        owned.iter().map(|text| text.as_str()).collect();
320                                    let request = NvidiaEmbeddingRequest {
321                                        input: refs,
322                                        model: &self.model,
323                                        input_type,
324                                        encoding_format: &self.encoding_format,
325                                        truncate: &self.truncate,
326                                    };
327
328                                    let resp = self
329                                        .client
330                                        .post(&url)
331                                        .bearer_auth(&self.api_key)
332                                        .json(&request)
333                                        .send()
334                                        .map_err(|err| {
335                                            anyhow!("NVIDIA embeddings request failed: {err}")
336                                        })?;
337
338                                    let status = resp.status();
339                                    let body = resp.text().unwrap_or_default();
340                                    if status.is_success() {
341                                        let mut decoded: NvidiaEmbeddingResponse =
342                                            serde_json::from_str(&body).map_err(|err| {
343                                                anyhow!(
344                                                    "failed to decode NVIDIA embeddings response: {err}"
345                                                )
346                                            })?;
347                                        decoded.data.sort_by_key(|item| item.index);
348
349                                        if decoded.data.len() != texts.len() {
350                                            bail!(
351                                                "NVIDIA embeddings API returned {} embeddings for {} inputs",
352                                                decoded.data.len(),
353                                                texts.len()
354                                            );
355                                        }
356
357                                        let embeddings: Vec<Vec<f32>> = decoded
358                                            .data
359                                            .into_iter()
360                                            .map(|item| item.embedding)
361                                            .collect();
362
363                                        if embeddings.iter().any(|emb| emb.is_empty()) {
364                                            bail!(
365                                                "NVIDIA embeddings API returned an empty embedding vector"
366                                            );
367                                        }
368
369                                        return Ok(embeddings);
370                                    }
371
372                                    if status == StatusCode::BAD_REQUEST {
373                                        if let Some(message) = extract_error_message(&body) {
374                                            if parse_token_limit_error(&message).is_some() {
375                                                factor = (factor * 0.85).clamp(0.02, 0.8);
376                                                continue;
377                                            }
378                                        }
379                                    }
380
381                                    bail!(
382                                        "NVIDIA embeddings API returned error status {status}: {body}"
383                                    );
384                                }
385
386                                bail!(
387                                    "NVIDIA embeddings input exceeds token limit and could not be truncated automatically.\n\
388                                     Try enabling smaller chunks (or disable contextual prefixes) and retry. You can also set NVIDIA_EMBEDDING_TRUNCATE=END if your model supports server-side truncation."
389                                );
390                            }
391                        }
392                    }
393
394                    bail!("NVIDIA embeddings API returned error status {status}: {body}");
395                }
396                Err(err) => {
397                    let retryable = err.is_timeout() || err.is_connect();
398                    if retryable && attempt <= max_retries {
399                        warn!(
400                            "NVIDIA embeddings request failed (attempt {attempt}/{max_attempts}); retrying in {backoff:?}: {err}",
401                            max_attempts = max_retries + 1
402                        );
403                        std::thread::sleep(backoff);
404                        backoff = (backoff * 2).min(max_backoff);
405                        continue;
406                    }
407
408                    bail!("NVIDIA embeddings request failed: {err}");
409                }
410            }
411        }
412    }
413}