use std::{any::TypeId, fmt::Debug, ops::Mul};
pub trait Model {
fn complexity(&self, n: f64) -> f64;
#[doc(hidden)]
fn constant(&self) -> bool {
false
}
fn to_string(&self) -> String;
}
pub(crate) struct BoxedModel {
storage: Box<dyn Model + 'static>,
kind: TypeId,
}
impl BoxedModel {
pub(crate) fn new<M: Model + 'static>(model: M) -> Self {
let kind = TypeId::of::<M>();
debug_assert_ne!(kind, TypeId::of::<BoxedModel>());
Self {
storage: Box::new(model),
kind,
}
}
}
impl Model for BoxedModel {
fn complexity(&self, n: f64) -> f64 {
self.storage.complexity(n)
}
fn to_string(&self) -> String {
self.storage.to_string()
}
fn constant(&self) -> bool {
self.storage.constant()
}
}
impl PartialEq for BoxedModel {
fn eq(&self, other: &Self) -> bool {
(self.kind == other.kind) && (self.to_string() == other.to_string())
}
}
impl Debug for BoxedModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("BoxedModel(")?;
f.write_str(&self.to_string())?;
f.write_str(")")?;
Ok(())
}
}
#[derive(Clone, PartialEq)]
pub struct Constant;
impl Model for Constant {
fn complexity(&self, _n: f64) -> f64 {
1.0
}
fn constant(&self) -> bool {
true
}
fn to_string(&self) -> String {
"1".to_string()
}
}
impl<B: Model> Mul<B> for Constant {
type Output = B;
fn mul(self, rhs: B) -> Self::Output {
rhs
}
}
#[derive(Clone, PartialEq)]
pub struct MultipliedModels<M1: Model, M2: Model> {
m1: M1,
m2: M2,
}
impl<M1: Model, M2: Model> Model for MultipliedModels<M1, M2> {
fn complexity(&self, n: f64) -> f64 {
let c1 = self.m1.complexity(n);
let c2 = self.m2.complexity(n);
c1 * c2
}
fn to_string(&self) -> String {
format!("{}*{}", self.m1.to_string(), self.m2.to_string(),)
}
}
macro_rules! impl_mul {
($t:ident) => {
impl<RHS: Model> Mul<RHS> for $t {
type Output = MultipliedModels<$t, RHS>;
fn mul(self, rhs: RHS) -> Self::Output {
MultipliedModels { m1: self, m2: rhs }
}
}
};
($t:ident<$($n:ident),+>) => {
impl<RHS: Model, $($n: Model),+> Mul<RHS> for $t<$($n),+> {
type Output = MultipliedModels<$t<$($n),+>, RHS>;
fn mul(self, rhs: RHS) -> Self::Output {
MultipliedModels { m1: self, m2: rhs }
}
}
};
}
#[derive(Clone, PartialEq)]
pub struct Log<M: Model>(pub M);
impl<M: Model> Model for Log<M> {
fn complexity(&self, n: f64) -> f64 {
self.0.complexity(n).log2()
}
fn to_string(&self) -> String {
format!("log({})", self.0.to_string())
}
}
#[derive(Clone, PartialEq)]
pub struct Sqrt<M: Model>(pub M);
impl<M: Model> Model for Sqrt<M> {
fn complexity(&self, n: f64) -> f64 {
self.0.complexity(n).sqrt()
}
fn to_string(&self) -> String {
format!("sqrt({})", self.0.to_string())
}
}
#[derive(Clone, PartialEq)]
pub struct Pow(pub f64);
impl Model for Pow {
fn complexity(&self, n: f64) -> f64 {
n.powf(self.0)
}
fn to_string(&self) -> String {
if self.0 == 1. {
return "n".to_string();
}
if self.0 < 10.0 && self.0.round() == self.0 {
return format!("n^{}", self.0);
}
format!("n^{{{}}}", self.0)
}
}
impl_mul!(Log<M>);
impl_mul!(Sqrt<M>);
impl_mul!(MultipliedModels<C, D>);
impl_mul!(Pow);
pub const N: Pow = Pow(1.0);
pub const N2: Pow = Pow(2.0);
pub const N3: Pow = Pow(3.0);
pub struct KnownModels {
models: Vec<BoxedModel>,
}
impl KnownModels {
pub(crate) fn into_iter(self) -> impl Iterator<Item = BoxedModel> {
self.models.into_iter()
}
pub fn with<M: Model + 'static>(mut self, model: M) -> Self {
self.models.push(BoxedModel::new(model));
self
}
pub(crate) fn new() -> Self {
KnownModels { models: vec![] }
}
}
impl Default for KnownModels {
fn default() -> Self {
KnownModels::new()
.with(Constant)
.with(N)
.with(Log(N))
.with(Sqrt(N))
.with(N * Log(N))
.with(Log(Log(N)))
.with(N2)
.with(N3)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn equality() {
let boxed_logn = BoxedModel::new(Log(N));
let boxed_n = BoxedModel::new(N);
assert_ne!(&boxed_n, &boxed_logn);
assert_eq!(&boxed_n, &boxed_n);
assert_eq!(&boxed_logn, &boxed_logn);
}
#[test]
fn constant() {
assert_eq!(Constant.complexity(10.0), 1.0);
assert_eq!(Constant.complexity(20.0), 1.0);
let c = Constant * N;
assert_eq!(c.complexity(10.0), 10.0);
assert_eq!(c.complexity(20.0), 20.0);
assert_eq!(&Constant.to_string(), "1");
}
#[test]
fn linear() {
assert_eq!(N.complexity(10.0), 10.0);
assert_eq!(N.complexity(20.0), 20.0);
let n_2 = N * N;
assert_eq!(n_2.complexity(10.0), 100.0);
assert_eq!(n_2.complexity(20.0), 400.0);
assert_eq!(&N.to_string(), "n");
}
#[test]
fn n2() {
assert_eq!(N2.complexity(10.0), 100.0);
assert_eq!(N2.complexity(20.0), 400.0);
let n_3 = N2 * N;
assert_eq!(n_3.complexity(10.0), 1000.0);
assert_eq!(n_3.complexity(20.0), 8000.0);
assert_eq!(&N2.to_string(), "n^2");
}
#[test]
fn n3() {
assert_eq!(N3.complexity(10.0), 1000.0);
assert_eq!(N3.complexity(20.0), 8000.0);
let n_4 = N3 * N;
assert_eq!(n_4.complexity(2.0), 16.0);
assert_eq!(&N3.to_string(), "n^3");
}
#[test]
fn log() {
assert_eq!(Log(N).complexity(2.0), 1.0);
assert_eq!(Log(N).complexity(4.0), 2.0);
let n_log_n = Log(N) * N;
assert_eq!(n_log_n.complexity(4.0), 8.0);
assert_eq!(n_log_n.complexity(16.0), 64.0);
assert_eq!(&Log(N).to_string(), "log(n)");
}
#[test]
fn sqrt() {
assert_eq!(Sqrt(N).complexity(100.0), 10.0);
assert_eq!(Sqrt(N).complexity(4.0), 2.0);
let n_sqrt_n = Sqrt(N) * N;
assert_eq!(n_sqrt_n.complexity(9.0), 27.0);
assert_eq!(n_sqrt_n.complexity(100.0), 1000.0);
assert_eq!(&Sqrt(N).to_string(), "sqrt(n)");
}
#[test]
fn power() {
assert_eq!(Pow(2.0).complexity(11.0), 121.0);
let n_pow_3 = Pow(3.0) * N;
assert_eq!(n_pow_3.complexity(2.0), 16.0);
assert_eq!(&Pow(4.5).to_string(), "n^{4.5}");
assert_eq!(&Pow(10.0).to_string(), "n^{10}");
assert_eq!(&Pow(9.0).to_string(), "n^9");
assert_eq!(&Pow(4.0).to_string(), "n^4");
assert_eq!(&Pow(3.0).to_string(), "n^3");
assert_eq!(&Pow(2.0).to_string(), "n^2");
assert_eq!(&Pow(1.0).to_string(), "n");
}
#[test]
fn multiply_multiply() {
let c = N * N * N;
assert_eq!(c.complexity(10.0), 1000.0);
assert_eq!(c.complexity(3.0), 27.0);
assert_eq!(&(N * Log(N)).to_string(), "n*log(n)");
}
#[test]
fn known_models() {
assert_eq!(
KnownModels::new().into_iter().collect::<Vec<BoxedModel>>(),
vec![]
);
assert_eq!(
KnownModels::new()
.with(N2)
.with(N3)
.into_iter()
.collect::<Vec<BoxedModel>>(),
vec![BoxedModel::new(N2), BoxedModel::new(N3)]
);
assert_eq!(
KnownModels::default()
.into_iter()
.collect::<Vec<BoxedModel>>(),
vec![
BoxedModel::new(Constant),
BoxedModel::new(N),
BoxedModel::new(Log(N)),
BoxedModel::new(Sqrt(N)),
BoxedModel::new(N * Log(N)),
BoxedModel::new(Log(Log(N))),
BoxedModel::new(N2),
BoxedModel::new(N3)
]
);
}
}