1use crate::mcp::McpClient;
27use crate::util::STREAM_CHUNK_TIMEOUT;
28use futures::StreamExt;
29use std::time::Instant;
30
31#[cfg(feature = "native-inference")]
33use crate::provider::native::InferenceBackend;
34use rig::client::{CompletionClient, ProviderClient};
35use rig::completion::{CompletionModel as _, GetTokenUsage, Prompt, PromptError, ToolDefinition};
36use rig::providers::{anthropic, deepseek, gemini, groq, mistral, openai, xai};
37use rig::streaming::StreamedAssistantContent;
38use rig::tool::{ToolDyn, ToolError};
39use std::future::Future;
40use std::pin::Pin;
41use std::sync::Arc;
42use tokio::sync::mpsc;
43use tokio::time::timeout;
44
45#[derive(Debug)]
53pub struct McpToolError {
54 kind: McpToolErrorKind,
55 message: String,
56}
57
58#[derive(Debug, Clone, Copy)]
60pub enum McpToolErrorKind {
61 InvalidArguments,
63 NotConfigured,
65 CallFailed,
67 SerializationError,
69}
70
71impl McpToolError {
72 pub fn invalid_args(msg: impl Into<String>) -> Self {
74 Self {
75 kind: McpToolErrorKind::InvalidArguments,
76 message: msg.into(),
77 }
78 }
79
80 pub fn not_configured(msg: impl Into<String>) -> Self {
82 Self {
83 kind: McpToolErrorKind::NotConfigured,
84 message: msg.into(),
85 }
86 }
87
88 pub fn call_failed(msg: impl Into<String>) -> Self {
90 Self {
91 kind: McpToolErrorKind::CallFailed,
92 message: msg.into(),
93 }
94 }
95
96 pub fn serialization(msg: impl Into<String>) -> Self {
98 Self {
99 kind: McpToolErrorKind::SerializationError,
100 message: msg.into(),
101 }
102 }
103}
104
105impl std::fmt::Display for McpToolError {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 let kind_str = match self.kind {
108 McpToolErrorKind::InvalidArguments => "InvalidArguments",
109 McpToolErrorKind::NotConfigured => "NotConfigured",
110 McpToolErrorKind::CallFailed => "CallFailed",
111 McpToolErrorKind::SerializationError => "SerializationError",
112 };
113 write!(f, "[{}] {}", kind_str, self.message)
114 }
115}
116
117impl std::error::Error for McpToolError {}
118
119#[derive(Debug, Clone, Default)]
123pub struct InferOptions {
124 pub model: Option<String>,
126 pub temperature: Option<f64>,
128 pub max_tokens: Option<u32>,
130 pub system: Option<String>,
132}
133
134pub fn is_reasoning_model(model_id: &str) -> bool {
139 let lower = model_id.to_lowercase();
140 lower == "o1"
142 || lower == "o1-mini"
143 || lower == "o1-pro"
144 || lower == "o3"
145 || lower == "o3-mini"
146 || lower == "o3-pro"
147 || lower == "o4-mini"
148 || lower.starts_with("o1-")
149 || lower.starts_with("o3-")
150 || lower == "o4"
151 || lower.starts_with("o4-")
152 || lower == "gpt-5"
154 || lower.starts_with("gpt-5-")
155 || lower == "deepseek-reasoner"
157}
158
159#[derive(Debug, Clone)]
164pub enum RigProvider {
165 Claude(anthropic::Client),
167 OpenAI(openai::Client),
169 Mistral(mistral::Client),
171 Groq(groq::Client),
173 DeepSeek(deepseek::Client),
175 Gemini(gemini::Client),
177 XAi(xai::Client),
179 #[cfg(feature = "native-inference")]
183 Native(super::native::NativeRuntime),
184}
185
186impl RigProvider {
187 pub fn from_name(name: &str) -> Result<Self, crate::error::NikaError> {
197 let provider = crate::core::find_provider(name).ok_or_else(|| {
198 crate::error::NikaError::ProviderNotConfigured {
199 provider: name.to_string(),
200 }
201 })?;
202
203 if provider.requires_key && !provider.has_env_key() {
205 return Err(crate::error::NikaError::MissingApiKey {
206 provider: provider.id.to_string(),
207 });
208 }
209
210 match provider.id {
211 "anthropic" => Ok(Self::claude()),
212 "openai" => Ok(Self::openai()),
213 "mistral" => Ok(Self::mistral()),
214 "groq" => Ok(Self::groq()),
215 "deepseek" => Ok(Self::deepseek()),
216 "gemini" => Ok(Self::gemini()),
217 "xai" => Ok(Self::xai()),
218 #[cfg(feature = "native-inference")]
219 "native" => Ok(Self::native()),
220 _ => Err(crate::error::NikaError::ProviderNotConfigured {
221 provider: name.to_string(),
222 }),
223 }
224 }
225
226 pub fn claude() -> Self {
228 let client = anthropic::Client::from_env();
229 RigProvider::Claude(client)
230 }
231
232 pub fn openai() -> Self {
234 let client = openai::Client::from_env();
235 RigProvider::OpenAI(client)
236 }
237
238 pub fn mistral() -> Self {
240 let client = mistral::Client::from_env();
241 RigProvider::Mistral(client)
242 }
243
244 pub fn groq() -> Self {
246 let client = groq::Client::from_env();
247 RigProvider::Groq(client)
248 }
249
250 pub fn deepseek() -> Self {
252 let client = deepseek::Client::from_env();
253 RigProvider::DeepSeek(client)
254 }
255
256 pub fn gemini() -> Self {
258 let client = gemini::Client::from_env();
259 RigProvider::Gemini(client)
260 }
261
262 pub fn xai() -> Self {
264 let client = xai::Client::from_env();
265 RigProvider::XAi(client)
266 }
267
268 #[cfg(feature = "native-inference")]
277 pub fn native() -> Self {
278 RigProvider::Native(super::native::NativeRuntime::new())
279 }
280
281 #[cfg(feature = "native-inference")]
289 pub async fn load_native_model(
290 &mut self,
291 model_path: impl Into<std::path::PathBuf>,
292 config: Option<super::native::LoadConfig>,
293 ) -> Result<(), RigInferError> {
294 match self {
295 RigProvider::Native(runtime) => runtime
296 .load(model_path.into(), config.unwrap_or_default())
297 .await
298 .map_err(|e: super::native::NativeError| RigInferError::PromptError(e.to_string())),
299 _ => Err(RigInferError::PromptError(
300 "load_native_model only valid for Native provider".to_string(),
301 )),
302 }
303 }
304
305 #[cfg(feature = "native-inference")]
307 pub fn is_native_loaded(&self) -> bool {
308 match self {
309 RigProvider::Native(runtime) => runtime.is_loaded(),
310 _ => false,
311 }
312 }
313
314 pub fn name(&self) -> &'static str {
316 match self {
317 RigProvider::Claude(_) => "claude",
318 RigProvider::OpenAI(_) => "openai",
319 RigProvider::Mistral(_) => "mistral",
320 RigProvider::Groq(_) => "groq",
321 RigProvider::DeepSeek(_) => "deepseek",
322 RigProvider::Gemini(_) => "gemini",
323 RigProvider::XAi(_) => "xai",
324 #[cfg(feature = "native-inference")]
325 RigProvider::Native(_) => "native",
326 }
327 }
328
329 pub fn default_model(&self) -> &'static str {
341 match self {
342 RigProvider::Claude(_) => "claude-sonnet-4-6",
345 RigProvider::OpenAI(_) => openai::GPT_4O,
346 RigProvider::Mistral(_) => mistral::MISTRAL_LARGE,
347 RigProvider::Groq(_) => "llama-3.3-70b-versatile",
348 RigProvider::DeepSeek(_) => "deepseek-chat",
349 RigProvider::Gemini(_) => "gemini-2.0-flash",
350 RigProvider::XAi(_) => "grok-3-fast",
351 #[cfg(feature = "native-inference")]
353 RigProvider::Native(_) => "native-model",
354 }
355 }
356
357 pub async fn infer(&self, prompt: &str, model: Option<&str>) -> Result<String, RigInferError> {
366 const INFER_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
369
370 let model_id = model.unwrap_or_else(|| self.default_model());
371
372 match self {
373 RigProvider::Claude(client) => {
374 let agent = client.agent(model_id).max_tokens(8192).build();
376 timeout(INFER_TIMEOUT, agent.prompt(prompt))
377 .await
378 .map_err(|_| RigInferError::Timeout {
379 duration_ms: INFER_TIMEOUT.as_millis() as u64,
380 })?
381 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
382 }
383 RigProvider::OpenAI(client) => {
384 let agent = client.agent(model_id).max_tokens(8192).build();
385 timeout(INFER_TIMEOUT, agent.prompt(prompt))
386 .await
387 .map_err(|_| RigInferError::Timeout {
388 duration_ms: INFER_TIMEOUT.as_millis() as u64,
389 })?
390 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
391 }
392 RigProvider::Mistral(client) => {
393 let agent = client.agent(model_id).max_tokens(8192).build();
394 timeout(INFER_TIMEOUT, agent.prompt(prompt))
395 .await
396 .map_err(|_| RigInferError::Timeout {
397 duration_ms: INFER_TIMEOUT.as_millis() as u64,
398 })?
399 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
400 }
401 RigProvider::Groq(client) => {
402 let agent = client.agent(model_id).max_tokens(8192).build();
403 timeout(INFER_TIMEOUT, agent.prompt(prompt))
404 .await
405 .map_err(|_| RigInferError::Timeout {
406 duration_ms: INFER_TIMEOUT.as_millis() as u64,
407 })?
408 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
409 }
410 RigProvider::DeepSeek(client) => {
411 let agent = client.agent(model_id).max_tokens(8192).build();
412 timeout(INFER_TIMEOUT, agent.prompt(prompt))
413 .await
414 .map_err(|_| RigInferError::Timeout {
415 duration_ms: INFER_TIMEOUT.as_millis() as u64,
416 })?
417 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
418 }
419 RigProvider::Gemini(client) => {
420 let agent = client.agent(model_id).max_tokens(8192).build();
421 timeout(INFER_TIMEOUT, agent.prompt(prompt))
422 .await
423 .map_err(|_| RigInferError::Timeout {
424 duration_ms: INFER_TIMEOUT.as_millis() as u64,
425 })?
426 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
427 }
428 RigProvider::XAi(client) => {
429 let agent = client.agent(model_id).max_tokens(8192).build();
430 timeout(INFER_TIMEOUT, agent.prompt(prompt))
431 .await
432 .map_err(|_| RigInferError::Timeout {
433 duration_ms: INFER_TIMEOUT.as_millis() as u64,
434 })?
435 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
436 }
437 #[cfg(feature = "native-inference")]
438 RigProvider::Native(runtime) => {
439 timeout(
442 INFER_TIMEOUT,
443 runtime.infer(prompt, super::native::ChatOptions::default()),
444 )
445 .await
446 .map_err(|_| RigInferError::Timeout {
447 duration_ms: INFER_TIMEOUT.as_millis() as u64,
448 })?
449 .map(|r| r.message.content)
450 .map_err(|e: super::native::NativeError| RigInferError::PromptError(e.to_string()))
451 }
452 }
453 }
454
455 pub async fn infer_vision(
471 &self,
472 user_content: Vec<rig::completion::message::UserContent>,
473 model: Option<&str>,
474 system: Option<&str>,
475 max_tokens: Option<u32>,
476 ) -> Result<String, RigInferError> {
477 use rig::completion::message::Message;
478 use rig::OneOrMany;
479
480 if matches!(self, RigProvider::DeepSeek(_)) {
482 return Err(RigInferError::VisionNotSupported(
483 "DeepSeek does not support vision/multimodal content".to_string(),
484 ));
485 }
486
487 #[cfg(feature = "native-inference")]
489 if let RigProvider::Native(runtime) = self {
490 if !runtime.supports_vision() {
491 return Err(RigInferError::VisionNotSupported(
492 "Native model does not support vision. Load a vision model via \
493 NativeModelKind::VisionHf (e.g., `nika model vision <model_id> --isq Q4K`)"
494 .to_string(),
495 ));
496 }
497 let (prompt_text, vision_images) = extract_native_vision_parts(&user_content)?;
498 let options = super::native::ChatOptions {
499 max_tokens,
500 ..Default::default()
501 };
502 let response = runtime
503 .infer_vision(&prompt_text, vision_images, options)
504 .await
505 .map_err(|e: super::native::NativeError| {
506 RigInferError::PromptError(e.to_string())
507 })?;
508 return Ok(response.message.content);
509 }
510
511 let model_id = model.unwrap_or_else(|| self.default_model());
512 let max_tok = max_tokens.map(u64::from).unwrap_or(8192);
513
514 let message = Message::User {
515 content: OneOrMany::many(user_content).map_err(|_| {
516 RigInferError::VisionNotSupported("content parts list is empty".to_string())
517 })?,
518 };
519
520 macro_rules! vision_prompt {
521 ($client:expr) => {{
522 let mut builder = $client.agent(model_id).max_tokens(max_tok);
523 if let Some(sys) = system {
524 builder = builder.preamble(sys);
525 }
526 let agent = builder.build();
527 agent
528 .prompt(message)
529 .await
530 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
531 }};
532 }
533
534 match self {
535 RigProvider::Claude(client) => vision_prompt!(client),
536 RigProvider::OpenAI(client) => vision_prompt!(client),
537 RigProvider::Mistral(client) => vision_prompt!(client),
538 RigProvider::Groq(client) => vision_prompt!(client),
539 RigProvider::Gemini(client) => vision_prompt!(client),
540 RigProvider::XAi(client) => vision_prompt!(client),
541 RigProvider::DeepSeek(_) => unreachable!("DeepSeek handled above"),
543 #[cfg(feature = "native-inference")]
544 RigProvider::Native(_) => unreachable!("Native handled above"),
545 }
546 }
547
548 pub async fn infer_vision_stream(
553 &self,
554 user_content: Vec<rig::completion::message::UserContent>,
555 tx: mpsc::Sender<StreamChunk>,
556 model: Option<&str>,
557 system: Option<&str>,
558 max_tokens: Option<u32>,
559 ) -> Result<StreamResult, RigInferError> {
560 use rig::completion::message::Message;
561 use rig::OneOrMany;
562
563 if matches!(self, RigProvider::DeepSeek(_)) {
565 return Err(RigInferError::VisionNotSupported(
566 "DeepSeek does not support vision/multimodal content".to_string(),
567 ));
568 }
569
570 #[cfg(feature = "native-inference")]
574 if let RigProvider::Native(runtime) = self {
575 if !runtime.supports_vision() {
576 return Err(RigInferError::VisionNotSupported(
577 "Native model does not support vision. Load a vision model via \
578 NativeModelKind::VisionHf (e.g., `nika model vision <model_id> --isq Q4K`)"
579 .to_string(),
580 ));
581 }
582 let (prompt_text, vision_images) = extract_native_vision_parts(&user_content)?;
583 let options = super::native::ChatOptions {
584 max_tokens,
585 ..Default::default()
586 };
587 let response = runtime
588 .infer_vision(&prompt_text, vision_images, options)
589 .await
590 .map_err(|e: super::native::NativeError| {
591 RigInferError::PromptError(e.to_string())
592 })?;
593 let text = response.message.content;
595 if let Err(e) = tx.send(StreamChunk::Done(text.clone())).await {
596 tracing::warn!(error = %e, "Vision result channel closed — TUI may not show output");
597 }
598 return Ok(StreamResult {
599 text,
600 ..Default::default()
601 });
602 }
603
604 let model_id = model.unwrap_or_else(|| self.default_model());
605 let max_tok = max_tokens.map(u64::from).unwrap_or(8192);
606
607 let message = Message::User {
608 content: OneOrMany::many(user_content).map_err(|_| {
609 RigInferError::VisionNotSupported("content parts list is empty".to_string())
610 })?,
611 };
612
613 let mut response_parts: Vec<String> = Vec::new();
614 let mut result = StreamResult::default();
615
616 macro_rules! vision_stream {
617 ($client:expr, $is_anthropic:expr) => {{
618 let model = $client.completion_model(model_id);
619 let mut builder = model.completion_request(message).max_tokens(max_tok);
620 if let Some(sys) = system {
621 builder = builder.preamble(sys.to_string());
622 }
623 let request = builder.build();
624 let stream_start = Instant::now();
625 let mut stream = model
626 .stream(request)
627 .await
628 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
629 consume_rig_stream(
630 &mut stream,
631 &tx,
632 &mut response_parts,
633 &mut result,
634 $is_anthropic,
635 stream_start,
636 )
637 .await?;
638 }};
639 }
640
641 match self {
642 RigProvider::Claude(client) => vision_stream!(client, true),
643 RigProvider::OpenAI(client) => vision_stream!(client, false),
644 RigProvider::Mistral(client) => vision_stream!(client, false),
645 RigProvider::Groq(client) => vision_stream!(client, false),
646 RigProvider::Gemini(client) => vision_stream!(client, false),
647 RigProvider::XAi(client) => vision_stream!(client, false),
648 RigProvider::DeepSeek(_) => unreachable!("DeepSeek handled above"),
650 #[cfg(feature = "native-inference")]
651 RigProvider::Native(_) => unreachable!("Native handled above"),
652 }
653
654 result.text = response_parts.join("");
655 Ok(result)
656 }
657
658 pub async fn infer_with_tools(
673 &self,
674 prompt: &str,
675 tools: Vec<Box<dyn ToolDyn>>,
676 model: Option<&str>,
677 max_tokens: Option<u32>,
678 system: Option<&str>,
679 ) -> Result<String, RigInferError> {
680 use rig::agent::AgentBuilder;
681 use rig::message::ToolChoice as RigToolChoice;
682
683 let model_id = model.unwrap_or_else(|| self.default_model());
684 let max_tok = max_tokens.map(|v| v as u64).unwrap_or(8192);
685
686 macro_rules! build_agent_with_tools {
687 ($client:expr) => {{
688 let mut builder = AgentBuilder::new($client.completion_model(model_id))
689 .tools(tools)
690 .tool_choice(RigToolChoice::Required)
691 .max_tokens(max_tok);
692 if let Some(sys) = system {
693 builder = builder.preamble(sys);
694 }
695 let agent = builder.build();
696 agent
697 .prompt(prompt)
698 .await
699 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
700 }};
701 }
702
703 match self {
704 RigProvider::Claude(client) => build_agent_with_tools!(client),
705 RigProvider::OpenAI(client) => build_agent_with_tools!(client),
706 RigProvider::Mistral(client) => build_agent_with_tools!(client),
707 RigProvider::Groq(client) => build_agent_with_tools!(client),
708 RigProvider::DeepSeek(client) => build_agent_with_tools!(client),
709 RigProvider::Gemini(client) => build_agent_with_tools!(client),
710 RigProvider::XAi(client) => build_agent_with_tools!(client),
711 #[cfg(feature = "native-inference")]
712 RigProvider::Native(_) => {
713 Err(RigInferError::PromptError(
715 "Native inference does not support tool-based structured output".to_string(),
716 ))
717 }
718 }
719 }
720
721 pub async fn infer_with_options(
741 &self,
742 prompt: &str,
743 options: &InferOptions,
744 ) -> Result<String, RigInferError> {
745 let model_id = options
746 .model
747 .as_deref()
748 .unwrap_or_else(|| self.default_model());
749 let max_tokens = options.max_tokens.unwrap_or(8192);
750
751 let effective_temperature = if options.temperature.is_some() && is_reasoning_model(model_id)
753 {
754 tracing::warn!(
755 model = %model_id,
756 "temperature ignored for reasoning model '{}' (not supported)",
757 model_id
758 );
759 None
760 } else {
761 options.temperature
762 };
763
764 let user_prompt = prompt.to_string();
766
767 match self {
768 RigProvider::Claude(client) => {
769 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
770 if let Some(system) = &options.system {
771 builder = builder.preamble(system);
772 }
773 if let Some(temp) = effective_temperature {
774 builder = builder.temperature(temp);
775 }
776 let agent = builder.build();
777 agent
778 .prompt(&user_prompt)
779 .await
780 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
781 }
782 RigProvider::OpenAI(client) => {
783 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
784 if let Some(system) = &options.system {
785 builder = builder.preamble(system);
786 }
787 if let Some(temp) = effective_temperature {
788 builder = builder.temperature(temp);
789 }
790 let agent = builder.build();
791 agent
792 .prompt(&user_prompt)
793 .await
794 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
795 }
796 RigProvider::Mistral(client) => {
797 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
798 if let Some(system) = &options.system {
799 builder = builder.preamble(system);
800 }
801 if let Some(temp) = effective_temperature {
802 builder = builder.temperature(temp);
803 }
804 let agent = builder.build();
805 agent
806 .prompt(&user_prompt)
807 .await
808 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
809 }
810 RigProvider::Groq(client) => {
811 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
812 if let Some(system) = &options.system {
813 builder = builder.preamble(system);
814 }
815 if let Some(temp) = effective_temperature {
816 builder = builder.temperature(temp);
817 }
818 let agent = builder.build();
819 agent
820 .prompt(&user_prompt)
821 .await
822 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
823 }
824 RigProvider::DeepSeek(client) => {
825 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
826 if let Some(system) = &options.system {
827 builder = builder.preamble(system);
828 }
829 if let Some(temp) = effective_temperature {
830 builder = builder.temperature(temp);
831 }
832 let agent = builder.build();
833 agent
834 .prompt(&user_prompt)
835 .await
836 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
837 }
838 RigProvider::Gemini(client) => {
839 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
840 if let Some(system) = &options.system {
841 builder = builder.preamble(system);
842 }
843 if let Some(temp) = effective_temperature {
844 builder = builder.temperature(temp);
845 }
846 let agent = builder.build();
847 agent
848 .prompt(&user_prompt)
849 .await
850 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
851 }
852 RigProvider::XAi(client) => {
853 let mut builder = client.agent(model_id).max_tokens(max_tokens as u64);
854 if let Some(system) = &options.system {
855 builder = builder.preamble(system);
856 }
857 if let Some(temp) = effective_temperature {
858 builder = builder.temperature(temp);
859 }
860 let agent = builder.build();
861 agent
862 .prompt(&user_prompt)
863 .await
864 .map_err(|e: PromptError| RigInferError::PromptError(e.to_string()))
865 }
866 #[cfg(feature = "native-inference")]
867 RigProvider::Native(runtime) => {
868 let chat_options = super::native::ChatOptions {
870 temperature: effective_temperature.map(|t| t as f32),
871 max_tokens: options.max_tokens,
872 ..Default::default()
873 };
874 runtime
875 .infer(&user_prompt, chat_options)
876 .await
877 .map(|r| r.message.content)
878 .map_err(|e: super::native::NativeError| {
879 RigInferError::PromptError(e.to_string())
880 })
881 }
882 }
883 }
884
885 pub fn auto() -> Option<Self> {
899 use crate::core::providers::{ProviderCategory, KNOWN_PROVIDERS};
900
901 for p in KNOWN_PROVIDERS.iter() {
903 if p.category == ProviderCategory::Llm && p.has_env_key() {
904 return match p.id {
905 "anthropic" => Some(Self::claude()),
906 "openai" => Some(Self::openai()),
907 "mistral" => Some(Self::mistral()),
908 "groq" => Some(Self::groq()),
909 "deepseek" => Some(Self::deepseek()),
910 "gemini" => Some(Self::gemini()),
911 "xai" => Some(Self::xai()),
912 _ => continue,
913 };
914 }
915 }
916 #[cfg(feature = "native-inference")]
918 if std::env::var("NIKA_NATIVE_MODEL").is_ok_and(|v| !v.trim().is_empty()) {
919 return Some(Self::native());
920 }
921 None
922 }
923
924 pub async fn verify(&self) -> Result<ProviderVerifyResult, ProviderVerifyError> {
938 use std::time::Instant;
939
940 let start = Instant::now();
941
942 let test_prompt = "Hi";
944
945 match self.infer(test_prompt, None).await {
946 Ok(_) => Ok(ProviderVerifyResult {
947 provider: self.name().to_string(),
948 latency: start.elapsed(),
949 model: self.default_model().to_string(),
950 }),
951 Err(e) => {
952 let error_msg = e.to_string().to_lowercase();
953
954 if error_msg.contains("401")
956 || error_msg.contains("unauthorized")
957 || error_msg.contains("invalid api key")
958 || error_msg.contains("authentication")
959 {
960 Err(ProviderVerifyError::InvalidApiKey {
961 provider: self.name().to_string(),
962 })
963 } else if error_msg.contains("rate limit")
964 || error_msg.contains("429")
965 || error_msg.contains("too many requests")
966 {
967 Err(ProviderVerifyError::RateLimited {
968 provider: self.name().to_string(),
969 })
970 } else if error_msg.contains("timeout")
971 || error_msg.contains("timed out")
972 || error_msg.contains("deadline")
973 {
974 Err(ProviderVerifyError::Timeout {
975 provider: self.name().to_string(),
976 })
977 } else if error_msg.contains("connection")
978 || error_msg.contains("network")
979 || error_msg.contains("dns")
980 || error_msg.contains("refused")
981 {
982 Err(ProviderVerifyError::NetworkError {
983 provider: self.name().to_string(),
984 details: e.to_string(),
985 })
986 } else {
987 Err(ProviderVerifyError::ProviderError {
988 provider: self.name().to_string(),
989 details: e.to_string(),
990 })
991 }
992 }
993 }
994 }
995
996 pub fn is_configured(&self) -> bool {
1001 let has_key = |key: &str| std::env::var(key).is_ok_and(|v| !v.trim().is_empty());
1002
1003 match self {
1004 RigProvider::Claude(_) => has_key("ANTHROPIC_API_KEY"),
1005 RigProvider::OpenAI(_) => has_key("OPENAI_API_KEY"),
1006 RigProvider::Mistral(_) => has_key("MISTRAL_API_KEY"),
1007 RigProvider::Groq(_) => has_key("GROQ_API_KEY"),
1008 RigProvider::DeepSeek(_) => has_key("DEEPSEEK_API_KEY"),
1009 RigProvider::Gemini(_) => has_key("GEMINI_API_KEY"),
1010 RigProvider::XAi(_) => has_key("XAI_API_KEY"),
1011 #[cfg(feature = "native-inference")]
1012 RigProvider::Native(_) => {
1013 true
1016 }
1017 }
1018 }
1019}
1020
1021#[derive(Debug, Clone)]
1027pub struct ProviderVerifyResult {
1028 pub provider: String,
1030 pub latency: std::time::Duration,
1032 pub model: String,
1034}
1035
1036#[derive(Debug, Clone, thiserror::Error)]
1038pub enum ProviderVerifyError {
1039 #[error("Invalid API key for {provider}")]
1040 InvalidApiKey { provider: String },
1041
1042 #[error("Rate limited by {provider}")]
1043 RateLimited { provider: String },
1044
1045 #[error("Connection timeout to {provider}")]
1046 Timeout { provider: String },
1047
1048 #[error("Network error connecting to {provider}: {details}")]
1049 NetworkError { provider: String, details: String },
1050
1051 #[error("Provider error from {provider}: {details}")]
1052 ProviderError { provider: String, details: String },
1053}
1054
1055impl ProviderVerifyError {
1056 pub fn suggestion(&self) -> &'static str {
1058 match self {
1059 ProviderVerifyError::InvalidApiKey { .. } => {
1060 "Check your API key in environment variables"
1061 }
1062 ProviderVerifyError::RateLimited { .. } => {
1063 "Wait a moment and try again, or check your plan limits"
1064 }
1065 ProviderVerifyError::Timeout { .. } => "Check your network connection or try again",
1066 ProviderVerifyError::NetworkError { .. } => {
1067 "Check your internet connection and firewall settings"
1068 }
1069 ProviderVerifyError::ProviderError { .. } => {
1070 "The provider service may be experiencing issues"
1071 }
1072 }
1073 }
1074}
1075
1076#[derive(Debug, thiserror::Error)]
1078pub enum RigInferError {
1079 #[error("Completion error: {0}")]
1080 PromptError(String),
1081
1082 #[error("Stream timeout: no chunk received for {duration_ms}ms")]
1084 Timeout { duration_ms: u64 },
1085
1086 #[error("Vision not supported: {0}")]
1088 VisionNotSupported(String),
1089}
1090
1091#[derive(Debug, Clone)]
1097pub enum StreamChunk {
1098 Token(String),
1100 Thinking(String),
1102 Done(String),
1104 Error(String),
1106 Metrics {
1108 input_tokens: u64,
1109 output_tokens: u64,
1110 },
1111 McpConnected(String),
1113 McpError { server_name: String, error: String },
1115 McpCallStart {
1120 tool: String,
1121 server: String,
1122 params: String,
1123 },
1124 McpCallComplete { result: String },
1126 McpCallFailed { error: String },
1128 InferStart {
1130 model: String,
1131 prompt: String,
1133 prompt_tokens: u32,
1134 max_tokens: u32,
1135 },
1136 InferTokens { output_tokens: u32 },
1138 InferComplete,
1140 ExecStart { command: String },
1145 ExecComplete,
1147 FetchStart { url: String, method: String },
1149 FetchComplete,
1151 AgentStart { goal: String },
1153 AgentComplete,
1155 ProviderVerifying { provider: String, model: String },
1160 ProviderVerified {
1162 provider: String,
1163 model: String,
1164 latency_ms: u64,
1165 },
1166 ProviderVerifyFailed { provider: String, error: String },
1168 ProviderNotConfigured { provider: String },
1170 McpPinging { server: String },
1172 McpPinged {
1174 server: String,
1175 latency_ms: u64,
1176 tool_count: usize,
1177 },
1178 ProviderVerificationTimeout,
1180 NativeModelPullStarted { model: String },
1185 NativeModelPullProgress {
1187 model: String,
1188 status: String,
1189 completed: u64,
1190 total: u64,
1191 },
1192 NativeModelPulled {
1194 model: String,
1195 path: String,
1196 size: u64,
1197 },
1198 NativeModelPullFailed { model: String, error: String },
1200 NativeModelDeleted { model: String },
1202 NativeModelDeleteFailed { model: String, error: String },
1204 NativeModelsRefreshed { count: usize },
1206}
1207
1208#[derive(Debug, Clone, Default)]
1214pub struct StreamResult {
1215 pub text: String,
1217 pub input_tokens: u64,
1219 pub output_tokens: u64,
1221 pub total_tokens: u64,
1223 pub cached_input_tokens: u64,
1225 pub ttft_ms: Option<u64>,
1227 pub request_id: Option<String>,
1229}
1230
1231impl StreamResult {
1232 pub fn from_text(text: impl Into<String>) -> Self {
1234 Self {
1235 text: text.into(),
1236 ..Default::default()
1237 }
1238 }
1239}
1240
1241async fn consume_rig_stream<R>(
1251 stream: &mut rig::streaming::StreamingCompletionResponse<R>,
1252 tx: &mpsc::Sender<StreamChunk>,
1253 response_parts: &mut Vec<String>,
1254 result: &mut StreamResult,
1255 capture_thinking: bool,
1256 stream_start: Instant,
1257) -> Result<(), RigInferError>
1258where
1259 R: Clone + Unpin + GetTokenUsage + serde::Serialize + serde::de::DeserializeOwned,
1260{
1261 loop {
1262 let chunk_result = match timeout(STREAM_CHUNK_TIMEOUT, stream.next()).await {
1263 Ok(Some(result)) => result,
1264 Ok(None) => break,
1265 Err(_elapsed) => {
1266 let _ = tx.try_send(StreamChunk::Error(format!(
1267 "Stream timeout: no chunk received for {}s",
1268 STREAM_CHUNK_TIMEOUT.as_secs()
1269 )));
1270 return Err(RigInferError::Timeout {
1271 duration_ms: STREAM_CHUNK_TIMEOUT.as_millis() as u64,
1272 });
1273 }
1274 };
1275
1276 match chunk_result {
1277 Ok(content) => match content {
1278 StreamedAssistantContent::Text(text) => {
1279 if result.ttft_ms.is_none() {
1281 result.ttft_ms = Some(stream_start.elapsed().as_millis() as u64);
1282 }
1283 response_parts.push(text.text.clone());
1284 let _ = tx.try_send(StreamChunk::Token(text.text));
1285 }
1286 StreamedAssistantContent::ReasoningDelta { reasoning, .. } if capture_thinking => {
1287 let _ = tx.try_send(StreamChunk::Thinking(reasoning));
1288 }
1289 StreamedAssistantContent::Final(response) => {
1290 if let Some(usage) = response.token_usage() {
1291 result.input_tokens = usage.input_tokens;
1292 result.output_tokens = usage.output_tokens;
1293 result.total_tokens = usage.total_tokens;
1294 result.cached_input_tokens = usage.cached_input_tokens;
1295 }
1296 }
1297 _ => {}
1298 },
1299 Err(e) => {
1300 let _ = tx.try_send(StreamChunk::Error(e.to_string()));
1301 return Err(RigInferError::PromptError(e.to_string()));
1302 }
1303 }
1304 }
1305 Ok(())
1306}
1307
1308impl RigProvider {
1309 pub async fn infer_stream(
1321 &self,
1322 prompt: &str,
1323 tx: mpsc::Sender<StreamChunk>,
1324 model: Option<&str>,
1325 ) -> Result<StreamResult, RigInferError> {
1326 let model_id = model.unwrap_or_else(|| self.default_model());
1327 let mut response_parts: Vec<String> = Vec::new();
1328 let mut result = StreamResult::default();
1329
1330 match self {
1331 RigProvider::Claude(client) => {
1332 let model = client.completion_model(model_id);
1333 let request = model.completion_request(prompt).max_tokens(8192).build();
1334 let stream_start = Instant::now();
1335 let mut stream = model
1336 .stream(request)
1337 .await
1338 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1339 consume_rig_stream(
1340 &mut stream,
1341 &tx,
1342 &mut response_parts,
1343 &mut result,
1344 true,
1345 stream_start,
1346 )
1347 .await?;
1348 }
1349 RigProvider::OpenAI(client) => {
1350 let model = client.completion_model(model_id);
1351 let request = model.completion_request(prompt).max_tokens(8192).build();
1352 let stream_start = Instant::now();
1353 let mut stream = model
1354 .stream(request)
1355 .await
1356 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1357 consume_rig_stream(
1358 &mut stream,
1359 &tx,
1360 &mut response_parts,
1361 &mut result,
1362 false,
1363 stream_start,
1364 )
1365 .await?;
1366 }
1367 RigProvider::Mistral(client) => {
1368 let model = client.completion_model(model_id);
1369 let request = model.completion_request(prompt).max_tokens(8192).build();
1370 let stream_start = Instant::now();
1371 let mut stream = model
1372 .stream(request)
1373 .await
1374 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1375 consume_rig_stream(
1376 &mut stream,
1377 &tx,
1378 &mut response_parts,
1379 &mut result,
1380 false,
1381 stream_start,
1382 )
1383 .await?;
1384 }
1385 RigProvider::Groq(client) => {
1386 let model = client.completion_model(model_id);
1387 let request = model.completion_request(prompt).max_tokens(8192).build();
1388 let stream_start = Instant::now();
1389 let mut stream = model
1390 .stream(request)
1391 .await
1392 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1393 consume_rig_stream(
1394 &mut stream,
1395 &tx,
1396 &mut response_parts,
1397 &mut result,
1398 false,
1399 stream_start,
1400 )
1401 .await?;
1402 }
1403 RigProvider::DeepSeek(client) => {
1404 let model = client.completion_model(model_id);
1405 let request = model.completion_request(prompt).max_tokens(8192).build();
1406 let stream_start = Instant::now();
1407 let mut stream = model
1408 .stream(request)
1409 .await
1410 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1411 consume_rig_stream(
1412 &mut stream,
1413 &tx,
1414 &mut response_parts,
1415 &mut result,
1416 false,
1417 stream_start,
1418 )
1419 .await?;
1420 }
1421 RigProvider::Gemini(client) => {
1422 let model = client.completion_model(model_id);
1423 let request = model.completion_request(prompt).max_tokens(8192).build();
1424 let stream_start = Instant::now();
1425 let mut stream = model
1426 .stream(request)
1427 .await
1428 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1429 consume_rig_stream(
1430 &mut stream,
1431 &tx,
1432 &mut response_parts,
1433 &mut result,
1434 false,
1435 stream_start,
1436 )
1437 .await?;
1438 }
1439 RigProvider::XAi(client) => {
1440 let model = client.completion_model(model_id);
1441 let request = model.completion_request(prompt).max_tokens(8192).build();
1442 let stream_start = Instant::now();
1443 let mut stream = model
1444 .stream(request)
1445 .await
1446 .map_err(|e| RigInferError::PromptError(e.to_string()))?;
1447 consume_rig_stream(
1448 &mut stream,
1449 &tx,
1450 &mut response_parts,
1451 &mut result,
1452 false,
1453 stream_start,
1454 )
1455 .await?;
1456 }
1457 #[cfg(feature = "native-inference")]
1459 RigProvider::Native(runtime) => {
1460 use futures::StreamExt;
1461 use std::pin::pin;
1462
1463 let stream = runtime
1465 .infer_stream(prompt, super::native::ChatOptions::default())
1466 .await
1467 .map_err(|e: super::native::NativeError| {
1468 RigInferError::PromptError(e.to_string())
1469 })?;
1470
1471 let mut stream = pin!(stream);
1473
1474 while let Some(result) = stream.next().await {
1476 match result {
1477 Ok(token) => {
1478 response_parts.push(token.clone());
1479 let _ = tx.try_send(StreamChunk::Token(token));
1480 }
1481 Err(e) => {
1482 let _ = tx.try_send(StreamChunk::Error(e.to_string()));
1483 return Err(RigInferError::PromptError(e.to_string()));
1484 }
1485 }
1486 }
1487
1488 }
1491 }
1492
1493 let complete_response = response_parts.concat();
1494 let _ = tx.try_send(StreamChunk::Done(complete_response.clone()));
1495
1496 let _ = tx.try_send(StreamChunk::Metrics {
1498 input_tokens: result.input_tokens,
1499 output_tokens: result.output_tokens,
1500 });
1501
1502 result.text = complete_response;
1503 Ok(result)
1504 }
1505
1506 pub async fn infer_stream_with_options(
1519 &self,
1520 prompt: &str,
1521 tx: mpsc::Sender<StreamChunk>,
1522 options: &InferOptions,
1523 ) -> Result<StreamResult, RigInferError> {
1524 let model_id = options
1525 .model
1526 .as_deref()
1527 .unwrap_or_else(|| self.default_model());
1528 let max_tokens = options.max_tokens.unwrap_or(8192);
1529 let mut response_parts: Vec<String> = Vec::new();
1530 let mut result = StreamResult::default();
1531
1532 let effective_temperature = if options.temperature.is_some() && is_reasoning_model(model_id)
1534 {
1535 tracing::warn!(
1536 model = %model_id,
1537 "temperature ignored for reasoning model '{}' (not supported)",
1538 model_id
1539 );
1540 None
1541 } else {
1542 options.temperature
1543 };
1544
1545 macro_rules! build_request_with_options {
1549 ($client:expr) => {{
1550 let model = $client.completion_model(model_id);
1551 let mut rb = model
1552 .completion_request(prompt)
1553 .max_tokens(max_tokens as u64);
1554 if let Some(ref system) = options.system {
1555 rb = rb.preamble(system.clone());
1556 }
1557 if let Some(temp) = effective_temperature {
1558 rb = rb.temperature(temp);
1559 }
1560 model
1561 .stream(rb.build())
1562 .await
1563 .map_err(|e| RigInferError::PromptError(e.to_string()))?
1564 }};
1565 }
1566
1567 match self {
1568 RigProvider::Claude(client) => {
1569 let stream_start = Instant::now();
1570 let mut stream = build_request_with_options!(client);
1571 consume_rig_stream(
1572 &mut stream,
1573 &tx,
1574 &mut response_parts,
1575 &mut result,
1576 true,
1577 stream_start,
1578 )
1579 .await?;
1580 }
1581 RigProvider::OpenAI(client) => {
1582 let stream_start = Instant::now();
1583 let mut stream = build_request_with_options!(client);
1584 consume_rig_stream(
1585 &mut stream,
1586 &tx,
1587 &mut response_parts,
1588 &mut result,
1589 false,
1590 stream_start,
1591 )
1592 .await?;
1593 }
1594 RigProvider::Mistral(client) => {
1595 let stream_start = Instant::now();
1596 let mut stream = build_request_with_options!(client);
1597 consume_rig_stream(
1598 &mut stream,
1599 &tx,
1600 &mut response_parts,
1601 &mut result,
1602 false,
1603 stream_start,
1604 )
1605 .await?;
1606 }
1607 RigProvider::Groq(client) => {
1608 let stream_start = Instant::now();
1609 let mut stream = build_request_with_options!(client);
1610 consume_rig_stream(
1611 &mut stream,
1612 &tx,
1613 &mut response_parts,
1614 &mut result,
1615 false,
1616 stream_start,
1617 )
1618 .await?;
1619 }
1620 RigProvider::DeepSeek(client) => {
1621 let stream_start = Instant::now();
1622 let mut stream = build_request_with_options!(client);
1623 consume_rig_stream(
1624 &mut stream,
1625 &tx,
1626 &mut response_parts,
1627 &mut result,
1628 false,
1629 stream_start,
1630 )
1631 .await?;
1632 }
1633 RigProvider::Gemini(client) => {
1634 let stream_start = Instant::now();
1635 let mut stream = build_request_with_options!(client);
1636 consume_rig_stream(
1637 &mut stream,
1638 &tx,
1639 &mut response_parts,
1640 &mut result,
1641 false,
1642 stream_start,
1643 )
1644 .await?;
1645 }
1646 RigProvider::XAi(client) => {
1647 let stream_start = Instant::now();
1648 let mut stream = build_request_with_options!(client);
1649 consume_rig_stream(
1650 &mut stream,
1651 &tx,
1652 &mut response_parts,
1653 &mut result,
1654 false,
1655 stream_start,
1656 )
1657 .await?;
1658 }
1659 #[cfg(feature = "native-inference")]
1661 RigProvider::Native(runtime) => {
1662 use futures::StreamExt;
1663 use std::pin::pin;
1664
1665 let native_prompt = if let Some(ref system) = options.system {
1667 format!("{}\n\n{}", system, prompt)
1668 } else {
1669 prompt.to_string()
1670 };
1671 let chat_options = super::native::ChatOptions {
1672 temperature: effective_temperature.map(|t| t as f32),
1673 max_tokens: options.max_tokens,
1674 ..Default::default()
1675 };
1676 let stream = runtime
1677 .infer_stream(&native_prompt, chat_options)
1678 .await
1679 .map_err(|e: super::native::NativeError| {
1680 RigInferError::PromptError(e.to_string())
1681 })?;
1682
1683 let mut stream = pin!(stream);
1685
1686 while let Some(result) = stream.next().await {
1688 match result {
1689 Ok(token) => {
1690 response_parts.push(token.clone());
1691 let _ = tx.try_send(StreamChunk::Token(token));
1692 }
1693 Err(e) => {
1694 let _ = tx.try_send(StreamChunk::Error(e.to_string()));
1695 return Err(RigInferError::PromptError(e.to_string()));
1696 }
1697 }
1698 }
1699
1700 }
1703 }
1704
1705 let complete_response = response_parts.concat();
1706 let _ = tx.try_send(StreamChunk::Done(complete_response.clone()));
1707
1708 let _ = tx.try_send(StreamChunk::Metrics {
1709 input_tokens: result.input_tokens,
1710 output_tokens: result.output_tokens,
1711 });
1712
1713 result.text = complete_response;
1714 Ok(result)
1715 }
1716}
1717
1718#[derive(Debug, Clone)]
1727pub struct NikaMcpToolDef {
1728 pub name: String,
1730 pub description: String,
1732 pub input_schema: serde_json::Value,
1734}
1735
1736pub type AgentMediaStaging = Arc<dashmap::DashMap<String, Vec<crate::mcp::types::ContentBlock>>>;
1743
1744#[derive(Debug, Clone)]
1752pub struct NikaMcpTool {
1753 definition: NikaMcpToolDef,
1754 client: Option<Arc<McpClient>>,
1756 media_staging: Option<AgentMediaStaging>,
1758}
1759
1760impl NikaMcpTool {
1761 pub fn new(definition: NikaMcpToolDef) -> Self {
1763 Self {
1764 definition,
1765 client: None,
1766 media_staging: None,
1767 }
1768 }
1769
1770 pub fn with_client(definition: NikaMcpToolDef, client: Arc<McpClient>) -> Self {
1772 Self {
1773 definition,
1774 client: Some(client),
1775 media_staging: None,
1776 }
1777 }
1778
1779 pub fn with_media_staging(
1781 definition: NikaMcpToolDef,
1782 client: Arc<McpClient>,
1783 staging: AgentMediaStaging,
1784 ) -> Self {
1785 Self {
1786 definition,
1787 client: Some(client),
1788 media_staging: Some(staging),
1789 }
1790 }
1791
1792 pub fn tool_name(&self) -> &str {
1794 &self.definition.name
1795 }
1796}
1797
1798type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
1800
1801impl ToolDyn for NikaMcpTool {
1802 fn name(&self) -> String {
1803 self.definition.name.clone()
1804 }
1805
1806 fn definition(&self, _prompt: String) -> BoxFuture<'_, ToolDefinition> {
1807 let def = ToolDefinition {
1808 name: self.definition.name.clone(),
1809 description: self.definition.description.clone(),
1810 parameters: self.definition.input_schema.clone(),
1811 };
1812 Box::pin(async move { def })
1813 }
1814
1815 fn call(&self, args: String) -> BoxFuture<'_, Result<String, ToolError>> {
1816 let tool_name = self.definition.name.clone();
1817 let client = self.client.clone();
1818
1819 Box::pin(async move {
1820 let params: serde_json::Value = serde_json::from_str(&args).map_err(|e| {
1822 ToolError::ToolCallError(Box::new(McpToolError::invalid_args(format!(
1823 "Invalid JSON arguments: {}",
1824 e
1825 ))))
1826 })?;
1827
1828 let client = client.ok_or_else(|| {
1830 ToolError::ToolCallError(Box::new(McpToolError::not_configured(
1831 "No MCP client configured for this tool",
1832 )))
1833 })?;
1834
1835 let result = client.call_tool(&tool_name, params).await.map_err(|e| {
1837 ToolError::ToolCallError(Box::new(McpToolError::call_failed(format!(
1838 "MCP tool call failed: {}",
1839 e
1840 ))))
1841 })?;
1842
1843 if result.has_media() {
1845 if let Some(ref staging) = self.media_staging {
1846 let media_blocks: Vec<_> = result.media_blocks().into_iter().cloned().collect();
1847 if !media_blocks.is_empty() {
1848 tracing::debug!(
1849 tool = %tool_name,
1850 media_count = media_blocks.len(),
1851 "agent: staging binary content from tool call"
1852 );
1853 staging
1854 .entry(tool_name.clone())
1855 .or_default()
1856 .extend(media_blocks);
1857 }
1858 } else {
1859 tracing::warn!(
1860 tool = %tool_name,
1861 media_count = result.media_blocks().len(),
1862 "agent: tool returned binary content but no media staging configured — data will be lost"
1863 );
1864 }
1865 }
1866
1867 let output = result.text();
1869
1870 if output.is_empty() {
1871 serde_json::to_string(&result).map_err(|e| {
1873 ToolError::ToolCallError(Box::new(McpToolError::serialization(format!(
1874 "Failed to serialize result: {}",
1875 e
1876 ))))
1877 })
1878 } else {
1879 Ok(output)
1880 }
1881 })
1882 }
1883}
1884
1885#[cfg(feature = "native-inference")]
1898fn extract_native_vision_parts(
1899 user_content: &[rig::completion::message::UserContent],
1900) -> Result<(String, Vec<crate::core::backend::VisionImage>), RigInferError> {
1901 use base64::Engine as _;
1902 use rig::completion::message::{DocumentSourceKind, Image, UserContent};
1903
1904 let mut text_parts: Vec<String> = Vec::new();
1905 let mut images: Vec<crate::core::backend::VisionImage> = Vec::new();
1906
1907 for part in user_content {
1908 match part {
1909 UserContent::Text(text) => {
1910 text_parts.push(text.text.clone());
1911 }
1912 UserContent::Image(Image {
1913 data, media_type, ..
1914 }) => {
1915 let bytes = match data {
1916 DocumentSourceKind::Base64(b64) => base64::engine::general_purpose::STANDARD
1917 .decode(b64)
1918 .map_err(|e| {
1919 RigInferError::PromptError(format!(
1920 "Failed to decode base64 image for native vision: {}",
1921 e
1922 ))
1923 })?,
1924 DocumentSourceKind::Raw(raw) => raw.clone(),
1925 DocumentSourceKind::Url(url) => {
1926 return Err(RigInferError::VisionNotSupported(format!(
1927 "Native vision does not support URL images. Pre-fetch the image: {}",
1928 url
1929 )));
1930 }
1931 _ => {
1932 return Err(RigInferError::PromptError(
1933 "Unsupported image source kind for native vision".to_string(),
1934 ));
1935 }
1936 };
1937
1938 let mime = media_type
1940 .as_ref()
1941 .map(|mt| match mt {
1942 rig::completion::message::ImageMediaType::JPEG => "image/jpeg",
1943 rig::completion::message::ImageMediaType::PNG => "image/png",
1944 rig::completion::message::ImageMediaType::GIF => "image/gif",
1945 rig::completion::message::ImageMediaType::WEBP => "image/webp",
1946 _ => "image/png", })
1948 .unwrap_or("image/png");
1949
1950 images.push(crate::core::backend::VisionImage::new(bytes, mime));
1951 }
1952 _ => {}
1954 }
1955 }
1956
1957 Ok((text_parts.join("\n"), images))
1958}
1959
1960#[cfg(test)]
1961mod tests {
1962 use super::*;
1963 use serial_test::serial;
1964
1965 #[test]
1970 fn stream_result_from_text_has_zero_tokens() {
1971 let result = StreamResult::from_text("hello world");
1972 assert_eq!(result.text, "hello world");
1973 assert_eq!(result.input_tokens, 0);
1974 assert_eq!(result.output_tokens, 0);
1975 assert_eq!(result.total_tokens, 0);
1976 assert_eq!(result.cached_input_tokens, 0);
1977 }
1978
1979 #[test]
1980 fn stream_result_default_is_empty() {
1981 let result = StreamResult::default();
1982 assert_eq!(result.text, "");
1983 assert_eq!(result.total_tokens, 0);
1984 }
1985
1986 #[test]
1987 fn stream_result_with_tokens() {
1988 let result = StreamResult {
1989 text: "response".to_string(),
1990 input_tokens: 100,
1991 output_tokens: 50,
1992 total_tokens: 150,
1993 cached_input_tokens: 20,
1994 ttft_ms: None,
1995 request_id: None,
1996 };
1997 assert_eq!(
1998 result.total_tokens,
1999 result.input_tokens + result.output_tokens
2000 );
2001 assert_eq!(result.cached_input_tokens, 20);
2002 }
2003
2004 #[test]
2005 #[serial]
2006 fn test_rig_provider_claude_returns_claude_variant() {
2007 std::env::set_var("ANTHROPIC_API_KEY", "test-key-for-unit-test");
2013 let provider = RigProvider::claude();
2014
2015 assert_eq!(provider.name(), "claude");
2016 assert!(matches!(provider, RigProvider::Claude(_)));
2017 }
2018
2019 #[test]
2020 #[serial]
2021 fn test_rig_provider_openai_returns_openai_variant() {
2022 std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
2023 let provider = RigProvider::openai();
2024
2025 assert_eq!(provider.name(), "openai");
2026 assert!(matches!(provider, RigProvider::OpenAI(_)));
2027 }
2028
2029 #[test]
2030 #[serial]
2031 fn test_rig_provider_default_model_claude() {
2032 std::env::set_var("ANTHROPIC_API_KEY", "test-key-for-unit-test");
2033 let provider = RigProvider::claude();
2034
2035 assert_eq!(provider.default_model(), "claude-sonnet-4-6");
2038 }
2039
2040 #[test]
2041 #[serial]
2042 fn test_rig_provider_default_model_openai() {
2043 std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
2044 let provider = RigProvider::openai();
2045
2046 assert_eq!(provider.default_model(), openai::GPT_4O);
2047 }
2048
2049 #[test]
2050 fn test_rig_infer_error_display() {
2051 let err = RigInferError::PromptError("Test error message".to_string());
2052 assert_eq!(err.to_string(), "Completion error: Test error message");
2053 }
2054
2055 #[test]
2056 fn test_rig_infer_error_timeout_display() {
2057 let err = RigInferError::Timeout { duration_ms: 60000 };
2059 assert_eq!(
2060 err.to_string(),
2061 "Stream timeout: no chunk received for 60000ms"
2062 );
2063 }
2064
2065 #[test]
2070 #[serial]
2071 fn test_rig_provider_mistral_returns_mistral_variant() {
2072 std::env::set_var("MISTRAL_API_KEY", "test-key-for-unit-test");
2073 let provider = RigProvider::mistral();
2074
2075 assert_eq!(provider.name(), "mistral");
2076 assert!(matches!(provider, RigProvider::Mistral(_)));
2077 }
2078
2079 #[test]
2080 #[serial]
2081 fn test_rig_provider_groq_returns_groq_variant() {
2082 std::env::set_var("GROQ_API_KEY", "test-key-for-unit-test");
2083 let provider = RigProvider::groq();
2084
2085 assert_eq!(provider.name(), "groq");
2086 assert!(matches!(provider, RigProvider::Groq(_)));
2087 }
2088
2089 #[test]
2090 #[serial]
2091 fn test_rig_provider_deepseek_returns_deepseek_variant() {
2092 std::env::set_var("DEEPSEEK_API_KEY", "test-key-for-unit-test");
2093 let provider = RigProvider::deepseek();
2094
2095 assert_eq!(provider.name(), "deepseek");
2096 assert!(matches!(provider, RigProvider::DeepSeek(_)));
2097 }
2098
2099 #[test]
2100 #[serial]
2101 fn test_rig_provider_default_models_v06() {
2102 std::env::set_var("MISTRAL_API_KEY", "test");
2104 std::env::set_var("GROQ_API_KEY", "test");
2105 std::env::set_var("DEEPSEEK_API_KEY", "test");
2106
2107 assert_eq!(
2108 RigProvider::mistral().default_model(),
2109 mistral::MISTRAL_LARGE
2110 );
2111 assert_eq!(
2112 RigProvider::groq().default_model(),
2113 "llama-3.3-70b-versatile"
2114 );
2115 assert_eq!(RigProvider::deepseek().default_model(), "deepseek-chat");
2116 }
2117
2118 #[test]
2119 #[serial]
2120 fn test_rig_provider_auto_detects_claude() {
2121 std::env::remove_var("OPENAI_API_KEY");
2123 std::env::remove_var("MISTRAL_API_KEY");
2124 std::env::remove_var("GROQ_API_KEY");
2125 std::env::remove_var("DEEPSEEK_API_KEY");
2126 std::env::set_var("ANTHROPIC_API_KEY", "test-key");
2127
2128 let provider = RigProvider::auto();
2129 assert!(provider.is_some());
2130 assert_eq!(provider.unwrap().name(), "claude");
2131 }
2132
2133 #[test]
2134 #[serial]
2135 fn test_rig_provider_auto_returns_none_when_no_keys() {
2136 clear_all_provider_env_vars();
2138
2139 let provider = RigProvider::auto();
2140 assert!(provider.is_none());
2141 }
2142
2143 fn clear_all_provider_env_vars() {
2149 std::env::remove_var("ANTHROPIC_API_KEY");
2150 std::env::remove_var("OPENAI_API_KEY");
2151 std::env::remove_var("MISTRAL_API_KEY");
2152 std::env::remove_var("GROQ_API_KEY");
2153 std::env::remove_var("DEEPSEEK_API_KEY");
2154 std::env::remove_var("GEMINI_API_KEY");
2155 }
2156
2157 #[test]
2158 #[serial]
2159 fn test_auto_fallback_to_openai() {
2160 clear_all_provider_env_vars();
2162 std::env::set_var("OPENAI_API_KEY", "test-key");
2163
2164 let provider = RigProvider::auto();
2166
2167 assert!(provider.is_some());
2169 assert_eq!(provider.unwrap().name(), "openai");
2170 }
2171
2172 #[test]
2173 #[serial]
2174 fn test_auto_fallback_to_mistral() {
2175 clear_all_provider_env_vars();
2177 std::env::set_var("MISTRAL_API_KEY", "test-key");
2178
2179 let provider = RigProvider::auto();
2181
2182 assert!(provider.is_some());
2184 assert_eq!(provider.unwrap().name(), "mistral");
2185 }
2186
2187 #[test]
2188 #[serial]
2189 fn test_auto_fallback_to_groq() {
2190 clear_all_provider_env_vars();
2192 std::env::set_var("GROQ_API_KEY", "test-key");
2193
2194 let provider = RigProvider::auto();
2196
2197 assert!(provider.is_some());
2199 assert_eq!(provider.unwrap().name(), "groq");
2200 }
2201
2202 #[test]
2203 #[serial]
2204 fn test_auto_fallback_to_deepseek() {
2205 clear_all_provider_env_vars();
2207 std::env::set_var("DEEPSEEK_API_KEY", "test-key");
2208
2209 let provider = RigProvider::auto();
2211
2212 assert!(provider.is_some());
2214 assert_eq!(provider.unwrap().name(), "deepseek");
2215 }
2216
2217 #[test]
2218 #[serial]
2219 fn test_auto_fallback_to_gemini() {
2220 clear_all_provider_env_vars();
2222 std::env::set_var("GEMINI_API_KEY", "test-key");
2223
2224 let provider = RigProvider::auto();
2226
2227 assert!(provider.is_some());
2229 assert_eq!(provider.unwrap().name(), "gemini");
2230 }
2231
2232 #[test]
2233 #[serial]
2234 fn test_auto_priority_claude_over_openai() {
2235 clear_all_provider_env_vars();
2237 std::env::set_var("ANTHROPIC_API_KEY", "claude-key");
2238 std::env::set_var("OPENAI_API_KEY", "openai-key");
2239
2240 let provider = RigProvider::auto();
2242
2243 assert!(provider.is_some());
2245 assert_eq!(provider.unwrap().name(), "claude");
2246 }
2247
2248 #[test]
2249 #[serial]
2250 fn test_auto_priority_openai_over_mistral() {
2251 clear_all_provider_env_vars();
2253 std::env::set_var("OPENAI_API_KEY", "openai-key");
2254 std::env::set_var("MISTRAL_API_KEY", "mistral-key");
2255
2256 let provider = RigProvider::auto();
2258
2259 assert!(provider.is_some());
2261 assert_eq!(provider.unwrap().name(), "openai");
2262 }
2263
2264 #[test]
2265 #[serial]
2266 fn test_auto_empty_env_var_treated_as_unset() {
2267 clear_all_provider_env_vars();
2269 std::env::set_var("ANTHROPIC_API_KEY", ""); std::env::set_var("OPENAI_API_KEY", "valid-key");
2271
2272 let provider = RigProvider::auto();
2274
2275 assert!(provider.is_some());
2277 assert_eq!(provider.unwrap().name(), "openai");
2278 }
2279
2280 #[test]
2281 #[serial]
2282 fn test_auto_whitespace_env_var_treated_as_unset() {
2283 clear_all_provider_env_vars();
2285 std::env::set_var("ANTHROPIC_API_KEY", " "); let provider = RigProvider::auto();
2289
2290 assert!(
2293 provider.is_none(),
2294 "Whitespace-only API key should be treated as unset"
2295 );
2296 }
2297
2298 #[test]
2303 fn test_nika_mcp_tool_implements_tool_dyn() {
2304 let tool_def = NikaMcpToolDef {
2306 name: "novanet_context".to_string(),
2307 description: "Generate native content for an entity".to_string(),
2308 input_schema: serde_json::json!({
2309 "type": "object",
2310 "properties": {
2311 "entity": { "type": "string" },
2312 "locale": { "type": "string" }
2313 },
2314 "required": ["entity", "locale"]
2315 }),
2316 };
2317
2318 let tool = NikaMcpTool::new(tool_def);
2320
2321 assert_eq!(tool.tool_name(), "novanet_context");
2323 }
2324
2325 #[test]
2326 fn test_nika_mcp_tool_definition_returns_correct_schema() {
2327 use rig::tool::ToolDyn;
2328
2329 let tool_def = NikaMcpToolDef {
2331 name: "novanet_describe".to_string(),
2332 description: "Describe an entity from the knowledge graph".to_string(),
2333 input_schema: serde_json::json!({
2334 "type": "object",
2335 "properties": {
2336 "entity_key": { "type": "string" }
2337 },
2338 "required": ["entity_key"]
2339 }),
2340 };
2341 let tool = NikaMcpTool::new(tool_def);
2342
2343 let name = tool.name();
2345
2346 assert_eq!(name, "novanet_describe");
2348 }
2349
2350 #[tokio::test]
2355 async fn test_nika_mcp_tool_call_uses_mcp_client() {
2356 use crate::mcp::McpClient;
2357 use rig::tool::ToolDyn;
2358 use std::sync::Arc;
2359
2360 let client = Arc::new(McpClient::mock("novanet"));
2362
2363 let tool_def = NikaMcpToolDef {
2365 name: "novanet_describe".to_string(),
2366 description: "Describe an entity".to_string(),
2367 input_schema: serde_json::json!({
2368 "type": "object",
2369 "properties": {
2370 "entity_key": { "type": "string" }
2371 },
2372 "required": ["entity_key"]
2373 }),
2374 };
2375 let tool = NikaMcpTool::with_client(tool_def, client);
2376
2377 let args = r#"{"entity_key": "qr-code"}"#.to_string();
2379 let result = tool.call(args).await;
2380
2381 assert!(result.is_ok(), "Tool call should succeed with mock client");
2383 let output = result.unwrap();
2384 assert!(!output.is_empty(), "Tool should return non-empty output");
2385 }
2386
2387 #[tokio::test]
2393 async fn test_usecase_novanet_context_entity_locale() {
2394 use crate::mcp::McpClient;
2395 use rig::tool::ToolDyn;
2396 use std::sync::Arc;
2397
2398 let client = Arc::new(McpClient::mock("novanet"));
2400
2401 let tool_def = NikaMcpToolDef {
2403 name: "novanet_context".to_string(),
2404 description: "Full RLM-on-KG context assembly for generation".to_string(),
2405 input_schema: serde_json::json!({
2406 "type": "object",
2407 "properties": {
2408 "focus_key": { "type": "string", "description": "Entity key to generate for" },
2409 "locale": { "type": "string", "description": "BCP-47 locale code" },
2410 "mode": { "type": "string", "enum": ["block", "page"], "default": "block" },
2411 "token_budget": { "type": "integer", "default": 4000 },
2412 "spreading_depth": { "type": "integer", "default": 2 },
2413 "forms": {
2414 "type": "array",
2415 "items": { "type": "string", "enum": ["text", "title", "abbrev", "url"] }
2416 }
2417 },
2418 "required": ["focus_key", "locale"]
2419 }),
2420 };
2421 let tool = NikaMcpTool::with_client(tool_def, client);
2422
2423 let args = serde_json::json!({
2425 "focus_key": "qr-code",
2426 "locale": "fr-FR",
2427 "mode": "page",
2428 "forms": ["text", "title", "abbrev"]
2429 })
2430 .to_string();
2431
2432 let result = tool.call(args).await;
2433
2434 assert!(
2436 result.is_ok(),
2437 "novanet_context should succeed: {:?}",
2438 result
2439 );
2440 let output = result.unwrap();
2441 assert!(!output.is_empty(), "Should return generation context");
2442 }
2443
2444 #[tokio::test]
2446 async fn test_usecase_novanet_describe_entity() {
2447 use crate::mcp::McpClient;
2448 use rig::tool::ToolDyn;
2449 use std::sync::Arc;
2450
2451 let client = Arc::new(McpClient::mock("novanet"));
2452
2453 let tool_def = NikaMcpToolDef {
2454 name: "novanet_describe".to_string(),
2455 description: "Bootstrap agent understanding of the knowledge graph".to_string(),
2456 input_schema: serde_json::json!({
2457 "type": "object",
2458 "properties": {
2459 "describe": {
2460 "type": "string",
2461 "enum": ["schema", "entity", "category", "relations", "locales", "stats"]
2462 },
2463 "entity_key": { "type": "string" },
2464 "category_key": { "type": "string" }
2465 },
2466 "required": ["describe"]
2467 }),
2468 };
2469 let tool = NikaMcpTool::with_client(tool_def, client);
2470
2471 let args = serde_json::json!({
2473 "describe": "schema"
2474 })
2475 .to_string();
2476
2477 let result = tool.call(args).await;
2478 assert!(result.is_ok(), "novanet_describe should succeed");
2479 }
2480
2481 #[tokio::test]
2483 async fn test_usecase_novanet_search_walk_graph() {
2484 use crate::mcp::McpClient;
2485 use rig::tool::ToolDyn;
2486 use std::sync::Arc;
2487
2488 let client = Arc::new(McpClient::mock("novanet"));
2489
2490 let tool_def = NikaMcpToolDef {
2491 name: "novanet_search".to_string(),
2492 description: "Graph traversal with configurable depth and filters".to_string(),
2493 input_schema: serde_json::json!({
2494 "type": "object",
2495 "properties": {
2496 "start_key": { "type": "string" },
2497 "max_depth": { "type": "integer", "default": 2 },
2498 "direction": { "type": "string", "enum": ["outgoing", "incoming", "both"] },
2499 "arc_families": { "type": "array", "items": { "type": "string" } },
2500 "target_kinds": { "type": "array", "items": { "type": "string" } }
2501 },
2502 "required": ["start_key"]
2503 }),
2504 };
2505 let tool = NikaMcpTool::with_client(tool_def, client);
2506
2507 let args = serde_json::json!({
2509 "start_key": "qr-code",
2510 "max_depth": 2,
2511 "direction": "outgoing",
2512 "arc_families": ["ownership", "localization"]
2513 })
2514 .to_string();
2515
2516 let result = tool.call(args).await;
2517 assert!(result.is_ok(), "novanet_search walk should succeed");
2518 }
2519
2520 #[tokio::test]
2522 async fn test_usecase_novanet_search_hybrid() {
2523 use crate::mcp::McpClient;
2524 use rig::tool::ToolDyn;
2525 use std::sync::Arc;
2526
2527 let client = Arc::new(McpClient::mock("novanet"));
2528
2529 let tool_def = NikaMcpToolDef {
2530 name: "novanet_search".to_string(),
2531 description: "Fulltext + property search with hybrid mode".to_string(),
2532 input_schema: serde_json::json!({
2533 "type": "object",
2534 "properties": {
2535 "query": { "type": "string" },
2536 "mode": { "type": "string", "enum": ["fulltext", "property", "hybrid"] },
2537 "kinds": { "type": "array", "items": { "type": "string" } },
2538 "realm": { "type": "string", "enum": ["shared", "org"] },
2539 "limit": { "type": "integer", "default": 10 }
2540 },
2541 "required": ["query"]
2542 }),
2543 };
2544 let tool = NikaMcpTool::with_client(tool_def, client);
2545
2546 let args = serde_json::json!({
2548 "query": "QR code generator",
2549 "mode": "hybrid",
2550 "kinds": ["Entity", "Page"],
2551 "limit": 5
2552 })
2553 .to_string();
2554
2555 let result = tool.call(args).await;
2556 assert!(result.is_ok(), "novanet_search should succeed");
2557 }
2558
2559 #[tokio::test]
2561 async fn test_usecase_novanet_audit_locale() {
2562 use crate::mcp::McpClient;
2563 use rig::tool::ToolDyn;
2564 use std::sync::Arc;
2565
2566 let client = Arc::new(McpClient::mock("novanet"));
2567
2568 let tool_def = NikaMcpToolDef {
2569 name: "novanet_audit".to_string(),
2570 description: "Retrieve knowledge atoms for a specific locale".to_string(),
2571 input_schema: serde_json::json!({
2572 "type": "object",
2573 "properties": {
2574 "locale": { "type": "string" },
2575 "atom_type": {
2576 "type": "string",
2577 "enum": ["term", "expression", "pattern", "cultureref", "taboo", "audiencetrait", "all"]
2578 },
2579 "domain": { "type": "string" }
2580 },
2581 "required": ["locale"]
2582 }),
2583 };
2584 let tool = NikaMcpTool::with_client(tool_def, client);
2585
2586 let args = serde_json::json!({
2588 "locale": "fr-FR",
2589 "atom_type": "term",
2590 "domain": "qr-code"
2591 })
2592 .to_string();
2593
2594 let result = tool.call(args).await;
2595 assert!(result.is_ok(), "novanet_audit should succeed");
2596 }
2597
2598 #[tokio::test]
2600 async fn test_usecase_novanet_batch_context() {
2601 use crate::mcp::McpClient;
2602 use rig::tool::ToolDyn;
2603 use std::sync::Arc;
2604
2605 let client = Arc::new(McpClient::mock("novanet"));
2606
2607 let tool_def = NikaMcpToolDef {
2608 name: "novanet_batch".to_string(),
2609 description: "Assemble context for LLM generation (token-aware)".to_string(),
2610 input_schema: serde_json::json!({
2611 "type": "object",
2612 "properties": {
2613 "focus_key": { "type": "string" },
2614 "locale": { "type": "string" },
2615 "token_budget": { "type": "integer", "default": 4000 },
2616 "strategy": {
2617 "type": "string",
2618 "enum": ["breadth", "depth", "relevance", "custom"]
2619 }
2620 },
2621 "required": ["focus_key", "locale"]
2622 }),
2623 };
2624 let tool = NikaMcpTool::with_client(tool_def, client);
2625
2626 let args = serde_json::json!({
2628 "focus_key": "qr-code",
2629 "locale": "es-MX",
2630 "token_budget": 3000,
2631 "strategy": "relevance"
2632 })
2633 .to_string();
2634
2635 let result = tool.call(args).await;
2636 assert!(result.is_ok(), "novanet_batch should succeed");
2637 }
2638
2639 #[tokio::test]
2645 async fn test_error_no_client_configured() {
2646 use rig::tool::ToolDyn;
2647
2648 let tool_def = NikaMcpToolDef {
2650 name: "novanet_describe".to_string(),
2651 description: "Test tool".to_string(),
2652 input_schema: serde_json::json!({"type": "object"}),
2653 };
2654 let tool = NikaMcpTool::new(tool_def); let args = r#"{"entity_key": "test"}"#.to_string();
2658 let result = tool.call(args).await;
2659
2660 assert!(result.is_err(), "Should fail without client");
2662 let err = result.unwrap_err();
2663 let err_str = err.to_string();
2664 assert!(
2665 err_str.contains("No MCP client") || err_str.contains("NotConnected"),
2666 "Error should mention missing client: {}",
2667 err_str
2668 );
2669 }
2670
2671 #[tokio::test]
2673 async fn test_error_invalid_json_arguments() {
2674 use crate::mcp::McpClient;
2675 use rig::tool::ToolDyn;
2676 use std::sync::Arc;
2677
2678 let client = Arc::new(McpClient::mock("novanet"));
2679 let tool_def = NikaMcpToolDef {
2680 name: "novanet_describe".to_string(),
2681 description: "Test tool".to_string(),
2682 input_schema: serde_json::json!({"type": "object"}),
2683 };
2684 let tool = NikaMcpTool::with_client(tool_def, client);
2685
2686 let args = "not valid json {{{".to_string();
2688 let result = tool.call(args).await;
2689
2690 assert!(result.is_err(), "Should fail with invalid JSON");
2692 let err = result.unwrap_err();
2693 let err_str = err.to_string();
2694 assert!(
2695 err_str.contains("Invalid JSON") || err_str.contains("JSON"),
2696 "Error should mention JSON parsing: {}",
2697 err_str
2698 );
2699 }
2700
2701 #[tokio::test]
2703 async fn test_empty_json_object_is_valid() {
2704 use crate::mcp::McpClient;
2705 use rig::tool::ToolDyn;
2706 use std::sync::Arc;
2707
2708 let client = Arc::new(McpClient::mock("novanet"));
2709 let tool_def = NikaMcpToolDef {
2710 name: "novanet_describe".to_string(),
2711 description: "Test tool".to_string(),
2712 input_schema: serde_json::json!({"type": "object"}),
2713 };
2714 let tool = NikaMcpTool::with_client(tool_def, client);
2715
2716 let args = "{}".to_string();
2718 let result = tool.call(args).await;
2719
2720 assert!(result.is_ok(), "Empty JSON object should be valid");
2722 }
2723
2724 #[tokio::test]
2730 async fn test_tool_definition_async() {
2731 use rig::tool::ToolDyn;
2732
2733 let input_schema = serde_json::json!({
2734 "type": "object",
2735 "properties": {
2736 "entity_key": { "type": "string" },
2737 "locale": { "type": "string" }
2738 },
2739 "required": ["entity_key"]
2740 });
2741
2742 let tool_def = NikaMcpToolDef {
2743 name: "test_tool".to_string(),
2744 description: "A test tool for verification".to_string(),
2745 input_schema: input_schema.clone(),
2746 };
2747 let tool = NikaMcpTool::new(tool_def);
2748
2749 let definition = tool.definition("some prompt".to_string()).await;
2751
2752 assert_eq!(definition.name, "test_tool");
2754 assert_eq!(definition.description, "A test tool for verification");
2755 assert_eq!(definition.parameters, input_schema);
2756 }
2757
2758 #[test]
2760 fn test_multiple_tools_independent() {
2761 let tool1 = NikaMcpTool::new(NikaMcpToolDef {
2763 name: "novanet_context".to_string(),
2764 description: "Generate content".to_string(),
2765 input_schema: serde_json::json!({"type": "object"}),
2766 });
2767
2768 let tool2 = NikaMcpTool::new(NikaMcpToolDef {
2769 name: "novanet_describe".to_string(),
2770 description: "Describe entity".to_string(),
2771 input_schema: serde_json::json!({"type": "object"}),
2772 });
2773
2774 let tool3 = NikaMcpTool::new(NikaMcpToolDef {
2775 name: "novanet_search".to_string(),
2776 description: "Traverse graph".to_string(),
2777 input_schema: serde_json::json!({"type": "object"}),
2778 });
2779
2780 assert_eq!(tool1.tool_name(), "novanet_context");
2782 assert_eq!(tool2.tool_name(), "novanet_describe");
2783 assert_eq!(tool3.tool_name(), "novanet_search");
2784 }
2785
2786 #[tokio::test]
2788 async fn test_tool_clone_works() {
2789 use crate::mcp::McpClient;
2790 use rig::tool::ToolDyn;
2791 use std::sync::Arc;
2792
2793 let client = Arc::new(McpClient::mock("novanet"));
2794 let tool_def = NikaMcpToolDef {
2795 name: "novanet_describe".to_string(),
2796 description: "Test tool".to_string(),
2797 input_schema: serde_json::json!({"type": "object"}),
2798 };
2799 let tool = NikaMcpTool::with_client(tool_def, client);
2800
2801 let cloned_tool = tool.clone();
2803
2804 let args = r#"{"entity_key": "test"}"#.to_string();
2806 let result1 = tool.call(args.clone()).await;
2807 let result2 = cloned_tool.call(args).await;
2808
2809 assert!(result1.is_ok(), "Original tool should work");
2810 assert!(result2.is_ok(), "Cloned tool should work");
2811 }
2812
2813 #[tokio::test]
2819 async fn test_multi_locale_generation_workflow() {
2820 use crate::mcp::McpClient;
2821 use rig::tool::ToolDyn;
2822 use std::sync::Arc;
2823
2824 let client = Arc::new(McpClient::mock("novanet"));
2825 let tool_def = NikaMcpToolDef {
2826 name: "novanet_context".to_string(),
2827 description: "Generate native content".to_string(),
2828 input_schema: serde_json::json!({
2829 "type": "object",
2830 "properties": {
2831 "focus_key": { "type": "string" },
2832 "locale": { "type": "string" },
2833 "forms": { "type": "array", "items": { "type": "string" } }
2834 },
2835 "required": ["focus_key", "locale"]
2836 }),
2837 };
2838 let tool = NikaMcpTool::with_client(tool_def, client);
2839
2840 let locales = ["fr-FR", "es-MX", "de-DE", "ja-JP", "zh-CN"];
2842 let mut results = Vec::new();
2843
2844 for locale in locales {
2845 let args = serde_json::json!({
2846 "focus_key": "qr-code",
2847 "locale": locale,
2848 "forms": ["text", "title"]
2849 })
2850 .to_string();
2851
2852 let result = tool.call(args).await;
2853 results.push((locale, result.is_ok()));
2854 }
2855
2856 for (locale, success) in &results {
2858 assert!(success, "Generation for {} should succeed", locale);
2859 }
2860 assert_eq!(results.len(), 5, "Should process all 5 locales");
2861 }
2862
2863 #[test]
2868 fn test_provider_verify_error_types() {
2869 let invalid_key = ProviderVerifyError::InvalidApiKey {
2871 provider: "claude".to_string(),
2872 };
2873 assert!(invalid_key.to_string().contains("Invalid API key"));
2874 assert!(invalid_key.suggestion().contains("API key"));
2875
2876 let rate_limited = ProviderVerifyError::RateLimited {
2877 provider: "openai".to_string(),
2878 };
2879 assert!(rate_limited.to_string().contains("Rate limited"));
2880
2881 let timeout = ProviderVerifyError::Timeout {
2882 provider: "mistral".to_string(),
2883 };
2884 assert!(timeout.to_string().contains("timeout"));
2885
2886 let network = ProviderVerifyError::NetworkError {
2887 provider: "groq".to_string(),
2888 details: "connection refused".to_string(),
2889 };
2890 assert!(network.to_string().contains("Network error"));
2891
2892 let provider_err = ProviderVerifyError::ProviderError {
2893 provider: "deepseek".to_string(),
2894 details: "server down".to_string(),
2895 };
2896 assert!(provider_err.to_string().contains("server down"));
2897 }
2898
2899 #[test]
2900 fn test_provider_verify_result_fields() {
2901 let result = ProviderVerifyResult {
2902 provider: "claude".to_string(),
2903 latency: std::time::Duration::from_millis(150),
2904 model: "claude-sonnet-4-6".to_string(),
2905 };
2906
2907 assert_eq!(result.provider, "claude");
2908 assert_eq!(result.latency.as_millis(), 150);
2909 assert_eq!(result.model, "claude-sonnet-4-6");
2910 }
2911
2912 #[test]
2913 #[serial]
2914 fn test_is_configured_with_api_key() {
2915 std::env::set_var("ANTHROPIC_API_KEY", "test-key");
2916 let provider = RigProvider::claude();
2917 assert!(provider.is_configured());
2918 }
2919
2920 #[test]
2921 #[serial]
2922 fn test_is_configured_returns_true_for_all_providers_with_keys() {
2923 std::env::set_var("ANTHROPIC_API_KEY", "test");
2925 std::env::set_var("OPENAI_API_KEY", "test");
2926 std::env::set_var("MISTRAL_API_KEY", "test");
2927 std::env::set_var("GROQ_API_KEY", "test");
2928 std::env::set_var("DEEPSEEK_API_KEY", "test");
2929
2930 assert!(RigProvider::claude().is_configured());
2931 assert!(RigProvider::openai().is_configured());
2932 assert!(RigProvider::mistral().is_configured());
2933 assert!(RigProvider::groq().is_configured());
2934 assert!(RigProvider::deepseek().is_configured());
2935 }
2936
2937 #[test]
2942 fn test_infer_options_default() {
2943 let opts = InferOptions::default();
2944 assert!(opts.model.is_none());
2945 assert!(opts.temperature.is_none());
2946 assert!(opts.max_tokens.is_none());
2947 assert!(opts.system.is_none());
2948 }
2949
2950 #[test]
2951 fn test_infer_options_with_all_fields() {
2952 let opts = InferOptions {
2953 model: Some("gpt-4o".to_string()),
2954 temperature: Some(0.7),
2955 max_tokens: Some(2000),
2956 system: Some("You are a helpful assistant.".to_string()),
2957 };
2958 assert_eq!(opts.model.as_deref(), Some("gpt-4o"));
2959 assert_eq!(opts.temperature, Some(0.7));
2960 assert_eq!(opts.max_tokens, Some(2000));
2961 assert_eq!(opts.system.as_deref(), Some("You are a helpful assistant."));
2962 }
2963
2964 #[test]
2965 fn test_infer_options_partial_fields() {
2966 let opts = InferOptions {
2967 temperature: Some(0.5),
2968 ..Default::default()
2969 };
2970 assert!(opts.model.is_none());
2971 assert_eq!(opts.temperature, Some(0.5));
2972 assert!(opts.max_tokens.is_none());
2973 assert!(opts.system.is_none());
2974 }
2975
2976 #[test]
2977 fn test_infer_options_temperature_zero() {
2978 let opts = InferOptions {
2979 temperature: Some(0.0),
2980 ..Default::default()
2981 };
2982 assert_eq!(opts.temperature, Some(0.0));
2983 }
2984
2985 #[test]
2986 fn test_infer_options_max_tokens_small() {
2987 let opts = InferOptions {
2988 max_tokens: Some(1),
2989 ..Default::default()
2990 };
2991 assert_eq!(opts.max_tokens, Some(1));
2992 }
2993
2994 #[test]
2995 fn test_infer_options_system_empty_string() {
2996 let opts = InferOptions {
2997 system: Some(String::new()),
2998 ..Default::default()
2999 };
3000 assert_eq!(opts.system.as_deref(), Some(""));
3001 }
3002
3003 #[test]
3004 fn test_infer_options_clone() {
3005 let opts = InferOptions {
3006 model: Some("test-model".to_string()),
3007 temperature: Some(0.8),
3008 max_tokens: Some(1000),
3009 system: Some("Test system".to_string()),
3010 };
3011 let cloned = opts.clone();
3012 assert_eq!(opts.model, cloned.model);
3013 assert_eq!(opts.temperature, cloned.temperature);
3014 assert_eq!(opts.max_tokens, cloned.max_tokens);
3015 assert_eq!(opts.system, cloned.system);
3016 }
3017
3018 #[test]
3023 fn vision_not_supported_error_display() {
3024 let err = RigInferError::VisionNotSupported("DeepSeek no vision".to_string());
3025 assert!(err.to_string().contains("Vision not supported"));
3026 assert!(err.to_string().contains("DeepSeek no vision"));
3027 }
3028
3029 #[tokio::test]
3031 async fn infer_vision_deepseek_returns_error() {
3032 if std::env::var("DEEPSEEK_API_KEY").is_err() {
3033 let err = RigInferError::VisionNotSupported("DeepSeek".to_string());
3035 assert!(err.to_string().contains("Vision not supported"));
3036 return;
3037 }
3038 let provider = RigProvider::deepseek();
3039 let content = vec![rig::completion::message::UserContent::text("hello")];
3040 let result = provider.infer_vision(content, None, None, None).await;
3041 assert!(result.is_err());
3042 assert!(matches!(
3043 result.unwrap_err(),
3044 RigInferError::VisionNotSupported(_)
3045 ));
3046 }
3047
3048 #[test]
3049 fn infer_vision_empty_content_builds_error() {
3050 use rig::OneOrMany;
3052 let content: Vec<rig::completion::message::UserContent> = vec![];
3053 let result = OneOrMany::many(content);
3054 assert!(result.is_err(), "empty content should fail");
3055 }
3056
3057 #[test]
3058 fn build_vision_user_content_text_only() {
3059 let content = [rig::completion::message::UserContent::text("Describe this")];
3060 assert_eq!(content.len(), 1);
3061 }
3062
3063 #[test]
3064 fn build_vision_user_content_with_image() {
3065 use rig::completion::message::{ImageMediaType, UserContent};
3066 let content = [
3067 UserContent::text("What is in this image?"),
3068 UserContent::image_base64(
3069 "iVBORw0KGgo=", Some(ImageMediaType::PNG),
3071 None,
3072 ),
3073 ];
3074 assert_eq!(content.len(), 2);
3075 }
3076
3077 #[test]
3078 fn build_vision_message_from_content() {
3079 use rig::completion::message::{ImageMediaType, Message, UserContent};
3080 use rig::OneOrMany;
3081
3082 let parts = vec![
3083 UserContent::text("Describe this image"),
3084 UserContent::image_base64("iVBORw0KGgo=", Some(ImageMediaType::PNG), None),
3085 ];
3086 let msg = Message::User {
3087 content: OneOrMany::many(parts).unwrap(),
3088 };
3089 assert!(matches!(msg, Message::User { .. }));
3090 }
3091
3092 #[test]
3097 fn reasoning_model_o_series() {
3098 assert!(is_reasoning_model("o1"));
3099 assert!(is_reasoning_model("o1-mini"));
3100 assert!(is_reasoning_model("o1-pro"));
3101 assert!(is_reasoning_model("o3"));
3102 assert!(is_reasoning_model("o3-mini"));
3103 assert!(is_reasoning_model("o3-pro"));
3104 assert!(is_reasoning_model("o4"));
3105 assert!(is_reasoning_model("o4-mini"));
3106 assert!(is_reasoning_model("o1-2024-12-17"));
3107 }
3108
3109 #[test]
3110 fn reasoning_model_gpt5() {
3111 assert!(is_reasoning_model("gpt-5"));
3112 assert!(is_reasoning_model("gpt-5-turbo"));
3113 }
3114
3115 #[test]
3116 fn reasoning_model_deepseek() {
3117 assert!(is_reasoning_model("deepseek-reasoner"));
3118 }
3119
3120 #[test]
3121 fn reasoning_model_case_insensitive() {
3122 assert!(is_reasoning_model("O1"));
3123 assert!(is_reasoning_model("GPT-5"));
3124 }
3125
3126 #[test]
3127 fn non_reasoning_models() {
3128 assert!(!is_reasoning_model("gpt-4o"));
3129 assert!(!is_reasoning_model("gpt-4o-mini"));
3130 assert!(!is_reasoning_model("claude-sonnet-4"));
3131 assert!(!is_reasoning_model("deepseek-chat"));
3132 assert!(!is_reasoning_model("gemini-2.0-flash"));
3133 assert!(!is_reasoning_model("grok-3"));
3134 }
3135}