use crate::eval::Eval;
use pest::error::InputLocation;
use std::fmt::{Debug, Formatter};
use std::ops::{Deref, DerefMut};
#[derive(Debug, Clone)]
pub enum Expression<T: Eval> {
Literal(Spanned<T>),
BinaryOp(Box<BinaryOp<T>>),
UnaryOp(Box<UnaryOp<T>>),
}
impl<T: Eval> Expression<T> {
pub fn eval(&self) -> Result<T, T::ErrEval> {
match self {
Expression::Literal(v) => Ok(v.deref().clone()),
Expression::BinaryOp(v) => v.eval(),
Expression::UnaryOp(v) => v.eval(),
}
}
pub fn span(&self) -> Span {
match self {
Expression::Literal(v) => v.span(),
Expression::BinaryOp(v) => v.span(),
Expression::UnaryOp(v) => v.span(),
}
}
}
impl<T: Eval + PartialEq> PartialEq for Expression<T> {
fn eq(&self, other: &Self) -> bool {
match self {
Expression::Literal(v1) => {
if let Expression::Literal(v2) = other {
v1 == v2
} else {
false
}
}
Expression::BinaryOp(v1) => {
if let Expression::BinaryOp(v2) = other {
v1 == v2
} else {
false
}
}
Expression::UnaryOp(v1) => {
if let Expression::UnaryOp(v2) = other {
v1 == v2
} else {
false
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct BinaryOp<T: Eval> {
pub operand1: Expression<T>,
pub operand2: Expression<T>,
pub operator: Spanned<BinaryOpType>,
}
impl<T: Eval> BinaryOp<T> {
pub fn eval(&self) -> Result<T, T::ErrEval> {
let val1 = self.operand1.eval()?;
let val2 = self.operand2.eval()?;
match self.operator.deref() {
BinaryOpType::Eq => val1.eq(val2),
BinaryOpType::Neq => val1.neq(val2),
BinaryOpType::Gte => val1.gte(val2),
BinaryOpType::Gt => val1.gt(val2),
BinaryOpType::Lte => val1.lte(val2),
BinaryOpType::Lt => val1.lt(val2),
BinaryOpType::And => val1.and(val2),
BinaryOpType::Or => val1.or(val2),
BinaryOpType::BitAnd => val1.bit_and(val2),
BinaryOpType::BitOr => val1.bit_or(val2),
BinaryOpType::Add => val1.add(val2),
BinaryOpType::Sub => val1.sub(val2),
BinaryOpType::Mul => val1.mul(val2),
BinaryOpType::Div => val1.div(val2),
BinaryOpType::Mod => val1.rem(val2),
BinaryOpType::Exp => val1.exp(val2),
}
}
pub fn span(&self) -> Span {
Span::combine(
[
self.operand1.span(),
self.operand2.span(),
self.operator.span(),
]
.as_ref(),
)
.unwrap()
}
}
impl<T: Eval + PartialEq> PartialEq for BinaryOp<T> {
fn eq(&self, other: &Self) -> bool {
self.operator == other.operator
&& self.operand1 == other.operand1
&& self.operand2 == other.operand2
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum BinaryOpType {
Eq,
Neq,
Gte,
Gt,
Lte,
Lt,
And,
Or,
BitAnd,
BitOr,
Add,
Sub,
Mul,
Div,
Mod,
Exp,
}
#[derive(Debug, Clone)]
pub struct UnaryOp<T: Eval> {
pub operand: Expression<T>,
pub operator: Spanned<UnaryOpType>,
}
impl<T: Eval> UnaryOp<T> {
pub fn eval(&self) -> Result<T, T::ErrEval> {
let val = self.operand.eval()?;
match self.operator.deref() {
UnaryOpType::Plus => val.plus(),
UnaryOpType::Minus => val.minus(),
UnaryOpType::Not => val.not(),
UnaryOpType::BitNot => val.bit_not(),
}
}
pub fn span(&self) -> Span {
Span::combine([self.operand.span(), self.operator.span()].as_ref()).unwrap()
}
}
impl<T: Eval + PartialEq> PartialEq for UnaryOp<T> {
fn eq(&self, other: &Self) -> bool {
self.operator == other.operator && self.operand == other.operand
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum UnaryOpType {
Plus,
Minus,
Not,
BitNot,
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Span {
start: usize,
end: usize,
}
impl Span {
pub fn new(start: usize, end: usize) -> Option<Span> {
if end <= start {
return None;
}
Some(Self { start, end })
}
pub fn combine<'a>(spans: impl Into<&'a [Span]>) -> Option<Span> {
let spans = spans.into();
if spans.is_empty() {
return None;
}
let mut start = spans[0].start;
let mut end = spans[0].end;
for span in spans {
if span.start < start {
start = span.start
}
if span.end > end {
end = span.end
}
}
Some(Span { start, end })
}
pub fn start(&self) -> usize {
self.start
}
pub fn end(&self) -> usize {
self.end
}
pub fn len(&self) -> usize {
self.end - self.start
}
}
impl From<InputLocation> for Span {
fn from(value: InputLocation) -> Self {
match value {
InputLocation::Pos(v) => Span::new(v, v + 1).unwrap(),
InputLocation::Span((v1, v2)) => Span::new(v1, v2 + 1).unwrap(),
}
}
}
pub struct Spanned<T: Sized> {
span: Span,
inner: T,
}
impl<T: Sized> Spanned<T> {
pub fn new(inner: T, span: Span) -> Self {
Self { span, inner }
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn span(&self) -> Span {
self.span
}
}
impl<T: Sized + Clone> Clone for Spanned<T> {
fn clone(&self) -> Self {
Self {
span: self.span.clone(),
inner: self.inner.clone(),
}
}
}
impl<T: Sized + PartialEq> PartialEq for Spanned<T> {
fn eq(&self, other: &Self) -> bool {
self.span == other.span && self.inner == other.inner
}
}
impl<T: Sized + Debug> Debug for Spanned<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(format!("Spanned({:?}, {:?})", self.span, self.inner).as_str())
}
}
impl<T: Sized> Deref for Spanned<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T: Sized> DerefMut for Spanned<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl From<pest::Span<'_>> for Span {
fn from(value: pest::Span<'_>) -> Self {
Span::new(value.start(), value.end()).expect("could not convert span")
}
}