memvid_cli/
nvidia_embeddings.rs1use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use reqwest::StatusCode;
18use serde::{Deserialize, Serialize};
19use std::time::Duration;
20use tracing::warn;
21
22const DEFAULT_NVIDIA_BASE_URL: &str = "https://integrate.api.nvidia.com";
23const DEFAULT_NVIDIA_EMBEDDING_MODEL: &str = "nvidia/nv-embed-v1";
24
25const DEFAULT_BATCH_SIZE: usize = 64;
26const MAX_BATCH_SIZE: usize = 256;
27
28const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
29
30fn truncate_to_chars(text: &str, max_chars: usize) -> String {
31 if text.len() <= max_chars {
32 return text.to_string();
33 }
34
35 let truncated = &text[..max_chars];
36 let end = truncated
37 .char_indices()
38 .rev()
39 .next()
40 .map(|(i, c)| i + c.len_utf8())
41 .unwrap_or(max_chars);
42 text[..end].to_string()
43}
44
45fn extract_error_message(body: &str) -> Option<String> {
46 let value: serde_json::Value = serde_json::from_str(body).ok()?;
47 if let Some(message) = value.get("error").and_then(|v| v.as_str()) {
48 return Some(message.to_string());
49 }
50 if let Some(message) = value.get("message").and_then(|v| v.as_str()) {
51 return Some(message.to_string());
52 }
53 None
54}
55
56fn parse_token_limit_error(message: &str) -> Option<(usize, usize)> {
57 if !message
58 .to_ascii_lowercase()
59 .contains("exceeds maximum allowed token size")
60 {
61 return None;
62 }
63
64 let mut numbers = Vec::new();
65 let mut current = String::new();
66 for ch in message.chars() {
67 if ch.is_ascii_digit() {
68 current.push(ch);
69 } else if !current.is_empty() {
70 if let Ok(value) = current.parse::<usize>() {
71 numbers.push(value);
72 }
73 current.clear();
74 }
75 }
76 if !current.is_empty() {
77 if let Ok(value) = current.parse::<usize>() {
78 numbers.push(value);
79 }
80 }
81
82 if numbers.len() >= 2 {
83 Some((numbers[0], numbers[1]))
84 } else {
85 None
86 }
87}
88
89#[derive(Debug, Serialize)]
90struct NvidiaEmbeddingRequest<'a> {
91 input: Vec<&'a str>,
92 model: &'a str,
93 #[serde(rename = "input_type")]
94 input_type: &'a str,
95 #[serde(rename = "encoding_format")]
96 encoding_format: &'a str,
97 truncate: &'a str,
98}
99
100#[derive(Debug, Deserialize)]
101struct NvidiaEmbeddingResponse {
102 data: Vec<NvidiaEmbeddingData>,
103}
104
105#[derive(Debug, Deserialize)]
106struct NvidiaEmbeddingData {
107 embedding: Vec<f32>,
108 index: usize,
109}
110
111#[derive(Clone, Debug)]
112pub struct NvidiaEmbeddingProvider {
113 api_key: String,
114 base_url: String,
115 model: String,
116 batch_size: usize,
117 document_input_type: String,
118 query_input_type: String,
119 encoding_format: String,
120 truncate: String,
121 client: Client,
122}
123
124impl NvidiaEmbeddingProvider {
125 pub fn from_env(explicit_model_override: Option<&str>) -> Result<Self> {
126 let api_key = std::env::var("NVIDIA_API_KEY")
127 .map_err(|_| anyhow!("NVIDIA_API_KEY environment variable is required for NVIDIA embeddings"))?;
128 if api_key.trim().is_empty() {
129 bail!("NVIDIA_API_KEY cannot be empty");
130 }
131
132 let base_url = std::env::var("NVIDIA_BASE_URL")
133 .unwrap_or_else(|_| DEFAULT_NVIDIA_BASE_URL.to_string());
134 let base_url = base_url.trim().trim_end_matches('/').to_string();
135 if base_url.is_empty() {
136 bail!("NVIDIA_BASE_URL cannot be empty");
137 }
138
139 let model = explicit_model_override
140 .and_then(|value| {
141 let trimmed = value.trim();
142 (!trimmed.is_empty()).then_some(trimmed.to_string())
143 })
144 .or_else(|| std::env::var("NVIDIA_EMBEDDING_MODEL").ok().map(|s| s.trim().to_string()))
145 .filter(|value| !value.is_empty())
146 .unwrap_or_else(|| DEFAULT_NVIDIA_EMBEDDING_MODEL.to_string());
147
148 let batch_size = std::env::var("NVIDIA_EMBEDDING_BATCH_SIZE")
149 .ok()
150 .and_then(|value| value.trim().parse::<usize>().ok())
151 .unwrap_or(DEFAULT_BATCH_SIZE)
152 .clamp(1, MAX_BATCH_SIZE);
153
154 let client = crate::http::blocking_client(REQUEST_TIMEOUT)
155 .map_err(|err| anyhow!("Failed to create HTTP client: {err}"))?;
156
157 let truncate = std::env::var("NVIDIA_EMBEDDING_TRUNCATE")
158 .ok()
159 .map(|value| value.trim().to_string())
160 .filter(|value| !value.is_empty())
161 .unwrap_or_else(|| "NONE".to_string());
162
163 Ok(Self {
164 api_key,
165 base_url,
166 model,
167 batch_size,
168 document_input_type: "passage".to_string(),
169 query_input_type: "query".to_string(),
170 encoding_format: "float".to_string(),
171 truncate,
172 client,
173 })
174 }
175
176 pub fn kind(&self) -> &'static str {
177 "nvidia"
178 }
179
180 pub fn model(&self) -> &str {
181 &self.model
182 }
183
184 pub fn embed_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
185 self.embed_batch_with_retry(&self.document_input_type, texts, 3)
186 }
187
188 pub fn embed_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
189 self.embed_batch_with_retry(&self.query_input_type, texts, 3)
190 }
191
192 pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
193 let mut out = self.embed_passages(&[text])?;
194 out.pop()
195 .ok_or_else(|| anyhow!("NVIDIA embeddings API returned no embedding output"))
196 }
197
198 pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
199 let mut out = self.embed_queries(&[text])?;
200 out.pop()
201 .ok_or_else(|| anyhow!("NVIDIA embeddings API returned no embedding output"))
202 }
203
204 fn embeddings_url(&self) -> String {
205 format!("{}/v1/embeddings", self.base_url)
206 }
207
208 fn embed_batch_with_retry(
209 &self,
210 input_type: &str,
211 texts: &[&str],
212 max_retries: usize,
213 ) -> Result<Vec<Vec<f32>>> {
214 if texts.is_empty() {
215 return Ok(Vec::new());
216 }
217
218 let mut all_embeddings = Vec::with_capacity(texts.len());
219 for chunk in texts.chunks(self.batch_size) {
220 let embeddings = self.call_nvidia_with_retry(input_type, chunk, max_retries)?;
221 all_embeddings.extend(embeddings);
222 }
223
224 Ok(all_embeddings)
225 }
226
227 fn call_nvidia_with_retry(
228 &self,
229 input_type: &str,
230 texts: &[&str],
231 max_retries: usize,
232 ) -> Result<Vec<Vec<f32>>> {
233 let url = self.embeddings_url();
234 let request = NvidiaEmbeddingRequest {
235 input: texts.to_vec(),
236 model: &self.model,
237 input_type,
238 encoding_format: &self.encoding_format,
239 truncate: &self.truncate,
240 };
241
242 let mut attempt = 0usize;
243 let mut backoff = Duration::from_millis(500);
244 let max_backoff = Duration::from_secs(8);
245
246 loop {
247 attempt += 1;
248 let response = self
249 .client
250 .post(&url)
251 .bearer_auth(&self.api_key)
252 .json(&request)
253 .send();
254
255 match response {
256 Ok(resp) => {
257 let status = resp.status();
258 let body = resp.text().unwrap_or_default();
259
260 if status.is_success() {
261 let mut decoded: NvidiaEmbeddingResponse =
262 serde_json::from_str(&body).map_err(|err| {
263 anyhow!("failed to decode NVIDIA embeddings response: {err}")
264 })?;
265 decoded.data.sort_by_key(|item| item.index);
266
267 if decoded.data.len() != texts.len() {
268 bail!(
269 "NVIDIA embeddings API returned {} embeddings for {} inputs",
270 decoded.data.len(),
271 texts.len()
272 );
273 }
274
275 let embeddings: Vec<Vec<f32>> = decoded
276 .data
277 .into_iter()
278 .map(|item| item.embedding)
279 .collect();
280
281 if embeddings.iter().any(|emb| emb.is_empty()) {
282 bail!("NVIDIA embeddings API returned an empty embedding vector");
283 }
284
285 return Ok(embeddings);
286 }
287
288 let retryable =
289 status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error();
290 if retryable && attempt <= max_retries {
291 warn!(
292 "NVIDIA embeddings API returned {status} (attempt {attempt}/{max_attempts}); retrying in {backoff:?}: {body}",
293 max_attempts = max_retries + 1
294 );
295 std::thread::sleep(backoff);
296 backoff = (backoff * 2).min(max_backoff);
297 continue;
298 }
299
300 if status == StatusCode::BAD_REQUEST {
301 if let Some(message) = extract_error_message(&body) {
302 if let Some((actual, max)) = parse_token_limit_error(&message) {
303 let mut factor =
304 (max as f64 / actual.max(1) as f64).clamp(0.05, 0.95) * 0.95;
305 warn!(
306 "NVIDIA embeddings input exceeds token limit ({actual} > {max}); retrying with automatic truncation"
307 );
308
309 for _ in 0..3 {
310 let owned: Vec<String> = texts
311 .iter()
312 .map(|text| {
313 let target =
314 ((text.len() as f64) * factor).floor() as usize;
315 truncate_to_chars(text, target.max(256))
316 })
317 .collect();
318 let refs: Vec<&str> =
319 owned.iter().map(|text| text.as_str()).collect();
320 let request = NvidiaEmbeddingRequest {
321 input: refs,
322 model: &self.model,
323 input_type,
324 encoding_format: &self.encoding_format,
325 truncate: &self.truncate,
326 };
327
328 let resp = self
329 .client
330 .post(&url)
331 .bearer_auth(&self.api_key)
332 .json(&request)
333 .send()
334 .map_err(|err| {
335 anyhow!("NVIDIA embeddings request failed: {err}")
336 })?;
337
338 let status = resp.status();
339 let body = resp.text().unwrap_or_default();
340 if status.is_success() {
341 let mut decoded: NvidiaEmbeddingResponse =
342 serde_json::from_str(&body).map_err(|err| {
343 anyhow!(
344 "failed to decode NVIDIA embeddings response: {err}"
345 )
346 })?;
347 decoded.data.sort_by_key(|item| item.index);
348
349 if decoded.data.len() != texts.len() {
350 bail!(
351 "NVIDIA embeddings API returned {} embeddings for {} inputs",
352 decoded.data.len(),
353 texts.len()
354 );
355 }
356
357 let embeddings: Vec<Vec<f32>> = decoded
358 .data
359 .into_iter()
360 .map(|item| item.embedding)
361 .collect();
362
363 if embeddings.iter().any(|emb| emb.is_empty()) {
364 bail!(
365 "NVIDIA embeddings API returned an empty embedding vector"
366 );
367 }
368
369 return Ok(embeddings);
370 }
371
372 if status == StatusCode::BAD_REQUEST {
373 if let Some(message) = extract_error_message(&body) {
374 if parse_token_limit_error(&message).is_some() {
375 factor = (factor * 0.85).clamp(0.02, 0.8);
376 continue;
377 }
378 }
379 }
380
381 bail!(
382 "NVIDIA embeddings API returned error status {status}: {body}"
383 );
384 }
385
386 bail!(
387 "NVIDIA embeddings input exceeds token limit and could not be truncated automatically.\n\
388 Try enabling smaller chunks (or disable contextual prefixes) and retry. You can also set NVIDIA_EMBEDDING_TRUNCATE=END if your model supports server-side truncation."
389 );
390 }
391 }
392 }
393
394 bail!("NVIDIA embeddings API returned error status {status}: {body}");
395 }
396 Err(err) => {
397 let retryable = err.is_timeout() || err.is_connect();
398 if retryable && attempt <= max_retries {
399 warn!(
400 "NVIDIA embeddings request failed (attempt {attempt}/{max_attempts}); retrying in {backoff:?}: {err}",
401 max_attempts = max_retries + 1
402 );
403 std::thread::sleep(backoff);
404 backoff = (backoff * 2).min(max_backoff);
405 continue;
406 }
407
408 bail!("NVIDIA embeddings request failed: {err}");
409 }
410 }
411 }
412 }
413}