1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use super::traits::Prompt;
use crate::PromptTemplate;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum ChatRole {
    User,
    Assistant,
    System,
    Other(String),
}

impl fmt::Display for ChatRole {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            ChatRole::User => write!(f, "User"),
            ChatRole::Assistant => write!(f, "Assistant"),
            ChatRole::System => write!(f, "System"),
            ChatRole::Other(s) => write!(f, "{}", s),
        }
    }
}

#[derive(Debug, Builder, Clone, Serialize, Deserialize)]
#[builder(setter(into))]
pub struct ChatMessage {
    role: ChatRole,
    content: PromptTemplate,
}

impl ChatMessage {
    /// Creates a new `ChatMessage` from a role and a string.
    #[cfg(feature = "tera")]
    pub fn new<S: Into<String>>(role: ChatRole, content: S) -> Self {
        Self {
            role,
            content: PromptTemplate::tera(content.into()),
        }
    }
    #[cfg(not(feature = "tera"))]
    pub fn new<S: Into<String>>(role: ChatRole, content: S) -> Self {
        Self {
            role,
            content: PromptTemplate::legacy(content.into()),
        }
    }

    /// Creates a new `ChatMessage` from a role and a prompt template.
    pub fn from_template(role: ChatRole, content: PromptTemplate) -> Self {
        Self { role, content }
    }
    pub fn role(&self) -> ChatRole {
        self.role.clone()
    }
    pub fn content(&self) -> PromptTemplate {
        self.content.clone()
    }
}

#[derive(Debug, Builder, Clone, Serialize, Deserialize)]
pub struct ChatPrompt {
    messages: Vec<ChatMessage>,
}

impl ChatPrompt {
    /// Returns a new `ChatPromptBuilder` for building a `ChatPrompt`.
    pub fn builder() -> ChatPromptBuilder {
        ChatPromptBuilder::default()
    }
    pub fn to_builder(&self) -> ChatPromptBuilder {
        let mut cpb = ChatPromptBuilder::default();
        cpb.messages(self.messages.clone());
        cpb
    }
}

impl Prompt for ChatPrompt {
    fn as_chat_prompt(&self) -> Vec<ChatMessage> {
        self.messages.clone()
    }

    fn as_text_prompt(&self) -> Option<&PromptTemplate> {
        None
    }
}

impl fmt::Display for ChatPrompt {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        for message in &self.messages {
            writeln!(f, "{}: {}", message.role, message.content)?;
        }
        Ok(())
    }
}

// Adding extension methods for ChatPromptBuilder to push user, agent, and system messages
impl ChatPromptBuilder {
    pub fn new() -> Self {
        Self::default()
    }

    /// Adds a chat message to the prompt
    pub fn add_message(mut self, message: ChatMessage) -> Self {
        self.messages
            .get_or_insert_with(std::vec::Vec::new)
            .push(message);
        self
    }

    /// Adds a user message to the prompt
    pub fn user<S: Into<String>>(self, message: S) -> Self {
        self.add_message(ChatMessage::new(ChatRole::User, message))
    }

    /// Adds an agent message to the prompt
    pub fn assistant<S: Into<String>>(self, message: S) -> Self {
        self.add_message(ChatMessage::new(ChatRole::Assistant, message))
    }

    /// Adds a system message to the prompt
    pub fn system<S: Into<String>>(self, message: S) -> Self {
        self.add_message(ChatMessage::new(ChatRole::System, message))
    }
}