1use crate::citl::{SuggestionApplicability, TrainingSource};
13use crate::classifier::{DefectCategory, RuleBasedClassifier};
14use crate::git::CommitInfo;
15use anyhow::{anyhow, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TrainingExample {
24 pub message: String,
26 pub label: DefectCategory,
28 pub confidence: f32,
30 pub commit_hash: String,
32 pub author: String,
34 pub timestamp: i64,
36 pub lines_added: usize,
38 pub lines_removed: usize,
40 pub files_changed: usize,
42
43 #[serde(default)]
46 pub error_code: Option<String>,
47 #[serde(default)]
49 pub clippy_lint: Option<String>,
50 #[serde(default)]
52 pub has_suggestion: bool,
53 #[serde(default)]
55 pub suggestion_applicability: Option<SuggestionApplicability>,
56 #[serde(default)]
58 pub source: TrainingSource,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TrainingDataset {
64 pub train: Vec<TrainingExample>,
66 pub validation: Vec<TrainingExample>,
68 pub test: Vec<TrainingExample>,
70 pub metadata: DatasetMetadata,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct DatasetMetadata {
77 pub total_examples: usize,
79 pub train_size: usize,
81 pub validation_size: usize,
83 pub test_size: usize,
85 pub class_distribution: HashMap<String, usize>,
87 pub avg_confidence: f32,
89 pub min_confidence: f32,
91 pub repositories: Vec<String>,
93}
94
95pub struct TrainingDataExtractor {
97 classifier: RuleBasedClassifier,
98 min_confidence: f32,
99}
100
101impl TrainingDataExtractor {
102 pub fn new(min_confidence: f32) -> Self {
116 Self {
117 classifier: RuleBasedClassifier::new(),
118 min_confidence,
119 }
120 }
121
122 pub fn extract_training_data(
159 &self,
160 commits: &[CommitInfo],
161 _repository_name: &str,
162 ) -> Result<Vec<TrainingExample>> {
163 let mut examples = Vec::new();
164
165 for commit in commits {
166 if !self.is_defect_fix_commit(&commit.message) {
168 continue;
169 }
170
171 if let Some(classification) = self.classifier.classify_from_message(&commit.message) {
173 if classification.confidence >= self.min_confidence {
175 examples.push(TrainingExample {
176 message: commit.message.clone(),
177 label: classification.category,
178 confidence: classification.confidence,
179 commit_hash: commit.hash.clone(),
180 author: commit.author.clone(),
181 timestamp: commit.timestamp,
182 lines_added: commit.lines_added,
183 lines_removed: commit.lines_removed,
184 files_changed: commit.files_changed,
185 error_code: None,
187 clippy_lint: None,
188 has_suggestion: false,
189 suggestion_applicability: None,
190 source: TrainingSource::CommitMessage,
191 });
192 }
193 }
194 }
195
196 Ok(examples)
197 }
198
199 fn is_defect_fix_commit(&self, message: &str) -> bool {
206 let lower = message.to_lowercase();
207
208 if lower.starts_with("merge")
210 || lower.starts_with("revert")
211 || lower.contains("wip")
212 || lower.contains("work in progress")
213 {
214 return false;
215 }
216
217 lower.starts_with("fix:")
219 || lower.starts_with("bug:")
220 || lower.starts_with("patch:")
221 || lower.contains("fix ")
222 || lower.contains("bug ")
223 || lower.contains("error")
224 || lower.contains("crash")
225 || lower.contains("issue")
226 }
227
228 pub fn create_splits(
273 &self,
274 examples: &[TrainingExample],
275 repositories: &[String],
276 ) -> Result<TrainingDataset> {
277 if examples.is_empty() {
278 return Err(anyhow!("Cannot create splits from empty dataset"));
279 }
280
281 let total = examples.len();
282
283 let train_size = (total as f32 * 0.70) as usize;
285 let validation_size = (total as f32 * 0.15) as usize;
286 let test_size = total - train_size - validation_size;
287
288 let train = examples[0..train_size].to_vec();
290 let validation = examples[train_size..train_size + validation_size].to_vec();
291 let test = examples[train_size + validation_size..].to_vec();
292
293 let mut class_distribution = HashMap::new();
295 for example in examples {
296 let category_name = format!("{}", example.label);
297 *class_distribution.entry(category_name).or_insert(0) += 1;
298 }
299
300 let avg_confidence =
302 examples.iter().map(|e| e.confidence).sum::<f32>() / examples.len() as f32;
303
304 let metadata = DatasetMetadata {
305 total_examples: total,
306 train_size,
307 validation_size,
308 test_size,
309 class_distribution,
310 avg_confidence,
311 min_confidence: self.min_confidence,
312 repositories: repositories.to_vec(),
313 };
314
315 Ok(TrainingDataset {
316 train,
317 validation,
318 test,
319 metadata,
320 })
321 }
322
323 pub fn get_statistics(&self, examples: &[TrainingExample]) -> String {
333 if examples.is_empty() {
334 return "No examples extracted".to_string();
335 }
336
337 let mut category_counts: HashMap<DefectCategory, usize> = HashMap::new();
338 let mut confidence_sum = 0.0_f32;
339
340 for example in examples {
341 *category_counts.entry(example.label).or_insert(0) += 1;
342 confidence_sum += example.confidence;
343 }
344
345 let avg_confidence = confidence_sum / examples.len() as f32;
346
347 let mut stats = "Training Data Statistics:\n".to_string();
348 stats.push_str(&format!(" Total examples: {}\n", examples.len()));
349 stats.push_str(&format!(" Avg confidence: {:.2}\n", avg_confidence));
350 stats.push_str(&format!(
351 " Min confidence threshold: {:.2}\n",
352 self.min_confidence
353 ));
354 stats.push_str("\nClass Distribution:\n");
355
356 let mut sorted_categories: Vec<_> = category_counts.iter().collect();
357 sorted_categories.sort_by_key(|(_, count)| std::cmp::Reverse(*count));
358
359 for (category, count) in sorted_categories {
360 let percentage = (*count as f32 / examples.len() as f32) * 100.0;
361 stats.push_str(&format!(
362 " {:?}: {} ({:.1}%)\n",
363 category, count, percentage
364 ));
365 }
366
367 stats
368 }
369}
370
371impl Default for TrainingDataExtractor {
372 fn default() -> Self {
373 Self::new(0.75) }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_extractor_creation() {
383 let extractor = TrainingDataExtractor::new(0.80);
384 assert_eq!(extractor.min_confidence, 0.80);
385 }
386
387 #[test]
388 fn test_is_defect_fix_commit() {
389 let extractor = TrainingDataExtractor::new(0.75);
390
391 assert!(extractor.is_defect_fix_commit("fix: null pointer"));
393 assert!(extractor.is_defect_fix_commit("bug: race condition"));
394 assert!(extractor.is_defect_fix_commit("patch: memory leak"));
395 assert!(extractor.is_defect_fix_commit("fix memory leak in parser"));
396
397 assert!(!extractor.is_defect_fix_commit("Merge branch 'main'"));
399 assert!(!extractor.is_defect_fix_commit("Revert commit abc123"));
400 assert!(!extractor.is_defect_fix_commit("feat: add new feature"));
401 assert!(!extractor.is_defect_fix_commit("docs: update README"));
402 assert!(!extractor.is_defect_fix_commit("WIP: working on feature"));
403 }
404
405 #[test]
406 fn test_extract_training_data() {
407 let extractor = TrainingDataExtractor::new(0.70);
408
409 let commits = vec![
410 CommitInfo {
411 hash: "abc123".to_string(),
412 message: "fix: null pointer dereference in parser".to_string(),
413 author: "dev@example.com".to_string(),
414 timestamp: 1234567890,
415 files_changed: 2,
416 lines_added: 10,
417 lines_removed: 5,
418 },
419 CommitInfo {
420 hash: "def456".to_string(),
421 message: "feat: add new feature".to_string(), author: "dev@example.com".to_string(),
423 timestamp: 1234567891,
424 files_changed: 5,
425 lines_added: 100,
426 lines_removed: 0,
427 },
428 CommitInfo {
429 hash: "ghi789".to_string(),
430 message: "fix: race condition in mutex lock".to_string(),
431 author: "dev@example.com".to_string(),
432 timestamp: 1234567892,
433 files_changed: 1,
434 lines_added: 5,
435 lines_removed: 3,
436 },
437 ];
438
439 let examples = extractor
440 .extract_training_data(&commits, "test-repo")
441 .unwrap();
442
443 assert_eq!(examples.len(), 2);
445 assert_eq!(
446 examples[0].message,
447 "fix: null pointer dereference in parser"
448 );
449 assert_eq!(examples[1].message, "fix: race condition in mutex lock");
450 }
451
452 #[test]
453 fn test_create_splits() {
454 let extractor = TrainingDataExtractor::new(0.75);
455
456 let mut examples = Vec::new();
458 for i in 0..100 {
459 examples.push(TrainingExample {
460 message: format!("fix: bug {}", i),
461 label: DefectCategory::MemorySafety,
462 confidence: 0.85,
463 commit_hash: format!("hash{}", i),
464 author: "dev".to_string(),
465 timestamp: 123 + i as i64,
466 lines_added: 5,
467 lines_removed: 2,
468 files_changed: 1,
469 error_code: None,
470 clippy_lint: None,
471 has_suggestion: false,
472 suggestion_applicability: None,
473 source: TrainingSource::CommitMessage,
474 });
475 }
476
477 let dataset = extractor
478 .create_splits(&examples, &["repo1".to_string()])
479 .unwrap();
480
481 assert_eq!(dataset.train.len(), 70);
483 assert_eq!(dataset.validation.len(), 15);
484 assert_eq!(dataset.test.len(), 15);
485 assert_eq!(dataset.metadata.total_examples, 100);
486 assert_eq!(dataset.metadata.train_size, 70);
487 }
488
489 #[test]
490 fn test_empty_dataset_error() {
491 let extractor = TrainingDataExtractor::new(0.75);
492 let examples: Vec<TrainingExample> = vec![];
493
494 let result = extractor.create_splits(&examples, &[]);
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_get_statistics() {
500 let extractor = TrainingDataExtractor::new(0.75);
501
502 let examples = vec![
503 TrainingExample {
504 message: "fix: bug 1".to_string(),
505 label: DefectCategory::MemorySafety,
506 confidence: 0.85,
507 commit_hash: "a".to_string(),
508 author: "dev".to_string(),
509 timestamp: 123,
510 lines_added: 5,
511 lines_removed: 2,
512 files_changed: 1,
513 error_code: None,
514 clippy_lint: None,
515 has_suggestion: false,
516 suggestion_applicability: None,
517 source: TrainingSource::CommitMessage,
518 },
519 TrainingExample {
520 message: "fix: bug 2".to_string(),
521 label: DefectCategory::ConcurrencyBugs,
522 confidence: 0.90,
523 commit_hash: "b".to_string(),
524 author: "dev".to_string(),
525 timestamp: 124,
526 lines_added: 3,
527 lines_removed: 1,
528 files_changed: 1,
529 error_code: None,
530 clippy_lint: None,
531 has_suggestion: false,
532 suggestion_applicability: None,
533 source: TrainingSource::CommitMessage,
534 },
535 ];
536
537 let stats = extractor.get_statistics(&examples);
538 assert!(stats.contains("Total examples: 2"));
539 assert!(stats.contains("Avg confidence:"));
540 assert!(stats.contains("Class Distribution:"));
541 }
542
543 #[test]
544 fn test_confidence_threshold_filtering() {
545 let extractor = TrainingDataExtractor::new(0.90); let commits = vec![CommitInfo {
548 hash: "abc".to_string(),
549 message: "fix: memory leak".to_string(), author: "dev".to_string(),
551 timestamp: 123,
552 files_changed: 1,
553 lines_added: 5,
554 lines_removed: 2,
555 }];
556
557 let examples = extractor
558 .extract_training_data(&commits, "test-repo")
559 .unwrap();
560
561 assert!(examples.len() <= 1);
564 }
565
566 #[test]
567 fn test_default_extractor() {
568 let extractor = TrainingDataExtractor::default();
569 assert_eq!(extractor.min_confidence, 0.75);
570 }
571}