use std::{
cmp::Ordering,
fmt::Display,
iter::Sum,
ops::{Add, Div, Mul, Rem, Sub},
};
use anyhow::Result;
use crate::prelude::{Deserialize, Serialize};
use super::{constants::NULL, Value};
const MAX_U32_AS_I128: i128 = u32::MAX as i128;
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum Primitive {
Int(i128),
Bool(bool),
Null,
Double(f64),
String(String),
Array(Vec<Primitive>),
Error(String),
Function { parameters: Vec<String>, exprs: Vec<Value> },
Unit,
NoReturn,
EarlyReturn(Box<Primitive>),
}
pub trait ToNumber {
fn to_int(&self) -> Self;
fn to_double(&self) -> Self;
}
pub trait Pow {
fn pow(&self, n: Self) -> Self;
}
pub trait And {
fn and(&self, n: Self) -> Self;
}
pub trait Or {
fn or(&self, n: Self) -> Self;
}
pub trait Sqrt {
fn sqrt(&self) -> Self;
}
pub trait Abs {
fn abs(&self) -> Self;
}
pub trait Logarithm {
fn log(&self) -> Self;
fn ln(&self) -> Self;
}
pub trait Sin {
fn sin(&self) -> Self;
}
pub trait Array {
fn index_at(&self, rhs: &Self) -> Self;
fn len(&self) -> Primitive;
fn swap_mem(&mut self, rhs: &mut Self, index: &Primitive) -> Primitive;
}
pub trait Cos {
fn cos(&self) -> Self;
}
pub trait Tan {
fn tan(&self) -> Self;
}
#[allow(dead_code)]
impl Primitive {
pub fn is_greater_than(&self, other: &Primitive) -> Primitive {
match self.partial_cmp(other) {
Some(Ordering::Greater) => Primitive::Bool(true),
Some(Ordering::Less) => Primitive::Bool(false),
Some(Ordering::Equal) => Primitive::Bool(false),
None => Primitive::Error(
format!("call to is_greater_than() for two different types {self} => {other}")
),
}
}
pub fn is_greater_or_equal(&self, other: &Primitive) -> Primitive {
match self.partial_cmp(other) {
Some(Ordering::Greater) | Some(Ordering::Equal) => {
Primitive::Bool(true)
}
Some(Ordering::Less) => Primitive::Bool(false),
None => Primitive::Error(
format!("call to is_greater_or_equal() for two different types {self} => {other}")
),
}
}
pub fn is_less_than(&self, other: &Primitive) -> Primitive {
match self.partial_cmp(other) {
Some(Ordering::Less) => Primitive::Bool(true),
Some(Ordering::Greater) => Primitive::Bool(false),
Some(Ordering::Equal) => Primitive::Bool(false),
None => Primitive::Error(
format!("call to is_less_than() for two different types {self} => {other}")
),
}
}
pub fn is_less_or_equal(&self, other: &Primitive) -> Primitive {
match self.partial_cmp(other) {
Some(Ordering::Less) | Some(Ordering::Equal) => {
Primitive::Bool(true)
}
Some(Ordering::Greater) => Primitive::Bool(false),
None => Primitive::Error(
format!("call to is_less_or_equal() for two different types {self} => {other}")
),
}
}
pub fn is_equal(&self, other: &Primitive) -> Primitive {
match self.partial_cmp(other) {
Some(Ordering::Equal) => Primitive::Bool(true),
Some(Ordering::Less) | Some(Ordering::Greater) => {
Primitive::Bool(false)
}
None => match (self, other) {
(Primitive::Null, _) | (_, Primitive::Null) => {
Primitive::Bool(false)
}
_ => Primitive::Error(
format!("call to is_equal() for two different types {self} => {other}"),
),
},
}
}
pub fn as_ref_ok(&self) -> Result<&Primitive> {
match self {
Primitive::Error(msg) => Err(anyhow::Error::msg(msg.to_string())),
_ => Ok(self),
}
}
}
impl Display for Primitive {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Primitive::Int(i) => write!(f, "{i}"),
Primitive::Double(d) => write!(f, "{d}"),
Primitive::Bool(b) => write!(f, "{b}"),
Primitive::Error(e) => write!(f, "Err: {e}"),
Primitive::String(s) => write!(f, "{s}"),
Primitive::Unit => Ok(()),
Primitive::Array(arr) => {
let joined_arr = arr
.iter()
.map(|p| match p {
Primitive::String(s) => format!(r#""{s}""#),
_ => p.to_string(),
})
.collect::<Vec<_>>();
write!(f, "[{}]", joined_arr[..].join(", "))
}
Primitive::Function { parameters, exprs: _ } => {
write!(f, "({}) => {{...}}", parameters.join(", "))
}
Primitive::NoReturn => write!(f, "!"),
Primitive::Null => write!(f, "{NULL}"),
Primitive::EarlyReturn(p) => write!(f, "{p}"),
}
}
}
impl Sin for Primitive {
fn sin(&self) -> Self {
match self {
Primitive::Int(i) => Primitive::Double((*i as f64).sin()),
Primitive::Double(d) => Primitive::Double(d.sin()),
Primitive::Error(e) => panic!("call to sin() on an error. {e}"),
_ => Primitive::Error(format!("illegal call to sin() => {self}")),
}
}
}
impl Cos for Primitive {
fn cos(&self) -> Self {
match self {
Primitive::Int(i) => Primitive::Double((*i as f64).cos()),
Primitive::Double(d) => Primitive::Double(d.cos()),
Primitive::Error(e) => panic!("call to cos() on an error. {e}"),
_ => Primitive::Error(format!("illegal call to cos() => {self}")),
}
}
}
impl Tan for Primitive {
fn tan(&self) -> Self {
match self {
Primitive::Int(i) => Primitive::Double((*i as f64).tan()),
Primitive::Double(d) => Primitive::Double(d.tan()),
Primitive::Error(e) => panic!("call to tan() on an error. {e}"),
_ => Primitive::Error(format!("illegal call to tan() => {self}")),
}
}
}
impl Logarithm for Primitive {
fn log(&self) -> Self {
match self {
Primitive::Int(i) => Primitive::Double((*i as f64).log10()),
Primitive::Double(d) => Primitive::Double(d.log10()),
Primitive::Error(e) => panic!("call to log() on an error. {e}"),
_ => Primitive::Error(format!("illegal call to log() => {self}")),
}
}
fn ln(&self) -> Self {
match self {
Primitive::Int(i) => Primitive::Double((*i as f64).ln()),
Primitive::Double(d) => Primitive::Double(d.ln()),
Primitive::Error(e) => panic!("call to ln() on an error. {e}"),
_ => Primitive::Error(format!("illegal call to ln() => {self}")),
}
}
}
impl Sqrt for Primitive {
fn sqrt(&self) -> Self {
match self {
Primitive::Int(i) => Primitive::Double((*i as f64).sqrt()),
Primitive::Double(d) => Primitive::Double(d.sqrt()),
Primitive::Error(e) => panic!("call to sqrt() on an error. {e}"),
_ => Primitive::Error(format!("illegal call to sqrt() => {self}")),
}
}
}
impl Abs for Primitive {
fn abs(&self) -> Self {
match self {
Primitive::Int(i) => Primitive::Int(i.abs()),
Primitive::Double(d) => Primitive::Double(d.abs()),
Primitive::Error(e) => panic!("call to abs() on an error. {e}"),
_ => Primitive::Error(format!("illegal call to abs() => {self}")),
}
}
}
impl Pow for Primitive {
fn pow(&self, rhs: Self) -> Self {
match (self, rhs) {
#[allow(clippy::manual_range_contains)]
(Primitive::Int(l), Primitive::Int(r))
if r >= 0 && r <= MAX_U32_AS_I128 =>
{
Primitive::Int(l.pow(r as u32))
}
(Primitive::Int(l), Primitive::Int(r)) => {
Primitive::Double((*l as f64).powf(r as f64))
}
(Primitive::Int(l), Primitive::Double(r)) => {
Primitive::Double((*l as f64).powf(r))
}
(Primitive::Double(l), Primitive::Int(r)) => {
Primitive::Double(l.powf(r as f64))
}
(Primitive::Double(l), Primitive::Double(r)) => {
Primitive::Double(l.powf(r))
}
(l, r) => Primitive::Error(format!(
"illegal call to pow() => left: {l} right: {r}"
)),
}
}
}
impl Sum for Primitive {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
let mut first = Primitive::Int(0);
for next in iter {
first = first + next;
}
first
}
}
impl Add for Primitive {
type Output = Primitive;
fn add(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Primitive::Int(l), Primitive::Int(r)) => Primitive::Int(l + r),
(Primitive::Int(l), Primitive::Double(r)) => {
Primitive::Double(l as f64 + r)
}
(Primitive::Int(l), Primitive::String(s)) => {
Primitive::String(format!("{l}{s}"))
}
(Primitive::Double(l), Primitive::Int(r)) => {
Primitive::Double(l + r as f64)
}
(Primitive::Double(l), Primitive::Double(r)) => {
Primitive::Double(l + r)
}
(l, Primitive::String(s)) => Primitive::String(format!("{l}{s}")),
(Primitive::String(s), r) => Primitive::String(format!("{s}{r}")),
(Primitive::Array(mut l), Primitive::Array(mut r)) => {
l.append(&mut r);
Primitive::Array(l)
}
(l, r) => Primitive::Error(format!(
"illegal call to add() => left: {l} right: {r}"
)),
}
}
}
impl Sub for Primitive {
type Output = Primitive;
fn sub(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Primitive::Int(l), Primitive::Int(r)) => Primitive::Int(l - r),
(Primitive::Int(l), Primitive::Double(r)) => {
Primitive::Double(l as f64 - r)
}
(Primitive::Double(l), Primitive::Int(r)) => {
Primitive::Double(l - r as f64)
}
(Primitive::Double(l), Primitive::Double(r)) => {
Primitive::Double(l - r)
}
(l, r) => Primitive::Error(format!(
"illegal call to sub() => left: {l} right: {r}"
)),
}
}
}
impl Rem for Primitive {
type Output = Primitive;
fn rem(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Primitive::Int(l), Primitive::Int(r)) if r != 0 => {
Primitive::Int(l % r)
}
(Primitive::Int(l), Primitive::Double(r)) => {
Primitive::Double(l as f64 % r)
}
(Primitive::Int(_), _) => Primitive::Double(f64::NAN),
(Primitive::Double(l), Primitive::Int(r)) => {
Primitive::Double(l % r as f64)
}
(Primitive::Double(l), Primitive::Double(r)) => {
Primitive::Double(l % r)
}
(l, r) => Primitive::Error(format!(
"illegal call to rem() => left: {l} right: {r}"
)),
}
}
}
impl Mul for Primitive {
type Output = Primitive;
fn mul(self, rhs: Self) -> Self::Output {
fn multiply_array(arr: Vec<Primitive>, n: i128) -> Vec<Primitive> {
let arr_size = arr.len();
arr.into_iter().cycle().take(n as usize * arr_size).collect()
}
match (self, rhs) {
(Primitive::Int(l), Primitive::Int(r)) => {
Primitive::Int(l.wrapping_mul(r))
}
(Primitive::Int(l), Primitive::Double(r)) => {
Primitive::Double(l as f64 * r)
}
(Primitive::Int(l), Primitive::Array(r)) => {
Primitive::Array(multiply_array(r, l))
}
(Primitive::Double(l), Primitive::Int(r)) => {
Primitive::Double(l * r as f64)
}
(Primitive::Double(l), Primitive::Double(r)) => {
Primitive::Double(l * r)
}
(Primitive::String(l), Primitive::Int(r)) => {
Primitive::String(l.repeat(r as usize))
}
(Primitive::Array(l), Primitive::Int(n)) => {
Primitive::Array(multiply_array(l, n))
}
(l, r) => Primitive::Error(format!(
"illegal call to mul() => left: {l} right: {r}"
)),
}
}
}
impl Div for Primitive {
type Output = Primitive;
fn div(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Primitive::Int(l), Primitive::Int(r)) if r != 0 => {
Primitive::Int(l / r)
}
(Primitive::Int(l), Primitive::Double(r)) => {
Primitive::Double(l as f64 / r)
}
(Primitive::Int(l), Primitive::Int(_)) if l >= 1 => {
Primitive::Double(f64::INFINITY)
}
(Primitive::Int(_), _) => Primitive::Double(f64::NAN),
(Primitive::Double(l), Primitive::Int(r)) => {
Primitive::Double(l / r as f64)
}
(Primitive::Double(l), Primitive::Double(r)) => {
Primitive::Double(l / r)
}
(l, r) => Primitive::Error(format!(
"illegal call to div() => left: {l} right: {r}"
)),
}
}
}
impl std::ops::Neg for Primitive {
type Output = Primitive;
fn neg(self) -> Self::Output {
match self {
Primitive::Int(n) => Primitive::Int(-n),
Primitive::Double(n) => Primitive::Double(-n),
_ => Primitive::Error(format!("invalid call to neg() {self}")),
}
}
}
impl std::ops::Not for Primitive {
type Output = Primitive;
fn not(self) -> Self::Output {
match self {
Primitive::Bool(b) => Primitive::Bool(!b),
_ => Primitive::Error(format!("invalid call to not() {self}")),
}
}
}
impl ToNumber for Primitive {
fn to_int(&self) -> Self {
match self {
v @ Primitive::Int(_) => v.clone(),
Primitive::Bool(false) => Primitive::Int(0),
Primitive::Bool(true) => Primitive::Int(1),
Primitive::Double(d) => Primitive::Int(*d as i128),
Primitive::String(s) => match s.parse::<i128>() {
Ok(number) => Primitive::Int(number),
Err(e) => Primitive::Error(format!(
"invalid cast to int: {self}, {e}"
)),
},
_ => Primitive::Error(format!("invalid cast to int: {self}")),
}
}
fn to_double(&self) -> Self {
match self {
Primitive::Int(d) => Primitive::Double(*d as f64),
v @ Primitive::Double(_) => v.clone(),
Primitive::String(s) => match s.parse::<f64>() {
Ok(number) => Primitive::Double(number),
Err(e) => Primitive::Error(format!(
"invalid cast to double: {self}, {e}"
)),
},
_ => Primitive::Error(format!("invalid cast to double: {self}")),
}
}
}
impl Or for Primitive {
fn or(&self, rhs: Self) -> Self {
if let &Primitive::Bool(true) = &self {
return Primitive::Bool(true);
}
if !matches!((self, &rhs), (Primitive::Bool(_), Primitive::Bool(_))) {
return Primitive::Error(format!(
"illegal call to 'or' => left: {self} right: {rhs}"
));
}
rhs
}
}
impl And for Primitive {
fn and(&self, rhs: Self) -> Self {
if let &Primitive::Bool(false) = &self {
return Primitive::Bool(false);
}
if !matches!((self, &rhs), (Primitive::Bool(_), Primitive::Bool(_))) {
return Primitive::Error(format!(
"illegal call to 'and' => left: {self} right: {rhs}"
));
}
rhs
}
}
impl PartialOrd for Primitive {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match (self, other) {
(Primitive::Int(l), Primitive::Int(r)) => l.partial_cmp(r),
(Primitive::Int(l), Primitive::Double(r)) => {
(*l as f64).partial_cmp(r)
}
(Primitive::Double(l), Primitive::Int(r)) => {
l.partial_cmp(&(*r as f64))
}
(Primitive::Double(l), Primitive::Double(r)) => l.partial_cmp(r),
(Primitive::Bool(a), Primitive::Bool(b)) => a.partial_cmp(b),
(Primitive::String(l), Primitive::String(r)) => l.partial_cmp(r),
(Primitive::Unit, Primitive::Unit) => Some(Ordering::Equal),
(Primitive::Array(l), Primitive::Array(r)) => l.partial_cmp(r),
(
l @ Primitive::Function { parameters: _, exprs: _ },
r @ Primitive::Function { parameters: _, exprs: _ },
) => l.partial_cmp(r),
(Primitive::Null, Primitive::Null) => Some(Ordering::Equal),
(Primitive::Int(_), _) => None,
(Primitive::Bool(_), _) => None,
(Primitive::Double(_), _) => None,
(Primitive::String(_), _) => None,
(Primitive::NoReturn, _) => None,
(Primitive::EarlyReturn(_), _) => None,
(Primitive::Null, _) => None,
(Primitive::Array(_), _) => None,
(Primitive::Error(_), _) => None,
(Primitive::Unit, _) => None,
(Primitive::Function { parameters: _, exprs: _ }, _) => None,
}
}
}
impl Array for Primitive {
fn index_at(&self, rhs: &Primitive) -> Primitive {
match (self, rhs) {
(Primitive::Array(arr), Primitive::Int(idx)) => {
let idx = *idx as usize;
if idx < arr.len() {
arr[idx].clone()
} else {
Primitive::Error("index out of range".to_string())
}
}
(Primitive::String(s), Primitive::Int(idx)) => {
let idx = *idx as usize;
if idx < s.len() {
let s: String = s.chars().skip(idx).take(1).collect();
Primitive::String(s)
} else {
Primitive::Error("index out of range".to_string())
}
}
_ => Primitive::Error("illegal access to array!!!".to_string()),
}
}
fn len(&self) -> Primitive {
match self {
Primitive::String(s) => Primitive::Int(s.len() as i128),
Primitive::Array(a) => Primitive::Int(a.len() as i128),
_ => Primitive::Error(format!(
"call to len() on a non array value => {self}"
)),
}
}
fn swap_mem(
&mut self,
rhs: &mut Primitive,
index: &Primitive,
) -> Primitive {
match (self, index) {
(Primitive::Array(arr), Primitive::Int(idx)) => {
let idx = *idx as usize;
if !matches!(rhs, Primitive::Error(_) | Primitive::Unit)
&& idx < arr.len()
{
std::mem::swap(&mut arr[idx], rhs);
arr[idx].clone()
} else {
Primitive::Error("index out of range".to_string())
}
}
(Primitive::String(s), Primitive::Int(idx)) => {
let idx = *idx as usize;
if !matches!(rhs, Primitive::Error(_) | Primitive::Unit)
&& idx < s.len()
{
s.remove(idx);
s.insert_str(idx, &rhs.to_string());
rhs.clone()
} else {
Primitive::Error("index out of range".to_string())
}
}
_ => Primitive::Error("invalid call to swap_mem()".to_string()),
}
}
}
#[cfg(test)]
mod test {
use super::Primitive;
#[test]
fn test_add_valid() {
let l = Primitive::Int(1);
let r = Primitive::Int(2);
assert_eq!(l + r, Primitive::Int(3));
let l = Primitive::Int(1);
let r = Primitive::Double(2.);
assert_eq!(l + r, Primitive::Double(3.));
let l = Primitive::Double(1.);
let r = Primitive::Int(2);
assert_eq!(l + r, Primitive::Double(3.));
let l = Primitive::Double(1.);
let r = Primitive::Double(2.);
assert_eq!(l + r, Primitive::Double(3.));
}
}