rig/providers/together/
client.rs1use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel};
2use crate::client::{EmbeddingsClient, ProviderClient, impl_conversion_traits};
3use rig::client::CompletionClient;
4
5const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
9
10#[derive(Clone)]
11pub struct Client {
12 base_url: String,
13 default_headers: reqwest::header::HeaderMap,
14 api_key: String,
15 http_client: reqwest::Client,
16}
17
18impl std::fmt::Debug for Client {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("Client")
21 .field("base_url", &self.base_url)
22 .field("http_client", &self.http_client)
23 .field("default_headers", &self.default_headers)
24 .field("api_key", &"<REDACTED>")
25 .finish()
26 }
27}
28
29impl Client {
30 pub fn new(api_key: &str) -> Self {
32 Self::from_url(api_key, TOGETHER_AI_BASE_URL)
33 }
34
35 fn from_url(api_key: &str, base_url: &str) -> Self {
36 let mut default_headers = reqwest::header::HeaderMap::new();
37 default_headers.insert(
38 reqwest::header::CONTENT_TYPE,
39 "application/json".parse().unwrap(),
40 );
41 Self {
42 base_url: base_url.to_string(),
43 api_key: api_key.to_string(),
44 default_headers,
45 http_client: reqwest::Client::builder()
46 .build()
47 .expect("Together AI reqwest client should build"),
48 }
49 }
50
51 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
54 self.http_client = client;
55
56 self
57 }
58
59 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
60 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
61
62 tracing::debug!("POST {}", url);
63 self.http_client
64 .post(url)
65 .bearer_auth(&self.api_key)
66 .headers(self.default_headers.clone())
67 }
68}
69
70impl ProviderClient for Client {
71 fn from_env() -> Self {
74 let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
75 Self::new(&api_key)
76 }
77
78 fn from_val(input: crate::client::ProviderValue) -> Self {
79 let crate::client::ProviderValue::Simple(api_key) = input else {
80 panic!("Incorrect provider value type")
81 };
82 Self::new(&api_key)
83 }
84}
85
86impl CompletionClient for Client {
87 type CompletionModel = CompletionModel;
88
89 fn completion_model(&self, model: &str) -> CompletionModel {
91 CompletionModel::new(self.clone(), model)
92 }
93}
94
95impl EmbeddingsClient for Client {
96 type EmbeddingModel = EmbeddingModel;
97
98 fn embedding_model(&self, model: &str) -> EmbeddingModel {
112 let ndims = match model {
113 M2_BERT_80M_8K_RETRIEVAL => 8192,
114 _ => 0,
115 };
116 EmbeddingModel::new(self.clone(), model, ndims)
117 }
118
119 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
132 EmbeddingModel::new(self.clone(), model, ndims)
133 }
134}
135
136impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client);
137
138pub mod together_ai_api_types {
139 use serde::Deserialize;
140
141 impl ApiErrorResponse {
142 pub fn message(&self) -> String {
143 format!("Code `{}`: {}", self.code, self.error)
144 }
145 }
146
147 #[derive(Debug, Deserialize)]
148 pub struct ApiErrorResponse {
149 pub error: String,
150 pub code: String,
151 }
152
153 #[derive(Debug, Deserialize)]
154 #[serde(untagged)]
155 pub enum ApiResponse<T> {
156 Ok(T),
157 Error(ApiErrorResponse),
158 }
159}