rig/providers/together/
client.rs

1use crate::{
2    agent::AgentBuilder,
3    embeddings::{self},
4    extractor::ExtractorBuilder,
5    Embed,
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9
10use super::{completion::CompletionModel, embedding::EmbeddingModel, M2_BERT_80M_8K_RETRIEVAL};
11
12// ================================================================
13// Together AI Client
14// ================================================================
15const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
16
17#[derive(Clone)]
18pub struct Client {
19    base_url: String,
20    http_client: reqwest::Client,
21}
22
23impl Client {
24    /// Create a new Together AI client with the given API key.
25    pub fn new(api_key: &str) -> Self {
26        Self::from_url(api_key, TOGETHER_AI_BASE_URL)
27    }
28
29    fn from_url(api_key: &str, base_url: &str) -> Self {
30        Self {
31            base_url: base_url.to_string(),
32            http_client: reqwest::Client::builder()
33                .default_headers({
34                    let mut headers = reqwest::header::HeaderMap::new();
35                    headers.insert(
36                        reqwest::header::CONTENT_TYPE,
37                        "application/json".parse().unwrap(),
38                    );
39                    headers.insert(
40                        "Authorization",
41                        format!("Bearer {}", api_key)
42                            .parse()
43                            .expect("Bearer token should parse"),
44                    );
45                    headers
46                })
47                .build()
48                .expect("Together AI reqwest client should build"),
49        }
50    }
51
52    /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
53    /// Panics if the environment variable is not set.
54    pub fn from_env() -> Self {
55        let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
56        Self::new(&api_key)
57    }
58
59    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
60        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
61
62        tracing::debug!("POST {}", url);
63        self.http_client.post(url)
64    }
65
66    /// Create an embedding model with the given name.
67    /// Note: default embedding dimension of 0 will be used if model is not known.
68    /// If this is the case, it's better to use function `embedding_model_with_ndims`
69    ///
70    /// # Example
71    /// ```
72    /// use rig::providers::together_ai::{Client, self};
73    ///
74    /// // Initialize the Together AI client
75    /// let together_ai = Client::new("your-together-ai-api-key");
76    ///
77    /// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1);
78    /// ```
79    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
80        let ndims = match model {
81            M2_BERT_80M_8K_RETRIEVAL => 8192,
82            _ => 0,
83        };
84        EmbeddingModel::new(self.clone(), model, ndims)
85    }
86
87    /// Create an embedding model with the given name and the number of dimensions in the embedding
88    /// generated by the model.
89    ///
90    /// # Example
91    /// ```
92    /// use rig::providers::together_ai::{Client, self};
93    ///
94    /// // Initialize the Together AI client
95    /// let together_ai = Client::new("your-together-ai-api-key");
96    ///
97    /// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
98    /// ```
99    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
100        EmbeddingModel::new(self.clone(), model, ndims)
101    }
102
103    /// Create an embedding builder with the given embedding model.
104    ///
105    /// # Example
106    /// ```
107    /// use rig::providers::together_ai::{Client, self};
108    ///
109    /// // Initialize the Together AI client
110    /// let together_ai = Client::new("your-together-ai-api-key");
111    ///
112    /// let embeddings = together_ai.embeddings(together_ai::embedding::EMBEDDING_V1)
113    ///     .simple_document("doc0", "Hello, world!")
114    ///     .simple_document("doc1", "Goodbye, world!")
115    ///     .build()
116    ///     .await
117    ///     .expect("Failed to embed documents");
118    /// ```
119    pub fn embeddings<D: Embed>(
120        &self,
121        model: &str,
122    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
123        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
124    }
125
126    /// Create a completion model with the given name.
127    pub fn completion_model(&self, model: &str) -> CompletionModel {
128        CompletionModel::new(self.clone(), model)
129    }
130
131    /// Create an agent builder with the given completion model.
132    /// # Example
133    /// ```
134    /// use rig::providers::together_ai::{Client, self};
135    ///
136    /// // Initialize the Together AI client
137    /// let together_ai = Client::new("your-together-ai-api-key");
138    ///
139    /// let agent = together_ai.agent(together_ai::completion::MODEL_NAME)
140    ///    .preamble("You are comedian AI with a mission to make people laugh.")
141    ///    .temperature(0.0)
142    ///    .build();
143    /// ```
144    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
145        AgentBuilder::new(self.completion_model(model))
146    }
147
148    /// Create an extractor builder with the given completion model.
149    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
150        &self,
151        model: &str,
152    ) -> ExtractorBuilder<T, CompletionModel> {
153        ExtractorBuilder::new(self.completion_model(model))
154    }
155}
156
157pub mod together_ai_api_types {
158    use serde::Deserialize;
159
160    impl ApiErrorResponse {
161        pub fn message(&self) -> String {
162            format!("Code `{}`: {}", self.code, self.error)
163        }
164    }
165
166    #[derive(Debug, Deserialize)]
167    pub struct ApiErrorResponse {
168        pub error: String,
169        pub code: String,
170    }
171
172    #[derive(Debug, Deserialize)]
173    #[serde(untagged)]
174    pub enum ApiResponse<T> {
175        Ok(T),
176        Error(ApiErrorResponse),
177    }
178}