rig/providers/together/
client.rs

1use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel};
2use crate::{
3    client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits},
4    http_client::{self, HttpClientExt},
5};
6use bytes::Bytes;
7use rig::client::CompletionClient;
8
9// ================================================================
10// Together AI Client
11// ================================================================
12const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
13
14pub struct ClientBuilder<'a, T = reqwest::Client> {
15    api_key: &'a str,
16    base_url: &'a str,
17    http_client: T,
18}
19
20impl<'a, T> ClientBuilder<'a, T>
21where
22    T: Default,
23{
24    pub fn new(api_key: &'a str) -> Self {
25        Self {
26            api_key,
27            base_url: TOGETHER_AI_BASE_URL,
28            http_client: Default::default(),
29        }
30    }
31}
32
33impl<'a, T> ClientBuilder<'a, T> {
34    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
35        Self {
36            api_key,
37            base_url: TOGETHER_AI_BASE_URL,
38            http_client,
39        }
40    }
41
42    pub fn base_url(mut self, base_url: &'a str) -> Self {
43        self.base_url = base_url;
44        self
45    }
46
47    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
48        ClientBuilder {
49            api_key: self.api_key,
50            base_url: self.base_url,
51            http_client,
52        }
53    }
54
55    pub fn build(self) -> Client<T> {
56        let mut default_headers = reqwest::header::HeaderMap::new();
57        default_headers.insert(
58            reqwest::header::CONTENT_TYPE,
59            "application/json".parse().unwrap(),
60        );
61
62        Client {
63            base_url: self.base_url.to_string(),
64            api_key: self.api_key.to_string(),
65            default_headers,
66            http_client: self.http_client,
67        }
68    }
69}
70#[derive(Clone)]
71pub struct Client<T = reqwest::Client> {
72    base_url: String,
73    default_headers: reqwest::header::HeaderMap,
74    api_key: String,
75    pub http_client: T,
76}
77
78impl<T> std::fmt::Debug for Client<T>
79where
80    T: std::fmt::Debug,
81{
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("Client")
84            .field("base_url", &self.base_url)
85            .field("http_client", &self.http_client)
86            .field("default_headers", &self.default_headers)
87            .field("api_key", &"<REDACTED>")
88            .finish()
89    }
90}
91
92impl<T> Client<T>
93where
94    T: HttpClientExt,
95{
96    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
97        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
98
99        tracing::debug!("POST {}", url);
100
101        let mut req = http_client::Request::post(url);
102
103        if let Some(hs) = req.headers_mut() {
104            *hs = self.default_headers.clone();
105        }
106
107        http_client::with_bearer_auth(req, &self.api_key)
108    }
109
110    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
111        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
112
113        tracing::debug!("GET {}", url);
114
115        let mut req = http_client::Request::get(url);
116
117        if let Some(hs) = req.headers_mut() {
118            *hs = self.default_headers.clone();
119        }
120
121        http_client::with_bearer_auth(req, &self.api_key)
122    }
123
124    pub(crate) async fn send<U, R>(
125        &self,
126        req: http_client::Request<U>,
127    ) -> http_client::Result<http::Response<http_client::LazyBody<R>>>
128    where
129        U: Into<Bytes> + Send,
130        R: From<Bytes> + Send + 'static,
131    {
132        self.http_client.send(req).await
133    }
134}
135
136impl Client<reqwest::Client> {
137    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
138        ClientBuilder::new(api_key)
139    }
140
141    pub fn new(api_key: &str) -> Self {
142        Self::builder(api_key).build()
143    }
144
145    pub fn from_env() -> Self {
146        <Self as ProviderClient>::from_env()
147    }
148}
149
150impl<T> ProviderClient for Client<T>
151where
152    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
153{
154    /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
155    /// Panics if the environment variable is not set.
156    fn from_env() -> Self {
157        let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
158        ClientBuilder::<T>::new(&api_key).build()
159    }
160
161    fn from_val(input: crate::client::ProviderValue) -> Self {
162        let crate::client::ProviderValue::Simple(api_key) = input else {
163            panic!("Incorrect provider value type")
164        };
165        ClientBuilder::<T>::new(&api_key).build()
166    }
167}
168
169impl<T> CompletionClient for Client<T>
170where
171    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
172{
173    type CompletionModel = CompletionModel<T>;
174
175    /// Create a completion model with the given name.
176    fn completion_model(&self, model: &str) -> Self::CompletionModel {
177        CompletionModel::new(self.clone(), model)
178    }
179}
180
181impl<T> EmbeddingsClient for Client<T>
182where
183    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
184{
185    type EmbeddingModel = EmbeddingModel<T>;
186
187    /// Create an embedding model with the given name.
188    /// Note: default embedding dimension of 0 will be used if model is not known.
189    /// If this is the case, it's better to use function `embedding_model_with_ndims`
190    ///
191    /// # Example
192    /// ```
193    /// use rig::providers::together_ai::{Client, self};
194    ///
195    /// // Initialize the Together AI client
196    /// let together_ai = Client::new("your-together-ai-api-key");
197    ///
198    /// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1);
199    /// ```
200    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
201        let ndims = match model {
202            M2_BERT_80M_8K_RETRIEVAL => 8192,
203            _ => 0,
204        };
205        EmbeddingModel::new(self.clone(), model, ndims)
206    }
207
208    /// Create an embedding model with the given name and the number of dimensions in the embedding
209    /// generated by the model.
210    ///
211    /// # Example
212    /// ```
213    /// use rig::providers::together_ai::{Client, self};
214    ///
215    /// // Initialize the Together AI client
216    /// let together_ai = Client::new("your-together-ai-api-key");
217    ///
218    /// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
219    /// ```
220    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
221        EmbeddingModel::new(self.clone(), model, ndims)
222    }
223}
224
225impl<T> VerifyClient for Client<T>
226where
227    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
228{
229    #[cfg_attr(feature = "worker", worker::send)]
230    async fn verify(&self) -> Result<(), VerifyError> {
231        let req = self
232            .get("/models")?
233            .body(http_client::NoBody)
234            .map_err(|e| VerifyError::HttpError(e.into()))?;
235
236        let response = HttpClientExt::send(&self.http_client, req).await?;
237
238        match response.status() {
239            reqwest::StatusCode::OK => Ok(()),
240            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
241            reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::GATEWAY_TIMEOUT => {
242                let text = http_client::text(response).await?;
243                Err(VerifyError::ProviderError(text))
244            }
245            _ => {
246                //response.error_for_status()?;
247                Ok(())
248            }
249        }
250    }
251}
252
253impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client<T>);
254
255pub mod together_ai_api_types {
256    use serde::Deserialize;
257
258    impl ApiErrorResponse {
259        pub fn message(&self) -> String {
260            format!("Code `{}`: {}", self.code, self.error)
261        }
262    }
263
264    #[derive(Debug, Deserialize)]
265    pub struct ApiErrorResponse {
266        pub error: String,
267        pub code: String,
268    }
269
270    #[derive(Debug, Deserialize)]
271    #[serde(untagged)]
272    pub enum ApiResponse<T> {
273        Ok(T),
274        Error(ApiErrorResponse),
275    }
276}