rig/providers/together/client.rs
1use crate::{
2 agent::AgentBuilder,
3 embeddings::{self},
4 extractor::ExtractorBuilder,
5 Embed,
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9
10use super::{completion::CompletionModel, embedding::EmbeddingModel, M2_BERT_80M_8K_RETRIEVAL};
11
12// ================================================================
13// Together AI Client
14// ================================================================
15const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
16
17#[derive(Clone)]
18pub struct Client {
19 base_url: String,
20 http_client: reqwest::Client,
21}
22
23impl Client {
24 /// Create a new Together AI client with the given API key.
25 pub fn new(api_key: &str) -> Self {
26 Self::from_url(api_key, TOGETHER_AI_BASE_URL)
27 }
28
29 fn from_url(api_key: &str, base_url: &str) -> Self {
30 Self {
31 base_url: base_url.to_string(),
32 http_client: reqwest::Client::builder()
33 .default_headers({
34 let mut headers = reqwest::header::HeaderMap::new();
35 headers.insert(
36 reqwest::header::CONTENT_TYPE,
37 "application/json".parse().unwrap(),
38 );
39 headers.insert(
40 "Authorization",
41 format!("Bearer {}", api_key)
42 .parse()
43 .expect("Bearer token should parse"),
44 );
45 headers
46 })
47 .build()
48 .expect("Together AI reqwest client should build"),
49 }
50 }
51
52 /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
53 /// Panics if the environment variable is not set.
54 pub fn from_env() -> Self {
55 let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
56 Self::new(&api_key)
57 }
58
59 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
60 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
61
62 tracing::debug!("POST {}", url);
63 self.http_client.post(url)
64 }
65
66 /// Create an embedding model with the given name.
67 /// Note: default embedding dimension of 0 will be used if model is not known.
68 /// If this is the case, it's better to use function `embedding_model_with_ndims`
69 ///
70 /// # Example
71 /// ```
72 /// use rig::providers::together_ai::{Client, self};
73 ///
74 /// // Initialize the Together AI client
75 /// let together_ai = Client::new("your-together-ai-api-key");
76 ///
77 /// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1);
78 /// ```
79 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
80 let ndims = match model {
81 M2_BERT_80M_8K_RETRIEVAL => 8192,
82 _ => 0,
83 };
84 EmbeddingModel::new(self.clone(), model, ndims)
85 }
86
87 /// Create an embedding model with the given name and the number of dimensions in the embedding
88 /// generated by the model.
89 ///
90 /// # Example
91 /// ```
92 /// use rig::providers::together_ai::{Client, self};
93 ///
94 /// // Initialize the Together AI client
95 /// let together_ai = Client::new("your-together-ai-api-key");
96 ///
97 /// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
98 /// ```
99 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
100 EmbeddingModel::new(self.clone(), model, ndims)
101 }
102
103 /// Create an embedding builder with the given embedding model.
104 ///
105 /// # Example
106 /// ```
107 /// use rig::providers::together_ai::{Client, self};
108 ///
109 /// // Initialize the Together AI client
110 /// let together_ai = Client::new("your-together-ai-api-key");
111 ///
112 /// let embeddings = together_ai.embeddings(together_ai::embedding::EMBEDDING_V1)
113 /// .simple_document("doc0", "Hello, world!")
114 /// .simple_document("doc1", "Goodbye, world!")
115 /// .build()
116 /// .await
117 /// .expect("Failed to embed documents");
118 /// ```
119 pub fn embeddings<D: Embed>(
120 &self,
121 model: &str,
122 ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
123 embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
124 }
125
126 /// Create a completion model with the given name.
127 pub fn completion_model(&self, model: &str) -> CompletionModel {
128 CompletionModel::new(self.clone(), model)
129 }
130
131 /// Create an agent builder with the given completion model.
132 /// # Example
133 /// ```
134 /// use rig::providers::together_ai::{Client, self};
135 ///
136 /// // Initialize the Together AI client
137 /// let together_ai = Client::new("your-together-ai-api-key");
138 ///
139 /// let agent = together_ai.agent(together_ai::completion::MODEL_NAME)
140 /// .preamble("You are comedian AI with a mission to make people laugh.")
141 /// .temperature(0.0)
142 /// .build();
143 /// ```
144 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
145 AgentBuilder::new(self.completion_model(model))
146 }
147
148 /// Create an extractor builder with the given completion model.
149 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
150 &self,
151 model: &str,
152 ) -> ExtractorBuilder<T, CompletionModel> {
153 ExtractorBuilder::new(self.completion_model(model))
154 }
155}
156
157pub mod together_ai_api_types {
158 use serde::Deserialize;
159
160 impl ApiErrorResponse {
161 pub fn message(&self) -> String {
162 format!("Code `{}`: {}", self.code, self.error)
163 }
164 }
165
166 #[derive(Debug, Deserialize)]
167 pub struct ApiErrorResponse {
168 pub error: String,
169 pub code: String,
170 }
171
172 #[derive(Debug, Deserialize)]
173 #[serde(untagged)]
174 pub enum ApiResponse<T> {
175 Ok(T),
176 Error(ApiErrorResponse),
177 }
178}