Skip to main content

rig_dyn/
client.rs

1use crate::traits::{CompletionModel, EmbeddingModel};
2use rig::providers;
3
4#[derive(Clone)]
5pub enum Client {
6    Anthropic(providers::anthropic::Client),
7    Azure(providers::azure::Client),
8    Cohere(providers::cohere::Client),
9    DeepSeek(providers::deepseek::Client),
10    Galadriel(providers::galadriel::Client),
11    Gemini(providers::gemini::Client),
12    Groq(providers::groq::Client),
13    HuggingFace(providers::huggingface::Client),
14    Hyperbolic(providers::hyperbolic::Client),
15    Mira(providers::mira::Client),
16    Moonshot(providers::moonshot::Client),
17    OpenAI(providers::openai::Client),
18    OpenRouter(providers::openrouter::Client),
19    Ollama(providers::ollama::Client),
20    Perplexity(providers::perplexity::Client),
21    Together(providers::together::Client),
22    Xai(providers::xai::Client),
23}
24
25macro_rules! completion_model {
26	($self:expr, $model:expr, {$($variant:ident),*}) => {
27		match $self {
28			$(
29				Client::$variant(client) => Box::new(client.completion_model($model)),
30			)*
31		}
32	}
33}
34
35macro_rules! embedding_model {
36    ($self:expr, $model:expr, $input_type:expr,
37     {$($some_variant:ident),*},
38     {$($none_variant:ident),*},
39     $cohere_expr:expr) => {
40        match $self {
41            $(
42                Client::$some_variant(client) => Some(Box::new(client.embedding_model($model))),
43            )*
44            $(
45                Client::$none_variant(_) => None,
46            )*
47            Client::Cohere(client) => $cohere_expr(client),
48        }
49    }
50}
51
52macro_rules! embedding_model_with_ndims {
53	($self:expr, $model:expr, $ndims:expr, $input_type:expr,
54		{$($some_variant:ident),*},
55		{$($none_variant:ident),*},
56		$cohere_expr:expr) => {
57		   match $self {
58				$(
59					Client::$some_variant(client) => Some(
60						Box::new(client.embedding_model_with_ndims($model, $ndims))
61					),
62				)*
63				$(
64					Client::$none_variant(_) => None,
65				)*
66				Client::Cohere(client) => $cohere_expr(client),
67		   }
68	   }
69}
70
71impl Client {
72    /// Returns a completion model wrapper for the given provider and model name.
73    pub async fn completion_model(&self, model: &str) -> Box<dyn CompletionModel> {
74        completion_model!(
75            self, model,
76            {
77                Anthropic, Azure, Cohere, DeepSeek,
78                Galadriel, Gemini, Groq, Hyperbolic,
79                Moonshot, OpenAI, Ollama, Perplexity, Xai,
80                HuggingFace, OpenRouter, Mira, Together
81            }
82        )
83    }
84
85    /// Returns an embedding model wrapper for the given provider and model name.
86    /// Returns `None` if the provider does not support embeddings or
87    /// if improper input type is provided (cohere requires a input type).
88    pub async fn embedding_model(
89        &self,
90        model: &str,
91        input_type: Option<&str>,
92    ) -> Option<Box<dyn EmbeddingModel>> {
93        embedding_model!(
94            self, model, input_type,
95            {
96                Azure, Gemini, OpenAI, Xai, Ollama, Together
97            },
98            {
99                Anthropic, DeepSeek, Galadriel,
100                Groq, Hyperbolic, Moonshot, Perplexity,
101                Mira, HuggingFace, OpenRouter
102            },
103            |client: &providers::cohere::Client| input_type.map(|input_type| {
104                Box::new(
105                    client.embedding_model(model, input_type)
106                ) as Box<dyn EmbeddingModel>
107            })
108        )
109    }
110
111    pub async fn embedding_model_with_ndims(
112        &self,
113        model: &str,
114        ndims: usize,
115        input_type: Option<&str>,
116    ) -> Option<Box<dyn EmbeddingModel>> {
117        embedding_model_with_ndims!(
118            self, model, ndims, input_type,
119            {
120                Azure, Gemini, OpenAI, Xai, Ollama, Together
121            },
122            {
123                Anthropic, DeepSeek, Galadriel,
124                Groq, Hyperbolic, Moonshot, Perplexity,
125                Mira, HuggingFace, OpenRouter
126            },
127            |client: &providers::cohere::Client| input_type.map(|input_type| {
128                Box::new(
129                    client.embedding_model_with_ndims(model, input_type, ndims)
130                ) as Box<dyn EmbeddingModel>
131            })
132        )
133    }
134}