use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum TrainingFormat {
#[default]
Sft,
Dpo,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TrainingMetadata {
pub episode_id: Option<String>,
pub outcome_score: Option<f64>,
pub model: Option<String>,
pub lora: Option<String>,
pub strategy_name: Option<String>,
pub scenario_name: Option<String>,
#[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
pub custom: std::collections::HashMap<String, String>,
}
impl TrainingMetadata {
pub fn new() -> Self {
Self::default()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingData {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
pub prompt: String,
pub chosen: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rejected: Option<String>,
pub format: TrainingFormat,
#[serde(default)]
pub metadata: TrainingMetadata,
}
impl TrainingData {
pub fn sft(system: &str, prompt: &str, response: &str) -> Self {
Self {
system: Some(system.to_string()),
prompt: prompt.to_string(),
chosen: response.to_string(),
rejected: None,
format: TrainingFormat::Sft,
metadata: TrainingMetadata::default(),
}
}
pub fn sft_simple(prompt: &str, response: &str) -> Self {
Self {
system: None,
prompt: prompt.to_string(),
chosen: response.to_string(),
rejected: None,
format: TrainingFormat::Sft,
metadata: TrainingMetadata::default(),
}
}
pub fn dpo(prompt: &str, chosen: &str, rejected: &str) -> Self {
Self {
system: None,
prompt: prompt.to_string(),
chosen: chosen.to_string(),
rejected: Some(rejected.to_string()),
format: TrainingFormat::Dpo,
metadata: TrainingMetadata::default(),
}
}
pub fn dpo_with_system(system: &str, prompt: &str, chosen: &str, rejected: &str) -> Self {
Self {
system: Some(system.to_string()),
prompt: prompt.to_string(),
chosen: chosen.to_string(),
rejected: Some(rejected.to_string()),
format: TrainingFormat::Dpo,
metadata: TrainingMetadata::default(),
}
}
pub fn with_episode_id(mut self, episode_id: String) -> Self {
self.metadata.episode_id = Some(episode_id);
self
}
pub fn with_outcome_score(mut self, score: f64) -> Self {
self.metadata.outcome_score = Some(score);
self
}
pub fn with_model(mut self, model: &str) -> Self {
self.metadata.model = Some(model.to_string());
self
}
pub fn with_lora(mut self, lora: Option<String>) -> Self {
self.metadata.lora = lora;
self
}
pub fn with_strategy(mut self, strategy: &str) -> Self {
self.metadata.strategy_name = Some(strategy.to_string());
self
}
pub fn with_scenario(mut self, scenario: &str) -> Self {
self.metadata.scenario_name = Some(scenario.to_string());
self
}
pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.custom.insert(key.into(), value.into());
self
}
pub fn is_sft(&self) -> bool {
matches!(self.format, TrainingFormat::Sft)
}
pub fn is_dpo(&self) -> bool {
matches!(self.format, TrainingFormat::Dpo)
}
pub fn is_valid(&self) -> bool {
!self.prompt.is_empty() && !self.chosen.is_empty()
}
pub fn is_valid_dpo(&self) -> bool {
self.is_valid()
&& self
.rejected
.as_ref()
.map(|r| !r.is_empty())
.unwrap_or(false)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationData {
pub conversations: Vec<ConversationTurn>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<TrainingMetadata>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationTurn {
pub role: ConversationRole,
pub content: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ConversationRole {
System,
User,
Assistant,
}
impl From<&TrainingData> for ConversationData {
fn from(data: &TrainingData) -> Self {
let mut conversations = Vec::new();
if let Some(system) = &data.system {
conversations.push(ConversationTurn {
role: ConversationRole::System,
content: system.clone(),
});
}
conversations.push(ConversationTurn {
role: ConversationRole::User,
content: data.prompt.clone(),
});
conversations.push(ConversationTurn {
role: ConversationRole::Assistant,
content: data.chosen.clone(),
});
Self {
conversations,
metadata: Some(data.metadata.clone()),
}
}
}
impl TrainingData {
pub fn to_conversation(&self) -> ConversationData {
ConversationData::from(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sft_simple() {
let data = TrainingData::sft_simple("What action?", "CheckStatus");
assert_eq!(data.prompt, "What action?");
assert_eq!(data.chosen, "CheckStatus");
assert!(data.system.is_none());
assert!(data.rejected.is_none());
assert!(data.is_sft());
assert!(data.is_valid());
}
#[test]
fn test_sft_with_system() {
let data = TrainingData::sft("You are an agent.", "What to do?", "CheckStatus");
assert_eq!(data.system, Some("You are an agent.".to_string()));
assert_eq!(data.prompt, "What to do?");
assert_eq!(data.chosen, "CheckStatus");
assert!(data.is_sft());
}
#[test]
fn test_dpo() {
let data = TrainingData::dpo("What action?", "CheckStatus", "InvalidAction");
assert_eq!(data.chosen, "CheckStatus");
assert_eq!(data.rejected, Some("InvalidAction".to_string()));
assert!(data.is_dpo());
assert!(data.is_valid_dpo());
}
#[test]
fn test_builder_methods() {
let data = TrainingData::sft_simple("prompt", "response")
.with_episode_id("ep_001".to_string())
.with_outcome_score(0.85)
.with_model("qwen2.5")
.with_lora(Some("my_lora".to_string()))
.with_strategy("worker_action")
.with_scenario("troubleshooting")
.with_custom("key", "value");
assert_eq!(data.metadata.episode_id, Some("ep_001".to_string()));
assert_eq!(data.metadata.outcome_score, Some(0.85));
assert_eq!(data.metadata.model, Some("qwen2.5".to_string()));
assert_eq!(data.metadata.lora, Some("my_lora".to_string()));
assert_eq!(
data.metadata.strategy_name,
Some("worker_action".to_string())
);
assert_eq!(
data.metadata.scenario_name,
Some("troubleshooting".to_string())
);
assert_eq!(data.metadata.custom.get("key"), Some(&"value".to_string()));
}
#[test]
fn test_to_conversation() {
let data = TrainingData::sft("System prompt", "User prompt", "Assistant response");
let conv = data.to_conversation();
assert_eq!(conv.conversations.len(), 3);
assert_eq!(conv.conversations[0].role, ConversationRole::System);
assert_eq!(conv.conversations[0].content, "System prompt");
assert_eq!(conv.conversations[1].role, ConversationRole::User);
assert_eq!(conv.conversations[1].content, "User prompt");
assert_eq!(conv.conversations[2].role, ConversationRole::Assistant);
assert_eq!(conv.conversations[2].content, "Assistant response");
}
#[test]
fn test_to_conversation_no_system() {
let data = TrainingData::sft_simple("prompt", "response");
let conv = data.to_conversation();
assert_eq!(conv.conversations.len(), 2);
assert_eq!(conv.conversations[0].role, ConversationRole::User);
assert_eq!(conv.conversations[1].role, ConversationRole::Assistant);
}
#[test]
fn test_serialization() {
let data =
TrainingData::sft_simple("prompt", "response").with_episode_id("ep_001".to_string());
let json = serde_json::to_string(&data).unwrap();
let deserialized: TrainingData = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.prompt, data.prompt);
assert_eq!(deserialized.chosen, data.chosen);
assert_eq!(deserialized.metadata.episode_id, data.metadata.episode_id);
}
#[test]
fn test_conversation_serialization() {
let data = TrainingData::sft("System", "User", "Assistant");
let conv = data.to_conversation();
let json = serde_json::to_string(&conv).unwrap();
assert!(json.contains("\"conversations\""));
assert!(json.contains("\"role\""));
assert!(json.contains("\"system\""));
assert!(json.contains("\"user\""));
assert!(json.contains("\"assistant\""));
}
}