rig/providers/together/
client.rs

1use super::{completion::CompletionModel, embedding::EmbeddingModel, M2_BERT_80M_8K_RETRIEVAL};
2use crate::client::{impl_conversion_traits, EmbeddingsClient, ProviderClient};
3use rig::client::CompletionClient;
4
5// ================================================================
6// Together AI Client
7// ================================================================
8const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
9
10#[derive(Debug, Clone)]
11pub struct Client {
12    base_url: String,
13    http_client: reqwest::Client,
14}
15
16impl Client {
17    /// Create a new Together AI client with the given API key.
18    pub fn new(api_key: &str) -> Self {
19        Self::from_url(api_key, TOGETHER_AI_BASE_URL)
20    }
21
22    fn from_url(api_key: &str, base_url: &str) -> Self {
23        Self {
24            base_url: base_url.to_string(),
25            http_client: reqwest::Client::builder()
26                .default_headers({
27                    let mut headers = reqwest::header::HeaderMap::new();
28                    headers.insert(
29                        reqwest::header::CONTENT_TYPE,
30                        "application/json".parse().unwrap(),
31                    );
32                    headers.insert(
33                        "Authorization",
34                        format!("Bearer {api_key}")
35                            .parse()
36                            .expect("Bearer token should parse"),
37                    );
38                    headers
39                })
40                .build()
41                .expect("Together AI reqwest client should build"),
42        }
43    }
44
45    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
46        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
47
48        tracing::debug!("POST {}", url);
49        self.http_client.post(url)
50    }
51}
52
53impl ProviderClient for Client {
54    /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
55    /// Panics if the environment variable is not set.
56    fn from_env() -> Self {
57        let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
58        Self::new(&api_key)
59    }
60}
61
62impl CompletionClient for Client {
63    type CompletionModel = CompletionModel;
64
65    /// Create a completion model with the given name.
66    fn completion_model(&self, model: &str) -> CompletionModel {
67        CompletionModel::new(self.clone(), model)
68    }
69}
70
71impl EmbeddingsClient for Client {
72    type EmbeddingModel = EmbeddingModel;
73
74    /// Create an embedding model with the given name.
75    /// Note: default embedding dimension of 0 will be used if model is not known.
76    /// If this is the case, it's better to use function `embedding_model_with_ndims`
77    ///
78    /// # Example
79    /// ```
80    /// use rig::providers::together_ai::{Client, self};
81    ///
82    /// // Initialize the Together AI client
83    /// let together_ai = Client::new("your-together-ai-api-key");
84    ///
85    /// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1);
86    /// ```
87    fn embedding_model(&self, model: &str) -> EmbeddingModel {
88        let ndims = match model {
89            M2_BERT_80M_8K_RETRIEVAL => 8192,
90            _ => 0,
91        };
92        EmbeddingModel::new(self.clone(), model, ndims)
93    }
94
95    /// Create an embedding model with the given name and the number of dimensions in the embedding
96    /// generated by the model.
97    ///
98    /// # Example
99    /// ```
100    /// use rig::providers::together_ai::{Client, self};
101    ///
102    /// // Initialize the Together AI client
103    /// let together_ai = Client::new("your-together-ai-api-key");
104    ///
105    /// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
106    /// ```
107    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
108        EmbeddingModel::new(self.clone(), model, ndims)
109    }
110}
111
112impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client);
113
114pub mod together_ai_api_types {
115    use serde::Deserialize;
116
117    impl ApiErrorResponse {
118        pub fn message(&self) -> String {
119            format!("Code `{}`: {}", self.code, self.error)
120        }
121    }
122
123    #[derive(Debug, Deserialize)]
124    pub struct ApiErrorResponse {
125        pub error: String,
126        pub code: String,
127    }
128
129    #[derive(Debug, Deserialize)]
130    #[serde(untagged)]
131    pub enum ApiResponse<T> {
132        Ok(T),
133        Error(ApiErrorResponse),
134    }
135}