use crate::learning::engine::StateId;
use crate::observation::Observation;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub trait ValueFunction {
fn evaluate(&self, state: &StateId) -> f64;
fn update(&mut self, state: &StateId, target: f64, learning_rate: f64);
fn reset(&mut self);
fn size(&self) -> usize;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TabularValueFunction {
values: HashMap<StateId, f64>,
default_value: f64,
}
impl TabularValueFunction {
pub fn new(default_value: f64) -> Self {
Self {
values: HashMap::new(),
default_value,
}
}
pub fn get_all_values(&self) -> &HashMap<StateId, f64> {
&self.values
}
pub fn set_value(&mut self, state: &StateId, value: f64) {
self.values.insert(state.clone(), value);
}
}
impl Default for TabularValueFunction {
fn default() -> Self {
Self::new(0.0)
}
}
impl ValueFunction for TabularValueFunction {
fn evaluate(&self, state: &StateId) -> f64 {
self.values
.get(state)
.copied()
.unwrap_or(self.default_value)
}
fn update(&mut self, state: &StateId, target: f64, learning_rate: f64) {
let current = self.evaluate(state);
let new_value = current + learning_rate * (target - current);
self.values.insert(state.clone(), new_value);
}
fn reset(&mut self) {
self.values.clear();
}
fn size(&self) -> usize {
self.values.len()
}
}
pub type FeatureExtractor = Box<dyn Fn(&Observation) -> Vec<f64> + Send + Sync>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinearValueFunction {
weights: Vec<f64>,
num_features: usize,
default_value: f64,
#[serde(skip)]
update_counts: Vec<u64>,
}
impl LinearValueFunction {
pub fn new(num_features: usize, default_value: f64) -> Self {
Self {
weights: vec![0.0; num_features],
num_features,
default_value,
update_counts: vec![0; num_features],
}
}
pub fn with_weights(weights: Vec<f64>) -> Self {
let num_features = weights.len();
Self {
weights,
num_features,
default_value: 0.0,
update_counts: vec![0; num_features],
}
}
pub fn evaluate_features(&self, features: &[f64]) -> f64 {
if features.len() != self.num_features {
return self.default_value;
}
self.weights
.iter()
.zip(features.iter())
.map(|(w, f)| w * f)
.sum()
}
pub fn update_features(&mut self, features: &[f64], target: f64, learning_rate: f64) {
if features.len() != self.num_features {
return;
}
let prediction = self.evaluate_features(features);
let error = target - prediction;
for (i, &feature) in features.iter().enumerate() {
self.weights[i] += learning_rate * error * feature;
self.update_counts[i] += 1;
}
}
pub fn get_weights(&self) -> &[f64] {
&self.weights
}
pub fn get_update_counts(&self) -> &[u64] {
&self.update_counts
}
pub fn num_features(&self) -> usize {
self.num_features
}
}
impl Default for LinearValueFunction {
fn default() -> Self {
Self::new(10, 0.0)
}
}
impl ValueFunction for LinearValueFunction {
fn evaluate(&self, _state: &StateId) -> f64 {
self.default_value
}
fn update(&mut self, _state: &StateId, _target: f64, _learning_rate: f64) {
}
fn reset(&mut self) {
self.weights.fill(0.0);
self.update_counts.fill(0);
}
fn size(&self) -> usize {
self.num_features
}
}
pub mod feature_extractors {
use super::*;
pub fn numeric_value(obs: &Observation) -> Vec<f64> {
vec![obs.value.as_f64().unwrap_or(0.0)]
}
pub fn value_and_confidence(obs: &Observation) -> Vec<f64> {
vec![
obs.value.as_f64().unwrap_or(0.0),
obs.confidence.value() as f64,
]
}
pub fn value_confidence_age(obs: &Observation) -> Vec<f64> {
vec![
obs.value.as_f64().unwrap_or(0.0),
obs.confidence.value() as f64,
obs.age_secs() as f64,
]
}
pub fn polynomial_features(obs: &Observation) -> Vec<f64> {
let x = obs.value.as_f64().unwrap_or(0.0);
vec![1.0, x, x * x, x * x * x]
}
pub fn normalized_value(min: f64, max: f64) -> impl Fn(&Observation) -> Vec<f64> {
move |obs: &Observation| {
let value = obs.value.as_f64().unwrap_or(0.0);
let normalized = if max > min {
((value - min) / (max - min)).clamp(0.0, 1.0)
} else {
0.0
};
vec![normalized]
}
}
pub fn threshold_feature(threshold: f64) -> impl Fn(&Observation) -> Vec<f64> {
move |obs: &Observation| {
let value = obs.value.as_f64().unwrap_or(0.0);
vec![if value > threshold { 1.0 } else { 0.0 }]
}
}
pub fn multi_threshold_features(thresholds: Vec<f64>) -> impl Fn(&Observation) -> Vec<f64> {
move |obs: &Observation| {
let value = obs.value.as_f64().unwrap_or(0.0);
thresholds
.iter()
.map(|&t| if value > t { 1.0 } else { 0.0 })
.collect()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::observation::Observation;
fn create_state_id(name: &str) -> StateId {
StateId::from_string(name.to_string())
}
#[test]
fn test_tabular_value_function() {
let mut vf = TabularValueFunction::new(0.0);
let state = create_state_id("state1");
assert_eq!(vf.evaluate(&state), 0.0);
vf.update(&state, 1.0, 0.1);
assert!(vf.evaluate(&state) > 0.0);
vf.update(&state, 1.0, 0.1);
assert!(vf.evaluate(&state) > 0.0);
assert_eq!(vf.size(), 1);
}
#[test]
fn test_tabular_value_function_reset() {
let mut vf = TabularValueFunction::new(0.0);
let state = create_state_id("state1");
vf.update(&state, 1.0, 0.1);
assert!(vf.size() > 0);
vf.reset();
assert_eq!(vf.size(), 0);
assert_eq!(vf.evaluate(&state), 0.0);
}
#[test]
fn test_linear_value_function() {
let mut vf = LinearValueFunction::new(3, 0.0);
let features = vec![1.0, 2.0, 3.0];
assert_eq!(vf.evaluate_features(&features), 0.0);
vf.update_features(&features, 10.0, 0.1);
let value = vf.evaluate_features(&features);
assert!(value > 0.0);
assert_eq!(vf.num_features(), 3);
}
#[test]
fn test_linear_value_function_with_weights() {
let weights = vec![1.0, 2.0, 3.0];
let vf = LinearValueFunction::with_weights(weights);
let features = vec![1.0, 1.0, 1.0];
assert_eq!(vf.evaluate_features(&features), 6.0);
}
#[test]
fn test_feature_extractor_numeric() {
let obs = Observation::sensor("temp", 25.5);
let features = feature_extractors::numeric_value(&obs);
assert_eq!(features.len(), 1);
assert_eq!(features[0], 25.5);
}
#[test]
fn test_feature_extractor_value_confidence() {
let obs = Observation::sensor("temp", 25.0).with_confidence(0.9);
let features = feature_extractors::value_and_confidence(&obs);
assert_eq!(features.len(), 2);
assert_eq!(features[0], 25.0);
assert!((features[1] - 0.9).abs() < 0.001);
}
#[test]
fn test_feature_extractor_polynomial() {
let obs = Observation::sensor("temp", 2.0);
let features = feature_extractors::polynomial_features(&obs);
assert_eq!(features.len(), 4);
assert_eq!(features[0], 1.0); assert_eq!(features[1], 2.0); assert_eq!(features[2], 4.0); assert_eq!(features[3], 8.0); }
#[test]
fn test_feature_extractor_normalized() {
let extractor = feature_extractors::normalized_value(0.0, 100.0);
let obs = Observation::sensor("temp", 50.0);
let features = extractor(&obs);
assert_eq!(features.len(), 1);
assert_eq!(features[0], 0.5);
}
#[test]
fn test_feature_extractor_threshold() {
let extractor = feature_extractors::threshold_feature(30.0);
let obs_low = Observation::sensor("temp", 20.0);
let features_low = extractor(&obs_low);
assert_eq!(features_low[0], 0.0);
let obs_high = Observation::sensor("temp", 40.0);
let features_high = extractor(&obs_high);
assert_eq!(features_high[0], 1.0);
}
#[test]
fn test_feature_extractor_multi_threshold() {
let extractor = feature_extractors::multi_threshold_features(vec![20.0, 30.0, 40.0]);
let obs = Observation::sensor("temp", 35.0);
let features = extractor(&obs);
assert_eq!(features.len(), 3);
assert_eq!(features[0], 1.0); assert_eq!(features[1], 1.0); assert_eq!(features[2], 0.0); }
}