chat_prompts/
lib.rs

1//! `chat-prompts` is part of [LlamaEdge API Server](https://github.com/LlamaEdge/LlamaEdge/tree/main/api-server) project. It provides a collection of prompt templates that are used to generate prompts for the LLMs (See models in [huggingface.co/second-state](https://huggingface.co/second-state)).
2//!
3//! For the details of available prompt templates, see [README.md](https://github.com/LlamaEdge/LlamaEdge/tree/main/api-server/chat-prompts).
4
5pub mod chat;
6pub mod error;
7pub mod utils;
8
9use clap::ValueEnum;
10use endpoints::chat::ChatCompletionRequestMessage;
11use serde::{Deserialize, Serialize};
12use std::str::FromStr;
13
14/// Define the chat prompt template types.
15#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)]
16pub enum PromptTemplateType {
17    #[value(name = "llama-2-chat")]
18    Llama2Chat,
19    #[value(name = "llama-3-chat")]
20    Llama3Chat,
21    #[value(name = "llama-3-tool")]
22    Llama3Tool,
23    #[value(name = "mistral-instruct")]
24    MistralInstruct,
25    #[value(name = "mistral-tool")]
26    MistralTool,
27    #[value(name = "mistrallite")]
28    MistralLite,
29    #[value(name = "mistral-small-chat")]
30    MistralSmallChat,
31    #[value(name = "mistral-small-tool")]
32    MistralSmallTool,
33    #[value(name = "openchat")]
34    OpenChat,
35    #[value(name = "codellama-instruct")]
36    CodeLlama,
37    #[value(name = "codellama-super-instruct")]
38    CodeLlamaSuper,
39    #[value(name = "human-assistant")]
40    HumanAssistant,
41    #[value(name = "vicuna-1.0-chat")]
42    VicunaChat,
43    #[value(name = "vicuna-1.1-chat")]
44    Vicuna11Chat,
45    #[value(name = "vicuna-llava")]
46    VicunaLlava,
47    #[value(name = "chatml")]
48    ChatML,
49    #[value(name = "chatml-tool")]
50    ChatMLTool,
51    #[value(name = "internlm-2-tool")]
52    InternLM2Tool,
53    #[value(name = "baichuan-2")]
54    Baichuan2,
55    #[value(name = "wizard-coder")]
56    WizardCoder,
57    #[value(name = "zephyr")]
58    Zephyr,
59    #[value(name = "stablelm-zephyr")]
60    StableLMZephyr,
61    #[value(name = "intel-neural")]
62    IntelNeural,
63    #[value(name = "deepseek-chat")]
64    DeepseekChat,
65    #[value(name = "deepseek-coder")]
66    DeepseekCoder,
67    #[value(name = "deepseek-chat-2")]
68    DeepseekChat2,
69    #[value(name = "deepseek-chat-25")]
70    DeepseekChat25,
71    #[value(name = "deepseek-chat-3")]
72    DeepseekChat3,
73    #[value(name = "solar-instruct")]
74    SolarInstruct,
75    #[value(name = "phi-2-chat")]
76    Phi2Chat,
77    #[value(name = "phi-2-instruct")]
78    Phi2Instruct,
79    #[value(name = "phi-3-chat")]
80    Phi3Chat,
81    #[value(name = "phi-3-instruct")]
82    Phi3Instruct,
83    #[value(name = "phi-4-chat")]
84    Phi4Chat,
85    #[value(name = "gemma-instruct")]
86    GemmaInstruct,
87    #[value(name = "gemma-3")]
88    Gemma3,
89    #[value(name = "octopus")]
90    Octopus,
91    #[value(name = "glm-4-chat")]
92    Glm4Chat,
93    #[value(name = "groq-llama3-tool")]
94    GroqLlama3Tool,
95    #[value(name = "mediatek-breeze")]
96    BreezeInstruct,
97    #[value(name = "nemotron-chat")]
98    NemotronChat,
99    #[value(name = "nemotron-tool")]
100    NemotronTool,
101    #[value(name = "functionary-32")]
102    FunctionaryV32,
103    #[value(name = "functionary-31")]
104    FunctionaryV31,
105    #[value(name = "minicpmv")]
106    MiniCPMV,
107    #[value(name = "moxin-chat")]
108    MoxinChat,
109    #[value(name = "falcon3")]
110    Falcon3,
111    #[value(name = "megrez")]
112    Megrez,
113    #[value(name = "qwen2-vision")]
114    Qwen2vl,
115    #[value(name = "embedding")]
116    Embedding,
117    #[value(name = "tts")]
118    Tts,
119    #[value(name = "none")]
120    Null,
121}
122impl PromptTemplateType {
123    /// Check if the prompt template has a system prompt.
124    pub fn has_system_prompt(&self) -> bool {
125        match self {
126            PromptTemplateType::Llama2Chat
127            | PromptTemplateType::Llama3Chat
128            | PromptTemplateType::Llama3Tool
129            | PromptTemplateType::CodeLlama
130            | PromptTemplateType::CodeLlamaSuper
131            | PromptTemplateType::VicunaChat
132            | PromptTemplateType::VicunaLlava
133            | PromptTemplateType::ChatML
134            | PromptTemplateType::ChatMLTool
135            | PromptTemplateType::InternLM2Tool
136            | PromptTemplateType::Baichuan2
137            | PromptTemplateType::WizardCoder
138            | PromptTemplateType::Zephyr
139            | PromptTemplateType::IntelNeural
140            | PromptTemplateType::DeepseekCoder
141            | PromptTemplateType::DeepseekChat2
142            | PromptTemplateType::DeepseekChat3
143            | PromptTemplateType::Octopus
144            | PromptTemplateType::Phi3Chat
145            | PromptTemplateType::Phi4Chat
146            | PromptTemplateType::Glm4Chat
147            | PromptTemplateType::GroqLlama3Tool
148            | PromptTemplateType::BreezeInstruct
149            | PromptTemplateType::DeepseekChat25
150            | PromptTemplateType::NemotronChat
151            | PromptTemplateType::NemotronTool
152            | PromptTemplateType::MiniCPMV
153            | PromptTemplateType::MoxinChat
154            | PromptTemplateType::Falcon3
155            | PromptTemplateType::Megrez
156            | PromptTemplateType::Qwen2vl
157            | PromptTemplateType::MistralSmallChat
158            | PromptTemplateType::MistralSmallTool => true,
159            PromptTemplateType::MistralInstruct
160            | PromptTemplateType::MistralTool
161            | PromptTemplateType::MistralLite
162            | PromptTemplateType::HumanAssistant
163            | PromptTemplateType::DeepseekChat
164            | PromptTemplateType::GemmaInstruct
165            | PromptTemplateType::Gemma3
166            | PromptTemplateType::OpenChat
167            | PromptTemplateType::Phi2Chat
168            | PromptTemplateType::Phi2Instruct
169            | PromptTemplateType::Phi3Instruct
170            | PromptTemplateType::SolarInstruct
171            | PromptTemplateType::Vicuna11Chat
172            | PromptTemplateType::StableLMZephyr
173            | PromptTemplateType::FunctionaryV32
174            | PromptTemplateType::FunctionaryV31
175            | PromptTemplateType::Embedding
176            | PromptTemplateType::Tts
177            | PromptTemplateType::Null => false,
178        }
179    }
180
181    /// Check if the prompt template supports image input.
182    pub fn is_image_supported(&self) -> bool {
183        matches!(
184            self,
185            PromptTemplateType::MiniCPMV
186                | PromptTemplateType::Qwen2vl
187                | PromptTemplateType::VicunaLlava
188        )
189    }
190}
191impl FromStr for PromptTemplateType {
192    type Err = error::PromptError;
193
194    fn from_str(template: &str) -> std::result::Result<Self, Self::Err> {
195        match template {
196            "llama-2-chat" => Ok(PromptTemplateType::Llama2Chat),
197            "llama-3-chat" => Ok(PromptTemplateType::Llama3Chat),
198            "llama-3-tool" => Ok(PromptTemplateType::Llama3Tool),
199            "mistral-instruct" => Ok(PromptTemplateType::MistralInstruct),
200            "mistral-tool" => Ok(PromptTemplateType::MistralTool),
201            "mistrallite" => Ok(PromptTemplateType::MistralLite),
202            "mistral-small-chat" => Ok(PromptTemplateType::MistralSmallChat),
203            "mistral-small-tool" => Ok(PromptTemplateType::MistralSmallTool),
204            "codellama-instruct" => Ok(PromptTemplateType::CodeLlama),
205            "codellama-super-instruct" => Ok(PromptTemplateType::CodeLlamaSuper),
206            "belle-llama-2-chat" => Ok(PromptTemplateType::HumanAssistant),
207            "human-assistant" => Ok(PromptTemplateType::HumanAssistant),
208            "vicuna-1.0-chat" => Ok(PromptTemplateType::VicunaChat),
209            "vicuna-1.1-chat" => Ok(PromptTemplateType::Vicuna11Chat),
210            "vicuna-llava" => Ok(PromptTemplateType::VicunaLlava),
211            "chatml" => Ok(PromptTemplateType::ChatML),
212            "chatml-tool" => Ok(PromptTemplateType::ChatMLTool),
213            "internlm-2-tool" => Ok(PromptTemplateType::InternLM2Tool),
214            "openchat" => Ok(PromptTemplateType::OpenChat),
215            "baichuan-2" => Ok(PromptTemplateType::Baichuan2),
216            "wizard-coder" => Ok(PromptTemplateType::WizardCoder),
217            "zephyr" => Ok(PromptTemplateType::Zephyr),
218            "stablelm-zephyr" => Ok(PromptTemplateType::StableLMZephyr),
219            "intel-neural" => Ok(PromptTemplateType::IntelNeural),
220            "deepseek-chat" => Ok(PromptTemplateType::DeepseekChat),
221            "deepseek-coder" => Ok(PromptTemplateType::DeepseekCoder),
222            "deepseek-chat-2" => Ok(PromptTemplateType::DeepseekChat2),
223            "deepseek-chat-25" => Ok(PromptTemplateType::DeepseekChat25),
224            "deepseek-chat-3" => Ok(PromptTemplateType::DeepseekChat3),
225            "solar-instruct" => Ok(PromptTemplateType::SolarInstruct),
226            "phi-2-chat" => Ok(PromptTemplateType::Phi2Chat),
227            "phi-2-instruct" => Ok(PromptTemplateType::Phi2Instruct),
228            "phi-3-chat" => Ok(PromptTemplateType::Phi3Chat),
229            "phi-3-instruct" => Ok(PromptTemplateType::Phi3Instruct),
230            "phi-4-chat" => Ok(PromptTemplateType::Phi4Chat),
231            "gemma-instruct" => Ok(PromptTemplateType::GemmaInstruct),
232            "gemma-3" => Ok(PromptTemplateType::Gemma3),
233            "octopus" => Ok(PromptTemplateType::Octopus),
234            "glm-4-chat" => Ok(PromptTemplateType::Glm4Chat),
235            "groq-llama3-tool" => Ok(PromptTemplateType::GroqLlama3Tool),
236            "mediatek-breeze" => Ok(PromptTemplateType::BreezeInstruct),
237            "nemotron-chat" => Ok(PromptTemplateType::NemotronChat),
238            "nemotron-tool" => Ok(PromptTemplateType::NemotronTool),
239            "functionary-32" => Ok(PromptTemplateType::FunctionaryV32),
240            "functionary-31" => Ok(PromptTemplateType::FunctionaryV31),
241            "minicpmv" => Ok(PromptTemplateType::MiniCPMV),
242            "moxin-chat" => Ok(PromptTemplateType::MoxinChat),
243            "falcon3" => Ok(PromptTemplateType::Falcon3),
244            "megrez" => Ok(PromptTemplateType::Megrez),
245            "qwen2-vision" => Ok(PromptTemplateType::Qwen2vl),
246            "embedding" => Ok(PromptTemplateType::Embedding),
247            "tts" => Ok(PromptTemplateType::Tts),
248            "none" => Ok(PromptTemplateType::Null),
249            _ => Err(error::PromptError::UnknownPromptTemplateType(
250                template.to_string(),
251            )),
252        }
253    }
254}
255impl std::fmt::Display for PromptTemplateType {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        match self {
258            PromptTemplateType::Llama2Chat => write!(f, "llama-2-chat"),
259            PromptTemplateType::Llama3Chat => write!(f, "llama-3-chat"),
260            PromptTemplateType::Llama3Tool => write!(f, "llama-3-tool"),
261            PromptTemplateType::MistralInstruct => write!(f, "mistral-instruct"),
262            PromptTemplateType::MistralTool => write!(f, "mistral-tool"),
263            PromptTemplateType::MistralLite => write!(f, "mistrallite"),
264            PromptTemplateType::MistralSmallChat => write!(f, "mistral-small-chat"),
265            PromptTemplateType::MistralSmallTool => write!(f, "mistral-small-tool"),
266            PromptTemplateType::OpenChat => write!(f, "openchat"),
267            PromptTemplateType::CodeLlama => write!(f, "codellama-instruct"),
268            PromptTemplateType::HumanAssistant => write!(f, "human-asistant"),
269            PromptTemplateType::VicunaChat => write!(f, "vicuna-1.0-chat"),
270            PromptTemplateType::Vicuna11Chat => write!(f, "vicuna-1.1-chat"),
271            PromptTemplateType::VicunaLlava => write!(f, "vicuna-llava"),
272            PromptTemplateType::ChatML => write!(f, "chatml"),
273            PromptTemplateType::ChatMLTool => write!(f, "chatml-tool"),
274            PromptTemplateType::InternLM2Tool => write!(f, "internlm-2-tool"),
275            PromptTemplateType::Baichuan2 => write!(f, "baichuan-2"),
276            PromptTemplateType::WizardCoder => write!(f, "wizard-coder"),
277            PromptTemplateType::Zephyr => write!(f, "zephyr"),
278            PromptTemplateType::StableLMZephyr => write!(f, "stablelm-zephyr"),
279            PromptTemplateType::IntelNeural => write!(f, "intel-neural"),
280            PromptTemplateType::DeepseekChat => write!(f, "deepseek-chat"),
281            PromptTemplateType::DeepseekCoder => write!(f, "deepseek-coder"),
282            PromptTemplateType::DeepseekChat2 => write!(f, "deepseek-chat-2"),
283            PromptTemplateType::DeepseekChat25 => write!(f, "deepseek-chat-25"),
284            PromptTemplateType::DeepseekChat3 => write!(f, "deepseek-chat-3"),
285            PromptTemplateType::SolarInstruct => write!(f, "solar-instruct"),
286            PromptTemplateType::Phi2Chat => write!(f, "phi-2-chat"),
287            PromptTemplateType::Phi2Instruct => write!(f, "phi-2-instruct"),
288            PromptTemplateType::Phi3Chat => write!(f, "phi-3-chat"),
289            PromptTemplateType::Phi3Instruct => write!(f, "phi-3-instruct"),
290            PromptTemplateType::Phi4Chat => write!(f, "phi-4-chat"),
291            PromptTemplateType::CodeLlamaSuper => write!(f, "codellama-super-instruct"),
292            PromptTemplateType::GemmaInstruct => write!(f, "gemma-instruct"),
293            PromptTemplateType::Gemma3 => write!(f, "gemma-3"),
294            PromptTemplateType::Octopus => write!(f, "octopus"),
295            PromptTemplateType::Glm4Chat => write!(f, "glm-4-chat"),
296            PromptTemplateType::GroqLlama3Tool => write!(f, "groq-llama3-tool"),
297            PromptTemplateType::BreezeInstruct => write!(f, "mediatek-breeze"),
298            PromptTemplateType::NemotronChat => write!(f, "nemotron-chat"),
299            PromptTemplateType::NemotronTool => write!(f, "nemotron-tool"),
300            PromptTemplateType::FunctionaryV32 => write!(f, "functionary-32"),
301            PromptTemplateType::FunctionaryV31 => write!(f, "functionary-31"),
302            PromptTemplateType::MiniCPMV => write!(f, "minicpmv"),
303            PromptTemplateType::MoxinChat => write!(f, "moxin-chat"),
304            PromptTemplateType::Falcon3 => write!(f, "falcon3"),
305            PromptTemplateType::Megrez => write!(f, "megrez"),
306            PromptTemplateType::Qwen2vl => write!(f, "qwen2-vision"),
307            PromptTemplateType::Embedding => write!(f, "embedding"),
308            PromptTemplateType::Tts => write!(f, "tts"),
309            PromptTemplateType::Null => write!(f, "none"),
310        }
311    }
312}
313
314/// Trait for merging RAG context into chat messages
315pub trait MergeRagContext: Send {
316    /// Merge RAG context into chat messages.
317    ///
318    /// Note that the default implementation simply merges the RAG context into the system message. That is, to use the default implementation, `has_system_prompt` should be set to `true` and `policy` set to `MergeRagContextPolicy::SystemMessage`.
319    ///
320    /// # Arguments
321    ///
322    /// * `messages` - The chat messages to merge the context into.
323    ///
324    /// * `context` - The RAG context to merge into the chat messages.
325    ///
326    /// * `has_system_prompt` - Whether the chat template has a system prompt.
327    ///
328    /// * `policy` - The policy for merging RAG context into chat messages.
329    fn build(
330        messages: &mut Vec<endpoints::chat::ChatCompletionRequestMessage>,
331        context: &[String],
332        has_system_prompt: bool,
333        policy: MergeRagContextPolicy,
334    ) -> error::Result<()> {
335        if (policy == MergeRagContextPolicy::SystemMessage) && has_system_prompt {
336            if messages.is_empty() {
337                return Err(error::PromptError::NoMessages);
338            }
339
340            if context.is_empty() {
341                return Err(error::PromptError::Operation(
342                    "No context provided.".to_string(),
343                ));
344            }
345
346            let context = context[0].trim_end();
347
348            // update or insert system message
349            match messages[0] {
350                ChatCompletionRequestMessage::System(ref message) => {
351                    // compose new system message content
352                    let content = format!("{original_system_message}\nUse the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n{context}", original_system_message=message.content().trim(), context=context.trim_end());
353                    // create system message
354                    let system_message = ChatCompletionRequestMessage::new_system_message(
355                        content,
356                        messages[0].name().cloned(),
357                    );
358                    // replace the original system message
359                    messages[0] = system_message;
360                }
361                _ => {
362                    // prepare system message
363                    let content = format!("Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n{}", context.trim_end());
364
365                    // create system message
366                    let system_message = ChatCompletionRequestMessage::new_system_message(
367                        content,
368                        messages[0].name().cloned(),
369                    );
370                    // insert system message
371                    messages.insert(0, system_message);
372                }
373            };
374        }
375
376        Ok(())
377    }
378}
379
380/// Define the strategy for merging RAG context into chat messages.
381#[derive(Clone, Debug, Copy, Default, PartialEq, Eq, Serialize, Deserialize, ValueEnum)]
382pub enum MergeRagContextPolicy {
383    /// Merge RAG context into the system message.
384    ///
385    /// Note that this policy is only applicable when the chat template has a system message.
386    #[default]
387    #[serde(rename = "system-message")]
388    SystemMessage,
389    /// Merge RAG context into the last user message.
390    #[serde(rename = "last-user-message")]
391    LastUserMessage,
392}
393impl std::fmt::Display for MergeRagContextPolicy {
394    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395        match self {
396            MergeRagContextPolicy::SystemMessage => write!(f, "system-message"),
397            MergeRagContextPolicy::LastUserMessage => write!(f, "last-user-message"),
398        }
399    }
400}
401impl FromStr for MergeRagContextPolicy {
402    type Err = error::PromptError;
403
404    fn from_str(policy: &str) -> std::result::Result<Self, Self::Err> {
405        Ok(match policy {
406            "system-message" => MergeRagContextPolicy::SystemMessage,
407            "last-user-message" => MergeRagContextPolicy::LastUserMessage,
408            _ => {
409                return Err(error::PromptError::UnknownMergeRagContextPolicy(
410                    policy.to_string(),
411                ))
412            }
413        })
414    }
415}