rig/providers/together/
client.rs

1use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel};
2use crate::{
3    client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits},
4    http_client::{self, HttpClientExt},
5};
6use bytes::Bytes;
7use rig::client::CompletionClient;
8
9// ================================================================
10// Together AI Client
11// ================================================================
12const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
13
14pub struct ClientBuilder<'a, T = reqwest::Client> {
15    api_key: &'a str,
16    base_url: &'a str,
17    http_client: T,
18}
19
20impl<'a, T> ClientBuilder<'a, T>
21where
22    T: Default,
23{
24    pub fn new(api_key: &'a str) -> Self {
25        Self {
26            api_key,
27            base_url: TOGETHER_AI_BASE_URL,
28            http_client: Default::default(),
29        }
30    }
31}
32
33impl<'a, T> ClientBuilder<'a, T> {
34    pub fn base_url(mut self, base_url: &'a str) -> Self {
35        self.base_url = base_url;
36        self
37    }
38
39    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
40        ClientBuilder {
41            api_key: self.api_key,
42            base_url: self.base_url,
43            http_client,
44        }
45    }
46
47    pub fn build(self) -> Client<T> {
48        let mut default_headers = reqwest::header::HeaderMap::new();
49        default_headers.insert(
50            reqwest::header::CONTENT_TYPE,
51            "application/json".parse().unwrap(),
52        );
53
54        Client {
55            base_url: self.base_url.to_string(),
56            api_key: self.api_key.to_string(),
57            default_headers,
58            http_client: self.http_client,
59        }
60    }
61}
62#[derive(Clone)]
63pub struct Client<T = reqwest::Client> {
64    base_url: String,
65    default_headers: reqwest::header::HeaderMap,
66    api_key: String,
67    http_client: T,
68}
69
70impl<T> std::fmt::Debug for Client<T>
71where
72    T: std::fmt::Debug,
73{
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct("Client")
76            .field("base_url", &self.base_url)
77            .field("http_client", &self.http_client)
78            .field("default_headers", &self.default_headers)
79            .field("api_key", &"<REDACTED>")
80            .finish()
81    }
82}
83
84impl<T> Client<T>
85where
86    T: Default,
87{
88    /// Create a new Together AI client builder.
89    ///
90    /// # Example
91    /// ```
92    /// use rig::providers::together_ai::{ClientBuilder, self};
93    ///
94    /// // Initialize the Together AI client
95    /// let together_ai = Client::builder("your-together-ai-api-key")
96    ///    .build()
97    /// ```
98    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
99        ClientBuilder::new(api_key)
100    }
101
102    /// Create a new Together AI client. For more control, use the `builder` method.
103    ///
104    /// # Panics
105    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
106    pub fn new(api_key: &str) -> Self {
107        Self::builder(api_key).build()
108    }
109}
110
111impl<T> Client<T>
112where
113    T: HttpClientExt,
114{
115    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
116        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
117
118        tracing::debug!("POST {}", url);
119
120        let mut req = http_client::Request::post(url);
121
122        if let Some(hs) = req.headers_mut() {
123            *hs = self.default_headers.clone();
124        }
125
126        http_client::with_bearer_auth(req, &self.api_key)
127    }
128
129    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
130        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
131
132        tracing::debug!("GET {}", url);
133
134        let mut req = http_client::Request::get(url);
135
136        if let Some(hs) = req.headers_mut() {
137            *hs = self.default_headers.clone();
138        }
139
140        http_client::with_bearer_auth(req, &self.api_key)
141    }
142
143    pub(crate) async fn send<U, R>(
144        &self,
145        req: http_client::Request<U>,
146    ) -> http_client::Result<http::Response<http_client::LazyBody<R>>>
147    where
148        U: Into<Bytes> + Send,
149        R: From<Bytes> + Send + 'static,
150    {
151        self.http_client.send(req).await
152    }
153}
154
155impl Client<reqwest::Client> {
156    pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
157        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
158
159        tracing::debug!("POST {}", url);
160
161        self.http_client
162            .post(url)
163            .bearer_auth(&self.api_key)
164            .headers(self.default_headers.clone())
165    }
166}
167
168impl ProviderClient for Client<reqwest::Client> {
169    /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
170    /// Panics if the environment variable is not set.
171    fn from_env() -> Self {
172        let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
173        Self::new(&api_key)
174    }
175
176    fn from_val(input: crate::client::ProviderValue) -> Self {
177        let crate::client::ProviderValue::Simple(api_key) = input else {
178            panic!("Incorrect provider value type")
179        };
180        Self::new(&api_key)
181    }
182}
183
184impl CompletionClient for Client<reqwest::Client> {
185    type CompletionModel = CompletionModel<reqwest::Client>;
186
187    /// Create a completion model with the given name.
188    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
189        CompletionModel::new(self.clone(), model)
190    }
191}
192
193impl EmbeddingsClient for Client<reqwest::Client> {
194    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
195
196    /// Create an embedding model with the given name.
197    /// Note: default embedding dimension of 0 will be used if model is not known.
198    /// If this is the case, it's better to use function `embedding_model_with_ndims`
199    ///
200    /// # Example
201    /// ```
202    /// use rig::providers::together_ai::{Client, self};
203    ///
204    /// // Initialize the Together AI client
205    /// let together_ai = Client::new("your-together-ai-api-key");
206    ///
207    /// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1);
208    /// ```
209    fn embedding_model(&self, model: &str) -> EmbeddingModel<reqwest::Client> {
210        let ndims = match model {
211            M2_BERT_80M_8K_RETRIEVAL => 8192,
212            _ => 0,
213        };
214        EmbeddingModel::new(self.clone(), model, ndims)
215    }
216
217    /// Create an embedding model with the given name and the number of dimensions in the embedding
218    /// generated by the model.
219    ///
220    /// # Example
221    /// ```
222    /// use rig::providers::together_ai::{Client, self};
223    ///
224    /// // Initialize the Together AI client
225    /// let together_ai = Client::new("your-together-ai-api-key");
226    ///
227    /// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
228    /// ```
229    fn embedding_model_with_ndims(
230        &self,
231        model: &str,
232        ndims: usize,
233    ) -> EmbeddingModel<reqwest::Client> {
234        EmbeddingModel::new(self.clone(), model, ndims)
235    }
236}
237
238impl VerifyClient for Client<reqwest::Client> {
239    #[cfg_attr(feature = "worker", worker::send)]
240    async fn verify(&self) -> Result<(), VerifyError> {
241        let req = self
242            .get("/models")?
243            .body(http_client::NoBody)
244            .map_err(|e| VerifyError::HttpError(e.into()))?;
245
246        let response = HttpClientExt::send(&self.http_client, req).await?;
247
248        match response.status() {
249            reqwest::StatusCode::OK => Ok(()),
250            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
251            reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::GATEWAY_TIMEOUT => {
252                let text = http_client::text(response).await?;
253                Err(VerifyError::ProviderError(text))
254            }
255            _ => {
256                //response.error_for_status()?;
257                Ok(())
258            }
259        }
260    }
261}
262
263impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client<T>);
264
265pub mod together_ai_api_types {
266    use serde::Deserialize;
267
268    impl ApiErrorResponse {
269        pub fn message(&self) -> String {
270            format!("Code `{}`: {}", self.code, self.error)
271        }
272    }
273
274    #[derive(Debug, Deserialize)]
275    pub struct ApiErrorResponse {
276        pub error: String,
277        pub code: String,
278    }
279
280    #[derive(Debug, Deserialize)]
281    #[serde(untagged)]
282    pub enum ApiResponse<T> {
283        Ok(T),
284        Error(ApiErrorResponse),
285    }
286}