1pub mod chat;
6pub mod error;
7pub mod utils;
8
9use clap::ValueEnum;
10use endpoints::chat::ChatCompletionRequestMessage;
11use serde::{Deserialize, Serialize};
12use std::str::FromStr;
13
14#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)]
16pub enum PromptTemplateType {
17 #[value(name = "llama-2-chat")]
18 Llama2Chat,
19 #[value(name = "llama-3-chat")]
20 Llama3Chat,
21 #[value(name = "llama-3-tool")]
22 Llama3Tool,
23 #[value(name = "mistral-instruct")]
24 MistralInstruct,
25 #[value(name = "mistral-tool")]
26 MistralTool,
27 #[value(name = "mistrallite")]
28 MistralLite,
29 #[value(name = "mistral-small-chat")]
30 MistralSmallChat,
31 #[value(name = "mistral-small-tool")]
32 MistralSmallTool,
33 #[value(name = "openchat")]
34 OpenChat,
35 #[value(name = "codellama-instruct")]
36 CodeLlama,
37 #[value(name = "codellama-super-instruct")]
38 CodeLlamaSuper,
39 #[value(name = "human-assistant")]
40 HumanAssistant,
41 #[value(name = "vicuna-1.0-chat")]
42 VicunaChat,
43 #[value(name = "vicuna-1.1-chat")]
44 Vicuna11Chat,
45 #[value(name = "vicuna-llava")]
46 VicunaLlava,
47 #[value(name = "chatml")]
48 ChatML,
49 #[value(name = "chatml-tool")]
50 ChatMLTool,
51 #[value(name = "internlm-2-tool")]
52 InternLM2Tool,
53 #[value(name = "baichuan-2")]
54 Baichuan2,
55 #[value(name = "wizard-coder")]
56 WizardCoder,
57 #[value(name = "zephyr")]
58 Zephyr,
59 #[value(name = "stablelm-zephyr")]
60 StableLMZephyr,
61 #[value(name = "intel-neural")]
62 IntelNeural,
63 #[value(name = "deepseek-chat")]
64 DeepseekChat,
65 #[value(name = "deepseek-coder")]
66 DeepseekCoder,
67 #[value(name = "deepseek-chat-2")]
68 DeepseekChat2,
69 #[value(name = "deepseek-chat-25")]
70 DeepseekChat25,
71 #[value(name = "deepseek-chat-3")]
72 DeepseekChat3,
73 #[value(name = "solar-instruct")]
74 SolarInstruct,
75 #[value(name = "phi-2-chat")]
76 Phi2Chat,
77 #[value(name = "phi-2-instruct")]
78 Phi2Instruct,
79 #[value(name = "phi-3-chat")]
80 Phi3Chat,
81 #[value(name = "phi-3-instruct")]
82 Phi3Instruct,
83 #[value(name = "phi-4-chat")]
84 Phi4Chat,
85 #[value(name = "gemma-instruct")]
86 GemmaInstruct,
87 #[value(name = "gemma-3")]
88 Gemma3,
89 #[value(name = "octopus")]
90 Octopus,
91 #[value(name = "glm-4-chat")]
92 Glm4Chat,
93 #[value(name = "groq-llama3-tool")]
94 GroqLlama3Tool,
95 #[value(name = "mediatek-breeze")]
96 BreezeInstruct,
97 #[value(name = "nemotron-chat")]
98 NemotronChat,
99 #[value(name = "nemotron-tool")]
100 NemotronTool,
101 #[value(name = "functionary-32")]
102 FunctionaryV32,
103 #[value(name = "functionary-31")]
104 FunctionaryV31,
105 #[value(name = "minicpmv")]
106 MiniCPMV,
107 #[value(name = "moxin-chat")]
108 MoxinChat,
109 #[value(name = "falcon3")]
110 Falcon3,
111 #[value(name = "megrez")]
112 Megrez,
113 #[value(name = "qwen2-vision")]
114 Qwen2vl,
115 #[value(name = "embedding")]
116 Embedding,
117 #[value(name = "tts")]
118 Tts,
119 #[value(name = "none")]
120 Null,
121}
122impl PromptTemplateType {
123 pub fn has_system_prompt(&self) -> bool {
125 match self {
126 PromptTemplateType::Llama2Chat
127 | PromptTemplateType::Llama3Chat
128 | PromptTemplateType::Llama3Tool
129 | PromptTemplateType::CodeLlama
130 | PromptTemplateType::CodeLlamaSuper
131 | PromptTemplateType::VicunaChat
132 | PromptTemplateType::VicunaLlava
133 | PromptTemplateType::ChatML
134 | PromptTemplateType::ChatMLTool
135 | PromptTemplateType::InternLM2Tool
136 | PromptTemplateType::Baichuan2
137 | PromptTemplateType::WizardCoder
138 | PromptTemplateType::Zephyr
139 | PromptTemplateType::IntelNeural
140 | PromptTemplateType::DeepseekCoder
141 | PromptTemplateType::DeepseekChat2
142 | PromptTemplateType::DeepseekChat3
143 | PromptTemplateType::Octopus
144 | PromptTemplateType::Phi3Chat
145 | PromptTemplateType::Phi4Chat
146 | PromptTemplateType::Glm4Chat
147 | PromptTemplateType::GroqLlama3Tool
148 | PromptTemplateType::BreezeInstruct
149 | PromptTemplateType::DeepseekChat25
150 | PromptTemplateType::NemotronChat
151 | PromptTemplateType::NemotronTool
152 | PromptTemplateType::MiniCPMV
153 | PromptTemplateType::MoxinChat
154 | PromptTemplateType::Falcon3
155 | PromptTemplateType::Megrez
156 | PromptTemplateType::Qwen2vl
157 | PromptTemplateType::MistralSmallChat
158 | PromptTemplateType::MistralSmallTool => true,
159 PromptTemplateType::MistralInstruct
160 | PromptTemplateType::MistralTool
161 | PromptTemplateType::MistralLite
162 | PromptTemplateType::HumanAssistant
163 | PromptTemplateType::DeepseekChat
164 | PromptTemplateType::GemmaInstruct
165 | PromptTemplateType::Gemma3
166 | PromptTemplateType::OpenChat
167 | PromptTemplateType::Phi2Chat
168 | PromptTemplateType::Phi2Instruct
169 | PromptTemplateType::Phi3Instruct
170 | PromptTemplateType::SolarInstruct
171 | PromptTemplateType::Vicuna11Chat
172 | PromptTemplateType::StableLMZephyr
173 | PromptTemplateType::FunctionaryV32
174 | PromptTemplateType::FunctionaryV31
175 | PromptTemplateType::Embedding
176 | PromptTemplateType::Tts
177 | PromptTemplateType::Null => false,
178 }
179 }
180
181 pub fn is_image_supported(&self) -> bool {
183 matches!(
184 self,
185 PromptTemplateType::MiniCPMV
186 | PromptTemplateType::Qwen2vl
187 | PromptTemplateType::VicunaLlava
188 )
189 }
190}
191impl FromStr for PromptTemplateType {
192 type Err = error::PromptError;
193
194 fn from_str(template: &str) -> std::result::Result<Self, Self::Err> {
195 match template {
196 "llama-2-chat" => Ok(PromptTemplateType::Llama2Chat),
197 "llama-3-chat" => Ok(PromptTemplateType::Llama3Chat),
198 "llama-3-tool" => Ok(PromptTemplateType::Llama3Tool),
199 "mistral-instruct" => Ok(PromptTemplateType::MistralInstruct),
200 "mistral-tool" => Ok(PromptTemplateType::MistralTool),
201 "mistrallite" => Ok(PromptTemplateType::MistralLite),
202 "mistral-small-chat" => Ok(PromptTemplateType::MistralSmallChat),
203 "mistral-small-tool" => Ok(PromptTemplateType::MistralSmallTool),
204 "codellama-instruct" => Ok(PromptTemplateType::CodeLlama),
205 "codellama-super-instruct" => Ok(PromptTemplateType::CodeLlamaSuper),
206 "belle-llama-2-chat" => Ok(PromptTemplateType::HumanAssistant),
207 "human-assistant" => Ok(PromptTemplateType::HumanAssistant),
208 "vicuna-1.0-chat" => Ok(PromptTemplateType::VicunaChat),
209 "vicuna-1.1-chat" => Ok(PromptTemplateType::Vicuna11Chat),
210 "vicuna-llava" => Ok(PromptTemplateType::VicunaLlava),
211 "chatml" => Ok(PromptTemplateType::ChatML),
212 "chatml-tool" => Ok(PromptTemplateType::ChatMLTool),
213 "internlm-2-tool" => Ok(PromptTemplateType::InternLM2Tool),
214 "openchat" => Ok(PromptTemplateType::OpenChat),
215 "baichuan-2" => Ok(PromptTemplateType::Baichuan2),
216 "wizard-coder" => Ok(PromptTemplateType::WizardCoder),
217 "zephyr" => Ok(PromptTemplateType::Zephyr),
218 "stablelm-zephyr" => Ok(PromptTemplateType::StableLMZephyr),
219 "intel-neural" => Ok(PromptTemplateType::IntelNeural),
220 "deepseek-chat" => Ok(PromptTemplateType::DeepseekChat),
221 "deepseek-coder" => Ok(PromptTemplateType::DeepseekCoder),
222 "deepseek-chat-2" => Ok(PromptTemplateType::DeepseekChat2),
223 "deepseek-chat-25" => Ok(PromptTemplateType::DeepseekChat25),
224 "deepseek-chat-3" => Ok(PromptTemplateType::DeepseekChat3),
225 "solar-instruct" => Ok(PromptTemplateType::SolarInstruct),
226 "phi-2-chat" => Ok(PromptTemplateType::Phi2Chat),
227 "phi-2-instruct" => Ok(PromptTemplateType::Phi2Instruct),
228 "phi-3-chat" => Ok(PromptTemplateType::Phi3Chat),
229 "phi-3-instruct" => Ok(PromptTemplateType::Phi3Instruct),
230 "phi-4-chat" => Ok(PromptTemplateType::Phi4Chat),
231 "gemma-instruct" => Ok(PromptTemplateType::GemmaInstruct),
232 "gemma-3" => Ok(PromptTemplateType::Gemma3),
233 "octopus" => Ok(PromptTemplateType::Octopus),
234 "glm-4-chat" => Ok(PromptTemplateType::Glm4Chat),
235 "groq-llama3-tool" => Ok(PromptTemplateType::GroqLlama3Tool),
236 "mediatek-breeze" => Ok(PromptTemplateType::BreezeInstruct),
237 "nemotron-chat" => Ok(PromptTemplateType::NemotronChat),
238 "nemotron-tool" => Ok(PromptTemplateType::NemotronTool),
239 "functionary-32" => Ok(PromptTemplateType::FunctionaryV32),
240 "functionary-31" => Ok(PromptTemplateType::FunctionaryV31),
241 "minicpmv" => Ok(PromptTemplateType::MiniCPMV),
242 "moxin-chat" => Ok(PromptTemplateType::MoxinChat),
243 "falcon3" => Ok(PromptTemplateType::Falcon3),
244 "megrez" => Ok(PromptTemplateType::Megrez),
245 "qwen2-vision" => Ok(PromptTemplateType::Qwen2vl),
246 "embedding" => Ok(PromptTemplateType::Embedding),
247 "tts" => Ok(PromptTemplateType::Tts),
248 "none" => Ok(PromptTemplateType::Null),
249 _ => Err(error::PromptError::UnknownPromptTemplateType(
250 template.to_string(),
251 )),
252 }
253 }
254}
255impl std::fmt::Display for PromptTemplateType {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 match self {
258 PromptTemplateType::Llama2Chat => write!(f, "llama-2-chat"),
259 PromptTemplateType::Llama3Chat => write!(f, "llama-3-chat"),
260 PromptTemplateType::Llama3Tool => write!(f, "llama-3-tool"),
261 PromptTemplateType::MistralInstruct => write!(f, "mistral-instruct"),
262 PromptTemplateType::MistralTool => write!(f, "mistral-tool"),
263 PromptTemplateType::MistralLite => write!(f, "mistrallite"),
264 PromptTemplateType::MistralSmallChat => write!(f, "mistral-small-chat"),
265 PromptTemplateType::MistralSmallTool => write!(f, "mistral-small-tool"),
266 PromptTemplateType::OpenChat => write!(f, "openchat"),
267 PromptTemplateType::CodeLlama => write!(f, "codellama-instruct"),
268 PromptTemplateType::HumanAssistant => write!(f, "human-asistant"),
269 PromptTemplateType::VicunaChat => write!(f, "vicuna-1.0-chat"),
270 PromptTemplateType::Vicuna11Chat => write!(f, "vicuna-1.1-chat"),
271 PromptTemplateType::VicunaLlava => write!(f, "vicuna-llava"),
272 PromptTemplateType::ChatML => write!(f, "chatml"),
273 PromptTemplateType::ChatMLTool => write!(f, "chatml-tool"),
274 PromptTemplateType::InternLM2Tool => write!(f, "internlm-2-tool"),
275 PromptTemplateType::Baichuan2 => write!(f, "baichuan-2"),
276 PromptTemplateType::WizardCoder => write!(f, "wizard-coder"),
277 PromptTemplateType::Zephyr => write!(f, "zephyr"),
278 PromptTemplateType::StableLMZephyr => write!(f, "stablelm-zephyr"),
279 PromptTemplateType::IntelNeural => write!(f, "intel-neural"),
280 PromptTemplateType::DeepseekChat => write!(f, "deepseek-chat"),
281 PromptTemplateType::DeepseekCoder => write!(f, "deepseek-coder"),
282 PromptTemplateType::DeepseekChat2 => write!(f, "deepseek-chat-2"),
283 PromptTemplateType::DeepseekChat25 => write!(f, "deepseek-chat-25"),
284 PromptTemplateType::DeepseekChat3 => write!(f, "deepseek-chat-3"),
285 PromptTemplateType::SolarInstruct => write!(f, "solar-instruct"),
286 PromptTemplateType::Phi2Chat => write!(f, "phi-2-chat"),
287 PromptTemplateType::Phi2Instruct => write!(f, "phi-2-instruct"),
288 PromptTemplateType::Phi3Chat => write!(f, "phi-3-chat"),
289 PromptTemplateType::Phi3Instruct => write!(f, "phi-3-instruct"),
290 PromptTemplateType::Phi4Chat => write!(f, "phi-4-chat"),
291 PromptTemplateType::CodeLlamaSuper => write!(f, "codellama-super-instruct"),
292 PromptTemplateType::GemmaInstruct => write!(f, "gemma-instruct"),
293 PromptTemplateType::Gemma3 => write!(f, "gemma-3"),
294 PromptTemplateType::Octopus => write!(f, "octopus"),
295 PromptTemplateType::Glm4Chat => write!(f, "glm-4-chat"),
296 PromptTemplateType::GroqLlama3Tool => write!(f, "groq-llama3-tool"),
297 PromptTemplateType::BreezeInstruct => write!(f, "mediatek-breeze"),
298 PromptTemplateType::NemotronChat => write!(f, "nemotron-chat"),
299 PromptTemplateType::NemotronTool => write!(f, "nemotron-tool"),
300 PromptTemplateType::FunctionaryV32 => write!(f, "functionary-32"),
301 PromptTemplateType::FunctionaryV31 => write!(f, "functionary-31"),
302 PromptTemplateType::MiniCPMV => write!(f, "minicpmv"),
303 PromptTemplateType::MoxinChat => write!(f, "moxin-chat"),
304 PromptTemplateType::Falcon3 => write!(f, "falcon3"),
305 PromptTemplateType::Megrez => write!(f, "megrez"),
306 PromptTemplateType::Qwen2vl => write!(f, "qwen2-vision"),
307 PromptTemplateType::Embedding => write!(f, "embedding"),
308 PromptTemplateType::Tts => write!(f, "tts"),
309 PromptTemplateType::Null => write!(f, "none"),
310 }
311 }
312}
313
314pub trait MergeRagContext: Send {
316 fn build(
330 messages: &mut Vec<endpoints::chat::ChatCompletionRequestMessage>,
331 context: &[String],
332 has_system_prompt: bool,
333 policy: MergeRagContextPolicy,
334 ) -> error::Result<()> {
335 if (policy == MergeRagContextPolicy::SystemMessage) && has_system_prompt {
336 if messages.is_empty() {
337 return Err(error::PromptError::NoMessages);
338 }
339
340 if context.is_empty() {
341 return Err(error::PromptError::Operation(
342 "No context provided.".to_string(),
343 ));
344 }
345
346 let context = context[0].trim_end();
347
348 match messages[0] {
350 ChatCompletionRequestMessage::System(ref message) => {
351 let content = format!("{original_system_message}\nUse the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n{context}", original_system_message=message.content().trim(), context=context.trim_end());
353 let system_message = ChatCompletionRequestMessage::new_system_message(
355 content,
356 messages[0].name().cloned(),
357 );
358 messages[0] = system_message;
360 }
361 _ => {
362 let content = format!("Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n{}", context.trim_end());
364
365 let system_message = ChatCompletionRequestMessage::new_system_message(
367 content,
368 messages[0].name().cloned(),
369 );
370 messages.insert(0, system_message);
372 }
373 };
374 }
375
376 Ok(())
377 }
378}
379
380#[derive(Clone, Debug, Copy, Default, PartialEq, Eq, Serialize, Deserialize, ValueEnum)]
382pub enum MergeRagContextPolicy {
383 #[default]
387 #[serde(rename = "system-message")]
388 SystemMessage,
389 #[serde(rename = "last-user-message")]
391 LastUserMessage,
392}
393impl std::fmt::Display for MergeRagContextPolicy {
394 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395 match self {
396 MergeRagContextPolicy::SystemMessage => write!(f, "system-message"),
397 MergeRagContextPolicy::LastUserMessage => write!(f, "last-user-message"),
398 }
399 }
400}
401impl FromStr for MergeRagContextPolicy {
402 type Err = error::PromptError;
403
404 fn from_str(policy: &str) -> std::result::Result<Self, Self::Err> {
405 Ok(match policy {
406 "system-message" => MergeRagContextPolicy::SystemMessage,
407 "last-user-message" => MergeRagContextPolicy::LastUserMessage,
408 _ => {
409 return Err(error::PromptError::UnknownMergeRagContextPolicy(
410 policy.to_string(),
411 ))
412 }
413 })
414 }
415}