1use crate::error_domains::ProviderError;
27use crate::mcp::McpClient;
28use crate::util::STREAM_CHUNK_TIMEOUT;
29use futures::StreamExt;
30use std::time::Instant;
31
32#[cfg(feature = "native-inference")]
34use crate::provider::native::InferenceBackend;
35use rig::client::{CompletionClient, ProviderClient};
36use rig::completion::{CompletionModel as _, GetTokenUsage, Prompt, PromptError, ToolDefinition};
37use rig::providers::{anthropic, deepseek, gemini, groq, mistral, openai, xai};
38use rig::streaming::StreamedAssistantContent;
39use rig::tool::{ToolDyn, ToolError};
40use std::future::Future;
41use std::pin::Pin;
42use std::sync::Arc;
43use tokio::sync::mpsc;
44use tokio::time::timeout;
45
46#[derive(Debug)]
54pub struct McpToolError {
55 kind: McpToolErrorKind,
56 message: String,
57}
58
59#[derive(Debug, Clone, Copy)]
61pub enum McpToolErrorKind {
62 InvalidArguments,
64 NotConfigured,
66 CallFailed,
68 SerializationError,
70}
71
72impl McpToolError {
73 pub fn invalid_args(msg: impl Into<String>) -> Self {
75 Self {
76 kind: McpToolErrorKind::InvalidArguments,
77 message: msg.into(),
78 }
79 }
80
81 pub fn not_configured(msg: impl Into<String>) -> Self {
83 Self {
84 kind: McpToolErrorKind::NotConfigured,
85 message: msg.into(),
86 }
87 }
88
89 pub fn call_failed(msg: impl Into<String>) -> Self {
91 Self {
92 kind: McpToolErrorKind::CallFailed,
93 message: msg.into(),
94 }
95 }
96
97 pub fn serialization(msg: impl Into<String>) -> Self {
99 Self {
100 kind: McpToolErrorKind::SerializationError,
101 message: msg.into(),
102 }
103 }
104}
105
106impl std::fmt::Display for McpToolError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 let kind_str = match self.kind {
109 McpToolErrorKind::InvalidArguments => "InvalidArguments",
110 McpToolErrorKind::NotConfigured => "NotConfigured",
111 McpToolErrorKind::CallFailed => "CallFailed",
112 McpToolErrorKind::SerializationError => "SerializationError",
113 };
114 write!(f, "[{}] {}", kind_str, self.message)
115 }
116}
117
118impl std::error::Error for McpToolError {}
119
120#[derive(Debug, Clone, Default)]
124pub struct InferOptions {
125 pub model: Option<String>,
127 pub temperature: Option<f64>,
129 pub max_tokens: Option<u32>,
131 pub system: Option<String>,
133}
134
135pub fn is_reasoning_model(model_id: &str) -> bool {
140 let lower = model_id.to_lowercase();
141 lower == "o1"
143 || lower == "o1-mini"
144 || lower == "o1-pro"
145 || lower == "o3"
146 || lower == "o3-mini"
147 || lower == "o3-pro"
148 || lower == "o4-mini"
149 || lower.starts_with("o1-")
150 || lower.starts_with("o3-")
151 || lower == "o4"
152 || lower.starts_with("o4-")
153 || lower == "gpt-5"
155 || lower.starts_with("gpt-5-")
156 || lower == "deepseek-reasoner"
158}
159
160#[derive(Debug, Clone)]
165pub enum RigProvider {
166 Claude(anthropic::Client),
168 OpenAI(openai::Client),
170 Mistral(mistral::Client),
172 Groq(groq::Client),
174 DeepSeek(deepseek::Client),
176 Gemini(gemini::Client),
178 XAi(xai::Client),
180 #[cfg(feature = "native-inference")]
184 Native(super::native::NativeRuntime),
185}
186
187impl RigProvider {
188 pub fn from_name(name: &str) -> Result<Self, crate::error::NikaError> {
198 let provider = crate::core::find_provider(name).ok_or(
199 ProviderError::NotConfigured {
200 provider: name.to_string(),
201 }
202 )?;
203
204 if provider.requires_key && !provider.has_env_key() {
206 return Err(ProviderError::MissingApiKey {
207 provider: provider.id.to_string(),
208 }.into());
209 }
210
211 match provider.id {
212 "anthropic" => Ok(Self::claude()),
213 "openai" => Ok(Self::openai()),
214 "mistral" => Ok(Self::mistral()),
215 "groq" => Ok(Self::groq()),
216 "deepseek" => Ok(Self::deepseek()),
217 "gemini" => Ok(Self::gemini()),
218 "xai" => Ok(Self::xai()),
219 #[cfg(feature = "native-inference")]
220 "native" => Ok(Self::native()),
221 _ => Err(ProviderError::NotConfigured {
222 provider: name.to_string(),
223 }.into()),
224 }
225 }
226
227 pub fn claude() -> Self {
229 let client = anthropic::Client::from_env();
230 RigProvider::Claude(client)
231 }
232
233 pub fn openai() -> Self {
235 let client = openai::Client::from_env();
236 RigProvider::OpenAI(client)
237 }
238
239 pub fn mistral() -> Self {
241 let client = mistral::Client::from_env();
242 RigProvider::Mistral(client)
243 }
244
245 pub fn groq() -> Self {
247 let client = groq::Client::from_env();
248 RigProvider::Groq(client)
249 }
250
251 pub fn deepseek() -> Self {
253 let client = deepseek::Client::from_env();
254 RigProvider::DeepSeek(client)
255 }
256
257 pub fn gemini() -> Self {
259 let client = gemini::Client::from_env();
260 RigProvider::Gemini(client)
261 }
262
263 pub fn xai() -> Self {
265 let client = xai::Client::from_env();
266 RigProvider::XAi(client)
267 }
268
269 #[cfg(feature = "native-inference")]
278 pub fn native() -> Self {
279 RigProvider::Native(super::native::NativeRuntime::new())
280 }
281
282 #[cfg(feature = "native-inference")]
290 pub async fn load_native_model(
291 &mut self,
292 model_path: impl Into<std::path::PathBuf>,
293 config: Option<super::native::LoadConfig>,
294 ) -> Result<(), RigInferError> {
295 match self {
296 RigProvider::Native(runtime) => runtime
297 .load(model_path.into(), config.unwrap_or_default())
298 .await
299 .map_err(|e: super::native::NativeError| RigInferError::PromptError(e.to_string())),
300 _ => Err(RigInferError::PromptError(
301 "load_native_model only valid for Native provider".to_string(),
302 )),
303 }
304 }
305
306 #[cfg(feature = "native-inference")]
308 pub fn is_native_loaded(&self) -> bool {
309 match self {
310 RigProvider::Native(runtime) => runtime.is_loaded(),
311 _ => false,
312 }
313 }
314
315 pub fn name(&self) -> &'static str {
317 match self {
318 RigProvider::Claude(_) => "claude",
319 RigProvider::OpenAI(_) => "openai",
320 RigProvider::Mistral(_) => "mistral",
321 RigProvider::Groq(_) => "groq",
322 RigProvider::DeepSeek(_) => "deepseek",
323 RigProvider::Gemini(_) => "gemini",
324 RigProvider::XAi(_) => "xai",
325 #[cfg(feature = "native-inference")]
326 RigProvider::Native(_) => "native",
327 }
328 }
329
330 pub fn default_model(&self) -> &'static str {
342 match self {
343 RigProvider::Claude(_) => "claude-sonnet-4-6",
346 RigProvider::OpenAI(_) => openai::GPT_4O,
347 RigProvider::Mistral(_) => mistral::MISTRAL_LARGE,
348 RigProvider::Groq(_) => "llama-3.3-70b-versatile",
349 RigProvider::DeepSeek(_) => "deepseek-chat",
350 RigProvider::Gemini(_) => "gemini-2.0-flash",
351 RigProvider::XAi(_) => "grok-3-fast",
352 #[cfg(feature = "native-inference")]
354 RigProvider::Native(_) => "native-model",
355 }
356 }
357
358 pub async fn infer(&self, prompt: &str, model: Option<&str>) -> Result<String, RigInferError> {
367 const INFER_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
370
371 let model_id = model.unwrap_or_else(|| self.default_model());
372
373 match self {
374 RigProvider::Claude(client) => {
375 let agent = client.agent(model_id).max_tokens(8192).build();
377 timeout(INFER_TIMEOUT, agent.prompt(prompt))
378 .await
379 .map_err(|_| RigInferError::Timeout {
380 duration_ms: INFER_TIMEOUT.as_millis() as u64,
381 })?
382 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
383 }
384 RigProvider::OpenAI(client) => {
385 let agent = client.agent(model_id).max_tokens(8192).build();
386 timeout(INFER_TIMEOUT, agent.prompt(prompt))
387 .await
388 .map_err(|_| RigInferError::Timeout {
389 duration_ms: INFER_TIMEOUT.as_millis() as u64,
390 })?
391 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
392 }
393 RigProvider::Mistral(client) => {
394 let agent = client.agent(model_id).max_tokens(8192).build();
395 timeout(INFER_TIMEOUT, agent.prompt(prompt))
396 .await
397 .map_err(|_| RigInferError::Timeout {
398 duration_ms: INFER_TIMEOUT.as_millis() as u64,
399 })?
400 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
401 }
402 RigProvider::Groq(client) => {
403 let agent = client.agent(model_id).max_tokens(8192).build();
404 timeout(INFER_TIMEOUT, agent.prompt(prompt))
405 .await
406 .map_err(|_| RigInferError::Timeout {
407 duration_ms: INFER_TIMEOUT.as_millis() as u64,
408 })?
409 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
410 }
411 RigProvider::DeepSeek(client) => {
412 let agent = client.agent(model_id).max_tokens(8192).build();
413 timeout(INFER_TIMEOUT, agent.prompt(prompt))
414 .await
415 .map_err(|_| RigInferError::Timeout {
416 duration_ms: INFER_TIMEOUT.as_millis() as u64,
417 })?
418 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
419 }
420 RigProvider::Gemini(client) => {
421 let agent = client.agent(model_id).max_tokens(8192).build();
422 timeout(INFER_TIMEOUT, agent.prompt(prompt))
423 .await
424 .map_err(|_| RigInferError::Timeout {
425 duration_ms: INFER_TIMEOUT.as_millis() as u64,
426 })?
427 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
428 }
429 RigProvider::XAi(client) => {
430 let agent = client.agent(model_id).max_tokens(8192).build();
431 timeout(INFER_TIMEOUT, agent.prompt(prompt))
432 .await
433 .map_err(|_| RigInferError::Timeout {
434 duration_ms: INFER_TIMEOUT.as_millis() as u64,
435 })?
436 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
437 }
438 #[cfg(feature = "native-inference")]
439 RigProvider::Native(runtime) => {
440 timeout(
443 INFER_TIMEOUT,
444 runtime.infer(prompt, super::native::ChatOptions::default()),
445 )
446 .await
447 .map_err(|_| RigInferError::Timeout {
448 duration_ms: INFER_TIMEOUT.as_millis() as u64,
449 })?
450 .map(|r| r.message.content)
451 .map_err(|e: super::native::NativeError| RigInferError::PromptError(e.to_string()))
452 }
453 }
454 }
455
456 pub async fn infer_vision(
472 &self,
473 user_content: Vec<rig::completion::message::UserContent>,
474 model: Option<&str>,
475 system: Option<&str>,
476 max_tokens: Option<u32>,
477 ) -> Result<String, RigInferError> {
478 use rig::completion::message::Message;
479 use rig::OneOrMany;
480
481 if matches!(self, RigProvider::DeepSeek(_)) {
483 return Err(RigInferError::VisionNotSupported(
484 "DeepSeek does not support vision/multimodal content".to_string(),
485 ));
486 }
487
488 #[cfg(feature = "native-inference")]
490 if let RigProvider::Native(runtime) = self {
491 if !runtime.supports_vision() {
492 return Err(RigInferError::VisionNotSupported(
493 "Native model does not support vision. Load a vision model via \
494 NativeModelKind::VisionHf (e.g., `nika model vision <model_id> --isq Q4K`)"
495 .to_string(),
496 ));
497 }
498 let (prompt_text, vision_images) = extract_native_vision_parts(&user_content)?;
499 let options = super::native::ChatOptions {
500 max_tokens,
501 ..Default::default()
502 };
503 let response = runtime
504 .infer_vision(&prompt_text, vision_images, options)
505 .await
506 .map_err(|e: super::native::NativeError| {
507 RigInferError::PromptError(e.to_string())
508 })?;
509 return Ok(response.message.content);
510 }
511
512 let model_id = model.unwrap_or_else(|| self.default_model());
513 let max_tok = max_tokens.map(u64::from).unwrap_or(8192);
514
515 let message = Message::User {
516 content: OneOrMany::many(user_content).map_err(|_| {
517 RigInferError::VisionNotSupported("content parts list is empty".to_string())
518 })?,
519 };
520
521 macro_rules! vision_prompt {
522 ($client:expr) => {{
523 let mut builder = $client.agent(model_id).max_tokens(max_tok);
524 if let Some(sys) = system {
525 builder = builder.preamble(sys);
526 }
527 let agent = builder.build();
528 agent
529 .prompt(message)
530 .await
531 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
532 }};
533 }
534
535 match self {
536 RigProvider::Claude(client) => vision_prompt!(client),
537 RigProvider::OpenAI(client) => vision_prompt!(client),
538 RigProvider::Mistral(client) => vision_prompt!(client),
539 RigProvider::Groq(client) => vision_prompt!(client),
540 RigProvider::Gemini(client) => vision_prompt!(client),
541 RigProvider::XAi(client) => vision_prompt!(client),
542 RigProvider::DeepSeek(_) => unreachable!("DeepSeek handled above"),
544 #[cfg(feature = "native-inference")]
545 RigProvider::Native(_) => unreachable!("Native handled above"),
546 }
547 }
548
549 pub async fn infer_vision_stream(
554 &self,
555 user_content: Vec<rig::completion::message::UserContent>,
556 tx: mpsc::Sender<StreamChunk>,
557 model: Option<&str>,
558 system: Option<&str>,
559 max_tokens: Option<u32>,
560 ) -> Result<StreamResult, RigInferError> {
561 use rig::completion::message::Message;
562 use rig::OneOrMany;
563
564 if matches!(self, RigProvider::DeepSeek(_)) {
566 return Err(RigInferError::VisionNotSupported(
567 "DeepSeek does not support vision/multimodal content".to_string(),
568 ));
569 }
570
571 #[cfg(feature = "native-inference")]
575 if let RigProvider::Native(runtime) = self {
576 if !runtime.supports_vision() {
577 return Err(RigInferError::VisionNotSupported(
578 "Native model does not support vision. Load a vision model via \
579 NativeModelKind::VisionHf (e.g., `nika model vision <model_id> --isq Q4K`)"
580 .to_string(),
581 ));
582 }
583 let (prompt_text, vision_images) = extract_native_vision_parts(&user_content)?;
584 let options = super::native::ChatOptions {
585 max_tokens,
586 ..Default::default()
587 };
588 let response = runtime
589 .infer_vision(&prompt_text, vision_images, options)
590 .await
591 .map_err(|e: super::native::NativeError| {
592 RigInferError::PromptError(e.to_string())
593 })?;
594 let text = response.message.content;
596 if let Err(e) = tx.send(StreamChunk::Done(text.clone())).await {
597 tracing::warn!(error = %e, "Vision result channel closed — TUI may not show output");
598 }
599 return Ok(StreamResult {
600 text,
601 ..Default::default()
602 });
603 }
604
605 let model_id = model.unwrap_or_else(|| self.default_model());
606 let max_tok = max_tokens.map(u64::from).unwrap_or(8192);
607
608 let message = Message::User {
609 content: OneOrMany::many(user_content).map_err(|_| {
610 RigInferError::VisionNotSupported("content parts list is empty".to_string())
611 })?,
612 };
613
614 let mut response_parts: Vec<String> = Vec::new();
615 let mut result = StreamResult::default();
616
617 macro_rules! vision_stream {
618 ($client:expr, $is_anthropic:expr) => {{
619 let model = $client.completion_model(model_id);
620 let mut builder = model.completion_request(message).max_tokens(max_tok);
621 if let Some(sys) = system {
622 builder = builder.preamble(sys.to_string());
623 }
624 let request = builder.build();
625 let stream_start = Instant::now();
626 let mut stream = model
627 .stream(request)
628 .await
629 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
630 consume_rig_stream(
631 &mut stream,
632 &tx,
633 &mut response_parts,
634 &mut result,
635 $is_anthropic,
636 stream_start,
637 )
638 .await?;
639 }};
640 }
641
642 match self {
643 RigProvider::Claude(client) => vision_stream!(client, true),
644 RigProvider::OpenAI(client) => vision_stream!(client, false),
645 RigProvider::Mistral(client) => vision_stream!(client, false),
646 RigProvider::Groq(client) => vision_stream!(client, false),
647 RigProvider::Gemini(client) => vision_stream!(client, false),
648 RigProvider::XAi(client) => vision_stream!(client, false),
649 RigProvider::DeepSeek(_) => unreachable!("DeepSeek handled above"),
651 #[cfg(feature = "native-inference")]
652 RigProvider::Native(_) => unreachable!("Native handled above"),
653 }
654
655 result.text = response_parts.join("");
656 Ok(result)
657 }
658
659 pub async fn infer_with_tools(
674 &self,
675 prompt: &str,
676 tools: Vec<Box<dyn ToolDyn>>,
677 model: Option<&str>,
678 max_tokens: Option<u32>,
679 system: Option<&str>,
680 ) -> Result<String, RigInferError> {
681 use rig::agent::AgentBuilder;
682 use rig::message::ToolChoice as RigToolChoice;
683
684 let model_id = model.unwrap_or_else(|| self.default_model());
685 let max_tok = max_tokens.map(|v| v as u64).unwrap_or(8192);
686
687 macro_rules! build_agent_with_tools {
688 ($client:expr) => {{
689 let mut builder = AgentBuilder::new($client.completion_model(model_id))
690 .tools(tools)
691 .tool_choice(RigToolChoice::Required)
692 .max_tokens(max_tok);
693 if let Some(sys) = system {
694 builder = builder.preamble(sys);
695 }
696 let agent = builder.build();
697 agent
698 .prompt(prompt)
699 .await
700 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
701 }};
702 }
703
704 match self {
705 RigProvider::Claude(client) => build_agent_with_tools!(client),
706 RigProvider::OpenAI(client) => build_agent_with_tools!(client),
707 RigProvider::Mistral(client) => build_agent_with_tools!(client),
708 RigProvider::Groq(client) => build_agent_with_tools!(client),
709 RigProvider::DeepSeek(client) => build_agent_with_tools!(client),
710 RigProvider::Gemini(client) => build_agent_with_tools!(client),
711 RigProvider::XAi(client) => build_agent_with_tools!(client),
712 #[cfg(feature = "native-inference")]
713 RigProvider::Native(_) => {
714 Err(RigInferError::PromptError(
716 "Native inference does not support tool-based structured output".to_string(),
717 ))
718 }
719 }
720 }
721
722 pub async fn infer_with_options(
742 &self,
743 prompt: &str,
744 options: &InferOptions,
745 ) -> Result<String, RigInferError> {
746 let model_id = options
747 .model
748 .as_deref()
749 .unwrap_or_else(|| self.default_model());
750 let max_tokens = options.max_tokens.unwrap_or(8192);
751
752 let effective_temperature = if options.temperature.is_some() && is_reasoning_model(model_id)
754 {
755 tracing::warn!(
756 model = %model_id,
757 "temperature ignored for reasoning model '{}' (not supported)",
758 model_id
759 );
760 None
761 } else {
762 options.temperature
763 };
764
765 let user_prompt = prompt.to_string();
767
768 match self {
769 RigProvider::Claude(client) => {
770 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
771 if let Some(system) = &options.system {
772 builder = builder.preamble(system);
773 }
774 if let Some(temp) = effective_temperature {
775 builder = builder.temperature(temp);
776 }
777 let agent = builder.build();
778 agent
779 .prompt(&user_prompt)
780 .await
781 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
782 }
783 RigProvider::OpenAI(client) => {
784 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
785 if let Some(system) = &options.system {
786 builder = builder.preamble(system);
787 }
788 if let Some(temp) = effective_temperature {
789 builder = builder.temperature(temp);
790 }
791 let agent = builder.build();
792 agent
793 .prompt(&user_prompt)
794 .await
795 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
796 }
797 RigProvider::Mistral(client) => {
798 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
799 if let Some(system) = &options.system {
800 builder = builder.preamble(system);
801 }
802 if let Some(temp) = effective_temperature {
803 builder = builder.temperature(temp);
804 }
805 let agent = builder.build();
806 agent
807 .prompt(&user_prompt)
808 .await
809 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
810 }
811 RigProvider::Groq(client) => {
812 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
813 if let Some(system) = &options.system {
814 builder = builder.preamble(system);
815 }
816 if let Some(temp) = effective_temperature {
817 builder = builder.temperature(temp);
818 }
819 let agent = builder.build();
820 agent
821 .prompt(&user_prompt)
822 .await
823 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
824 }
825 RigProvider::DeepSeek(client) => {
826 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
827 if let Some(system) = &options.system {
828 builder = builder.preamble(system);
829 }
830 if let Some(temp) = effective_temperature {
831 builder = builder.temperature(temp);
832 }
833 let agent = builder.build();
834 agent
835 .prompt(&user_prompt)
836 .await
837 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
838 }
839 RigProvider::Gemini(client) => {
840 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
841 if let Some(system) = &options.system {
842 builder = builder.preamble(system);
843 }
844 if let Some(temp) = effective_temperature {
845 builder = builder.temperature(temp);
846 }
847 let agent = builder.build();
848 agent
849 .prompt(&user_prompt)
850 .await
851 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
852 }
853 RigProvider::XAi(client) => {
854 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
855 if let Some(system) = &options.system {
856 builder = builder.preamble(system);
857 }
858 if let Some(temp) = effective_temperature {
859 builder = builder.temperature(temp);
860 }
861 let agent = builder.build();
862 agent
863 .prompt(&user_prompt)
864 .await
865 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
866 }
867 #[cfg(feature = "native-inference")]
868 RigProvider::Native(runtime) => {
869 let chat_options = super::native::ChatOptions {
871 temperature: effective_temperature.map(|t| t as f32),
872 max_tokens: options.max_tokens,
873 ..Default::default()
874 };
875 runtime
876 .infer(&user_prompt, chat_options)
877 .await
878 .map(|r| r.message.content)
879 .map_err(|e: super::native::NativeError| {
880 RigInferError::PromptError(e.to_string())
881 })
882 }
883 }
884 }
885
886 pub fn auto() -> Option<Self> {
900 use crate::core::providers::{ProviderCategory, KNOWN_PROVIDERS};
901
902 for p in KNOWN_PROVIDERS.iter() {
904 if p.category == ProviderCategory::Llm && p.has_env_key() {
905 return match p.id {
906 "anthropic" => Some(Self::claude()),
907 "openai" => Some(Self::openai()),
908 "mistral" => Some(Self::mistral()),
909 "groq" => Some(Self::groq()),
910 "deepseek" => Some(Self::deepseek()),
911 "gemini" => Some(Self::gemini()),
912 "xai" => Some(Self::xai()),
913 _ => continue,
914 };
915 }
916 }
917 #[cfg(feature = "native-inference")]
919 if std::env::var("NIKA_NATIVE_MODEL").is_ok_and(|v| !v.trim().is_empty()) {
920 return Some(Self::native());
921 }
922 None
923 }
924
925 pub async fn verify(&self) -> Result<ProviderVerifyResult, ProviderVerifyError> {
939 use std::time::Instant;
940
941 let start = Instant::now();
942
943 let test_prompt = "Hi";
945
946 match self.infer(test_prompt, None).await {
947 Ok(_) => Ok(ProviderVerifyResult {
948 provider: self.name().to_string(),
949 latency: start.elapsed(),
950 model: self.default_model().to_string(),
951 }),
952 Err(e) => {
953 let error_msg = e.to_string().to_lowercase();
954
955 if error_msg.contains("401")
957 || error_msg.contains("unauthorized")
958 || error_msg.contains("invalid api key")
959 || error_msg.contains("authentication")
960 {
961 Err(ProviderVerifyError::InvalidApiKey {
962 provider: self.name().to_string(),
963 })
964 } else if error_msg.contains("rate limit")
965 || error_msg.contains("429")
966 || error_msg.contains("too many requests")
967 {
968 Err(ProviderVerifyError::RateLimited {
969 provider: self.name().to_string(),
970 })
971 } else if error_msg.contains("timeout")
972 || error_msg.contains("timed out")
973 || error_msg.contains("deadline")
974 {
975 Err(ProviderVerifyError::Timeout {
976 provider: self.name().to_string(),
977 })
978 } else if error_msg.contains("connection")
979 || error_msg.contains("network")
980 || error_msg.contains("dns")
981 || error_msg.contains("refused")
982 {
983 Err(ProviderVerifyError::NetworkError {
984 provider: self.name().to_string(),
985 details: e.to_string(),
986 })
987 } else {
988 Err(ProviderVerifyError::ProviderError {
989 provider: self.name().to_string(),
990 details: e.to_string(),
991 })
992 }
993 }
994 }
995 }
996
997 pub fn is_configured(&self) -> bool {
1002 let has_key = |key: &str| std::env::var(key).is_ok_and(|v| !v.trim().is_empty());
1003
1004 match self {
1005 RigProvider::Claude(_) => has_key("ANTHROPIC_API_KEY"),
1006 RigProvider::OpenAI(_) => has_key("OPENAI_API_KEY"),
1007 RigProvider::Mistral(_) => has_key("MISTRAL_API_KEY"),
1008 RigProvider::Groq(_) => has_key("GROQ_API_KEY"),
1009 RigProvider::DeepSeek(_) => has_key("DEEPSEEK_API_KEY"),
1010 RigProvider::Gemini(_) => has_key("GEMINI_API_KEY"),
1011 RigProvider::XAi(_) => has_key("XAI_API_KEY"),
1012 #[cfg(feature = "native-inference")]
1013 RigProvider::Native(_) => {
1014 true
1017 }
1018 }
1019 }
1020}
1021
1022#[derive(Debug, Clone)]
1028pub struct ProviderVerifyResult {
1029 pub provider: String,
1031 pub latency: std::time::Duration,
1033 pub model: String,
1035}
1036
1037#[derive(Debug, Clone, thiserror::Error)]
1039pub enum ProviderVerifyError {
1040 #[error("Invalid API key for {provider}")]
1041 InvalidApiKey { provider: String },
1042
1043 #[error("Rate limited by {provider}")]
1044 RateLimited { provider: String },
1045
1046 #[error("Connection timeout to {provider}")]
1047 Timeout { provider: String },
1048
1049 #[error("Network error connecting to {provider}: {details}")]
1050 NetworkError { provider: String, details: String },
1051
1052 #[error("Provider error from {provider}: {details}")]
1053 ProviderError { provider: String, details: String },
1054}
1055
1056impl ProviderVerifyError {
1057 pub fn suggestion(&self) -> &'static str {
1059 match self {
1060 ProviderVerifyError::InvalidApiKey { .. } => {
1061 "Check your API key in environment variables"
1062 }
1063 ProviderVerifyError::RateLimited { .. } => {
1064 "Wait a moment and try again, or check your plan limits"
1065 }
1066 ProviderVerifyError::Timeout { .. } => "Check your network connection or try again",
1067 ProviderVerifyError::NetworkError { .. } => {
1068 "Check your internet connection and firewall settings"
1069 }
1070 ProviderVerifyError::ProviderError { .. } => {
1071 "The provider service may be experiencing issues"
1072 }
1073 }
1074 }
1075}
1076
1077#[derive(Debug, thiserror::Error)]
1079pub enum RigInferError {
1080 #[error("Completion error: {0}")]
1081 PromptError(String),
1082
1083 #[error("Stream timeout: no chunk received for {duration_ms}ms")]
1085 Timeout { duration_ms: u64 },
1086
1087 #[error("Vision not supported: {0}")]
1089 VisionNotSupported(String),
1090}
1091
1092#[derive(Debug, Clone)]
1098pub enum StreamChunk {
1099 Token(String),
1101 Thinking(String),
1103 Done(String),
1105 Error(String),
1107 Metrics {
1109 input_tokens: u64,
1110 output_tokens: u64,
1111 },
1112 McpConnected(String),
1114 McpError { server_name: String, error: String },
1116 McpCallStart {
1121 tool: String,
1122 server: String,
1123 params: String,
1124 },
1125 McpCallComplete { result: String },
1127 McpCallFailed { error: String },
1129 InferStart {
1131 model: String,
1132 prompt: String,
1134 prompt_tokens: u32,
1135 max_tokens: u32,
1136 },
1137 InferTokens { output_tokens: u32 },
1139 InferComplete,
1141 ExecStart { command: String },
1146 ExecComplete,
1148 FetchStart { url: String, method: String },
1150 FetchComplete,
1152 AgentStart { goal: String },
1154 AgentComplete,
1156 ProviderVerifying { provider: String, model: String },
1161 ProviderVerified {
1163 provider: String,
1164 model: String,
1165 latency_ms: u64,
1166 },
1167 ProviderVerifyFailed { provider: String, error: String },
1169 ProviderNotConfigured { provider: String },
1171 McpPinging { server: String },
1173 McpPinged {
1175 server: String,
1176 latency_ms: u64,
1177 tool_count: usize,
1178 },
1179 ProviderVerificationTimeout,
1181 NativeModelPullStarted { model: String },
1186 NativeModelPullProgress {
1188 model: String,
1189 status: String,
1190 completed: u64,
1191 total: u64,
1192 },
1193 NativeModelPulled {
1195 model: String,
1196 path: String,
1197 size: u64,
1198 },
1199 NativeModelPullFailed { model: String, error: String },
1201 NativeModelDeleted { model: String },
1203 NativeModelDeleteFailed { model: String, error: String },
1205 NativeModelsRefreshed { count: usize },
1207}
1208
1209#[derive(Debug, Clone, Default)]
1215pub struct StreamResult {
1216 pub text: String,
1218 pub input_tokens: u64,
1220 pub output_tokens: u64,
1222 pub total_tokens: u64,
1224 pub cached_input_tokens: u64,
1226 pub ttft_ms: Option<u64>,
1228 pub request_id: Option<String>,
1230}
1231
1232impl StreamResult {
1233 pub fn from_text(text: impl Into<String>) -> Self {
1235 Self {
1236 text: text.into(),
1237 ..Default::default()
1238 }
1239 }
1240}
1241
1242async fn consume_rig_stream<R>(
1252 stream: &mut rig::streaming::StreamingCompletionResponse<R>,
1253 tx: &mpsc::Sender<StreamChunk>,
1254 response_parts: &mut Vec<String>,
1255 result: &mut StreamResult,
1256 capture_thinking: bool,
1257 stream_start: Instant,
1258) -> Result<(), RigInferError>
1259where
1260 R: Clone + Unpin + GetTokenUsage + serde::Serialize + serde::de::DeserializeOwned,
1261{
1262 loop {
1263 let chunk_result = match timeout(STREAM_CHUNK_TIMEOUT, stream.next()).await {
1264 Ok(Some(result)) => result,
1265 Ok(None) => break,
1266 Err(_elapsed) => {
1267 let _ = tx.try_send(StreamChunk::Error(format!(
1268 "Stream timeout: no chunk received for {}s",
1269 STREAM_CHUNK_TIMEOUT.as_secs()
1270 )));
1271 return Err(RigInferError::Timeout {
1272 duration_ms: STREAM_CHUNK_TIMEOUT.as_millis() as u64,
1273 });
1274 }
1275 };
1276
1277 match chunk_result {
1278 Ok(content) => match content {
1279 StreamedAssistantContent::Text(text) => {
1280 if result.ttft_ms.is_none() {
1282 result.ttft_ms = Some(stream_start.elapsed().as_millis() as u64);
1283 }
1284 response_parts.push(text.text.clone());
1285 let _ = tx.try_send(StreamChunk::Token(text.text));
1286 }
1287 StreamedAssistantContent::ReasoningDelta { reasoning, .. } if capture_thinking => {
1288 let _ = tx.try_send(StreamChunk::Thinking(reasoning));
1289 }
1290 StreamedAssistantContent::Final(response) => {
1291 if let Some(usage) = response.token_usage() {
1292 result.input_tokens = usage.input_tokens;
1293 result.output_tokens = usage.output_tokens;
1294 result.total_tokens = usage.total_tokens;
1295 result.cached_input_tokens = usage.cached_input_tokens;
1296 }
1297 }
1298 _ => {}
1299 },
1300 Err(e) => {
1301 let _ = tx.try_send(StreamChunk::Error(e.to_string()));
1302 return Err(RigInferError::PromptError(e.to_string()));
1303 }
1304 }
1305 }
1306 Ok(())
1307}
1308
1309impl RigProvider {
1310 pub async fn infer_stream(
1322 &self,
1323 prompt: &str,
1324 tx: mpsc::Sender<StreamChunk>,
1325 model: Option<&str>,
1326 ) -> Result<StreamResult, RigInferError> {
1327 let model_id = model.unwrap_or_else(|| self.default_model());
1328 let mut response_parts: Vec<String> = Vec::new();
1329 let mut result = StreamResult::default();
1330
1331 match self {
1332 RigProvider::Claude(client) => {
1333 let model = client.completion_model(model_id);
1334 let request = model.completion_request(prompt).max_tokens(8192).build();
1335 let stream_start = Instant::now();
1336 let mut stream = model
1337 .stream(request)
1338 .await
1339 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1340 consume_rig_stream(
1341 &mut stream,
1342 &tx,
1343 &mut response_parts,
1344 &mut result,
1345 true,
1346 stream_start,
1347 )
1348 .await?;
1349 }
1350 RigProvider::OpenAI(client) => {
1351 let model = client.completion_model(model_id);
1352 let request = model.completion_request(prompt).max_tokens(8192).build();
1353 let stream_start = Instant::now();
1354 let mut stream = model
1355 .stream(request)
1356 .await
1357 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1358 consume_rig_stream(
1359 &mut stream,
1360 &tx,
1361 &mut response_parts,
1362 &mut result,
1363 false,
1364 stream_start,
1365 )
1366 .await?;
1367 }
1368 RigProvider::Mistral(client) => {
1369 let model = client.completion_model(model_id);
1370 let request = model.completion_request(prompt).max_tokens(8192).build();
1371 let stream_start = Instant::now();
1372 let mut stream = model
1373 .stream(request)
1374 .await
1375 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1376 consume_rig_stream(
1377 &mut stream,
1378 &tx,
1379 &mut response_parts,
1380 &mut result,
1381 false,
1382 stream_start,
1383 )
1384 .await?;
1385 }
1386 RigProvider::Groq(client) => {
1387 let model = client.completion_model(model_id);
1388 let request = model.completion_request(prompt).max_tokens(8192).build();
1389 let stream_start = Instant::now();
1390 let mut stream = model
1391 .stream(request)
1392 .await
1393 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1394 consume_rig_stream(
1395 &mut stream,
1396 &tx,
1397 &mut response_parts,
1398 &mut result,
1399 false,
1400 stream_start,
1401 )
1402 .await?;
1403 }
1404 RigProvider::DeepSeek(client) => {
1405 let model = client.completion_model(model_id);
1406 let request = model.completion_request(prompt).max_tokens(8192).build();
1407 let stream_start = Instant::now();
1408 let mut stream = model
1409 .stream(request)
1410 .await
1411 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1412 consume_rig_stream(
1413 &mut stream,
1414 &tx,
1415 &mut response_parts,
1416 &mut result,
1417 false,
1418 stream_start,
1419 )
1420 .await?;
1421 }
1422 RigProvider::Gemini(client) => {
1423 let model = client.completion_model(model_id);
1424 let request = model.completion_request(prompt).max_tokens(8192).build();
1425 let stream_start = Instant::now();
1426 let mut stream = model
1427 .stream(request)
1428 .await
1429 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1430 consume_rig_stream(
1431 &mut stream,
1432 &tx,
1433 &mut response_parts,
1434 &mut result,
1435 false,
1436 stream_start,
1437 )
1438 .await?;
1439 }
1440 RigProvider::XAi(client) => {
1441 let model = client.completion_model(model_id);
1442 let request = model.completion_request(prompt).max_tokens(8192).build();
1443 let stream_start = Instant::now();
1444 let mut stream = model
1445 .stream(request)
1446 .await
1447 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1448 consume_rig_stream(
1449 &mut stream,
1450 &tx,
1451 &mut response_parts,
1452 &mut result,
1453 false,
1454 stream_start,
1455 )
1456 .await?;
1457 }
1458 #[cfg(feature = "native-inference")]
1460 RigProvider::Native(runtime) => {
1461 use futures::StreamExt;
1462 use std::pin::pin;
1463
1464 let stream = runtime
1466 .infer_stream(prompt, super::native::ChatOptions::default())
1467 .await
1468 .map_err(|e: super::native::NativeError| {
1469 RigInferError::PromptError(e.to_string())
1470 })?;
1471
1472 let mut stream = pin!(stream);
1474
1475 while let Some(result) = stream.next().await {
1477 match result {
1478 Ok(token) => {
1479 response_parts.push(token.clone());
1480 let _ = tx.try_send(StreamChunk::Token(token));
1481 }
1482 Err(e) => {
1483 let _ = tx.try_send(StreamChunk::Error(e.to_string()));
1484 return Err(RigInferError::PromptError(e.to_string()));
1485 }
1486 }
1487 }
1488
1489 }
1492 }
1493
1494 let complete_response = response_parts.concat();
1495 let _ = tx.try_send(StreamChunk::Done(complete_response.clone()));
1496
1497 let _ = tx.try_send(StreamChunk::Metrics {
1499 input_tokens: result.input_tokens,
1500 output_tokens: result.output_tokens,
1501 });
1502
1503 result.text = complete_response;
1504 Ok(result)
1505 }
1506
1507 pub async fn infer_stream_with_options(
1520 &self,
1521 prompt: &str,
1522 tx: mpsc::Sender<StreamChunk>,
1523 options: &InferOptions,
1524 ) -> Result<StreamResult, RigInferError> {
1525 let model_id = options
1526 .model
1527 .as_deref()
1528 .unwrap_or_else(|| self.default_model());
1529 let max_tokens = options.max_tokens.unwrap_or(8192);
1530 let mut response_parts: Vec<String> = Vec::new();
1531 let mut result = StreamResult::default();
1532
1533 let effective_temperature = if options.temperature.is_some() && is_reasoning_model(model_id)
1535 {
1536 tracing::warn!(
1537 model = %model_id,
1538 "temperature ignored for reasoning model '{}' (not supported)",
1539 model_id
1540 );
1541 None
1542 } else {
1543 options.temperature
1544 };
1545
1546 macro_rules! build_request_with_options {
1550 ($client:expr) => {{
1551 let model = $client.completion_model(model_id);
1552 let mut rb = model
1553 .completion_request(prompt)
1554 .max_tokens(max_tokens as u64);
1555 if let Some(ref system) = options.system {
1556 rb = rb.preamble(system.clone());
1557 }
1558 if let Some(temp) = effective_temperature {
1559 rb = rb.temperature(temp);
1560 }
1561 model
1562 .stream(rb.build())
1563 .await
1564 .map_err(|e| RigInferError::PromptError(e.to_string()))?
1565 }};
1566 }
1567
1568 match self {
1569 RigProvider::Claude(client) => {
1570 let stream_start = Instant::now();
1571 let mut stream = build_request_with_options!(client);
1572 consume_rig_stream(
1573 &mut stream,
1574 &tx,
1575 &mut response_parts,
1576 &mut result,
1577 true,
1578 stream_start,
1579 )
1580 .await?;
1581 }
1582 RigProvider::OpenAI(client) => {
1583 let stream_start = Instant::now();
1584 let mut stream = build_request_with_options!(client);
1585 consume_rig_stream(
1586 &mut stream,
1587 &tx,
1588 &mut response_parts,
1589 &mut result,
1590 false,
1591 stream_start,
1592 )
1593 .await?;
1594 }
1595 RigProvider::Mistral(client) => {
1596 let stream_start = Instant::now();
1597 let mut stream = build_request_with_options!(client);
1598 consume_rig_stream(
1599 &mut stream,
1600 &tx,
1601 &mut response_parts,
1602 &mut result,
1603 false,
1604 stream_start,
1605 )
1606 .await?;
1607 }
1608 RigProvider::Groq(client) => {
1609 let stream_start = Instant::now();
1610 let mut stream = build_request_with_options!(client);
1611 consume_rig_stream(
1612 &mut stream,
1613 &tx,
1614 &mut response_parts,
1615 &mut result,
1616 false,
1617 stream_start,
1618 )
1619 .await?;
1620 }
1621 RigProvider::DeepSeek(client) => {
1622 let stream_start = Instant::now();
1623 let mut stream = build_request_with_options!(client);
1624 consume_rig_stream(
1625 &mut stream,
1626 &tx,
1627 &mut response_parts,
1628 &mut result,
1629 false,
1630 stream_start,
1631 )
1632 .await?;
1633 }
1634 RigProvider::Gemini(client) => {
1635 let stream_start = Instant::now();
1636 let mut stream = build_request_with_options!(client);
1637 consume_rig_stream(
1638 &mut stream,
1639 &tx,
1640 &mut response_parts,
1641 &mut result,
1642 false,
1643 stream_start,
1644 )
1645 .await?;
1646 }
1647 RigProvider::XAi(client) => {
1648 let stream_start = Instant::now();
1649 let mut stream = build_request_with_options!(client);
1650 consume_rig_stream(
1651 &mut stream,
1652 &tx,
1653 &mut response_parts,
1654 &mut result,
1655 false,
1656 stream_start,
1657 )
1658 .await?;
1659 }
1660 #[cfg(feature = "native-inference")]
1662 RigProvider::Native(runtime) => {
1663 use futures::StreamExt;
1664 use std::pin::pin;
1665
1666 let native_prompt = if let Some(ref system) = options.system {
1668 format!("{}\n\n{}", system, prompt)
1669 } else {
1670 prompt.to_string()
1671 };
1672 let chat_options = super::native::ChatOptions {
1673 temperature: effective_temperature.map(|t| t as f32),
1674 max_tokens: options.max_tokens,
1675 ..Default::default()
1676 };
1677 let stream = runtime
1678 .infer_stream(&native_prompt, chat_options)
1679 .await
1680 .map_err(|e: super::native::NativeError| {
1681 RigInferError::PromptError(e.to_string())
1682 })?;
1683
1684 let mut stream = pin!(stream);
1686
1687 while let Some(result) = stream.next().await {
1689 match result {
1690 Ok(token) => {
1691 response_parts.push(token.clone());
1692 let _ = tx.try_send(StreamChunk::Token(token));
1693 }
1694 Err(e) => {
1695 let _ = tx.try_send(StreamChunk::Error(e.to_string()));
1696 return Err(RigInferError::PromptError(e.to_string()));
1697 }
1698 }
1699 }
1700
1701 }
1704 }
1705
1706 let complete_response = response_parts.concat();
1707 let _ = tx.try_send(StreamChunk::Done(complete_response.clone()));
1708
1709 let _ = tx.try_send(StreamChunk::Metrics {
1710 input_tokens: result.input_tokens,
1711 output_tokens: result.output_tokens,
1712 });
1713
1714 result.text = complete_response;
1715 Ok(result)
1716 }
1717}
1718
1719#[derive(Debug, Clone)]
1728pub struct NikaMcpToolDef {
1729 pub name: String,
1731 pub description: String,
1733 pub input_schema: serde_json::Value,
1735}
1736
1737pub type AgentMediaStaging = Arc<dashmap::DashMap<String, Vec<crate::mcp::types::ContentBlock>>>;
1744
1745#[derive(Debug, Clone)]
1753pub struct NikaMcpTool {
1754 definition: NikaMcpToolDef,
1755 client: Option<Arc<McpClient>>,
1757 media_staging: Option<AgentMediaStaging>,
1759}
1760
1761impl NikaMcpTool {
1762 pub fn new(definition: NikaMcpToolDef) -> Self {
1764 Self {
1765 definition,
1766 client: None,
1767 media_staging: None,
1768 }
1769 }
1770
1771 pub fn with_client(definition: NikaMcpToolDef, client: Arc<McpClient>) -> Self {
1773 Self {
1774 definition,
1775 client: Some(client),
1776 media_staging: None,
1777 }
1778 }
1779
1780 pub fn with_media_staging(
1782 definition: NikaMcpToolDef,
1783 client: Arc<McpClient>,
1784 staging: AgentMediaStaging,
1785 ) -> Self {
1786 Self {
1787 definition,
1788 client: Some(client),
1789 media_staging: Some(staging),
1790 }
1791 }
1792
1793 pub fn tool_name(&self) -> &str {
1795 &self.definition.name
1796 }
1797}
1798
1799type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
1801
1802impl ToolDyn for NikaMcpTool {
1803 fn name(&self) -> String {
1804 self.definition.name.clone()
1805 }
1806
1807 fn definition(&self, _prompt: String) -> BoxFuture<'_, ToolDefinition> {
1808 let def = ToolDefinition {
1809 name: self.definition.name.clone(),
1810 description: self.definition.description.clone(),
1811 parameters: self.definition.input_schema.clone(),
1812 };
1813 Box::pin(async move { def })
1814 }
1815
1816 fn call(&self, args: String) -> BoxFuture<'_, Result<String, ToolError>> {
1817 let tool_name = self.definition.name.clone();
1818 let client = self.client.clone();
1819
1820 Box::pin(async move {
1821 let params: serde_json::Value = serde_json::from_str(&args).map_err(|e| {
1823 ToolError::ToolCallError(Box::new(McpToolError::invalid_args(format!(
1824 "Invalid JSON arguments: {}",
1825 e
1826 ))))
1827 })?;
1828
1829 let client = client.ok_or_else(|| {
1831 ToolError::ToolCallError(Box::new(McpToolError::not_configured(
1832 "No MCP client configured for this tool",
1833 )))
1834 })?;
1835
1836 let result = client.call_tool(&tool_name, params).await.map_err(|e| {
1838 ToolError::ToolCallError(Box::new(McpToolError::call_failed(format!(
1839 "MCP tool call failed: {}",
1840 e
1841 ))))
1842 })?;
1843
1844 if result.has_media() {
1846 if let Some(ref staging) = self.media_staging {
1847 let media_blocks: Vec<_> = result.media_blocks().into_iter().cloned().collect();
1848 if !media_blocks.is_empty() {
1849 tracing::debug!(
1850 tool = %tool_name,
1851 media_count = media_blocks.len(),
1852 "agent: staging binary content from tool call"
1853 );
1854 staging
1855 .entry(tool_name.clone())
1856 .or_default()
1857 .extend(media_blocks);
1858 }
1859 } else {
1860 tracing::warn!(
1861 tool = %tool_name,
1862 media_count = result.media_blocks().len(),
1863 "agent: tool returned binary content but no media staging configured — data will be lost"
1864 );
1865 }
1866 }
1867
1868 let output = result.text();
1870
1871 if output.is_empty() {
1872 serde_json::to_string(&result).map_err(|e| {
1874 ToolError::ToolCallError(Box::new(McpToolError::serialization(format!(
1875 "Failed to serialize result: {}",
1876 e
1877 ))))
1878 })
1879 } else {
1880 Ok(output)
1881 }
1882 })
1883 }
1884}
1885
1886#[cfg(feature = "native-inference")]
1899fn extract_native_vision_parts(
1900 user_content: &[rig::completion::message::UserContent],
1901) -> Result<(String, Vec<crate::core::backend::VisionImage>), RigInferError> {
1902 use base64::Engine as _;
1903 use rig::completion::message::{DocumentSourceKind, Image, UserContent};
1904
1905 let mut text_parts: Vec<String> = Vec::new();
1906 let mut images: Vec<crate::core::backend::VisionImage> = Vec::new();
1907
1908 for part in user_content {
1909 match part {
1910 UserContent::Text(text) => {
1911 text_parts.push(text.text.clone());
1912 }
1913 UserContent::Image(Image {
1914 data, media_type, ..
1915 }) => {
1916 let bytes = match data {
1917 DocumentSourceKind::Base64(b64) => base64::engine::general_purpose::STANDARD
1918 .decode(b64)
1919 .map_err(|e| {
1920 RigInferError::PromptError(format!(
1921 "Failed to decode base64 image for native vision: {}",
1922 e
1923 ))
1924 })?,
1925 DocumentSourceKind::Raw(raw) => raw.clone(),
1926 DocumentSourceKind::Url(url) => {
1927 return Err(RigInferError::VisionNotSupported(format!(
1928 "Native vision does not support URL images. Pre-fetch the image: {}",
1929 url
1930 )));
1931 }
1932 _ => {
1933 return Err(RigInferError::PromptError(
1934 "Unsupported image source kind for native vision".to_string(),
1935 ));
1936 }
1937 };
1938
1939 let mime = media_type
1941 .as_ref()
1942 .map(|mt| match mt {
1943 rig::completion::message::ImageMediaType::JPEG => "image/jpeg",
1944 rig::completion::message::ImageMediaType::PNG => "image/png",
1945 rig::completion::message::ImageMediaType::GIF => "image/gif",
1946 rig::completion::message::ImageMediaType::WEBP => "image/webp",
1947 _ => "image/png", })
1949 .unwrap_or("image/png");
1950
1951 images.push(crate::core::backend::VisionImage::new(bytes, mime));
1952 }
1953 _ => {}
1955 }
1956 }
1957
1958 Ok((text_parts.join("\n"), images))
1959}
1960
1961#[cfg(test)]
1962mod tests {
1963 use super::*;
1964 use serial_test::serial;
1965
1966 #[test]
1971 fn stream_result_from_text_has_zero_tokens() {
1972 let result = StreamResult::from_text("hello world");
1973 assert_eq!(result.text, "hello world");
1974 assert_eq!(result.input_tokens, 0);
1975 assert_eq!(result.output_tokens, 0);
1976 assert_eq!(result.total_tokens, 0);
1977 assert_eq!(result.cached_input_tokens, 0);
1978 }
1979
1980 #[test]
1981 fn stream_result_default_is_empty() {
1982 let result = StreamResult::default();
1983 assert_eq!(result.text, "");
1984 assert_eq!(result.total_tokens, 0);
1985 }
1986
1987 #[test]
1988 fn stream_result_with_tokens() {
1989 let result = StreamResult {
1990 text: "response".to_string(),
1991 input_tokens: 100,
1992 output_tokens: 50,
1993 total_tokens: 150,
1994 cached_input_tokens: 20,
1995 ttft_ms: None,
1996 request_id: None,
1997 };
1998 assert_eq!(
1999 result.total_tokens,
2000 result.input_tokens + result.output_tokens
2001 );
2002 assert_eq!(result.cached_input_tokens, 20);
2003 }
2004
2005 #[test]
2006 #[serial]
2007 fn test_rig_provider_claude_returns_claude_variant() {
2008 std::env::set_var("ANTHROPIC_API_KEY", "test-key-for-unit-test");
2014 let provider = RigProvider::claude();
2015
2016 assert_eq!(provider.name(), "claude");
2017 assert!(matches!(provider, RigProvider::Claude(_)));
2018 }
2019
2020 #[test]
2021 #[serial]
2022 fn test_rig_provider_openai_returns_openai_variant() {
2023 std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
2024 let provider = RigProvider::openai();
2025
2026 assert_eq!(provider.name(), "openai");
2027 assert!(matches!(provider, RigProvider::OpenAI(_)));
2028 }
2029
2030 #[test]
2031 #[serial]
2032 fn test_rig_provider_default_model_claude() {
2033 std::env::set_var("ANTHROPIC_API_KEY", "test-key-for-unit-test");
2034 let provider = RigProvider::claude();
2035
2036 assert_eq!(provider.default_model(), "claude-sonnet-4-6");
2039 }
2040
2041 #[test]
2042 #[serial]
2043 fn test_rig_provider_default_model_openai() {
2044 std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
2045 let provider = RigProvider::openai();
2046
2047 assert_eq!(provider.default_model(), openai::GPT_4O);
2048 }
2049
2050 #[test]
2051 fn test_rig_infer_error_display() {
2052 let err = RigInferError::PromptError("Test error message".to_string());
2053 assert_eq!(err.to_string(), "Completion error: Test error message");
2054 }
2055
2056 #[test]
2057 fn test_rig_infer_error_timeout_display() {
2058 let err = RigInferError::Timeout { duration_ms: 60000 };
2060 assert_eq!(
2061 err.to_string(),
2062 "Stream timeout: no chunk received for 60000ms"
2063 );
2064 }
2065
2066 #[test]
2071 #[serial]
2072 fn test_rig_provider_mistral_returns_mistral_variant() {
2073 std::env::set_var("MISTRAL_API_KEY", "test-key-for-unit-test");
2074 let provider = RigProvider::mistral();
2075
2076 assert_eq!(provider.name(), "mistral");
2077 assert!(matches!(provider, RigProvider::Mistral(_)));
2078 }
2079
2080 #[test]
2081 #[serial]
2082 fn test_rig_provider_groq_returns_groq_variant() {
2083 std::env::set_var("GROQ_API_KEY", "test-key-for-unit-test");
2084 let provider = RigProvider::groq();
2085
2086 assert_eq!(provider.name(), "groq");
2087 assert!(matches!(provider, RigProvider::Groq(_)));
2088 }
2089
2090 #[test]
2091 #[serial]
2092 fn test_rig_provider_deepseek_returns_deepseek_variant() {
2093 std::env::set_var("DEEPSEEK_API_KEY", "test-key-for-unit-test");
2094 let provider = RigProvider::deepseek();
2095
2096 assert_eq!(provider.name(), "deepseek");
2097 assert!(matches!(provider, RigProvider::DeepSeek(_)));
2098 }
2099
2100 #[test]
2101 #[serial]
2102 fn test_rig_provider_default_models_v06() {
2103 std::env::set_var("MISTRAL_API_KEY", "test");
2105 std::env::set_var("GROQ_API_KEY", "test");
2106 std::env::set_var("DEEPSEEK_API_KEY", "test");
2107
2108 assert_eq!(
2109 RigProvider::mistral().default_model(),
2110 mistral::MISTRAL_LARGE
2111 );
2112 assert_eq!(
2113 RigProvider::groq().default_model(),
2114 "llama-3.3-70b-versatile"
2115 );
2116 assert_eq!(RigProvider::deepseek().default_model(), "deepseek-chat");
2117 }
2118
2119 #[test]
2120 #[serial]
2121 fn test_rig_provider_auto_detects_claude() {
2122 std::env::remove_var("OPENAI_API_KEY");
2124 std::env::remove_var("MISTRAL_API_KEY");
2125 std::env::remove_var("GROQ_API_KEY");
2126 std::env::remove_var("DEEPSEEK_API_KEY");
2127 std::env::set_var("ANTHROPIC_API_KEY", "test-key");
2128
2129 let provider = RigProvider::auto();
2130 assert!(provider.is_some());
2131 assert_eq!(provider.unwrap().name(), "claude");
2132 }
2133
2134 #[test]
2135 #[serial]
2136 fn test_rig_provider_auto_returns_none_when_no_keys() {
2137 clear_all_provider_env_vars();
2139
2140 let provider = RigProvider::auto();
2141 assert!(provider.is_none());
2142 }
2143
2144 fn clear_all_provider_env_vars() {
2150 std::env::remove_var("ANTHROPIC_API_KEY");
2151 std::env::remove_var("OPENAI_API_KEY");
2152 std::env::remove_var("MISTRAL_API_KEY");
2153 std::env::remove_var("GROQ_API_KEY");
2154 std::env::remove_var("DEEPSEEK_API_KEY");
2155 std::env::remove_var("GEMINI_API_KEY");
2156 }
2157
2158 #[test]
2159 #[serial]
2160 fn test_auto_fallback_to_openai() {
2161 clear_all_provider_env_vars();
2163 std::env::set_var("OPENAI_API_KEY", "test-key");
2164
2165 let provider = RigProvider::auto();
2167
2168 assert!(provider.is_some());
2170 assert_eq!(provider.unwrap().name(), "openai");
2171 }
2172
2173 #[test]
2174 #[serial]
2175 fn test_auto_fallback_to_mistral() {
2176 clear_all_provider_env_vars();
2178 std::env::set_var("MISTRAL_API_KEY", "test-key");
2179
2180 let provider = RigProvider::auto();
2182
2183 assert!(provider.is_some());
2185 assert_eq!(provider.unwrap().name(), "mistral");
2186 }
2187
2188 #[test]
2189 #[serial]
2190 fn test_auto_fallback_to_groq() {
2191 clear_all_provider_env_vars();
2193 std::env::set_var("GROQ_API_KEY", "test-key");
2194
2195 let provider = RigProvider::auto();
2197
2198 assert!(provider.is_some());
2200 assert_eq!(provider.unwrap().name(), "groq");
2201 }
2202
2203 #[test]
2204 #[serial]
2205 fn test_auto_fallback_to_deepseek() {
2206 clear_all_provider_env_vars();
2208 std::env::set_var("DEEPSEEK_API_KEY", "test-key");
2209
2210 let provider = RigProvider::auto();
2212
2213 assert!(provider.is_some());
2215 assert_eq!(provider.unwrap().name(), "deepseek");
2216 }
2217
2218 #[test]
2219 #[serial]
2220 fn test_auto_fallback_to_gemini() {
2221 clear_all_provider_env_vars();
2223 std::env::set_var("GEMINI_API_KEY", "test-key");
2224
2225 let provider = RigProvider::auto();
2227
2228 assert!(provider.is_some());
2230 assert_eq!(provider.unwrap().name(), "gemini");
2231 }
2232
2233 #[test]
2234 #[serial]
2235 fn test_auto_priority_claude_over_openai() {
2236 clear_all_provider_env_vars();
2238 std::env::set_var("ANTHROPIC_API_KEY", "claude-key");
2239 std::env::set_var("OPENAI_API_KEY", "openai-key");
2240
2241 let provider = RigProvider::auto();
2243
2244 assert!(provider.is_some());
2246 assert_eq!(provider.unwrap().name(), "claude");
2247 }
2248
2249 #[test]
2250 #[serial]
2251 fn test_auto_priority_openai_over_mistral() {
2252 clear_all_provider_env_vars();
2254 std::env::set_var("OPENAI_API_KEY", "openai-key");
2255 std::env::set_var("MISTRAL_API_KEY", "mistral-key");
2256
2257 let provider = RigProvider::auto();
2259
2260 assert!(provider.is_some());
2262 assert_eq!(provider.unwrap().name(), "openai");
2263 }
2264
2265 #[test]
2266 #[serial]
2267 fn test_auto_empty_env_var_treated_as_unset() {
2268 clear_all_provider_env_vars();
2270 std::env::set_var("ANTHROPIC_API_KEY", ""); std::env::set_var("OPENAI_API_KEY", "valid-key");
2272
2273 let provider = RigProvider::auto();
2275
2276 assert!(provider.is_some());
2278 assert_eq!(provider.unwrap().name(), "openai");
2279 }
2280
2281 #[test]
2282 #[serial]
2283 fn test_auto_whitespace_env_var_treated_as_unset() {
2284 clear_all_provider_env_vars();
2286 std::env::set_var("ANTHROPIC_API_KEY", " "); let provider = RigProvider::auto();
2290
2291 assert!(
2294 provider.is_none(),
2295 "Whitespace-only API key should be treated as unset"
2296 );
2297 }
2298
2299 #[test]
2304 fn test_nika_mcp_tool_implements_tool_dyn() {
2305 let tool_def = NikaMcpToolDef {
2307 name: "novanet_context".to_string(),
2308 description: "Generate native content for an entity".to_string(),
2309 input_schema: serde_json::json!({
2310 "type": "object",
2311 "properties": {
2312 "entity": { "type": "string" },
2313 "locale": { "type": "string" }
2314 },
2315 "required": ["entity", "locale"]
2316 }),
2317 };
2318
2319 let tool = NikaMcpTool::new(tool_def);
2321
2322 assert_eq!(tool.tool_name(), "novanet_context");
2324 }
2325
2326 #[test]
2327 fn test_nika_mcp_tool_definition_returns_correct_schema() {
2328 use rig::tool::ToolDyn;
2329
2330 let tool_def = NikaMcpToolDef {
2332 name: "novanet_describe".to_string(),
2333 description: "Describe an entity from the knowledge graph".to_string(),
2334 input_schema: serde_json::json!({
2335 "type": "object",
2336 "properties": {
2337 "entity_key": { "type": "string" }
2338 },
2339 "required": ["entity_key"]
2340 }),
2341 };
2342 let tool = NikaMcpTool::new(tool_def);
2343
2344 let name = tool.name();
2346
2347 assert_eq!(name, "novanet_describe");
2349 }
2350
2351 #[tokio::test]
2356 async fn test_nika_mcp_tool_call_uses_mcp_client() {
2357 use crate::mcp::McpClient;
2358 use rig::tool::ToolDyn;
2359 use std::sync::Arc;
2360
2361 let client = Arc::new(McpClient::mock("novanet"));
2363
2364 let tool_def = NikaMcpToolDef {
2366 name: "novanet_describe".to_string(),
2367 description: "Describe an entity".to_string(),
2368 input_schema: serde_json::json!({
2369 "type": "object",
2370 "properties": {
2371 "entity_key": { "type": "string" }
2372 },
2373 "required": ["entity_key"]
2374 }),
2375 };
2376 let tool = NikaMcpTool::with_client(tool_def, client);
2377
2378 let args = r#"{"entity_key": "qr-code"}"#.to_string();
2380 let result = tool.call(args).await;
2381
2382 assert!(result.is_ok(), "Tool call should succeed with mock client");
2384 let output = result.unwrap();
2385 assert!(!output.is_empty(), "Tool should return non-empty output");
2386 }
2387
2388 #[tokio::test]
2394 async fn test_usecase_novanet_context_entity_locale() {
2395 use crate::mcp::McpClient;
2396 use rig::tool::ToolDyn;
2397 use std::sync::Arc;
2398
2399 let client = Arc::new(McpClient::mock("novanet"));
2401
2402 let tool_def = NikaMcpToolDef {
2404 name: "novanet_context".to_string(),
2405 description: "Full RLM-on-KG context assembly for generation".to_string(),
2406 input_schema: serde_json::json!({
2407 "type": "object",
2408 "properties": {
2409 "focus_key": { "type": "string", "description": "Entity key to generate for" },
2410 "locale": { "type": "string", "description": "BCP-47 locale code" },
2411 "mode": { "type": "string", "enum": ["block", "page"], "default": "block" },
2412 "token_budget": { "type": "integer", "default": 4000 },
2413 "spreading_depth": { "type": "integer", "default": 2 },
2414 "forms": {
2415 "type": "array",
2416 "items": { "type": "string", "enum": ["text", "title", "abbrev", "url"] }
2417 }
2418 },
2419 "required": ["focus_key", "locale"]
2420 }),
2421 };
2422 let tool = NikaMcpTool::with_client(tool_def, client);
2423
2424 let args = serde_json::json!({
2426 "focus_key": "qr-code",
2427 "locale": "fr-FR",
2428 "mode": "page",
2429 "forms": ["text", "title", "abbrev"]
2430 })
2431 .to_string();
2432
2433 let result = tool.call(args).await;
2434
2435 assert!(
2437 result.is_ok(),
2438 "novanet_context should succeed: {:?}",
2439 result
2440 );
2441 let output = result.unwrap();
2442 assert!(!output.is_empty(), "Should return generation context");
2443 }
2444
2445 #[tokio::test]
2447 async fn test_usecase_novanet_describe_entity() {
2448 use crate::mcp::McpClient;
2449 use rig::tool::ToolDyn;
2450 use std::sync::Arc;
2451
2452 let client = Arc::new(McpClient::mock("novanet"));
2453
2454 let tool_def = NikaMcpToolDef {
2455 name: "novanet_describe".to_string(),
2456 description: "Bootstrap agent understanding of the knowledge graph".to_string(),
2457 input_schema: serde_json::json!({
2458 "type": "object",
2459 "properties": {
2460 "describe": {
2461 "type": "string",
2462 "enum": ["schema", "entity", "category", "relations", "locales", "stats"]
2463 },
2464 "entity_key": { "type": "string" },
2465 "category_key": { "type": "string" }
2466 },
2467 "required": ["describe"]
2468 }),
2469 };
2470 let tool = NikaMcpTool::with_client(tool_def, client);
2471
2472 let args = serde_json::json!({
2474 "describe": "schema"
2475 })
2476 .to_string();
2477
2478 let result = tool.call(args).await;
2479 assert!(result.is_ok(), "novanet_describe should succeed");
2480 }
2481
2482 #[tokio::test]
2484 async fn test_usecase_novanet_search_walk_graph() {
2485 use crate::mcp::McpClient;
2486 use rig::tool::ToolDyn;
2487 use std::sync::Arc;
2488
2489 let client = Arc::new(McpClient::mock("novanet"));
2490
2491 let tool_def = NikaMcpToolDef {
2492 name: "novanet_search".to_string(),
2493 description: "Graph traversal with configurable depth and filters".to_string(),
2494 input_schema: serde_json::json!({
2495 "type": "object",
2496 "properties": {
2497 "start_key": { "type": "string" },
2498 "max_depth": { "type": "integer", "default": 2 },
2499 "direction": { "type": "string", "enum": ["outgoing", "incoming", "both"] },
2500 "arc_families": { "type": "array", "items": { "type": "string" } },
2501 "target_kinds": { "type": "array", "items": { "type": "string" } }
2502 },
2503 "required": ["start_key"]
2504 }),
2505 };
2506 let tool = NikaMcpTool::with_client(tool_def, client);
2507
2508 let args = serde_json::json!({
2510 "start_key": "qr-code",
2511 "max_depth": 2,
2512 "direction": "outgoing",
2513 "arc_families": ["ownership", "localization"]
2514 })
2515 .to_string();
2516
2517 let result = tool.call(args).await;
2518 assert!(result.is_ok(), "novanet_search walk should succeed");
2519 }
2520
2521 #[tokio::test]
2523 async fn test_usecase_novanet_search_hybrid() {
2524 use crate::mcp::McpClient;
2525 use rig::tool::ToolDyn;
2526 use std::sync::Arc;
2527
2528 let client = Arc::new(McpClient::mock("novanet"));
2529
2530 let tool_def = NikaMcpToolDef {
2531 name: "novanet_search".to_string(),
2532 description: "Fulltext + property search with hybrid mode".to_string(),
2533 input_schema: serde_json::json!({
2534 "type": "object",
2535 "properties": {
2536 "query": { "type": "string" },
2537 "mode": { "type": "string", "enum": ["fulltext", "property", "hybrid"] },
2538 "kinds": { "type": "array", "items": { "type": "string" } },
2539 "realm": { "type": "string", "enum": ["shared", "org"] },
2540 "limit": { "type": "integer", "default": 10 }
2541 },
2542 "required": ["query"]
2543 }),
2544 };
2545 let tool = NikaMcpTool::with_client(tool_def, client);
2546
2547 let args = serde_json::json!({
2549 "query": "QR code generator",
2550 "mode": "hybrid",
2551 "kinds": ["Entity", "Page"],
2552 "limit": 5
2553 })
2554 .to_string();
2555
2556 let result = tool.call(args).await;
2557 assert!(result.is_ok(), "novanet_search should succeed");
2558 }
2559
2560 #[tokio::test]
2562 async fn test_usecase_novanet_audit_locale() {
2563 use crate::mcp::McpClient;
2564 use rig::tool::ToolDyn;
2565 use std::sync::Arc;
2566
2567 let client = Arc::new(McpClient::mock("novanet"));
2568
2569 let tool_def = NikaMcpToolDef {
2570 name: "novanet_audit".to_string(),
2571 description: "Retrieve knowledge atoms for a specific locale".to_string(),
2572 input_schema: serde_json::json!({
2573 "type": "object",
2574 "properties": {
2575 "locale": { "type": "string" },
2576 "atom_type": {
2577 "type": "string",
2578 "enum": ["term", "expression", "pattern", "cultureref", "taboo", "audiencetrait", "all"]
2579 },
2580 "domain": { "type": "string" }
2581 },
2582 "required": ["locale"]
2583 }),
2584 };
2585 let tool = NikaMcpTool::with_client(tool_def, client);
2586
2587 let args = serde_json::json!({
2589 "locale": "fr-FR",
2590 "atom_type": "term",
2591 "domain": "qr-code"
2592 })
2593 .to_string();
2594
2595 let result = tool.call(args).await;
2596 assert!(result.is_ok(), "novanet_audit should succeed");
2597 }
2598
2599 #[tokio::test]
2601 async fn test_usecase_novanet_batch_context() {
2602 use crate::mcp::McpClient;
2603 use rig::tool::ToolDyn;
2604 use std::sync::Arc;
2605
2606 let client = Arc::new(McpClient::mock("novanet"));
2607
2608 let tool_def = NikaMcpToolDef {
2609 name: "novanet_batch".to_string(),
2610 description: "Assemble context for LLM generation (token-aware)".to_string(),
2611 input_schema: serde_json::json!({
2612 "type": "object",
2613 "properties": {
2614 "focus_key": { "type": "string" },
2615 "locale": { "type": "string" },
2616 "token_budget": { "type": "integer", "default": 4000 },
2617 "strategy": {
2618 "type": "string",
2619 "enum": ["breadth", "depth", "relevance", "custom"]
2620 }
2621 },
2622 "required": ["focus_key", "locale"]
2623 }),
2624 };
2625 let tool = NikaMcpTool::with_client(tool_def, client);
2626
2627 let args = serde_json::json!({
2629 "focus_key": "qr-code",
2630 "locale": "es-MX",
2631 "token_budget": 3000,
2632 "strategy": "relevance"
2633 })
2634 .to_string();
2635
2636 let result = tool.call(args).await;
2637 assert!(result.is_ok(), "novanet_batch should succeed");
2638 }
2639
2640 #[tokio::test]
2646 async fn test_error_no_client_configured() {
2647 use rig::tool::ToolDyn;
2648
2649 let tool_def = NikaMcpToolDef {
2651 name: "novanet_describe".to_string(),
2652 description: "Test tool".to_string(),
2653 input_schema: serde_json::json!({"type": "object"}),
2654 };
2655 let tool = NikaMcpTool::new(tool_def); let args = r#"{"entity_key": "test"}"#.to_string();
2659 let result = tool.call(args).await;
2660
2661 assert!(result.is_err(), "Should fail without client");
2663 let err = result.unwrap_err();
2664 let err_str = err.to_string();
2665 assert!(
2666 err_str.contains("No MCP client") || err_str.contains("NotConnected"),
2667 "Error should mention missing client: {}",
2668 err_str
2669 );
2670 }
2671
2672 #[tokio::test]
2674 async fn test_error_invalid_json_arguments() {
2675 use crate::mcp::McpClient;
2676 use rig::tool::ToolDyn;
2677 use std::sync::Arc;
2678
2679 let client = Arc::new(McpClient::mock("novanet"));
2680 let tool_def = NikaMcpToolDef {
2681 name: "novanet_describe".to_string(),
2682 description: "Test tool".to_string(),
2683 input_schema: serde_json::json!({"type": "object"}),
2684 };
2685 let tool = NikaMcpTool::with_client(tool_def, client);
2686
2687 let args = "not valid json {{{".to_string();
2689 let result = tool.call(args).await;
2690
2691 assert!(result.is_err(), "Should fail with invalid JSON");
2693 let err = result.unwrap_err();
2694 let err_str = err.to_string();
2695 assert!(
2696 err_str.contains("Invalid JSON") || err_str.contains("JSON"),
2697 "Error should mention JSON parsing: {}",
2698 err_str
2699 );
2700 }
2701
2702 #[tokio::test]
2704 async fn test_empty_json_object_is_valid() {
2705 use crate::mcp::McpClient;
2706 use rig::tool::ToolDyn;
2707 use std::sync::Arc;
2708
2709 let client = Arc::new(McpClient::mock("novanet"));
2710 let tool_def = NikaMcpToolDef {
2711 name: "novanet_describe".to_string(),
2712 description: "Test tool".to_string(),
2713 input_schema: serde_json::json!({"type": "object"}),
2714 };
2715 let tool = NikaMcpTool::with_client(tool_def, client);
2716
2717 let args = "{}".to_string();
2719 let result = tool.call(args).await;
2720
2721 assert!(result.is_ok(), "Empty JSON object should be valid");
2723 }
2724
2725 #[tokio::test]
2731 async fn test_tool_definition_async() {
2732 use rig::tool::ToolDyn;
2733
2734 let input_schema = serde_json::json!({
2735 "type": "object",
2736 "properties": {
2737 "entity_key": { "type": "string" },
2738 "locale": { "type": "string" }
2739 },
2740 "required": ["entity_key"]
2741 });
2742
2743 let tool_def = NikaMcpToolDef {
2744 name: "test_tool".to_string(),
2745 description: "A test tool for verification".to_string(),
2746 input_schema: input_schema.clone(),
2747 };
2748 let tool = NikaMcpTool::new(tool_def);
2749
2750 let definition = tool.definition("some prompt".to_string()).await;
2752
2753 assert_eq!(definition.name, "test_tool");
2755 assert_eq!(definition.description, "A test tool for verification");
2756 assert_eq!(definition.parameters, input_schema);
2757 }
2758
2759 #[test]
2761 fn test_multiple_tools_independent() {
2762 let tool1 = NikaMcpTool::new(NikaMcpToolDef {
2764 name: "novanet_context".to_string(),
2765 description: "Generate content".to_string(),
2766 input_schema: serde_json::json!({"type": "object"}),
2767 });
2768
2769 let tool2 = NikaMcpTool::new(NikaMcpToolDef {
2770 name: "novanet_describe".to_string(),
2771 description: "Describe entity".to_string(),
2772 input_schema: serde_json::json!({"type": "object"}),
2773 });
2774
2775 let tool3 = NikaMcpTool::new(NikaMcpToolDef {
2776 name: "novanet_search".to_string(),
2777 description: "Traverse graph".to_string(),
2778 input_schema: serde_json::json!({"type": "object"}),
2779 });
2780
2781 assert_eq!(tool1.tool_name(), "novanet_context");
2783 assert_eq!(tool2.tool_name(), "novanet_describe");
2784 assert_eq!(tool3.tool_name(), "novanet_search");
2785 }
2786
2787 #[tokio::test]
2789 async fn test_tool_clone_works() {
2790 use crate::mcp::McpClient;
2791 use rig::tool::ToolDyn;
2792 use std::sync::Arc;
2793
2794 let client = Arc::new(McpClient::mock("novanet"));
2795 let tool_def = NikaMcpToolDef {
2796 name: "novanet_describe".to_string(),
2797 description: "Test tool".to_string(),
2798 input_schema: serde_json::json!({"type": "object"}),
2799 };
2800 let tool = NikaMcpTool::with_client(tool_def, client);
2801
2802 let cloned_tool = tool.clone();
2804
2805 let args = r#"{"entity_key": "test"}"#.to_string();
2807 let result1 = tool.call(args.clone()).await;
2808 let result2 = cloned_tool.call(args).await;
2809
2810 assert!(result1.is_ok(), "Original tool should work");
2811 assert!(result2.is_ok(), "Cloned tool should work");
2812 }
2813
2814 #[tokio::test]
2820 async fn test_multi_locale_generation_workflow() {
2821 use crate::mcp::McpClient;
2822 use rig::tool::ToolDyn;
2823 use std::sync::Arc;
2824
2825 let client = Arc::new(McpClient::mock("novanet"));
2826 let tool_def = NikaMcpToolDef {
2827 name: "novanet_context".to_string(),
2828 description: "Generate native content".to_string(),
2829 input_schema: serde_json::json!({
2830 "type": "object",
2831 "properties": {
2832 "focus_key": { "type": "string" },
2833 "locale": { "type": "string" },
2834 "forms": { "type": "array", "items": { "type": "string" } }
2835 },
2836 "required": ["focus_key", "locale"]
2837 }),
2838 };
2839 let tool = NikaMcpTool::with_client(tool_def, client);
2840
2841 let locales = ["fr-FR", "es-MX", "de-DE", "ja-JP", "zh-CN"];
2843 let mut results = Vec::new();
2844
2845 for locale in locales {
2846 let args = serde_json::json!({
2847 "focus_key": "qr-code",
2848 "locale": locale,
2849 "forms": ["text", "title"]
2850 })
2851 .to_string();
2852
2853 let result = tool.call(args).await;
2854 results.push((locale, result.is_ok()));
2855 }
2856
2857 for (locale, success) in &results {
2859 assert!(success, "Generation for {} should succeed", locale);
2860 }
2861 assert_eq!(results.len(), 5, "Should process all 5 locales");
2862 }
2863
2864 #[test]
2869 fn test_provider_verify_error_types() {
2870 let invalid_key = ProviderVerifyError::InvalidApiKey {
2872 provider: "claude".to_string(),
2873 };
2874 assert!(invalid_key.to_string().contains("Invalid API key"));
2875 assert!(invalid_key.suggestion().contains("API key"));
2876
2877 let rate_limited = ProviderVerifyError::RateLimited {
2878 provider: "openai".to_string(),
2879 };
2880 assert!(rate_limited.to_string().contains("Rate limited"));
2881
2882 let timeout = ProviderVerifyError::Timeout {
2883 provider: "mistral".to_string(),
2884 };
2885 assert!(timeout.to_string().contains("timeout"));
2886
2887 let network = ProviderVerifyError::NetworkError {
2888 provider: "groq".to_string(),
2889 details: "connection refused".to_string(),
2890 };
2891 assert!(network.to_string().contains("Network error"));
2892
2893 let provider_err = ProviderVerifyError::ProviderError {
2894 provider: "deepseek".to_string(),
2895 details: "server down".to_string(),
2896 };
2897 assert!(provider_err.to_string().contains("server down"));
2898 }
2899
2900 #[test]
2901 fn test_provider_verify_result_fields() {
2902 let result = ProviderVerifyResult {
2903 provider: "claude".to_string(),
2904 latency: std::time::Duration::from_millis(150),
2905 model: "claude-sonnet-4-6".to_string(),
2906 };
2907
2908 assert_eq!(result.provider, "claude");
2909 assert_eq!(result.latency.as_millis(), 150);
2910 assert_eq!(result.model, "claude-sonnet-4-6");
2911 }
2912
2913 #[test]
2914 #[serial]
2915 fn test_is_configured_with_api_key() {
2916 std::env::set_var("ANTHROPIC_API_KEY", "test-key");
2917 let provider = RigProvider::claude();
2918 assert!(provider.is_configured());
2919 }
2920
2921 #[test]
2922 #[serial]
2923 fn test_is_configured_returns_true_for_all_providers_with_keys() {
2924 std::env::set_var("ANTHROPIC_API_KEY", "test");
2926 std::env::set_var("OPENAI_API_KEY", "test");
2927 std::env::set_var("MISTRAL_API_KEY", "test");
2928 std::env::set_var("GROQ_API_KEY", "test");
2929 std::env::set_var("DEEPSEEK_API_KEY", "test");
2930
2931 assert!(RigProvider::claude().is_configured());
2932 assert!(RigProvider::openai().is_configured());
2933 assert!(RigProvider::mistral().is_configured());
2934 assert!(RigProvider::groq().is_configured());
2935 assert!(RigProvider::deepseek().is_configured());
2936 }
2937
2938 #[test]
2943 fn test_infer_options_default() {
2944 let opts = InferOptions::default();
2945 assert!(opts.model.is_none());
2946 assert!(opts.temperature.is_none());
2947 assert!(opts.max_tokens.is_none());
2948 assert!(opts.system.is_none());
2949 }
2950
2951 #[test]
2952 fn test_infer_options_with_all_fields() {
2953 let opts = InferOptions {
2954 model: Some("gpt-4o".to_string()),
2955 temperature: Some(0.7),
2956 max_tokens: Some(2000),
2957 system: Some("You are a helpful assistant.".to_string()),
2958 };
2959 assert_eq!(opts.model.as_deref(), Some("gpt-4o"));
2960 assert_eq!(opts.temperature, Some(0.7));
2961 assert_eq!(opts.max_tokens, Some(2000));
2962 assert_eq!(opts.system.as_deref(), Some("You are a helpful assistant."));
2963 }
2964
2965 #[test]
2966 fn test_infer_options_partial_fields() {
2967 let opts = InferOptions {
2968 temperature: Some(0.5),
2969 ..Default::default()
2970 };
2971 assert!(opts.model.is_none());
2972 assert_eq!(opts.temperature, Some(0.5));
2973 assert!(opts.max_tokens.is_none());
2974 assert!(opts.system.is_none());
2975 }
2976
2977 #[test]
2978 fn test_infer_options_temperature_zero() {
2979 let opts = InferOptions {
2980 temperature: Some(0.0),
2981 ..Default::default()
2982 };
2983 assert_eq!(opts.temperature, Some(0.0));
2984 }
2985
2986 #[test]
2987 fn test_infer_options_max_tokens_small() {
2988 let opts = InferOptions {
2989 max_tokens: Some(1),
2990 ..Default::default()
2991 };
2992 assert_eq!(opts.max_tokens, Some(1));
2993 }
2994
2995 #[test]
2996 fn test_infer_options_system_empty_string() {
2997 let opts = InferOptions {
2998 system: Some(String::new()),
2999 ..Default::default()
3000 };
3001 assert_eq!(opts.system.as_deref(), Some(""));
3002 }
3003
3004 #[test]
3005 fn test_infer_options_clone() {
3006 let opts = InferOptions {
3007 model: Some("test-model".to_string()),
3008 temperature: Some(0.8),
3009 max_tokens: Some(1000),
3010 system: Some("Test system".to_string()),
3011 };
3012 let cloned = opts.clone();
3013 assert_eq!(opts.model, cloned.model);
3014 assert_eq!(opts.temperature, cloned.temperature);
3015 assert_eq!(opts.max_tokens, cloned.max_tokens);
3016 assert_eq!(opts.system, cloned.system);
3017 }
3018
3019 #[test]
3024 fn vision_not_supported_error_display() {
3025 let err = RigInferError::VisionNotSupported("DeepSeek no vision".to_string());
3026 assert!(err.to_string().contains("Vision not supported"));
3027 assert!(err.to_string().contains("DeepSeek no vision"));
3028 }
3029
3030 #[tokio::test]
3032 async fn infer_vision_deepseek_returns_error() {
3033 if std::env::var("DEEPSEEK_API_KEY").is_err() {
3034 let err = RigInferError::VisionNotSupported("DeepSeek".to_string());
3036 assert!(err.to_string().contains("Vision not supported"));
3037 return;
3038 }
3039 let provider = RigProvider::deepseek();
3040 let content = vec![rig::completion::message::UserContent::text("hello")];
3041 let result = provider.infer_vision(content, None, None, None).await;
3042 assert!(result.is_err());
3043 assert!(matches!(
3044 result.unwrap_err(),
3045 RigInferError::VisionNotSupported(_)
3046 ));
3047 }
3048
3049 #[test]
3050 fn infer_vision_empty_content_builds_error() {
3051 use rig::OneOrMany;
3053 let content: Vec<rig::completion::message::UserContent> = vec![];
3054 let result = OneOrMany::many(content);
3055 assert!(result.is_err(), "empty content should fail");
3056 }
3057
3058 #[test]
3059 fn build_vision_user_content_text_only() {
3060 let content = [rig::completion::message::UserContent::text("Describe this")];
3061 assert_eq!(content.len(), 1);
3062 }
3063
3064 #[test]
3065 fn build_vision_user_content_with_image() {
3066 use rig::completion::message::{ImageMediaType, UserContent};
3067 let content = [
3068 UserContent::text("What is in this image?"),
3069 UserContent::image_base64(
3070 "iVBORw0KGgo=", Some(ImageMediaType::PNG),
3072 None,
3073 ),
3074 ];
3075 assert_eq!(content.len(), 2);
3076 }
3077
3078 #[test]
3079 fn build_vision_message_from_content() {
3080 use rig::completion::message::{ImageMediaType, Message, UserContent};
3081 use rig::OneOrMany;
3082
3083 let parts = vec![
3084 UserContent::text("Describe this image"),
3085 UserContent::image_base64("iVBORw0KGgo=", Some(ImageMediaType::PNG), None),
3086 ];
3087 let msg = Message::User {
3088 content: OneOrMany::many(parts).unwrap(),
3089 };
3090 assert!(matches!(msg, Message::User { .. }));
3091 }
3092
3093 #[test]
3098 fn reasoning_model_o_series() {
3099 assert!(is_reasoning_model("o1"));
3100 assert!(is_reasoning_model("o1-mini"));
3101 assert!(is_reasoning_model("o1-pro"));
3102 assert!(is_reasoning_model("o3"));
3103 assert!(is_reasoning_model("o3-mini"));
3104 assert!(is_reasoning_model("o3-pro"));
3105 assert!(is_reasoning_model("o4"));
3106 assert!(is_reasoning_model("o4-mini"));
3107 assert!(is_reasoning_model("o1-2024-12-17"));
3108 }
3109
3110 #[test]
3111 fn reasoning_model_gpt5() {
3112 assert!(is_reasoning_model("gpt-5"));
3113 assert!(is_reasoning_model("gpt-5-turbo"));
3114 }
3115
3116 #[test]
3117 fn reasoning_model_deepseek() {
3118 assert!(is_reasoning_model("deepseek-reasoner"));
3119 }
3120
3121 #[test]
3122 fn reasoning_model_case_insensitive() {
3123 assert!(is_reasoning_model("O1"));
3124 assert!(is_reasoning_model("GPT-5"));
3125 }
3126
3127 #[test]
3128 fn non_reasoning_models() {
3129 assert!(!is_reasoning_model("gpt-4o"));
3130 assert!(!is_reasoning_model("gpt-4o-mini"));
3131 assert!(!is_reasoning_model("claude-sonnet-4"));
3132 assert!(!is_reasoning_model("deepseek-chat"));
3133 assert!(!is_reasoning_model("gemini-2.0-flash"));
3134 assert!(!is_reasoning_model("grok-3"));
3135 }
3136}