use crate::mcp::McpClient;
use crate::util::STREAM_CHUNK_TIMEOUT;
use futures::StreamExt;
#[cfg(feature = "native-inference")]
use crate::provider::native::InferenceBackend;
use rig::client::{CompletionClient, ProviderClient};
use rig::completion::{CompletionModel as _, GetTokenUsage, Prompt, PromptError, ToolDefinition};
use rig::providers::{anthropic, deepseek, gemini, groq, mistral, openai, xai};
use rig::streaming::StreamedAssistantContent;
use rig::tool::{ToolDyn, ToolError};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::timeout;
#[derive(Debug)]
pub struct McpToolError {
kind: McpToolErrorKind,
message: String,
}
#[derive(Debug, Clone, Copy)]
pub enum McpToolErrorKind {
InvalidArguments,
NotConfigured,
CallFailed,
SerializationError,
}
impl McpToolError {
pub fn invalid_args(msg: impl Into<String>) -> Self {
Self {
kind: McpToolErrorKind::InvalidArguments,
message: msg.into(),
}
}
pub fn not_configured(msg: impl Into<String>) -> Self {
Self {
kind: McpToolErrorKind::NotConfigured,
message: msg.into(),
}
}
pub fn call_failed(msg: impl Into<String>) -> Self {
Self {
kind: McpToolErrorKind::CallFailed,
message: msg.into(),
}
}
pub fn serialization(msg: impl Into<String>) -> Self {
Self {
kind: McpToolErrorKind::SerializationError,
message: msg.into(),
}
}
}
impl std::fmt::Display for McpToolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let kind_str = match self.kind {
McpToolErrorKind::InvalidArguments => "InvalidArguments",
McpToolErrorKind::NotConfigured => "NotConfigured",
McpToolErrorKind::CallFailed => "CallFailed",
McpToolErrorKind::SerializationError => "SerializationError",
};
write!(f, "[{}] {}", kind_str, self.message)
}
}
impl std::error::Error for McpToolError {}
#[derive(Debug, Clone, Default)]
pub struct InferOptions {
pub model: Option<String>,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub system: Option<String>,
}
#[derive(Debug, Clone)]
pub enum RigProvider {
Claude(anthropic::Client),
OpenAI(openai::Client),
Mistral(mistral::Client),
Groq(groq::Client),
DeepSeek(deepseek::Client),
Gemini(gemini::Client),
XAi(xai::Client),
#[cfg(feature = "native-inference")]
Native(super::native::NativeRuntime),
}
impl RigProvider {
pub fn from_name(name: &str) -> Result<Self, crate::error::NikaError> {
let provider = crate::core::find_provider(name).ok_or_else(|| {
crate::error::NikaError::ProviderNotConfigured {
provider: name.to_string(),
}
})?;
if provider.requires_key && !provider.has_env_key() {
return Err(crate::error::NikaError::MissingApiKey {
provider: provider.id.to_string(),
});
}
match provider.id {
"anthropic" => Ok(Self::claude()),
"openai" => Ok(Self::openai()),
"mistral" => Ok(Self::mistral()),
"groq" => Ok(Self::groq()),
"deepseek" => Ok(Self::deepseek()),
"gemini" => Ok(Self::gemini()),
"xai" => Ok(Self::xai()),
#[cfg(feature = "native-inference")]
"native" => Ok(Self::native()),
_ => Err(crate::error::NikaError::ProviderNotConfigured {
provider: name.to_string(),
}),
}
}
pub fn claude() -> Self {
let client = anthropic::Client::from_env();
RigProvider::Claude(client)
}
pub fn openai() -> Self {
let client = openai::Client::from_env();
RigProvider::OpenAI(client)
}
pub fn mistral() -> Self {
let client = mistral::Client::from_env();
RigProvider::Mistral(client)
}
pub fn groq() -> Self {
let client = groq::Client::from_env();
RigProvider::Groq(client)
}
pub fn deepseek() -> Self {
let client = deepseek::Client::from_env();
RigProvider::DeepSeek(client)
}
pub fn gemini() -> Self {
let client = gemini::Client::from_env();
RigProvider::Gemini(client)
}
pub fn xai() -> Self {
let client = xai::Client::from_env();
RigProvider::XAi(client)
}
#[cfg(feature = "native-inference")]
pub fn native() -> Self {
RigProvider::Native(super::native::NativeRuntime::new())
}
#[cfg(feature = "native-inference")]
pub async fn load_native_model(
&mut self,
model_path: impl Into<std::path::PathBuf>,
config: Option<super::native::LoadConfig>,
) -> Result<(), RigInferError> {
match self {
RigProvider::Native(runtime) => runtime
.load(model_path.into(), config.unwrap_or_default())
.await
.map_err(|e: super::native::NativeError| RigInferError::PromptError(e.to_string())),
_ => Err(RigInferError::PromptError(
"load_native_model only valid for Native provider".to_string(),
)),
}
}
#[cfg(feature = "native-inference")]
pub fn is_native_loaded(&self) -> bool {
match self {
RigProvider::Native(runtime) => runtime.is_loaded(),
_ => false,
}
}
pub fn name(&self) -> &'static str {
match self {
RigProvider::Claude(_) => "claude",
RigProvider::OpenAI(_) => "openai",
RigProvider::Mistral(_) => "mistral",
RigProvider::Groq(_) => "groq",
RigProvider::DeepSeek(_) => "deepseek",
RigProvider::Gemini(_) => "gemini",
RigProvider::XAi(_) => "xai",
#[cfg(feature = "native-inference")]
RigProvider::Native(_) => "native",
}
}
pub fn default_model(&self) -> &'static str {
match self {
RigProvider::Claude(_) => "claude-sonnet-4-6",
RigProvider::OpenAI(_) => openai::GPT_4O,
RigProvider::Mistral(_) => mistral::MISTRAL_LARGE,
RigProvider::Groq(_) => "llama-3.3-70b-versatile",
RigProvider::DeepSeek(_) => "deepseek-chat",
RigProvider::Gemini(_) => "gemini-2.0-flash",
RigProvider::XAi(_) => "grok-3-fast",
#[cfg(feature = "native-inference")]
RigProvider::Native(_) => "native-model",
}
}
pub async fn infer(&self, prompt: &str, model: Option<&str>) -> Result<String, RigInferError> {
let model_id = model.unwrap_or_else(|| self.default_model());
match self {
RigProvider::Claude(client) => {
let agent = client.agent(model_id).max_tokens(8192).build();
agent
.prompt(prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::OpenAI(client) => {
let agent = client.agent(model_id).max_tokens(8192).build();
agent
.prompt(prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::Mistral(client) => {
let agent = client.agent(model_id).max_tokens(8192).build();
agent
.prompt(prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::Groq(client) => {
let agent = client.agent(model_id).max_tokens(8192).build();
agent
.prompt(prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::DeepSeek(client) => {
let agent = client.agent(model_id).max_tokens(8192).build();
agent
.prompt(prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::Gemini(client) => {
let agent = client.agent(model_id).max_tokens(8192).build();
agent
.prompt(prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::XAi(client) => {
let agent = client.agent(model_id).max_tokens(8192).build();
agent
.prompt(prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
#[cfg(feature = "native-inference")]
RigProvider::Native(runtime) => {
runtime
.infer(prompt, super::native::ChatOptions::default())
.await
.map(|r| r.message.content)
.map_err(|e: super::native::NativeError| {
RigInferError::PromptError(e.to_string())
})
}
}
}
pub async fn infer_vision(
&self,
user_content: Vec<rig::completion::message::UserContent>,
model: Option<&str>,
system: Option<&str>,
max_tokens: Option<u32>,
) -> Result<String, RigInferError> {
use rig::completion::message::Message;
use rig::OneOrMany;
if matches!(self, RigProvider::DeepSeek(_)) {
return Err(RigInferError::VisionNotSupported(
"DeepSeek does not support vision/multimodal content".to_string(),
));
}
#[cfg(feature = "native-inference")]
if let RigProvider::Native(runtime) = self {
if !runtime.supports_vision() {
return Err(RigInferError::VisionNotSupported(
"Native model does not support vision. Load a vision model via \
NativeModelKind::VisionHf (e.g., `nika model vision <model_id> --isq Q4K`)"
.to_string(),
));
}
let (prompt_text, vision_images) = extract_native_vision_parts(&user_content)?;
let options = super::native::ChatOptions {
max_tokens,
..Default::default()
};
let response = runtime
.infer_vision(&prompt_text, vision_images, options)
.await
.map_err(|e: super::native::NativeError| {
RigInferError::PromptError(e.to_string())
})?;
return Ok(response.message.content);
}
let model_id = model.unwrap_or_else(|| self.default_model());
let max_tok = max_tokens.map(u64::from).unwrap_or(8192);
let message = Message::User {
content: OneOrMany::many(user_content).map_err(|_| {
RigInferError::VisionNotSupported("content parts list is empty".to_string())
})?,
};
macro_rules! vision_prompt {
($client:expr) => {{
let mut builder = $client.agent(model_id).max_tokens(max_tok);
if let Some(sys) = system {
builder = builder.preamble(sys);
}
let agent = builder.build();
agent
.prompt(message)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}};
}
match self {
RigProvider::Claude(client) => vision_prompt!(client),
RigProvider::OpenAI(client) => vision_prompt!(client),
RigProvider::Mistral(client) => vision_prompt!(client),
RigProvider::Groq(client) => vision_prompt!(client),
RigProvider::Gemini(client) => vision_prompt!(client),
RigProvider::XAi(client) => vision_prompt!(client),
RigProvider::DeepSeek(_) => unreachable!("DeepSeek handled above"),
#[cfg(feature = "native-inference")]
RigProvider::Native(_) => unreachable!("Native handled above"),
}
}
pub async fn infer_vision_stream(
&self,
user_content: Vec<rig::completion::message::UserContent>,
tx: mpsc::Sender<StreamChunk>,
model: Option<&str>,
system: Option<&str>,
max_tokens: Option<u32>,
) -> Result<StreamResult, RigInferError> {
use rig::completion::message::Message;
use rig::OneOrMany;
if matches!(self, RigProvider::DeepSeek(_)) {
return Err(RigInferError::VisionNotSupported(
"DeepSeek does not support vision/multimodal content".to_string(),
));
}
#[cfg(feature = "native-inference")]
if let RigProvider::Native(runtime) = self {
if !runtime.supports_vision() {
return Err(RigInferError::VisionNotSupported(
"Native model does not support vision. Load a vision model via \
NativeModelKind::VisionHf (e.g., `nika model vision <model_id> --isq Q4K`)"
.to_string(),
));
}
let (prompt_text, vision_images) = extract_native_vision_parts(&user_content)?;
let options = super::native::ChatOptions {
max_tokens,
..Default::default()
};
let response = runtime
.infer_vision(&prompt_text, vision_images, options)
.await
.map_err(|e: super::native::NativeError| {
RigInferError::PromptError(e.to_string())
})?;
let text = response.message.content;
let _ = tx.send(StreamChunk::Done(text.clone())).await;
return Ok(StreamResult {
text,
..Default::default()
});
}
let model_id = model.unwrap_or_else(|| self.default_model());
let max_tok = max_tokens.map(u64::from).unwrap_or(8192);
let message = Message::User {
content: OneOrMany::many(user_content).map_err(|_| {
RigInferError::VisionNotSupported("content parts list is empty".to_string())
})?,
};
let mut response_parts: Vec<String> = Vec::new();
let mut result = StreamResult::default();
macro_rules! vision_stream {
($client:expr, $is_anthropic:expr) => {{
let model = $client.completion_model(model_id);
let mut builder = model.completion_request(message).max_tokens(max_tok);
if let Some(sys) = system {
builder = builder.preamble(sys.to_string());
}
let request = builder.build();
let mut stream = model
.stream(request)
.await
.map_err(|e| RigInferError::PromptError(e.to_string()))?;
consume_rig_stream(
&mut stream,
&tx,
&mut response_parts,
&mut result,
$is_anthropic,
)
.await?;
}};
}
match self {
RigProvider::Claude(client) => vision_stream!(client, true),
RigProvider::OpenAI(client) => vision_stream!(client, false),
RigProvider::Mistral(client) => vision_stream!(client, false),
RigProvider::Groq(client) => vision_stream!(client, false),
RigProvider::Gemini(client) => vision_stream!(client, false),
RigProvider::XAi(client) => vision_stream!(client, false),
RigProvider::DeepSeek(_) => unreachable!("DeepSeek handled above"),
#[cfg(feature = "native-inference")]
RigProvider::Native(_) => unreachable!("Native handled above"),
}
result.text = response_parts.join("");
Ok(result)
}
pub async fn infer_with_tools(
&self,
prompt: &str,
tools: Vec<Box<dyn ToolDyn>>,
model: Option<&str>,
max_tokens: Option<u32>,
system: Option<&str>,
) -> Result<String, RigInferError> {
use rig::agent::AgentBuilder;
use rig::message::ToolChoice as RigToolChoice;
let model_id = model.unwrap_or_else(|| self.default_model());
let max_tok = max_tokens.map(|v| v as u64).unwrap_or(8192);
macro_rules! build_agent_with_tools {
($client:expr) => {{
let mut builder = AgentBuilder::new($client.completion_model(model_id))
.tools(tools)
.tool_choice(RigToolChoice::Required)
.max_tokens(max_tok);
if let Some(sys) = system {
builder = builder.preamble(sys);
}
let agent = builder.build();
agent
.prompt(prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}};
}
match self {
RigProvider::Claude(client) => build_agent_with_tools!(client),
RigProvider::OpenAI(client) => build_agent_with_tools!(client),
RigProvider::Mistral(client) => build_agent_with_tools!(client),
RigProvider::Groq(client) => build_agent_with_tools!(client),
RigProvider::DeepSeek(client) => build_agent_with_tools!(client),
RigProvider::Gemini(client) => build_agent_with_tools!(client),
RigProvider::XAi(client) => build_agent_with_tools!(client),
#[cfg(feature = "native-inference")]
RigProvider::Native(_) => {
Err(RigInferError::PromptError(
"Native inference does not support tool-based structured output".to_string(),
))
}
}
}
pub async fn infer_with_options(
&self,
prompt: &str,
options: &InferOptions,
) -> Result<String, RigInferError> {
let model_id = options
.model
.as_deref()
.unwrap_or_else(|| self.default_model());
let max_tokens = options.max_tokens.unwrap_or(8192);
let user_prompt = prompt.to_string();
match self {
RigProvider::Claude(client) => {
let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
if let Some(system) = &options.system {
builder = builder.preamble(system);
}
if let Some(temp) = options.temperature {
builder = builder.temperature(temp);
}
let agent = builder.build();
agent
.prompt(&user_prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::OpenAI(client) => {
let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
if let Some(system) = &options.system {
builder = builder.preamble(system);
}
if let Some(temp) = options.temperature {
builder = builder.temperature(temp);
}
let agent = builder.build();
agent
.prompt(&user_prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::Mistral(client) => {
let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
if let Some(system) = &options.system {
builder = builder.preamble(system);
}
if let Some(temp) = options.temperature {
builder = builder.temperature(temp);
}
let agent = builder.build();
agent
.prompt(&user_prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::Groq(client) => {
let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
if let Some(system) = &options.system {
builder = builder.preamble(system);
}
if let Some(temp) = options.temperature {
builder = builder.temperature(temp);
}
let agent = builder.build();
agent
.prompt(&user_prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::DeepSeek(client) => {
let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
if let Some(system) = &options.system {
builder = builder.preamble(system);
}
if let Some(temp) = options.temperature {
builder = builder.temperature(temp);
}
let agent = builder.build();
agent
.prompt(&user_prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::Gemini(client) => {
let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
if let Some(system) = &options.system {
builder = builder.preamble(system);
}
if let Some(temp) = options.temperature {
builder = builder.temperature(temp);
}
let agent = builder.build();
agent
.prompt(&user_prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
RigProvider::XAi(client) => {
let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
if let Some(system) = &options.system {
builder = builder.preamble(system);
}
if let Some(temp) = options.temperature {
builder = builder.temperature(temp);
}
let agent = builder.build();
agent
.prompt(&user_prompt)
.await
.map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
}
#[cfg(feature = "native-inference")]
RigProvider::Native(runtime) => {
let chat_options = super::native::ChatOptions {
temperature: options.temperature.map(|t| t as f32),
max_tokens: options.max_tokens,
..Default::default()
};
runtime
.infer(&user_prompt, chat_options)
.await
.map(|r| r.message.content)
.map_err(|e: super::native::NativeError| {
RigInferError::PromptError(e.to_string())
})
}
}
}
pub fn auto() -> Option<Self> {
use crate::core::providers::{ProviderCategory, KNOWN_PROVIDERS};
for p in KNOWN_PROVIDERS.iter() {
if p.category == ProviderCategory::Llm && p.has_env_key() {
return match p.id {
"anthropic" => Some(Self::claude()),
"openai" => Some(Self::openai()),
"mistral" => Some(Self::mistral()),
"groq" => Some(Self::groq()),
"deepseek" => Some(Self::deepseek()),
"gemini" => Some(Self::gemini()),
"xai" => Some(Self::xai()),
_ => continue,
};
}
}
#[cfg(feature = "native-inference")]
if std::env::var("NIKA_NATIVE_MODEL").is_ok_and(|v| !v.trim().is_empty()) {
return Some(Self::native());
}
None
}
pub async fn verify(&self) -> Result<ProviderVerifyResult, ProviderVerifyError> {
use std::time::Instant;
let start = Instant::now();
let test_prompt = "Hi";
match self.infer(test_prompt, None).await {
Ok(_) => Ok(ProviderVerifyResult {
provider: self.name().to_string(),
latency: start.elapsed(),
model: self.default_model().to_string(),
}),
Err(e) => {
let error_msg = e.to_string().to_lowercase();
if error_msg.contains("401")
|| error_msg.contains("unauthorized")
|| error_msg.contains("invalid api key")
|| error_msg.contains("authentication")
{
Err(ProviderVerifyError::InvalidApiKey {
provider: self.name().to_string(),
})
} else if error_msg.contains("rate limit")
|| error_msg.contains("429")
|| error_msg.contains("too many requests")
{
Err(ProviderVerifyError::RateLimited {
provider: self.name().to_string(),
})
} else if error_msg.contains("timeout")
|| error_msg.contains("timed out")
|| error_msg.contains("deadline")
{
Err(ProviderVerifyError::Timeout {
provider: self.name().to_string(),
})
} else if error_msg.contains("connection")
|| error_msg.contains("network")
|| error_msg.contains("dns")
|| error_msg.contains("refused")
{
Err(ProviderVerifyError::NetworkError {
provider: self.name().to_string(),
details: e.to_string(),
})
} else {
Err(ProviderVerifyError::ProviderError {
provider: self.name().to_string(),
details: e.to_string(),
})
}
}
}
}
pub fn is_configured(&self) -> bool {
let has_key = |key: &str| std::env::var(key).is_ok_and(|v| !v.trim().is_empty());
match self {
RigProvider::Claude(_) => has_key("ANTHROPIC_API_KEY"),
RigProvider::OpenAI(_) => has_key("OPENAI_API_KEY"),
RigProvider::Mistral(_) => has_key("MISTRAL_API_KEY"),
RigProvider::Groq(_) => has_key("GROQ_API_KEY"),
RigProvider::DeepSeek(_) => has_key("DEEPSEEK_API_KEY"),
RigProvider::Gemini(_) => has_key("GEMINI_API_KEY"),
RigProvider::XAi(_) => has_key("XAI_API_KEY"),
#[cfg(feature = "native-inference")]
RigProvider::Native(_) => {
true
}
}
}
}
#[derive(Debug, Clone)]
pub struct ProviderVerifyResult {
pub provider: String,
pub latency: std::time::Duration,
pub model: String,
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum ProviderVerifyError {
#[error("Invalid API key for {provider}")]
InvalidApiKey { provider: String },
#[error("Rate limited by {provider}")]
RateLimited { provider: String },
#[error("Connection timeout to {provider}")]
Timeout { provider: String },
#[error("Network error connecting to {provider}: {details}")]
NetworkError { provider: String, details: String },
#[error("Provider error from {provider}: {details}")]
ProviderError { provider: String, details: String },
}
impl ProviderVerifyError {
pub fn suggestion(&self) -> &'static str {
match self {
ProviderVerifyError::InvalidApiKey { .. } => {
"Check your API key in environment variables"
}
ProviderVerifyError::RateLimited { .. } => {
"Wait a moment and try again, or check your plan limits"
}
ProviderVerifyError::Timeout { .. } => "Check your network connection or try again",
ProviderVerifyError::NetworkError { .. } => {
"Check your internet connection and firewall settings"
}
ProviderVerifyError::ProviderError { .. } => {
"The provider service may be experiencing issues"
}
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum RigInferError {
#[error("Completion error: {0}")]
PromptError(String),
#[error("Stream timeout: no chunk received for {duration_ms}ms")]
Timeout { duration_ms: u64 },
#[error("Vision not supported: {0}")]
VisionNotSupported(String),
}
#[derive(Debug, Clone)]
pub enum StreamChunk {
Token(String),
Thinking(String),
Done(String),
Error(String),
Metrics {
input_tokens: u64,
output_tokens: u64,
},
McpConnected(String),
McpError { server_name: String, error: String },
McpCallStart {
tool: String,
server: String,
params: String,
},
McpCallComplete { result: String },
McpCallFailed { error: String },
InferStart {
model: String,
prompt: String,
prompt_tokens: u32,
max_tokens: u32,
},
InferTokens { output_tokens: u32 },
InferComplete,
ExecStart { command: String },
ExecComplete,
FetchStart { url: String, method: String },
FetchComplete,
AgentStart { goal: String },
AgentComplete,
ProviderVerifying { provider: String, model: String },
ProviderVerified {
provider: String,
model: String,
latency_ms: u64,
},
ProviderVerifyFailed { provider: String, error: String },
ProviderNotConfigured { provider: String },
McpPinging { server: String },
McpPinged {
server: String,
latency_ms: u64,
tool_count: usize,
},
ProviderVerificationTimeout,
NativeModelPullStarted { model: String },
NativeModelPullProgress {
model: String,
status: String,
completed: u64,
total: u64,
},
NativeModelPulled {
model: String,
path: String,
size: u64,
},
NativeModelPullFailed { model: String, error: String },
NativeModelDeleted { model: String },
NativeModelDeleteFailed { model: String, error: String },
NativeModelsRefreshed { count: usize },
}
#[derive(Debug, Clone, Default)]
pub struct StreamResult {
pub text: String,
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
pub cached_input_tokens: u64,
}
impl StreamResult {
pub fn from_text(text: impl Into<String>) -> Self {
Self {
text: text.into(),
..Default::default()
}
}
}
async fn consume_rig_stream<R>(
stream: &mut rig::streaming::StreamingCompletionResponse<R>,
tx: &mpsc::Sender<StreamChunk>,
response_parts: &mut Vec<String>,
result: &mut StreamResult,
capture_thinking: bool,
) -> Result<(), RigInferError>
where
R: Clone + Unpin + GetTokenUsage + serde::Serialize + serde::de::DeserializeOwned,
{
loop {
let chunk_result = match timeout(STREAM_CHUNK_TIMEOUT, stream.next()).await {
Ok(Some(result)) => result,
Ok(None) => break,
Err(_elapsed) => {
let _ = tx.try_send(StreamChunk::Error(format!(
"Stream timeout: no chunk received for {}s",
STREAM_CHUNK_TIMEOUT.as_secs()
)));
return Err(RigInferError::Timeout {
duration_ms: STREAM_CHUNK_TIMEOUT.as_millis() as u64,
});
}
};
match chunk_result {
Ok(content) => match content {
StreamedAssistantContent::Text(text) => {
response_parts.push(text.text.clone());
let _ = tx.try_send(StreamChunk::Token(text.text));
}
StreamedAssistantContent::ReasoningDelta { reasoning, .. } if capture_thinking => {
let _ = tx.try_send(StreamChunk::Thinking(reasoning));
}
StreamedAssistantContent::Final(response) => {
if let Some(usage) = response.token_usage() {
result.input_tokens = usage.input_tokens;
result.output_tokens = usage.output_tokens;
result.total_tokens = usage.total_tokens;
result.cached_input_tokens = usage.cached_input_tokens;
}
}
_ => {}
},
Err(e) => {
let _ = tx.try_send(StreamChunk::Error(e.to_string()));
return Err(RigInferError::PromptError(e.to_string()));
}
}
}
Ok(())
}
impl RigProvider {
pub async fn infer_stream(
&self,
prompt: &str,
tx: mpsc::Sender<StreamChunk>,
model: Option<&str>,
) -> Result<StreamResult, RigInferError> {
let model_id = model.unwrap_or_else(|| self.default_model());
let mut response_parts: Vec<String> = Vec::new();
let mut result = StreamResult::default();
match self {
RigProvider::Claude(client) => {
let model = client.completion_model(model_id);
let request = model.completion_request(prompt).max_tokens(8192).build();
let mut stream = model
.stream(request)
.await
.map_err(|e| RigInferError::PromptError(e.to_string()))?;
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, true)
.await?;
}
RigProvider::OpenAI(client) => {
let model = client.completion_model(model_id);
let request = model.completion_request(prompt).max_tokens(8192).build();
let mut stream = model
.stream(request)
.await
.map_err(|e| RigInferError::PromptError(e.to_string()))?;
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::Mistral(client) => {
let model = client.completion_model(model_id);
let request = model.completion_request(prompt).max_tokens(8192).build();
let mut stream = model
.stream(request)
.await
.map_err(|e| RigInferError::PromptError(e.to_string()))?;
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::Groq(client) => {
let model = client.completion_model(model_id);
let request = model.completion_request(prompt).max_tokens(8192).build();
let mut stream = model
.stream(request)
.await
.map_err(|e| RigInferError::PromptError(e.to_string()))?;
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::DeepSeek(client) => {
let model = client.completion_model(model_id);
let request = model.completion_request(prompt).max_tokens(8192).build();
let mut stream = model
.stream(request)
.await
.map_err(|e| RigInferError::PromptError(e.to_string()))?;
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::Gemini(client) => {
let model = client.completion_model(model_id);
let request = model.completion_request(prompt).max_tokens(8192).build();
let mut stream = model
.stream(request)
.await
.map_err(|e| RigInferError::PromptError(e.to_string()))?;
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::XAi(client) => {
let model = client.completion_model(model_id);
let request = model.completion_request(prompt).max_tokens(8192).build();
let mut stream = model
.stream(request)
.await
.map_err(|e| RigInferError::PromptError(e.to_string()))?;
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
#[cfg(feature = "native-inference")]
RigProvider::Native(runtime) => {
use futures::StreamExt;
use std::pin::pin;
let stream = runtime
.infer_stream(prompt, super::native::ChatOptions::default())
.await
.map_err(|e: super::native::NativeError| {
RigInferError::PromptError(e.to_string())
})?;
let mut stream = pin!(stream);
while let Some(result) = stream.next().await {
match result {
Ok(token) => {
response_parts.push(token.clone());
let _ = tx.try_send(StreamChunk::Token(token));
}
Err(e) => {
let _ = tx.try_send(StreamChunk::Error(e.to_string()));
return Err(RigInferError::PromptError(e.to_string()));
}
}
}
}
}
let complete_response = response_parts.concat();
let _ = tx.try_send(StreamChunk::Done(complete_response.clone()));
let _ = tx.try_send(StreamChunk::Metrics {
input_tokens: result.input_tokens,
output_tokens: result.output_tokens,
});
result.text = complete_response;
Ok(result)
}
pub async fn infer_stream_with_options(
&self,
prompt: &str,
tx: mpsc::Sender<StreamChunk>,
options: &InferOptions,
) -> Result<StreamResult, RigInferError> {
let model_id = options
.model
.as_deref()
.unwrap_or_else(|| self.default_model());
let max_tokens = options.max_tokens.unwrap_or(8192);
let mut response_parts: Vec<String> = Vec::new();
let mut result = StreamResult::default();
macro_rules! build_request_with_options {
($client:expr) => {{
let model = $client.completion_model(model_id);
let mut rb = model
.completion_request(prompt)
.max_tokens(max_tokens as u64);
if let Some(ref system) = options.system {
rb = rb.preamble(system.clone());
}
if let Some(temp) = options.temperature {
rb = rb.temperature(temp);
}
model
.stream(rb.build())
.await
.map_err(|e| RigInferError::PromptError(e.to_string()))?
}};
}
match self {
RigProvider::Claude(client) => {
let mut stream = build_request_with_options!(client);
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, true)
.await?;
}
RigProvider::OpenAI(client) => {
let mut stream = build_request_with_options!(client);
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::Mistral(client) => {
let mut stream = build_request_with_options!(client);
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::Groq(client) => {
let mut stream = build_request_with_options!(client);
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::DeepSeek(client) => {
let mut stream = build_request_with_options!(client);
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::Gemini(client) => {
let mut stream = build_request_with_options!(client);
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
RigProvider::XAi(client) => {
let mut stream = build_request_with_options!(client);
consume_rig_stream(&mut stream, &tx, &mut response_parts, &mut result, false)
.await?;
}
#[cfg(feature = "native-inference")]
RigProvider::Native(runtime) => {
use futures::StreamExt;
use std::pin::pin;
let native_prompt = if let Some(ref system) = options.system {
format!("{}\n\n{}", system, prompt)
} else {
prompt.to_string()
};
let chat_options = super::native::ChatOptions {
temperature: options.temperature.map(|t| t as f32),
max_tokens: options.max_tokens,
..Default::default()
};
let stream = runtime
.infer_stream(&native_prompt, chat_options)
.await
.map_err(|e: super::native::NativeError| {
RigInferError::PromptError(e.to_string())
})?;
let mut stream = pin!(stream);
while let Some(result) = stream.next().await {
match result {
Ok(token) => {
response_parts.push(token.clone());
let _ = tx.try_send(StreamChunk::Token(token));
}
Err(e) => {
let _ = tx.try_send(StreamChunk::Error(e.to_string()));
return Err(RigInferError::PromptError(e.to_string()));
}
}
}
}
}
let complete_response = response_parts.concat();
let _ = tx.try_send(StreamChunk::Done(complete_response.clone()));
let _ = tx.try_send(StreamChunk::Metrics {
input_tokens: result.input_tokens,
output_tokens: result.output_tokens,
});
result.text = complete_response;
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct NikaMcpToolDef {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
pub type AgentMediaStaging = Arc<dashmap::DashMap<String, Vec<crate::mcp::types::ContentBlock>>>;
#[derive(Debug, Clone)]
pub struct NikaMcpTool {
definition: NikaMcpToolDef,
client: Option<Arc<McpClient>>,
media_staging: Option<AgentMediaStaging>,
}
impl NikaMcpTool {
pub fn new(definition: NikaMcpToolDef) -> Self {
Self {
definition,
client: None,
media_staging: None,
}
}
pub fn with_client(definition: NikaMcpToolDef, client: Arc<McpClient>) -> Self {
Self {
definition,
client: Some(client),
media_staging: None,
}
}
pub fn with_media_staging(
definition: NikaMcpToolDef,
client: Arc<McpClient>,
staging: AgentMediaStaging,
) -> Self {
Self {
definition,
client: Some(client),
media_staging: Some(staging),
}
}
pub fn tool_name(&self) -> &str {
&self.definition.name
}
}
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
impl ToolDyn for NikaMcpTool {
fn name(&self) -> String {
self.definition.name.clone()
}
fn definition(&self, _prompt: String) -> BoxFuture<'_, ToolDefinition> {
let def = ToolDefinition {
name: self.definition.name.clone(),
description: self.definition.description.clone(),
parameters: self.definition.input_schema.clone(),
};
Box::pin(async move { def })
}
fn call(&self, args: String) -> BoxFuture<'_, Result<String, ToolError>> {
let tool_name = self.definition.name.clone();
let client = self.client.clone();
Box::pin(async move {
let params: serde_json::Value = serde_json::from_str(&args).map_err(|e| {
ToolError::ToolCallError(Box::new(McpToolError::invalid_args(format!(
"Invalid JSON arguments: {}",
e
))))
})?;
let client = client.ok_or_else(|| {
ToolError::ToolCallError(Box::new(McpToolError::not_configured(
"No MCP client configured for this tool",
)))
})?;
let result = client.call_tool(&tool_name, params).await.map_err(|e| {
ToolError::ToolCallError(Box::new(McpToolError::call_failed(format!(
"MCP tool call failed: {}",
e
))))
})?;
if result.has_media() {
if let Some(ref staging) = self.media_staging {
let media_blocks: Vec<_> = result.media_blocks().into_iter().cloned().collect();
if !media_blocks.is_empty() {
tracing::debug!(
tool = %tool_name,
media_count = media_blocks.len(),
"agent: staging binary content from tool call"
);
staging
.entry(tool_name.clone())
.or_default()
.extend(media_blocks);
}
} else {
tracing::warn!(
tool = %tool_name,
media_count = result.media_blocks().len(),
"agent: tool returned binary content but no media staging configured — data will be lost"
);
}
}
let output = result.text();
if output.is_empty() {
serde_json::to_string(&result).map_err(|e| {
ToolError::ToolCallError(Box::new(McpToolError::serialization(format!(
"Failed to serialize result: {}",
e
))))
})
} else {
Ok(output)
}
})
}
}
#[cfg(feature = "native-inference")]
fn extract_native_vision_parts(
user_content: &[rig::completion::message::UserContent],
) -> Result<(String, Vec<crate::core::backend::VisionImage>), RigInferError> {
use base64::Engine as _;
use rig::completion::message::{DocumentSourceKind, Image, UserContent};
let mut text_parts: Vec<String> = Vec::new();
let mut images: Vec<crate::core::backend::VisionImage> = Vec::new();
for part in user_content {
match part {
UserContent::Text(text) => {
text_parts.push(text.text.clone());
}
UserContent::Image(Image {
data, media_type, ..
}) => {
let bytes = match data {
DocumentSourceKind::Base64(b64) => base64::engine::general_purpose::STANDARD
.decode(b64)
.map_err(|e| {
RigInferError::PromptError(format!(
"Failed to decode base64 image for native vision: {}",
e
))
})?,
DocumentSourceKind::Raw(raw) => raw.clone(),
DocumentSourceKind::Url(url) => {
return Err(RigInferError::VisionNotSupported(format!(
"Native vision does not support URL images. Pre-fetch the image: {}",
url
)));
}
_ => {
return Err(RigInferError::PromptError(
"Unsupported image source kind for native vision".to_string(),
));
}
};
let mime = media_type
.as_ref()
.map(|mt| match mt {
rig::completion::message::ImageMediaType::JPEG => "image/jpeg",
rig::completion::message::ImageMediaType::PNG => "image/png",
rig::completion::message::ImageMediaType::GIF => "image/gif",
rig::completion::message::ImageMediaType::WEBP => "image/webp",
_ => "image/png", })
.unwrap_or("image/png");
images.push(crate::core::backend::VisionImage::new(bytes, mime));
}
_ => {}
}
}
Ok((text_parts.join("\n"), images))
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn stream_result_from_text_has_zero_tokens() {
let result = StreamResult::from_text("hello world");
assert_eq!(result.text, "hello world");
assert_eq!(result.input_tokens, 0);
assert_eq!(result.output_tokens, 0);
assert_eq!(result.total_tokens, 0);
assert_eq!(result.cached_input_tokens, 0);
}
#[test]
fn stream_result_default_is_empty() {
let result = StreamResult::default();
assert_eq!(result.text, "");
assert_eq!(result.total_tokens, 0);
}
#[test]
fn stream_result_with_tokens() {
let result = StreamResult {
text: "response".to_string(),
input_tokens: 100,
output_tokens: 50,
total_tokens: 150,
cached_input_tokens: 20,
};
assert_eq!(
result.total_tokens,
result.input_tokens + result.output_tokens
);
assert_eq!(result.cached_input_tokens, 20);
}
#[test]
#[serial]
fn test_rig_provider_claude_returns_claude_variant() {
std::env::set_var("ANTHROPIC_API_KEY", "test-key-for-unit-test");
let provider = RigProvider::claude();
assert_eq!(provider.name(), "claude");
assert!(matches!(provider, RigProvider::Claude(_)));
}
#[test]
#[serial]
fn test_rig_provider_openai_returns_openai_variant() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let provider = RigProvider::openai();
assert_eq!(provider.name(), "openai");
assert!(matches!(provider, RigProvider::OpenAI(_)));
}
#[test]
#[serial]
fn test_rig_provider_default_model_claude() {
std::env::set_var("ANTHROPIC_API_KEY", "test-key-for-unit-test");
let provider = RigProvider::claude();
assert_eq!(provider.default_model(), "claude-sonnet-4-6");
}
#[test]
#[serial]
fn test_rig_provider_default_model_openai() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let provider = RigProvider::openai();
assert_eq!(provider.default_model(), openai::GPT_4O);
}
#[test]
fn test_rig_infer_error_display() {
let err = RigInferError::PromptError("Test error message".to_string());
assert_eq!(err.to_string(), "Completion error: Test error message");
}
#[test]
fn test_rig_infer_error_timeout_display() {
let err = RigInferError::Timeout { duration_ms: 60000 };
assert_eq!(
err.to_string(),
"Stream timeout: no chunk received for 60000ms"
);
}
#[test]
#[serial]
fn test_rig_provider_mistral_returns_mistral_variant() {
std::env::set_var("MISTRAL_API_KEY", "test-key-for-unit-test");
let provider = RigProvider::mistral();
assert_eq!(provider.name(), "mistral");
assert!(matches!(provider, RigProvider::Mistral(_)));
}
#[test]
#[serial]
fn test_rig_provider_groq_returns_groq_variant() {
std::env::set_var("GROQ_API_KEY", "test-key-for-unit-test");
let provider = RigProvider::groq();
assert_eq!(provider.name(), "groq");
assert!(matches!(provider, RigProvider::Groq(_)));
}
#[test]
#[serial]
fn test_rig_provider_deepseek_returns_deepseek_variant() {
std::env::set_var("DEEPSEEK_API_KEY", "test-key-for-unit-test");
let provider = RigProvider::deepseek();
assert_eq!(provider.name(), "deepseek");
assert!(matches!(provider, RigProvider::DeepSeek(_)));
}
#[test]
#[serial]
fn test_rig_provider_default_models_v06() {
std::env::set_var("MISTRAL_API_KEY", "test");
std::env::set_var("GROQ_API_KEY", "test");
std::env::set_var("DEEPSEEK_API_KEY", "test");
assert_eq!(
RigProvider::mistral().default_model(),
mistral::MISTRAL_LARGE
);
assert_eq!(
RigProvider::groq().default_model(),
"llama-3.3-70b-versatile"
);
assert_eq!(RigProvider::deepseek().default_model(), "deepseek-chat");
}
#[test]
#[serial]
fn test_rig_provider_auto_detects_claude() {
std::env::remove_var("OPENAI_API_KEY");
std::env::remove_var("MISTRAL_API_KEY");
std::env::remove_var("GROQ_API_KEY");
std::env::remove_var("DEEPSEEK_API_KEY");
std::env::set_var("ANTHROPIC_API_KEY", "test-key");
let provider = RigProvider::auto();
assert!(provider.is_some());
assert_eq!(provider.unwrap().name(), "claude");
}
#[test]
#[serial]
fn test_rig_provider_auto_returns_none_when_no_keys() {
clear_all_provider_env_vars();
let provider = RigProvider::auto();
assert!(provider.is_none());
}
fn clear_all_provider_env_vars() {
std::env::remove_var("ANTHROPIC_API_KEY");
std::env::remove_var("OPENAI_API_KEY");
std::env::remove_var("MISTRAL_API_KEY");
std::env::remove_var("GROQ_API_KEY");
std::env::remove_var("DEEPSEEK_API_KEY");
std::env::remove_var("GEMINI_API_KEY");
}
#[test]
#[serial]
fn test_auto_fallback_to_openai() {
clear_all_provider_env_vars();
std::env::set_var("OPENAI_API_KEY", "test-key");
let provider = RigProvider::auto();
assert!(provider.is_some());
assert_eq!(provider.unwrap().name(), "openai");
}
#[test]
#[serial]
fn test_auto_fallback_to_mistral() {
clear_all_provider_env_vars();
std::env::set_var("MISTRAL_API_KEY", "test-key");
let provider = RigProvider::auto();
assert!(provider.is_some());
assert_eq!(provider.unwrap().name(), "mistral");
}
#[test]
#[serial]
fn test_auto_fallback_to_groq() {
clear_all_provider_env_vars();
std::env::set_var("GROQ_API_KEY", "test-key");
let provider = RigProvider::auto();
assert!(provider.is_some());
assert_eq!(provider.unwrap().name(), "groq");
}
#[test]
#[serial]
fn test_auto_fallback_to_deepseek() {
clear_all_provider_env_vars();
std::env::set_var("DEEPSEEK_API_KEY", "test-key");
let provider = RigProvider::auto();
assert!(provider.is_some());
assert_eq!(provider.unwrap().name(), "deepseek");
}
#[test]
#[serial]
fn test_auto_fallback_to_gemini() {
clear_all_provider_env_vars();
std::env::set_var("GEMINI_API_KEY", "test-key");
let provider = RigProvider::auto();
assert!(provider.is_some());
assert_eq!(provider.unwrap().name(), "gemini");
}
#[test]
#[serial]
fn test_auto_priority_claude_over_openai() {
clear_all_provider_env_vars();
std::env::set_var("ANTHROPIC_API_KEY", "claude-key");
std::env::set_var("OPENAI_API_KEY", "openai-key");
let provider = RigProvider::auto();
assert!(provider.is_some());
assert_eq!(provider.unwrap().name(), "claude");
}
#[test]
#[serial]
fn test_auto_priority_openai_over_mistral() {
clear_all_provider_env_vars();
std::env::set_var("OPENAI_API_KEY", "openai-key");
std::env::set_var("MISTRAL_API_KEY", "mistral-key");
let provider = RigProvider::auto();
assert!(provider.is_some());
assert_eq!(provider.unwrap().name(), "openai");
}
#[test]
#[serial]
fn test_auto_empty_env_var_treated_as_unset() {
clear_all_provider_env_vars();
std::env::set_var("ANTHROPIC_API_KEY", ""); std::env::set_var("OPENAI_API_KEY", "valid-key");
let provider = RigProvider::auto();
assert!(provider.is_some());
assert_eq!(provider.unwrap().name(), "openai");
}
#[test]
#[serial]
fn test_auto_whitespace_env_var_treated_as_unset() {
clear_all_provider_env_vars();
std::env::set_var("ANTHROPIC_API_KEY", " ");
let provider = RigProvider::auto();
assert!(
provider.is_none(),
"Whitespace-only API key should be treated as unset"
);
}
#[test]
fn test_nika_mcp_tool_implements_tool_dyn() {
let tool_def = NikaMcpToolDef {
name: "novanet_context".to_string(),
description: "Generate native content for an entity".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"entity": { "type": "string" },
"locale": { "type": "string" }
},
"required": ["entity", "locale"]
}),
};
let tool = NikaMcpTool::new(tool_def);
assert_eq!(tool.tool_name(), "novanet_context");
}
#[test]
fn test_nika_mcp_tool_definition_returns_correct_schema() {
use rig::tool::ToolDyn;
let tool_def = NikaMcpToolDef {
name: "novanet_describe".to_string(),
description: "Describe an entity from the knowledge graph".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"entity_key": { "type": "string" }
},
"required": ["entity_key"]
}),
};
let tool = NikaMcpTool::new(tool_def);
let name = tool.name();
assert_eq!(name, "novanet_describe");
}
#[tokio::test]
async fn test_nika_mcp_tool_call_uses_mcp_client() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_describe".to_string(),
description: "Describe an entity".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"entity_key": { "type": "string" }
},
"required": ["entity_key"]
}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let args = r#"{"entity_key": "qr-code"}"#.to_string();
let result = tool.call(args).await;
assert!(result.is_ok(), "Tool call should succeed with mock client");
let output = result.unwrap();
assert!(!output.is_empty(), "Tool should return non-empty output");
}
#[tokio::test]
async fn test_usecase_novanet_context_entity_locale() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_context".to_string(),
description: "Full RLM-on-KG context assembly for generation".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"focus_key": { "type": "string", "description": "Entity key to generate for" },
"locale": { "type": "string", "description": "BCP-47 locale code" },
"mode": { "type": "string", "enum": ["block", "page"], "default": "block" },
"token_budget": { "type": "integer", "default": 4000 },
"spreading_depth": { "type": "integer", "default": 2 },
"forms": {
"type": "array",
"items": { "type": "string", "enum": ["text", "title", "abbrev", "url"] }
}
},
"required": ["focus_key", "locale"]
}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let args = serde_json::json!({
"focus_key": "qr-code",
"locale": "fr-FR",
"mode": "page",
"forms": ["text", "title", "abbrev"]
})
.to_string();
let result = tool.call(args).await;
assert!(
result.is_ok(),
"novanet_context should succeed: {:?}",
result
);
let output = result.unwrap();
assert!(!output.is_empty(), "Should return generation context");
}
#[tokio::test]
async fn test_usecase_novanet_describe_entity() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_describe".to_string(),
description: "Bootstrap agent understanding of the knowledge graph".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"describe": {
"type": "string",
"enum": ["schema", "entity", "category", "relations", "locales", "stats"]
},
"entity_key": { "type": "string" },
"category_key": { "type": "string" }
},
"required": ["describe"]
}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let args = serde_json::json!({
"describe": "schema"
})
.to_string();
let result = tool.call(args).await;
assert!(result.is_ok(), "novanet_describe should succeed");
}
#[tokio::test]
async fn test_usecase_novanet_search_walk_graph() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_search".to_string(),
description: "Graph traversal with configurable depth and filters".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"start_key": { "type": "string" },
"max_depth": { "type": "integer", "default": 2 },
"direction": { "type": "string", "enum": ["outgoing", "incoming", "both"] },
"arc_families": { "type": "array", "items": { "type": "string" } },
"target_kinds": { "type": "array", "items": { "type": "string" } }
},
"required": ["start_key"]
}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let args = serde_json::json!({
"start_key": "qr-code",
"max_depth": 2,
"direction": "outgoing",
"arc_families": ["ownership", "localization"]
})
.to_string();
let result = tool.call(args).await;
assert!(result.is_ok(), "novanet_search walk should succeed");
}
#[tokio::test]
async fn test_usecase_novanet_search_hybrid() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_search".to_string(),
description: "Fulltext + property search with hybrid mode".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" },
"mode": { "type": "string", "enum": ["fulltext", "property", "hybrid"] },
"kinds": { "type": "array", "items": { "type": "string" } },
"realm": { "type": "string", "enum": ["shared", "org"] },
"limit": { "type": "integer", "default": 10 }
},
"required": ["query"]
}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let args = serde_json::json!({
"query": "QR code generator",
"mode": "hybrid",
"kinds": ["Entity", "Page"],
"limit": 5
})
.to_string();
let result = tool.call(args).await;
assert!(result.is_ok(), "novanet_search should succeed");
}
#[tokio::test]
async fn test_usecase_novanet_audit_locale() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_audit".to_string(),
description: "Retrieve knowledge atoms for a specific locale".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"locale": { "type": "string" },
"atom_type": {
"type": "string",
"enum": ["term", "expression", "pattern", "cultureref", "taboo", "audiencetrait", "all"]
},
"domain": { "type": "string" }
},
"required": ["locale"]
}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let args = serde_json::json!({
"locale": "fr-FR",
"atom_type": "term",
"domain": "qr-code"
})
.to_string();
let result = tool.call(args).await;
assert!(result.is_ok(), "novanet_audit should succeed");
}
#[tokio::test]
async fn test_usecase_novanet_batch_context() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_batch".to_string(),
description: "Assemble context for LLM generation (token-aware)".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"focus_key": { "type": "string" },
"locale": { "type": "string" },
"token_budget": { "type": "integer", "default": 4000 },
"strategy": {
"type": "string",
"enum": ["breadth", "depth", "relevance", "custom"]
}
},
"required": ["focus_key", "locale"]
}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let args = serde_json::json!({
"focus_key": "qr-code",
"locale": "es-MX",
"token_budget": 3000,
"strategy": "relevance"
})
.to_string();
let result = tool.call(args).await;
assert!(result.is_ok(), "novanet_batch should succeed");
}
#[tokio::test]
async fn test_error_no_client_configured() {
use rig::tool::ToolDyn;
let tool_def = NikaMcpToolDef {
name: "novanet_describe".to_string(),
description: "Test tool".to_string(),
input_schema: serde_json::json!({"type": "object"}),
};
let tool = NikaMcpTool::new(tool_def);
let args = r#"{"entity_key": "test"}"#.to_string();
let result = tool.call(args).await;
assert!(result.is_err(), "Should fail without client");
let err = result.unwrap_err();
let err_str = err.to_string();
assert!(
err_str.contains("No MCP client") || err_str.contains("NotConnected"),
"Error should mention missing client: {}",
err_str
);
}
#[tokio::test]
async fn test_error_invalid_json_arguments() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_describe".to_string(),
description: "Test tool".to_string(),
input_schema: serde_json::json!({"type": "object"}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let args = "not valid json {{{".to_string();
let result = tool.call(args).await;
assert!(result.is_err(), "Should fail with invalid JSON");
let err = result.unwrap_err();
let err_str = err.to_string();
assert!(
err_str.contains("Invalid JSON") || err_str.contains("JSON"),
"Error should mention JSON parsing: {}",
err_str
);
}
#[tokio::test]
async fn test_empty_json_object_is_valid() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_describe".to_string(),
description: "Test tool".to_string(),
input_schema: serde_json::json!({"type": "object"}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let args = "{}".to_string();
let result = tool.call(args).await;
assert!(result.is_ok(), "Empty JSON object should be valid");
}
#[tokio::test]
async fn test_tool_definition_async() {
use rig::tool::ToolDyn;
let input_schema = serde_json::json!({
"type": "object",
"properties": {
"entity_key": { "type": "string" },
"locale": { "type": "string" }
},
"required": ["entity_key"]
});
let tool_def = NikaMcpToolDef {
name: "test_tool".to_string(),
description: "A test tool for verification".to_string(),
input_schema: input_schema.clone(),
};
let tool = NikaMcpTool::new(tool_def);
let definition = tool.definition("some prompt".to_string()).await;
assert_eq!(definition.name, "test_tool");
assert_eq!(definition.description, "A test tool for verification");
assert_eq!(definition.parameters, input_schema);
}
#[test]
fn test_multiple_tools_independent() {
let tool1 = NikaMcpTool::new(NikaMcpToolDef {
name: "novanet_context".to_string(),
description: "Generate content".to_string(),
input_schema: serde_json::json!({"type": "object"}),
});
let tool2 = NikaMcpTool::new(NikaMcpToolDef {
name: "novanet_describe".to_string(),
description: "Describe entity".to_string(),
input_schema: serde_json::json!({"type": "object"}),
});
let tool3 = NikaMcpTool::new(NikaMcpToolDef {
name: "novanet_search".to_string(),
description: "Traverse graph".to_string(),
input_schema: serde_json::json!({"type": "object"}),
});
assert_eq!(tool1.tool_name(), "novanet_context");
assert_eq!(tool2.tool_name(), "novanet_describe");
assert_eq!(tool3.tool_name(), "novanet_search");
}
#[tokio::test]
async fn test_tool_clone_works() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_describe".to_string(),
description: "Test tool".to_string(),
input_schema: serde_json::json!({"type": "object"}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let cloned_tool = tool.clone();
let args = r#"{"entity_key": "test"}"#.to_string();
let result1 = tool.call(args.clone()).await;
let result2 = cloned_tool.call(args).await;
assert!(result1.is_ok(), "Original tool should work");
assert!(result2.is_ok(), "Cloned tool should work");
}
#[tokio::test]
async fn test_multi_locale_generation_workflow() {
use crate::mcp::McpClient;
use rig::tool::ToolDyn;
use std::sync::Arc;
let client = Arc::new(McpClient::mock("novanet"));
let tool_def = NikaMcpToolDef {
name: "novanet_context".to_string(),
description: "Generate native content".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"focus_key": { "type": "string" },
"locale": { "type": "string" },
"forms": { "type": "array", "items": { "type": "string" } }
},
"required": ["focus_key", "locale"]
}),
};
let tool = NikaMcpTool::with_client(tool_def, client);
let locales = ["fr-FR", "es-MX", "de-DE", "ja-JP", "zh-CN"];
let mut results = Vec::new();
for locale in locales {
let args = serde_json::json!({
"focus_key": "qr-code",
"locale": locale,
"forms": ["text", "title"]
})
.to_string();
let result = tool.call(args).await;
results.push((locale, result.is_ok()));
}
for (locale, success) in &results {
assert!(success, "Generation for {} should succeed", locale);
}
assert_eq!(results.len(), 5, "Should process all 5 locales");
}
#[test]
fn test_provider_verify_error_types() {
let invalid_key = ProviderVerifyError::InvalidApiKey {
provider: "claude".to_string(),
};
assert!(invalid_key.to_string().contains("Invalid API key"));
assert!(invalid_key.suggestion().contains("API key"));
let rate_limited = ProviderVerifyError::RateLimited {
provider: "openai".to_string(),
};
assert!(rate_limited.to_string().contains("Rate limited"));
let timeout = ProviderVerifyError::Timeout {
provider: "mistral".to_string(),
};
assert!(timeout.to_string().contains("timeout"));
let network = ProviderVerifyError::NetworkError {
provider: "groq".to_string(),
details: "connection refused".to_string(),
};
assert!(network.to_string().contains("Network error"));
let provider_err = ProviderVerifyError::ProviderError {
provider: "deepseek".to_string(),
details: "server down".to_string(),
};
assert!(provider_err.to_string().contains("server down"));
}
#[test]
fn test_provider_verify_result_fields() {
let result = ProviderVerifyResult {
provider: "claude".to_string(),
latency: std::time::Duration::from_millis(150),
model: "claude-sonnet-4-6".to_string(),
};
assert_eq!(result.provider, "claude");
assert_eq!(result.latency.as_millis(), 150);
assert_eq!(result.model, "claude-sonnet-4-6");
}
#[test]
#[serial]
fn test_is_configured_with_api_key() {
std::env::set_var("ANTHROPIC_API_KEY", "test-key");
let provider = RigProvider::claude();
assert!(provider.is_configured());
}
#[test]
#[serial]
fn test_is_configured_returns_true_for_all_providers_with_keys() {
std::env::set_var("ANTHROPIC_API_KEY", "test");
std::env::set_var("OPENAI_API_KEY", "test");
std::env::set_var("MISTRAL_API_KEY", "test");
std::env::set_var("GROQ_API_KEY", "test");
std::env::set_var("DEEPSEEK_API_KEY", "test");
assert!(RigProvider::claude().is_configured());
assert!(RigProvider::openai().is_configured());
assert!(RigProvider::mistral().is_configured());
assert!(RigProvider::groq().is_configured());
assert!(RigProvider::deepseek().is_configured());
}
#[test]
fn test_infer_options_default() {
let opts = InferOptions::default();
assert!(opts.model.is_none());
assert!(opts.temperature.is_none());
assert!(opts.max_tokens.is_none());
assert!(opts.system.is_none());
}
#[test]
fn test_infer_options_with_all_fields() {
let opts = InferOptions {
model: Some("gpt-4o".to_string()),
temperature: Some(0.7),
max_tokens: Some(2000),
system: Some("You are a helpful assistant.".to_string()),
};
assert_eq!(opts.model.as_deref(), Some("gpt-4o"));
assert_eq!(opts.temperature, Some(0.7));
assert_eq!(opts.max_tokens, Some(2000));
assert_eq!(opts.system.as_deref(), Some("You are a helpful assistant."));
}
#[test]
fn test_infer_options_partial_fields() {
let opts = InferOptions {
temperature: Some(0.5),
..Default::default()
};
assert!(opts.model.is_none());
assert_eq!(opts.temperature, Some(0.5));
assert!(opts.max_tokens.is_none());
assert!(opts.system.is_none());
}
#[test]
fn test_infer_options_temperature_zero() {
let opts = InferOptions {
temperature: Some(0.0),
..Default::default()
};
assert_eq!(opts.temperature, Some(0.0));
}
#[test]
fn test_infer_options_max_tokens_small() {
let opts = InferOptions {
max_tokens: Some(1),
..Default::default()
};
assert_eq!(opts.max_tokens, Some(1));
}
#[test]
fn test_infer_options_system_empty_string() {
let opts = InferOptions {
system: Some(String::new()),
..Default::default()
};
assert_eq!(opts.system.as_deref(), Some(""));
}
#[test]
fn test_infer_options_clone() {
let opts = InferOptions {
model: Some("test-model".to_string()),
temperature: Some(0.8),
max_tokens: Some(1000),
system: Some("Test system".to_string()),
};
let cloned = opts.clone();
assert_eq!(opts.model, cloned.model);
assert_eq!(opts.temperature, cloned.temperature);
assert_eq!(opts.max_tokens, cloned.max_tokens);
assert_eq!(opts.system, cloned.system);
}
#[test]
fn vision_not_supported_error_display() {
let err = RigInferError::VisionNotSupported("DeepSeek no vision".to_string());
assert!(err.to_string().contains("Vision not supported"));
assert!(err.to_string().contains("DeepSeek no vision"));
}
#[tokio::test]
async fn infer_vision_deepseek_returns_error() {
if std::env::var("DEEPSEEK_API_KEY").is_err() {
let err = RigInferError::VisionNotSupported("DeepSeek".to_string());
assert!(err.to_string().contains("Vision not supported"));
return;
}
let provider = RigProvider::deepseek();
let content = vec![rig::completion::message::UserContent::text("hello")];
let result = provider.infer_vision(content, None, None, None).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
RigInferError::VisionNotSupported(_)
));
}
#[test]
fn infer_vision_empty_content_builds_error() {
use rig::OneOrMany;
let content: Vec<rig::completion::message::UserContent> = vec![];
let result = OneOrMany::many(content);
assert!(result.is_err(), "empty content should fail");
}
#[test]
fn build_vision_user_content_text_only() {
let content = [rig::completion::message::UserContent::text("Describe this")];
assert_eq!(content.len(), 1);
}
#[test]
fn build_vision_user_content_with_image() {
use rig::completion::message::{ImageMediaType, UserContent};
let content = [
UserContent::text("What is in this image?"),
UserContent::image_base64(
"iVBORw0KGgo=", Some(ImageMediaType::PNG),
None,
),
];
assert_eq!(content.len(), 2);
}
#[test]
fn build_vision_message_from_content() {
use rig::completion::message::{ImageMediaType, Message, UserContent};
use rig::OneOrMany;
let parts = vec![
UserContent::text("Describe this image"),
UserContent::image_base64("iVBORw0KGgo=", Some(ImageMediaType::PNG), None),
];
let msg = Message::User {
content: OneOrMany::many(parts).unwrap(),
};
assert!(matches!(msg, Message::User { .. }));
}
}