Skip to main content

brainwires_datasets/
types.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// Role in a training conversation.
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "lowercase")]
7pub enum TrainingRole {
8    /// System prompt role.
9    System,
10    /// User message role.
11    User,
12    /// Assistant response role.
13    Assistant,
14    /// Tool output role.
15    Tool,
16}
17
18impl std::fmt::Display for TrainingRole {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            Self::System => write!(f, "system"),
22            Self::User => write!(f, "user"),
23            Self::Assistant => write!(f, "assistant"),
24            Self::Tool => write!(f, "tool"),
25        }
26    }
27}
28
29/// A single message in a training conversation.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TrainingMessage {
32    /// Role of the message sender.
33    pub role: TrainingRole,
34    /// Text content of the message.
35    pub content: String,
36    /// Optional tool calls made by the assistant.
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub tool_calls: Option<Vec<serde_json::Value>>,
39    /// ID of the tool call this message responds to.
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub tool_call_id: Option<String>,
42    /// Optional name of the sender.
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub name: Option<String>,
45}
46
47impl TrainingMessage {
48    /// Create a new training message with the given role and content.
49    pub fn new(role: TrainingRole, content: impl Into<String>) -> Self {
50        Self {
51            role,
52            content: content.into(),
53            tool_calls: None,
54            tool_call_id: None,
55            name: None,
56        }
57    }
58
59    /// Create a system message.
60    pub fn system(content: impl Into<String>) -> Self {
61        Self::new(TrainingRole::System, content)
62    }
63
64    /// Create a user message.
65    pub fn user(content: impl Into<String>) -> Self {
66        Self::new(TrainingRole::User, content)
67    }
68
69    /// Create an assistant message.
70    pub fn assistant(content: impl Into<String>) -> Self {
71        Self::new(TrainingRole::Assistant, content)
72    }
73
74    /// Create a tool response message.
75    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
76        Self {
77            role: TrainingRole::Tool,
78            content: content.into(),
79            tool_calls: None,
80            tool_call_id: Some(tool_call_id.into()),
81            name: None,
82        }
83    }
84
85    /// Estimated token count (rough: ~4 chars per token).
86    pub fn estimated_tokens(&self) -> usize {
87        self.content.len() / 4 + 1
88    }
89}
90
91/// A training example consisting of a multi-turn conversation.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct TrainingExample {
94    /// Unique identifier for this example.
95    #[serde(default = "generate_id")]
96    pub id: String,
97    /// Ordered list of messages in the conversation.
98    pub messages: Vec<TrainingMessage>,
99    /// Arbitrary metadata attached to this example.
100    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
101    pub metadata: HashMap<String, serde_json::Value>,
102}
103
104fn generate_id() -> String {
105    uuid::Uuid::new_v4().to_string()
106}
107
108impl TrainingExample {
109    /// Create a new training example with an auto-generated ID.
110    pub fn new(messages: Vec<TrainingMessage>) -> Self {
111        Self {
112            id: generate_id(),
113            messages,
114            metadata: HashMap::new(),
115        }
116    }
117
118    /// Create a new training example with a specific ID.
119    pub fn with_id(id: impl Into<String>, messages: Vec<TrainingMessage>) -> Self {
120        Self {
121            id: id.into(),
122            messages,
123            metadata: HashMap::new(),
124        }
125    }
126
127    /// Total estimated token count across all messages.
128    pub fn estimated_tokens(&self) -> usize {
129        self.messages.iter().map(|m| m.estimated_tokens()).sum()
130    }
131
132    /// Check if this example has a system message.
133    pub fn has_system_message(&self) -> bool {
134        self.messages.iter().any(|m| m.role == TrainingRole::System)
135    }
136
137    /// Check if the last message is from the assistant (completion target).
138    pub fn ends_with_assistant(&self) -> bool {
139        self.messages
140            .last()
141            .is_some_and(|m| m.role == TrainingRole::Assistant)
142    }
143}
144
145/// A preference pair for DPO/ORPO training.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct PreferencePair {
148    /// Unique identifier for this preference pair.
149    #[serde(default = "generate_id")]
150    pub id: String,
151    /// The shared prompt messages.
152    pub prompt: Vec<TrainingMessage>,
153    /// The preferred (chosen) response messages.
154    pub chosen: Vec<TrainingMessage>,
155    /// The rejected response messages.
156    pub rejected: Vec<TrainingMessage>,
157    /// Arbitrary metadata attached to this pair.
158    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
159    pub metadata: HashMap<String, serde_json::Value>,
160}
161
162impl PreferencePair {
163    /// Create a new preference pair with an auto-generated ID.
164    pub fn new(
165        prompt: Vec<TrainingMessage>,
166        chosen: Vec<TrainingMessage>,
167        rejected: Vec<TrainingMessage>,
168    ) -> Self {
169        Self {
170            id: generate_id(),
171            prompt,
172            chosen,
173            rejected,
174            metadata: HashMap::new(),
175        }
176    }
177
178    /// Total estimated tokens for prompt + chosen + rejected.
179    pub fn estimated_tokens(&self) -> usize {
180        let prompt_tokens: usize = self.prompt.iter().map(|m| m.estimated_tokens()).sum();
181        let chosen_tokens: usize = self.chosen.iter().map(|m| m.estimated_tokens()).sum();
182        let rejected_tokens: usize = self.rejected.iter().map(|m| m.estimated_tokens()).sum();
183        prompt_tokens + chosen_tokens + rejected_tokens
184    }
185}
186
187/// Supported data formats for import/export.
188#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
189#[serde(rename_all = "lowercase")]
190pub enum DataFormat {
191    /// OpenAI fine-tuning format.
192    OpenAI,
193    /// Together AI fine-tuning format.
194    Together,
195    /// Alpaca instruction-following format.
196    Alpaca,
197    /// ShareGPT conversation format.
198    ShareGpt,
199    /// ChatML template format.
200    ChatMl,
201}
202
203impl std::fmt::Display for DataFormat {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        match self {
206            Self::OpenAI => write!(f, "openai"),
207            Self::Together => write!(f, "together"),
208            Self::Alpaca => write!(f, "alpaca"),
209            Self::ShareGpt => write!(f, "sharegpt"),
210            Self::ChatMl => write!(f, "chatml"),
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_training_message_creation() {
221        let msg = TrainingMessage::system("You are a helpful assistant.");
222        assert_eq!(msg.role, TrainingRole::System);
223        assert_eq!(msg.content, "You are a helpful assistant.");
224        assert!(msg.tool_calls.is_none());
225    }
226
227    #[test]
228    fn test_training_example() {
229        let example = TrainingExample::new(vec![
230            TrainingMessage::system("You are helpful."),
231            TrainingMessage::user("Hello"),
232            TrainingMessage::assistant("Hi there!"),
233        ]);
234        assert_eq!(example.messages.len(), 3);
235        assert!(example.has_system_message());
236        assert!(example.ends_with_assistant());
237        assert!(example.estimated_tokens() > 0);
238    }
239
240    #[test]
241    fn test_preference_pair() {
242        let pair = PreferencePair::new(
243            vec![TrainingMessage::user("What is 2+2?")],
244            vec![TrainingMessage::assistant("4")],
245            vec![TrainingMessage::assistant("22")],
246        );
247        assert_eq!(pair.prompt.len(), 1);
248        assert_eq!(pair.chosen.len(), 1);
249        assert_eq!(pair.rejected.len(), 1);
250    }
251
252    #[test]
253    fn test_training_role_display() {
254        assert_eq!(TrainingRole::System.to_string(), "system");
255        assert_eq!(TrainingRole::User.to_string(), "user");
256        assert_eq!(TrainingRole::Assistant.to_string(), "assistant");
257        assert_eq!(TrainingRole::Tool.to_string(), "tool");
258    }
259
260    #[test]
261    fn test_data_format_display() {
262        assert_eq!(DataFormat::OpenAI.to_string(), "openai");
263        assert_eq!(DataFormat::Together.to_string(), "together");
264        assert_eq!(DataFormat::ShareGpt.to_string(), "sharegpt");
265    }
266
267    #[test]
268    fn test_training_message_serialization() {
269        let msg = TrainingMessage::assistant("Hello world");
270        let json = serde_json::to_string(&msg).unwrap();
271        let parsed: TrainingMessage = serde_json::from_str(&json).unwrap();
272        assert_eq!(parsed.role, TrainingRole::Assistant);
273        assert_eq!(parsed.content, "Hello world");
274    }
275}