1use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel};
2use crate::{
3 client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits},
4 http_client::{self, HttpClientExt},
5};
6use bytes::Bytes;
7use rig::client::CompletionClient;
8
9const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
13
14pub struct ClientBuilder<'a, T = reqwest::Client> {
15 api_key: &'a str,
16 base_url: &'a str,
17 http_client: T,
18}
19
20impl<'a, T> ClientBuilder<'a, T>
21where
22 T: Default,
23{
24 pub fn new(api_key: &'a str) -> Self {
25 Self {
26 api_key,
27 base_url: TOGETHER_AI_BASE_URL,
28 http_client: Default::default(),
29 }
30 }
31}
32
33impl<'a, T> ClientBuilder<'a, T> {
34 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
35 Self {
36 api_key,
37 base_url: TOGETHER_AI_BASE_URL,
38 http_client,
39 }
40 }
41
42 pub fn base_url(mut self, base_url: &'a str) -> Self {
43 self.base_url = base_url;
44 self
45 }
46
47 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
48 ClientBuilder {
49 api_key: self.api_key,
50 base_url: self.base_url,
51 http_client,
52 }
53 }
54
55 pub fn build(self) -> Client<T> {
56 let mut default_headers = reqwest::header::HeaderMap::new();
57 default_headers.insert(
58 reqwest::header::CONTENT_TYPE,
59 "application/json".parse().unwrap(),
60 );
61
62 Client {
63 base_url: self.base_url.to_string(),
64 api_key: self.api_key.to_string(),
65 default_headers,
66 http_client: self.http_client,
67 }
68 }
69}
70#[derive(Clone)]
71pub struct Client<T = reqwest::Client> {
72 base_url: String,
73 default_headers: reqwest::header::HeaderMap,
74 api_key: String,
75 pub http_client: T,
76}
77
78impl<T> std::fmt::Debug for Client<T>
79where
80 T: std::fmt::Debug,
81{
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("Client")
84 .field("base_url", &self.base_url)
85 .field("http_client", &self.http_client)
86 .field("default_headers", &self.default_headers)
87 .field("api_key", &"<REDACTED>")
88 .finish()
89 }
90}
91
92impl<T> Client<T>
93where
94 T: HttpClientExt,
95{
96 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
97 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
98
99 tracing::debug!("POST {}", url);
100
101 let mut req = http_client::Request::post(url);
102
103 if let Some(hs) = req.headers_mut() {
104 *hs = self.default_headers.clone();
105 }
106
107 http_client::with_bearer_auth(req, &self.api_key)
108 }
109
110 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
111 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
112
113 tracing::debug!("GET {}", url);
114
115 let mut req = http_client::Request::get(url);
116
117 if let Some(hs) = req.headers_mut() {
118 *hs = self.default_headers.clone();
119 }
120
121 http_client::with_bearer_auth(req, &self.api_key)
122 }
123
124 pub(crate) async fn send<U, R>(
125 &self,
126 req: http_client::Request<U>,
127 ) -> http_client::Result<http::Response<http_client::LazyBody<R>>>
128 where
129 U: Into<Bytes> + Send,
130 R: From<Bytes> + Send + 'static,
131 {
132 self.http_client.send(req).await
133 }
134}
135
136impl Client<reqwest::Client> {
137 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
138 ClientBuilder::new(api_key)
139 }
140
141 pub fn new(api_key: &str) -> Self {
142 Self::builder(api_key).build()
143 }
144
145 pub fn from_env() -> Self {
146 <Self as ProviderClient>::from_env()
147 }
148}
149
150impl<T> ProviderClient for Client<T>
151where
152 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
153{
154 fn from_env() -> Self {
157 let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
158 ClientBuilder::<T>::new(&api_key).build()
159 }
160
161 fn from_val(input: crate::client::ProviderValue) -> Self {
162 let crate::client::ProviderValue::Simple(api_key) = input else {
163 panic!("Incorrect provider value type")
164 };
165 ClientBuilder::<T>::new(&api_key).build()
166 }
167}
168
169impl<T> CompletionClient for Client<T>
170where
171 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
172{
173 type CompletionModel = CompletionModel<T>;
174
175 fn completion_model(&self, model: &str) -> Self::CompletionModel {
177 CompletionModel::new(self.clone(), model)
178 }
179}
180
181impl<T> EmbeddingsClient for Client<T>
182where
183 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
184{
185 type EmbeddingModel = EmbeddingModel<T>;
186
187 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
201 let ndims = match model {
202 M2_BERT_80M_8K_RETRIEVAL => 8192,
203 _ => 0,
204 };
205 EmbeddingModel::new(self.clone(), model, ndims)
206 }
207
208 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
221 EmbeddingModel::new(self.clone(), model, ndims)
222 }
223}
224
225impl<T> VerifyClient for Client<T>
226where
227 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
228{
229 #[cfg_attr(feature = "worker", worker::send)]
230 async fn verify(&self) -> Result<(), VerifyError> {
231 let req = self
232 .get("/models")?
233 .body(http_client::NoBody)
234 .map_err(|e| VerifyError::HttpError(e.into()))?;
235
236 let response = HttpClientExt::send(&self.http_client, req).await?;
237
238 match response.status() {
239 reqwest::StatusCode::OK => Ok(()),
240 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
241 reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::GATEWAY_TIMEOUT => {
242 let text = http_client::text(response).await?;
243 Err(VerifyError::ProviderError(text))
244 }
245 _ => {
246 Ok(())
248 }
249 }
250 }
251}
252
253impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client<T>);
254
255pub mod together_ai_api_types {
256 use serde::Deserialize;
257
258 impl ApiErrorResponse {
259 pub fn message(&self) -> String {
260 format!("Code `{}`: {}", self.code, self.error)
261 }
262 }
263
264 #[derive(Debug, Deserialize)]
265 pub struct ApiErrorResponse {
266 pub error: String,
267 pub code: String,
268 }
269
270 #[derive(Debug, Deserialize)]
271 #[serde(untagged)]
272 pub enum ApiResponse<T> {
273 Ok(T),
274 Error(ApiErrorResponse),
275 }
276}