Skip to main content

abu_base/chat/
message.rs

1use std::fmt::Display;
2
3use crate::common::Role;
4
5use super::tool::*;
6use serde::{Deserialize, Serialize};
7use strum::{EnumMessage, EnumVariantNames};
8
9#[derive(Debug, Clone, Serialize, EnumVariantNames, EnumMessage)]
10#[serde(rename_all = "snake_case", tag = "role")]
11pub enum ChatMessage {
12    /// A message from a system.
13    System(SystemMessage),
14    /// A message from a human.
15    User(UserMessage),
16    /// A message from the assistant.
17    Assistant(AssistantMessage),
18    /// A message from a tool.
19    Tool(ToolMessage),
20}
21
22impl Into<ChatMessage> for SystemMessage {
23    fn into(self) -> ChatMessage {
24        ChatMessage::System(self)
25    }
26}
27
28impl Into<ChatMessage> for UserMessage {
29    fn into(self) -> ChatMessage {
30        ChatMessage::User(self)
31    }
32}
33
34impl Into<ChatMessage> for AssistantMessage {
35    fn into(self) -> ChatMessage {
36        ChatMessage::Assistant(self)
37    }
38}
39
40impl Into<ChatMessage> for ToolMessage {
41    fn into(self) -> ChatMessage {
42        ChatMessage::Tool(self)
43    }
44}
45
46#[derive(Debug, Clone, Serialize)]
47pub struct SystemMessage {
48    /// The contents of the system message.
49    pub content: String,
50    /// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub name: Option<String>,
53}
54
55#[derive(Debug, Clone, Serialize)]
56pub struct UserMessage {
57    /// The contents of the user message.
58    pub content: String,
59    /// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub name: Option<String>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct AssistantMessage {
66    /// The contents of the system message.
67    #[serde(default)]
68    pub content: String,
69    /// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
70    #[serde(skip_serializing_if = "Option::is_none", default)]
71    pub name: Option<String>,    
72    /// The tool calls generated by the model, such as function calls
73    #[serde(skip_serializing_if = "Vec::is_empty", default)]
74    pub tool_calls: Vec<ToolCall>,
75}
76
77#[derive(Debug, Clone, Serialize)]
78pub struct ToolMessage {
79    pub content: String,
80    pub tool_call_id: String,
81}
82
83impl ChatMessage {
84    pub fn role(&self) -> Role {
85        match self {
86            Self::Assistant(_) => Role::Assistant,
87            Self::User(_) => Role::User,
88            Self::Tool(_) => Role::Tool,
89            Self::System(_) => Role::System,
90        }
91    }
92
93    pub fn content(&self) -> &str {
94        match self {
95            Self::Assistant(m) => &m.content,
96            Self::User(m) => &m.content,
97            Self::Tool(m) => &m.content,
98            Self::System(m) => &m.content,
99        }
100    }
101
102    pub fn system(content: impl Into<String>) -> Self {
103        Self::System(SystemMessage {
104            content: content.into(),
105            name: None,
106        })
107    }
108
109    pub fn user(content: impl Into<String>) -> Self {
110        Self::User(UserMessage {
111            content: content.into(),
112            name: None,
113        })
114    }
115
116    pub fn assistant(content: impl Into<String>, tool_calls: impl Into<Vec<ToolCall>>) -> Self {
117        Self::Assistant(AssistantMessage {
118            content: content.into(),
119            name: None,
120            tool_calls: tool_calls.into(),
121        })
122    }
123
124    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
125        Self::Tool(ToolMessage {
126            content: content.into(),
127            tool_call_id: tool_call_id.into()
128        })
129    }
130
131    pub fn is_system(&self) -> bool {
132        matches!(self, Self::System(_))
133    }
134
135    pub fn is_user(&self) -> bool {
136        matches!(self, Self::User(_))
137    }
138
139    pub fn is_assistant(&self) -> bool {
140        matches!(self, Self::Assistant(_))
141    }
142
143    pub fn is_tool(&self) -> bool {
144        matches!(self, Self::Tool(_))
145    }
146
147    pub fn as_system(&self) -> Option<&SystemMessage> {
148        if let Self::System(msg) = self {
149            Some(msg)
150        } else {
151            None
152        }
153    }
154
155    pub fn as_user(&self) -> Option<&UserMessage> {
156        if let Self::User(msg) = self {
157            Some(msg)
158        } else {
159            None
160        }
161    }
162
163    pub fn as_assistant(&self) -> Option<&AssistantMessage> {
164        if let Self::Assistant(msg) = self {
165            Some(msg)
166        } else {
167            None
168        }
169    }
170
171    pub fn as_tool(&self) -> Option<&ToolMessage> {
172        if let Self::Tool(msg) = self {
173            Some(msg)
174        } else {
175            None
176        }
177    }
178}
179
180impl Display for Role {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self {
183            Role::System => write!(f, "system"),
184            Role::User => write!(f, "user"),
185            Role::Assistant => write!(f, "assiatant"),
186            Role::Tool => write!(f, "tool"),
187        }
188    }
189}