use std::{fmt, str::FromStr};
use arithmetic_parser::{
grammars::{NumGrammar, NumLiteral, Parse},
BinaryOp, InputSpan, NomResult,
};
use arithmetic_typing::{
arith::*, defs::Prelude, error::OpErrors, visit::Visit, Annotated, PrimitiveType, Type,
TypeEnvironment,
};
#[derive(Debug, Clone)]
enum NumOrBytes {
Num(u32),
Bytes(Vec<u8>),
}
impl NumLiteral for NumOrBytes {
fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
use arithmetic_parser::ErrorKind;
use nom::{
branch::alt,
bytes::complete::{tag, take_while},
combinator::{cut, map, map_res},
sequence::preceded,
};
let parse_hex = preceded(
tag("0x"),
cut(map_res(
take_while(|c: char| c.is_ascii_hexdigit()),
|digits: InputSpan<'_>| hex::decode(digits.fragment()).map_err(ErrorKind::literal),
)),
);
alt((map(parse_hex, Self::Bytes), map(u32::parse, Self::Num)))(input)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum NumOrBytesType {
Num,
Bytes,
Bool,
}
impl fmt::Display for NumOrBytesType {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(match self {
Self::Num => "Num",
Self::Bytes => "Bytes",
Self::Bool => "Bool",
})
}
}
impl FromStr for NumOrBytesType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Num" => Ok(Self::Num),
"Bytes" => Ok(Self::Bytes),
"Bool" => Ok(Self::Bool),
_ => Err(anyhow::anyhow!("Expected `Num`, `Bytes` or `Bool`")),
}
}
}
impl WithBoolean for NumOrBytesType {
const BOOL: Self = Self::Bool;
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Constraints {
Summable,
Lin,
}
impl fmt::Display for Constraints {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(match self {
Self::Lin => "Lin",
Self::Summable => "Sum",
})
}
}
impl Constraint<NumOrBytesType> for Constraints {
fn visitor<'r>(
&self,
substitutions: &'r mut Substitutions<NumOrBytesType>,
errors: OpErrors<'r, NumOrBytesType>,
) -> Box<dyn Visit<NumOrBytesType> + 'r> {
let predicate: fn(&NumOrBytesType) -> bool = match self {
Self::Summable => |prim| matches!(prim, NumOrBytesType::Bytes | NumOrBytesType::Num),
Self::Lin => |prim| *prim == NumOrBytesType::Num,
};
StructConstraint::new(*self, predicate)
.deny_dyn_slices()
.visitor(substitutions, errors)
}
fn clone_boxed(&self) -> Box<dyn Constraint<NumOrBytesType>> {
Box::new(*self)
}
}
impl PrimitiveType for NumOrBytesType {}
#[derive(Debug, Clone, Copy)]
struct NumOrBytesArithmetic;
impl MapPrimitiveType<NumOrBytes> for NumOrBytesArithmetic {
type Prim = NumOrBytesType;
fn type_of_literal(&self, lit: &NumOrBytes) -> Self::Prim {
match lit {
NumOrBytes::Num(_) => NumOrBytesType::Num,
NumOrBytes::Bytes(_) => NumOrBytesType::Bytes,
}
}
}
impl TypeArithmetic<NumOrBytesType> for NumOrBytesArithmetic {
fn process_unary_op<'a>(
&self,
substitutions: &mut Substitutions<NumOrBytesType>,
context: &UnaryOpContext<NumOrBytesType>,
errors: OpErrors<'a, NumOrBytesType>,
) -> Type<NumOrBytesType> {
NumArithmetic::process_unary_op(substitutions, context, errors, &Constraints::Lin)
}
fn process_binary_op<'a>(
&self,
substitutions: &mut Substitutions<NumOrBytesType>,
context: &BinaryOpContext<NumOrBytesType>,
errors: OpErrors<'a, NumOrBytesType>,
) -> Type<NumOrBytesType> {
let op_constraints = if let BinaryOp::Add = context.op {
Constraints::Summable
} else {
Constraints::Lin
};
let settings = OpConstraintSettings {
lin: &Constraints::Lin,
ops: &op_constraints,
};
let comparable_type = Some(NumOrBytesType::Num);
NumArithmetic::process_binary_op(substitutions, context, errors, comparable_type, settings)
}
}
type Parser = Annotated<NumGrammar<NumOrBytes>>;
fn main() -> anyhow::Result<()> {
let code = r#"
x = 1 + 3;
y = 0x1234 + 0x56;
z = (1, 0x12) + (x, y);
sum = |xs, init| xs.fold(init, |acc, x| acc + x);
sum_of_bytes = (0xdead, 0xbe, 0xef).sum(0x);
sum_of_tuples = ((1, 2), (3, 4)).sum((0, 0));
product = |xs, init| xs.fold(init, |acc, x| acc * x);
product_of_ints = (1, 5, 2).product(1);
"#;
let ast = Parser::parse_statements(code)?;
let mut env = TypeEnvironment::<NumOrBytesType>::new();
env.insert("fold", Prelude::Fold);
env.process_with_arithmetic(&NumOrBytesArithmetic, &ast)?;
assert_eq!(env["x"], Type::Prim(NumOrBytesType::Num));
assert_eq!(env["y"].to_string(), "Bytes");
assert_eq!(env["z"].to_string(), "(Num, Bytes)");
assert_eq!(env["sum"].to_string(), "for<'T: Sum> (['T; N], 'T) -> 'T");
assert_eq!(env["sum_of_bytes"].to_string(), "Bytes");
assert_eq!(env["sum_of_tuples"].to_string(), "(Num, Num)");
assert_eq!(
env["product"].to_string(),
"for<'T: Lin> (['T; N], 'T) -> 'T"
);
assert_eq!(env["product_of_ints"].to_string(), "Num");
let bogus_code = "(0xca, 0xfe).product(0x)";
let bogus_ast = Parser::parse_statements(bogus_code)?;
let err = env
.process_with_arithmetic(&NumOrBytesArithmetic, &bogus_ast)
.unwrap_err();
assert_eq!(err.to_string(), "1:2: Type `Bytes` fails constraint `Lin`");
let bogus_code = "(|x| x + 1, |x| x - 5).sum(|x| x)";
let bogus_ast = Parser::parse_statements(bogus_code)?;
let err = env
.process_with_arithmetic(&NumOrBytesArithmetic, &bogus_ast)
.unwrap_err();
assert_eq!(
err.to_string(),
"1:2: Type `(Num) -> Num` fails constraint `Sum`"
);
env.insert("true", Type::BOOL);
let bogus_code = "((1, true), (2, true)).sum((3, true))";
let bogus_ast = Parser::parse_statements(bogus_code)?;
let err = env
.process_with_arithmetic(&NumOrBytesArithmetic, &bogus_ast)
.unwrap_err();
assert_eq!(err.to_string(), "1:6: Type `Bool` fails constraint `Sum`");
Ok(())
}