Skip to main content

brainwires_datasets/format/
together.rs

1use serde_json::json;
2
3use super::FormatConverter;
4use crate::error::{DatasetError, DatasetResult};
5use crate::types::{TrainingExample, TrainingMessage, TrainingRole};
6
7/// Together AI fine-tuning format.
8///
9/// Uses OpenAI-compatible chat format but with `text` wrapper:
10/// `{"text": "<s>[INST] ... [/INST] ..."}`
11///
12/// For chat format (preferred), same as OpenAI: `{"messages": [...]}`
13pub struct TogetherFormat {
14    /// If true, use chat messages format (OpenAI-compatible). If false, use text template.
15    pub use_chat_format: bool,
16}
17
18impl Default for TogetherFormat {
19    fn default() -> Self {
20        Self {
21            use_chat_format: true,
22        }
23    }
24}
25
26impl TogetherFormat {
27    /// Create a Together format using chat messages (OpenAI-compatible).
28    pub fn chat() -> Self {
29        Self {
30            use_chat_format: true,
31        }
32    }
33
34    /// Create a Together format using text template wrapping.
35    pub fn text() -> Self {
36        Self {
37            use_chat_format: false,
38        }
39    }
40
41    fn messages_to_text(messages: &[TrainingMessage]) -> String {
42        let mut text = String::new();
43        for msg in messages {
44            match msg.role {
45                TrainingRole::System => {
46                    text.push_str(&format!("<<SYS>>\n{}\n<</SYS>>\n\n", msg.content));
47                }
48                TrainingRole::User => {
49                    text.push_str(&format!("[INST] {} [/INST] ", msg.content));
50                }
51                TrainingRole::Assistant => {
52                    text.push_str(&format!("{}\n", msg.content));
53                }
54                TrainingRole::Tool => {
55                    text.push_str(&format!("[TOOL] {} [/TOOL] ", msg.content));
56                }
57            }
58        }
59        format!("<s>{}</s>", text.trim())
60    }
61}
62
63impl FormatConverter for TogetherFormat {
64    fn name(&self) -> &str {
65        "together"
66    }
67
68    fn to_json(&self, example: &TrainingExample) -> DatasetResult<serde_json::Value> {
69        if self.use_chat_format {
70            // Same as OpenAI format
71            let messages: Vec<serde_json::Value> = example
72                .messages
73                .iter()
74                .map(|msg| {
75                    json!({
76                        "role": msg.role.to_string(),
77                        "content": msg.content,
78                    })
79                })
80                .collect();
81            Ok(json!({ "messages": messages }))
82        } else {
83            let text = Self::messages_to_text(&example.messages);
84            Ok(json!({ "text": text }))
85        }
86    }
87
88    fn parse_json(&self, value: &serde_json::Value) -> DatasetResult<TrainingExample> {
89        // Prefer chat format parsing
90        if let Some(messages) = value.get("messages") {
91            let arr = messages
92                .as_array()
93                .ok_or_else(|| DatasetError::FormatConversion {
94                    message: "'messages' must be an array".to_string(),
95                })?;
96            let mut msgs = Vec::new();
97            for msg in arr {
98                let role = match msg.get("role").and_then(|v| v.as_str()) {
99                    Some("system") => TrainingRole::System,
100                    Some("user") => TrainingRole::User,
101                    Some("assistant") => TrainingRole::Assistant,
102                    Some("tool") => TrainingRole::Tool,
103                    _ => {
104                        return Err(DatasetError::FormatConversion {
105                            message: "Invalid or missing role".to_string(),
106                        });
107                    }
108                };
109                let content = msg
110                    .get("content")
111                    .and_then(|v| v.as_str())
112                    .unwrap_or("")
113                    .to_string();
114                msgs.push(TrainingMessage::new(role, content));
115            }
116            Ok(TrainingExample::new(msgs))
117        } else if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
118            // Basic text format parsing — extract user/assistant turns
119            let mut messages = Vec::new();
120            let text = text
121                .trim_start_matches("<s>")
122                .trim_end_matches("</s>")
123                .trim();
124
125            // Extract system message if present
126            if let Some(sys_start) = text.find("<<SYS>>")
127                && let Some(sys_end) = text.find("<</SYS>>")
128            {
129                let system_content = text[sys_start + 7..sys_end].trim().to_string();
130                messages.push(TrainingMessage::system(system_content));
131            }
132
133            // Extract [INST]...[/INST] pairs
134            let mut remaining = text;
135            while let Some(inst_start) = remaining.find("[INST]") {
136                if let Some(inst_end) = remaining.find("[/INST]") {
137                    let user_content = remaining[inst_start + 6..inst_end].trim().to_string();
138                    messages.push(TrainingMessage::user(user_content));
139
140                    remaining = &remaining[inst_end + 7..];
141                    // Everything until next [INST] or end is assistant
142                    let assistant_end = remaining.find("[INST]").unwrap_or(remaining.len());
143                    let assistant_content = remaining[..assistant_end].trim().to_string();
144                    if !assistant_content.is_empty() {
145                        messages.push(TrainingMessage::assistant(assistant_content));
146                    }
147                    remaining = &remaining[assistant_end..];
148                } else {
149                    break;
150                }
151            }
152
153            if messages.is_empty() {
154                return Err(DatasetError::FormatConversion {
155                    message: "Could not parse Together text format".to_string(),
156                });
157            }
158
159            Ok(TrainingExample::new(messages))
160        } else {
161            Err(DatasetError::FormatConversion {
162                message: "Expected 'messages' or 'text' field".to_string(),
163            })
164        }
165    }
166}
167
168use super::PreferenceConverter;
169use crate::types::PreferencePair;
170
171impl PreferenceConverter for TogetherFormat {
172    fn name(&self) -> &str {
173        "together"
174    }
175
176    fn preference_to_json(&self, pair: &PreferencePair) -> DatasetResult<serde_json::Value> {
177        let to_msgs = |msgs: &[TrainingMessage]| -> Vec<serde_json::Value> {
178            msgs.iter()
179                .map(|msg| json!({ "role": msg.role.to_string(), "content": msg.content }))
180                .collect()
181        };
182
183        let mut result = json!({
184            "prompt": to_msgs(&pair.prompt),
185            "chosen": to_msgs(&pair.chosen),
186            "rejected": to_msgs(&pair.rejected),
187        });
188
189        if !pair.metadata.is_empty() {
190            result["metadata"] = json!(pair.metadata);
191        }
192
193        Ok(result)
194    }
195
196    fn parse_preference_json(&self, value: &serde_json::Value) -> DatasetResult<PreferencePair> {
197        let parse_msgs = |key: &str| -> DatasetResult<Vec<TrainingMessage>> {
198            let arr = value.get(key).and_then(|v| v.as_array()).ok_or_else(|| {
199                DatasetError::FormatConversion {
200                    message: format!("Missing or invalid '{}' field", key),
201                }
202            })?;
203            let mut msgs = Vec::new();
204            for msg in arr {
205                let role = match msg.get("role").and_then(|v| v.as_str()) {
206                    Some("system") => TrainingRole::System,
207                    Some("user") => TrainingRole::User,
208                    Some("assistant") => TrainingRole::Assistant,
209                    Some("tool") => TrainingRole::Tool,
210                    _ => {
211                        return Err(DatasetError::FormatConversion {
212                            message: format!("Invalid role in '{}' messages", key),
213                        });
214                    }
215                };
216                let content = msg
217                    .get("content")
218                    .and_then(|v| v.as_str())
219                    .unwrap_or("")
220                    .to_string();
221                msgs.push(TrainingMessage::new(role, content));
222            }
223            Ok(msgs)
224        };
225
226        let prompt = parse_msgs("prompt")?;
227        let chosen = parse_msgs("chosen")?;
228        let rejected = parse_msgs("rejected")?;
229
230        let mut pair = PreferencePair::new(prompt, chosen, rejected);
231        if let Some(meta) = value.get("metadata").and_then(|v| v.as_object()) {
232            for (k, v) in meta {
233                pair.metadata.insert(k.clone(), v.clone());
234            }
235        }
236
237        Ok(pair)
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_together_chat_roundtrip() {
247        let format = TogetherFormat::chat();
248        let example = TrainingExample::new(vec![
249            TrainingMessage::user("Hello"),
250            TrainingMessage::assistant("Hi!"),
251        ]);
252
253        let json = format.to_json(&example).unwrap();
254        let parsed = format.parse_json(&json).unwrap();
255        assert_eq!(parsed.messages.len(), 2);
256    }
257
258    #[test]
259    fn test_together_text_format() {
260        let format = TogetherFormat::text();
261        let example = TrainingExample::new(vec![
262            TrainingMessage::system("Be helpful"),
263            TrainingMessage::user("Hello"),
264            TrainingMessage::assistant("Hi!"),
265        ]);
266
267        let json = format.to_json(&example).unwrap();
268        let text = json["text"].as_str().unwrap();
269        assert!(text.starts_with("<s>"));
270        assert!(text.ends_with("</s>"));
271        assert!(text.contains("<<SYS>>"));
272        assert!(text.contains("[INST]"));
273    }
274
275    #[test]
276    fn test_together_preference_roundtrip() {
277        use super::PreferenceConverter;
278        use crate::types::PreferencePair;
279        let format = TogetherFormat::chat();
280        let pair = PreferencePair::new(
281            vec![TrainingMessage::user("Q")],
282            vec![TrainingMessage::assistant("Good")],
283            vec![TrainingMessage::assistant("Bad")],
284        );
285        let json = format.preference_to_json(&pair).unwrap();
286        let parsed = format.parse_preference_json(&json).unwrap();
287        assert_eq!(parsed.chosen[0].content, "Good");
288    }
289}