rsrl 0.8.1

A fast, extensible reinforcement learning framework in Rust
Documentation
//! Actor-critic algorithms.
use crate::{
    domains::Transition,
    fa::StateActionUpdate,
    policies::Policy,
    Function,
    Handler,
};

pub trait Critic<'t, S: 't, A: 't> {
    fn target(&self, t: &'t Transition<S, A>) -> f64;
}

impl<'t, F, S: 't, A: 't> Critic<'t, S, A> for F
where F: Fn(&'t Transition<S, A>) -> f64,
{
    fn target(&self, t: &'t Transition<S, A>) -> f64 {
        (self)(t)
    }
}

pub struct QCritic<Q>(pub Q);

impl<'t, Q, S: 't, A: 't> Critic<'t, S, A> for QCritic<Q>
where Q: Function<(&'t S, &'t A), Output = f64>,
{
    fn target(&self, t: &'t Transition<S, A>) -> f64 {
        self.0.evaluate((t.from.state(), &t.action))
    }
}

pub struct TDCritic<V> {
    pub gamma: f64,
    pub v_func: V,
}

impl<'t, V, S: 't, A: 't> Critic<'t, S, A> for TDCritic<V>
where V: Function<(&'t S,), Output = f64>,
{
    fn target(&self, t: &'t Transition<S, A>) -> f64 {
        let nv = self.v_func.evaluate((t.to.state(),));

        if t.terminated() {
            t.reward - nv
        } else {
            let v = self.v_func.evaluate((t.from.state(),));

            t.reward + self.gamma * nv - v
        }
    }
}

#[derive(Clone, Debug)]
#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
pub struct ActorCritic<C, P> {
    pub critic: C,
    pub policy: P,

    pub alpha: f64,
}

impl<C, P> ActorCritic<C, P> {
    pub fn new(critic: C, policy: P, alpha: f64) -> Self {
        ActorCritic {
            critic,
            policy,

            alpha,
        }
    }
}

impl<Q, P> ActorCritic<QCritic<Q>, P> {
    pub fn qac(q_func: Q, policy: P, alpha: f64) -> Self {
        ActorCritic {
            critic: QCritic(q_func),
            policy,
            alpha,
        }
    }
}

impl<V, P> ActorCritic<TDCritic<V>, P> {
    pub fn tdac(v_func: V, policy: P, alpha: f64, gamma: f64) -> Self {
        ActorCritic {
            critic: TDCritic {
                gamma,
                v_func,
            },
            policy,
            alpha,
        }
    }
}

impl<'m, S, C, P> Handler<&'m Transition<S, P::Action>> for ActorCritic<C, P>
where
    C: Critic<'m, S, P::Action>,
    P: Policy<&'m S> + Handler<StateActionUpdate<&'m S, &'m <P as Policy<&'m S>>::Action, f64>>,
{
    type Response = P::Response;
    type Error = P::Error;

    fn handle(&mut self, t: &'m Transition<S, P::Action>) -> Result<Self::Response, Self::Error> {
        self.policy.handle(StateActionUpdate {
            state: t.from.state(),
            action: &t.action,
            error: self.alpha * self.critic.target(&t),
        })
    }
}