use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::debug;
use crate::error::{LlmError, Result};
use crate::model_config::{
ModelCapabilities, ModelCard, ModelType, ProviderConfig, ProviderType as ConfigProviderType,
};
use crate::providers::openai_compatible::OpenAICompatibleProvider;
use crate::traits::StreamChunk;
use crate::traits::{
ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse,
ToolChoice, ToolDefinition,
};
const NVIDIA_BASE_URL: &str = "https://integrate.api.nvidia.com/v1";
const NVIDIA_DEFAULT_MODEL: &str = "nvidia/llama-3.3-nemotron-super-49b-v1";
const NVIDIA_PROVIDER_NAME: &str = "nvidia";
const NVIDIA_TIMEOUT_SECS: u64 = 300;
const NVIDIA_NVCF_STATUS_URL: &str = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status";
const NVIDIA_REQID_HEADER: &str = "NVCF-REQID";
const NVIDIA_POLL_INTERVAL_MS: u64 = 500;
const NVIDIA_MAX_POLL_ATTEMPTS: u32 = 600;
const NVIDIA_CHAT_MODELS: &[(&str, &str, usize, bool, bool, bool)] = &[
(
"nvidia/llama-3.3-nemotron-super-49b-v1",
"Nemotron Super 49B v1 (128K, thinking, free)",
131_072,
false,
true, true, ),
(
"nvidia/llama-3.3-nemotron-super-49b-v1.5",
"Nemotron Super 49B v1.5 (128K, thinking, free)",
131_072,
false,
true,
true,
),
(
"nvidia/llama-3.1-nemotron-ultra-253b-v1",
"Nemotron Ultra 253B v1 (128K, thinking)",
131_072,
false,
true,
false, ),
(
"nvidia/llama-3.1-nemotron-nano-8b-v1",
"Nemotron Nano 8B v1 (128K, thinking, free)",
131_072,
false,
true,
true,
),
(
"nvidia/llama-3.1-nemotron-nano-4b-v1_1",
"Nemotron Nano 4B v1.1 (128K, thinking, free)",
131_072,
false,
true,
true,
),
(
"nvidia/nemotron-3-nano-30b-a3b",
"Nemotron 3 Nano 30B-A3B MoE (1M, thinking, free)",
1_000_000,
false,
true,
true,
),
(
"nvidia/nemotron-3-super-120b-a12b",
"Nemotron 3 Super 120B-A12B MoE (1M, thinking, free)",
1_000_000,
false,
true,
true,
),
(
"nvidia/nemotron-mini-4b-instruct",
"Nemotron Mini 4B Instruct (4K)",
4_096,
false,
false,
false,
),
(
"nvidia/nvidia-nemotron-nano-9b-v2",
"Nemotron Nano 9B v2 (128K)",
131_072,
false,
false,
false,
),
(
"deepseek-ai/deepseek-v4-flash",
"DeepSeek V4 Flash (64K, thinking via reasoning_effort, free)",
65_536,
false,
true, true,
),
(
"deepseek-ai/deepseek-v4-pro",
"DeepSeek V4 Pro (64K, thinking via reasoning_effort)",
65_536,
false,
true,
false, ),
(
"deepseek-ai/deepseek-v3.2",
"DeepSeek V3.2 (128K)",
131_072,
false,
false,
false,
),
(
"deepseek-ai/deepseek-v3.1-terminus",
"DeepSeek V3.1 Terminus (128K)",
131_072,
false,
false,
false,
),
(
"deepseek-ai/deepseek-r1",
"DeepSeek R1 (128K, thinking)",
131_072,
false,
true,
false,
),
(
"meta/llama-3.3-70b-instruct",
"Llama 3.3 70B Instruct (128K, free)",
131_072,
false,
false,
true,
),
(
"meta/llama-3.1-405b-instruct",
"Llama 3.1 405B Instruct (128K)",
131_072,
false,
false,
false,
),
(
"meta/llama-3.1-70b-instruct",
"Llama 3.1 70B Instruct (128K, free)",
131_072,
false,
false,
true,
),
(
"meta/llama-3.1-8b-instruct",
"Llama 3.1 8B Instruct (128K, free)",
131_072,
false,
false,
true,
),
(
"meta/llama-3.2-3b-instruct",
"Llama 3.2 3B Instruct (128K, free)",
131_072,
false,
false,
true,
),
(
"meta/llama-3.2-1b-instruct",
"Llama 3.2 1B Instruct (128K, free)",
131_072,
false,
false,
true,
),
(
"meta/llama-4-maverick-17b-128e-instruct",
"Llama 4 Maverick 17B 128E (1M, vision, free)",
1_000_000,
true, false,
true,
),
(
"meta/llama-3.2-11b-vision-instruct",
"Llama 3.2 11B Vision (128K, vision)",
131_072,
true,
false,
false,
),
(
"meta/llama-3.2-90b-vision-instruct",
"Llama 3.2 90B Vision (128K, vision)",
131_072,
true,
false,
false,
),
(
"microsoft/phi-4-mini-instruct",
"Phi-4 Mini Instruct (128K, free)",
131_072,
false,
false,
true,
),
(
"microsoft/phi-4-mini-flash-reasoning",
"Phi-4 Mini Flash Reasoning (128K, thinking, free)",
131_072,
false,
true,
true,
),
(
"microsoft/phi-4-multimodal-instruct",
"Phi-4 Multimodal Instruct (128K, vision)",
131_072,
true,
false,
false,
),
(
"microsoft/phi-3.5-mini",
"Phi-3.5 Mini Instruct (128K)",
131_072,
false,
false,
false,
),
(
"microsoft/phi-3.5-vision-instruct",
"Phi-3.5 Vision Instruct (128K, vision)",
131_072,
true,
false,
false,
),
(
"microsoft/phi-3-mini-128k-instruct",
"Phi-3 Mini 128K Instruct (128K)",
131_072,
false,
false,
false,
),
(
"microsoft/phi-3-mini-4k-instruct",
"Phi-3 Mini 4K Instruct (4K)",
4_096,
false,
false,
false,
),
(
"microsoft/phi-3-small-128k-instruct",
"Phi-3 Small 128K Instruct (128K)",
131_072,
false,
false,
false,
),
(
"microsoft/phi-3-medium-128k-instruct",
"Phi-3 Medium 128K Instruct (128K)",
131_072,
false,
false,
false,
),
(
"mistralai/mistral-nemotron",
"Mistral Nemotron (128K)",
131_072,
false,
false,
false,
),
(
"mistralai/mistral-small-24b-instruct",
"Mistral Small 24B Instruct (128K)",
131_072,
false,
false,
false,
),
(
"mistralai/mistral-large-2-instruct",
"Mistral Large 2 Instruct (128K)",
131_072,
false,
false,
false,
),
(
"mistralai/mistral-7b-instruct-v0.3",
"Mistral 7B Instruct v0.3 (32K, free)",
32_768,
false,
false,
true,
),
(
"mistralai/mixtral-8x7b-instruct",
"Mixtral 8x7B Instruct (32K)",
32_768,
false,
false,
false,
),
(
"mistralai/mixtral-8x22b-instruct",
"Mixtral 8x22B Instruct (65K)",
65_536,
false,
false,
false,
),
(
"mistralai/magistral-small-2506",
"Magistral Small 2506 (128K, thinking)",
131_072,
false,
true,
false,
),
(
"qwen/qwen2.5-7b-instruct",
"Qwen 2.5 7B Instruct (128K, free)",
131_072,
false,
false,
true,
),
(
"qwen/qwen2.5-coder-7b-instruct",
"Qwen 2.5 Coder 7B Instruct (128K)",
131_072,
false,
false,
false,
),
(
"qwen/qwen2.5-coder-32b-instruct",
"Qwen 2.5 Coder 32B Instruct (128K)",
131_072,
false,
false,
false,
),
(
"qwen/qwq-32b",
"QwQ 32B (128K, thinking, free)",
131_072,
false,
true,
true,
),
(
"qwen/qwen3-coder-480b-a35b-instruct",
"Qwen3 Coder 480B MoE (128K)",
131_072,
false,
false,
false,
),
(
"qwen/qwen3-next-80b-a3b-instruct",
"Qwen3 Next 80B-A3B Instruct (128K)",
131_072,
false,
false,
false,
),
(
"qwen/qwen3-next-80b-a3b-thinking",
"Qwen3 Next 80B-A3B Thinking (128K, thinking)",
131_072,
false,
true,
false,
),
(
"moonshotai/kimi-k2-instruct",
"Kimi K2 Instruct (128K, free)",
131_072,
false,
false,
true,
),
(
"moonshotai/kimi-k2-thinking",
"Kimi K2 Thinking (128K, thinking)",
131_072,
false,
true,
false,
),
(
"moonshotai/kimi-k2-instruct-0905",
"Kimi K2 Instruct 0905 (128K)",
131_072,
false,
false,
false,
),
(
"minimaxai/minimax-m2.5",
"MiniMax M2.5 (128K)",
131_072,
false,
false,
false,
),
(
"google/gemma-2-9b-it",
"Gemma 2 9B IT (8K, free)",
8_192,
false,
false,
true,
),
(
"google/gemma-2-27b-it",
"Gemma 2 27B IT (8K)",
8_192,
false,
false,
false,
),
(
"marin/marin-8b-instruct",
"Marin 8B Instruct (128K, free)",
131_072,
false,
false,
true,
),
(
"databricks/dbrx-instruct",
"DBRX Instruct (32K)",
32_768,
false,
false,
false,
),
(
"snowflake/arctic",
"Snowflake Arctic (4K)",
4_096,
false,
false,
false,
),
(
"upstage/solar-10.7b-instruct",
"Solar 10.7B Instruct (4K, free)",
4_096,
false,
false,
true,
),
(
"z-ai/glm4.7",
"GLM-4.7 (128K)",
131_072,
false,
false,
false,
),
(
"openai/gpt-oss-20b",
"OpenAI OSS 20B (128K)",
131_072,
false,
false,
false,
),
(
"openai/gpt-oss-120b",
"OpenAI OSS 120B (128K)",
131_072,
false,
false,
false,
),
(
"bytedance/seed-oss-36b-instruct",
"Seed OSS 36B Instruct (128K)",
131_072,
false,
false,
false,
),
(
"stepfun-ai/step-3-5-flash",
"Step-3.5 Flash (128K)",
131_072,
false,
false,
false,
),
];
const NVIDIA_FREE_MODELS: &[&str] = &[
"nvidia/llama-3.3-nemotron-super-49b-v1",
"nvidia/llama-3.3-nemotron-super-49b-v1.5",
"nvidia/llama-3.1-nemotron-nano-8b-v1",
"nvidia/llama-3.1-nemotron-nano-4b-v1_1",
"nvidia/nemotron-3-nano-30b-a3b",
"nvidia/nemotron-3-super-120b-a12b",
"deepseek-ai/deepseek-v4-flash",
"meta/llama-3.3-70b-instruct",
"meta/llama-3.1-70b-instruct",
"meta/llama-3.1-8b-instruct",
"meta/llama-3.2-3b-instruct",
"meta/llama-3.2-1b-instruct",
"meta/llama-4-maverick-17b-128e-instruct",
"microsoft/phi-4-mini-instruct",
"microsoft/phi-4-mini-flash-reasoning",
"mistralai/mistral-7b-instruct-v0.3",
"qwen/qwen2.5-7b-instruct",
"qwen/qwq-32b",
"moonshotai/kimi-k2-instruct",
"google/gemma-2-9b-it",
"marin/marin-8b-instruct",
"upstage/solar-10.7b-instruct",
];
type NvidiaMsgContent = serde_json::Value;
#[derive(Debug, Serialize)]
struct NvidiaMessageReq {
role: String,
content: NvidiaMsgContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<NvidiaToolCallReq>>,
}
#[derive(Debug, Serialize)]
struct NvidiaToolCallReq {
id: String,
#[serde(rename = "type")]
call_type: String,
function: NvidiaFnCallReq,
}
#[derive(Debug, Serialize)]
struct NvidiaFnCallReq {
name: String,
arguments: String,
}
#[derive(Debug, Serialize)]
struct NvidiaChatReq<'a> {
model: &'a str,
messages: Vec<NvidiaMessageReq>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<NvidiaRespFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<String>,
}
#[derive(Debug, Serialize)]
struct NvidiaRespFormat {
#[serde(rename = "type")]
format_type: String,
}
#[derive(Debug, Deserialize)]
struct NvidiaChatCompletion {
#[serde(default)]
id: Option<String>,
#[serde(default)]
model: Option<String>,
choices: Vec<NvidiaCompletionChoice>,
#[serde(default)]
usage: Option<NvidiaCompletionUsage>,
}
#[derive(Debug, Deserialize)]
struct NvidiaCompletionChoice {
message: NvidiaCompletionMessage,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct NvidiaCompletionMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
#[serde(default)]
tool_calls: Vec<NvidiaToolCallResp>,
}
#[derive(Debug, Deserialize)]
struct NvidiaToolCallResp {
id: String,
#[serde(rename = "type", default)]
call_type: String,
function: NvidiaFnCallResp,
}
#[derive(Debug, Deserialize)]
struct NvidiaFnCallResp {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize, Default)]
struct NvidiaCompletionUsage {
#[serde(default)]
prompt_tokens: usize,
#[serde(default)]
completion_tokens: usize,
}
#[derive(Debug, Deserialize)]
pub struct NvidiaModelsResponse {
pub object: String,
pub data: Vec<NvidiaModelInfo>,
}
#[derive(Debug, Deserialize)]
pub struct NvidiaModelInfo {
pub id: String,
pub object: String,
#[serde(default)]
pub created: i64,
#[serde(default)]
pub owned_by: String,
#[serde(skip)]
pub is_free: bool,
}
#[derive(Debug)]
pub struct NvidiaProvider {
inner: OpenAICompatibleProvider,
model: String,
api_key: String,
base_url: String,
client: Client,
}
impl NvidiaProvider {
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("NVIDIA_API_KEY").map_err(|_| {
LlmError::ConfigError(
"NVIDIA_API_KEY environment variable not set. \
Get your free API key from https://build.nvidia.com"
.to_string(),
)
})?;
if api_key.is_empty() {
return Err(LlmError::ConfigError(
"NVIDIA_API_KEY is empty. Please set a valid API key from https://build.nvidia.com"
.to_string(),
));
}
let model =
std::env::var("NVIDIA_MODEL").unwrap_or_else(|_| NVIDIA_DEFAULT_MODEL.to_string());
let base_url =
std::env::var("NVIDIA_BASE_URL").unwrap_or_else(|_| NVIDIA_BASE_URL.to_string());
Self::new(api_key, model, Some(base_url))
}
pub fn from_config(config: &ProviderConfig) -> Result<Self> {
let api_key = if let Some(env_var) = &config.api_key_env {
std::env::var(env_var).map_err(|_| {
LlmError::ConfigError(format!(
"API key environment variable '{}' not set for NVIDIA provider.",
env_var
))
})?
} else {
return Err(LlmError::ConfigError(
"NVIDIA provider requires api_key_env to be set.".to_string(),
));
};
if api_key.is_empty() {
return Err(LlmError::ConfigError(
"NVIDIA API key is empty.".to_string(),
));
}
let model = config
.default_llm_model
.clone()
.unwrap_or_else(|| NVIDIA_DEFAULT_MODEL.to_string());
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| NVIDIA_BASE_URL.to_string());
Self::new(api_key, model, Some(base_url))
}
pub fn new(api_key: String, model: String, base_url: Option<String>) -> Result<Self> {
let base_url = base_url.unwrap_or_else(|| NVIDIA_BASE_URL.to_string());
let config = Self::build_provider_config(&api_key, &model, &base_url);
let inner = OpenAICompatibleProvider::from_config(config)?;
let client = Client::builder()
.timeout(Duration::from_secs(NVIDIA_TIMEOUT_SECS))
.build()
.map_err(|e| LlmError::ConfigError(format!("Failed to build HTTP client: {}", e)))?;
debug!(
provider = NVIDIA_PROVIDER_NAME,
model = %model,
base_url = %base_url,
"Created NVIDIA NIM provider"
);
Ok(Self {
inner,
model,
api_key,
base_url,
client,
})
}
pub fn with_model(mut self, model: &str) -> Self {
self.model = model.to_string();
self.inner = self.inner.with_model(model);
self
}
pub fn context_length(model: &str) -> usize {
NVIDIA_CHAT_MODELS
.iter()
.find(|(id, _, _, _, _, _)| *id == model)
.map(|(_, _, ctx, _, _, _)| *ctx)
.unwrap_or(32_768)
}
pub fn supports_vision(model: &str) -> bool {
NVIDIA_CHAT_MODELS
.iter()
.find(|(id, _, _, _, _, _)| *id == model)
.map(|(_, _, _, vision, _, _)| *vision)
.unwrap_or_else(|| {
model.contains("vision")
|| model.contains("vl")
|| model.contains("multimodal")
|| model.contains("maverick")
})
}
pub fn supports_thinking(model: &str) -> bool {
NVIDIA_CHAT_MODELS
.iter()
.find(|(id, _, _, _, _, _)| *id == model)
.map(|(_, _, _, _, thinking, _)| *thinking)
.unwrap_or(false)
}
pub fn is_free_model(model: &str) -> bool {
NVIDIA_FREE_MODELS.contains(&model)
}
pub fn available_models() -> Vec<(&'static str, &'static str, usize, bool, bool, bool)> {
NVIDIA_CHAT_MODELS.to_vec()
}
pub fn free_models() -> Vec<(&'static str, &'static str, usize)> {
NVIDIA_CHAT_MODELS
.iter()
.filter(|(_, _, _, _, _, free)| *free)
.map(|(id, name, ctx, _, _, _)| (*id, *name, *ctx))
.collect()
}
pub async fn list_models(&self) -> Result<NvidiaModelsResponse> {
let url = format!("{}/models", self.base_url.trim_end_matches('/'));
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Accept", "application/json")
.send()
.await
.map_err(|e| LlmError::NetworkError(format!("Failed to list NVIDIA models: {}", e)))?;
let status = response.status();
if status == reqwest::StatusCode::ACCEPTED {
let body = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"NVIDIA returned 202 Accepted for model listing (unexpected): {}",
body
)));
}
let body = response.text().await.map_err(|e| {
LlmError::NetworkError(format!("Failed to read NVIDIA models response: {}", e))
})?;
if !status.is_success() {
return Err(LlmError::ApiError(format!(
"NVIDIA models list failed ({status}): {body}"
)));
}
let mut resp: NvidiaModelsResponse = serde_json::from_str(&body)
.map_err(|e| LlmError::ApiError(format!("Failed to parse models response: {e}")))?;
for model in &mut resp.data {
model.is_free = NVIDIA_FREE_MODELS.contains(&model.id.as_str());
}
Ok(resp)
}
fn build_messages(messages: &[ChatMessage]) -> Vec<NvidiaMessageReq> {
messages
.iter()
.map(|msg| {
let role = match msg.role {
ChatRole::System => "system",
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
ChatRole::Tool | ChatRole::Function => "tool",
};
let content: serde_json::Value =
if msg.images.as_ref().is_some_and(|imgs| !imgs.is_empty()) {
let mut parts: Vec<serde_json::Value> = Vec::new();
if !msg.content.is_empty() {
parts.push(serde_json::json!({"type": "text", "text": &msg.content}));
}
if let Some(ref images) = msg.images {
for img in images {
let mut img_obj =
serde_json::json!({"type": "image_url", "image_url": {"url": img.to_data_uri()}});
if let Some(ref detail) = img.detail {
img_obj["image_url"]["detail"] =
serde_json::Value::String(detail.clone());
}
parts.push(img_obj);
}
}
serde_json::Value::Array(parts)
} else {
serde_json::Value::String(msg.content.clone())
};
let tool_calls = msg.tool_calls.as_ref().map(|tcs| {
tcs.iter()
.map(|tc| NvidiaToolCallReq {
id: tc.id.clone(),
call_type: tc.call_type.clone(),
function: NvidiaFnCallReq {
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
},
})
.collect::<Vec<_>>()
});
NvidiaMessageReq {
role: role.to_string(),
content,
name: msg.name.clone(),
tool_call_id: msg.tool_call_id.clone(),
tool_calls,
}
})
.collect()
}
fn tool_choice_to_json(choice: &ToolChoice) -> serde_json::Value {
match choice {
ToolChoice::Auto(s) | ToolChoice::Required(s) => serde_json::json!(s),
ToolChoice::Function { function, .. } => serde_json::json!({
"type": "function",
"function": {"name": function.name}
}),
}
}
async fn execute_non_streaming_chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
tools: Option<(&[ToolDefinition], Option<ToolChoice>)>,
) -> Result<LLMResponse> {
let opts = options.cloned().unwrap_or_default();
let use_json_mode = opts
.response_format
.as_ref()
.map(|f| f == "json_object" || f == "json")
.unwrap_or(false);
let (api_tools, api_tool_choice) = match tools {
Some((defs, choice)) if !defs.is_empty() => {
let tc = choice.as_ref().map(Self::tool_choice_to_json);
(Some(defs.to_vec()), tc)
}
_ => (None, None),
};
let request = NvidiaChatReq {
model: &self.model,
messages: Self::build_messages(messages),
temperature: opts.temperature,
top_p: opts.top_p,
max_tokens: opts.max_tokens,
stop: opts.stop.clone(),
frequency_penalty: opts.frequency_penalty,
presence_penalty: opts.presence_penalty,
stream: false,
tools: api_tools,
tool_choice: api_tool_choice,
response_format: if use_json_mode {
Some(NvidiaRespFormat {
format_type: "json_object".to_string(),
})
} else {
None
},
reasoning_effort: opts.reasoning_effort.clone(),
};
let chat_url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
debug!(
provider = NVIDIA_PROVIDER_NAME,
model = %self.model,
url = %chat_url,
"Sending NVIDIA non-streaming chat request"
);
let response = self
.client
.post(&chat_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(format!("NVIDIA chat request failed: {e}")))?;
let final_response = self.resolve_202(response).await?;
self.parse_completion_response(final_response).await
}
async fn resolve_202(&self, response: reqwest::Response) -> Result<reqwest::Response> {
if response.status() != reqwest::StatusCode::ACCEPTED {
return Ok(response);
}
let nvcf_reqid = response
.headers()
.get(NVIDIA_REQID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.ok_or_else(|| {
LlmError::ApiError(
"NVIDIA returned HTTP 202 (async-queued inference) but the required \
'NVCF-REQID' response header is missing — cannot poll for the result."
.to_string(),
)
})?;
debug!(
provider = NVIDIA_PROVIDER_NAME,
nvcf_reqid = %nvcf_reqid,
poll_interval_ms = NVIDIA_POLL_INTERVAL_MS,
max_attempts = NVIDIA_MAX_POLL_ATTEMPTS,
"NVIDIA returned HTTP 202 — starting async polling"
);
let poll_url = format!("{}/{}", NVIDIA_NVCF_STATUS_URL, nvcf_reqid);
for attempt in 0..NVIDIA_MAX_POLL_ATTEMPTS {
tokio::time::sleep(Duration::from_millis(NVIDIA_POLL_INTERVAL_MS)).await;
let poll_resp = self
.client
.get(&poll_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Accept", "application/json")
.send()
.await
.map_err(|e| {
LlmError::NetworkError(format!(
"NVIDIA async poll #{attempt} failed for NVCF-REQID={nvcf_reqid}: {e}"
))
})?;
if poll_resp.status() == reqwest::StatusCode::ACCEPTED {
debug!(
attempt,
nvcf_reqid = %nvcf_reqid,
"NVIDIA async inference still pending"
);
continue;
}
debug!(
attempt,
nvcf_reqid = %nvcf_reqid,
status = poll_resp.status().as_u16(),
"NVIDIA async inference resolved"
);
return Ok(poll_resp);
}
Err(LlmError::ApiError(format!(
"NVIDIA async inference timed out after {NVIDIA_MAX_POLL_ATTEMPTS} poll attempts \
({NVIDIA_POLL_INTERVAL_MS} ms interval, {:.0} s total) for NVCF-REQID='{nvcf_reqid}'",
f64::from(NVIDIA_MAX_POLL_ATTEMPTS) * NVIDIA_POLL_INTERVAL_MS as f64 / 1000.0
)))
}
async fn parse_completion_response(&self, response: reqwest::Response) -> Result<LLMResponse> {
let status = response.status();
let body = response.text().await.map_err(|e| {
LlmError::NetworkError(format!("Failed to read NVIDIA response body: {e}"))
})?;
if !status.is_success() {
if let Ok(val) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(msg) = val.pointer("/error/message").and_then(|v| v.as_str()) {
return Err(LlmError::ApiError(format!(
"NVIDIA API error ({status}): {msg}"
)));
}
}
return Err(LlmError::ApiError(format!(
"NVIDIA API error ({status}): {}",
&body[..1000.min(body.len())]
)));
}
let completion: NvidiaChatCompletion = serde_json::from_str(&body).map_err(|e| {
LlmError::ApiError(format!(
"Failed to parse NVIDIA completion response: {e} | body preview: {}",
&body[..500.min(body.len())]
))
})?;
let choice = completion
.choices
.into_iter()
.next()
.ok_or_else(|| LlmError::ApiError("NVIDIA response has no choices".to_string()))?;
let content = choice.message.content.unwrap_or_default();
let thinking = choice.message.reasoning_content.filter(|s| !s.is_empty());
let (prompt_tokens, completion_tokens) = completion
.usage
.map(|u| (u.prompt_tokens, u.completion_tokens))
.unwrap_or((0, 0));
let model_name = completion.model.unwrap_or_else(|| self.model.clone());
let finish_reason = choice.finish_reason.unwrap_or_else(|| "stop".to_string());
let tool_calls: Vec<crate::traits::ToolCall> = choice
.message
.tool_calls
.into_iter()
.map(|tc| crate::traits::ToolCall {
id: tc.id,
call_type: if tc.call_type.is_empty() {
"function".to_string()
} else {
tc.call_type
},
function: crate::traits::FunctionCall {
name: tc.function.name,
arguments: tc.function.arguments,
},
thought_signature: None,
})
.collect();
let mut resp = LLMResponse::new(content, &model_name)
.with_usage(prompt_tokens, completion_tokens)
.with_finish_reason(finish_reason);
if let Some(id) = completion.id {
resp = resp.with_metadata("id", serde_json::Value::String(id));
}
if let Some(thinking_content) = thinking {
resp = resp.with_thinking_content(thinking_content);
}
if !tool_calls.is_empty() {
resp.tool_calls = tool_calls;
}
Ok(resp)
}
fn build_provider_config(api_key: &str, model: &str, base_url: &str) -> ProviderConfig {
let models: Vec<ModelCard> = NVIDIA_CHAT_MODELS
.iter()
.map(|(id, display, ctx, vision, thinking, _free)| ModelCard {
name: id.to_string(),
display_name: display.to_string(),
model_type: ModelType::Llm,
capabilities: ModelCapabilities {
context_length: *ctx,
supports_function_calling: true,
supports_json_mode: true,
supports_streaming: true,
supports_system_message: true,
supports_vision: *vision,
supports_thinking: *thinking,
..Default::default()
},
..Default::default()
})
.collect();
ProviderConfig {
name: NVIDIA_PROVIDER_NAME.to_string(),
display_name: "NVIDIA NIM".to_string(),
provider_type: ConfigProviderType::OpenAICompatible,
api_key: Some(api_key.to_string()),
api_key_env: Some("NVIDIA_API_KEY".to_string()),
base_url: Some(base_url.to_string()),
base_url_env: Some("NVIDIA_BASE_URL".to_string()),
default_llm_model: Some(model.to_string()),
default_embedding_model: None,
models,
headers: std::collections::HashMap::new(),
enabled: true,
timeout_seconds: NVIDIA_TIMEOUT_SECS,
..Default::default()
}
}
}
#[async_trait]
impl LLMProvider for NvidiaProvider {
fn name(&self) -> &str {
NVIDIA_PROVIDER_NAME
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
Self::context_length(&self.model)
}
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.complete_with_options(prompt, &CompletionOptions::default())
.await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
let mut messages = Vec::new();
if let Some(ref sys) = options.system_prompt {
messages.push(ChatMessage::system(sys.clone()));
}
messages.push(ChatMessage::user(prompt));
self.execute_non_streaming_chat(&messages, Some(options), None)
.await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
self.execute_non_streaming_chat(messages, options, None)
.await
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
self.execute_non_streaming_chat(messages, options, Some((tools, tool_choice)))
.await
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<StreamChunk>>> {
self.inner
.chat_with_tools_stream(messages, tools, tool_choice, options)
.await
}
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
self.inner.stream(prompt).await
}
fn supports_function_calling(&self) -> bool {
self.inner.supports_function_calling()
}
fn supports_tool_streaming(&self) -> bool {
self.inner.supports_tool_streaming()
}
}
#[async_trait]
impl EmbeddingProvider for NvidiaProvider {
fn name(&self) -> &str {
NVIDIA_PROVIDER_NAME
}
fn model(&self) -> &str {
"none"
}
fn dimension(&self) -> usize {
0
}
fn max_tokens(&self) -> usize {
0
}
async fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
Err(LlmError::ConfigError(
"NVIDIA NIM embeddings are not supported via NvidiaProvider in this release. \
Use OpenAICompatibleProvider with base_url=https://integrate.api.nvidia.com/v1 \
and an embedding model ID such as 'nvidia/nv-embedqa-e5-v5'."
.to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_length_known_model() {
assert_eq!(
NvidiaProvider::context_length("meta/llama-3.1-8b-instruct"),
131_072
);
assert_eq!(
NvidiaProvider::context_length("nvidia/nemotron-3-nano-30b-a3b"),
1_000_000
);
assert_eq!(
NvidiaProvider::context_length("deepseek-ai/deepseek-v4-flash"),
65_536
);
}
#[test]
fn test_context_length_unknown_model_fallback() {
assert_eq!(NvidiaProvider::context_length("unknown/model-xyz"), 32_768);
}
#[test]
fn test_supports_vision() {
assert!(NvidiaProvider::supports_vision(
"meta/llama-4-maverick-17b-128e-instruct"
));
assert!(NvidiaProvider::supports_vision(
"meta/llama-3.2-11b-vision-instruct"
));
assert!(!NvidiaProvider::supports_vision(
"meta/llama-3.3-70b-instruct"
));
}
#[test]
fn test_supports_thinking() {
assert!(NvidiaProvider::supports_thinking(
"nvidia/llama-3.3-nemotron-super-49b-v1"
));
assert!(NvidiaProvider::supports_thinking(
"deepseek-ai/deepseek-v4-flash"
));
assert!(!NvidiaProvider::supports_thinking(
"meta/llama-3.3-70b-instruct"
));
}
#[test]
fn test_is_free_model() {
assert!(NvidiaProvider::is_free_model("meta/llama-3.1-8b-instruct"));
assert!(NvidiaProvider::is_free_model(
"nvidia/llama-3.3-nemotron-super-49b-v1"
));
assert!(!NvidiaProvider::is_free_model(
"deepseek-ai/deepseek-v4-pro"
));
assert!(!NvidiaProvider::is_free_model(
"nvidia/llama-3.1-nemotron-ultra-253b-v1"
));
}
#[test]
fn test_free_models_list() {
let free = NvidiaProvider::free_models();
assert!(!free.is_empty());
let ids: Vec<&str> = free.iter().map(|(id, _, _)| *id).collect();
assert!(ids.contains(&"meta/llama-3.1-8b-instruct"));
assert!(ids.contains(&"qwen/qwq-32b"));
}
#[test]
fn test_available_models_not_empty() {
let models = NvidiaProvider::available_models();
assert!(!models.is_empty());
assert!(models.len() >= 20);
}
#[test]
fn test_catalog_integrity() {
for free_id in NVIDIA_FREE_MODELS {
let found = NVIDIA_CHAT_MODELS
.iter()
.any(|(id, _, _, _, _, _)| id == free_id);
assert!(
found,
"Free model '{}' is not in NVIDIA_CHAT_MODELS catalog — add it or remove it from NVIDIA_FREE_MODELS",
free_id
);
}
}
#[test]
fn test_catalog_no_duplicate_ids() {
let mut ids: Vec<&str> = NVIDIA_CHAT_MODELS
.iter()
.map(|(id, _, _, _, _, _)| *id)
.collect();
let original_len = ids.len();
ids.sort_unstable();
ids.dedup();
assert_eq!(
ids.len(),
original_len,
"NVIDIA_CHAT_MODELS contains duplicate model IDs"
);
}
#[test]
fn test_from_env_missing_key() {
let saved = std::env::var("NVIDIA_API_KEY").ok();
std::env::remove_var("NVIDIA_API_KEY");
let result = NvidiaProvider::from_env();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("NVIDIA_API_KEY"),
"Error should mention NVIDIA_API_KEY, got: {}",
err
);
if let Some(key) = saved {
std::env::set_var("NVIDIA_API_KEY", key);
}
}
#[test]
fn test_from_env_empty_key() {
let saved = std::env::var("NVIDIA_API_KEY").ok();
std::env::set_var("NVIDIA_API_KEY", "");
let result = NvidiaProvider::from_env();
assert!(result.is_err());
if let Some(key) = saved {
std::env::set_var("NVIDIA_API_KEY", key);
} else {
std::env::remove_var("NVIDIA_API_KEY");
}
}
#[test]
fn test_new_creates_provider() {
let provider = NvidiaProvider::new(
"nvapi-test-key".to_string(),
"meta/llama-3.3-70b-instruct".to_string(),
None,
);
assert!(provider.is_ok());
let p = provider.unwrap();
assert_eq!(LLMProvider::name(&p), "nvidia");
assert_eq!(LLMProvider::model(&p), "meta/llama-3.3-70b-instruct");
}
#[test]
fn test_with_model_changes_model() {
let provider = NvidiaProvider::new(
"nvapi-test-key".to_string(),
NVIDIA_DEFAULT_MODEL.to_string(),
None,
)
.unwrap();
let provider2 = provider.with_model("deepseek-ai/deepseek-v4-flash");
assert_eq!(
LLMProvider::model(&provider2),
"deepseek-ai/deepseek-v4-flash"
);
}
#[test]
fn test_max_context_length() {
let provider = NvidiaProvider::new(
"nvapi-test-key".to_string(),
"nvidia/nemotron-3-nano-30b-a3b".to_string(),
None,
)
.unwrap();
assert_eq!(provider.max_context_length(), 1_000_000);
}
#[test]
fn test_provider_name() {
let provider = NvidiaProvider::new(
"nvapi-test-key".to_string(),
NVIDIA_DEFAULT_MODEL.to_string(),
None,
)
.unwrap();
assert_eq!(LLMProvider::name(&provider), "nvidia");
}
#[test]
fn test_embed_returns_error() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let provider = NvidiaProvider::new(
"nvapi-test-key".to_string(),
NVIDIA_DEFAULT_MODEL.to_string(),
None,
)
.unwrap();
let result = provider.embed(&["hello world".to_string()]).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("embeddings are not supported"), "Got: {}", err);
});
}
#[test]
fn test_polling_constants_are_sane() {
use std::hint::black_box;
let poll_interval_ms = black_box(NVIDIA_POLL_INTERVAL_MS);
let max_poll_attempts = black_box(NVIDIA_MAX_POLL_ATTEMPTS);
let poll_url = black_box(NVIDIA_NVCF_STATUS_URL);
let reqid_header = black_box(NVIDIA_REQID_HEADER);
assert!(poll_interval_ms > 0, "Poll interval must be > 0 ms");
let total_ms = u64::from(max_poll_attempts) * poll_interval_ms;
assert!(
total_ms >= 60_000,
"Total polling window {total_ms} ms is less than 60 s — increase limits"
);
assert!(!poll_url.is_empty(), "NVCF polling URL must not be empty");
assert!(
!reqid_header.is_empty(),
"NVCF-REQID header name must not be empty"
);
}
#[test]
fn test_build_messages_plain_text() {
let messages = vec![
ChatMessage::system("You are a helpful assistant."),
ChatMessage::user("Hello, world!"),
];
let reqs = NvidiaProvider::build_messages(&messages);
assert_eq!(reqs.len(), 2);
assert_eq!(reqs[0].role, "system");
assert_eq!(
reqs[0].content,
serde_json::json!("You are a helpful assistant.")
);
assert_eq!(reqs[1].role, "user");
assert_eq!(reqs[1].content, serde_json::json!("Hello, world!"));
assert!(reqs[0].tool_calls.is_none());
assert!(reqs[0].tool_call_id.is_none());
}
#[test]
fn test_build_messages_tool_role() {
use crate::traits::ChatRole;
let mut tool_msg = ChatMessage::user("tool result here");
tool_msg.role = ChatRole::Tool;
tool_msg.tool_call_id = Some("call_abc".to_string());
let reqs = NvidiaProvider::build_messages(&[tool_msg]);
assert_eq!(reqs[0].role, "tool");
assert_eq!(reqs[0].tool_call_id.as_deref(), Some("call_abc"));
}
#[test]
fn test_build_messages_with_images() {
use crate::traits::ImageData;
let mut msg = ChatMessage::user("What is in this image?");
msg.images = Some(vec![ImageData::new("aGVsbG8=", "image/png")]);
let reqs = NvidiaProvider::build_messages(&[msg]);
assert!(
reqs[0].content.is_array(),
"Expected multipart array, got: {:?}",
reqs[0].content
);
let parts = reqs[0].content.as_array().unwrap();
assert_eq!(parts.len(), 2, "Expected text + image_url part");
assert_eq!(parts[0]["type"], "text");
assert_eq!(parts[1]["type"], "image_url");
let url = parts[1]["image_url"]["url"].as_str().unwrap();
assert!(url.starts_with("data:image/png;base64,"), "URL: {url}");
}
#[test]
fn test_tool_choice_to_json_auto() {
let choice = ToolChoice::auto();
let val = NvidiaProvider::tool_choice_to_json(&choice);
assert_eq!(val, serde_json::json!("auto"));
}
#[test]
fn test_tool_choice_to_json_none() {
let choice = ToolChoice::none();
let val = NvidiaProvider::tool_choice_to_json(&choice);
assert_eq!(val, serde_json::json!("none"));
}
#[test]
fn test_tool_choice_to_json_required() {
let choice = ToolChoice::required();
let val = NvidiaProvider::tool_choice_to_json(&choice);
assert_eq!(val, serde_json::json!("required"));
}
#[test]
fn test_tool_choice_to_json_function() {
let choice = ToolChoice::function("get_weather");
let val = NvidiaProvider::tool_choice_to_json(&choice);
assert_eq!(val["type"], "function");
assert_eq!(val["function"]["name"], "get_weather");
}
#[test]
fn test_nvidia_chat_req_serialises_stream_false() {
let provider = NvidiaProvider::new(
"nvapi-key".to_string(),
"meta/llama-3.1-8b-instruct".to_string(),
None,
)
.unwrap();
let req = NvidiaChatReq {
model: &provider.model,
messages: vec![],
temperature: Some(0.7),
top_p: None,
max_tokens: Some(512),
stop: None,
frequency_penalty: None,
presence_penalty: None,
stream: false,
tools: None,
tool_choice: None,
response_format: None,
reasoning_effort: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["stream"], serde_json::json!(false));
assert_eq!(json["model"], "meta/llama-3.1-8b-instruct");
assert_eq!(json["max_tokens"], 512);
assert!((json["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
assert!(json.get("top_p").is_none());
}
#[test]
fn test_nvidia_chat_req_json_mode() {
let req = NvidiaChatReq {
model: "deepseek-ai/deepseek-v4-flash",
messages: vec![],
temperature: None,
top_p: None,
max_tokens: None,
stop: None,
frequency_penalty: None,
presence_penalty: None,
stream: false,
tools: None,
tool_choice: None,
response_format: Some(NvidiaRespFormat {
format_type: "json_object".to_string(),
}),
reasoning_effort: Some("high".to_string()),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["response_format"]["type"], "json_object");
assert_eq!(json["reasoning_effort"], "high");
}
#[test]
fn test_nvcf_reqid_header_name() {
assert_eq!(NVIDIA_REQID_HEADER, "NVCF-REQID");
}
#[test]
fn test_nvcf_polling_url_format() {
assert!(
NVIDIA_NVCF_STATUS_URL.contains("api.nvcf.nvidia.com"),
"Polling URL should target api.nvcf.nvidia.com, got: {NVIDIA_NVCF_STATUS_URL}"
);
assert!(
NVIDIA_NVCF_STATUS_URL.contains("pexec/status"),
"Polling URL should contain 'pexec/status', got: {NVIDIA_NVCF_STATUS_URL}"
);
let sample_id = "abc-123-def-456";
let full_url = format!("{NVIDIA_NVCF_STATUS_URL}/{sample_id}");
assert_eq!(
full_url,
format!("https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/{sample_id}")
);
}
}