Skip to main content

entrenar/finetune/
instruct_corpus.rs

1//! Instruction-following corpus loader for generative fine-tuning (GH-371)
2//!
3//! Loads JSONL files with `{"instruction": "...", "response": "..."}` format
4//! for causal language model fine-tuning.
5//!
6//! # Contract
7//!
8//! - F-INST-001: Each sample must have non-empty instruction and response
9//! - F-INST-002: Total token count (prompt + response) must fit max_seq_len
10
11use serde::Deserialize;
12use std::path::Path;
13
14/// A single instruction-response training sample.
15#[derive(Debug, Clone, Deserialize)]
16pub struct InstructSample {
17    /// The instruction/prompt text
18    pub instruction: String,
19    /// The expected response/completion
20    pub response: String,
21    /// Optional system prompt override
22    #[serde(default)]
23    pub system: Option<String>,
24    /// Optional metadata (source corpus, complexity, etc.)
25    #[serde(default)]
26    pub metadata: Option<InstructMetadata>,
27}
28
29/// Optional metadata for an instruction sample.
30#[derive(Debug, Clone, Default, Deserialize)]
31pub struct InstructMetadata {
32    /// Source corpus name
33    #[serde(default)]
34    pub source: Option<String>,
35    /// Libraries used in the response
36    #[serde(default)]
37    pub libraries: Vec<String>,
38    /// Estimated complexity (1-10)
39    #[serde(default)]
40    pub complexity: Option<u32>,
41}
42
43/// Format an instruction sample as a Qwen chat prompt.
44///
45/// Uses the `<|im_start|>` / `<|im_end|>` template that Qwen2.5 models expect.
46///
47/// Returns (prompt_text, response_text) where:
48/// - prompt_text includes system + user + assistant prefix
49/// - response_text is the completion + `<|im_end|>`
50#[must_use]
51pub fn format_chat_prompt(sample: &InstructSample) -> (String, String) {
52    let system = sample.system.as_deref().unwrap_or(
53        "You are a helpful programming assistant. Write clean, correct, well-documented code.",
54    );
55
56    let prompt = format!(
57        "<|im_start|>system\n{system}<|im_end|>\n\
58         <|im_start|>user\n{}<|im_end|>\n\
59         <|im_start|>assistant\n",
60        sample.instruction
61    );
62
63    let response = format!("{}<|im_end|>", sample.response);
64
65    (prompt, response)
66}
67
68/// Corpus statistics for instruction samples.
69#[derive(Debug, Clone)]
70pub struct InstructCorpusStats {
71    /// Total number of samples
72    pub total: usize,
73    /// Average instruction length (chars)
74    pub avg_instruction_len: usize,
75    /// Average response length (chars)
76    pub avg_response_len: usize,
77    /// Samples with system prompt override
78    pub with_system: usize,
79    /// Unique source corpora
80    pub sources: Vec<String>,
81}
82
83/// Load instruction corpus from JSONL file.
84///
85/// Each line is `{"instruction": "...", "response": "..."}`.
86///
87/// # Contract (F-INST-001)
88/// All samples must have non-empty instruction and response.
89///
90/// # Errors
91/// Returns error if file cannot be read or contains invalid samples.
92pub fn load_instruct_corpus(path: &Path) -> crate::Result<Vec<InstructSample>> {
93    let content = std::fs::read_to_string(path)
94        .map_err(|e| crate::Error::Io(format!("Corpus file not found: {}: {e}", path.display())))?;
95
96    let mut samples = Vec::new();
97    for (line_num, line) in content.lines().enumerate() {
98        let line = line.trim();
99        if line.is_empty() {
100            continue;
101        }
102        let sample: InstructSample = serde_json::from_str(line).map_err(|e| {
103            crate::Error::ConfigError(format!("Invalid JSONL at line {}: {e}", line_num + 1))
104        })?;
105
106        // F-INST-001: non-empty validation
107        if sample.instruction.trim().is_empty() {
108            return Err(crate::Error::ConfigError(format!(
109                "F-INST-001: empty instruction at line {}",
110                line_num + 1,
111            )));
112        }
113        if sample.response.trim().is_empty() {
114            return Err(crate::Error::ConfigError(format!(
115                "F-INST-001: empty response at line {}",
116                line_num + 1,
117            )));
118        }
119
120        samples.push(sample);
121    }
122
123    Ok(samples)
124}
125
126/// Compute corpus statistics.
127pub fn instruct_corpus_stats(samples: &[InstructSample]) -> InstructCorpusStats {
128    if samples.is_empty() {
129        return InstructCorpusStats {
130            total: 0,
131            avg_instruction_len: 0,
132            avg_response_len: 0,
133            with_system: 0,
134            sources: Vec::new(),
135        };
136    }
137
138    let total_inst_len: usize = samples.iter().map(|s| s.instruction.len()).sum();
139    let total_resp_len: usize = samples.iter().map(|s| s.response.len()).sum();
140    let with_system = samples.iter().filter(|s| s.system.is_some()).count();
141
142    let mut sources: Vec<String> =
143        samples.iter().filter_map(|s| s.metadata.as_ref()?.source.clone()).collect();
144    sources.sort();
145    sources.dedup();
146
147    InstructCorpusStats {
148        total: samples.len(),
149        avg_instruction_len: total_inst_len / samples.len(),
150        avg_response_len: total_resp_len / samples.len(),
151        with_system,
152        sources,
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use std::io::Write;
160    use tempfile::NamedTempFile;
161
162    #[test]
163    fn test_load_instruct_corpus() {
164        let mut f = NamedTempFile::new().expect("valid");
165        writeln!(
166            f,
167            r#"{{"instruction": "Write hello world", "response": "print('hello world')"}}"#
168        )
169        .expect("valid");
170        writeln!(f, r#"{{"instruction": "Sort a list", "response": "sorted(lst)"}}"#)
171            .expect("valid");
172
173        let samples = load_instruct_corpus(f.path()).expect("valid");
174        assert_eq!(samples.len(), 2);
175        assert_eq!(samples[0].instruction, "Write hello world");
176        assert_eq!(samples[1].response, "sorted(lst)");
177    }
178
179    #[test]
180    fn test_empty_instruction_rejected() {
181        let mut f = NamedTempFile::new().expect("valid");
182        writeln!(f, r#"{{"instruction": "", "response": "some code"}}"#).expect("valid");
183
184        let result = load_instruct_corpus(f.path());
185        assert!(result.is_err());
186        assert!(result.unwrap_err().to_string().contains("F-INST-001"));
187    }
188
189    #[test]
190    fn test_empty_response_rejected() {
191        let mut f = NamedTempFile::new().expect("valid");
192        writeln!(f, r#"{{"instruction": "Do something", "response": "  "}}"#).expect("valid");
193
194        let result = load_instruct_corpus(f.path());
195        assert!(result.is_err());
196        assert!(result.unwrap_err().to_string().contains("F-INST-001"));
197    }
198
199    #[test]
200    fn test_format_chat_prompt() {
201        let sample = InstructSample {
202            instruction: "Write a sort function".to_string(),
203            response: "def sort(lst):\n    return sorted(lst)".to_string(),
204            system: None,
205            metadata: None,
206        };
207
208        let (prompt, response) = format_chat_prompt(&sample);
209        assert!(prompt.contains("<|im_start|>system"));
210        assert!(prompt.contains("<|im_start|>user"));
211        assert!(prompt.contains("Write a sort function"));
212        assert!(prompt.ends_with("<|im_start|>assistant\n"));
213        assert!(response.contains("def sort(lst)"));
214        assert!(response.ends_with("<|im_end|>"));
215    }
216
217    #[test]
218    fn test_format_chat_prompt_custom_system() {
219        let sample = InstructSample {
220            instruction: "test".to_string(),
221            response: "ok".to_string(),
222            system: Some("You are a Python expert.".to_string()),
223            metadata: None,
224        };
225
226        let (prompt, _) = format_chat_prompt(&sample);
227        assert!(prompt.contains("You are a Python expert."));
228    }
229
230    #[test]
231    fn test_instruct_corpus_stats() {
232        let samples = vec![
233            InstructSample {
234                instruction: "hello".to_string(),
235                response: "world".to_string(),
236                system: Some("sys".to_string()),
237                metadata: Some(InstructMetadata {
238                    source: Some("test".to_string()),
239                    ..Default::default()
240                }),
241            },
242            InstructSample {
243                instruction: "foo".to_string(),
244                response: "bar".to_string(),
245                system: None,
246                metadata: None,
247            },
248        ];
249
250        let stats = instruct_corpus_stats(&samples);
251        assert_eq!(stats.total, 2);
252        assert_eq!(stats.with_system, 1);
253        assert_eq!(stats.sources, vec!["test".to_string()]);
254    }
255
256    #[test]
257    fn test_skip_empty_lines() {
258        let mut f = NamedTempFile::new().expect("valid");
259        writeln!(f, r#"{{"instruction": "a", "response": "b"}}"#).expect("valid");
260        writeln!(f).expect("valid"); // empty line
261        writeln!(f, r#"{{"instruction": "c", "response": "d"}}"#).expect("valid");
262
263        let samples = load_instruct_corpus(f.path()).expect("valid");
264        assert_eq!(samples.len(), 2);
265    }
266
267    #[test]
268    fn test_invalid_json_rejected() {
269        let mut f = NamedTempFile::new().expect("valid");
270        writeln!(f, "not json").expect("valid");
271
272        let result = load_instruct_corpus(f.path());
273        assert!(result.is_err());
274    }
275
276    #[test]
277    fn test_corpus_stats_empty() {
278        let stats = instruct_corpus_stats(&[]);
279        assert_eq!(stats.total, 0);
280        assert_eq!(stats.avg_instruction_len, 0);
281    }
282}
283
284/// A DPO preference training sample (prompt + chosen + rejected).
285/// Contract: dpo-alignment-v1 / preference_data_valid
286#[derive(Debug, Clone, serde::Deserialize)]
287pub struct PreferenceSample {
288    /// The prompt/instruction text
289    pub prompt: String,
290    /// The preferred (chosen) response
291    pub chosen: String,
292    /// The rejected response
293    pub rejected: String,
294}
295
296/// Load preference pairs from JSONL file.
297pub fn load_preference_pairs(path: &std::path::Path) -> Result<Vec<PreferenceSample>, String> {
298    let file = std::fs::File::open(path).map_err(|e| format!("Open {}: {e}", path.display()))?;
299    let reader = std::io::BufReader::new(file);
300    let mut samples = Vec::new();
301    for (i, line) in std::io::BufRead::lines(reader).enumerate() {
302        let line = line.map_err(|e| format!("Line {i}: {e}"))?;
303        if line.trim().is_empty() {
304            continue;
305        }
306        let sample: PreferenceSample =
307            serde_json::from_str(&line).map_err(|e| format!("Line {i}: {e}"))?;
308        // FALSIFY-DPO-002: validate non-empty fields
309        if sample.prompt.is_empty() || sample.chosen.is_empty() || sample.rejected.is_empty() {
310            return Err(format!("Line {i}: empty prompt/chosen/rejected"));
311        }
312        samples.push(sample);
313    }
314    Ok(samples)
315}