use crate::{
fa::{GradientUpdate, ScaledGradientUpdate, StateActionUpdate},
params::*,
policies::{sample_probs_with_rng, Policy},
utils::argmax_first,
Differentiable,
Enumerable,
Function,
Handler,
};
use ndarray::{Array2, ArrayBase, Data, Ix2};
use rand::Rng;
use std::{f64, iter::FromIterator};
fn softmax<C: FromIterator<f64>>(values: &[f64], tau: f64, c: f64) -> C {
let mut z = 0.0;
let ps: Vec<f64> = values
.into_iter()
.map(|v| {
let v = ((v - c) / tau).exp();
z += v;
v
})
.collect();
ps.into_iter().map(|v| (v / z).min(f64::MAX)).collect()
}
fn softmax_stable<C: FromIterator<f64>>(values: &[f64], tau: f64) -> C {
let max_v = values
.into_iter()
.fold(f64::NAN, |acc, &v| f64::max(acc, v));
softmax(values, tau, max_v)
}
pub type Gibbs<F> = Softmax<F>;
#[derive(Clone, Copy, Debug)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct Error;
#[derive(Clone, Debug, Parameterised)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct Softmax<F> {
#[weights]
fa: F,
pub tau: f64,
}
impl<F> Softmax<F> {
pub fn new(fa: F, tau: f64) -> Self {
if tau.abs() < 1e-7 {
panic!("Tau parameter in Softmax must be non-zero.");
}
Softmax { fa, tau }
}
pub fn standard(fa: F) -> Self { Self::new(fa, 1.0) }
}
impl<'s, S, F: Function<(&'s S,), Output = Vec<f64>>> Function<(&'s S,)> for Softmax<F> {
type Output = Vec<f64>;
fn evaluate(&self, (s,): (&'s S,)) -> Vec<f64> {
let values = self.fa.evaluate((s,));
softmax_stable(&values, self.tau)
}
}
impl<'s, S, A, F> Function<(&'s S, A)> for Softmax<F>
where
A: std::borrow::Borrow<usize>,
F: Function<(&'s S, usize), Output = f64>,
{
type Output = f64;
fn evaluate(&self, (s, a): (&'s S, A)) -> f64 { self.fa.evaluate((s, *a.borrow())) }
}
impl<'s, S, F> Enumerable<(&'s S,)> for Softmax<F>
where F: Enumerable<(&'s S,), Output = Vec<f64>>
{
fn evaluate_index(&self, (s,): (&'s S,), index: usize) -> f64 {
self.fa.evaluate_index((s,), index)
}
}
impl<'s, S, A, F> Differentiable<(&'s S, A)> for Softmax<F>
where
A: std::borrow::Borrow<usize>,
F: Function<(&'s S, usize), Output = f64> + Parameterised,
F: Enumerable<(&'s S,), Output = Vec<f64>>,
F: Differentiable<(&'s S, usize)>,
{
type Jacobian = Array2<f64>;
fn grad(&self, _: (&'s S, A)) -> Array2<f64> { unimplemented!() }
fn grad_log(&self, (s, a): (&'s S, A)) -> Array2<f64> {
let a = *a.borrow();
let mut scale_factors = self.evaluate((s,));
scale_factors[a] = scale_factors[a] - 1.0;
let mut jac = Array2::zeros(self.weights_dim());
for (col, sf) in scale_factors.into_iter().enumerate() {
jac.scaled_add(-sf, &self.fa.grad((&s, col)).into_dense());
}
jac
}
}
impl<'s, S, F> Policy<&'s S> for Softmax<F>
where
F: Function<(&'s S, usize), Output = f64> + Parameterised,
F: Enumerable<(&'s S,), Output = Vec<f64>>,
{
type Action = usize;
fn sample<R: Rng + ?Sized>(&self, rng: &mut R, s: &'s S) -> usize {
sample_probs_with_rng(rng, &self.evaluate((s,)))
}
fn mode(&self, s: &'s S) -> usize { argmax_first(self.evaluate((s,))).0 }
}
impl<'s, S, A, F> Handler<StateActionUpdate<&'s S, A>> for Softmax<F>
where
A: std::borrow::Borrow<usize>,
F: Handler<ScaledGradientUpdate<<Self as Differentiable<(&'s S, A)>>::Jacobian>>,
Self: Differentiable<(&'s S, A)>,
{
type Response = F::Response;
type Error = F::Error;
fn handle(&mut self, msg: StateActionUpdate<&'s S, A>) -> Result<Self::Response, Self::Error> {
let jac = self.grad_log((msg.state, msg.action));
self.fa.handle(ScaledGradientUpdate {
alpha: msg.error,
jacobian: jac,
})
}
}
impl<D, F> Handler<GradientUpdate<ArrayBase<D, Ix2>>> for Softmax<F>
where
F: Parameterised,
D: Data<Elem = f64>,
{
type Response = ();
type Error = Error;
fn handle(&mut self, msg: GradientUpdate<ArrayBase<D, Ix2>>) -> Result<Self::Response, Self::Error> {
self.handle(GradientUpdate(&msg.0))
}
}
impl<'m, D, F> Handler<GradientUpdate<&'m ArrayBase<D, Ix2>>> for Softmax<F>
where
F: Parameterised,
D: Data<Elem = f64>,
{
type Response = ();
type Error = Error;
fn handle(&mut self, msg: GradientUpdate<&'m ArrayBase<D, Ix2>>) -> Result<Self::Response, Self::Error> {
msg.0.addto(&mut self.fa.weights_view_mut());
Ok(())
}
}
impl<F, D> Handler<ScaledGradientUpdate<ArrayBase<D, Ix2>>> for Softmax<F>
where
F: Parameterised,
D: Data<Elem = f64>,
{
type Response = ();
type Error = Error;
fn handle(&mut self, msg: ScaledGradientUpdate<ArrayBase<D, Ix2>>) -> Result<Self::Response, Self::Error> {
self.handle(ScaledGradientUpdate {
alpha: msg.alpha,
jacobian: &msg.jacobian,
})
}
}
impl<'m, F, D> Handler<ScaledGradientUpdate<&'m ArrayBase<D, Ix2>>> for Softmax<F>
where
F: Parameterised,
D: Data<Elem = f64>,
{
type Response = ();
type Error = Error;
fn handle(&mut self, msg: ScaledGradientUpdate<&'m ArrayBase<D, Ix2>>) -> Result<Self::Response, Self::Error> {
msg.jacobian
.scaled_addto(msg.alpha, &mut self.fa.weights_view_mut());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
fa::{
linear::{
basis::{Basis, Polynomial},
optim::SGD,
LFA,
},
mocking::MockQ,
},
};
use rand::thread_rng;
use std::f64::consts::E;
#[test]
#[should_panic]
fn test_0d() {
let p = Softmax::new(MockQ::new_shared(None), 1.0);
p.sample(&mut thread_rng(), &vec![]);
}
#[test]
fn test_1d() {
let p = Softmax::new(MockQ::new_shared(None), 1.0);
let mut rng = thread_rng();
for i in 1..100 {
assert_eq!(p.sample(&mut rng, &vec![i as f64]), 0);
}
}
}