rig_dyn/
client.rs

1use crate::traits::{CompletionModel, EmbeddingModel};
2use rig::providers;
3
4#[derive(Clone)]
5pub enum Client {
6    Anthropic(providers::anthropic::Client),
7    Azure(providers::azure::Client),
8    Cohere(providers::cohere::Client),
9    DeepSeek(providers::deepseek::Client),
10    Galadriel(providers::galadriel::Client),
11    Gemini(providers::gemini::Client),
12    Groq(providers::groq::Client),
13    HuggingFace(providers::huggingface::Client),
14    Hyperbolic(providers::hyperbolic::Client),
15    Mira(providers::mira::Client),
16    Moonshot(providers::moonshot::Client),
17    OpenAI(providers::openai::Client),
18    OpenRouter(providers::openrouter::Client),
19    Ollama(providers::ollama::Client),
20    Perplexity(providers::perplexity::Client),
21    Xai(providers::xai::Client),
22}
23
24macro_rules! completion_model {
25	($self:expr, $model:expr, {$($variant:ident),*}) => {
26		match $self {
27			$(
28				Client::$variant(client) => Box::new(client.completion_model($model)),
29			)*
30		}
31	}
32}
33
34macro_rules! embedding_model {
35    ($self:expr, $model:expr, $input_type:expr,
36     {$($some_variant:ident),*},
37     {$($none_variant:ident),*},
38     $cohere_expr:expr) => {
39        match $self {
40            $(
41                Client::$some_variant(client) => Some(Box::new(client.embedding_model($model))),
42            )*
43            $(
44                Client::$none_variant(_) => None,
45            )*
46            Client::Cohere(client) => $cohere_expr(client),
47        }
48    }
49}
50
51macro_rules! embedding_model_with_ndims {
52	($self:expr, $model:expr, $ndims:expr, $input_type:expr,
53		{$($some_variant:ident),*},
54		{$($none_variant:ident),*},
55		$cohere_expr:expr) => {
56		   match $self {
57				$(
58					Client::$some_variant(client) => Some(
59						Box::new(client.embedding_model_with_ndims($model, $ndims))
60					),
61				)*
62				$(
63					Client::$none_variant(_) => None,
64				)*
65				Client::Cohere(client) => $cohere_expr(client),
66		   }
67	   }
68}
69
70impl Client {
71    /// Returns a completion model wrapper for the given provider and model name.
72    pub async fn completion_model(&self, model: &str) -> Box<dyn CompletionModel> {
73        completion_model!(
74            self, model,
75            {
76                Anthropic, Azure, Cohere, DeepSeek,
77                Galadriel, Gemini, Groq, Hyperbolic,
78                Moonshot, OpenAI, Ollama, Perplexity, Xai,
79                HuggingFace, OpenRouter, Mira
80            }
81        )
82    }
83
84    /// Returns an embedding model wrapper for the given provider and model name.
85    /// Returns `None` if the provider does not support embeddings or
86    /// if improper input type is provided (cohere requires a input type).
87    pub async fn embedding_model(
88        &self,
89        model: &str,
90        input_type: Option<&str>,
91    ) -> Option<Box<dyn EmbeddingModel>> {
92        embedding_model!(
93            self, model, input_type,
94            {
95                Azure, Gemini, OpenAI, Xai, Ollama
96            },
97            {
98                Anthropic, DeepSeek, Galadriel,
99                Groq, Hyperbolic, Moonshot, Perplexity,
100                Mira, HuggingFace, OpenRouter
101            },
102            |client: &providers::cohere::Client| input_type.map(|input_type| {
103                Box::new(
104                    client.embedding_model(model, input_type)
105                ) as Box<dyn EmbeddingModel>
106            })
107        )
108    }
109
110    pub async fn embedding_model_with_ndims(
111        &self,
112        model: &str,
113        ndims: usize,
114        input_type: Option<&str>,
115    ) -> Option<Box<dyn EmbeddingModel>> {
116        embedding_model_with_ndims!(
117            self, model, ndims, input_type,
118            {
119                Azure, Gemini, OpenAI, Xai, Ollama
120            },
121            {
122                Anthropic, DeepSeek, Galadriel,
123                Groq, Hyperbolic, Moonshot, Perplexity,
124                Mira, HuggingFace, OpenRouter
125            },
126            |client: &providers::cohere::Client| input_type.map(|input_type| {
127                Box::new(
128                    client.embedding_model_with_ndims(model, input_type, ndims)
129                ) as Box<dyn EmbeddingModel>
130            })
131        )
132    }
133}