use crate::predictive::{AnomalyDetector, StateEncoder, TransitionModel};
use crate::{Action, Observation, Timestamp};
use std::collections::{HashMap, VecDeque};
pub struct PredictiveModel {
state_history: VecDeque<StateSnapshot>,
transition_model: TransitionModel,
reward_predictor: RewardPredictor,
anomaly_detector: AnomalyDetector,
config: PredictiveConfig,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct PredictiveConfig {
pub history_size: usize,
pub prediction_horizon: usize,
pub confidence_threshold: f64,
pub anomaly_threshold: f64,
}
impl Default for PredictiveConfig {
fn default() -> Self {
Self {
history_size: 1000,
prediction_horizon: 10,
confidence_threshold: 0.5,
anomaly_threshold: 2.0, }
}
}
#[derive(Clone, Debug)]
pub struct StateSnapshot {
pub observation: Observation,
pub timestamp: Timestamp,
pub features: Vec<f64>,
}
#[derive(Clone, Debug)]
pub struct PredictedState {
pub observation: Observation,
pub confidence: f64,
pub timestamp: Timestamp,
}
#[derive(Clone, Debug)]
pub struct Trajectory {
pub states: Vec<PredictedState>,
pub total_reward: f64,
pub confidence: f64,
}
impl PredictiveModel {
pub fn new(config: PredictiveConfig) -> Self {
let discretization_bins = 100;
Self {
state_history: VecDeque::with_capacity(config.history_size),
transition_model: TransitionModel::new(discretization_bins),
reward_predictor: RewardPredictor::new(),
anomaly_detector: AnomalyDetector::new(config.anomaly_threshold, config.history_size),
config,
}
}
pub fn with_default_config() -> Self {
Self::new(PredictiveConfig::default())
}
pub fn record(&mut self, obs: &Observation) {
let features = self.extract_features(obs);
let snapshot = StateSnapshot {
observation: obs.clone(),
timestamp: Timestamp::now(),
features,
};
if self.state_history.len() >= self.config.history_size {
self.state_history.pop_front();
}
self.state_history.push_back(snapshot);
self.anomaly_detector.update(obs);
}
pub fn record_transition(
&mut self,
obs: &Observation,
action: &Action,
reward: f64,
next_obs: &Observation,
) {
self.record(obs);
self.record(next_obs);
self.transition_model.record(obs, action, next_obs);
self.reward_predictor.record(obs, action, reward);
}
pub fn predict_next(&self, current: &Observation, action: &Action) -> PredictedState {
let state_key = self.transition_model.predict(current, action);
match state_key {
Some(_key) => {
let probs = self.transition_model.get_transition_probs(current, action);
let confidence = probs
.values()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.copied()
.unwrap_or(0.0);
PredictedState {
observation: current.clone(),
confidence,
timestamp: Timestamp::now(),
}
}
None => {
PredictedState {
observation: current.clone(),
confidence: 0.0,
timestamp: Timestamp::now(),
}
}
}
}
pub fn predict_trajectory(&self, start: &Observation, actions: &[Action]) -> Trajectory {
let mut states = Vec::with_capacity(actions.len() + 1);
let mut total_reward = 0.0;
let mut current = start.clone();
let mut min_confidence: f64 = 1.0;
for action in actions {
let predicted = self.predict_next(¤t, action);
let reward = self.predict_reward(¤t, action);
total_reward += reward;
min_confidence = min_confidence.min(predicted.confidence);
states.push(predicted.clone());
current = predicted.observation;
}
Trajectory {
states,
total_reward,
confidence: min_confidence,
}
}
pub fn predict_reward(&self, state: &Observation, action: &Action) -> f64 {
self.reward_predictor.predict(state, action)
}
pub fn is_anomaly(&self, obs: &Observation) -> bool {
self.anomaly_detector.is_anomaly(obs)
}
pub fn get_confidence(&self, prediction: &PredictedState) -> f64 {
prediction.confidence
}
pub fn get_uncertainty(&self, state: &Observation) -> f64 {
let anomaly_score = self.anomaly_detector.anomaly_score(state);
anomaly_score.min(1.0) }
pub fn learn(&mut self) {
}
pub fn history(&self) -> &VecDeque<StateSnapshot> {
&self.state_history
}
fn extract_features(&self, obs: &Observation) -> Vec<f64> {
let mut features = Vec::new();
if let Some(f) = obs.value.as_f64() {
features.push(f);
} else if let Some(i) = obs.value.as_i64() {
features.push(i as f64);
} else if let Some(b) = obs.value.as_bool() {
features.push(if b { 1.0 } else { 0.0 });
}
features.push(obs.age_secs() as f64);
features.push(obs.confidence.value() as f64);
features
}
}
pub struct RewardPredictor {
rewards: HashMap<(u64, u64), (f64, u64)>,
encoder: StateEncoder,
}
impl RewardPredictor {
pub fn new() -> Self {
Self {
rewards: HashMap::new(),
encoder: StateEncoder::new(100),
}
}
pub fn record(&mut self, state: &Observation, action: &Action, reward: f64) {
let state_key = self.encoder.encode(state);
let action_key = self.hash_action(action);
let entry = self
.rewards
.entry((state_key, action_key))
.or_insert((0.0, 0));
entry.0 += reward;
entry.1 += 1;
}
pub fn predict(&self, state: &Observation, action: &Action) -> f64 {
let state_key = self.encoder.encode(state);
let action_key = self.hash_action(action);
self.rewards
.get(&(state_key, action_key))
.map(|(total, count)| total / (*count as f64))
.unwrap_or(0.0)
}
fn hash_action(&self, action: &Action) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
action.action_type.hash(&mut hasher);
hasher.finish()
}
}
impl Default for RewardPredictor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ActionType;
#[test]
fn test_predictive_model_creation() {
let model = PredictiveModel::with_default_config();
assert_eq!(model.history().len(), 0);
}
#[test]
fn test_record_observation() {
let mut model = PredictiveModel::with_default_config();
let obs = Observation::sensor("temp", 20.0);
model.record(&obs);
assert_eq!(model.history().len(), 1);
}
#[test]
fn test_predict_next_state() {
let mut model = PredictiveModel::with_default_config();
let obs1 = Observation::sensor("temp", 20.0);
let obs2 = Observation::sensor("temp", 21.0);
let action = Action::new(ActionType::Custom("heat".to_string()));
model.record_transition(&obs1, &action, 1.0, &obs2);
model.learn();
let pred = model.predict_next(&obs1, &action);
assert!(pred.confidence >= 0.0);
}
#[test]
fn test_predict_reward() {
let mut model = PredictiveModel::with_default_config();
let obs = Observation::sensor("temp", 20.0);
let action = Action::new(ActionType::Custom("heat".to_string()));
model.record_transition(&obs, &action, 5.0, &obs);
let predicted_reward = model.predict_reward(&obs, &action);
assert!((predicted_reward - 5.0).abs() < 0.1);
}
#[test]
fn test_predict_trajectory() {
let mut model = PredictiveModel::with_default_config();
let obs = Observation::sensor("temp", 20.0);
let action1 = Action::new(ActionType::Custom("heat".to_string()));
let action2 = Action::new(ActionType::Wait);
model.record_transition(&obs, &action1, 1.0, &obs);
let trajectory = model.predict_trajectory(&obs, &[action1, action2]);
assert_eq!(trajectory.states.len(), 2);
}
#[test]
fn test_history_size_limit() {
let config = PredictiveConfig {
history_size: 5,
..Default::default()
};
let mut model = PredictiveModel::new(config);
for i in 0..10 {
let obs = Observation::sensor("temp", i as f64);
model.record(&obs);
}
assert_eq!(model.history().len(), 5);
}
}