1use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel};
2use crate::{
3 client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits},
4 http_client::{self, HttpClientExt},
5};
6use bytes::Bytes;
7use rig::client::CompletionClient;
8
9const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
13
14pub struct ClientBuilder<'a, T = reqwest::Client> {
15 api_key: &'a str,
16 base_url: &'a str,
17 http_client: T,
18}
19
20impl<'a, T> ClientBuilder<'a, T>
21where
22 T: Default,
23{
24 pub fn new(api_key: &'a str) -> Self {
25 Self {
26 api_key,
27 base_url: TOGETHER_AI_BASE_URL,
28 http_client: Default::default(),
29 }
30 }
31}
32
33impl<'a, T> ClientBuilder<'a, T> {
34 pub fn base_url(mut self, base_url: &'a str) -> Self {
35 self.base_url = base_url;
36 self
37 }
38
39 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
40 ClientBuilder {
41 api_key: self.api_key,
42 base_url: self.base_url,
43 http_client,
44 }
45 }
46
47 pub fn build(self) -> Client<T> {
48 let mut default_headers = reqwest::header::HeaderMap::new();
49 default_headers.insert(
50 reqwest::header::CONTENT_TYPE,
51 "application/json".parse().unwrap(),
52 );
53
54 Client {
55 base_url: self.base_url.to_string(),
56 api_key: self.api_key.to_string(),
57 default_headers,
58 http_client: self.http_client,
59 }
60 }
61}
62#[derive(Clone)]
63pub struct Client<T = reqwest::Client> {
64 base_url: String,
65 default_headers: reqwest::header::HeaderMap,
66 api_key: String,
67 http_client: T,
68}
69
70impl<T> std::fmt::Debug for Client<T>
71where
72 T: std::fmt::Debug,
73{
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct("Client")
76 .field("base_url", &self.base_url)
77 .field("http_client", &self.http_client)
78 .field("default_headers", &self.default_headers)
79 .field("api_key", &"<REDACTED>")
80 .finish()
81 }
82}
83
84impl<T> Client<T>
85where
86 T: Default,
87{
88 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
99 ClientBuilder::new(api_key)
100 }
101
102 pub fn new(api_key: &str) -> Self {
107 Self::builder(api_key).build()
108 }
109}
110
111impl<T> Client<T>
112where
113 T: HttpClientExt,
114{
115 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
116 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
117
118 tracing::debug!("POST {}", url);
119
120 let mut req = http_client::Request::post(url);
121
122 if let Some(hs) = req.headers_mut() {
123 *hs = self.default_headers.clone();
124 }
125
126 http_client::with_bearer_auth(req, &self.api_key)
127 }
128
129 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
130 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
131
132 tracing::debug!("GET {}", url);
133
134 let mut req = http_client::Request::get(url);
135
136 if let Some(hs) = req.headers_mut() {
137 *hs = self.default_headers.clone();
138 }
139
140 http_client::with_bearer_auth(req, &self.api_key)
141 }
142
143 pub(crate) async fn send<U, R>(
144 &self,
145 req: http_client::Request<U>,
146 ) -> http_client::Result<http::Response<http_client::LazyBody<R>>>
147 where
148 U: Into<Bytes> + Send,
149 R: From<Bytes> + Send + 'static,
150 {
151 self.http_client.send(req).await
152 }
153}
154
155impl Client<reqwest::Client> {
156 pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
157 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
158
159 tracing::debug!("POST {}", url);
160
161 self.http_client
162 .post(url)
163 .bearer_auth(&self.api_key)
164 .headers(self.default_headers.clone())
165 }
166}
167
168impl ProviderClient for Client<reqwest::Client> {
169 fn from_env() -> Self {
172 let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
173 Self::new(&api_key)
174 }
175
176 fn from_val(input: crate::client::ProviderValue) -> Self {
177 let crate::client::ProviderValue::Simple(api_key) = input else {
178 panic!("Incorrect provider value type")
179 };
180 Self::new(&api_key)
181 }
182}
183
184impl CompletionClient for Client<reqwest::Client> {
185 type CompletionModel = CompletionModel<reqwest::Client>;
186
187 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
189 CompletionModel::new(self.clone(), model)
190 }
191}
192
193impl EmbeddingsClient for Client<reqwest::Client> {
194 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
195
196 fn embedding_model(&self, model: &str) -> EmbeddingModel<reqwest::Client> {
210 let ndims = match model {
211 M2_BERT_80M_8K_RETRIEVAL => 8192,
212 _ => 0,
213 };
214 EmbeddingModel::new(self.clone(), model, ndims)
215 }
216
217 fn embedding_model_with_ndims(
230 &self,
231 model: &str,
232 ndims: usize,
233 ) -> EmbeddingModel<reqwest::Client> {
234 EmbeddingModel::new(self.clone(), model, ndims)
235 }
236}
237
238impl VerifyClient for Client<reqwest::Client> {
239 #[cfg_attr(feature = "worker", worker::send)]
240 async fn verify(&self) -> Result<(), VerifyError> {
241 let req = self
242 .get("/models")?
243 .body(http_client::NoBody)
244 .map_err(|e| VerifyError::HttpError(e.into()))?;
245
246 let response = HttpClientExt::send(&self.http_client, req).await?;
247
248 match response.status() {
249 reqwest::StatusCode::OK => Ok(()),
250 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
251 reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::GATEWAY_TIMEOUT => {
252 let text = http_client::text(response).await?;
253 Err(VerifyError::ProviderError(text))
254 }
255 _ => {
256 Ok(())
258 }
259 }
260 }
261}
262
263impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client<T>);
264
265pub mod together_ai_api_types {
266 use serde::Deserialize;
267
268 impl ApiErrorResponse {
269 pub fn message(&self) -> String {
270 format!("Code `{}`: {}", self.code, self.error)
271 }
272 }
273
274 #[derive(Debug, Deserialize)]
275 pub struct ApiErrorResponse {
276 pub error: String,
277 pub code: String,
278 }
279
280 #[derive(Debug, Deserialize)]
281 #[serde(untagged)]
282 pub enum ApiResponse<T> {
283 Ok(T),
284 Error(ApiErrorResponse),
285 }
286}