#![cfg_attr(not(test), allow(dead_code))]
use crate::{
FrameType,
spec::{BasisDef, BasisIndex},
};
use core::{cmp::Ordering, str::FromStr};
use itertools::Either;
use num_traits::ConstOne;
use proc_macro2::Span;
use std::{collections::HashSet, iter::FusedIterator};
#[derive(Debug, Clone)]
pub struct Alias {
pub name: String,
pub span: Span,
}
impl PartialEq for Alias {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl PartialOrd for Alias {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Eq for Alias {}
impl Ord for Alias {
fn cmp(&self, other: &Self) -> Ordering {
self.name.cmp(&other.name)
}
}
impl Alias {
pub fn new(name: impl Into<String>, span: Span) -> Self {
Self {
name: name.into(),
span,
}
}
}
impl From<&syn::Ident> for Alias {
fn from(value: &syn::Ident) -> Self {
Self::new(value.to_string(), value.span())
}
}
impl FromStr for Alias {
type Err = syn::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if is_valid_alias_name(s) {
Ok(Alias::new(s, Span::call_site()))
} else {
Err(syn::Error::new(
Span::call_site(),
format!(
"basis alias `{}` must match the pattern `[a-z]+[0-9A-Z]?`",
s
),
))
}
}
}
pub fn validate_aliases<'a, I>(aliases: I) -> syn::Result<()>
where
I: IntoIterator<Item = &'a Alias>,
{
let mut iter = aliases.into_iter();
let Some(mut prev) = iter.next() else {
return Ok(()); };
if !is_valid_alias_name(&prev.name) {
return Err(syn::Error::new(
prev.span,
format!(
"basis alias `{}` must match the pattern `[a-z]+[0-9A-Z]?`",
prev.name
),
));
}
for alias in iter {
if !is_valid_alias_name(&alias.name) {
return Err(syn::Error::new(
alias.span,
format!(
"basis alias `{}` must match the pattern `[a-z]+[0-9A-Z]?`",
alias.name
),
));
}
if prev == alias {
return Err(syn::Error::new(
alias.span,
format!("basis alias `{}` declared multiple times", alias.name),
));
}
if alias.name.starts_with(&prev.name) {
return Err(syn::Error::new(
alias.span,
format!(
"basis alias `{}` conflicts with alias `{}`, as `{}` is a prefix of `{}`",
alias.name, prev.name, alias.name, prev.name
),
));
}
prev = alias;
}
Ok(())
}
pub fn is_valid_alias_name(name: &str) -> bool {
let mut valid_prefix = false;
let mut seen_postfix = false;
for ch in name.chars() {
if seen_postfix {
return false; }
if ch.is_ascii_lowercase() {
valid_prefix = true;
} else if valid_prefix && (ch.is_ascii_digit() || ch.is_ascii_uppercase()) {
seen_postfix = true;
} else {
return false; }
}
valid_prefix
}
pub fn expand_compound_alias(name: &str, basis_aliases: &HashSet<String>) -> Option<Vec<String>> {
if name.is_empty() {
return None;
}
let mut expander = AliasExpander {
aliases: &HashSet::from_iter(basis_aliases.iter().map(|s| s.as_str())),
prefixes: &HashSet::from_iter(basis_aliases.iter().map(|s| {
if s.chars().last().unwrap().is_lowercase() {
&s
} else {
&s[..s.len() - 1]
}
})),
prefix: "",
tail: name,
is_valid: true,
};
let components = expander.collect();
expander.is_valid.then_some(components)
}
#[derive(Debug)]
struct AliasExpander<'a> {
aliases: &'a HashSet<&'a str>,
prefixes: &'a HashSet<&'a str>,
prefix: &'a str,
tail: &'a str,
is_valid: bool,
}
impl Iterator for &mut AliasExpander<'_> {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
if !self.is_valid {
return None;
}
for (i, c) in self.tail.char_indices() {
if !c.is_lowercase() {
break;
}
let (prefix, tail) = self.tail.split_at(i + 1);
if self.prefixes.contains(prefix) {
self.prefix = prefix;
self.tail = tail;
break;
}
}
if self.prefix.is_empty() {
self.is_valid = self.tail.is_empty();
return None;
}
if let Some(suffix) = self.tail.chars().next() {
if suffix.is_ascii_digit() || suffix.is_ascii_uppercase() {
let (suffix, rest) = self.tail.split_at(1);
self.tail = rest;
let alias = format!("{}{}", self.prefix, suffix);
self.is_valid = self.aliases.contains(&*alias);
self.is_valid.then_some(alias)
} else if suffix.is_lowercase() {
self.is_valid = self.aliases.contains(self.prefix);
self.is_valid.then_some(self.prefix.to_string())
} else {
self.is_valid = false;
return None;
}
} else {
let out = self.prefix.to_string();
self.prefix = "";
(self.is_valid && self.aliases.contains(out.as_str())).then_some(out)
}
}
}
impl FusedIterator for &mut AliasExpander<'_> {}
#[derive(Debug, Clone)]
pub enum BasisSlot {
Scalar(f64, Span),
HypePos(f64, String, usize, Span),
HypeNeg(f64, String, usize, Span),
ImagPos(f64, String, usize, Span),
ImagNeg(f64, String, usize, Span),
}
impl TryFrom<&BasisIndex> for BasisSlot {
type Error = syn::Error;
fn try_from(value: &BasisIndex) -> syn::Result<Self> {
let is_negative = matches!(value.sign, Either::Right(_));
let scalar_coeff = if let Some(ref lit) = value.scalar {
lit.base10_parse::<f64>().map_err(|_| {
syn::Error::new(lit.span(), "scalar coefficient must be a valid float")
})?
} else {
1.0
};
let Some(ref ident) = value.ident else {
let span = value
.scalar
.as_ref()
.map(|s| s.span())
.unwrap_or_else(|| proc_macro2::Span::call_site());
let coeff = if is_negative {
-scalar_coeff
} else {
scalar_coeff
};
return Ok(BasisSlot::Scalar(coeff, span));
};
let repr = ident.to_string();
let mut chars = repr.chars();
let Some(prefix) = chars.next() else {
return Err(syn::Error::new(
ident.span(),
"basis index must include a family prefix and numeric slot",
));
};
let digits: String = chars.collect();
if digits.is_empty() {
return Err(syn::Error::new(
ident.span(),
"basis index must end with digits",
));
}
let index = digits
.parse()
.map_err(|_| syn::Error::new(ident.span(), "basis index suffix must be an integer"))?;
let span = ident.span();
let coeff = if is_negative {
-scalar_coeff
} else {
scalar_coeff
};
let slot = match prefix {
'P' => {
if is_negative {
BasisSlot::HypeNeg(coeff, repr, index, span)
} else {
BasisSlot::HypePos(coeff, repr, index, span)
}
}
'N' => {
if is_negative {
BasisSlot::ImagNeg(coeff, repr, index, span)
} else {
BasisSlot::ImagPos(coeff, repr, index, span)
}
}
_ => {
return Err(syn::Error::new(
ident.span(),
"basis indices must start with `P` or `N`",
));
}
};
Ok(slot)
}
}
#[allow(dead_code)]
impl BasisSlot {
pub fn repr(&self) -> &str {
match self {
BasisSlot::Scalar(_, _) => "1",
BasisSlot::HypePos(_, repr, _, _)
| BasisSlot::HypeNeg(_, repr, _, _)
| BasisSlot::ImagPos(_, repr, _, _)
| BasisSlot::ImagNeg(_, repr, _, _) => repr,
}
}
pub fn index(&self) -> usize {
match self {
BasisSlot::Scalar(_, _) => usize::MAX,
BasisSlot::HypePos(_, _, idx, _)
| BasisSlot::HypeNeg(_, _, idx, _)
| BasisSlot::ImagPos(_, _, idx, _)
| BasisSlot::ImagNeg(_, _, idx, _) => *idx,
}
}
pub fn span(&self) -> Span {
match self {
BasisSlot::Scalar(_, span)
| BasisSlot::HypePos(_, _, _, span)
| BasisSlot::HypeNeg(_, _, _, span)
| BasisSlot::ImagPos(_, _, _, span)
| BasisSlot::ImagNeg(_, _, _, span) => *span,
}
}
fn is_hyperbolic(&self) -> bool {
matches!(self, BasisSlot::HypePos(..) | BasisSlot::HypeNeg(..))
}
fn is_imaginary(&self) -> bool {
matches!(self, BasisSlot::ImagPos(..) | BasisSlot::ImagNeg(..))
}
pub fn is_scalar(&self) -> bool {
matches!(self, BasisSlot::Scalar(..))
}
pub fn is_positive(&self) -> bool {
match self {
BasisSlot::Scalar(v, _) => *v >= 0.0,
BasisSlot::HypePos(v, ..) | BasisSlot::ImagPos(v, ..) => *v >= 0.0,
BasisSlot::HypeNeg(..) | BasisSlot::ImagNeg(..) => false,
}
}
pub fn is_negative(&self) -> bool {
!self.is_positive()
}
pub fn coeff(&self) -> f64 {
match self {
BasisSlot::Scalar(v, _)
| BasisSlot::HypePos(v, ..)
| BasisSlot::HypeNeg(v, ..)
| BasisSlot::ImagPos(v, ..)
| BasisSlot::ImagNeg(v, ..) => *v,
}
}
pub fn to_mask(&self, positive: usize, negative: usize) -> syn::Result<FrameType> {
if self.is_scalar() {
return Ok(0);
}
let total = positive + negative;
if total > FrameType::BITS as usize {
return Err(syn::Error::new(
self.span(),
format!(
"frame uses {total} slots but FrameType `{}` supports at most {}",
std::any::type_name::<FrameType>(),
FrameType::BITS
),
));
}
let idx = self.index();
if self.is_hyperbolic() {
if idx >= positive {
return Err(syn::Error::new(
self.span(),
format!(
"positive slot `{}` exceeds positive count ({positive})",
self.repr()
),
));
}
Ok(FrameType::ONE << idx)
} else {
if idx >= negative {
return Err(syn::Error::new(
self.span(),
format!(
"negative slot `{}` exceeds negative count ({negative})",
self.repr()
),
));
}
Ok(FrameType::ONE << (positive + idx))
}
}
}
#[derive(Debug, Clone)]
pub struct BasisAlias {
pub alias: Alias,
pub terms: Vec<BasisSlot>,
}
impl TryFrom<&BasisDef> for BasisAlias {
type Error = syn::Error;
fn try_from(def: &BasisDef) -> Result<Self, Self::Error> {
let alias = Alias::new(def.name.to_string(), def.name.span());
let mut terms = Vec::with_capacity(def.indices.len());
for index in def.indices.iter() {
terms.push(index.try_into()?);
}
Ok(Self { alias, terms })
}
}
#[cfg(test)]
mod tests {
use super::*;
use num_traits::ConstOne;
use proc_macro2::Span;
use std::collections::HashSet;
#[test]
fn basis_slot_to_mask_positive_family() {
let bi: BasisIndex = syn::parse_str("P2").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).expect("parse slot");
let mask = slot.to_mask(3, 2).expect("mask");
assert_eq!(mask, FrameType::ONE << 2);
}
#[test]
fn basis_slot_to_mask_negative_family() {
let bi: BasisIndex = syn::parse_str("N1").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).expect("parse slot");
let mask = slot.to_mask(3, 2).expect("mask");
assert_eq!(mask, FrameType::ONE << 4);
}
#[test]
fn basis_slot_to_mask_rejects_out_of_range() {
let bi: BasisIndex = syn::parse_str("P3").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).expect("parse slot");
let err = slot.to_mask(3, 0).expect_err("expected failure");
assert!(err.to_string().contains("positive slot"));
}
#[test]
fn basis_alias_collects_terms() {
let def: BasisDef = syn::parse_str("e0 = P0 - N1").expect("parse basis def");
let alias = BasisAlias::try_from(&def).expect("alias");
assert_eq!(alias.terms.len(), 2);
assert_eq!(alias.alias, Alias::new("e0", Span::call_site()));
assert!(alias.terms[0].is_positive());
assert!(alias.terms[1].is_negative());
assert!(matches!(alias.terms[0], BasisSlot::HypePos(..)));
assert!(matches!(alias.terms[1], BasisSlot::ImagNeg(..)));
}
fn alias(name: &str) -> Alias {
Alias {
name: name.to_string(),
span: Span::call_site(),
}
}
#[test]
fn accepts_valid_aliases() {
let aliases = vec![
alias("e"),
alias("sigma0"),
alias("bladeA"),
alias("uvw"),
alias("gamma9"),
];
assert!(validate_aliases(&aliases).is_ok());
}
#[test]
fn rejects_invalid_pattern() {
validate_aliases(&vec![alias("E0")]).expect_err("uppercase leading character should fail");
}
#[test]
fn rejects_duplicate_aliases() {
validate_aliases(&vec![alias("foo"), alias("foo")]).expect_err("duplicates must fail");
}
#[test]
fn rejects_prefix_conflicts() {
validate_aliases(&vec![alias("foo"), alias("foo0")])
.expect_err("prefix conflict must fail");
}
#[test]
fn validates_alias_name_helper() {
for name in ["e", "sigma", "sigma0", "bladeA"] {
assert!(is_valid_alias_name(name), "{name} should be accepted");
}
for name in ["", "E", "foo_bar", "dualRotor", "abc1x", "9foo"] {
assert!(!is_valid_alias_name(name), "{name} should be rejected");
}
}
fn basis(names: &[&str]) -> HashSet<String> {
names.iter().map(|name| name.to_string()).collect()
}
#[test]
fn expands_simple_compound_alias() {
let aliases = basis(&["e0", "e1", "e2", "e3"]);
let expanded = expand_compound_alias("e12", &aliases).expect("compound alias");
assert_eq!(expanded, vec!["e1".to_string(), "e2".to_string()]);
}
#[test]
fn expands_descending_digits() {
let aliases = basis(&["e0", "e1", "e2", "e3"]);
let expanded = expand_compound_alias("e321", &aliases).expect("compound alias");
assert_eq!(expanded, vec!["e3", "e2", "e1"]);
}
#[test]
fn expands_multi_prefix_alias() {
let aliases = basis(&["t0", "x0", "x1", "x2"]);
let expanded = expand_compound_alias("t0x012", &aliases).expect("compound alias");
assert_eq!(expanded, vec!["t0", "x0", "x1", "x2"]);
}
#[test]
fn expands_prefix_without_suffix() {
let aliases = basis(&["t", "x", "y", "z"]);
let expanded = expand_compound_alias("txyz", &aliases).expect("compound alias");
assert_eq!(expanded, vec!["t", "x", "y", "z"]);
}
#[test]
fn test_basis_slot_scalar_coefficient() {
let bi: BasisIndex = syn::parse_str("+0.5 * P0").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).unwrap();
assert_eq!(slot.coeff(), 0.5);
assert_eq!(slot.repr(), "P0");
assert!(slot.is_positive());
let bi: BasisIndex = syn::parse_str("-0.5 * N1").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).unwrap();
assert_eq!(slot.coeff(), -0.5);
assert_eq!(slot.repr(), "N1");
assert!(slot.is_negative());
let bi: BasisIndex = syn::parse_str("0.5").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).unwrap();
assert_eq!(slot.coeff(), 0.5);
assert!(slot.is_scalar());
let bi: BasisIndex = syn::parse_str("-0.5").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).unwrap();
assert_eq!(slot.coeff(), -0.5);
assert!(slot.is_scalar());
let bi: BasisIndex = syn::parse_str("P0").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).unwrap();
assert_eq!(slot.coeff(), 1.0);
assert_eq!(slot.repr(), "P0");
let bi: BasisIndex = syn::parse_str("-P0").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).unwrap();
assert_eq!(slot.coeff(), -1.0);
assert_eq!(slot.repr(), "P0");
}
#[test]
fn test_basis_slot_to_mask_scalar() {
let bi: BasisIndex = syn::parse_str("0.5").expect("parse basis index");
let slot = BasisSlot::try_from(&bi).unwrap();
let mask = slot.to_mask(2, 2).unwrap();
assert_eq!(mask, 0);
}
}