use std::marker::PhantomData;
use num_traits::ToPrimitive;
use crate::agent::agent::update;
use crate::Stepper;
use super::{Agent, ArgBounds};
pub struct GreedyAgent<T> {
q_star: Vec<f64>,
stepper: Box<dyn Stepper>,
phantom: PhantomData<T>,
}
impl<T: ToPrimitive> Agent<T> for GreedyAgent<T> {
fn action(&self) -> usize {
self.q_star.arg_max()
}
fn arms(&self) -> usize {
self.q_star.len()
}
fn current_estimate(&self, arm: usize) -> f64 {
self.q_star[arm]
}
fn reset(&mut self, q_init: &[f64]) {
self.q_star = q_init.to_owned();
self.stepper.reset()
}
fn step(&mut self, arm: usize, reward: T) {
self.q_star[arm] += update(&mut self.stepper, &self.q_star, arm, reward)
}
}
impl<T> GreedyAgent<T> {
pub fn new(q_init: Vec<f64>, stepper: Box<dyn Stepper>) -> GreedyAgent<T> {
GreedyAgent {
q_star: q_init,
stepper,
phantom: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use crate::HarmonicStepper;
use super::{Agent, GreedyAgent};
#[test]
fn test_action() {
let Q_INIT = vec![0.5, 0.61, 0.7, 0.12, 0.37];
let stepper = HarmonicStepper::new(1, Q_INIT.len());
let greedy: GreedyAgent<u32> = GreedyAgent::new(Q_INIT, Box::new(stepper));
assert_eq!(greedy.action(), 2)
}
#[test]
fn test_q_star() {
let Q_INIT = vec![0.5, 0.61, 0.7, 0.12, 0.37];
let stepper = HarmonicStepper::new(1, Q_INIT.len());
let greedy: GreedyAgent<u32> = GreedyAgent::new(Q_INIT, Box::new(stepper));
assert_eq!(greedy.q_star, vec![0.5, 0.61, 0.7, 0.12, 0.37])
}
#[test]
fn test_reset() {
let Q_INIT = vec![0.5, 0.61, 0.7, 0.12, 0.37];
let mut stepper = HarmonicStepper::new(1, Q_INIT.len());
let mut greedy: GreedyAgent<u32> = GreedyAgent::new(Q_INIT.to_vec(), Box::new(stepper));
let new_q = vec![0.01, 0.86, 0.43, 0.65, 0.66];
greedy.reset(&new_q);
assert_eq!(greedy.q_star, new_q)
}
}