rig/providers/together/
client.rs

1use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel};
2use crate::client::{
3    ClientBuilderError, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError,
4    impl_conversion_traits,
5};
6use rig::client::CompletionClient;
7
8// ================================================================
9// Together AI Client
10// ================================================================
11const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
12
13pub struct ClientBuilder<'a> {
14    api_key: &'a str,
15    base_url: &'a str,
16    http_client: Option<reqwest::Client>,
17}
18
19impl<'a> ClientBuilder<'a> {
20    pub fn new(api_key: &'a str) -> Self {
21        Self {
22            api_key,
23            base_url: TOGETHER_AI_BASE_URL,
24            http_client: None,
25        }
26    }
27
28    pub fn base_url(mut self, base_url: &'a str) -> Self {
29        self.base_url = base_url;
30        self
31    }
32
33    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
34        self.http_client = Some(client);
35        self
36    }
37
38    pub fn build(self) -> Result<Client, ClientBuilderError> {
39        let mut default_headers = reqwest::header::HeaderMap::new();
40        default_headers.insert(
41            reqwest::header::CONTENT_TYPE,
42            "application/json".parse().unwrap(),
43        );
44
45        let http_client = if let Some(http_client) = self.http_client {
46            http_client
47        } else {
48            reqwest::Client::builder().build()?
49        };
50
51        Ok(Client {
52            base_url: self.base_url.to_string(),
53            api_key: self.api_key.to_string(),
54            default_headers,
55            http_client,
56        })
57    }
58}
59#[derive(Clone)]
60pub struct Client {
61    base_url: String,
62    default_headers: reqwest::header::HeaderMap,
63    api_key: String,
64    http_client: reqwest::Client,
65}
66
67impl std::fmt::Debug for Client {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.debug_struct("Client")
70            .field("base_url", &self.base_url)
71            .field("http_client", &self.http_client)
72            .field("default_headers", &self.default_headers)
73            .field("api_key", &"<REDACTED>")
74            .finish()
75    }
76}
77
78impl Client {
79    /// Create a new Together AI client builder.
80    ///
81    /// # Example
82    /// ```
83    /// use rig::providers::together_ai::{ClientBuilder, self};
84    ///
85    /// // Initialize the Together AI client
86    /// let together_ai = Client::builder("your-together-ai-api-key")
87    ///    .build()
88    /// ```
89    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
90        ClientBuilder::new(api_key)
91    }
92
93    /// Create a new Together AI client. For more control, use the `builder` method.
94    ///
95    /// # Panics
96    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
97    pub fn new(api_key: &str) -> Self {
98        Self::builder(api_key)
99            .build()
100            .expect("Together AI client should build")
101    }
102
103    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
104        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
105
106        tracing::debug!("POST {}", url);
107        self.http_client
108            .post(url)
109            .bearer_auth(&self.api_key)
110            .headers(self.default_headers.clone())
111    }
112
113    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
114        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
115
116        tracing::debug!("GET {}", url);
117        self.http_client
118            .get(url)
119            .bearer_auth(&self.api_key)
120            .headers(self.default_headers.clone())
121    }
122}
123
124impl ProviderClient for Client {
125    /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
126    /// Panics if the environment variable is not set.
127    fn from_env() -> Self {
128        let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
129        Self::new(&api_key)
130    }
131
132    fn from_val(input: crate::client::ProviderValue) -> Self {
133        let crate::client::ProviderValue::Simple(api_key) = input else {
134            panic!("Incorrect provider value type")
135        };
136        Self::new(&api_key)
137    }
138}
139
140impl CompletionClient for Client {
141    type CompletionModel = CompletionModel;
142
143    /// Create a completion model with the given name.
144    fn completion_model(&self, model: &str) -> CompletionModel {
145        CompletionModel::new(self.clone(), model)
146    }
147}
148
149impl EmbeddingsClient for Client {
150    type EmbeddingModel = EmbeddingModel;
151
152    /// Create an embedding model with the given name.
153    /// Note: default embedding dimension of 0 will be used if model is not known.
154    /// If this is the case, it's better to use function `embedding_model_with_ndims`
155    ///
156    /// # Example
157    /// ```
158    /// use rig::providers::together_ai::{Client, self};
159    ///
160    /// // Initialize the Together AI client
161    /// let together_ai = Client::new("your-together-ai-api-key");
162    ///
163    /// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1);
164    /// ```
165    fn embedding_model(&self, model: &str) -> EmbeddingModel {
166        let ndims = match model {
167            M2_BERT_80M_8K_RETRIEVAL => 8192,
168            _ => 0,
169        };
170        EmbeddingModel::new(self.clone(), model, ndims)
171    }
172
173    /// Create an embedding model with the given name and the number of dimensions in the embedding
174    /// generated by the model.
175    ///
176    /// # Example
177    /// ```
178    /// use rig::providers::together_ai::{Client, self};
179    ///
180    /// // Initialize the Together AI client
181    /// let together_ai = Client::new("your-together-ai-api-key");
182    ///
183    /// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
184    /// ```
185    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
186        EmbeddingModel::new(self.clone(), model, ndims)
187    }
188}
189
190impl VerifyClient for Client {
191    #[cfg_attr(feature = "worker", worker::send)]
192    async fn verify(&self) -> Result<(), VerifyError> {
193        let response = self.get("/models").send().await?;
194        match response.status() {
195            reqwest::StatusCode::OK => Ok(()),
196            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
197            reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::GATEWAY_TIMEOUT => {
198                Err(VerifyError::ProviderError(response.text().await?))
199            }
200            _ => {
201                response.error_for_status()?;
202                Ok(())
203            }
204        }
205    }
206}
207
208impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client);
209
210pub mod together_ai_api_types {
211    use serde::Deserialize;
212
213    impl ApiErrorResponse {
214        pub fn message(&self) -> String {
215            format!("Code `{}`: {}", self.code, self.error)
216        }
217    }
218
219    #[derive(Debug, Deserialize)]
220    pub struct ApiErrorResponse {
221        pub error: String,
222        pub code: String,
223    }
224
225    #[derive(Debug, Deserialize)]
226    #[serde(untagged)]
227    pub enum ApiResponse<T> {
228        Ok(T),
229        Error(ApiErrorResponse),
230    }
231}