memvid_cli/
nvidia_embeddings.rs1use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use reqwest::StatusCode;
18use serde::{Deserialize, Serialize};
19use std::time::Duration;
20use tracing::warn;
21
22const DEFAULT_NVIDIA_BASE_URL: &str = "https://integrate.api.nvidia.com";
23const DEFAULT_NVIDIA_EMBEDDING_MODEL: &str = "nvidia/nv-embed-v1";
24
25const DEFAULT_BATCH_SIZE: usize = 64;
26const MAX_BATCH_SIZE: usize = 256;
27
28const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
29
30fn truncate_to_chars(text: &str, max_chars: usize) -> String {
31 if text.len() <= max_chars {
32 return text.to_string();
33 }
34
35 let truncated = &text[..max_chars];
36 let end = truncated
37 .char_indices()
38 .rev()
39 .next()
40 .map(|(i, c)| i + c.len_utf8())
41 .unwrap_or(max_chars);
42 text[..end].to_string()
43}
44
45fn extract_error_message(body: &str) -> Option<String> {
46 let value: serde_json::Value = serde_json::from_str(body).ok()?;
47 if let Some(message) = value.get("error").and_then(|v| v.as_str()) {
48 return Some(message.to_string());
49 }
50 if let Some(message) = value.get("message").and_then(|v| v.as_str()) {
51 return Some(message.to_string());
52 }
53 None
54}
55
56fn parse_token_limit_error(message: &str) -> Option<(usize, usize)> {
57 if !message
58 .to_ascii_lowercase()
59 .contains("exceeds maximum allowed token size")
60 {
61 return None;
62 }
63
64 let mut numbers = Vec::new();
65 let mut current = String::new();
66 for ch in message.chars() {
67 if ch.is_ascii_digit() {
68 current.push(ch);
69 } else if !current.is_empty() {
70 if let Ok(value) = current.parse::<usize>() {
71 numbers.push(value);
72 }
73 current.clear();
74 }
75 }
76 if !current.is_empty() {
77 if let Ok(value) = current.parse::<usize>() {
78 numbers.push(value);
79 }
80 }
81
82 if numbers.len() >= 2 {
83 Some((numbers[0], numbers[1]))
84 } else {
85 None
86 }
87}
88
89#[derive(Debug, Serialize)]
90struct NvidiaEmbeddingRequest<'a> {
91 input: Vec<&'a str>,
92 model: &'a str,
93 #[serde(rename = "input_type")]
94 input_type: &'a str,
95 #[serde(rename = "encoding_format")]
96 encoding_format: &'a str,
97 truncate: &'a str,
98}
99
100#[derive(Debug, Deserialize)]
101struct NvidiaEmbeddingResponse {
102 data: Vec<NvidiaEmbeddingData>,
103}
104
105#[derive(Debug, Deserialize)]
106struct NvidiaEmbeddingData {
107 embedding: Vec<f32>,
108 index: usize,
109}
110
111#[derive(Clone, Debug)]
112pub struct NvidiaEmbeddingProvider {
113 api_key: String,
114 base_url: String,
115 model: String,
116 batch_size: usize,
117 document_input_type: String,
118 query_input_type: String,
119 encoding_format: String,
120 truncate: String,
121 client: Client,
122}
123
124impl NvidiaEmbeddingProvider {
125 pub fn from_env(explicit_model_override: Option<&str>) -> Result<Self> {
126 let api_key = std::env::var("NVIDIA_API_KEY").map_err(|_| {
127 anyhow!("NVIDIA_API_KEY environment variable is required for NVIDIA embeddings")
128 })?;
129 if api_key.trim().is_empty() {
130 bail!("NVIDIA_API_KEY cannot be empty");
131 }
132
133 let base_url = std::env::var("NVIDIA_BASE_URL")
134 .unwrap_or_else(|_| DEFAULT_NVIDIA_BASE_URL.to_string());
135 let base_url = base_url.trim().trim_end_matches('/').to_string();
136 if base_url.is_empty() {
137 bail!("NVIDIA_BASE_URL cannot be empty");
138 }
139
140 let model = explicit_model_override
141 .and_then(|value| {
142 let trimmed = value.trim();
143 (!trimmed.is_empty()).then_some(trimmed.to_string())
144 })
145 .or_else(|| {
146 std::env::var("NVIDIA_EMBEDDING_MODEL")
147 .ok()
148 .map(|s| s.trim().to_string())
149 })
150 .filter(|value| !value.is_empty())
151 .unwrap_or_else(|| DEFAULT_NVIDIA_EMBEDDING_MODEL.to_string());
152
153 let batch_size = std::env::var("NVIDIA_EMBEDDING_BATCH_SIZE")
154 .ok()
155 .and_then(|value| value.trim().parse::<usize>().ok())
156 .unwrap_or(DEFAULT_BATCH_SIZE)
157 .clamp(1, MAX_BATCH_SIZE);
158
159 let client = crate::http::blocking_client(REQUEST_TIMEOUT)
160 .map_err(|err| anyhow!("Failed to create HTTP client: {err}"))?;
161
162 let truncate = std::env::var("NVIDIA_EMBEDDING_TRUNCATE")
163 .ok()
164 .map(|value| value.trim().to_string())
165 .filter(|value| !value.is_empty())
166 .unwrap_or_else(|| "NONE".to_string());
167
168 Ok(Self {
169 api_key,
170 base_url,
171 model,
172 batch_size,
173 document_input_type: "passage".to_string(),
174 query_input_type: "query".to_string(),
175 encoding_format: "float".to_string(),
176 truncate,
177 client,
178 })
179 }
180
181 pub fn kind(&self) -> &'static str {
182 "nvidia"
183 }
184
185 pub fn model(&self) -> &str {
186 &self.model
187 }
188
189 pub fn embed_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
190 self.embed_batch_with_retry(&self.document_input_type, texts, 3)
191 }
192
193 pub fn embed_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
194 self.embed_batch_with_retry(&self.query_input_type, texts, 3)
195 }
196
197 pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
198 let mut out = self.embed_passages(&[text])?;
199 out.pop()
200 .ok_or_else(|| anyhow!("NVIDIA embeddings API returned no embedding output"))
201 }
202
203 pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
204 let mut out = self.embed_queries(&[text])?;
205 out.pop()
206 .ok_or_else(|| anyhow!("NVIDIA embeddings API returned no embedding output"))
207 }
208
209 fn embeddings_url(&self) -> String {
210 format!("{}/v1/embeddings", self.base_url)
211 }
212
213 fn embed_batch_with_retry(
214 &self,
215 input_type: &str,
216 texts: &[&str],
217 max_retries: usize,
218 ) -> Result<Vec<Vec<f32>>> {
219 if texts.is_empty() {
220 return Ok(Vec::new());
221 }
222
223 let mut all_embeddings = Vec::with_capacity(texts.len());
224 for chunk in texts.chunks(self.batch_size) {
225 let embeddings = self.call_nvidia_with_retry(input_type, chunk, max_retries)?;
226 all_embeddings.extend(embeddings);
227 }
228
229 Ok(all_embeddings)
230 }
231
232 fn call_nvidia_with_retry(
233 &self,
234 input_type: &str,
235 texts: &[&str],
236 max_retries: usize,
237 ) -> Result<Vec<Vec<f32>>> {
238 let url = self.embeddings_url();
239 let request = NvidiaEmbeddingRequest {
240 input: texts.to_vec(),
241 model: &self.model,
242 input_type,
243 encoding_format: &self.encoding_format,
244 truncate: &self.truncate,
245 };
246
247 let mut attempt = 0usize;
248 let mut backoff = Duration::from_millis(500);
249 let max_backoff = Duration::from_secs(8);
250
251 loop {
252 attempt += 1;
253 let response = self
254 .client
255 .post(&url)
256 .bearer_auth(&self.api_key)
257 .json(&request)
258 .send();
259
260 match response {
261 Ok(resp) => {
262 let status = resp.status();
263 let body = resp.text().unwrap_or_default();
264
265 if status.is_success() {
266 let mut decoded: NvidiaEmbeddingResponse = serde_json::from_str(&body)
267 .map_err(|err| {
268 anyhow!("failed to decode NVIDIA embeddings response: {err}")
269 })?;
270 decoded.data.sort_by_key(|item| item.index);
271
272 if decoded.data.len() != texts.len() {
273 bail!(
274 "NVIDIA embeddings API returned {} embeddings for {} inputs",
275 decoded.data.len(),
276 texts.len()
277 );
278 }
279
280 let embeddings: Vec<Vec<f32>> = decoded
281 .data
282 .into_iter()
283 .map(|item| item.embedding)
284 .collect();
285
286 if embeddings.iter().any(|emb| emb.is_empty()) {
287 bail!("NVIDIA embeddings API returned an empty embedding vector");
288 }
289
290 return Ok(embeddings);
291 }
292
293 let retryable =
294 status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error();
295 if retryable && attempt <= max_retries {
296 warn!(
297 "NVIDIA embeddings API returned {status} (attempt {attempt}/{max_attempts}); retrying in {backoff:?}: {body}",
298 max_attempts = max_retries + 1
299 );
300 std::thread::sleep(backoff);
301 backoff = (backoff * 2).min(max_backoff);
302 continue;
303 }
304
305 if status == StatusCode::BAD_REQUEST {
306 if let Some(message) = extract_error_message(&body) {
307 if let Some((actual, max)) = parse_token_limit_error(&message) {
308 let mut factor =
309 (max as f64 / actual.max(1) as f64).clamp(0.05, 0.95) * 0.95;
310 warn!(
311 "NVIDIA embeddings input exceeds token limit ({actual} > {max}); retrying with automatic truncation"
312 );
313
314 for _ in 0..3 {
315 let owned: Vec<String> = texts
316 .iter()
317 .map(|text| {
318 let target =
319 ((text.len() as f64) * factor).floor() as usize;
320 truncate_to_chars(text, target.max(256))
321 })
322 .collect();
323 let refs: Vec<&str> =
324 owned.iter().map(|text| text.as_str()).collect();
325 let request = NvidiaEmbeddingRequest {
326 input: refs,
327 model: &self.model,
328 input_type,
329 encoding_format: &self.encoding_format,
330 truncate: &self.truncate,
331 };
332
333 let resp = self
334 .client
335 .post(&url)
336 .bearer_auth(&self.api_key)
337 .json(&request)
338 .send()
339 .map_err(|err| {
340 anyhow!("NVIDIA embeddings request failed: {err}")
341 })?;
342
343 let status = resp.status();
344 let body = resp.text().unwrap_or_default();
345 if status.is_success() {
346 let mut decoded: NvidiaEmbeddingResponse =
347 serde_json::from_str(&body).map_err(|err| {
348 anyhow!(
349 "failed to decode NVIDIA embeddings response: {err}"
350 )
351 })?;
352 decoded.data.sort_by_key(|item| item.index);
353
354 if decoded.data.len() != texts.len() {
355 bail!(
356 "NVIDIA embeddings API returned {} embeddings for {} inputs",
357 decoded.data.len(),
358 texts.len()
359 );
360 }
361
362 let embeddings: Vec<Vec<f32>> = decoded
363 .data
364 .into_iter()
365 .map(|item| item.embedding)
366 .collect();
367
368 if embeddings.iter().any(|emb| emb.is_empty()) {
369 bail!(
370 "NVIDIA embeddings API returned an empty embedding vector"
371 );
372 }
373
374 return Ok(embeddings);
375 }
376
377 if status == StatusCode::BAD_REQUEST {
378 if let Some(message) = extract_error_message(&body) {
379 if parse_token_limit_error(&message).is_some() {
380 factor = (factor * 0.85).clamp(0.02, 0.8);
381 continue;
382 }
383 }
384 }
385
386 bail!(
387 "NVIDIA embeddings API returned error status {status}: {body}"
388 );
389 }
390
391 bail!(
392 "NVIDIA embeddings input exceeds token limit and could not be truncated automatically.\n\
393 Try enabling smaller chunks (or disable contextual prefixes) and retry. You can also set NVIDIA_EMBEDDING_TRUNCATE=END if your model supports server-side truncation."
394 );
395 }
396 }
397 }
398
399 bail!("NVIDIA embeddings API returned error status {status}: {body}");
400 }
401 Err(err) => {
402 let retryable = err.is_timeout() || err.is_connect();
403 if retryable && attempt <= max_retries {
404 warn!(
405 "NVIDIA embeddings request failed (attempt {attempt}/{max_attempts}); retrying in {backoff:?}: {err}",
406 max_attempts = max_retries + 1
407 );
408 std::thread::sleep(backoff);
409 backoff = (backoff * 2).min(max_backoff);
410 continue;
411 }
412
413 bail!("NVIDIA embeddings request failed: {err}");
414 }
415 }
416 }
417 }
418}