rig/providers/together/
client.rs

1use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel};
2use crate::client::{ClientBuilderError, EmbeddingsClient, ProviderClient, impl_conversion_traits};
3use rig::client::CompletionClient;
4
5// ================================================================
6// Together AI Client
7// ================================================================
8const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
9
10pub struct ClientBuilder<'a> {
11    api_key: &'a str,
12    base_url: &'a str,
13    http_client: Option<reqwest::Client>,
14}
15
16impl<'a> ClientBuilder<'a> {
17    pub fn new(api_key: &'a str) -> Self {
18        Self {
19            api_key,
20            base_url: TOGETHER_AI_BASE_URL,
21            http_client: None,
22        }
23    }
24
25    pub fn base_url(mut self, base_url: &'a str) -> Self {
26        self.base_url = base_url;
27        self
28    }
29
30    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
31        self.http_client = Some(client);
32        self
33    }
34
35    pub fn build(self) -> Result<Client, ClientBuilderError> {
36        let mut default_headers = reqwest::header::HeaderMap::new();
37        default_headers.insert(
38            reqwest::header::CONTENT_TYPE,
39            "application/json".parse().unwrap(),
40        );
41
42        let http_client = if let Some(http_client) = self.http_client {
43            http_client
44        } else {
45            reqwest::Client::builder().build()?
46        };
47
48        Ok(Client {
49            base_url: self.base_url.to_string(),
50            api_key: self.api_key.to_string(),
51            default_headers,
52            http_client,
53        })
54    }
55}
56#[derive(Clone)]
57pub struct Client {
58    base_url: String,
59    default_headers: reqwest::header::HeaderMap,
60    api_key: String,
61    http_client: reqwest::Client,
62}
63
64impl std::fmt::Debug for Client {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("Client")
67            .field("base_url", &self.base_url)
68            .field("http_client", &self.http_client)
69            .field("default_headers", &self.default_headers)
70            .field("api_key", &"<REDACTED>")
71            .finish()
72    }
73}
74
75impl Client {
76    /// Create a new Together AI client builder.
77    ///
78    /// # Example
79    /// ```
80    /// use rig::providers::together_ai::{ClientBuilder, self};
81    ///
82    /// // Initialize the Together AI client
83    /// let together_ai = Client::builder("your-together-ai-api-key")
84    ///    .build()
85    /// ```
86    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
87        ClientBuilder::new(api_key)
88    }
89
90    /// Create a new Together AI client. For more control, use the `builder` method.
91    ///
92    /// # Panics
93    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
94    pub fn new(api_key: &str) -> Self {
95        Self::builder(api_key)
96            .build()
97            .expect("Together AI client should build")
98    }
99
100    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
101        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
102
103        tracing::debug!("POST {}", url);
104        self.http_client
105            .post(url)
106            .bearer_auth(&self.api_key)
107            .headers(self.default_headers.clone())
108    }
109}
110
111impl ProviderClient for Client {
112    /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
113    /// Panics if the environment variable is not set.
114    fn from_env() -> Self {
115        let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
116        Self::new(&api_key)
117    }
118
119    fn from_val(input: crate::client::ProviderValue) -> Self {
120        let crate::client::ProviderValue::Simple(api_key) = input else {
121            panic!("Incorrect provider value type")
122        };
123        Self::new(&api_key)
124    }
125}
126
127impl CompletionClient for Client {
128    type CompletionModel = CompletionModel;
129
130    /// Create a completion model with the given name.
131    fn completion_model(&self, model: &str) -> CompletionModel {
132        CompletionModel::new(self.clone(), model)
133    }
134}
135
136impl EmbeddingsClient for Client {
137    type EmbeddingModel = EmbeddingModel;
138
139    /// Create an embedding model with the given name.
140    /// Note: default embedding dimension of 0 will be used if model is not known.
141    /// If this is the case, it's better to use function `embedding_model_with_ndims`
142    ///
143    /// # Example
144    /// ```
145    /// use rig::providers::together_ai::{Client, self};
146    ///
147    /// // Initialize the Together AI client
148    /// let together_ai = Client::new("your-together-ai-api-key");
149    ///
150    /// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1);
151    /// ```
152    fn embedding_model(&self, model: &str) -> EmbeddingModel {
153        let ndims = match model {
154            M2_BERT_80M_8K_RETRIEVAL => 8192,
155            _ => 0,
156        };
157        EmbeddingModel::new(self.clone(), model, ndims)
158    }
159
160    /// Create an embedding model with the given name and the number of dimensions in the embedding
161    /// generated by the model.
162    ///
163    /// # Example
164    /// ```
165    /// use rig::providers::together_ai::{Client, self};
166    ///
167    /// // Initialize the Together AI client
168    /// let together_ai = Client::new("your-together-ai-api-key");
169    ///
170    /// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
171    /// ```
172    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
173        EmbeddingModel::new(self.clone(), model, ndims)
174    }
175}
176
177impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client);
178
179pub mod together_ai_api_types {
180    use serde::Deserialize;
181
182    impl ApiErrorResponse {
183        pub fn message(&self) -> String {
184            format!("Code `{}`: {}", self.code, self.error)
185        }
186    }
187
188    #[derive(Debug, Deserialize)]
189    pub struct ApiErrorResponse {
190        pub error: String,
191        pub code: String,
192    }
193
194    #[derive(Debug, Deserialize)]
195    #[serde(untagged)]
196    pub enum ApiResponse<T> {
197        Ok(T),
198        Error(ApiErrorResponse),
199    }
200}