rig/providers/together/
client.rs1use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel};
2use crate::client::{ClientBuilderError, EmbeddingsClient, ProviderClient, impl_conversion_traits};
3use rig::client::CompletionClient;
4
5const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
9
10pub struct ClientBuilder<'a> {
11 api_key: &'a str,
12 base_url: &'a str,
13 http_client: Option<reqwest::Client>,
14}
15
16impl<'a> ClientBuilder<'a> {
17 pub fn new(api_key: &'a str) -> Self {
18 Self {
19 api_key,
20 base_url: TOGETHER_AI_BASE_URL,
21 http_client: None,
22 }
23 }
24
25 pub fn base_url(mut self, base_url: &'a str) -> Self {
26 self.base_url = base_url;
27 self
28 }
29
30 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
31 self.http_client = Some(client);
32 self
33 }
34
35 pub fn build(self) -> Result<Client, ClientBuilderError> {
36 let mut default_headers = reqwest::header::HeaderMap::new();
37 default_headers.insert(
38 reqwest::header::CONTENT_TYPE,
39 "application/json".parse().unwrap(),
40 );
41
42 let http_client = if let Some(http_client) = self.http_client {
43 http_client
44 } else {
45 reqwest::Client::builder().build()?
46 };
47
48 Ok(Client {
49 base_url: self.base_url.to_string(),
50 api_key: self.api_key.to_string(),
51 default_headers,
52 http_client,
53 })
54 }
55}
56#[derive(Clone)]
57pub struct Client {
58 base_url: String,
59 default_headers: reqwest::header::HeaderMap,
60 api_key: String,
61 http_client: reqwest::Client,
62}
63
64impl std::fmt::Debug for Client {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("Client")
67 .field("base_url", &self.base_url)
68 .field("http_client", &self.http_client)
69 .field("default_headers", &self.default_headers)
70 .field("api_key", &"<REDACTED>")
71 .finish()
72 }
73}
74
75impl Client {
76 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
87 ClientBuilder::new(api_key)
88 }
89
90 pub fn new(api_key: &str) -> Self {
95 Self::builder(api_key)
96 .build()
97 .expect("Together AI client should build")
98 }
99
100 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
101 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
102
103 tracing::debug!("POST {}", url);
104 self.http_client
105 .post(url)
106 .bearer_auth(&self.api_key)
107 .headers(self.default_headers.clone())
108 }
109}
110
111impl ProviderClient for Client {
112 fn from_env() -> Self {
115 let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
116 Self::new(&api_key)
117 }
118
119 fn from_val(input: crate::client::ProviderValue) -> Self {
120 let crate::client::ProviderValue::Simple(api_key) = input else {
121 panic!("Incorrect provider value type")
122 };
123 Self::new(&api_key)
124 }
125}
126
127impl CompletionClient for Client {
128 type CompletionModel = CompletionModel;
129
130 fn completion_model(&self, model: &str) -> CompletionModel {
132 CompletionModel::new(self.clone(), model)
133 }
134}
135
136impl EmbeddingsClient for Client {
137 type EmbeddingModel = EmbeddingModel;
138
139 fn embedding_model(&self, model: &str) -> EmbeddingModel {
153 let ndims = match model {
154 M2_BERT_80M_8K_RETRIEVAL => 8192,
155 _ => 0,
156 };
157 EmbeddingModel::new(self.clone(), model, ndims)
158 }
159
160 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
173 EmbeddingModel::new(self.clone(), model, ndims)
174 }
175}
176
177impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client);
178
179pub mod together_ai_api_types {
180 use serde::Deserialize;
181
182 impl ApiErrorResponse {
183 pub fn message(&self) -> String {
184 format!("Code `{}`: {}", self.code, self.error)
185 }
186 }
187
188 #[derive(Debug, Deserialize)]
189 pub struct ApiErrorResponse {
190 pub error: String,
191 pub code: String,
192 }
193
194 #[derive(Debug, Deserialize)]
195 #[serde(untagged)]
196 pub enum ApiResponse<T> {
197 Ok(T),
198 Error(ApiErrorResponse),
199 }
200}