use crate::{Rule, RuleAtom};
use anyhow::{anyhow, Result};
use scirs2_core::metrics::{Counter, Gauge, Timer};
use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReasoningStrategy {
Forward,
Backward,
Hybrid,
RETE,
GPU,
Distributed,
Custom,
}
#[derive(Debug, Clone)]
pub struct DatasetCharacteristics {
pub fact_count: usize,
pub rule_count: usize,
pub avg_rule_body_size: f64,
pub density: f64,
pub avg_degree: f64,
pub predicate_count: usize,
pub rule_selectivity: f64,
pub data_skew: f64,
}
#[derive(Debug, Clone)]
pub struct QueryPattern {
pub frequency: f64,
pub complexity: f64,
pub selectivity: f64,
pub temporal_locality: f64,
pub spatial_locality: f64,
}
#[derive(Debug, Clone)]
pub struct SystemResources {
pub cpu_cores: usize,
pub memory_available: usize,
pub gpu_available: bool,
pub distributed_nodes: usize,
pub cpu_utilization: f64,
pub memory_utilization: f64,
}
#[derive(Debug, Clone)]
pub struct StrategyMetrics {
pub execution_time: Duration,
pub memory_usage: usize,
pub throughput: f64,
pub quality: f64,
pub execution_count: usize,
}
pub struct AdaptiveStrategySelector {
strategies: Vec<ReasoningStrategy>,
performance_history: HashMap<ReasoningStrategy, Vec<StrategyMetrics>>,
current_strategy: Option<ReasoningStrategy>,
#[allow(dead_code)]
learning_rate: f64,
cost_weights: CostModelWeights,
metrics: SelectorMetrics,
random_state: u64,
exploration_rate: f64,
}
#[derive(Debug, Clone)]
struct CostModelWeights {
time_weight: f64,
memory_weight: f64,
#[allow(dead_code)]
throughput_weight: f64,
#[allow(dead_code)]
quality_weight: f64,
}
pub struct SelectorMetrics {
total_selections: Counter,
#[allow(dead_code)]
correct_predictions: Counter,
strategy_switches: Counter,
selection_timer: Timer,
#[allow(dead_code)]
current_cost: Gauge,
}
impl SelectorMetrics {
fn new() -> Self {
Self {
total_selections: Counter::new("adaptive_total_selections".to_string()),
correct_predictions: Counter::new("adaptive_correct_predictions".to_string()),
strategy_switches: Counter::new("adaptive_strategy_switches".to_string()),
selection_timer: Timer::new("adaptive_selection_time".to_string()),
current_cost: Gauge::new("adaptive_current_cost".to_string()),
}
}
}
impl AdaptiveStrategySelector {
pub fn new() -> Self {
Self {
strategies: Vec::new(),
performance_history: HashMap::new(),
current_strategy: None,
learning_rate: 0.1,
cost_weights: CostModelWeights {
time_weight: 0.4,
memory_weight: 0.2,
throughput_weight: 0.3,
quality_weight: 0.1,
},
metrics: SelectorMetrics::new(),
random_state: 42,
exploration_rate: 0.1,
}
}
pub fn register_strategy(&mut self, strategy: ReasoningStrategy) {
if !self.strategies.contains(&strategy) {
self.strategies.push(strategy);
self.performance_history.insert(strategy, Vec::new());
}
}
pub fn analyze_dataset(&self, facts: &[RuleAtom]) -> DatasetCharacteristics {
let fact_count = facts.len();
let mut subjects = std::collections::HashSet::new();
let mut predicates = std::collections::HashSet::new();
let mut objects = std::collections::HashSet::new();
for fact in facts {
if let RuleAtom::Triple {
subject,
predicate,
object,
} = fact
{
subjects.insert(format!("{:?}", subject));
predicates.insert(format!("{:?}", predicate));
objects.insert(format!("{:?}", object));
}
}
let predicate_count = predicates.len();
let node_count = subjects.len() + objects.len();
let possible_edges = if node_count > 1 {
node_count * (node_count - 1)
} else {
1
};
let density = fact_count as f64 / possible_edges as f64;
let avg_degree = if node_count > 0 {
(2.0 * fact_count as f64) / node_count as f64
} else {
0.0
};
let data_skew = if predicate_count > 0 {
1.0 / predicate_count as f64
} else {
1.0
};
DatasetCharacteristics {
fact_count,
rule_count: 0, avg_rule_body_size: 0.0,
density,
avg_degree,
predicate_count,
rule_selectivity: 0.0,
data_skew,
}
}
pub fn analyze_rules(&self, rules: &[Rule]) -> (f64, f64) {
if rules.is_empty() {
return (0.0, 0.0);
}
let total_body_size: usize = rules.iter().map(|r| r.body.len()).sum();
let avg_body_size = total_body_size as f64 / rules.len() as f64;
let avg_selectivity = 0.5;
(avg_body_size, avg_selectivity)
}
pub fn select_strategy(
&mut self,
characteristics: &DatasetCharacteristics,
) -> Result<ReasoningStrategy> {
self.metrics.total_selections.inc();
let _timer = self.metrics.selection_timer.start();
if self.strategies.is_empty() {
return Err(anyhow!("No strategies registered"));
}
self.random_state = self
.random_state
.wrapping_mul(1103515245)
.wrapping_add(12345);
let rand_val = (self.random_state >> 16) as f64 / 65536.0;
let explore = rand_val < self.exploration_rate;
let strategy = if explore {
let idx = (self.random_state as usize) % self.strategies.len();
self.strategies[idx]
} else {
self.select_best_strategy(characteristics)?
};
if let Some(current) = self.current_strategy {
if current != strategy {
self.metrics.strategy_switches.inc();
}
}
self.current_strategy = Some(strategy);
Ok(strategy)
}
fn select_best_strategy(
&self,
characteristics: &DatasetCharacteristics,
) -> Result<ReasoningStrategy> {
let mut best_strategy = self.strategies[0];
let mut best_cost = f64::INFINITY;
for &strategy in &self.strategies {
let cost = self.estimate_cost(strategy, characteristics);
if cost < best_cost {
best_cost = cost;
best_strategy = strategy;
}
}
Ok(best_strategy)
}
fn estimate_cost(
&self,
strategy: ReasoningStrategy,
characteristics: &DatasetCharacteristics,
) -> f64 {
let (time_cost, memory_cost) = match strategy {
ReasoningStrategy::Forward => {
let time = (characteristics.fact_count * characteristics.rule_count) as f64;
let memory = (characteristics.fact_count * 2) as f64; (time, memory)
}
ReasoningStrategy::Backward => {
let time = (characteristics.rule_count * 10) as f64; let memory = (characteristics.rule_count * 5) as f64; (time, memory)
}
ReasoningStrategy::Hybrid => {
let time = (characteristics.fact_count * characteristics.rule_count / 2) as f64;
let memory = characteristics.fact_count as f64 * 1.5;
(time, memory)
}
ReasoningStrategy::RETE => {
let time = (characteristics.fact_count + characteristics.rule_count * 10) as f64;
let memory = (characteristics.rule_count * 20) as f64; (time, memory)
}
ReasoningStrategy::GPU => {
let time = (characteristics.fact_count / 100) as f64; let memory = (characteristics.fact_count * 3) as f64; (time, memory)
}
ReasoningStrategy::Distributed => {
let time = (characteristics.fact_count / 10) as f64; let memory = characteristics.fact_count as f64 * 1.2; (time, memory)
}
ReasoningStrategy::Custom => (1000.0, 1000.0),
};
let learned_cost = if let Some(history) = self.performance_history.get(&strategy) {
if !history.is_empty() {
let avg_time: f64 = history
.iter()
.map(|m| m.execution_time.as_secs_f64())
.sum::<f64>()
/ history.len() as f64;
let avg_memory: f64 = history.iter().map(|m| m.memory_usage as f64).sum::<f64>()
/ history.len() as f64;
Some((avg_time * 1e6, avg_memory)) } else {
None
}
} else {
None
};
let (final_time, final_memory) = if let Some((learned_time, learned_mem)) = learned_cost {
let t = 0.7 * learned_time + 0.3 * time_cost;
let m = 0.7 * learned_mem + 0.3 * memory_cost;
(t, m)
} else {
(time_cost, memory_cost)
};
let norm_time = final_time / 1e6; let norm_memory = final_memory / 1e6;
self.cost_weights.time_weight * norm_time + self.cost_weights.memory_weight * norm_memory
}
pub fn record_performance(&mut self, strategy: ReasoningStrategy, metrics: StrategyMetrics) {
if let Some(history) = self.performance_history.get_mut(&strategy) {
history.push(metrics);
if history.len() > 100 {
history.remove(0);
}
}
}
pub fn current_strategy(&self) -> Option<ReasoningStrategy> {
self.current_strategy
}
pub fn set_cost_weights(&mut self, time: f64, memory: f64, throughput: f64, quality: f64) {
let total = time + memory + throughput + quality;
self.cost_weights = CostModelWeights {
time_weight: time / total,
memory_weight: memory / total,
throughput_weight: throughput / total,
quality_weight: quality / total,
};
}
pub fn set_exploration_rate(&mut self, rate: f64) {
self.exploration_rate = rate.clamp(0.0, 1.0);
}
pub fn get_performance_history(
&self,
strategy: ReasoningStrategy,
) -> Option<&Vec<StrategyMetrics>> {
self.performance_history.get(&strategy)
}
pub fn get_metrics(&self) -> &SelectorMetrics {
&self.metrics
}
pub fn recommend_for_workload(
&self,
characteristics: &DatasetCharacteristics,
resources: &SystemResources,
) -> ReasoningStrategy {
if characteristics.fact_count < 100 {
return ReasoningStrategy::Backward;
}
if characteristics.fact_count > 10000 && resources.gpu_available {
return ReasoningStrategy::GPU;
}
if resources.distributed_nodes > 1 && characteristics.fact_count > 5000 {
return ReasoningStrategy::Distributed;
}
if characteristics.density > 0.1 && characteristics.rule_count > 50 {
return ReasoningStrategy::RETE;
}
if characteristics.density < 0.01 {
return ReasoningStrategy::Backward;
}
ReasoningStrategy::Hybrid
}
}
impl Default for AdaptiveStrategySelector {
fn default() -> Self {
Self::new()
}
}
pub struct AdaptiveReasoningEngine {
selector: AdaptiveStrategySelector,
facts: Vec<RuleAtom>,
rules: Vec<Rule>,
#[allow(dead_code)]
current_metrics: Option<StrategyMetrics>,
adaptation_interval: usize,
query_count: usize,
}
impl AdaptiveReasoningEngine {
pub fn new() -> Self {
let mut selector = AdaptiveStrategySelector::new();
selector.register_strategy(ReasoningStrategy::Forward);
selector.register_strategy(ReasoningStrategy::Backward);
selector.register_strategy(ReasoningStrategy::Hybrid);
selector.register_strategy(ReasoningStrategy::RETE);
Self {
selector,
facts: Vec::new(),
rules: Vec::new(),
current_metrics: None,
adaptation_interval: 100,
query_count: 0,
}
}
pub fn add_facts(&mut self, facts: Vec<RuleAtom>) {
self.facts.extend(facts);
}
pub fn add_rules(&mut self, rules: Vec<Rule>) {
self.rules.extend(rules);
}
pub fn reason(&mut self) -> Result<Vec<RuleAtom>> {
self.query_count += 1;
if self.query_count % self.adaptation_interval == 0 {
self.adapt_strategy()?;
}
let start = Instant::now();
let results = self.execute_current_strategy()?;
let duration = start.elapsed();
let metrics = StrategyMetrics {
execution_time: duration,
memory_usage: 0, throughput: results.len() as f64 / duration.as_secs_f64(),
quality: 1.0,
execution_count: 1,
};
if let Some(strategy) = self.selector.current_strategy() {
self.selector.record_performance(strategy, metrics);
}
Ok(results)
}
fn adapt_strategy(&mut self) -> Result<()> {
let mut characteristics = self.selector.analyze_dataset(&self.facts);
characteristics.rule_count = self.rules.len();
let (avg_body_size, selectivity) = self.selector.analyze_rules(&self.rules);
characteristics.avg_rule_body_size = avg_body_size;
characteristics.rule_selectivity = selectivity;
self.selector.select_strategy(&characteristics)?;
Ok(())
}
fn execute_current_strategy(&self) -> Result<Vec<RuleAtom>> {
let strategy = self
.selector
.current_strategy()
.unwrap_or(ReasoningStrategy::Hybrid);
match strategy {
ReasoningStrategy::Forward => self.execute_forward(),
ReasoningStrategy::Backward => self.execute_backward(),
ReasoningStrategy::Hybrid => self.execute_hybrid(),
ReasoningStrategy::RETE => self.execute_rete(),
ReasoningStrategy::GPU => self.execute_gpu(),
ReasoningStrategy::Distributed => self.execute_distributed(),
ReasoningStrategy::Custom => self.execute_forward(), }
}
fn execute_forward(&self) -> Result<Vec<RuleAtom>> {
Ok(self.facts.clone())
}
fn execute_backward(&self) -> Result<Vec<RuleAtom>> {
Ok(self.facts.clone())
}
fn execute_hybrid(&self) -> Result<Vec<RuleAtom>> {
Ok(self.facts.clone())
}
fn execute_rete(&self) -> Result<Vec<RuleAtom>> {
Ok(self.facts.clone())
}
fn execute_gpu(&self) -> Result<Vec<RuleAtom>> {
Ok(self.facts.clone())
}
fn execute_distributed(&self) -> Result<Vec<RuleAtom>> {
Ok(self.facts.clone())
}
pub fn set_adaptation_interval(&mut self, interval: usize) {
self.adaptation_interval = interval;
}
pub fn current_strategy(&self) -> Option<ReasoningStrategy> {
self.selector.current_strategy()
}
}
impl Default for AdaptiveReasoningEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Rule, RuleAtom, Term};
fn create_test_fact(s: &str, p: &str, o: &str) -> RuleAtom {
RuleAtom::Triple {
subject: Term::Constant(s.to_string()),
predicate: Term::Constant(p.to_string()),
object: Term::Constant(o.to_string()),
}
}
#[test]
fn test_selector_creation() {
let selector = AdaptiveStrategySelector::new();
assert_eq!(selector.strategies.len(), 0);
assert!(selector.current_strategy.is_none());
}
#[test]
fn test_register_strategy() {
let mut selector = AdaptiveStrategySelector::new();
selector.register_strategy(ReasoningStrategy::Forward);
selector.register_strategy(ReasoningStrategy::Backward);
assert_eq!(selector.strategies.len(), 2);
}
#[test]
fn test_analyze_dataset() {
let selector = AdaptiveStrategySelector::new();
let facts = vec![
create_test_fact("a", "p", "b"),
create_test_fact("b", "p", "c"),
create_test_fact("c", "p", "d"),
];
let characteristics = selector.analyze_dataset(&facts);
assert_eq!(characteristics.fact_count, 3);
assert!(characteristics.density > 0.0);
}
#[test]
fn test_select_strategy() -> Result<(), Box<dyn std::error::Error>> {
let mut selector = AdaptiveStrategySelector::new();
selector.register_strategy(ReasoningStrategy::Forward);
selector.register_strategy(ReasoningStrategy::Backward);
let characteristics = DatasetCharacteristics {
fact_count: 100,
rule_count: 10,
avg_rule_body_size: 2.0,
density: 0.05,
avg_degree: 3.0,
predicate_count: 5,
rule_selectivity: 0.5,
data_skew: 0.2,
};
let strategy = selector.select_strategy(&characteristics)?;
assert!(strategy == ReasoningStrategy::Forward || strategy == ReasoningStrategy::Backward);
Ok(())
}
#[test]
fn test_cost_estimation() {
let selector = AdaptiveStrategySelector::new();
let characteristics = DatasetCharacteristics {
fact_count: 1000,
rule_count: 50,
avg_rule_body_size: 3.0,
density: 0.1,
avg_degree: 5.0,
predicate_count: 20,
rule_selectivity: 0.5,
data_skew: 0.1,
};
let cost_forward = selector.estimate_cost(ReasoningStrategy::Forward, &characteristics);
let cost_backward = selector.estimate_cost(ReasoningStrategy::Backward, &characteristics);
assert!(cost_forward > 0.0);
assert!(cost_backward > 0.0);
}
#[test]
fn test_performance_recording() -> Result<(), Box<dyn std::error::Error>> {
let mut selector = AdaptiveStrategySelector::new();
selector.register_strategy(ReasoningStrategy::Forward);
let metrics = StrategyMetrics {
execution_time: Duration::from_millis(100),
memory_usage: 1024,
throughput: 100.0,
quality: 0.95,
execution_count: 1,
};
selector.record_performance(ReasoningStrategy::Forward, metrics);
let history = selector.get_performance_history(ReasoningStrategy::Forward);
assert!(history.is_some());
assert_eq!(history.ok_or("expected Some value")?.len(), 1);
Ok(())
}
#[test]
fn test_cost_weights() {
let mut selector = AdaptiveStrategySelector::new();
selector.set_cost_weights(0.5, 0.3, 0.1, 0.1);
assert!((selector.cost_weights.time_weight - 0.5).abs() < 1e-6);
assert!((selector.cost_weights.memory_weight - 0.3).abs() < 1e-6);
}
#[test]
fn test_exploration_rate() {
let mut selector = AdaptiveStrategySelector::new();
selector.set_exploration_rate(0.2);
assert!((selector.exploration_rate - 0.2).abs() < 1e-6);
}
#[test]
fn test_recommend_for_workload() {
let selector = AdaptiveStrategySelector::new();
let small_chars = DatasetCharacteristics {
fact_count: 50,
rule_count: 5,
avg_rule_body_size: 2.0,
density: 0.01,
avg_degree: 2.0,
predicate_count: 3,
rule_selectivity: 0.3,
data_skew: 0.1,
};
let resources = SystemResources {
cpu_cores: 4,
memory_available: 8 * 1024 * 1024 * 1024,
gpu_available: false,
distributed_nodes: 1,
cpu_utilization: 0.5,
memory_utilization: 0.6,
};
let strategy = selector.recommend_for_workload(&small_chars, &resources);
assert_eq!(strategy, ReasoningStrategy::Backward);
}
#[test]
fn test_recommend_gpu_for_large_dataset() {
let selector = AdaptiveStrategySelector::new();
let large_chars = DatasetCharacteristics {
fact_count: 15000,
rule_count: 100,
avg_rule_body_size: 3.0,
density: 0.05,
avg_degree: 10.0,
predicate_count: 50,
rule_selectivity: 0.5,
data_skew: 0.2,
};
let resources = SystemResources {
cpu_cores: 8,
memory_available: 16 * 1024 * 1024 * 1024,
gpu_available: true,
distributed_nodes: 1,
cpu_utilization: 0.3,
memory_utilization: 0.4,
};
let strategy = selector.recommend_for_workload(&large_chars, &resources);
assert_eq!(strategy, ReasoningStrategy::GPU);
}
#[test]
fn test_adaptive_engine_creation() {
let engine = AdaptiveReasoningEngine::new();
assert_eq!(engine.facts.len(), 0);
assert_eq!(engine.rules.len(), 0);
}
#[test]
fn test_adaptive_engine_add_facts() {
let mut engine = AdaptiveReasoningEngine::new();
let facts = vec![create_test_fact("a", "p", "b")];
engine.add_facts(facts);
assert_eq!(engine.facts.len(), 1);
}
#[test]
fn test_adaptive_engine_add_rules() {
let mut engine = AdaptiveReasoningEngine::new();
let rule = Rule {
name: "test".to_string(),
body: vec![],
head: vec![],
};
engine.add_rules(vec![rule]);
assert_eq!(engine.rules.len(), 1);
}
#[test]
fn test_adaptive_engine_reason() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = AdaptiveReasoningEngine::new();
let facts = vec![create_test_fact("a", "p", "b")];
engine.add_facts(facts);
let results = engine.reason()?;
assert!(!results.is_empty());
Ok(())
}
#[test]
fn test_adaptation_interval() {
let mut engine = AdaptiveReasoningEngine::new();
engine.set_adaptation_interval(50);
assert_eq!(engine.adaptation_interval, 50);
}
#[test]
fn test_strategy_switching() -> Result<(), Box<dyn std::error::Error>> {
let mut selector = AdaptiveStrategySelector::new();
selector.register_strategy(ReasoningStrategy::Forward);
selector.register_strategy(ReasoningStrategy::Backward);
let characteristics = DatasetCharacteristics {
fact_count: 100,
rule_count: 10,
avg_rule_body_size: 2.0,
density: 0.05,
avg_degree: 3.0,
predicate_count: 5,
rule_selectivity: 0.5,
data_skew: 0.2,
};
let _strategy1 = selector.select_strategy(&characteristics)?;
let _strategy2 = selector.select_strategy(&characteristics)?;
Ok(())
}
#[test]
fn test_performance_history_limit() -> Result<(), Box<dyn std::error::Error>> {
let mut selector = AdaptiveStrategySelector::new();
selector.register_strategy(ReasoningStrategy::Forward);
for _ in 0..150 {
let metrics = StrategyMetrics {
execution_time: Duration::from_millis(100),
memory_usage: 1024,
throughput: 100.0,
quality: 0.95,
execution_count: 1,
};
selector.record_performance(ReasoningStrategy::Forward, metrics);
}
let history = selector.get_performance_history(ReasoningStrategy::Forward);
assert_eq!(history.ok_or("expected Some value")?.len(), 100);
Ok(())
}
}