use crate::literal::{Lit, Var};
#[allow(unused_imports)]
use crate::prelude::*;
#[derive(Debug, Default, Clone)]
pub struct MLBranchingStats {
pub predictions: usize,
pub correct_predictions: usize,
pub incorrect_predictions: usize,
pub learning_updates: usize,
pub avg_confidence: f64,
pub restart_predictions: usize,
}
impl MLBranchingStats {
pub fn accuracy(&self) -> f64 {
if self.predictions == 0 {
0.0
} else {
self.correct_predictions as f64 / self.predictions as f64
}
}
pub fn display(&self) -> String {
format!(
"ML Branching Stats:\n\
- Predictions: {}\n\
- Accuracy: {:.2}%\n\
- Correct: {} / Incorrect: {}\n\
- Learning updates: {}\n\
- Avg confidence: {:.4}\n\
- Restart predictions: {}",
self.predictions,
self.accuracy() * 100.0,
self.correct_predictions,
self.incorrect_predictions,
self.learning_updates,
self.avg_confidence,
self.restart_predictions,
)
}
}
#[derive(Debug, Clone)]
struct VariableFeatures {
conflict_rate: f64,
propagation_rate: f64,
depth_preference: f64,
activity: f64,
phase_consistency: f64,
}
impl Default for VariableFeatures {
fn default() -> Self {
Self {
conflict_rate: 0.5,
propagation_rate: 0.5,
depth_preference: 0.5,
activity: 0.0,
phase_consistency: 0.5,
}
}
}
#[derive(Debug, Clone)]
pub struct MLBranchingConfig {
pub learning_rate: f64,
pub discount_factor: f64,
pub exploration_rate: f64,
pub min_confidence: f64,
pub normalize_features: bool,
}
impl Default for MLBranchingConfig {
fn default() -> Self {
Self {
learning_rate: 0.1,
discount_factor: 0.95,
exploration_rate: 0.1,
min_confidence: 0.6,
normalize_features: true,
}
}
}
pub struct MLBranching {
config: MLBranchingConfig,
features: HashMap<Var, VariableFeatures>,
stats: MLBranchingStats,
conflict_history: Vec<Vec<Lit>>,
max_history: usize,
current_confidence: f64,
}
impl MLBranching {
pub fn new(config: MLBranchingConfig) -> Self {
Self {
config,
features: HashMap::new(),
stats: MLBranchingStats::default(),
conflict_history: Vec::new(),
max_history: 100,
current_confidence: 0.5,
}
}
pub fn default_config() -> Self {
Self::new(MLBranchingConfig::default())
}
pub fn predict_branch(&mut self, candidates: &[Var]) -> Option<(Var, bool, f64)> {
if candidates.is_empty() {
return None;
}
if self.should_explore() {
let idx = self.random_index(candidates.len());
let var = candidates[idx];
let polarity = self.predict_polarity(var);
self.stats.predictions += 1;
return Some((var, polarity, 0.5));
}
let mut best_var = candidates[0];
let mut best_score = f64::NEG_INFINITY;
for &var in candidates {
let score = self.compute_variable_score(var);
if score > best_score {
best_score = score;
best_var = var;
}
}
let polarity = self.predict_polarity(best_var);
let confidence = self.score_to_confidence(best_score);
self.current_confidence = confidence;
self.stats.predictions += 1;
Some((best_var, polarity, confidence))
}
fn compute_variable_score(&self, var: Var) -> f64 {
let features = self.features.get(&var).cloned().unwrap_or_default();
let w_conflict = 0.3;
let w_propagation = 0.2;
let w_depth = 0.15;
let w_activity = 0.25;
let w_consistency = 0.1;
let score = w_conflict * features.conflict_rate
+ w_propagation * features.propagation_rate
+ w_depth * features.depth_preference
+ w_activity * features.activity
+ w_consistency * features.phase_consistency;
if self.config.normalize_features {
1.0 / (1.0 + libm::exp(-score))
} else {
score
}
}
fn predict_polarity(&self, var: Var) -> bool {
self.features
.get(&var)
.map(|f| f.phase_consistency > 0.5)
.unwrap_or(true)
}
fn score_to_confidence(&self, score: f64) -> f64 {
let normalized: f64 = 1.0 / (1.0 + libm::exp(-5.0 * (score - 0.5)));
normalized.clamp(0.0, 1.0)
}
pub fn learn_from_conflict(
&mut self,
conflict_clause: &[Lit],
decision_var: Var,
was_correct: bool,
) {
self.stats.learning_updates += 1;
if was_correct {
self.stats.correct_predictions += 1;
} else {
self.stats.incorrect_predictions += 1;
}
self.conflict_history.push(conflict_clause.to_vec());
if self.conflict_history.len() > self.max_history {
self.conflict_history.remove(0);
}
for &lit in conflict_clause {
let var = lit.var();
self.update_conflict_rate(var, 1.0);
}
self.update_decision_features(decision_var, was_correct);
let total = self.stats.correct_predictions + self.stats.incorrect_predictions;
if total > 0 {
self.stats.avg_confidence = (self.stats.avg_confidence * (total - 1) as f64
+ self.current_confidence)
/ total as f64;
}
}
fn update_conflict_rate(&mut self, var: Var, reward: f64) {
let features = self.features.entry(var).or_default();
let alpha = self.config.learning_rate;
features.conflict_rate = (1.0 - alpha) * features.conflict_rate + alpha * reward;
}
fn update_decision_features(&mut self, var: Var, was_correct: bool) {
let features = self.features.entry(var).or_default();
let alpha = self.config.learning_rate;
let reward = if was_correct { 1.0 } else { 0.0 };
features.activity = (1.0 - alpha) * features.activity + alpha * reward;
features.phase_consistency = (1.0 - alpha) * features.phase_consistency + alpha * reward;
}
pub fn update_propagation(&mut self, var: Var, success: bool) {
let features = self.features.entry(var).or_default();
let alpha = self.config.learning_rate;
let reward = if success { 1.0 } else { 0.0 };
features.propagation_rate = (1.0 - alpha) * features.propagation_rate + alpha * reward;
}
pub fn update_depth_preference(&mut self, var: Var, depth: usize, max_depth: usize) {
if max_depth == 0 {
return;
}
let features = self.features.entry(var).or_default();
let alpha = self.config.learning_rate;
let normalized_depth = depth as f64 / max_depth as f64;
features.depth_preference =
(1.0 - alpha) * features.depth_preference + alpha * normalized_depth;
}
pub fn predict_restart(&mut self, conflicts_since_restart: usize, lbd_avg: f64) -> bool {
let pattern_score = self.analyze_conflict_patterns();
let lbd_factor = if lbd_avg > 0.0 { 1.0 / lbd_avg } else { 1.0 };
let restart_score = pattern_score * lbd_factor;
let threshold = 0.5 + (conflicts_since_restart as f64 / 1000.0) * 0.3;
let should_restart = restart_score > threshold;
if should_restart {
self.stats.restart_predictions += 1;
}
should_restart
}
fn analyze_conflict_patterns(&self) -> f64 {
if self.conflict_history.len() < 2 {
return 0.0;
}
let recent = &self.conflict_history[self.conflict_history.len().saturating_sub(10)..];
let mut total_similarity = 0.0;
let mut comparisons = 0;
for i in 0..recent.len().saturating_sub(1) {
for j in (i + 1)..recent.len() {
total_similarity += self.clause_similarity(&recent[i], &recent[j]);
comparisons += 1;
}
}
if comparisons > 0 {
total_similarity / comparisons as f64
} else {
0.0
}
}
fn clause_similarity(&self, clause1: &[Lit], clause2: &[Lit]) -> f64 {
let set1: crate::prelude::HashSet<_> = clause1.iter().map(|l| l.var()).collect();
let set2: crate::prelude::HashSet<_> = clause2.iter().map(|l| l.var()).collect();
let intersection = set1.intersection(&set2).count();
let union = set1.union(&set2).count();
if union > 0 {
intersection as f64 / union as f64
} else {
0.0
}
}
pub fn decay_features(&mut self) {
let decay = self.config.discount_factor;
for features in self.features.values_mut() {
features.activity *= decay;
features.conflict_rate *= decay;
}
}
fn should_explore(&self) -> bool {
let pseudo_rand = (self.stats.predictions * 2654435761) % 1000;
(pseudo_rand as f64 / 1000.0) < self.config.exploration_rate
}
fn random_index(&self, max: usize) -> usize {
if max == 0 {
return 0;
}
(self.stats.predictions * 1103515245 + 12345) % max
}
pub fn stats(&self) -> &MLBranchingStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = MLBranchingStats::default();
}
pub fn clear(&mut self) {
self.features.clear();
self.conflict_history.clear();
self.stats = MLBranchingStats::default();
}
pub fn num_learned_vars(&self) -> usize {
self.features.len()
}
pub fn export_features(&self) -> Vec<(Var, f64)> {
self.features
.iter()
.map(|(&var, _features)| (var, self.compute_variable_score(var)))
.collect()
}
}
impl Default for MLBranching {
fn default() -> Self {
Self::default_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ml_branching_creation() {
let ml = MLBranching::default();
assert_eq!(ml.stats().predictions, 0);
assert_eq!(ml.num_learned_vars(), 0);
}
#[test]
fn test_predict_branch() {
let mut ml = MLBranching::default();
let candidates = vec![Var(0), Var(1), Var(2)];
let result = ml.predict_branch(&candidates);
assert!(result.is_some());
let (var, _polarity, confidence) =
result.expect("Prediction must succeed with valid candidates");
assert!(candidates.contains(&var));
assert!((0.0..=1.0).contains(&confidence));
assert_eq!(ml.stats().predictions, 1);
}
#[test]
fn test_empty_candidates() {
let mut ml = MLBranching::default();
let candidates = vec![];
let result = ml.predict_branch(&candidates);
assert!(result.is_none());
}
#[test]
fn test_learn_from_conflict() {
let mut ml = MLBranching::default();
let v0 = Var(0);
let v1 = Var(1);
let conflict = vec![Lit::pos(v0), Lit::neg(v1)];
ml.learn_from_conflict(&conflict, v0, true);
assert_eq!(ml.stats().learning_updates, 1);
assert_eq!(ml.stats().correct_predictions, 1);
assert!(ml.num_learned_vars() > 0);
}
#[test]
fn test_update_propagation() {
let mut ml = MLBranching::default();
let v0 = Var(0);
ml.update_propagation(v0, true);
ml.update_propagation(v0, true);
ml.update_propagation(v0, false);
assert!(ml.features.contains_key(&v0));
}
#[test]
fn test_update_depth_preference() {
let mut ml = MLBranching::default();
let v0 = Var(0);
ml.update_depth_preference(v0, 5, 10);
assert!(ml.features.contains_key(&v0));
let features = ml.features.get(&v0).expect("Features must exist for v0");
assert!(features.depth_preference > 0.0);
}
#[test]
fn test_predict_restart() {
let mut ml = MLBranching::default();
for i in 0..10 {
let conflict = vec![Lit::pos(Var(i)), Lit::neg(Var(i + 1))];
ml.conflict_history.push(conflict);
}
let _should_restart = ml.predict_restart(100, 3.0);
assert!(ml.stats().restart_predictions <= 1);
}
#[test]
fn test_decay_features() {
let mut ml = MLBranching::default();
let v0 = Var(0);
ml.features.entry(v0).or_default().activity = 1.0;
ml.features.entry(v0).or_default().conflict_rate = 0.8;
let initial_activity = ml
.features
.get(&v0)
.expect("Features must exist for v0")
.activity;
let initial_conflict = ml
.features
.get(&v0)
.expect("Features must exist for v0")
.conflict_rate;
ml.decay_features();
let decayed_activity = ml
.features
.get(&v0)
.expect("Features must exist for v0 after decay")
.activity;
let decayed_conflict = ml
.features
.get(&v0)
.expect("Features must exist for v0 after decay")
.conflict_rate;
assert!(decayed_activity < initial_activity);
assert!(decayed_conflict < initial_conflict);
}
#[test]
fn test_clause_similarity() {
let ml = MLBranching::default();
let v0 = Var(0);
let v1 = Var(1);
let v2 = Var(2);
let clause1 = vec![Lit::pos(v0), Lit::pos(v1)];
let clause2 = vec![Lit::pos(v0), Lit::pos(v2)];
let similarity = ml.clause_similarity(&clause1, &clause2);
assert!(similarity > 0.0 && similarity < 1.0);
}
#[test]
fn test_identical_clauses_similarity() {
let ml = MLBranching::default();
let v0 = Var(0);
let v1 = Var(1);
let clause1 = vec![Lit::pos(v0), Lit::pos(v1)];
let clause2 = vec![Lit::pos(v0), Lit::pos(v1)];
let similarity = ml.clause_similarity(&clause1, &clause2);
assert_eq!(similarity, 1.0);
}
#[test]
fn test_stats_accuracy() {
let mut stats = MLBranchingStats::default();
assert_eq!(stats.accuracy(), 0.0);
stats.predictions = 10;
stats.correct_predictions = 7;
stats.incorrect_predictions = 3;
assert_eq!(stats.accuracy(), 0.7);
}
#[test]
fn test_clear() {
let mut ml = MLBranching::default();
let v0 = Var(0);
ml.update_propagation(v0, true);
ml.stats.predictions = 10;
ml.clear();
assert_eq!(ml.num_learned_vars(), 0);
assert_eq!(ml.stats().predictions, 0);
}
#[test]
fn test_export_features() {
let mut ml = MLBranching::default();
let v0 = Var(0);
let v1 = Var(1);
ml.update_propagation(v0, true);
ml.update_propagation(v1, true);
let exported = ml.export_features();
assert_eq!(exported.len(), 2);
}
#[test]
fn test_config_default() {
let config = MLBranchingConfig::default();
assert!(config.learning_rate > 0.0);
assert!(config.discount_factor > 0.0 && config.discount_factor < 1.0);
assert!(config.exploration_rate >= 0.0 && config.exploration_rate <= 1.0);
}
#[test]
fn test_stats_display() {
let stats = MLBranchingStats {
predictions: 100,
correct_predictions: 75,
incorrect_predictions: 25,
learning_updates: 150,
avg_confidence: 0.85,
restart_predictions: 10,
};
let display = stats.display();
assert!(display.contains("100"));
assert!(display.contains("75.00%"));
}
}