rig/providers/together/
completion.rs1use crate::{
7 completion::{self, CompletionError},
8 http_client::HttpClientExt,
9 json_utils,
10 providers::openai,
11};
12
13use super::client::{Client, together_ai_api_types::ApiResponse};
14use crate::completion::CompletionRequest;
15use crate::streaming::StreamingCompletionResponse;
16use bytes::Bytes;
17use serde::{Deserialize, Serialize};
18use serde_json::json;
19use tracing::{Instrument, info_span};
20
21pub const YI_34B_CHAT: &str = "zero-one-ai/Yi-34B-Chat";
26pub const OLMO_7B_INSTRUCT: &str = "allenai/OLMo-7B-Instruct";
27pub const CHRONOS_HERMES_13B: &str = "Austism/chronos-hermes-13b";
28pub const ML318BR: &str = "carson/ml318br";
29pub const DOLPHIN_2_5_MIXTRAL_8X7B: &str = "cognitivecomputations/dolphin-2.5-mixtral-8x7b";
30pub const DBRX_INSTRUCT: &str = "databricks/dbrx-instruct";
31pub const DEEPSEEK_LLM_67B_CHAT: &str = "deepseek-ai/deepseek-llm-67b-chat";
32pub const DEEPSEEK_CODER_33B_INSTRUCT: &str = "deepseek-ai/deepseek-coder-33b-instruct";
33pub const PLATYPUS2_70B_INSTRUCT: &str = "garage-bAInd/Platypus2-70B-instruct";
34pub const GEMMA_2_9B_IT: &str = "google/gemma-2-9b-it";
35pub const GEMMA_2B_IT: &str = "google/gemma-2b-it";
36pub const GEMMA_2_27B_IT: &str = "google/gemma-2-27b-it";
37pub const GEMMA_7B_IT: &str = "google/gemma-7b-it";
38pub const LLAMA_3_70B_INSTRUCT_GRADIENT_1048K: &str =
39 "gradientai/Llama-3-70B-Instruct-Gradient-1048k";
40pub const MYTHOMAX_L2_13B: &str = "Gryphe/MythoMax-L2-13b";
41pub const MYTHOMAX_L2_13B_LITE: &str = "Gryphe/MythoMax-L2-13b-Lite";
42pub const LLAVA_NEXT_MISTRAL_7B: &str = "llava-hf/llava-v1.6-mistral-7b-hf";
43pub const ZEPHYR_7B_BETA: &str = "HuggingFaceH4/zephyr-7b-beta";
44pub const KOALA_7B: &str = "togethercomputer/Koala-7B";
45pub const VICUNA_7B_V1_3: &str = "lmsys/vicuna-7b-v1.3";
46pub const VICUNA_13B_V1_5_16K: &str = "lmsys/vicuna-13b-v1.5-16k";
47pub const VICUNA_13B_V1_5: &str = "lmsys/vicuna-13b-v1.5";
48pub const VICUNA_13B_V1_3: &str = "lmsys/vicuna-13b-v1.3";
49pub const KOALA_13B: &str = "togethercomputer/Koala-13B";
50pub const VICUNA_7B_V1_5: &str = "lmsys/vicuna-7b-v1.5";
51pub const CODE_LLAMA_34B_INSTRUCT: &str = "codellama/CodeLlama-34b-Instruct-hf";
52pub const LLAMA_3_8B_CHAT_HF_INT4: &str = "togethercomputer/Llama-3-8b-chat-hf-int4";
53pub const LLAMA_3_2_90B_VISION_INSTRUCT_TURBO: &str =
54 "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo";
55pub const LLAMA_3_2_11B_VISION_INSTRUCT_TURBO: &str =
56 "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo";
57pub const LLAMA_3_2_3B_INSTRUCT_TURBO: &str = "meta-llama/Llama-3.2-3B-Instruct-Turbo";
58pub const LLAMA_3_8B_CHAT_HF_INT8: &str = "togethercomputer/Llama-3-8b-chat-hf-int8";
59pub const LLAMA_3_1_70B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo";
60pub const LLAMA_2_13B_CHAT: &str = "meta-llama/Llama-2-13b-chat-hf";
61pub const LLAMA_3_70B_INSTRUCT_LITE: &str = "meta-llama/Meta-Llama-3-70B-Instruct-Lite";
62pub const LLAMA_3_8B_CHAT_HF: &str = "meta-llama/Llama-3-8b-chat-hf";
63pub const LLAMA_3_70B_CHAT_HF: &str = "meta-llama/Llama-3-70b-chat-hf";
64pub const LLAMA_3_8B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3-8B-Instruct-Turbo";
65pub const LLAMA_3_8B_INSTRUCT_LITE: &str = "meta-llama/Meta-Llama-3-8B-Instruct-Lite";
66pub const LLAMA_3_1_405B_INSTRUCT_LITE_PRO: &str =
67 "meta-llama/Meta-Llama-3.1-405B-Instruct-Lite-Pro";
68pub const LLAMA_2_7B_CHAT: &str = "meta-llama/Llama-2-7b-chat-hf";
69pub const LLAMA_3_1_405B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo";
70pub const LLAMA_VISION_FREE: &str = "meta-llama/Llama-Vision-Free";
71pub const LLAMA_3_70B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo";
72pub const LLAMA_3_1_8B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo";
73pub const CODE_LLAMA_7B_INSTRUCT_TOGETHER: &str = "togethercomputer/CodeLlama-7b-Instruct";
74pub const CODE_LLAMA_34B_INSTRUCT_TOGETHER: &str = "togethercomputer/CodeLlama-34b-Instruct";
75pub const CODE_LLAMA_13B_INSTRUCT: &str = "codellama/CodeLlama-13b-Instruct-hf";
76pub const CODE_LLAMA_13B_INSTRUCT_TOGETHER: &str = "togethercomputer/CodeLlama-13b-Instruct";
77pub const LLAMA_2_13B_CHAT_TOGETHER: &str = "togethercomputer/llama-2-13b-chat";
78pub const LLAMA_2_7B_CHAT_TOGETHER: &str = "togethercomputer/llama-2-7b-chat";
79pub const LLAMA_3_8B_INSTRUCT: &str = "meta-llama/Meta-Llama-3-8B-Instruct";
80pub const LLAMA_3_70B_INSTRUCT: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
81pub const CODE_LLAMA_70B_INSTRUCT: &str = "codellama/CodeLlama-70b-Instruct-hf";
82pub const LLAMA_2_70B_CHAT_TOGETHER: &str = "togethercomputer/llama-2-70b-chat";
83pub const LLAMA_3_1_8B_INSTRUCT_REFERENCE: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Reference";
84pub const LLAMA_3_1_70B_INSTRUCT_REFERENCE: &str =
85 "meta-llama/Meta-Llama-3.1-70B-Instruct-Reference";
86pub const WIZARDLM_2_8X22B: &str = "microsoft/WizardLM-2-8x22B";
87pub const MISTRAL_7B_INSTRUCT_V0_1: &str = "mistralai/Mistral-7B-Instruct-v0.1";
88pub const MISTRAL_7B_INSTRUCT_V0_2: &str = "mistralai/Mistral-7B-Instruct-v0.2";
89pub const MISTRAL_7B_INSTRUCT_V0_3: &str = "mistralai/Mistral-7B-Instruct-v0.3";
90pub const MIXTRAL_8X7B_INSTRUCT_V0_1: &str = "mistralai/Mixtral-8x7B-Instruct-v0.1";
91pub const MIXTRAL_8X22B_INSTRUCT_V0_1: &str = "mistralai/Mixtral-8x22B-Instruct-v0.1";
92pub const NOUS_HERMES_2_MIXTRAL_8X7B_DPO: &str = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO";
93pub const NOUS_HERMES_LLAMA2_70B: &str = "NousResearch/Nous-Hermes-Llama2-70b";
94pub const NOUS_HERMES_2_MIXTRAL_8X7B_SFT: &str = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT";
95pub const NOUS_HERMES_LLAMA2_13B: &str = "NousResearch/Nous-Hermes-Llama2-13b";
96pub const NOUS_HERMES_2_MISTRAL_DPO: &str = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO";
97pub const NOUS_HERMES_LLAMA2_7B: &str = "NousResearch/Nous-Hermes-llama-2-7b";
98pub const NOUS_CAPYBARA_V1_9: &str = "NousResearch/Nous-Capybara-7B-V1p9";
99pub const HERMES_2_THETA_LLAMA_3_70B: &str = "NousResearch/Hermes-2-Theta-Llama-3-70B";
100pub const OPENCHAT_3_5: &str = "openchat/openchat-3.5-1210";
101pub const OPENORCA_MISTRAL_7B_8K: &str = "Open-Orca/Mistral-7B-OpenOrca";
102pub const QWEN_2_72B_INSTRUCT: &str = "Qwen/Qwen2-72B-Instruct";
103pub const QWEN2_5_72B_INSTRUCT_TURBO: &str = "Qwen/Qwen2.5-72B-Instruct-Turbo";
104pub const QWEN2_5_7B_INSTRUCT_TURBO: &str = "Qwen/Qwen2.5-7B-Instruct-Turbo";
105pub const QWEN1_5_110B_CHAT: &str = "Qwen/Qwen1.5-110B-Chat";
106pub const QWEN1_5_72B_CHAT: &str = "Qwen/Qwen1.5-72B-Chat";
107pub const QWEN_2_1_5B_INSTRUCT: &str = "Qwen/Qwen2-1.5B-Instruct";
108pub const QWEN_2_7B_INSTRUCT: &str = "Qwen/Qwen2-7B-Instruct";
109pub const QWEN1_5_14B_CHAT: &str = "Qwen/Qwen1.5-14B-Chat";
110pub const QWEN1_5_1_8B_CHAT: &str = "Qwen/Qwen1.5-1.8B-Chat";
111pub const QWEN1_5_32B_CHAT: &str = "Qwen/Qwen1.5-32B-Chat";
112pub const QWEN1_5_7B_CHAT: &str = "Qwen/Qwen1.5-7B-Chat";
113pub const QWEN1_5_0_5B_CHAT: &str = "Qwen/Qwen1.5-0.5B-Chat";
114pub const QWEN1_5_4B_CHAT: &str = "Qwen/Qwen1.5-4B-Chat";
115pub const SNORKEL_MISTRAL_PAIRRM_DPO: &str = "snorkelai/Snorkel-Mistral-PairRM-DPO";
116pub const SNOWFLAKE_ARCTIC_INSTRUCT: &str = "Snowflake/snowflake-arctic-instruct";
117pub const ALPACA_7B: &str = "togethercomputer/alpaca-7b";
118pub const OPENHERMES_2_MISTRAL_7B: &str = "teknium/OpenHermes-2-Mistral-7B";
119pub const OPENHERMES_2_5_MISTRAL_7B: &str = "teknium/OpenHermes-2p5-Mistral-7B";
120pub const GUANACO_65B: &str = "togethercomputer/guanaco-65b";
121pub const GUANACO_13B: &str = "togethercomputer/guanaco-13b";
122pub const GUANACO_33B: &str = "togethercomputer/guanaco-33b";
123pub const GUANACO_7B: &str = "togethercomputer/guanaco-7b";
124pub const REMM_SLERP_L2_13B: &str = "Undi95/ReMM-SLERP-L2-13B";
125pub const TOPPY_M_7B: &str = "Undi95/Toppy-M-7B";
126pub const SOLAR_10_7B_INSTRUCT_V1: &str = "upstage/SOLAR-10.7B-Instruct-v1.0";
127pub const SOLAR_10_7B_INSTRUCT_V1_INT4: &str = "togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4";
128pub const WIZARDLM_13B_V1_2: &str = "WizardLM/WizardLM-13B-V1.2";
129
130#[derive(Clone)]
135pub struct CompletionModel<T = reqwest::Client> {
136 pub(crate) client: Client<T>,
137 pub model: String,
138}
139
140impl<T> CompletionModel<T> {
141 pub fn new(client: Client<T>, model: &str) -> Self {
142 Self {
143 client,
144 model: model.to_string(),
145 }
146 }
147
148 pub(crate) fn create_completion_request(
149 &self,
150 completion_request: completion::CompletionRequest,
151 ) -> Result<serde_json::Value, CompletionError> {
152 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
153 Some(preamble) => vec![openai::Message::system(preamble)],
154 None => vec![],
155 };
156 if let Some(docs) = completion_request.normalized_documents() {
157 let docs: Vec<openai::Message> = docs.try_into()?;
158 full_history.extend(docs);
159 }
160 let chat_history: Vec<openai::Message> = completion_request
161 .chat_history
162 .into_iter()
163 .map(|message| message.try_into())
164 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
165 .into_iter()
166 .flatten()
167 .collect();
168
169 full_history.extend(chat_history);
170
171 let tool_choice = completion_request
172 .tool_choice
173 .map(ToolChoice::try_from)
174 .transpose()?;
175
176 let mut request = if completion_request.tools.is_empty() {
177 json!({
178 "model": self.model,
179 "messages": full_history,
180 "temperature": completion_request.temperature,
181 })
182 } else {
183 json!({
184 "model": self.model,
185 "messages": full_history,
186 "temperature": completion_request.temperature,
187 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
188 "tool_choice": tool_choice,
189 })
190 };
191 request = if let Some(params) = completion_request.additional_params {
192 json_utils::merge(request, params)
193 } else {
194 request
195 };
196 Ok(request)
197 }
198}
199
200impl<T> completion::CompletionModel for CompletionModel<T>
201where
202 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
203{
204 type Response = openai::CompletionResponse;
205 type StreamingResponse = openai::StreamingCompletionResponse;
206
207 #[cfg_attr(feature = "worker", worker::send)]
208 async fn completion(
209 &self,
210 completion_request: completion::CompletionRequest,
211 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
212 let preamble = completion_request.preamble.clone();
213 let request = self.create_completion_request(completion_request)?;
214 let messages_as_json_string =
215 serde_json::to_string(request.get("messages").unwrap()).unwrap();
216
217 let span = if tracing::Span::current().is_disabled() {
218 info_span!(
219 target: "rig::completions",
220 "chat",
221 gen_ai.operation.name = "chat",
222 gen_ai.provider.name = "together",
223 gen_ai.request.model = self.model,
224 gen_ai.system_instructions = preamble,
225 gen_ai.response.id = tracing::field::Empty,
226 gen_ai.response.model = tracing::field::Empty,
227 gen_ai.usage.output_tokens = tracing::field::Empty,
228 gen_ai.usage.input_tokens = tracing::field::Empty,
229 gen_ai.input.messages = &messages_as_json_string,
230 gen_ai.output.messages = tracing::field::Empty,
231 )
232 } else {
233 tracing::Span::current()
234 };
235
236 tracing::debug!(target: "rig::completion", "TogetherAI completion request: {messages_as_json_string}");
237
238 let body = serde_json::to_vec(&request)?;
239
240 let req = self
241 .client
242 .post("/v1/chat/completions")?
243 .header("Content-Type", "application/json")
244 .body(body)
245 .map_err(|x| CompletionError::HttpError(x.into()))?;
246
247 async move {
248 let response = self.client.http_client.send::<_, Bytes>(req).await?;
249 let status = response.status();
250 let response_body = response.into_body().into_future().await?.to_vec();
251
252 if status.is_success() {
253 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(&response_body)? {
254 ApiResponse::Ok(response) => {
255 let span = tracing::Span::current();
256 span.record(
257 "gen_ai.output.messages",
258 serde_json::to_string(&response.choices).unwrap(),
259 );
260 span.record("gen_ai.response.id", &response.id);
261 span.record("gen_ai.response.model_name", &response.model);
262 if let Some(ref usage) = response.usage {
263 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
264 span.record(
265 "gen_ai.usage.output_tokens",
266 usage.total_tokens - usage.prompt_tokens,
267 );
268 }
269 tracing::debug!(target: "rig::completion", "TogetherAI completion response: {}", serde_json::to_string_pretty(&response)?);
270 response.try_into()
271 }
272 ApiResponse::Error(err) => Err(CompletionError::ProviderError(err.error)),
273 }
274 } else {
275 Err(CompletionError::ProviderError(
276 String::from_utf8_lossy(&response_body).to_string()
277 ))
278 }
279 }
280 .instrument(span)
281 .await
282 }
283
284 #[cfg_attr(feature = "worker", worker::send)]
285 async fn stream(
286 &self,
287 request: CompletionRequest,
288 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
289 CompletionModel::stream(self, request).await
290 }
291}
292
293#[derive(Debug, Serialize, Deserialize)]
294#[serde(untagged, rename_all = "snake_case")]
295pub enum ToolChoice {
296 None,
297 Auto,
298 Function(Vec<ToolChoiceFunctionKind>),
299}
300
301impl TryFrom<crate::message::ToolChoice> for ToolChoice {
302 type Error = CompletionError;
303
304 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
305 let res = match value {
306 crate::message::ToolChoice::None => Self::None,
307 crate::message::ToolChoice::Auto => Self::Auto,
308 crate::message::ToolChoice::Specific { function_names } => {
309 let vec: Vec<ToolChoiceFunctionKind> = function_names
310 .into_iter()
311 .map(|name| ToolChoiceFunctionKind::Function { name })
312 .collect();
313
314 Self::Function(vec)
315 }
316 choice => {
317 return Err(CompletionError::ProviderError(format!(
318 "Unsupported tool choice type: {choice:?}"
319 )));
320 }
321 };
322
323 Ok(res)
324 }
325}
326
327#[derive(Debug, Serialize, Deserialize)]
328#[serde(tag = "type", content = "function")]
329pub enum ToolChoiceFunctionKind {
330 Function { name: String },
331}