use crate::{Arena, IntoTerm, Term, TermError, View, atom, func, list};
use indexmap::IndexMap;
use smartstring::alias::String;
use std::collections::HashSet;
use std::fmt;
use std::str::FromStr;
macro_rules! bail {
($($arg:tt)*) => {
return Err(crate::TermError::OperDef(String::from(format!($($arg)*))))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Fixity {
Fun = 0,
Prefix = 1,
Infix = 2,
Postfix = 3,
}
impl Fixity {
pub const COUNT: usize = 4;
pub const STRS: &[&str] = &["fun", "prefix", "infix", "postfix"];
}
impl From<Fixity> for String {
fn from(f: Fixity) -> Self {
Fixity::STRS[Into::<usize>::into(f)].into()
}
}
impl From<Fixity> for usize {
fn from(f: Fixity) -> Self {
f as usize
}
}
impl fmt::Display for Fixity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(String::from(*self).as_str())
}
}
impl TryFrom<&str> for Fixity {
type Error = TermError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
s.parse()
}
}
impl TryFrom<String> for Fixity {
type Error = TermError;
fn try_from(s: String) -> Result<Self, Self::Error> {
s.as_str().parse()
}
}
impl FromStr for Fixity {
type Err = TermError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"fun" => Ok(Fixity::Fun),
"prefix" => Ok(Fixity::Prefix),
"infix" => Ok(Fixity::Infix),
"postfix" => Ok(Fixity::Postfix),
other => Err(TermError::InvalidFixity(String::from(other))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Assoc {
None = 0,
Left = 1,
Right = 2,
}
impl Assoc {
pub const COUNT: usize = 3;
pub const STRS: &[&str] = &["none", "left", "right"];
}
impl From<Assoc> for String {
fn from(a: Assoc) -> Self {
Assoc::STRS[Into::<usize>::into(a)].into()
}
}
impl From<Assoc> for usize {
fn from(a: Assoc) -> Self {
a as usize
}
}
impl fmt::Display for Assoc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(String::from(*self).as_str())
}
}
impl FromStr for Assoc {
type Err = TermError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"none" => Ok(Assoc::None),
"left" => Ok(Assoc::Left),
"right" => Ok(Assoc::Right),
other => Err(TermError::InvalidAssoc(String::from(other))),
}
}
}
impl TryFrom<&str> for Assoc {
type Error = TermError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
s.parse()
}
}
impl TryFrom<String> for Assoc {
type Error = TermError;
fn try_from(s: String) -> Result<Self, Self::Error> {
s.as_str().parse()
}
}
#[derive(Debug, Clone)]
pub struct OperArg {
pub name: String,
pub default: Option<Term>,
}
pub const NON_OPER_PREC: i64 = 0;
pub const MIN_OPER_PREC: i64 = 0;
pub const MAX_OPER_PREC: i64 = 1200;
#[derive(Debug, Clone)]
pub struct OperDef {
pub fixity: Fixity,
pub prec: i64,
pub assoc: Assoc,
pub args: Vec<OperArg>,
pub rename_to: Option<Term>,
pub embed_fixity: bool,
}
#[derive(Debug, Clone)]
pub struct OperDefTab {
tab: [Option<OperDef>; Fixity::COUNT],
}
#[derive(Debug, Clone, Default)]
pub struct OperDefs {
map: IndexMap<String, OperDefTab>,
}
static EMPTY_OPER_DEF_TAB: OperDefTab = OperDefTab::new();
impl OperDef {
pub fn required_arity(fixity: Fixity) -> usize {
match fixity {
Fixity::Fun => 0,
Fixity::Prefix => 1,
Fixity::Infix => 2,
Fixity::Postfix => 1,
}
}
}
impl OperDefTab {
pub const fn new() -> Self {
Self {
tab: [const { None }; Fixity::COUNT],
}
}
pub fn is_fun(&self) -> bool {
self.tab[0].is_some()
}
pub fn is_oper(&self) -> bool {
self.tab[1..].iter().any(|x| x.is_some())
}
pub fn get_op_def(&self, fixity: Fixity) -> Option<&OperDef> {
self.tab[usize::from(fixity)].as_ref()
}
}
impl std::ops::Index<Fixity> for OperDefTab {
type Output = Option<OperDef>;
fn index(&self, i: Fixity) -> &Self::Output {
let i: usize = i.into();
&self.tab[i]
}
}
impl std::ops::IndexMut<Fixity> for OperDefTab {
fn index_mut(&mut self, i: Fixity) -> &mut Self::Output {
let i: usize = i.into();
&mut self.tab[i]
}
}
impl OperDefs {
pub fn new() -> Self {
Self {
map: IndexMap::new(),
}
}
}
impl Arena {
pub fn lookup_oper(&self, name: &str) -> Option<usize> {
self.opers.map.get_index_of(name)
}
pub fn get_oper(&self, index: Option<usize>) -> &OperDefTab {
match index {
Some(index) => match self.opers.map.get_index(index) {
Some((_, tab)) => tab,
None => &EMPTY_OPER_DEF_TAB,
},
None => &EMPTY_OPER_DEF_TAB,
}
}
pub fn opers_len(&self) -> usize {
self.opers.map.len()
}
pub fn define_oper(&mut self, op: Term) -> Result<(), TermError> {
const BOOLS: &[&str] = &["false", "true"];
let (_, [oper, fixity, prec, assoc, rename_to, embed_fixity]) =
op.unpack_func(self, &["op"])?;
let (functor, args) = oper.unpack_func_any(self, &[])?;
let name = String::from(functor.atom_name(self)?);
let fixity = Fixity::try_from(fixity.unpack_atom(self, Fixity::STRS)?)?;
let prec = prec.unpack_int(self)?;
if prec < MIN_OPER_PREC || prec > MAX_OPER_PREC {
bail!(
"precedence {} out of range {}..={}",
prec,
MIN_OPER_PREC,
MAX_OPER_PREC
);
}
let assoc = Assoc::try_from(assoc.unpack_atom(self, Assoc::STRS)?)?;
let embed_fixity = embed_fixity.unpack_atom(self, BOOLS)? == "true";
let args = args
.into_iter()
.map(|arg| {
Ok(match arg.view(self)? {
View::Atom(name) => OperArg {
name: String::from(name),
default: None,
},
View::Func(ar, _, _) => {
let (_, [name, term]) = arg.unpack_func(ar, &["="])?;
OperArg {
name: String::from(name.atom_name(ar)?),
default: Some(term),
}
}
_ => bail!("oper arg must be an atom or =(atom, term) in {:?}", name),
})
})
.collect::<Result<Vec<_>, TermError>>()?;
let required_arity = OperDef::required_arity(fixity);
if args.len() < required_arity {
bail!(
"operator {:?} requires at least {} argument(s)",
name,
required_arity
);
}
if args[..required_arity].iter().any(|x| x.default.is_some()) {
bail!("defaults are not allowed for required operator arguments");
}
let unique_arg_names: HashSet<_> = args.iter().map(|x| &x.name).cloned().collect();
if unique_arg_names.len() != args.len() {
bail!("duplicate arguments in {:?}", name);
}
let rename_to = match rename_to.view(self)? {
View::Atom("none") => None,
View::Func(ar, _, _) => {
let (_, [rename_to]) = rename_to.unpack_func(ar, &["some"])?;
Some(rename_to)
}
_ => bail!("rename_to must be 'none' | some(atom)"),
};
if matches!(fixity, Fixity::Fun) && prec != NON_OPER_PREC {
bail!("{:?} must be assigned precedence 0", name);
}
if !matches!(fixity, Fixity::Fun) && (prec < MIN_OPER_PREC || prec > MAX_OPER_PREC) {
bail!(
"precedence {} is out of range for operator {:?} with type {:?} (expected {}–{})",
prec,
name,
fixity,
MIN_OPER_PREC,
MAX_OPER_PREC,
);
}
if matches!((fixity, assoc), (Fixity::Prefix, Assoc::Left))
|| matches!((fixity, assoc), (Fixity::Postfix, Assoc::Right))
{
bail!(
"operator {:?} with type {:?} cannot have associativity {:?}",
name,
fixity,
assoc
);
}
#[cfg(false)]
if matches!((fixity, assoc), (Fixity::Fun, Assoc::Left | Assoc::Right)) {
bail!(
"{:?} with type {:?} cannot have associativity {:?}",
name,
fixity,
assoc
);
}
let tab = self
.opers
.map
.entry(name.clone())
.or_insert_with(OperDefTab::new);
if matches!(fixity, Fixity::Fun) && tab.is_oper() {
bail!(
"cannot define {:?} with type {:?}; it is already defined as an operator with a different type",
name,
fixity,
);
}
if matches!(fixity, Fixity::Prefix | Fixity::Infix | Fixity::Postfix)
&& tab.tab[Into::<usize>::into(Fixity::Fun)].is_some()
{
bail!(
"cannot define {:?} as an operator with type {:?}; it is already defined with type Fun",
name,
fixity,
);
}
if tab[fixity].is_some() {
bail!("cannot re-define {:?}", name);
}
tab[fixity] = Some(OperDef {
fixity,
prec,
assoc,
rename_to,
embed_fixity,
args,
});
Ok(())
}
pub fn define_opers(&mut self, term: Term) -> Result<(), TermError> {
let ts = match term.view(self)? {
View::List(_, ts, _) => ts.to_vec(),
_ => {
vec![term]
}
};
for t in ts {
self.define_oper(t)?;
}
Ok(())
}
pub fn clear_opers(&mut self) {
self.opers.map.clear();
}
pub fn normalize_term(
&mut self,
term: Term,
fixity: Fixity,
op_tab_index: Option<usize>,
) -> Result<Term, TermError> {
match self.get_oper(op_tab_index)[fixity] {
Some(ref op_def) => {
let (functor, vs) = match term.view(self)? {
View::Atom(_) => (term, &[] as &[Term]),
View::Func(_, functor, args) => {
if args.is_empty() {
bail!("invalid Func");
}
(*functor, args)
}
_ => {
return Ok(term);
}
};
let name = functor.atom_name(self)?;
let n_required_args = OperDef::required_arity(fixity);
if vs.len() < n_required_args {
bail!(
"missing {} required arguments in term {:?}",
n_required_args - vs.len(),
name
);
}
let args = &op_def.args;
let mut xs: Vec<Option<Term>> = vec![None; args.len()];
for (i, value) in vs.iter().enumerate() {
if i < n_required_args {
xs[i] = Some(*value);
} else {
match value.view(self)? {
View::Func(ar, functor, vs)
if vs.len() == 2 && functor.atom_name(ar)? == "=" =>
{
let arg_name = vs[0].atom_name(self)?;
if let Some(pos) = args.iter().position(|x| x.name == arg_name) {
if xs[pos].is_none() {
xs[pos] = Some(vs[1]);
} else {
bail!(
"cannot redefine argument {:?} at position {} in {:?}",
arg_name,
pos,
name
);
}
} else {
bail!("invalid argument name {:?} in {:?}", arg_name, name);
}
}
_ => {
if xs[i].is_none() {
xs[i] = Some(*value);
} else {
bail!(
"cannot redefine argument {:?} at position {} in {:?}",
args[i].name,
i,
name
);
}
}
}
}
}
let vs: Option<Vec<_>> = xs
.into_iter()
.enumerate()
.map(|(i, x)| x.or(args[i].default))
.collect();
let mut vs = match vs {
Some(vs) => vs,
None => bail!("missing arguments in {:?}", name),
};
let rename_to = match op_def.rename_to {
Some(rename_to) => rename_to,
None => functor,
};
if op_def.embed_fixity {
vs.insert(0, self.atom(String::from(fixity)));
}
if vs.is_empty() {
Ok(rename_to)
} else {
Ok(self.funcv(std::iter::once(&rename_to).chain(vs.iter()))?)
}
}
None => match fixity {
Fixity::Fun => Ok(term),
_ => bail!("missing opdef for fixity {:?}", fixity),
},
}
}
pub fn define_default_opers(&mut self) -> Result<(), TermError> {
let term = list![
func!(
"op";
func!("-"; atom!("x")),
atom!("prefix"),
800,
atom!("right"),
atom!("none"),
atom!("false"),
),
func!(
"op";
func!("++"; atom!("x"), atom!("y")),
atom!("infix"),
500,
atom!("left"),
atom!("none"),
atom!("false"),
),
func!(
"op";
func!("="; atom!("x"), atom!("y")),
atom!("infix"),
100,
atom!("right"),
atom!("none"),
atom!("false"),
),
func!(
"op";
func!(
"op";
atom!("f"),
func!("="; atom!("type"), atom!("fun")),
func!("="; atom!("prec"), 0),
func!("="; atom!("assoc"), atom!("none")),
func!("="; atom!("rename_to"), atom!("none")),
func!("="; atom!("embed_type"), atom!("false")),
),
atom!("fun"),
0,
atom!("none"),
atom!("none"),
atom!("false"),
),
=> self
];
self.define_opers(term)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fixity_from_str_valid() {
assert_eq!("fun".parse::<Fixity>().unwrap(), Fixity::Fun);
assert_eq!("prefix".parse::<Fixity>().unwrap(), Fixity::Prefix);
assert_eq!("infix".parse::<Fixity>().unwrap(), Fixity::Infix);
assert_eq!("postfix".parse::<Fixity>().unwrap(), Fixity::Postfix);
}
#[test]
fn fixity_from_str_invalid() {
let err = "pre_fix".parse::<Fixity>().unwrap_err();
assert_eq!(err.to_string(), "invalid fixity: pre_fix");
}
#[test]
fn fixity_display_and_string_from() {
assert_eq!(Fixity::Fun.to_string(), "fun");
assert_eq!(Fixity::Prefix.to_string(), "prefix");
assert_eq!(Fixity::Infix.to_string(), "infix");
assert_eq!(Fixity::Postfix.to_string(), "postfix");
let s: smartstring::alias::String = Fixity::Infix.into();
assert_eq!(s.as_str(), "infix");
}
#[test]
fn fixity_into_usize_indices() {
assert_eq!(usize::from(Fixity::Fun), 0);
assert_eq!(usize::from(Fixity::Prefix), 1);
assert_eq!(usize::from(Fixity::Infix), 2);
assert_eq!(usize::from(Fixity::Postfix), 3);
assert_eq!(Fixity::STRS.len(), Fixity::COUNT);
}
#[test]
fn assoc_from_str_valid() {
assert_eq!("none".parse::<Assoc>().unwrap(), Assoc::None);
assert_eq!("left".parse::<Assoc>().unwrap(), Assoc::Left);
assert_eq!("right".parse::<Assoc>().unwrap(), Assoc::Right);
}
#[test]
fn assoc_from_str_invalid() {
let err = "center".parse::<Assoc>().unwrap_err();
assert_eq!(err.to_string(), "invalid associativity: center");
}
#[test]
fn assoc_display_and_string_from() {
assert_eq!(Assoc::None.to_string(), "none");
assert_eq!(Assoc::Left.to_string(), "left");
assert_eq!(Assoc::Right.to_string(), "right");
let s: smartstring::alias::String = Assoc::Right.into();
assert_eq!(s.as_str(), "right");
}
#[test]
fn assoc_into_usize_indices() {
assert_eq!(usize::from(Assoc::None), 0);
assert_eq!(usize::from(Assoc::Left), 1);
assert_eq!(usize::from(Assoc::Right), 2);
assert_eq!(Assoc::STRS.len(), Assoc::COUNT);
}
#[test]
fn required_arity_matches_fixity() {
assert_eq!(OperDef::required_arity(Fixity::Fun), 0);
assert_eq!(OperDef::required_arity(Fixity::Prefix), 1);
assert_eq!(OperDef::required_arity(Fixity::Infix), 2);
assert_eq!(OperDef::required_arity(Fixity::Postfix), 1);
}
fn minimal_def(fixity: Fixity, prec: i64, assoc: Assoc) -> OperDef {
OperDef {
fixity,
prec,
assoc,
args: Vec::new(),
rename_to: None,
embed_fixity: false,
}
}
#[test]
fn oper_def_tab_new_is_empty() {
let tab = OperDefTab::new();
assert!(!tab.is_fun());
assert!(!tab.is_oper());
assert!(tab.get_op_def(Fixity::Fun).is_none());
assert!(tab.get_op_def(Fixity::Prefix).is_none());
assert!(tab.get_op_def(Fixity::Infix).is_none());
assert!(tab.get_op_def(Fixity::Postfix).is_none());
}
#[test]
fn oper_def_tab_flags_update_correctly() {
let mut tab = OperDefTab::new();
tab[Fixity::Fun] = Some(minimal_def(Fixity::Fun, 0, Assoc::None));
assert!(tab.is_fun());
assert!(!tab.is_oper());
tab[Fixity::Infix] = Some(minimal_def(Fixity::Infix, 500, Assoc::Left));
assert!(tab.is_oper());
let inf = tab.get_op_def(Fixity::Infix).unwrap();
assert_eq!(inf.fixity, Fixity::Infix);
assert_eq!(inf.prec, 500);
assert_eq!(inf.assoc, Assoc::Left);
}
#[test]
fn oper_defs_empty_behavior() {
let arena = Arena::new();
assert_eq!(arena.opers_len(), 0);
assert_eq!(arena.lookup_oper("nope"), None);
let empty1 = arena.get_oper(None);
let empty2 = arena.get_oper(Some(0));
assert!(!empty1.is_fun());
assert!(!empty1.is_oper());
assert!(!empty2.is_fun());
assert!(!empty2.is_oper());
}
#[test]
fn oper_defs_with_one_entry() {
let mut arena = Arena::new();
let mut tab = OperDefTab::new();
let def = OperDef {
fixity: Fixity::Infix,
prec: 500,
assoc: Assoc::Left,
args: vec![
OperArg {
name: "lhs".into(),
default: None,
},
OperArg {
name: "rhs".into(),
default: None,
},
],
rename_to: None,
embed_fixity: false,
};
tab[Fixity::Infix] = Some(def.clone());
arena.opers.map.insert("+".into(), tab);
assert_eq!(arena.opers_len(), 1);
let idx = arena.lookup_oper("+").unwrap();
let retrieved_tab = arena.get_oper(Some(idx));
assert!(retrieved_tab.is_oper());
assert!(!retrieved_tab.is_fun());
let inf = retrieved_tab.get_op_def(Fixity::Infix).unwrap();
assert_eq!(inf.fixity, Fixity::Infix);
assert_eq!(inf.prec, 500);
assert_eq!(inf.assoc, Assoc::Left);
assert_eq!(inf.args.len(), 2);
assert_eq!(inf.args[0].name, "lhs");
assert_eq!(inf.args[1].name, "rhs");
}
}