rig_dyn/
client.rs

1use crate::traits::{CompletionModel, EmbeddingModel};
2use rig::providers;
3
4#[derive(Clone)]
5pub enum Client {
6    Anthropic(providers::anthropic::Client),
7    Azure(providers::azure::Client),
8    Cohere(providers::cohere::Client),
9    DeepSeek(providers::deepseek::Client),
10    Galadriel(providers::galadriel::Client),
11    Gemini(providers::gemini::Client),
12    Groq(providers::groq::Client),
13    Hyperbolic(providers::hyperbolic::Client),
14    Moonshot(providers::moonshot::Client),
15    OpenAI(providers::openai::Client),
16    Ollama(providers::ollama::Client),
17    Perplexity(providers::perplexity::Client),
18    Xai(providers::xai::Client),
19}
20
21macro_rules! completion_model {
22	($self:expr, $model:expr, {$($variant:ident),*}) => {
23		match $self {
24			$(
25				Client::$variant(client) => Box::new(client.completion_model($model)),
26			)*
27		}
28	}
29}
30
31macro_rules! embedding_model {
32    ($self:expr, $model:expr, $input_type:expr,
33     {$($some_variant:ident),*},
34     {$($none_variant:ident),*},
35     $cohere_expr:expr) => {
36        match $self {
37            $(
38                Client::$some_variant(client) => Some(Box::new(client.embedding_model($model))),
39            )*
40            $(
41                Client::$none_variant(_) => None,
42            )*
43            Client::Cohere(client) => $cohere_expr(client),
44        }
45    }
46}
47
48macro_rules! embedding_model_with_ndims {
49	($self:expr, $model:expr, $ndims:expr, $input_type:expr,
50		{$($some_variant:ident),*},
51		{$($none_variant:ident),*},
52		$cohere_expr:expr) => {
53		   match $self {
54				$(
55					Client::$some_variant(client) => Some(
56						Box::new(client.embedding_model_with_ndims($model, $ndims))
57					),
58				)*
59				$(
60					Client::$none_variant(_) => None,
61				)*
62				Client::Cohere(client) => $cohere_expr(client),
63		   }
64	   }
65}
66
67impl Client {
68    /// Returns a completion model wrapper for the given provider and model name.
69    pub async fn completion_model(&self, model: &str) -> Box<dyn CompletionModel> {
70        completion_model!(
71            self, model,
72            {
73                Anthropic, Azure, Cohere, DeepSeek,
74                Galadriel, Gemini, Groq, Hyperbolic,
75                Moonshot, OpenAI, Ollama, Perplexity, Xai
76            }
77        )
78    }
79
80    /// Returns an embedding model wrapper for the given provider and model name.
81    /// Returns `None` if the provider does not support embeddings or
82    /// if improper input type is provided (cohere requires a input type).
83    pub async fn embedding_model(
84        &self,
85        model: &str,
86        input_type: Option<&str>,
87    ) -> Option<Box<dyn EmbeddingModel>> {
88        embedding_model!(
89            self, model, input_type,
90            {
91                Azure, Gemini, OpenAI, Xai, Ollama
92            },
93            {
94                Anthropic, DeepSeek, Galadriel,
95                Groq, Hyperbolic, Moonshot, Perplexity
96            },
97            |client: &providers::cohere::Client| input_type.map(|input_type| {
98                Box::new(
99                    client.embedding_model(model, input_type)
100                ) as Box<dyn EmbeddingModel>
101            })
102        )
103    }
104
105    pub async fn embedding_model_with_ndims(
106        &self,
107        model: &str,
108        ndims: usize,
109        input_type: Option<&str>,
110    ) -> Option<Box<dyn EmbeddingModel>> {
111        embedding_model_with_ndims!(
112            self, model, ndims, input_type,
113            {
114                Azure, Gemini, OpenAI, Xai, Ollama
115            },
116            {
117                Anthropic, DeepSeek, Galadriel,
118                Groq, Hyperbolic, Moonshot, Perplexity
119            },
120            |client: &providers::cohere::Client| input_type.map(|input_type| {
121                Box::new(
122                    client.embedding_model_with_ndims(model, input_type, ndims)
123                ) as Box<dyn EmbeddingModel>
124            })
125        )
126    }
127}