use smallvec::{SmallVec, smallvec};
use typst_syntax::Spanned;
use typst_utils::{Numeric, default_math_class};
use unicode_math_class::MathClass;
use crate::diag::{At, HintedStrResult, StrResult, bail};
use crate::foundations::{
Array, Content, Dict, Fold, NoneValue, Resolve, Smart, StyleChain, Symbol, Value,
array, cast, dict, elem,
};
use crate::layout::{Abs, Em, HAlignment, Length, Rel};
use crate::math::Mathy;
use crate::visualize::Stroke;
const DEFAULT_ROW_GAP: Em = Em::new(0.2);
const DEFAULT_COL_GAP: Em = Em::new(0.5);
#[elem(title = "Vector", Mathy)]
pub struct VecElem {
#[default(DelimiterPair::PAREN)]
pub delim: DelimiterPair,
#[default(HAlignment::Center)]
pub align: HAlignment,
#[default(DEFAULT_ROW_GAP.into())]
pub gap: Rel<Length>,
#[variadic]
pub children: Vec<Content>,
}
#[elem(title = "Matrix", Mathy)]
pub struct MatElem {
#[default(DelimiterPair::PAREN)]
pub delim: DelimiterPair,
#[default(HAlignment::Center)]
pub align: HAlignment,
#[fold]
pub augment: Option<Augment>,
#[external]
pub gap: Rel<Length>,
#[parse(
let gap = args.named("gap")?;
args.named("row-gap")?.or(gap)
)]
#[default(DEFAULT_ROW_GAP.into())]
pub row_gap: Rel<Length>,
#[parse(args.named("column-gap")?.or(gap))]
#[default(DEFAULT_COL_GAP.into())]
pub column_gap: Rel<Length>,
#[variadic]
#[parse(
let mut rows = vec![];
let mut width = 0;
let values = args.all::<Spanned<Value>>()?;
if values.iter().any(|spanned| matches!(spanned.v, Value::Array(_))) {
for Spanned { v, span } in values {
let array = v.cast::<Array>().at(span)?;
let row: Vec<_> = array.into_iter().map(Value::display).collect();
width = width.max(row.len());
rows.push(row);
}
} else {
rows = vec![values.into_iter().map(|spanned| spanned.v.display()).collect()];
}
for row in &mut rows {
if row.len() < width {
row.resize(width, Content::empty());
}
}
rows
)]
pub rows: Vec<Vec<Content>>,
}
#[elem(Mathy)]
pub struct CasesElem {
#[default(DelimiterPair::BRACE)]
pub delim: DelimiterPair,
#[default(false)]
pub reverse: bool,
#[default(DEFAULT_ROW_GAP.into())]
pub gap: Rel<Length>,
#[variadic]
pub children: Vec<Content>,
}
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
pub struct Delimiter(Option<char>);
cast! {
Delimiter,
self => self.0.into_value(),
_: NoneValue => Self::none(),
v: Symbol => Self::char(v.get().parse::<char>().map_err(|_| "expected a single-codepoint symbol")?)?,
v: char => Self::char(v)?,
}
impl Delimiter {
pub fn none() -> Self {
Self(None)
}
pub fn char(c: char) -> StrResult<Self> {
if !matches!(
default_math_class(c),
Some(MathClass::Opening | MathClass::Closing | MathClass::Fence),
) {
bail!("invalid delimiter: \"{}\"", c)
}
Ok(Self(Some(c)))
}
pub fn get(self) -> Option<char> {
self.0
}
pub fn find_matching(self) -> Self {
match self.0 {
None => Self::none(),
Some('[') => Self(Some(']')),
Some(']') => Self(Some('[')),
Some('{') => Self(Some('}')),
Some('}') => Self(Some('{')),
Some(c) => match default_math_class(c) {
Some(MathClass::Opening) => Self(char::from_u32(c as u32 + 1)),
Some(MathClass::Closing) => Self(char::from_u32(c as u32 - 1)),
_ => Self(Some(c)),
},
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct DelimiterPair {
open: Delimiter,
close: Delimiter,
}
cast! {
DelimiterPair,
self => array![self.open, self.close].into_value(),
v: Array => match v.as_slice() {
[open, close] => Self {
open: open.clone().cast()?,
close: close.clone().cast()?,
},
_ => bail!("expected 2 delimiters, found {}", v.len())
},
v: Delimiter => Self { open: v, close: v.find_matching() }
}
impl DelimiterPair {
const PAREN: Self = Self {
open: Delimiter(Some('(')),
close: Delimiter(Some(')')),
};
const BRACE: Self = Self {
open: Delimiter(Some('{')),
close: Delimiter(Some('}')),
};
pub fn open(self) -> Option<char> {
self.open.get()
}
pub fn close(self) -> Option<char> {
self.close.get()
}
}
#[derive(Debug, Default, Clone, PartialEq, Hash)]
pub struct Augment<T: Numeric = Length> {
pub hline: AugmentOffsets,
pub vline: AugmentOffsets,
pub stroke: Smart<Stroke<T>>,
}
impl<T: Numeric + Fold> Fold for Augment<T> {
fn fold(self, outer: Self) -> Self {
Self {
stroke: match (self.stroke, outer.stroke) {
(Smart::Custom(inner), Smart::Custom(outer)) => {
Smart::Custom(inner.fold(outer))
}
(inner, outer) => inner.or(outer),
},
..self
}
}
}
impl Resolve for Augment {
type Output = Augment<Abs>;
fn resolve(self, styles: StyleChain) -> Self::Output {
Augment {
hline: self.hline,
vline: self.vline,
stroke: self.stroke.resolve(styles),
}
}
}
cast! {
Augment,
self => {
if self.stroke.is_auto() && self.hline.0.is_empty() && self.vline.0.len() == 1 {
return self.vline.0[0].into_value();
}
dict! {
"hline" => self.hline,
"vline" => self.vline,
"stroke" => self.stroke,
}.into_value()
},
v: isize => Augment {
hline: AugmentOffsets::default(),
vline: AugmentOffsets(smallvec![v]),
stroke: Smart::Auto,
},
mut dict: Dict => {
let mut take = |key| dict.take(key).ok().map(AugmentOffsets::from_value).transpose();
let hline = take("hline")?.unwrap_or_default();
let vline = take("vline")?.unwrap_or_default();
let stroke = dict.take("stroke")
.ok()
.map(Stroke::from_value)
.transpose()?
.map(Smart::Custom)
.unwrap_or(Smart::Auto);
Augment { hline, vline, stroke }
},
}
cast! {
Augment<Abs>,
self => self.into_value(),
}
#[derive(Debug, Default, Clone, Eq, PartialEq, Hash)]
pub struct AugmentOffsets(pub SmallVec<[isize; 1]>);
cast! {
AugmentOffsets,
self => self.0.into_value(),
v: isize => Self(smallvec![v]),
v: Array => Self(v.into_iter().map(Value::cast).collect::<HintedStrResult<_>>()?),
}