rig/providers/together/
client.rs

1use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel};
2use crate::client::{EmbeddingsClient, ProviderClient, impl_conversion_traits};
3use rig::client::CompletionClient;
4
5// ================================================================
6// Together AI Client
7// ================================================================
8const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
9
10#[derive(Clone)]
11pub struct Client {
12    base_url: String,
13    default_headers: reqwest::header::HeaderMap,
14    api_key: String,
15    http_client: reqwest::Client,
16}
17
18impl std::fmt::Debug for Client {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("Client")
21            .field("base_url", &self.base_url)
22            .field("http_client", &self.http_client)
23            .field("default_headers", &self.default_headers)
24            .field("api_key", &"<REDACTED>")
25            .finish()
26    }
27}
28
29impl Client {
30    /// Create a new Together AI client with the given API key.
31    pub fn new(api_key: &str) -> Self {
32        Self::from_url(api_key, TOGETHER_AI_BASE_URL)
33    }
34
35    fn from_url(api_key: &str, base_url: &str) -> Self {
36        let mut default_headers = reqwest::header::HeaderMap::new();
37        default_headers.insert(
38            reqwest::header::CONTENT_TYPE,
39            "application/json".parse().unwrap(),
40        );
41        Self {
42            base_url: base_url.to_string(),
43            api_key: api_key.to_string(),
44            default_headers,
45            http_client: reqwest::Client::builder()
46                .build()
47                .expect("Together AI reqwest client should build"),
48        }
49    }
50
51    /// Use your own `reqwest::Client`.
52    /// The required headers will be automatically attached upon trying to make a request.
53    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
54        self.http_client = client;
55
56        self
57    }
58
59    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
60        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
61
62        tracing::debug!("POST {}", url);
63        self.http_client
64            .post(url)
65            .bearer_auth(&self.api_key)
66            .headers(self.default_headers.clone())
67    }
68}
69
70impl ProviderClient for Client {
71    /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
72    /// Panics if the environment variable is not set.
73    fn from_env() -> Self {
74        let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
75        Self::new(&api_key)
76    }
77
78    fn from_val(input: crate::client::ProviderValue) -> Self {
79        let crate::client::ProviderValue::Simple(api_key) = input else {
80            panic!("Incorrect provider value type")
81        };
82        Self::new(&api_key)
83    }
84}
85
86impl CompletionClient for Client {
87    type CompletionModel = CompletionModel;
88
89    /// Create a completion model with the given name.
90    fn completion_model(&self, model: &str) -> CompletionModel {
91        CompletionModel::new(self.clone(), model)
92    }
93}
94
95impl EmbeddingsClient for Client {
96    type EmbeddingModel = EmbeddingModel;
97
98    /// Create an embedding model with the given name.
99    /// Note: default embedding dimension of 0 will be used if model is not known.
100    /// If this is the case, it's better to use function `embedding_model_with_ndims`
101    ///
102    /// # Example
103    /// ```
104    /// use rig::providers::together_ai::{Client, self};
105    ///
106    /// // Initialize the Together AI client
107    /// let together_ai = Client::new("your-together-ai-api-key");
108    ///
109    /// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1);
110    /// ```
111    fn embedding_model(&self, model: &str) -> EmbeddingModel {
112        let ndims = match model {
113            M2_BERT_80M_8K_RETRIEVAL => 8192,
114            _ => 0,
115        };
116        EmbeddingModel::new(self.clone(), model, ndims)
117    }
118
119    /// Create an embedding model with the given name and the number of dimensions in the embedding
120    /// generated by the model.
121    ///
122    /// # Example
123    /// ```
124    /// use rig::providers::together_ai::{Client, self};
125    ///
126    /// // Initialize the Together AI client
127    /// let together_ai = Client::new("your-together-ai-api-key");
128    ///
129    /// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
130    /// ```
131    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
132        EmbeddingModel::new(self.clone(), model, ndims)
133    }
134}
135
136impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client);
137
138pub mod together_ai_api_types {
139    use serde::Deserialize;
140
141    impl ApiErrorResponse {
142        pub fn message(&self) -> String {
143            format!("Code `{}`: {}", self.code, self.error)
144        }
145    }
146
147    #[derive(Debug, Deserialize)]
148    pub struct ApiErrorResponse {
149        pub error: String,
150        pub code: String,
151    }
152
153    #[derive(Debug, Deserialize)]
154    #[serde(untagged)]
155    pub enum ApiResponse<T> {
156        Ok(T),
157        Error(ApiErrorResponse),
158    }
159}