rig/providers/together/
completion.rs

1// ================================================================
2//! Together AI Completion Integration
3//! From [Together AI Reference](https://docs.together.ai/docs/chat-overview)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    json_utils,
9    providers::openai,
10};
11
12use serde_json::json;
13
14use super::client::{together_ai_api_types::ApiResponse, Client};
15
16// ================================================================
17// Together Completion Models
18// ================================================================
19
20pub const YI_34B_CHAT: &str = "zero-one-ai/Yi-34B-Chat";
21pub const OLMO_7B_INSTRUCT: &str = "allenai/OLMo-7B-Instruct";
22pub const CHRONOS_HERMES_13B: &str = "Austism/chronos-hermes-13b";
23pub const ML318BR: &str = "carson/ml318br";
24pub const DOLPHIN_2_5_MIXTRAL_8X7B: &str = "cognitivecomputations/dolphin-2.5-mixtral-8x7b";
25pub const DBRX_INSTRUCT: &str = "databricks/dbrx-instruct";
26pub const DEEPSEEK_LLM_67B_CHAT: &str = "deepseek-ai/deepseek-llm-67b-chat";
27pub const DEEPSEEK_CODER_33B_INSTRUCT: &str = "deepseek-ai/deepseek-coder-33b-instruct";
28pub const PLATYPUS2_70B_INSTRUCT: &str = "garage-bAInd/Platypus2-70B-instruct";
29pub const GEMMA_2_9B_IT: &str = "google/gemma-2-9b-it";
30pub const GEMMA_2B_IT: &str = "google/gemma-2b-it";
31pub const GEMMA_2_27B_IT: &str = "google/gemma-2-27b-it";
32pub const GEMMA_7B_IT: &str = "google/gemma-7b-it";
33pub const LLAMA_3_70B_INSTRUCT_GRADIENT_1048K: &str =
34    "gradientai/Llama-3-70B-Instruct-Gradient-1048k";
35pub const MYTHOMAX_L2_13B: &str = "Gryphe/MythoMax-L2-13b";
36pub const MYTHOMAX_L2_13B_LITE: &str = "Gryphe/MythoMax-L2-13b-Lite";
37pub const LLAVA_NEXT_MISTRAL_7B: &str = "llava-hf/llava-v1.6-mistral-7b-hf";
38pub const ZEPHYR_7B_BETA: &str = "HuggingFaceH4/zephyr-7b-beta";
39pub const KOALA_7B: &str = "togethercomputer/Koala-7B";
40pub const VICUNA_7B_V1_3: &str = "lmsys/vicuna-7b-v1.3";
41pub const VICUNA_13B_V1_5_16K: &str = "lmsys/vicuna-13b-v1.5-16k";
42pub const VICUNA_13B_V1_5: &str = "lmsys/vicuna-13b-v1.5";
43pub const VICUNA_13B_V1_3: &str = "lmsys/vicuna-13b-v1.3";
44pub const KOALA_13B: &str = "togethercomputer/Koala-13B";
45pub const VICUNA_7B_V1_5: &str = "lmsys/vicuna-7b-v1.5";
46pub const CODE_LLAMA_34B_INSTRUCT: &str = "codellama/CodeLlama-34b-Instruct-hf";
47pub const LLAMA_3_8B_CHAT_HF_INT4: &str = "togethercomputer/Llama-3-8b-chat-hf-int4";
48pub const LLAMA_3_2_90B_VISION_INSTRUCT_TURBO: &str =
49    "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo";
50pub const LLAMA_3_2_11B_VISION_INSTRUCT_TURBO: &str =
51    "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo";
52pub const LLAMA_3_2_3B_INSTRUCT_TURBO: &str = "meta-llama/Llama-3.2-3B-Instruct-Turbo";
53pub const LLAMA_3_8B_CHAT_HF_INT8: &str = "togethercomputer/Llama-3-8b-chat-hf-int8";
54pub const LLAMA_3_1_70B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo";
55pub const LLAMA_2_13B_CHAT: &str = "meta-llama/Llama-2-13b-chat-hf";
56pub const LLAMA_3_70B_INSTRUCT_LITE: &str = "meta-llama/Meta-Llama-3-70B-Instruct-Lite";
57pub const LLAMA_3_8B_CHAT_HF: &str = "meta-llama/Llama-3-8b-chat-hf";
58pub const LLAMA_3_70B_CHAT_HF: &str = "meta-llama/Llama-3-70b-chat-hf";
59pub const LLAMA_3_8B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3-8B-Instruct-Turbo";
60pub const LLAMA_3_8B_INSTRUCT_LITE: &str = "meta-llama/Meta-Llama-3-8B-Instruct-Lite";
61pub const LLAMA_3_1_405B_INSTRUCT_LITE_PRO: &str =
62    "meta-llama/Meta-Llama-3.1-405B-Instruct-Lite-Pro";
63pub const LLAMA_2_7B_CHAT: &str = "meta-llama/Llama-2-7b-chat-hf";
64pub const LLAMA_3_1_405B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo";
65pub const LLAMA_VISION_FREE: &str = "meta-llama/Llama-Vision-Free";
66pub const LLAMA_3_70B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo";
67pub const LLAMA_3_1_8B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo";
68pub const CODE_LLAMA_7B_INSTRUCT_TOGETHER: &str = "togethercomputer/CodeLlama-7b-Instruct";
69pub const CODE_LLAMA_34B_INSTRUCT_TOGETHER: &str = "togethercomputer/CodeLlama-34b-Instruct";
70pub const CODE_LLAMA_13B_INSTRUCT: &str = "codellama/CodeLlama-13b-Instruct-hf";
71pub const CODE_LLAMA_13B_INSTRUCT_TOGETHER: &str = "togethercomputer/CodeLlama-13b-Instruct";
72pub const LLAMA_2_13B_CHAT_TOGETHER: &str = "togethercomputer/llama-2-13b-chat";
73pub const LLAMA_2_7B_CHAT_TOGETHER: &str = "togethercomputer/llama-2-7b-chat";
74pub const LLAMA_3_8B_INSTRUCT: &str = "meta-llama/Meta-Llama-3-8B-Instruct";
75pub const LLAMA_3_70B_INSTRUCT: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
76pub const CODE_LLAMA_70B_INSTRUCT: &str = "codellama/CodeLlama-70b-Instruct-hf";
77pub const LLAMA_2_70B_CHAT_TOGETHER: &str = "togethercomputer/llama-2-70b-chat";
78pub const LLAMA_3_1_8B_INSTRUCT_REFERENCE: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Reference";
79pub const LLAMA_3_1_70B_INSTRUCT_REFERENCE: &str =
80    "meta-llama/Meta-Llama-3.1-70B-Instruct-Reference";
81pub const WIZARDLM_2_8X22B: &str = "microsoft/WizardLM-2-8x22B";
82pub const MISTRAL_7B_INSTRUCT_V0_1: &str = "mistralai/Mistral-7B-Instruct-v0.1";
83pub const MISTRAL_7B_INSTRUCT_V0_2: &str = "mistralai/Mistral-7B-Instruct-v0.2";
84pub const MISTRAL_7B_INSTRUCT_V0_3: &str = "mistralai/Mistral-7B-Instruct-v0.3";
85pub const MIXTRAL_8X7B_INSTRUCT_V0_1: &str = "mistralai/Mixtral-8x7B-Instruct-v0.1";
86pub const MIXTRAL_8X22B_INSTRUCT_V0_1: &str = "mistralai/Mixtral-8x22B-Instruct-v0.1";
87pub const NOUS_HERMES_2_MIXTRAL_8X7B_DPO: &str = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO";
88pub const NOUS_HERMES_LLAMA2_70B: &str = "NousResearch/Nous-Hermes-Llama2-70b";
89pub const NOUS_HERMES_2_MIXTRAL_8X7B_SFT: &str = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT";
90pub const NOUS_HERMES_LLAMA2_13B: &str = "NousResearch/Nous-Hermes-Llama2-13b";
91pub const NOUS_HERMES_2_MISTRAL_DPO: &str = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO";
92pub const NOUS_HERMES_LLAMA2_7B: &str = "NousResearch/Nous-Hermes-llama-2-7b";
93pub const NOUS_CAPYBARA_V1_9: &str = "NousResearch/Nous-Capybara-7B-V1p9";
94pub const HERMES_2_THETA_LLAMA_3_70B: &str = "NousResearch/Hermes-2-Theta-Llama-3-70B";
95pub const OPENCHAT_3_5: &str = "openchat/openchat-3.5-1210";
96pub const OPENORCA_MISTRAL_7B_8K: &str = "Open-Orca/Mistral-7B-OpenOrca";
97pub const QWEN_2_72B_INSTRUCT: &str = "Qwen/Qwen2-72B-Instruct";
98pub const QWEN2_5_72B_INSTRUCT_TURBO: &str = "Qwen/Qwen2.5-72B-Instruct-Turbo";
99pub const QWEN2_5_7B_INSTRUCT_TURBO: &str = "Qwen/Qwen2.5-7B-Instruct-Turbo";
100pub const QWEN1_5_110B_CHAT: &str = "Qwen/Qwen1.5-110B-Chat";
101pub const QWEN1_5_72B_CHAT: &str = "Qwen/Qwen1.5-72B-Chat";
102pub const QWEN_2_1_5B_INSTRUCT: &str = "Qwen/Qwen2-1.5B-Instruct";
103pub const QWEN_2_7B_INSTRUCT: &str = "Qwen/Qwen2-7B-Instruct";
104pub const QWEN1_5_14B_CHAT: &str = "Qwen/Qwen1.5-14B-Chat";
105pub const QWEN1_5_1_8B_CHAT: &str = "Qwen/Qwen1.5-1.8B-Chat";
106pub const QWEN1_5_32B_CHAT: &str = "Qwen/Qwen1.5-32B-Chat";
107pub const QWEN1_5_7B_CHAT: &str = "Qwen/Qwen1.5-7B-Chat";
108pub const QWEN1_5_0_5B_CHAT: &str = "Qwen/Qwen1.5-0.5B-Chat";
109pub const QWEN1_5_4B_CHAT: &str = "Qwen/Qwen1.5-4B-Chat";
110pub const SNORKEL_MISTRAL_PAIRRM_DPO: &str = "snorkelai/Snorkel-Mistral-PairRM-DPO";
111pub const SNOWFLAKE_ARCTIC_INSTRUCT: &str = "Snowflake/snowflake-arctic-instruct";
112pub const ALPACA_7B: &str = "togethercomputer/alpaca-7b";
113pub const OPENHERMES_2_MISTRAL_7B: &str = "teknium/OpenHermes-2-Mistral-7B";
114pub const OPENHERMES_2_5_MISTRAL_7B: &str = "teknium/OpenHermes-2p5-Mistral-7B";
115pub const GUANACO_65B: &str = "togethercomputer/guanaco-65b";
116pub const GUANACO_13B: &str = "togethercomputer/guanaco-13b";
117pub const GUANACO_33B: &str = "togethercomputer/guanaco-33b";
118pub const GUANACO_7B: &str = "togethercomputer/guanaco-7b";
119pub const REMM_SLERP_L2_13B: &str = "Undi95/ReMM-SLERP-L2-13B";
120pub const TOPPY_M_7B: &str = "Undi95/Toppy-M-7B";
121pub const SOLAR_10_7B_INSTRUCT_V1: &str = "upstage/SOLAR-10.7B-Instruct-v1.0";
122pub const SOLAR_10_7B_INSTRUCT_V1_INT4: &str = "togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4";
123pub const WIZARDLM_13B_V1_2: &str = "WizardLM/WizardLM-13B-V1.2";
124
125// =================================================================
126// Rig Implementation Types
127// =================================================================
128
129#[derive(Clone)]
130pub struct CompletionModel {
131    client: Client,
132    pub model: String,
133}
134
135impl CompletionModel {
136    pub fn new(client: Client, model: &str) -> Self {
137        Self {
138            client,
139            model: model.to_string(),
140        }
141    }
142}
143
144impl completion::CompletionModel for CompletionModel {
145    type Response = openai::CompletionResponse;
146
147    #[cfg_attr(feature = "worker", worker::send)]
148    async fn completion(
149        &self,
150        completion_request: completion::CompletionRequest,
151    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
152        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
153            Some(preamble) => vec![openai::Message::system(preamble)],
154            None => vec![],
155        };
156
157        // Convert prompt to user message
158        let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
159
160        // Convert existing chat history
161        let chat_history: Vec<openai::Message> = completion_request
162            .chat_history
163            .into_iter()
164            .map(|message| message.try_into())
165            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
166            .into_iter()
167            .flatten()
168            .collect();
169
170        // Combine all messages into a single history
171        full_history.extend(chat_history);
172        full_history.extend(prompt);
173
174        let mut request = if completion_request.tools.is_empty() {
175            json!({
176                "model": self.model,
177                "messages": full_history,
178                "temperature": completion_request.temperature,
179            })
180        } else {
181            json!({
182                "model": self.model,
183                "messages": full_history,
184                "temperature": completion_request.temperature,
185                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
186                "tool_choice": "auto",
187            })
188        };
189
190        request = if let Some(params) = completion_request.additional_params {
191            json_utils::merge(request, params)
192        } else {
193            request
194        };
195
196        let response = self
197            .client
198            .post("/v1/chat/completions")
199            .json(&request)
200            .send()
201            .await?;
202
203        if response.status().is_success() {
204            let t = response.text().await?;
205            tracing::debug!(target: "rig", "Together completion error: {}", t);
206
207            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
208                ApiResponse::Ok(response) => {
209                    tracing::info!(target: "rig",
210                        "Together completion token usage: {:?}",
211                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
212                    );
213                    response.try_into()
214                }
215                ApiResponse::Error(err) => Err(CompletionError::ProviderError(err.error)),
216            }
217        } else {
218            Err(CompletionError::ProviderError(response.text().await?))
219        }
220    }
221}