1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "lowercase")]
7pub enum TrainingRole {
8 System,
10 User,
12 Assistant,
14 Tool,
16}
17
18impl std::fmt::Display for TrainingRole {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 Self::System => write!(f, "system"),
22 Self::User => write!(f, "user"),
23 Self::Assistant => write!(f, "assistant"),
24 Self::Tool => write!(f, "tool"),
25 }
26 }
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TrainingMessage {
32 pub role: TrainingRole,
34 pub content: String,
36 #[serde(skip_serializing_if = "Option::is_none")]
38 pub tool_calls: Option<Vec<serde_json::Value>>,
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub tool_call_id: Option<String>,
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub name: Option<String>,
45}
46
47impl TrainingMessage {
48 pub fn new(role: TrainingRole, content: impl Into<String>) -> Self {
50 Self {
51 role,
52 content: content.into(),
53 tool_calls: None,
54 tool_call_id: None,
55 name: None,
56 }
57 }
58
59 pub fn system(content: impl Into<String>) -> Self {
61 Self::new(TrainingRole::System, content)
62 }
63
64 pub fn user(content: impl Into<String>) -> Self {
66 Self::new(TrainingRole::User, content)
67 }
68
69 pub fn assistant(content: impl Into<String>) -> Self {
71 Self::new(TrainingRole::Assistant, content)
72 }
73
74 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
76 Self {
77 role: TrainingRole::Tool,
78 content: content.into(),
79 tool_calls: None,
80 tool_call_id: Some(tool_call_id.into()),
81 name: None,
82 }
83 }
84
85 pub fn estimated_tokens(&self) -> usize {
87 self.content.len() / 4 + 1
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct TrainingExample {
94 #[serde(default = "generate_id")]
96 pub id: String,
97 pub messages: Vec<TrainingMessage>,
99 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
101 pub metadata: HashMap<String, serde_json::Value>,
102}
103
104fn generate_id() -> String {
105 uuid::Uuid::new_v4().to_string()
106}
107
108impl TrainingExample {
109 pub fn new(messages: Vec<TrainingMessage>) -> Self {
111 Self {
112 id: generate_id(),
113 messages,
114 metadata: HashMap::new(),
115 }
116 }
117
118 pub fn with_id(id: impl Into<String>, messages: Vec<TrainingMessage>) -> Self {
120 Self {
121 id: id.into(),
122 messages,
123 metadata: HashMap::new(),
124 }
125 }
126
127 pub fn estimated_tokens(&self) -> usize {
129 self.messages.iter().map(|m| m.estimated_tokens()).sum()
130 }
131
132 pub fn has_system_message(&self) -> bool {
134 self.messages.iter().any(|m| m.role == TrainingRole::System)
135 }
136
137 pub fn ends_with_assistant(&self) -> bool {
139 self.messages
140 .last()
141 .is_some_and(|m| m.role == TrainingRole::Assistant)
142 }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct PreferencePair {
148 #[serde(default = "generate_id")]
150 pub id: String,
151 pub prompt: Vec<TrainingMessage>,
153 pub chosen: Vec<TrainingMessage>,
155 pub rejected: Vec<TrainingMessage>,
157 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
159 pub metadata: HashMap<String, serde_json::Value>,
160}
161
162impl PreferencePair {
163 pub fn new(
165 prompt: Vec<TrainingMessage>,
166 chosen: Vec<TrainingMessage>,
167 rejected: Vec<TrainingMessage>,
168 ) -> Self {
169 Self {
170 id: generate_id(),
171 prompt,
172 chosen,
173 rejected,
174 metadata: HashMap::new(),
175 }
176 }
177
178 pub fn estimated_tokens(&self) -> usize {
180 let prompt_tokens: usize = self.prompt.iter().map(|m| m.estimated_tokens()).sum();
181 let chosen_tokens: usize = self.chosen.iter().map(|m| m.estimated_tokens()).sum();
182 let rejected_tokens: usize = self.rejected.iter().map(|m| m.estimated_tokens()).sum();
183 prompt_tokens + chosen_tokens + rejected_tokens
184 }
185}
186
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
189#[serde(rename_all = "lowercase")]
190pub enum DataFormat {
191 OpenAI,
193 Together,
195 Alpaca,
197 ShareGpt,
199 ChatMl,
201}
202
203impl std::fmt::Display for DataFormat {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 match self {
206 Self::OpenAI => write!(f, "openai"),
207 Self::Together => write!(f, "together"),
208 Self::Alpaca => write!(f, "alpaca"),
209 Self::ShareGpt => write!(f, "sharegpt"),
210 Self::ChatMl => write!(f, "chatml"),
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_training_message_creation() {
221 let msg = TrainingMessage::system("You are a helpful assistant.");
222 assert_eq!(msg.role, TrainingRole::System);
223 assert_eq!(msg.content, "You are a helpful assistant.");
224 assert!(msg.tool_calls.is_none());
225 }
226
227 #[test]
228 fn test_training_example() {
229 let example = TrainingExample::new(vec![
230 TrainingMessage::system("You are helpful."),
231 TrainingMessage::user("Hello"),
232 TrainingMessage::assistant("Hi there!"),
233 ]);
234 assert_eq!(example.messages.len(), 3);
235 assert!(example.has_system_message());
236 assert!(example.ends_with_assistant());
237 assert!(example.estimated_tokens() > 0);
238 }
239
240 #[test]
241 fn test_preference_pair() {
242 let pair = PreferencePair::new(
243 vec![TrainingMessage::user("What is 2+2?")],
244 vec![TrainingMessage::assistant("4")],
245 vec![TrainingMessage::assistant("22")],
246 );
247 assert_eq!(pair.prompt.len(), 1);
248 assert_eq!(pair.chosen.len(), 1);
249 assert_eq!(pair.rejected.len(), 1);
250 }
251
252 #[test]
253 fn test_training_role_display() {
254 assert_eq!(TrainingRole::System.to_string(), "system");
255 assert_eq!(TrainingRole::User.to_string(), "user");
256 assert_eq!(TrainingRole::Assistant.to_string(), "assistant");
257 assert_eq!(TrainingRole::Tool.to_string(), "tool");
258 }
259
260 #[test]
261 fn test_data_format_display() {
262 assert_eq!(DataFormat::OpenAI.to_string(), "openai");
263 assert_eq!(DataFormat::Together.to_string(), "together");
264 assert_eq!(DataFormat::ShareGpt.to_string(), "sharegpt");
265 }
266
267 #[test]
268 fn test_training_message_serialization() {
269 let msg = TrainingMessage::assistant("Hello world");
270 let json = serde_json::to_string(&msg).unwrap();
271 let parsed: TrainingMessage = serde_json::from_str(&json).unwrap();
272 assert_eq!(parsed.role, TrainingRole::Assistant);
273 assert_eq!(parsed.content, "Hello world");
274 }
275}