Skip to main content

rig_dyn/
client.rs

1use crate::traits::{CompletionModel, DynEmbeddingModel, RigCompletionModelAdapter};
2use rig::client::FinalCompletionResponse;
3use rig::client::{CompletionClient, EmbeddingsClient};
4use rig::completion::{self, CompletionError, CompletionRequest, CompletionResponse};
5use rig::providers;
6use rig::streaming::StreamingCompletionResponse;
7
8#[derive(Clone)]
9pub enum Client {
10    Anthropic(providers::anthropic::Client),
11    Azure(providers::azure::Client),
12    Cohere(providers::cohere::Client),
13    DeepSeek(providers::deepseek::Client),
14    Galadriel(providers::galadriel::Client),
15    Gemini(providers::gemini::Client),
16    Groq(providers::groq::Client),
17    HuggingFace(providers::huggingface::Client),
18    Hyperbolic(providers::hyperbolic::Client),
19    Mira(providers::mira::Client),
20    Moonshot(providers::moonshot::Client),
21    OpenAI(providers::openai::Client),
22    OpenRouter(providers::openrouter::Client),
23    Ollama(providers::ollama::Client),
24    Perplexity(providers::perplexity::Client),
25    Together(providers::together::Client),
26    Xai(providers::xai::Client),
27}
28
29#[derive(Clone)]
30pub struct RigClientCompletionModelAdapter {
31    client: Client,
32    model: String,
33}
34
35impl completion::CompletionModel for RigClientCompletionModelAdapter {
36    type Response = ();
37    type StreamingResponse = FinalCompletionResponse;
38    type Client = Client;
39
40    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
41        Self {
42            client: client.clone(),
43            model: model.into(),
44        }
45    }
46
47    fn completion(
48        &self,
49        request: CompletionRequest,
50    ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
51           + rig::wasm_compat::WasmCompatSend {
52        let client = self.client.clone();
53        let model = self.model.clone();
54
55        async move {
56            let completion_model = Client::completion_model(&client, &model).await;
57            completion_model.completion(request).await
58        }
59    }
60
61    fn stream(
62        &self,
63        _request: CompletionRequest,
64    ) -> impl std::future::Future<
65        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
66    > + rig::wasm_compat::WasmCompatSend {
67        async {
68            Err(CompletionError::ResponseError(
69                "Streaming is not supported by rig_dyn::Client adapter".to_string(),
70            ))
71        }
72    }
73}
74
75impl CompletionClient for Client {
76    type CompletionModel = RigClientCompletionModelAdapter;
77}
78
79macro_rules! completion_model {
80	($self:expr, $model:expr, {$($variant:ident),*}) => {
81		match $self {
82			$(
83				Client::$variant(client) => Box::new(client.completion_model($model)),
84			)*
85		}
86	}
87}
88
89macro_rules! embedding_model {
90    ($self:expr, $model:expr, $input_type:expr,
91     {$($some_variant:ident),*},
92     {$($none_variant:ident),*},
93     $cohere_expr:expr) => {
94        match $self {
95            $(
96                Client::$some_variant(client) => Some(Box::new(client.embedding_model($model))),
97            )*
98            $(
99                Client::$none_variant(_) => None,
100            )*
101            Client::Cohere(client) => $cohere_expr(client),
102        }
103    }
104}
105
106macro_rules! embedding_model_with_ndims {
107	($self:expr, $model:expr, $ndims:expr, $input_type:expr,
108		{$($some_variant:ident),*},
109		{$($none_variant:ident),*},
110		$cohere_expr:expr) => {
111		   match $self {
112				$(
113					Client::$some_variant(client) => Some(
114						Box::new(client.embedding_model_with_ndims($model, $ndims))
115					),
116				)*
117				$(
118					Client::$none_variant(_) => None,
119				)*
120				Client::Cohere(client) => $cohere_expr(client),
121		   }
122	   }
123}
124
125impl Client {
126    /// Returns a completion model wrapper for the given provider and model name.
127    pub async fn completion_model(&self, model: &str) -> Box<dyn CompletionModel> {
128        completion_model!(
129            self, model,
130            {
131                Anthropic, Azure, Cohere, DeepSeek,
132                Galadriel, Gemini, Groq, Hyperbolic,
133                Moonshot, OpenAI, Ollama, Perplexity, Xai,
134                HuggingFace, OpenRouter, Mira, Together
135            }
136        )
137    }
138
139    /// Returns a completion model compatible with `rig::completion::CompletionModel`.
140    pub async fn rig_completion_model(&self, model: &str) -> RigCompletionModelAdapter {
141        RigCompletionModelAdapter::from(self.completion_model(model).await)
142    }
143
144    /// Returns an embedding model wrapper for the given provider and model name.
145    /// Returns `None` if the provider does not support embeddings or
146    /// if improper input type is provided (cohere requires a input type).
147    pub async fn embedding_model(
148        &self,
149        model: &str,
150        input_type: Option<&str>,
151    ) -> Option<Box<dyn DynEmbeddingModel>> {
152        embedding_model!(
153            self, model, input_type,
154            {
155                Azure, Gemini, OpenAI, Ollama, Together, OpenRouter
156            },
157            {
158                Anthropic, DeepSeek, Galadriel,
159                Groq, Hyperbolic, Moonshot, Perplexity,
160                Mira, HuggingFace, Xai
161            },
162            |client: &providers::cohere::Client| input_type.map(|input_type| {
163                Box::new(
164                    client.embedding_model(model, input_type)
165                ) as Box<dyn DynEmbeddingModel>
166            })
167        )
168    }
169
170    pub async fn embedding_model_with_ndims(
171        &self,
172        model: &str,
173        ndims: usize,
174        input_type: Option<&str>,
175    ) -> Option<Box<dyn DynEmbeddingModel>> {
176        embedding_model_with_ndims!(
177            self, model, ndims, input_type,
178            {
179                Azure, Gemini, OpenAI, Ollama, Together
180            },
181            {
182                Anthropic, DeepSeek, Galadriel,
183                Groq, Hyperbolic, Moonshot, Perplexity,
184                Mira, HuggingFace, OpenRouter, Xai
185            },
186            |client: &providers::cohere::Client| input_type.map(|input_type| {
187                Box::new(
188                    client.embedding_model_with_ndims(model, input_type, ndims)
189                ) as Box<dyn DynEmbeddingModel>
190            })
191        )
192    }
193}