brainwires_datasets/format/
together.rs1use serde_json::json;
2
3use super::FormatConverter;
4use crate::error::{DatasetError, DatasetResult};
5use crate::types::{TrainingExample, TrainingMessage, TrainingRole};
6
7pub struct TogetherFormat {
14 pub use_chat_format: bool,
16}
17
18impl Default for TogetherFormat {
19 fn default() -> Self {
20 Self {
21 use_chat_format: true,
22 }
23 }
24}
25
26impl TogetherFormat {
27 pub fn chat() -> Self {
29 Self {
30 use_chat_format: true,
31 }
32 }
33
34 pub fn text() -> Self {
36 Self {
37 use_chat_format: false,
38 }
39 }
40
41 fn messages_to_text(messages: &[TrainingMessage]) -> String {
42 let mut text = String::new();
43 for msg in messages {
44 match msg.role {
45 TrainingRole::System => {
46 text.push_str(&format!("<<SYS>>\n{}\n<</SYS>>\n\n", msg.content));
47 }
48 TrainingRole::User => {
49 text.push_str(&format!("[INST] {} [/INST] ", msg.content));
50 }
51 TrainingRole::Assistant => {
52 text.push_str(&format!("{}\n", msg.content));
53 }
54 TrainingRole::Tool => {
55 text.push_str(&format!("[TOOL] {} [/TOOL] ", msg.content));
56 }
57 }
58 }
59 format!("<s>{}</s>", text.trim())
60 }
61}
62
63impl FormatConverter for TogetherFormat {
64 fn name(&self) -> &str {
65 "together"
66 }
67
68 fn to_json(&self, example: &TrainingExample) -> DatasetResult<serde_json::Value> {
69 if self.use_chat_format {
70 let messages: Vec<serde_json::Value> = example
72 .messages
73 .iter()
74 .map(|msg| {
75 json!({
76 "role": msg.role.to_string(),
77 "content": msg.content,
78 })
79 })
80 .collect();
81 Ok(json!({ "messages": messages }))
82 } else {
83 let text = Self::messages_to_text(&example.messages);
84 Ok(json!({ "text": text }))
85 }
86 }
87
88 fn parse_json(&self, value: &serde_json::Value) -> DatasetResult<TrainingExample> {
89 if let Some(messages) = value.get("messages") {
91 let arr = messages
92 .as_array()
93 .ok_or_else(|| DatasetError::FormatConversion {
94 message: "'messages' must be an array".to_string(),
95 })?;
96 let mut msgs = Vec::new();
97 for msg in arr {
98 let role = match msg.get("role").and_then(|v| v.as_str()) {
99 Some("system") => TrainingRole::System,
100 Some("user") => TrainingRole::User,
101 Some("assistant") => TrainingRole::Assistant,
102 Some("tool") => TrainingRole::Tool,
103 _ => {
104 return Err(DatasetError::FormatConversion {
105 message: "Invalid or missing role".to_string(),
106 });
107 }
108 };
109 let content = msg
110 .get("content")
111 .and_then(|v| v.as_str())
112 .unwrap_or("")
113 .to_string();
114 msgs.push(TrainingMessage::new(role, content));
115 }
116 Ok(TrainingExample::new(msgs))
117 } else if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
118 let mut messages = Vec::new();
120 let text = text
121 .trim_start_matches("<s>")
122 .trim_end_matches("</s>")
123 .trim();
124
125 if let Some(sys_start) = text.find("<<SYS>>")
127 && let Some(sys_end) = text.find("<</SYS>>")
128 {
129 let system_content = text[sys_start + 7..sys_end].trim().to_string();
130 messages.push(TrainingMessage::system(system_content));
131 }
132
133 let mut remaining = text;
135 while let Some(inst_start) = remaining.find("[INST]") {
136 if let Some(inst_end) = remaining.find("[/INST]") {
137 let user_content = remaining[inst_start + 6..inst_end].trim().to_string();
138 messages.push(TrainingMessage::user(user_content));
139
140 remaining = &remaining[inst_end + 7..];
141 let assistant_end = remaining.find("[INST]").unwrap_or(remaining.len());
143 let assistant_content = remaining[..assistant_end].trim().to_string();
144 if !assistant_content.is_empty() {
145 messages.push(TrainingMessage::assistant(assistant_content));
146 }
147 remaining = &remaining[assistant_end..];
148 } else {
149 break;
150 }
151 }
152
153 if messages.is_empty() {
154 return Err(DatasetError::FormatConversion {
155 message: "Could not parse Together text format".to_string(),
156 });
157 }
158
159 Ok(TrainingExample::new(messages))
160 } else {
161 Err(DatasetError::FormatConversion {
162 message: "Expected 'messages' or 'text' field".to_string(),
163 })
164 }
165 }
166}
167
168use super::PreferenceConverter;
169use crate::types::PreferencePair;
170
171impl PreferenceConverter for TogetherFormat {
172 fn name(&self) -> &str {
173 "together"
174 }
175
176 fn preference_to_json(&self, pair: &PreferencePair) -> DatasetResult<serde_json::Value> {
177 let to_msgs = |msgs: &[TrainingMessage]| -> Vec<serde_json::Value> {
178 msgs.iter()
179 .map(|msg| json!({ "role": msg.role.to_string(), "content": msg.content }))
180 .collect()
181 };
182
183 let mut result = json!({
184 "prompt": to_msgs(&pair.prompt),
185 "chosen": to_msgs(&pair.chosen),
186 "rejected": to_msgs(&pair.rejected),
187 });
188
189 if !pair.metadata.is_empty() {
190 result["metadata"] = json!(pair.metadata);
191 }
192
193 Ok(result)
194 }
195
196 fn parse_preference_json(&self, value: &serde_json::Value) -> DatasetResult<PreferencePair> {
197 let parse_msgs = |key: &str| -> DatasetResult<Vec<TrainingMessage>> {
198 let arr = value.get(key).and_then(|v| v.as_array()).ok_or_else(|| {
199 DatasetError::FormatConversion {
200 message: format!("Missing or invalid '{}' field", key),
201 }
202 })?;
203 let mut msgs = Vec::new();
204 for msg in arr {
205 let role = match msg.get("role").and_then(|v| v.as_str()) {
206 Some("system") => TrainingRole::System,
207 Some("user") => TrainingRole::User,
208 Some("assistant") => TrainingRole::Assistant,
209 Some("tool") => TrainingRole::Tool,
210 _ => {
211 return Err(DatasetError::FormatConversion {
212 message: format!("Invalid role in '{}' messages", key),
213 });
214 }
215 };
216 let content = msg
217 .get("content")
218 .and_then(|v| v.as_str())
219 .unwrap_or("")
220 .to_string();
221 msgs.push(TrainingMessage::new(role, content));
222 }
223 Ok(msgs)
224 };
225
226 let prompt = parse_msgs("prompt")?;
227 let chosen = parse_msgs("chosen")?;
228 let rejected = parse_msgs("rejected")?;
229
230 let mut pair = PreferencePair::new(prompt, chosen, rejected);
231 if let Some(meta) = value.get("metadata").and_then(|v| v.as_object()) {
232 for (k, v) in meta {
233 pair.metadata.insert(k.clone(), v.clone());
234 }
235 }
236
237 Ok(pair)
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_together_chat_roundtrip() {
247 let format = TogetherFormat::chat();
248 let example = TrainingExample::new(vec![
249 TrainingMessage::user("Hello"),
250 TrainingMessage::assistant("Hi!"),
251 ]);
252
253 let json = format.to_json(&example).unwrap();
254 let parsed = format.parse_json(&json).unwrap();
255 assert_eq!(parsed.messages.len(), 2);
256 }
257
258 #[test]
259 fn test_together_text_format() {
260 let format = TogetherFormat::text();
261 let example = TrainingExample::new(vec![
262 TrainingMessage::system("Be helpful"),
263 TrainingMessage::user("Hello"),
264 TrainingMessage::assistant("Hi!"),
265 ]);
266
267 let json = format.to_json(&example).unwrap();
268 let text = json["text"].as_str().unwrap();
269 assert!(text.starts_with("<s>"));
270 assert!(text.ends_with("</s>"));
271 assert!(text.contains("<<SYS>>"));
272 assert!(text.contains("[INST]"));
273 }
274
275 #[test]
276 fn test_together_preference_roundtrip() {
277 use super::PreferenceConverter;
278 use crate::types::PreferencePair;
279 let format = TogetherFormat::chat();
280 let pair = PreferencePair::new(
281 vec![TrainingMessage::user("Q")],
282 vec![TrainingMessage::assistant("Good")],
283 vec![TrainingMessage::assistant("Bad")],
284 );
285 let json = format.preference_to_json(&pair).unwrap();
286 let parsed = format.parse_preference_json(&json).unwrap();
287 assert_eq!(parsed.chosen[0].content, "Good");
288 }
289}