Skip to main content

rig/providers/together/
completion.rs

1// ================================================================
2//! Together AI Completion Integration
3//! From [Together AI Reference](https://docs.together.ai/docs/chat-overview)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    http_client::HttpClientExt,
9    providers::openai,
10};
11
12use super::client::{Client, together_ai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::streaming::StreamingCompletionResponse;
15use bytes::Bytes;
16use serde::{Deserialize, Serialize};
17use tracing::{Instrument, Level, enabled, info_span};
18
19// ================================================================
20// Together Completion Models
21// ================================================================
22
23pub const YI_34B_CHAT: &str = "zero-one-ai/Yi-34B-Chat";
24pub const OLMO_7B_INSTRUCT: &str = "allenai/OLMo-7B-Instruct";
25pub const CHRONOS_HERMES_13B: &str = "Austism/chronos-hermes-13b";
26pub const ML318BR: &str = "carson/ml318br";
27pub const DOLPHIN_2_5_MIXTRAL_8X7B: &str = "cognitivecomputations/dolphin-2.5-mixtral-8x7b";
28pub const DBRX_INSTRUCT: &str = "databricks/dbrx-instruct";
29pub const DEEPSEEK_LLM_67B_CHAT: &str = "deepseek-ai/deepseek-llm-67b-chat";
30pub const DEEPSEEK_CODER_33B_INSTRUCT: &str = "deepseek-ai/deepseek-coder-33b-instruct";
31pub const PLATYPUS2_70B_INSTRUCT: &str = "garage-bAInd/Platypus2-70B-instruct";
32pub const GEMMA_2_9B_IT: &str = "google/gemma-2-9b-it";
33pub const GEMMA_2B_IT: &str = "google/gemma-2b-it";
34pub const GEMMA_2_27B_IT: &str = "google/gemma-2-27b-it";
35pub const GEMMA_7B_IT: &str = "google/gemma-7b-it";
36pub const LLAMA_3_70B_INSTRUCT_GRADIENT_1048K: &str =
37    "gradientai/Llama-3-70B-Instruct-Gradient-1048k";
38pub const MYTHOMAX_L2_13B: &str = "Gryphe/MythoMax-L2-13b";
39pub const MYTHOMAX_L2_13B_LITE: &str = "Gryphe/MythoMax-L2-13b-Lite";
40pub const LLAVA_NEXT_MISTRAL_7B: &str = "llava-hf/llava-v1.6-mistral-7b-hf";
41pub const ZEPHYR_7B_BETA: &str = "HuggingFaceH4/zephyr-7b-beta";
42pub const KOALA_7B: &str = "togethercomputer/Koala-7B";
43pub const VICUNA_7B_V1_3: &str = "lmsys/vicuna-7b-v1.3";
44pub const VICUNA_13B_V1_5_16K: &str = "lmsys/vicuna-13b-v1.5-16k";
45pub const VICUNA_13B_V1_5: &str = "lmsys/vicuna-13b-v1.5";
46pub const VICUNA_13B_V1_3: &str = "lmsys/vicuna-13b-v1.3";
47pub const KOALA_13B: &str = "togethercomputer/Koala-13B";
48pub const VICUNA_7B_V1_5: &str = "lmsys/vicuna-7b-v1.5";
49pub const CODE_LLAMA_34B_INSTRUCT: &str = "codellama/CodeLlama-34b-Instruct-hf";
50pub const LLAMA_3_8B_CHAT_HF_INT4: &str = "togethercomputer/Llama-3-8b-chat-hf-int4";
51pub const LLAMA_3_2_90B_VISION_INSTRUCT_TURBO: &str =
52    "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo";
53pub const LLAMA_3_2_11B_VISION_INSTRUCT_TURBO: &str =
54    "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo";
55pub const LLAMA_3_2_3B_INSTRUCT_TURBO: &str = "meta-llama/Llama-3.2-3B-Instruct-Turbo";
56pub const LLAMA_3_8B_CHAT_HF_INT8: &str = "togethercomputer/Llama-3-8b-chat-hf-int8";
57pub const LLAMA_3_1_70B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo";
58pub const LLAMA_2_13B_CHAT: &str = "meta-llama/Llama-2-13b-chat-hf";
59pub const LLAMA_3_70B_INSTRUCT_LITE: &str = "meta-llama/Meta-Llama-3-70B-Instruct-Lite";
60pub const LLAMA_3_8B_CHAT_HF: &str = "meta-llama/Llama-3-8b-chat-hf";
61pub const LLAMA_3_70B_CHAT_HF: &str = "meta-llama/Llama-3-70b-chat-hf";
62pub const LLAMA_3_8B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3-8B-Instruct-Turbo";
63pub const LLAMA_3_8B_INSTRUCT_LITE: &str = "meta-llama/Meta-Llama-3-8B-Instruct-Lite";
64pub const LLAMA_3_1_405B_INSTRUCT_LITE_PRO: &str =
65    "meta-llama/Meta-Llama-3.1-405B-Instruct-Lite-Pro";
66pub const LLAMA_2_7B_CHAT: &str = "meta-llama/Llama-2-7b-chat-hf";
67pub const LLAMA_3_1_405B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo";
68pub const LLAMA_VISION_FREE: &str = "meta-llama/Llama-Vision-Free";
69pub const LLAMA_3_70B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo";
70pub const LLAMA_3_1_8B_INSTRUCT_TURBO: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo";
71pub const CODE_LLAMA_7B_INSTRUCT_TOGETHER: &str = "togethercomputer/CodeLlama-7b-Instruct";
72pub const CODE_LLAMA_34B_INSTRUCT_TOGETHER: &str = "togethercomputer/CodeLlama-34b-Instruct";
73pub const CODE_LLAMA_13B_INSTRUCT: &str = "codellama/CodeLlama-13b-Instruct-hf";
74pub const CODE_LLAMA_13B_INSTRUCT_TOGETHER: &str = "togethercomputer/CodeLlama-13b-Instruct";
75pub const LLAMA_2_13B_CHAT_TOGETHER: &str = "togethercomputer/llama-2-13b-chat";
76pub const LLAMA_2_7B_CHAT_TOGETHER: &str = "togethercomputer/llama-2-7b-chat";
77pub const LLAMA_3_8B_INSTRUCT: &str = "meta-llama/Meta-Llama-3-8B-Instruct";
78pub const LLAMA_3_70B_INSTRUCT: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
79pub const CODE_LLAMA_70B_INSTRUCT: &str = "codellama/CodeLlama-70b-Instruct-hf";
80pub const LLAMA_2_70B_CHAT_TOGETHER: &str = "togethercomputer/llama-2-70b-chat";
81pub const LLAMA_3_1_8B_INSTRUCT_REFERENCE: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Reference";
82pub const LLAMA_3_1_70B_INSTRUCT_REFERENCE: &str =
83    "meta-llama/Meta-Llama-3.1-70B-Instruct-Reference";
84pub const WIZARDLM_2_8X22B: &str = "microsoft/WizardLM-2-8x22B";
85pub const MISTRAL_7B_INSTRUCT_V0_1: &str = "mistralai/Mistral-7B-Instruct-v0.1";
86pub const MISTRAL_7B_INSTRUCT_V0_2: &str = "mistralai/Mistral-7B-Instruct-v0.2";
87pub const MISTRAL_7B_INSTRUCT_V0_3: &str = "mistralai/Mistral-7B-Instruct-v0.3";
88pub const MIXTRAL_8X7B_INSTRUCT_V0_1: &str = "mistralai/Mixtral-8x7B-Instruct-v0.1";
89pub const MIXTRAL_8X22B_INSTRUCT_V0_1: &str = "mistralai/Mixtral-8x22B-Instruct-v0.1";
90pub const NOUS_HERMES_2_MIXTRAL_8X7B_DPO: &str = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO";
91pub const NOUS_HERMES_LLAMA2_70B: &str = "NousResearch/Nous-Hermes-Llama2-70b";
92pub const NOUS_HERMES_2_MIXTRAL_8X7B_SFT: &str = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT";
93pub const NOUS_HERMES_LLAMA2_13B: &str = "NousResearch/Nous-Hermes-Llama2-13b";
94pub const NOUS_HERMES_2_MISTRAL_DPO: &str = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO";
95pub const NOUS_HERMES_LLAMA2_7B: &str = "NousResearch/Nous-Hermes-llama-2-7b";
96pub const NOUS_CAPYBARA_V1_9: &str = "NousResearch/Nous-Capybara-7B-V1p9";
97pub const HERMES_2_THETA_LLAMA_3_70B: &str = "NousResearch/Hermes-2-Theta-Llama-3-70B";
98pub const OPENCHAT_3_5: &str = "openchat/openchat-3.5-1210";
99pub const OPENORCA_MISTRAL_7B_8K: &str = "Open-Orca/Mistral-7B-OpenOrca";
100pub const QWEN_2_72B_INSTRUCT: &str = "Qwen/Qwen2-72B-Instruct";
101pub const QWEN2_5_72B_INSTRUCT_TURBO: &str = "Qwen/Qwen2.5-72B-Instruct-Turbo";
102pub const QWEN2_5_7B_INSTRUCT_TURBO: &str = "Qwen/Qwen2.5-7B-Instruct-Turbo";
103pub const QWEN1_5_110B_CHAT: &str = "Qwen/Qwen1.5-110B-Chat";
104pub const QWEN1_5_72B_CHAT: &str = "Qwen/Qwen1.5-72B-Chat";
105pub const QWEN_2_1_5B_INSTRUCT: &str = "Qwen/Qwen2-1.5B-Instruct";
106pub const QWEN_2_7B_INSTRUCT: &str = "Qwen/Qwen2-7B-Instruct";
107pub const QWEN1_5_14B_CHAT: &str = "Qwen/Qwen1.5-14B-Chat";
108pub const QWEN1_5_1_8B_CHAT: &str = "Qwen/Qwen1.5-1.8B-Chat";
109pub const QWEN1_5_32B_CHAT: &str = "Qwen/Qwen1.5-32B-Chat";
110pub const QWEN1_5_7B_CHAT: &str = "Qwen/Qwen1.5-7B-Chat";
111pub const QWEN1_5_0_5B_CHAT: &str = "Qwen/Qwen1.5-0.5B-Chat";
112pub const QWEN1_5_4B_CHAT: &str = "Qwen/Qwen1.5-4B-Chat";
113pub const SNORKEL_MISTRAL_PAIRRM_DPO: &str = "snorkelai/Snorkel-Mistral-PairRM-DPO";
114pub const SNOWFLAKE_ARCTIC_INSTRUCT: &str = "Snowflake/snowflake-arctic-instruct";
115pub const ALPACA_7B: &str = "togethercomputer/alpaca-7b";
116pub const OPENHERMES_2_MISTRAL_7B: &str = "teknium/OpenHermes-2-Mistral-7B";
117pub const OPENHERMES_2_5_MISTRAL_7B: &str = "teknium/OpenHermes-2p5-Mistral-7B";
118pub const GUANACO_65B: &str = "togethercomputer/guanaco-65b";
119pub const GUANACO_13B: &str = "togethercomputer/guanaco-13b";
120pub const GUANACO_33B: &str = "togethercomputer/guanaco-33b";
121pub const GUANACO_7B: &str = "togethercomputer/guanaco-7b";
122pub const REMM_SLERP_L2_13B: &str = "Undi95/ReMM-SLERP-L2-13B";
123pub const TOPPY_M_7B: &str = "Undi95/Toppy-M-7B";
124pub const SOLAR_10_7B_INSTRUCT_V1: &str = "upstage/SOLAR-10.7B-Instruct-v1.0";
125pub const SOLAR_10_7B_INSTRUCT_V1_INT4: &str = "togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4";
126pub const WIZARDLM_13B_V1_2: &str = "WizardLM/WizardLM-13B-V1.2";
127
128// =================================================================
129// Rig Implementation Types
130// =================================================================
131
132#[derive(Debug, Serialize, Deserialize)]
133pub(super) struct TogetherAICompletionRequest {
134    model: String,
135    pub messages: Vec<openai::Message>,
136    #[serde(skip_serializing_if = "Option::is_none")]
137    temperature: Option<f64>,
138    #[serde(skip_serializing_if = "Vec::is_empty")]
139    tools: Vec<crate::providers::openai::completion::ToolDefinition>,
140    #[serde(skip_serializing_if = "Option::is_none")]
141    tool_choice: Option<ToolChoice>,
142    #[serde(flatten, skip_serializing_if = "Option::is_none")]
143    pub additional_params: Option<serde_json::Value>,
144}
145
146impl TryFrom<(&str, CompletionRequest)> for TogetherAICompletionRequest {
147    type Error = CompletionError;
148
149    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
150        if req.output_schema.is_some() {
151            tracing::warn!("Structured outputs currently not supported for TogetherAI");
152        }
153        let model = req.model.clone().unwrap_or_else(|| model.to_string());
154        let mut full_history: Vec<openai::Message> = match &req.preamble {
155            Some(preamble) => vec![openai::Message::system(preamble)],
156            None => vec![],
157        };
158        if let Some(docs) = req.normalized_documents() {
159            let docs: Vec<openai::Message> = docs.try_into()?;
160            full_history.extend(docs);
161        }
162
163        let chat_history: Vec<openai::Message> = req
164            .chat_history
165            .into_iter()
166            .map(|message| message.try_into())
167            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
168            .into_iter()
169            .flatten()
170            .collect();
171
172        full_history.extend(chat_history);
173
174        if full_history.is_empty() {
175            return Err(CompletionError::RequestError(
176                std::io::Error::new(
177                    std::io::ErrorKind::InvalidInput,
178                    "Together request has no provider-compatible messages after conversion",
179                )
180                .into(),
181            ));
182        }
183
184        let tool_choice = req
185            .tool_choice
186            .clone()
187            .map(ToolChoice::try_from)
188            .transpose()?;
189
190        Ok(Self {
191            model: model.to_string(),
192            messages: full_history,
193            temperature: req.temperature,
194            tools: req
195                .tools
196                .clone()
197                .into_iter()
198                .map(crate::providers::openai::completion::ToolDefinition::from)
199                .collect::<Vec<_>>(),
200            tool_choice,
201            additional_params: req.additional_params,
202        })
203    }
204}
205
206#[derive(Clone)]
207pub struct CompletionModel<T = reqwest::Client> {
208    pub(crate) client: Client<T>,
209    pub model: String,
210}
211
212impl<T> CompletionModel<T> {
213    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
214        Self {
215            client,
216            model: model.into(),
217        }
218    }
219}
220
221impl<T> completion::CompletionModel for CompletionModel<T>
222where
223    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
224{
225    type Response = openai::CompletionResponse;
226    type StreamingResponse = openai::StreamingCompletionResponse;
227
228    type Client = Client<T>;
229
230    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
231        Self::new(client.clone(), model)
232    }
233
234    async fn completion(
235        &self,
236        completion_request: completion::CompletionRequest,
237    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
238        let span = if tracing::Span::current().is_disabled() {
239            info_span!(
240                target: "rig::completions",
241                "chat",
242                gen_ai.operation.name = "chat",
243                gen_ai.provider.name = "together",
244                gen_ai.request.model = self.model.to_string(),
245                gen_ai.system_instructions = tracing::field::Empty,
246                gen_ai.response.id = tracing::field::Empty,
247                gen_ai.response.model = tracing::field::Empty,
248                gen_ai.usage.output_tokens = tracing::field::Empty,
249                gen_ai.usage.input_tokens = tracing::field::Empty,
250            )
251        } else {
252            tracing::Span::current()
253        };
254
255        span.record("gen_ai.system_instructions", &completion_request.preamble);
256
257        let request = TogetherAICompletionRequest::try_from((
258            self.model.to_string().as_ref(),
259            completion_request,
260        ))?;
261
262        if enabled!(Level::TRACE) {
263            tracing::trace!(target: "rig::completions",
264                "TogetherAI completion request: {}",
265                serde_json::to_string_pretty(&request)?
266            );
267        }
268
269        let body = serde_json::to_vec(&request)?;
270
271        let req = self
272            .client
273            .post("/v1/chat/completions")?
274            .body(body)
275            .map_err(|x| CompletionError::HttpError(x.into()))?;
276
277        async move {
278            let response = self.client.send::<_, Bytes>(req).await?;
279            let status = response.status();
280            let response_body = response.into_body().into_future().await?.to_vec();
281
282            if status.is_success() {
283                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
284                    &response_body,
285                )? {
286                    ApiResponse::Ok(response) => {
287                        let span = tracing::Span::current();
288                        span.record("gen_ai.response.id", &response.id);
289                        span.record("gen_ai.response.model_name", &response.model);
290                        if let Some(ref usage) = response.usage {
291                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
292                            span.record(
293                                "gen_ai.usage.output_tokens",
294                                usage.total_tokens - usage.prompt_tokens,
295                            );
296                        }
297                        if enabled!(Level::TRACE) {
298                            tracing::trace!(
299                                target: "rig::completions",
300                                "TogetherAI completion response: {}",
301                                serde_json::to_string_pretty(&response)?
302                            );
303                        }
304                        response.try_into()
305                    }
306                    ApiResponse::Error(err) => Err(CompletionError::ProviderError(err.error)),
307                }
308            } else {
309                Err(CompletionError::ProviderError(
310                    String::from_utf8_lossy(&response_body).to_string(),
311                ))
312            }
313        }
314        .instrument(span)
315        .await
316    }
317
318    async fn stream(
319        &self,
320        request: CompletionRequest,
321    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
322        CompletionModel::stream(self, request).await
323    }
324}
325
326#[derive(Debug, Serialize, Deserialize)]
327#[serde(untagged, rename_all = "snake_case")]
328pub enum ToolChoice {
329    None,
330    Auto,
331    Function(Vec<ToolChoiceFunctionKind>),
332}
333
334impl TryFrom<crate::message::ToolChoice> for ToolChoice {
335    type Error = CompletionError;
336
337    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
338        let res = match value {
339            crate::message::ToolChoice::None => Self::None,
340            crate::message::ToolChoice::Auto => Self::Auto,
341            crate::message::ToolChoice::Specific { function_names } => {
342                let vec: Vec<ToolChoiceFunctionKind> = function_names
343                    .into_iter()
344                    .map(|name| ToolChoiceFunctionKind::Function { name })
345                    .collect();
346
347                Self::Function(vec)
348            }
349            choice => {
350                return Err(CompletionError::ProviderError(format!(
351                    "Unsupported tool choice type: {choice:?}"
352                )));
353            }
354        };
355
356        Ok(res)
357    }
358}
359
360#[derive(Debug, Serialize, Deserialize)]
361#[serde(tag = "type", content = "function")]
362pub enum ToolChoiceFunctionKind {
363    Function { name: String },
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::{OneOrMany, message};
370
371    #[test]
372    fn together_request_conversion_errors_when_all_messages_are_filtered() {
373        let request = CompletionRequest {
374            preamble: None,
375            chat_history: OneOrMany::one(message::Message::Assistant {
376                id: None,
377                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
378            }),
379            documents: vec![],
380            tools: vec![],
381            temperature: None,
382            max_tokens: None,
383            tool_choice: None,
384            additional_params: None,
385            model: None,
386            output_schema: None,
387        };
388
389        let result = TogetherAICompletionRequest::try_from(("meta-llama/test-model", request));
390        assert!(matches!(result, Err(CompletionError::RequestError(_))));
391    }
392}