use crate::error::OracleError;
use entrenar::citl::{CompilationOutcome, DecisionCITL, DecisionTrace};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct IngestionStats {
pub total_pairs: usize,
pub success_pairs: usize,
pub failed_pairs: usize,
pub unique_features: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CFeature {
MallocFree,
PointerArithmetic,
ConstCharPointer,
VoidPointer,
StringIteration,
ArrayWithLength,
StructWithPointers,
FunctionPointer,
GlobalVariable,
StaticVariable,
Custom(String),
}
impl CFeature {
pub fn as_decision(&self) -> String {
match self {
CFeature::MallocFree => "malloc_free".to_string(),
CFeature::PointerArithmetic => "pointer_arithmetic".to_string(),
CFeature::ConstCharPointer => "const_char_pointer".to_string(),
CFeature::VoidPointer => "void_pointer".to_string(),
CFeature::StringIteration => "string_iteration".to_string(),
CFeature::ArrayWithLength => "array_with_length".to_string(),
CFeature::StructWithPointers => "struct_with_pointers".to_string(),
CFeature::FunctionPointer => "function_pointer".to_string(),
CFeature::GlobalVariable => "global_variable".to_string(),
CFeature::StaticVariable => "static_variable".to_string(),
CFeature::Custom(s) => s.clone(),
}
}
pub fn from_decision(s: &str) -> Self {
match s {
"malloc_free" => CFeature::MallocFree,
"pointer_arithmetic" => CFeature::PointerArithmetic,
"const_char_pointer" => CFeature::ConstCharPointer,
"void_pointer" => CFeature::VoidPointer,
"string_iteration" => CFeature::StringIteration,
"array_with_length" => CFeature::ArrayWithLength,
"struct_with_pointers" => CFeature::StructWithPointers,
"function_pointer" => CFeature::FunctionPointer,
"global_variable" => CFeature::GlobalVariable,
"static_variable" => CFeature::StaticVariable,
other => CFeature::Custom(other.to_string()),
}
}
}
pub struct CorpusCITL {
citl: DecisionCITL,
stats: IngestionStats,
seen_features: HashSet<CFeature>,
}
impl CorpusCITL {
pub fn new() -> Result<Self, OracleError> {
let citl =
DecisionCITL::new().map_err(|e| OracleError::PatternStoreError(e.to_string()))?;
Ok(Self { citl, stats: IngestionStats::default(), seen_features: HashSet::new() })
}
pub fn ingest_pair(
&mut self,
_c_code: &str,
rust_code: Option<&str>,
features: &[CFeature],
) -> Result<(), OracleError> {
self.stats.total_pairs += 1;
let success = rust_code.is_some();
if success {
self.stats.success_pairs += 1;
} else {
self.stats.failed_pairs += 1;
}
for feature in features {
self.seen_features.insert(feature.clone());
}
self.stats.unique_features = self.seen_features.len();
let traces: Vec<DecisionTrace> = features
.iter()
.enumerate()
.map(|(i, f)| {
DecisionTrace::new(
format!("feature_{i}"),
f.as_decision(),
format!("C feature: {}", f.as_decision()),
)
})
.collect();
let outcome = if success {
CompilationOutcome::success()
} else {
CompilationOutcome::failure(
vec!["transpilation_failed".to_string()],
vec![],
vec!["Rust compilation failed".to_string()],
)
};
self.citl
.ingest_session(traces, outcome, None)
.map_err(|e| OracleError::PatternStoreError(e.to_string()))?;
Ok(())
}
pub fn top_suspicious(&self, k: usize) -> Vec<(CFeature, f64)> {
self.citl
.top_suspicious_types(k)
.into_iter()
.map(|(decision, score)| (CFeature::from_decision(decision), f64::from(score)))
.collect()
}
pub fn stats(&self) -> &IngestionStats {
&self.stats
}
pub fn extract_features(&self, c_code: &str) -> Vec<CFeature> {
let mut features = Vec::new();
if c_code.contains("malloc") || c_code.contains("free(") {
features.push(CFeature::MallocFree);
}
if c_code.contains("++") && c_code.contains("*")
|| c_code.contains("+ 1")
|| c_code.contains("+1")
{
if c_code.contains("char*")
|| c_code.contains("int*")
|| c_code.contains("void*")
|| c_code.contains("*p")
{
features.push(CFeature::PointerArithmetic);
}
}
if c_code.contains("const char*") || c_code.contains("const char *") {
features.push(CFeature::ConstCharPointer);
}
if c_code.contains("void*") || c_code.contains("void *") {
features.push(CFeature::VoidPointer);
}
if c_code.contains("while(*") || c_code.contains("while (*") {
features.push(CFeature::StringIteration);
}
if (c_code.contains("int n") || c_code.contains("size_t"))
&& (c_code.contains("[]") || c_code.contains("* arr"))
{
features.push(CFeature::ArrayWithLength);
}
if c_code.contains("struct") && c_code.contains("*") && c_code.contains("{") {
features.push(CFeature::StructWithPointers);
}
if c_code.contains("(*") && c_code.contains(")(") {
features.push(CFeature::FunctionPointer);
}
if (c_code.starts_with("int ") || c_code.starts_with("char ")) && !c_code.contains("(") {
features.push(CFeature::GlobalVariable);
}
if c_code.contains("static ") {
features.push(CFeature::StaticVariable);
}
features
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_corpus_citl_creation() {
let citl = CorpusCITL::new();
assert!(citl.is_ok(), "CorpusCITL creation should succeed");
}
#[test]
fn test_ingest_successful_pair() {
let mut citl = CorpusCITL::new().unwrap();
let c_code = "int add(int a, int b) { return a + b; }";
let rust_code = "fn add(a: i32, b: i32) -> i32 { a + b }";
let features = vec![];
citl.ingest_pair(c_code, Some(rust_code), &features).unwrap();
assert_eq!(citl.stats().total_pairs, 1);
assert_eq!(citl.stats().success_pairs, 1);
}
#[test]
fn test_ingest_failed_pair() {
let mut citl = CorpusCITL::new().unwrap();
let c_code = "void* ptr = malloc(10); ptr++;";
let features = vec![CFeature::MallocFree, CFeature::PointerArithmetic];
citl.ingest_pair(c_code, None, &features).unwrap();
assert_eq!(citl.stats().total_pairs, 1);
assert_eq!(citl.stats().failed_pairs, 1);
}
#[test]
fn test_top_suspicious_features() {
let mut citl = CorpusCITL::new().unwrap();
for _ in 0..10 {
citl.ingest_pair(
"void* p = malloc(10); p++;",
None, &[CFeature::MallocFree, CFeature::PointerArithmetic],
)
.unwrap();
}
for _ in 0..10 {
citl.ingest_pair(
"int add(int a, int b) { return a + b; }",
Some("fn add(a: i32, b: i32) -> i32 { a + b }"),
&[],
)
.unwrap();
}
let suspicious = citl.top_suspicious(5);
assert!(!suspicious.is_empty());
assert!(suspicious[0].1 > 0.5, "Top feature should have high suspiciousness");
}
#[test]
fn test_extract_features_malloc() {
let citl = CorpusCITL::new().unwrap();
let c_code = "int* p = malloc(sizeof(int)); free(p);";
let features = citl.extract_features(c_code);
assert!(features.contains(&CFeature::MallocFree));
}
#[test]
fn test_extract_features_pointer_arithmetic() {
let citl = CorpusCITL::new().unwrap();
let c_code = "char* p = str; while(*p) { p++; }";
let features = citl.extract_features(c_code);
assert!(features.contains(&CFeature::PointerArithmetic));
assert!(features.contains(&CFeature::StringIteration));
}
#[test]
fn test_extract_features_const_char() {
let citl = CorpusCITL::new().unwrap();
let c_code = "void print(const char* msg) { printf(\"%s\", msg); }";
let features = citl.extract_features(c_code);
assert!(features.contains(&CFeature::ConstCharPointer));
}
#[test]
fn test_cfeature_as_decision() {
assert_eq!(CFeature::MallocFree.as_decision(), "malloc_free");
assert_eq!(CFeature::PointerArithmetic.as_decision(), "pointer_arithmetic");
assert_eq!(CFeature::ConstCharPointer.as_decision(), "const_char_pointer");
}
#[test]
fn test_cfeature_roundtrip() {
let features = vec![
CFeature::MallocFree,
CFeature::PointerArithmetic,
CFeature::VoidPointer,
CFeature::StringIteration,
];
for f in features {
let decision = f.as_decision();
let recovered = CFeature::from_decision(&decision);
assert_eq!(f, recovered, "Feature should roundtrip through decision string");
}
}
#[test]
fn test_ingestion_stats_default() {
let stats = IngestionStats::default();
assert_eq!(stats.total_pairs, 0);
assert_eq!(stats.success_pairs, 0);
assert_eq!(stats.failed_pairs, 0);
}
#[test]
fn test_unique_features_tracking() {
let mut citl = CorpusCITL::new().unwrap();
citl.ingest_pair("code1", None, &[CFeature::MallocFree, CFeature::VoidPointer]).unwrap();
citl.ingest_pair("code2", None, &[CFeature::MallocFree, CFeature::PointerArithmetic])
.unwrap();
assert_eq!(citl.stats().unique_features, 3);
}
}