use crate::parser::{HExprParser, Rule};
use pest::Parser;
#[derive(Debug, Clone, PartialEq)]
pub enum Hexpr {
Composition(Vec<Hexpr>),
Tensor(Vec<Hexpr>),
Frobenius {
sources: Vec<Variable>,
targets: Vec<Variable>,
},
Operation(Operation),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Operation(pub(crate) String);
impl Operation {
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Variable(pub(crate) String);
impl std::str::FromStr for Variable {
type Err = Box<pest::error::Error<Rule>>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let pairs = HExprParser::parse(Rule::variable, s)?;
let variable_pair = pairs.into_iter().next().unwrap();
Ok(Variable(variable_pair.as_str().to_string()))
}
}
impl std::fmt::Display for Variable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::str::FromStr for Operation {
type Err = Box<pest::error::Error<Rule>>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let pairs = HExprParser::parse(Rule::operation, s)?;
let operation_pair = pairs.into_iter().next().unwrap();
Ok(Operation(operation_pair.as_str().to_string()))
}
}
impl std::fmt::Display for Operation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::str::FromStr for Hexpr {
type Err = crate::parser::ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
crate::parser::HExprParser::parse_hexpr(s)
}
}
impl std::fmt::Display for Hexpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Hexpr::Composition(exprs) => {
write!(f, "(")?;
for (i, expr) in exprs.iter().enumerate() {
if i > 0 {
write!(f, " ")?;
}
write!(f, "{}", expr)?;
}
write!(f, ")")
}
Hexpr::Tensor(exprs) => {
write!(f, "{{")?;
for (i, expr) in exprs.iter().enumerate() {
if i > 0 {
write!(f, " ")?;
}
write!(f, "{}", expr)?;
}
write!(f, "}}")
}
Hexpr::Frobenius { sources, targets } => {
if sources.is_empty() && targets.is_empty() {
write!(f, "[]")
} else {
write!(f, "[")?;
for (i, var) in sources.iter().enumerate() {
if i > 0 {
write!(f, " ")?;
}
write!(f, "{}", var)?;
}
write!(f, " . ")?;
for (i, var) in targets.iter().enumerate() {
if i > 0 {
write!(f, " ")?;
}
write!(f, "{}", var)?;
}
write!(f, "]")
}
}
Hexpr::Operation(name) => write!(f, "{}", name),
}
}
}