rig/providers/together/
client.rs1use super::{completion::CompletionModel, embedding::EmbeddingModel, M2_BERT_80M_8K_RETRIEVAL};
2use crate::client::{impl_conversion_traits, EmbeddingsClient, ProviderClient};
3use rig::client::CompletionClient;
4
5const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
9
10#[derive(Debug, Clone)]
11pub struct Client {
12 base_url: String,
13 http_client: reqwest::Client,
14}
15
16impl Client {
17 pub fn new(api_key: &str) -> Self {
19 Self::from_url(api_key, TOGETHER_AI_BASE_URL)
20 }
21
22 fn from_url(api_key: &str, base_url: &str) -> Self {
23 Self {
24 base_url: base_url.to_string(),
25 http_client: reqwest::Client::builder()
26 .default_headers({
27 let mut headers = reqwest::header::HeaderMap::new();
28 headers.insert(
29 reqwest::header::CONTENT_TYPE,
30 "application/json".parse().unwrap(),
31 );
32 headers.insert(
33 "Authorization",
34 format!("Bearer {api_key}")
35 .parse()
36 .expect("Bearer token should parse"),
37 );
38 headers
39 })
40 .build()
41 .expect("Together AI reqwest client should build"),
42 }
43 }
44
45 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
46 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
47
48 tracing::debug!("POST {}", url);
49 self.http_client.post(url)
50 }
51}
52
53impl ProviderClient for Client {
54 fn from_env() -> Self {
57 let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
58 Self::new(&api_key)
59 }
60}
61
62impl CompletionClient for Client {
63 type CompletionModel = CompletionModel;
64
65 fn completion_model(&self, model: &str) -> CompletionModel {
67 CompletionModel::new(self.clone(), model)
68 }
69}
70
71impl EmbeddingsClient for Client {
72 type EmbeddingModel = EmbeddingModel;
73
74 fn embedding_model(&self, model: &str) -> EmbeddingModel {
88 let ndims = match model {
89 M2_BERT_80M_8K_RETRIEVAL => 8192,
90 _ => 0,
91 };
92 EmbeddingModel::new(self.clone(), model, ndims)
93 }
94
95 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
108 EmbeddingModel::new(self.clone(), model, ndims)
109 }
110}
111
112impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client);
113
114pub mod together_ai_api_types {
115 use serde::Deserialize;
116
117 impl ApiErrorResponse {
118 pub fn message(&self) -> String {
119 format!("Code `{}`: {}", self.code, self.error)
120 }
121 }
122
123 #[derive(Debug, Deserialize)]
124 pub struct ApiErrorResponse {
125 pub error: String,
126 pub code: String,
127 }
128
129 #[derive(Debug, Deserialize)]
130 #[serde(untagged)]
131 pub enum ApiResponse<T> {
132 Ok(T),
133 Error(ApiErrorResponse),
134 }
135}