use crate::Scalar;
use std::ops::Range;
pub trait ScoreMapping: Send + Sync {
fn remap(&self, score: Scalar) -> Scalar;
fn chain<T>(self, other: T) -> ChainedScoreMapping<Self, T>
where
T: ScoreMapping,
Self: Sized,
{
ChainedScoreMapping::new(self, other)
}
}
#[derive(Debug, Default, Copy, Clone)]
pub struct NoScoreMapping;
impl ScoreMapping for NoScoreMapping {
fn remap(&self, score: Scalar) -> Scalar {
score
}
}
pub struct ClosureScoreMapping(pub Box<dyn Fn(Scalar) -> Scalar + Send + Sync>);
impl ClosureScoreMapping {
pub fn new<F>(f: F) -> Self
where
F: Fn(Scalar) -> Scalar + 'static + Send + Sync,
{
Self(Box::new(f))
}
}
impl ScoreMapping for ClosureScoreMapping {
fn remap(&self, score: Scalar) -> Scalar {
(self.0)(score)
}
}
impl std::fmt::Debug for ClosureScoreMapping {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClosureScoreMapping").finish()
}
}
#[derive(Debug, Clone)]
pub struct ChainedScoreMapping<A, B>
where
A: ScoreMapping,
B: ScoreMapping,
{
pub first: A,
pub second: B,
}
impl<A, B> ChainedScoreMapping<A, B>
where
A: ScoreMapping,
B: ScoreMapping,
{
pub fn new(first: A, second: B) -> Self {
Self { first, second }
}
}
impl<A, B> ScoreMapping for ChainedScoreMapping<A, B>
where
A: ScoreMapping,
B: ScoreMapping,
{
fn remap(&self, score: Scalar) -> Scalar {
self.second.remap(self.first.remap(score))
}
}
#[derive(Debug, Clone)]
pub struct ScoreRemap {
pub from: Range<Scalar>,
pub to: Range<Scalar>,
}
impl ScoreRemap {
pub fn new(from: Range<Scalar>, to: Range<Scalar>) -> Self {
Self { from, to }
}
}
impl ScoreMapping for ScoreRemap {
fn remap(&self, score: Scalar) -> Scalar {
let factor = (score - self.from.start) / (self.from.end - self.from.start);
factor * (self.to.end - self.to.start) + self.to.start
}
}
#[derive(Debug, Default, Copy, Clone)]
pub struct ReverseScoreMapping;
impl ScoreMapping for ReverseScoreMapping {
fn remap(&self, score: Scalar) -> Scalar {
1.0 - score
}
}
#[derive(Debug, Default, Copy, Clone)]
pub struct InverseScoreMapping;
impl ScoreMapping for InverseScoreMapping {
fn remap(&self, score: Scalar) -> Scalar {
1.0 / score
}
}
#[derive(Debug, Default, Copy, Clone)]
pub struct FastSigmoidScoreMapping;
impl ScoreMapping for FastSigmoidScoreMapping {
fn remap(&self, score: Scalar) -> Scalar {
score / (1.0 + score.abs())
}
}
#[derive(Debug, Default, Copy, Clone)]
pub struct ApproxSigmoidScoreMapping;
impl ScoreMapping for ApproxSigmoidScoreMapping {
fn remap(&self, score: Scalar) -> Scalar {
score / (1.0 + (score * score).sqrt())
}
}
#[derive(Debug, Default, Copy, Clone)]
pub struct ReluScoreMapping;
impl ScoreMapping for ReluScoreMapping {
fn remap(&self, score: Scalar) -> Scalar {
score.max(0.0)
}
}
#[derive(Debug, Default, Copy, Clone)]
pub struct SoftplusScoreMapping;
impl ScoreMapping for SoftplusScoreMapping {
fn remap(&self, score: Scalar) -> Scalar {
#[cfg(not(feature = "scalar64"))]
let base = std::f32::consts::E;
#[cfg(feature = "scalar64")]
let base = std::f64::consts::E;
(1.0 + score.exp()).log(base)
}
}