use conv::ConvUtil;
use eyre::Result;
use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedSub};
#[derive(Debug, Clone)]
pub enum Scalar {
Float(f64),
Integer(i64),
Char(char),
Nested(Box<Val>),
}
impl Scalar {
fn promote_pair(a: &Scalar, b: &Scalar) -> (Scalar, Scalar) {
match (a, b) {
(Scalar::Integer(i), Scalar::Integer(j)) => (Scalar::Integer(*i), Scalar::Integer(*j)),
(Scalar::Float(f), Scalar::Float(g)) => (Scalar::Float(*f), Scalar::Float(*g)),
(Scalar::Integer(i), Scalar::Float(f)) => (Scalar::Float(*i as f64), Scalar::Float(*f)),
(Scalar::Float(f), Scalar::Integer(i)) => (Scalar::Float(*f), Scalar::Float(*i as f64)),
_ => (a.clone(), b.clone()),
}
}
}
impl TryFrom<Scalar> for usize {
type Error = &'static str;
fn try_from(value: Scalar) -> Result<Self, Self::Error> {
match value {
Scalar::Integer(val) => val
.try_into()
.map_err(|_| "Failed to convert i64 into usize"),
Scalar::Float(val) => {
if val.fract() == 0.0 && val >= 0.0 {
val.approx_as::<usize>()
.map_err(|_| "Failed to convert f64 into usize")
} else {
Err("Float is not a whole number or is negative")
}
}
Scalar::Char(_) => Err("Cannot convert char to usize"),
Scalar::Nested(_) => Err("Cannot convert nested value to usize"),
}
}
}
impl From<Scalar> for f64 {
fn from(value: Scalar) -> Self {
match value {
Scalar::Integer(val) => val as f64,
Scalar::Float(val) => val,
Scalar::Char(c) => c as u32 as f64,
Scalar::Nested(_) => 0.0,
}
}
}
impl PartialEq for Scalar {
fn eq(&self, other: &Scalar) -> bool {
match (self, other) {
(Scalar::Integer(i), Scalar::Integer(j)) => i == j,
(Scalar::Float(f), Scalar::Float(g)) => f == g,
(Scalar::Integer(i), Scalar::Float(f)) => *i as f64 == *f,
(Scalar::Float(f), Scalar::Integer(i)) => *f == *i as f64,
(Scalar::Char(a), Scalar::Char(b)) => a == b,
_ => false,
}
}
}
impl Eq for Scalar {}
impl PartialOrd for Scalar {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Scalar {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (self, other) {
(Scalar::Integer(i), Scalar::Integer(j)) => i.cmp(j),
(Scalar::Float(f), Scalar::Float(g)) => {
if f.is_nan() && g.is_nan() {
std::cmp::Ordering::Equal
} else if f.is_nan() {
std::cmp::Ordering::Less
} else if g.is_nan() {
std::cmp::Ordering::Greater
} else {
f.partial_cmp(g).unwrap()
}
}
(Scalar::Integer(i), Scalar::Float(f)) => {
(*i as f64).partial_cmp(f).unwrap_or_else(|| {
if f.is_nan() {
std::cmp::Ordering::Greater
} else {
(*i as f64).partial_cmp(f).unwrap()
}
})
}
(Scalar::Float(f), Scalar::Integer(i)) => {
f.partial_cmp(&(*i as f64)).unwrap_or_else(|| {
if f.is_nan() {
std::cmp::Ordering::Less
} else {
f.partial_cmp(&(*i as f64)).unwrap()
}
})
}
(Scalar::Char(a), Scalar::Char(b)) => a.cmp(b),
(Scalar::Char(_), _) => std::cmp::Ordering::Greater,
(_, Scalar::Char(_)) => std::cmp::Ordering::Less,
_ => std::cmp::Ordering::Equal,
}
}
}
impl std::ops::Add for Scalar {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
let promoted = Self::promote_pair(&self, &other);
match promoted {
(Scalar::Integer(i), Scalar::Integer(j)) => Scalar::Integer(i + j),
(Scalar::Float(f), Scalar::Float(g)) => Scalar::Float(f + g),
_ => panic!("BUG: Unexpected type mismatch after promotion"),
}
}
}
impl CheckedAdd for Scalar {
fn checked_add(&self, other: &Self) -> Option<Self> {
let promoted_result = Self::promote_pair(self, other);
match promoted_result {
(Scalar::Integer(i), Scalar::Integer(j)) => i.checked_add(j).map(Scalar::Integer),
(Scalar::Float(f), Scalar::Float(g)) => Some(Scalar::Float(f + g)),
_ => None,
}
}
}
impl std::ops::Sub for Scalar {
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
let promoted = Self::promote_pair(&self, &other);
match promoted {
(Scalar::Integer(i), Scalar::Integer(j)) => Scalar::Integer(i - j),
(Scalar::Float(f), Scalar::Float(g)) => Scalar::Float(f - g),
_ => panic!("BUG: Unexpected type mismatch after promotion"),
}
}
}
impl CheckedSub for Scalar {
fn checked_sub(&self, other: &Self) -> Option<Self> {
let promoted_result = Self::promote_pair(self, other);
match promoted_result {
(Scalar::Integer(i), Scalar::Integer(j)) => i.checked_sub(j).map(Scalar::Integer),
(Scalar::Float(f), Scalar::Float(g)) => Some(Scalar::Float(f - g)),
_ => None,
}
}
}
impl std::ops::Mul for Scalar {
type Output = Self;
fn mul(self, other: Self) -> Self::Output {
let promoted = Self::promote_pair(&self, &other);
match promoted {
(Scalar::Integer(i), Scalar::Integer(j)) => Scalar::Integer(i * j),
(Scalar::Float(f), Scalar::Float(g)) => Scalar::Float(f * g),
_ => panic!("BUG: Unexpected type mismatch after promotion"),
}
}
}
impl CheckedMul for Scalar {
fn checked_mul(&self, other: &Self) -> Option<Self> {
let promoted_result = Self::promote_pair(self, other);
match promoted_result {
(Scalar::Integer(i), Scalar::Integer(j)) => i.checked_mul(j).map(Scalar::Integer),
(Scalar::Float(f), Scalar::Float(g)) => Some(Scalar::Float(f * g)),
_ => None,
}
}
}
impl std::ops::Div for Scalar {
type Output = Self;
fn div(self, other: Self) -> Self::Output {
let promoted = Self::promote_pair(&self, &other);
match promoted {
(Scalar::Integer(i), Scalar::Integer(j)) => Scalar::Float(i as f64 / j as f64),
(Scalar::Float(f), Scalar::Float(g)) => Scalar::Float(f / g),
_ => panic!("BUG: Unexpected type mismatch after promotion"),
}
}
}
impl CheckedDiv for Scalar {
fn checked_div(&self, other: &Self) -> Option<Self> {
let promoted_result = Self::promote_pair(self, other);
match promoted_result {
(Scalar::Integer(i), Scalar::Integer(j)) => Some(Scalar::Float(i as f64 / j as f64)),
(Scalar::Float(f), Scalar::Float(g)) => Some(Scalar::Float(f / g)),
_ => None,
}
}
}
pub trait CheckedPow: Sized {
fn checked_pow(&self, power: usize) -> Option<Self>;
fn checked_powf(&self, other: f64) -> Option<Self>;
}
impl CheckedPow for Scalar {
fn checked_pow(&self, other: usize) -> Option<Self> {
match self {
Scalar::Integer(i) => i.checked_pow(other as u32).map(Scalar::Integer),
Scalar::Float(f) => Some(Scalar::Float(num_traits::pow::pow(*f, other))),
Scalar::Char(_) => None,
Scalar::Nested(_) => None,
}
}
fn checked_powf(&self, other: f64) -> Option<Self> {
match self {
Scalar::Integer(i) => Some(Scalar::Float((*i as f64).powf(other))),
Scalar::Float(f) => Some(Scalar::Float(f.powf(other))),
Scalar::Char(_) => None,
Scalar::Nested(_) => None,
}
}
}
pub trait Log: Sized {
fn log(&self, base: &Self) -> Option<Self>;
}
impl Log for Scalar {
fn log(&self, base: &Self) -> Option<Self> {
match (self, base) {
(Scalar::Nested(_), _) | (_, Scalar::Nested(_)) => None,
_ => Some(Scalar::Float(
f64::from(self.clone()).log(f64::from(base.clone())),
)),
}
}
}
impl CheckedNeg for Scalar {
fn checked_neg(&self) -> Option<Self> {
match self {
Scalar::Integer(i) => i.checked_neg().map(Scalar::Integer),
Scalar::Float(f) => Some(Scalar::Float(-f)),
Scalar::Char(_) => None,
Scalar::Nested(_) => None,
}
}
}
#[derive(Debug, Clone)]
pub struct Val {
pub shape: Vec<usize>,
pub data: Vec<Scalar>,
}
impl Val {
pub fn scalar(s: Scalar) -> Self {
Val {
shape: vec![],
data: vec![s],
}
}
pub fn vector(data: Vec<Scalar>) -> Self {
let len = data.len();
Val {
shape: vec![len],
data,
}
}
pub fn new(shape: Vec<usize>, data: Vec<Scalar>) -> Self {
Val { shape, data }
}
pub fn is_scalar(&self) -> bool {
self.shape.is_empty()
}
pub fn depth(&self) -> usize {
if self.is_scalar() {
match &self.data[0] {
Scalar::Nested(inner) => 1 + inner.depth(),
_ => 0,
}
} else {
let has_nested = self.data.iter().any(|s| matches!(s, Scalar::Nested(_)));
if has_nested {
1 + self
.data
.iter()
.map(|s| match s {
Scalar::Nested(inner) => inner.depth(),
_ => 0,
})
.max()
.unwrap_or(0)
} else {
1
}
}
}
pub fn matches_val(&self, other: &Val) -> bool {
self.shape == other.shape
&& self.data.len() == other.data.len()
&& self
.data
.iter()
.zip(other.data.iter())
.all(|(a, b)| match (a, b) {
(Scalar::Nested(va), Scalar::Nested(vb)) => va.matches_val(vb),
_ => a == b,
})
}
pub fn from_f64s(values: &[f64]) -> Self {
let data: Vec<Scalar> = values
.iter()
.map(|&v| {
if v.fract() == 0.0 && v.abs() < i64::MAX as f64 {
Scalar::Integer(v as i64)
} else {
Scalar::Float(v)
}
})
.collect();
if data.len() == 1 {
Val::scalar(data.into_iter().next().unwrap())
} else {
Val::vector(data)
}
}
}