use crate::utils::error::OctopusError;
use ndarray::{Array, Array1, Dimension, Ix1};
use rand::{Rng, rng};
use std::collections::HashMap;
use std::hash::Hash; use std::ops::{Deref, DerefMut};
pub trait Action: Clone + Eq + Hash + Send + Sync + 'static {
type ValueType;
fn id(&self) -> u32;
fn name(&self) -> String {
format!("Action-{}", self.id())
}
fn value(&self) -> Self::ValueType;
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct NumericAction<T>
where
T: Copy + PartialEq + Eq + Hash + Send + Sync + 'static,
{
id: u32,
value: T,
name: String,
}
impl<T> NumericAction<T>
where
T: Copy + PartialEq + Eq + Hash + Send + Sync + 'static,
{
pub fn new(value: T, name: &str) -> Self {
let mut rng = rng();
let id = rng.random::<u32>();
Self {
id,
value,
name: name.to_string(),
}
}
pub fn with_id(id: u32, value: T, name: &str) -> Self {
Self {
id,
value,
name: name.to_string(),
}
}
}
impl<T> Action for NumericAction<T>
where
T: Copy + Eq + Hash + Send + Sync + 'static,
{
type ValueType = T;
fn id(&self) -> u32 {
self.id
}
fn value(&self) -> T {
self.value
}
fn name(&self) -> String {
self.name.clone()
}
}
#[derive(Debug, Clone)]
pub struct ActionStorage<A: Action>(HashMap<u32, A>);
impl<A: Action + Clone> ActionStorage<A> {
pub fn new(initial_actions: &[A]) -> Result<Self, OctopusError> {
let actions = initial_actions
.into_iter()
.map(|action| (action.id(), action.clone()))
.collect();
Ok(ActionStorage { 0: actions })
}
pub fn get_all_actions(&self) -> Vec<A> {
self.0.values().cloned().collect()
}
}
impl<A: Action> Deref for ActionStorage<A> {
type Target = HashMap<u32, A>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<A: Action> DerefMut for ActionStorage<A> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub trait Reward: Clone + Send + Sync + 'static {
fn value(&self) -> f64;
}
pub trait Context: Clone + Send + Sync + 'static {
type DimType: Dimension;
fn to_ndarray(&self) -> Array<f64, Self::DimType>;
}
#[derive(Debug, Clone, PartialEq)]
pub struct DummyContext;
impl Context for DummyContext {
type DimType = Ix1;
fn to_ndarray(&self) -> Array<f64, Self::DimType> {
Array1::from_vec(vec![0.0])
}
}