1use std::path::PathBuf;
20
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23
24use crate::EvalCaseResult;
25use crate::report::{Reporter, ReporterError, ReporterOutput};
26use crate::types::Invocation;
27
28#[derive(Debug, Error)]
32pub enum ExportError {
33 #[error("serialization error: {0}")]
35 Serialization(#[from] serde_json::Error),
36 #[error("format not fully implemented: {0:?}")]
38 NotImplemented(TrainingFormat),
39 #[error("no traces passed the quality threshold ({threshold})")]
41 NothingToExport { threshold: f32 },
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case")]
49#[non_exhaustive]
50pub enum TrainingFormat {
51 ChatMlSft,
53 DpoPairs,
55 ShareGpt,
57}
58
59#[derive(Debug, Clone)]
63pub struct ExportOptions {
64 pub format: TrainingFormat,
66 pub quality_threshold: f32,
69 pub include_metadata: bool,
72}
73
74impl Default for ExportOptions {
75 fn default() -> Self {
76 Self {
77 format: TrainingFormat::ChatMlSft,
78 quality_threshold: 0.0,
79 include_metadata: true,
80 }
81 }
82}
83
84impl ExportOptions {
85 #[must_use]
87 pub fn chatml_sft(quality_threshold: f32) -> Self {
88 Self {
89 format: TrainingFormat::ChatMlSft,
90 quality_threshold,
91 include_metadata: true,
92 }
93 }
94
95 #[must_use]
97 pub fn dpo_pairs(quality_threshold: f32) -> Self {
98 Self {
99 format: TrainingFormat::DpoPairs,
100 quality_threshold,
101 include_metadata: true,
102 }
103 }
104
105 #[must_use]
107 pub fn sharegpt() -> Self {
108 Self {
109 format: TrainingFormat::ShareGpt,
110 quality_threshold: 0.0,
111 include_metadata: true,
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
122pub struct ScoredTrace {
123 pub invocation: Invocation,
125 pub score: f64,
128 pub case_id: String,
130}
131
132impl ScoredTrace {
133 #[must_use]
137 pub fn from_case_result(result: &EvalCaseResult) -> Self {
138 let score = if result.metric_results.is_empty() {
139 0.0
140 } else {
141 let sum: f64 = result.metric_results.iter().map(|m| m.score.value).sum();
142 #[allow(clippy::cast_precision_loss)]
143 let mean = sum / result.metric_results.len() as f64;
144 mean
145 };
146 Self {
147 invocation: result.invocation.clone(),
148 score,
149 case_id: result.case_id.clone(),
150 }
151 }
152}
153
154pub trait TrainingExporter: Send + Sync {
161 fn export(&self, traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError>;
167}
168
169#[derive(Debug, Default, Clone, Copy)]
189pub struct ChatMlExporter;
190
191impl TrainingExporter for ChatMlExporter {
192 fn export(&self, traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError> {
193 let threshold = f64::from(opts.quality_threshold);
194 let qualified: Vec<&ScoredTrace> = traces.iter().filter(|t| t.score >= threshold).collect();
195
196 if qualified.is_empty() {
197 return Err(ExportError::NothingToExport {
198 threshold: opts.quality_threshold,
199 });
200 }
201
202 let mut out = Vec::new();
203 for trace in qualified {
204 let record = build_chatml_record(trace, opts);
205 serde_json::to_writer(&mut out, &record)?;
206 out.push(b'\n');
207 }
208 Ok(out)
209 }
210}
211
212#[derive(Serialize)]
215struct ChatMlRecord<'a> {
216 messages: Vec<ChatMlMessage>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 metadata: Option<ChatMlMetadata<'a>>,
219}
220
221#[derive(Serialize)]
222struct ChatMlMessage {
223 role: &'static str,
224 content: String,
225 #[serde(skip_serializing_if = "Option::is_none")]
226 tool_calls: Option<Vec<ChatMlToolCall>>,
227}
228
229#[derive(Serialize)]
230struct ChatMlToolCall {
231 id: String,
232 #[serde(rename = "type")]
233 call_type: &'static str,
234 function: ChatMlFunction,
235}
236
237#[derive(Serialize)]
238struct ChatMlFunction {
239 name: String,
240 arguments: String,
241}
242
243#[derive(Serialize)]
244struct ChatMlMetadata<'a> {
245 case_id: &'a str,
246 score: f64,
247 model_id: String,
248 turns: usize,
249}
250
251fn build_chatml_record<'a>(trace: &'a ScoredTrace, opts: &ExportOptions) -> ChatMlRecord<'a> {
252 let inv = &trace.invocation;
253 let mut messages: Vec<ChatMlMessage> = Vec::new();
254
255 messages.push(ChatMlMessage {
260 role: "system",
261 content: String::new(),
262 tool_calls: None,
263 });
264
265 for turn in &inv.turns {
266 if turn.turn_index == 0 {
271 messages.push(ChatMlMessage {
272 role: "user",
273 content: String::new(), tool_calls: None,
275 });
276 }
277
278 let content = extract_assistant_text(&turn.assistant_message);
280 let tool_calls: Vec<ChatMlToolCall> = turn
281 .tool_calls
282 .iter()
283 .map(|tc| ChatMlToolCall {
284 id: tc.id.clone(),
285 call_type: "function",
286 function: ChatMlFunction {
287 name: tc.name.clone(),
288 arguments: tc.arguments.to_string(),
289 },
290 })
291 .collect();
292
293 messages.push(ChatMlMessage {
294 role: "assistant",
295 content,
296 tool_calls: if tool_calls.is_empty() {
297 None
298 } else {
299 Some(tool_calls)
300 },
301 });
302 }
303
304 if let Some(response) = &inv.final_response {
307 let needs_patch = messages
308 .last()
309 .is_some_and(|last| last.role == "assistant" && last.content.is_empty());
310 let needs_append = messages.last().is_some_and(|last| last.role != "assistant");
311
312 if needs_patch && !response.is_empty() {
313 if let Some(last_mut) = messages.last_mut() {
314 last_mut.content.clone_from(response);
315 }
316 } else if needs_append {
317 messages.push(ChatMlMessage {
318 role: "assistant",
319 content: response.clone(),
320 tool_calls: None,
321 });
322 }
323 }
324
325 let metadata = if opts.include_metadata {
326 Some(ChatMlMetadata {
327 case_id: &trace.case_id,
328 score: trace.score,
329 model_id: inv.model.model_id.clone(),
330 turns: inv.turns.len(),
331 })
332 } else {
333 None
334 };
335
336 ChatMlRecord { messages, metadata }
337}
338
339fn extract_assistant_text(msg: &swink_agent::AssistantMessage) -> String {
340 use swink_agent::ContentBlock;
341 msg.content
342 .iter()
343 .filter_map(|block| {
344 if let ContentBlock::Text { text } = block {
345 Some(text.as_str())
346 } else {
347 None
348 }
349 })
350 .collect::<Vec<_>>()
351 .join("")
352}
353
354#[derive(Debug, Default, Clone, Copy)]
368pub struct DpoExporter;
369
370#[derive(Serialize)]
372struct DpoPairRecord {
373 case_id: String,
374 chosen: serde_json::Value,
375 rejected: serde_json::Value,
376}
377
378impl TrainingExporter for DpoExporter {
379 fn export(&self, traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError> {
380 let threshold = f64::from(opts.quality_threshold);
381 let qualified: Vec<&ScoredTrace> = traces.iter().filter(|t| t.score >= threshold).collect();
382
383 let mut by_case: std::collections::HashMap<&str, Vec<&ScoredTrace>> =
385 std::collections::HashMap::new();
386 for trace in &qualified {
387 by_case
388 .entry(trace.case_id.as_str())
389 .or_default()
390 .push(trace);
391 }
392
393 let mut pairs: Vec<DpoPairRecord> = Vec::new();
394 for (case_id, mut group) in by_case {
395 if group.len() < 2 {
396 continue;
397 }
398 group.sort_by(|a, b| {
400 b.score
401 .partial_cmp(&a.score)
402 .unwrap_or(std::cmp::Ordering::Equal)
403 });
404 let chosen_trace = group[0];
405 let rejected_trace = group[group.len() - 1];
406
407 let chosen_record = build_chatml_record(chosen_trace, opts);
408 let rejected_record = build_chatml_record(rejected_trace, opts);
409
410 pairs.push(DpoPairRecord {
411 case_id: case_id.to_string(),
412 chosen: serde_json::to_value(chosen_record)?,
413 rejected: serde_json::to_value(rejected_record)?,
414 });
415 }
416
417 if pairs.is_empty() {
418 return Err(ExportError::NothingToExport {
419 threshold: opts.quality_threshold,
420 });
421 }
422
423 let mut out = Vec::new();
424 for pair in &pairs {
425 serde_json::to_writer(&mut out, pair)?;
426 out.push(b'\n');
427 }
428 Ok(out)
429 }
430}
431
432#[derive(Debug, Default, Clone, Copy)]
446pub struct ShareGptExporter;
447
448#[derive(Serialize)]
449struct ShareGptRecord {
450 conversations: Vec<ShareGptTurn>,
451 #[serde(skip_serializing_if = "Option::is_none")]
452 metadata: Option<serde_json::Value>,
453}
454
455#[derive(Serialize)]
456struct ShareGptTurn {
457 from: &'static str,
458 value: String,
459}
460
461impl TrainingExporter for ShareGptExporter {
462 fn export(&self, traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError> {
463 let threshold = f64::from(opts.quality_threshold);
464 let qualified: Vec<&ScoredTrace> = traces.iter().filter(|t| t.score >= threshold).collect();
465
466 if qualified.is_empty() {
467 return Err(ExportError::NothingToExport {
468 threshold: opts.quality_threshold,
469 });
470 }
471
472 let mut out = Vec::new();
473 for trace in qualified {
474 let record = build_sharegpt_record(trace, opts);
475 serde_json::to_writer(&mut out, &record)?;
476 out.push(b'\n');
477 }
478 Ok(out)
479 }
480}
481
482fn build_sharegpt_record(trace: &ScoredTrace, opts: &ExportOptions) -> ShareGptRecord {
483 let inv = &trace.invocation;
484 let mut conversations: Vec<ShareGptTurn> = Vec::new();
485
486 conversations.push(ShareGptTurn {
488 from: "system",
489 value: String::new(),
490 });
491
492 for turn in &inv.turns {
493 if turn.turn_index == 0 {
494 conversations.push(ShareGptTurn {
495 from: "human",
496 value: String::new(), });
498 }
499 let content = extract_assistant_text(&turn.assistant_message);
500 conversations.push(ShareGptTurn {
501 from: "gpt",
502 value: content,
503 });
504 }
505
506 if let Some(response) = &inv.final_response {
508 let needs_patch = conversations
509 .last()
510 .is_some_and(|last| last.from == "gpt" && last.value.is_empty());
511 let needs_append = conversations.last().is_some_and(|last| last.from != "gpt");
512
513 if needs_patch && !response.is_empty() {
514 if let Some(last_mut) = conversations.last_mut() {
515 last_mut.value.clone_from(response);
516 }
517 } else if needs_append {
518 conversations.push(ShareGptTurn {
519 from: "gpt",
520 value: response.clone(),
521 });
522 }
523 }
524
525 let metadata = if opts.include_metadata {
526 Some(serde_json::json!({
527 "case_id": trace.case_id,
528 "score": trace.score,
529 }))
530 } else {
531 None
532 };
533
534 ShareGptRecord {
535 conversations,
536 metadata,
537 }
538}
539
540pub fn export_traces(traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError> {
544 match opts.format {
545 TrainingFormat::ChatMlSft => ChatMlExporter.export(traces, opts),
546 TrainingFormat::DpoPairs => DpoExporter.export(traces, opts),
547 TrainingFormat::ShareGpt => ShareGptExporter.export(traces, opts),
548 }
549}
550
551#[derive(Debug, Clone)]
562pub struct TrainingReporter {
563 opts: ExportOptions,
564 output_path: PathBuf,
566}
567
568impl TrainingReporter {
569 #[must_use]
571 pub fn new(opts: ExportOptions, output_path: impl Into<PathBuf>) -> Self {
572 Self {
573 opts,
574 output_path: output_path.into(),
575 }
576 }
577
578 #[must_use]
580 pub fn chatml_sft(quality_threshold: f32, output_path: impl Into<PathBuf>) -> Self {
581 Self::new(ExportOptions::chatml_sft(quality_threshold), output_path)
582 }
583
584 #[must_use]
586 pub fn dpo_pairs(quality_threshold: f32, output_path: impl Into<PathBuf>) -> Self {
587 Self::new(ExportOptions::dpo_pairs(quality_threshold), output_path)
588 }
589
590 #[must_use]
592 pub fn sharegpt(output_path: impl Into<PathBuf>) -> Self {
593 Self::new(ExportOptions::sharegpt(), output_path)
594 }
595}
596
597impl Reporter for TrainingReporter {
598 fn render(&self, result: &EvalSetResult) -> Result<ReporterOutput, ReporterError> {
599 let traces: Vec<ScoredTrace> = result
600 .case_results
601 .iter()
602 .map(ScoredTrace::from_case_result)
603 .collect();
604
605 let bytes =
606 export_traces(&traces, &self.opts).map_err(|e| ReporterError::Format(e.to_string()))?;
607
608 Ok(ReporterOutput::Artifact {
609 path: self.output_path.clone(),
610 bytes,
611 })
612 }
613}
614
615use crate::types::EvalSetResult;