impl CiCdLearningManager {
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn new(config: CiCdLearningConfig) -> Self {
Self {
config,
predictor: SurvivabilityPredictor::new(),
current_version: None,
}
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn collect_training_data(
&mut self,
results: &[MutationResult],
metadata: CiCdMetadata,
) -> Result<TrainingBatch> {
let samples: Vec<TrainingData> = results
.iter()
.filter(|r| matches!(r.status, MutantStatus::Killed | MutantStatus::Survived))
.map(|r| TrainingData {
mutant: r.mutant.clone(),
was_killed: r.status == MutantStatus::Killed,
test_failures: r.test_failures.clone(),
execution_time_ms: r.execution_time_ms,
})
.collect();
let batch = TrainingBatch {
id: format!("{}", Utc::now().timestamp()),
metadata,
samples,
collected_at: Utc::now(),
};
self.save_training_batch(&batch).await?;
if self.config.auto_train {
let all_samples = self.load_all_training_data().await?;
if all_samples.len() >= self.config.min_samples_for_training {
self.train_incremental(&all_samples).await?;
}
}
Ok(batch)
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn train_incremental(
&mut self,
training_data: &[TrainingData],
) -> Result<ModelVersion> {
let samples = if training_data.len() > self.config.max_training_samples {
&training_data[training_data.len() - self.config.max_training_samples..]
} else {
training_data
};
self.predictor
.train(samples)
.context("Failed to train predictor")?;
let accuracy = self.predictor.cross_validate(samples, 5).unwrap_or(0.0);
let version = self.get_next_version();
let model_version = ModelVersion {
version,
trained_at: Utc::now(),
sample_count: samples.len(),
accuracy,
file_path: self.get_model_path(version),
metadata: None,
};
if self.config.versioning_enabled {
self.save_model_version(&model_version).await?;
}
self.current_version = Some(model_version.clone());
Ok(model_version)
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn load_latest_model(&mut self) -> Result<Option<ModelVersion>> {
let versions = self.list_model_versions().await?;
if let Some(latest) = versions.last() {
self.predictor = SurvivabilityPredictor::load(&latest.file_path)?;
self.current_version = Some(latest.clone());
Ok(Some(latest.clone()))
} else {
Ok(None)
}
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn predictor(&self) -> &SurvivabilityPredictor {
&self.predictor
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn current_version(&self) -> Option<&ModelVersion> {
self.current_version.as_ref()
}
async fn save_training_batch(&self, batch: &TrainingBatch) -> Result<()> {
tokio::fs::create_dir_all(&self.config.data_dir).await?;
let file_path = self
.config
.data_dir
.join(format!("batch_{}.json", batch.id));
let json = serde_json::to_string_pretty(batch)?;
tokio::fs::write(file_path, json)
.await
.context("Failed to save training batch")
}
async fn load_all_training_data(&self) -> Result<Vec<TrainingData>> {
let mut all_samples = Vec::new();
if !self.config.data_dir.exists() {
return Ok(all_samples);
}
let mut entries = tokio::fs::read_dir(&self.config.data_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
let content = tokio::fs::read_to_string(&path).await?;
if let Ok(batch) = serde_json::from_str::<TrainingBatch>(&content) {
all_samples.extend(batch.samples);
}
}
}
Ok(all_samples)
}
async fn save_model_version(&self, version: &ModelVersion) -> Result<()> {
tokio::fs::create_dir_all(&self.config.model_dir).await?;
self.predictor.save(&version.file_path)?;
let metadata_path = self
.config
.model_dir
.join(format!("version_{}.json", version.version));
let json = serde_json::to_string_pretty(version)?;
tokio::fs::write(metadata_path, json).await?;
Ok(())
}
async fn list_model_versions(&self) -> Result<Vec<ModelVersion>> {
let mut versions = Vec::new();
if !self.config.model_dir.exists() {
return Ok(versions);
}
let mut entries = tokio::fs::read_dir(&self.config.model_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path
.file_name()
.and_then(|n| n.to_str())
.map(|n| n.starts_with("version_") && n.ends_with(".json"))
.unwrap_or(false)
{
let content = tokio::fs::read_to_string(&path).await?;
if let Ok(version) = serde_json::from_str::<ModelVersion>(&content) {
versions.push(version);
}
}
}
versions.sort_by_key(|v| v.version);
Ok(versions)
}
fn get_next_version(&self) -> u32 {
self.current_version
.as_ref()
.map(|v| v.version + 1)
.unwrap_or(1)
}
fn get_model_path(&self, version: u32) -> PathBuf {
self.config
.model_dir
.join(format!("model_v{}.bin", version))
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn cleanup_old_data(&self, keep_batches: usize) -> Result<usize> {
if !self.config.data_dir.exists() {
return Ok(0);
}
let mut batches = Vec::new();
let mut entries = tokio::fs::read_dir(&self.config.data_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
batches.push(path);
}
}
batches.sort();
if batches.len() <= keep_batches {
return Ok(0);
}
let to_remove = &batches[..batches.len() - keep_batches];
let mut removed = 0;
for path in to_remove {
if tokio::fs::remove_file(path).await.is_ok() {
removed += 1;
}
}
Ok(removed)
}
}