use serde::Deserialize;
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
pub struct InstructSample {
pub instruction: String,
pub response: String,
#[serde(default)]
pub system: Option<String>,
#[serde(default)]
pub metadata: Option<InstructMetadata>,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct InstructMetadata {
#[serde(default)]
pub source: Option<String>,
#[serde(default)]
pub libraries: Vec<String>,
#[serde(default)]
pub complexity: Option<u32>,
}
#[must_use]
pub fn format_chat_prompt(sample: &InstructSample) -> (String, String) {
let system = sample.system.as_deref().unwrap_or(
"You are a helpful programming assistant. Write clean, correct, well-documented code.",
);
let prompt = format!(
"<|im_start|>system\n{system}<|im_end|>\n\
<|im_start|>user\n{}<|im_end|>\n\
<|im_start|>assistant\n",
sample.instruction
);
let response = format!("{}<|im_end|>", sample.response);
(prompt, response)
}
#[derive(Debug, Clone)]
pub struct InstructCorpusStats {
pub total: usize,
pub avg_instruction_len: usize,
pub avg_response_len: usize,
pub with_system: usize,
pub sources: Vec<String>,
}
pub fn load_instruct_corpus(path: &Path) -> crate::Result<Vec<InstructSample>> {
let content = std::fs::read_to_string(path)
.map_err(|e| crate::Error::Io(format!("Corpus file not found: {}: {e}", path.display())))?;
let mut samples = Vec::new();
for (line_num, line) in content.lines().enumerate() {
let line = line.trim();
if line.is_empty() {
continue;
}
let sample: InstructSample = serde_json::from_str(line).map_err(|e| {
crate::Error::ConfigError(format!("Invalid JSONL at line {}: {e}", line_num + 1))
})?;
if sample.instruction.trim().is_empty() {
return Err(crate::Error::ConfigError(format!(
"F-INST-001: empty instruction at line {}",
line_num + 1,
)));
}
if sample.response.trim().is_empty() {
return Err(crate::Error::ConfigError(format!(
"F-INST-001: empty response at line {}",
line_num + 1,
)));
}
samples.push(sample);
}
Ok(samples)
}
pub fn instruct_corpus_stats(samples: &[InstructSample]) -> InstructCorpusStats {
if samples.is_empty() {
return InstructCorpusStats {
total: 0,
avg_instruction_len: 0,
avg_response_len: 0,
with_system: 0,
sources: Vec::new(),
};
}
let total_inst_len: usize = samples.iter().map(|s| s.instruction.len()).sum();
let total_resp_len: usize = samples.iter().map(|s| s.response.len()).sum();
let with_system = samples.iter().filter(|s| s.system.is_some()).count();
let mut sources: Vec<String> =
samples.iter().filter_map(|s| s.metadata.as_ref()?.source.clone()).collect();
sources.sort();
sources.dedup();
InstructCorpusStats {
total: samples.len(),
avg_instruction_len: total_inst_len / samples.len(),
avg_response_len: total_resp_len / samples.len(),
with_system,
sources,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_load_instruct_corpus() {
let mut f = NamedTempFile::new().expect("valid");
writeln!(
f,
r#"{{"instruction": "Write hello world", "response": "print('hello world')"}}"#
)
.expect("valid");
writeln!(f, r#"{{"instruction": "Sort a list", "response": "sorted(lst)"}}"#)
.expect("valid");
let samples = load_instruct_corpus(f.path()).expect("valid");
assert_eq!(samples.len(), 2);
assert_eq!(samples[0].instruction, "Write hello world");
assert_eq!(samples[1].response, "sorted(lst)");
}
#[test]
fn test_empty_instruction_rejected() {
let mut f = NamedTempFile::new().expect("valid");
writeln!(f, r#"{{"instruction": "", "response": "some code"}}"#).expect("valid");
let result = load_instruct_corpus(f.path());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("F-INST-001"));
}
#[test]
fn test_empty_response_rejected() {
let mut f = NamedTempFile::new().expect("valid");
writeln!(f, r#"{{"instruction": "Do something", "response": " "}}"#).expect("valid");
let result = load_instruct_corpus(f.path());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("F-INST-001"));
}
#[test]
fn test_format_chat_prompt() {
let sample = InstructSample {
instruction: "Write a sort function".to_string(),
response: "def sort(lst):\n return sorted(lst)".to_string(),
system: None,
metadata: None,
};
let (prompt, response) = format_chat_prompt(&sample);
assert!(prompt.contains("<|im_start|>system"));
assert!(prompt.contains("<|im_start|>user"));
assert!(prompt.contains("Write a sort function"));
assert!(prompt.ends_with("<|im_start|>assistant\n"));
assert!(response.contains("def sort(lst)"));
assert!(response.ends_with("<|im_end|>"));
}
#[test]
fn test_format_chat_prompt_custom_system() {
let sample = InstructSample {
instruction: "test".to_string(),
response: "ok".to_string(),
system: Some("You are a Python expert.".to_string()),
metadata: None,
};
let (prompt, _) = format_chat_prompt(&sample);
assert!(prompt.contains("You are a Python expert."));
}
#[test]
fn test_instruct_corpus_stats() {
let samples = vec![
InstructSample {
instruction: "hello".to_string(),
response: "world".to_string(),
system: Some("sys".to_string()),
metadata: Some(InstructMetadata {
source: Some("test".to_string()),
..Default::default()
}),
},
InstructSample {
instruction: "foo".to_string(),
response: "bar".to_string(),
system: None,
metadata: None,
},
];
let stats = instruct_corpus_stats(&samples);
assert_eq!(stats.total, 2);
assert_eq!(stats.with_system, 1);
assert_eq!(stats.sources, vec!["test".to_string()]);
}
#[test]
fn test_skip_empty_lines() {
let mut f = NamedTempFile::new().expect("valid");
writeln!(f, r#"{{"instruction": "a", "response": "b"}}"#).expect("valid");
writeln!(f).expect("valid"); writeln!(f, r#"{{"instruction": "c", "response": "d"}}"#).expect("valid");
let samples = load_instruct_corpus(f.path()).expect("valid");
assert_eq!(samples.len(), 2);
}
#[test]
fn test_invalid_json_rejected() {
let mut f = NamedTempFile::new().expect("valid");
writeln!(f, "not json").expect("valid");
let result = load_instruct_corpus(f.path());
assert!(result.is_err());
}
#[test]
fn test_corpus_stats_empty() {
let stats = instruct_corpus_stats(&[]);
assert_eq!(stats.total, 0);
assert_eq!(stats.avg_instruction_len, 0);
}
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct PreferenceSample {
pub prompt: String,
pub chosen: String,
pub rejected: String,
}
pub fn load_preference_pairs(path: &std::path::Path) -> Result<Vec<PreferenceSample>, String> {
let file = std::fs::File::open(path).map_err(|e| format!("Open {}: {e}", path.display()))?;
let reader = std::io::BufReader::new(file);
let mut samples = Vec::new();
for (i, line) in std::io::BufRead::lines(reader).enumerate() {
let line = line.map_err(|e| format!("Line {i}: {e}"))?;
if line.trim().is_empty() {
continue;
}
let sample: PreferenceSample =
serde_json::from_str(&line).map_err(|e| format!("Line {i}: {e}"))?;
if sample.prompt.is_empty() || sample.chosen.is_empty() || sample.rejected.is_empty() {
return Err(format!("Line {i}: empty prompt/chosen/rejected"));
}
samples.push(sample);
}
Ok(samples)
}