use std::{
borrow::Borrow,
cell::RefCell,
collections::HashMap,
fmt,
hash::{Hash, Hasher},
iter::once,
mem::{discriminant, swap, take},
ops::Deref,
slice::{self, SliceIndex},
sync::Arc,
};
use ecow::{EcoString, EcoVec, eco_vec};
use indexmap::IndexSet;
use rapidhash::quality::RapidHasher;
use serde::*;
use crate::{
AestheticHash, Assembly, BindingKind, DynamicFunction, Function, ImplPrimitive, Primitive,
Purity, Signature, Value,
check::SigCheckError,
compile::invert::{InversionError, InversionResult},
};
node!(
Array {
len: usize,
inner: Arc<Node>,
boxed: bool,
allow_ext: bool,
prim: Option<Primitive>,
span: usize
},
CallGlobal(index(usize), sig(Signature)),
CallMacro { index: usize, sig: Signature, span: usize },
BindGlobal { index: usize, span: usize },
Label(label(EcoString), span(usize)),
RemoveLabel(label(Option<EcoString>), span(usize)),
Format(parts(EcoVec<EcoString>), span(usize)),
MatchFormatPattern(parts(EcoVec<EcoString>), span(usize)),
CustomInverse(cust(Arc<CustomInverse>), span(usize)),
Switch { branches: Ops, sig: Signature, under_cond: bool, span: usize },
Unpack {
count: usize,
unbox: bool,
allow_ext: bool,
prim: Option<Primitive>,
span: usize
},
SetOutputComment { i: usize, n: usize },
Dynamic(func(DynamicFunction)),
PushUnder(n(usize), span(usize)),
CopyToUnder(n(usize), span(usize)),
PopUnder(n(usize), span(usize)),
NoInline(inner(Arc<Node>)),
TrackCaller(inner(Arc<SigNode>)),
Push(val(Value)),
(#[serde(untagged)] rep),
Prim(prim(Primitive), span(usize)),
(#[serde(untagged)] rep),
ImplPrim(prim(ImplPrimitive), span(usize)),
(#[serde(untagged)] rep),
Mod(prim(Primitive), args(Ops), span(usize)),
(#[serde(untagged)] rep),
ImplMod(prim(ImplPrimitive), args(Ops), span(usize)),
(#[serde(untagged)] rep),
Call(func(Function), span(usize)),
(#[serde(untagged)] rep),
Run(nodes(EcoVec<Node>)),
);
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
pub struct SigNode {
pub node: Node,
pub sig: Signature,
}
impl SigNode {
pub fn new(sig: impl Into<Signature>, node: impl Into<Node>) -> Self {
Self {
node: node.into(),
sig: sig.into(),
}
}
pub fn on_all(self, n: usize, span: usize) -> SigNode {
let sig = self.sig;
let mut sig = sig.with_under(sig.under_args() * n, sig.under_outputs() * n);
sig.update_args_outputs(|a, o| (a * n, o * n));
let node = match n {
0 => Node::empty(),
1 => self.node,
n => {
let mut sn = self;
let inner = sn.clone();
let prev_pow_2 = (n as f64).log2() as usize;
for _ in 0..prev_pow_2 {
let mut sig = sn.sig;
sig.update_args_outputs(|a, o| (a * 2, o * 2));
let node = Node::Mod(Primitive::Both, eco_vec![sn], span);
sn = SigNode::new(sig, node);
}
let remain = n - 2usize.pow(prev_pow_2 as u32);
if remain > 0 {
let SigNode { mut sig, node } = sn;
let args = inner.sig.args();
let mut both = node.clone();
for _ in 0..remain {
for _ in 0..args {
let inner = SigNode::new(sig, both);
both = Node::Mod(Primitive::Dip, eco_vec![inner], span);
}
both.push(inner.node.clone());
sig.update_args_outputs(|a, o| {
(a + args, o + args + inner.sig.outputs() - inner.sig.args())
});
}
both
} else {
sn.node
}
}
};
SigNode::new(sig, node)
}
pub fn dipped(self, depth: usize, span: usize) -> SigNode {
let mut sig = self.sig;
sig.update_args_outputs(|a, o| (a + depth, o + depth));
let node = match depth {
0 => self.node,
1 => Node::Mod(Primitive::Dip, eco_vec![self], span),
n => Node::ImplMod(ImplPrimitive::DipN(n), eco_vec![self], span),
};
SigNode::new(sig, node)
}
pub fn gapped(mut self, depth: usize, span: usize) -> SigNode {
let mut sig = self.sig;
sig.update_args(|a| a + depth);
for _ in 0..depth {
self.node.prepend(Node::Prim(Primitive::Pop, span));
}
self
}
}
impl From<SigNode> for Node {
fn from(sn: SigNode) -> Self {
sn.node
}
}
impl From<Arc<Node>> for Node {
fn from(node: Arc<Node>) -> Self {
Arc::unwrap_or_clone(node)
}
}
impl From<SigNode> for (Node, Signature) {
fn from(sn: SigNode) -> Self {
(sn.node, sn.sig)
}
}
impl Serialize for SigNode {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
(self.sig.args(), self.sig.outputs(), &self.node).serialize(serializer)
}
}
impl<'de> Deserialize<'de> for SigNode {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let (args, outputs, node) = <(usize, usize, Node)>::deserialize(deserializer)?;
Ok(SigNode::new(Signature::new(args, outputs), node))
}
}
pub(crate) type Ops = EcoVec<SigNode>;
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(default)]
pub struct CustomInverse {
pub normal: InversionResult<SigNode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub un: Option<SigNode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub under: Option<(SigNode, SigNode)>,
#[serde(skip_serializing_if = "Option::is_none")]
pub anti: Option<SigNode>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub is_obverse: bool,
}
impl Default for CustomInverse {
fn default() -> Self {
Self {
normal: Ok(SigNode::default()),
un: None,
under: None,
anti: None,
is_obverse: false,
}
}
}
impl From<InversionError> for CustomInverse {
fn from(e: InversionError) -> Self {
Self {
normal: Err(e),
..Default::default()
}
}
}
impl fmt::Debug for CustomInverse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut s = if self.is_obverse {
f.debug_struct("obverse")
} else {
f.debug_struct("custom inverse")
};
if let Ok(normal) = &self.normal {
s.field("normal", &normal.node);
} else {
s.field("normal", &self.normal.as_ref().map(|sn| &sn.node));
}
if let Some(un) = &self.un {
s.field("un", &un.node);
}
if let Some(anti) = &self.anti {
s.field("anti", &anti.node);
}
if let Some(under) = &self.under {
s.field("do", &under.0.node);
s.field("undo", &under.1.node);
}
s.finish()
}
}
impl CustomInverse {
pub fn sig(&self) -> Result<Signature, SigCheckError> {
Ok(match &self.normal {
Ok(n) => n.sig,
Err(e) => {
if let Some(un) = &self.un {
un.sig.inverse()
} else if let Some(anti) = &self.anti {
anti.sig
.anti()
.ok_or_else(|| SigCheckError::from(e.to_string()).no_inverse())?
} else {
return match e {
InversionError::Signature(e) => Err(e.clone()),
e => Err(SigCheckError::from(e.to_string()).no_inverse()),
};
}
}
})
}
pub fn nodes(&self) -> impl Iterator<Item = &SigNode> {
(self.normal.as_ref().into_iter())
.chain(self.un.as_ref())
.chain(self.anti.as_ref())
.chain(self.under.as_ref().into_iter().flat_map(|(b, a)| [a, b]))
}
pub fn nodes_mut(&mut self) -> impl Iterator<Item = &mut SigNode> {
(self.normal.as_mut().into_iter())
.chain(self.un.as_mut())
.chain(self.anti.as_mut())
.chain(self.under.as_mut().into_iter().flat_map(|(b, a)| [a, b]))
}
}
impl Default for Node {
fn default() -> Self {
Self::empty()
}
}
impl Node {
pub const fn empty() -> Self {
Self::Run(EcoVec::new())
}
pub fn new_push(val: impl Into<Value>) -> Self {
let mut value = val.into();
value.try_shrink();
value.derive_sortedness();
Self::Push(value)
}
pub fn as_slice(&self) -> &[Node] {
if let Node::Run(nodes) = self {
nodes
} else {
slice::from_ref(self)
}
}
pub fn as_mut_slice(&mut self) -> &mut [Node] {
match self {
Node::Run(nodes) => nodes.make_mut(),
other => slice::from_mut(other),
}
}
pub fn slice<R>(&self, range: R) -> Self
where
R: SliceIndex<[Node], Output = [Node]>,
{
Self::from_iter(self.as_slice()[range].iter().cloned())
}
pub fn as_vec(&mut self) -> &mut EcoVec<Node> {
match self {
Node::Run(nodes) => nodes,
other => {
let first = take(other);
let Node::Run(nodes) = other else {
unreachable!()
};
nodes.push(first);
nodes
}
}
}
pub fn into_vec(self) -> EcoVec<Node> {
if let Node::Run(nodes) = self {
nodes
} else {
eco_vec![self]
}
}
pub fn normalize(&mut self) {
if let Node::Run(nodes) = self {
if nodes.len() == 1 {
*self = take(nodes).remove(0);
self.normalize();
} else if nodes.iter().any(|node| matches!(node, Node::Run(_))) {
*self = take(nodes).into_iter().flatten().collect();
self.normalize();
}
}
}
pub fn truncate(&mut self, len: usize) {
if let Node::Run(nodes) = self {
nodes.truncate(len);
if nodes.len() == 1 {
*self = take(nodes).remove(0);
}
} else if len == 0 {
*self = Node::default();
}
}
#[track_caller]
pub fn split_off(&mut self, index: usize) -> Self {
if let Node::Run(nodes) = self {
let removed = EcoVec::from(&nodes[index..]);
nodes.truncate(index);
Node::Run(removed)
} else if index == 0 {
take(self)
} else if index == 1 {
Node::empty()
} else {
panic!(
"Index {index} out of bounds of node with length {}",
self.len()
);
}
}
pub fn iter_mut(&mut self) -> slice::IterMut<Self> {
self.as_mut_slice().iter_mut()
}
pub fn push(&mut self, mut node: Node) {
if let Node::Run(nodes) = self {
if nodes.is_empty() {
*self = node;
} else {
match node {
Node::Run(other) => nodes.extend(other),
node => nodes.push(node),
}
}
} else if let Node::Run(nodes) = &node {
if !nodes.is_empty() {
swap(self, &mut node);
self.as_vec().insert(0, node);
}
} else {
self.as_vec().push(node);
}
}
pub fn prepend(&mut self, mut node: Node) {
if let Node::Run(nodes) = self {
if nodes.is_empty() {
*self = node;
} else {
match node {
Node::Run(mut other) => {
swap(nodes, &mut other);
nodes.extend(other)
}
node => nodes.insert(0, node),
}
}
} else if let Node::Run(nodes) = &node {
if !nodes.is_empty() {
swap(self, &mut node);
self.as_vec().push(node);
}
} else {
self.as_vec().insert(0, node);
}
}
pub fn pop(&mut self) -> Option<Node> {
match self {
Node::Run(nodes) => {
let res = nodes.pop();
if nodes.len() == 1 {
*self = take(nodes).remove(0);
}
res
}
node => Some(take(node)),
}
}
pub fn clear(&mut self) {
*self = Node::empty();
}
pub(crate) fn as_flipped_primitive(&self) -> Option<(Primitive, bool)> {
match self {
Node::Prim(prim, _) => Some((*prim, false)),
Node::Run(nodes) => match nodes.as_slice() {
[Node::Prim(Primitive::Flip, _), Node::Prim(prim, _)] => Some((*prim, true)),
_ => None,
},
_ => None,
}
}
pub(crate) fn as_primitive(&self) -> Option<Primitive> {
self.as_flipped_primitive()
.filter(|(_, flipped)| !flipped)
.map(|(prim, _)| prim)
}
pub(crate) fn as_flipped_impl_primitive(&self) -> Option<(ImplPrimitive, bool)> {
match self {
Node::ImplPrim(prim, _) => Some((*prim, false)),
Node::Run(nodes) => match nodes.as_slice() {
[Node::Prim(Primitive::Flip, _), Node::ImplPrim(prim, _)] => Some((*prim, true)),
_ => None,
},
_ => None,
}
}
pub(crate) fn as_impl_primitive(&self) -> Option<ImplPrimitive> {
self.as_flipped_impl_primitive()
.filter(|(_, flipped)| !flipped)
.map(|(prim, _)| prim)
}
pub fn last_mut_recursive<T>(
&mut self,
asm: &mut Assembly,
f: impl Fn(&mut Node) -> T + Copy,
) -> Option<T> {
fn recurse<T>(
target: Option<usize>,
sub: Option<usize>,
node: &mut Node,
asm: &mut Assembly,
f: impl FnOnce(&mut Node) -> T + Copy,
) -> Option<T> {
let mut this_node = match target {
Some(i) => asm.functions[i].inner(),
None => node.inner(),
};
if let Some(i) = sub {
this_node = this_node.get(i)?;
}
match this_node {
Node::Run(nodes) if sub.is_none() => {
for i in (0..nodes.len()).rev() {
if let Some(res) = recurse(target, Some(i), node, asm, f) {
return Some(res);
}
}
None
}
Node::Call(func, _) => recurse(Some(func.index), None, node, asm, f),
_ => {
let mut node = match target {
Some(i) => asm.functions.make_mut()[i].inner_mut(),
None => node.inner_mut(),
};
if let Some(i) = sub {
node = node.as_mut_slice().get_mut(i)?.inner_mut();
}
Some(f(node))
}
}
}
recurse(None, None, self, asm, f)
}
pub(crate) fn inner(&self) -> &Node {
match self {
Node::NoInline(inner) => inner.inner(),
Node::TrackCaller(inner) => inner.node.inner(),
Node::CustomInverse(cust, ..) => {
if let Ok(sn) = cust.normal.as_ref() {
sn.node.inner()
} else {
self
}
}
node => node,
}
}
fn inner_mut(&mut self) -> &mut Node {
match self {
Node::NoInline(inner) => Arc::make_mut(inner).inner_mut(),
Node::TrackCaller(inner) => Arc::make_mut(inner).node.inner_mut(),
node => node,
}
}
pub(crate) fn hash_with_span(&self, hasher: &mut impl Hasher) {
self.hash(hasher);
if let Some(span) = self.span() {
span.hash(hasher);
}
}
pub fn bracket<I>(nodes: I, span: usize) -> Node
where
I: IntoIterator<Item = SigNode>,
I::IntoIter: ExactSizeIterator,
{
let mut nodes = nodes.into_iter();
let size = nodes.len();
match size {
0 => Node::empty(),
1 => nodes.next().unwrap().node,
2 => Node::Mod(Primitive::Bracket, nodes.collect(), span),
_ => {
let first = nodes.next().unwrap();
let second = Self::bracket(nodes, span).sig_node().unwrap();
Node::Mod(Primitive::Bracket, eco_vec![first, second], span)
}
}
}
pub fn sub_nodes(&self) -> Box<dyn Iterator<Item = &Node> + '_> {
match self {
Node::Run(nodes) => Box::new(nodes.iter()),
Node::Array { inner, .. } | Node::NoInline(inner) => Box::new(once(&**inner)),
Node::TrackCaller(inner) => Box::new(once(&inner.node)),
Node::Mod(_, args, _) | Node::ImplMod(_, args, _) => {
Box::new(args.iter().map(|sn| &sn.node))
}
Node::CustomInverse(cust, _) => Box::new(cust.nodes().map(|sn| &sn.node)),
Node::Switch { branches, .. } => Box::new(branches.iter().map(|sn| &sn.node)),
_ => Box::new([].into_iter()),
}
}
pub fn sub_nodes_mut(&mut self) -> Box<dyn Iterator<Item = &mut Node> + '_> {
match self {
Node::Run(nodes) => Box::new(nodes.make_mut().iter_mut()),
Node::Array { inner, .. } | Node::NoInline(inner) => {
Box::new(once(Arc::make_mut(inner)))
}
Node::TrackCaller(inner) => Box::new(once(&mut Arc::make_mut(inner).node)),
Node::Mod(_, args, _) | Node::ImplMod(_, args, _) => {
Box::new(args.make_mut().iter_mut().map(|sn| &mut sn.node))
}
Node::CustomInverse(cust, _) => {
Box::new(Arc::make_mut(cust).nodes_mut().map(|sn| &mut sn.node))
}
Node::Switch { branches, .. } => {
Box::new(branches.make_mut().iter_mut().map(|sn| &mut sn.node))
}
_ => Box::new([].into_iter()),
}
}
}
impl From<&[Node]> for Node {
fn from(nodes: &[Node]) -> Self {
Node::from_iter(nodes.iter().cloned())
}
}
impl<const N: usize> From<[Node; N]> for Node {
fn from(nodes: [Node; N]) -> Self {
Node::from_iter(nodes)
}
}
impl From<EcoVec<Node>> for Node {
fn from(nodes: EcoVec<Node>) -> Self {
if nodes.len() == 1 {
nodes.into_iter().next().unwrap()
} else {
Node::Run(nodes)
}
}
}
impl FromIterator<Node> for Node {
fn from_iter<T: IntoIterator<Item = Node>>(iter: T) -> Self {
let mut iter = iter.into_iter();
let Some(mut node) = iter.next() else {
return Node::default();
};
for n in iter {
node.push(n);
}
node
}
}
impl Extend<Node> for Node {
fn extend<T: IntoIterator<Item = Node>>(&mut self, iter: T) {
for node in iter.into_iter().flatten() {
self.push(node);
}
}
}
impl fmt::Debug for Node {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Node::Run(eco_vec) => {
let mut tuple = f.debug_tuple("");
for node in eco_vec {
tuple.field(node);
}
tuple.finish()
}
Node::Push(value) => write!(f, "push {value}"),
Node::Prim(prim, _) => write!(f, "{prim}"),
Node::ImplPrim(impl_prim, _) => write!(f, "{impl_prim:?}"),
Node::Mod(prim, args, _) => {
let mut tuple = f.debug_tuple(&prim.to_string());
for sn in args {
tuple.field(&sn.node);
}
tuple.finish()
}
Node::ImplMod(impl_prim, args, _) => {
let mut tuple = f.debug_tuple(&format!("{impl_prim:?}"));
for sn in args {
tuple.field(&sn.node);
}
tuple.finish()
}
Node::Array {
len,
inner,
boxed: true,
..
} => {
write!(f, "{}{}{{", Primitive::Len, len)?;
inner.fmt(f)?;
write!(f, "}}")
}
Node::Array {
len,
inner,
boxed: false,
..
} => {
write!(f, "{}{}[", Primitive::Len, len)?;
inner.fmt(f)?;
write!(f, "]")
}
Node::Call(func, _) => write!(f, "call {}", func.id),
Node::CallGlobal(index, _) => write!(f, "<call global {index}>"),
Node::CallMacro { index, .. } => write!(f, "<call macro {index}>"),
Node::BindGlobal { index, .. } => write!(f, "<bind global {index}>"),
Node::Label(label, _) => write!(f, "${label}"),
Node::RemoveLabel(..) => write!(f, "remove label"),
Node::Format(parts, _) => {
write!(f, "$\"")?;
for (i, part) in parts.iter().enumerate() {
if i > 0 {
write!(f, "_")?
}
write!(f, "{part}")?
}
write!(f, "\"")
}
Node::MatchFormatPattern(parts, _) => {
write!(f, "°$\"")?;
for (i, part) in parts.iter().enumerate() {
if i > 0 {
write!(f, "_")?
}
write!(f, "{part}")?
}
write!(f, "\"")
}
Node::Switch {
branches,
sig,
under_cond,
..
} => {
write!(f, "⨬{sig}")?;
if *under_cond || sig.under() != (0, 0) {
write!(f, "⍜({})", sig.under())?;
}
write!(f, "(")?;
for (i, br) in branches.iter().enumerate() {
if i > 0 {
write!(f, "|")?;
}
br.node.fmt(f)?;
}
write!(f, ")")
}
Node::CustomInverse(cust, _) => cust.fmt(f),
Node::Unpack {
count,
unbox: false,
..
} => write!(f, "<unpack {count}>"),
Node::Unpack {
count, unbox: true, ..
} => write!(f, "<unpack (unbox) {count}>"),
Node::SetOutputComment { i, n, .. } => write!(f, "<set output comment {i}({n})>"),
Node::Dynamic(func) => write!(f, "<dynamic function {}>", func.index),
Node::PushUnder(count, _) => write!(f, "push-u-{count}"),
Node::CopyToUnder(count, _) => write!(f, "copy-u-{count}"),
Node::PopUnder(count, _) => write!(f, "pop-u-{count}"),
Node::NoInline(inner) => f.debug_tuple("no-inline").field(inner.as_ref()).finish(),
Node::TrackCaller(inner) => {
f.debug_tuple("track-caller").field(inner.as_ref()).finish()
}
}
}
}
impl Node {
pub fn is_pure<'a>(&'a self, asm: &'a Assembly) -> bool {
self.is_min_purity(Purity::Pure, asm)
}
pub fn is_min_purity<'a>(&'a self, min_purity: Purity, asm: &'a Assembly) -> bool {
fn recurse<'a>(
node: &'a Node,
purity: Purity,
asm: &'a Assembly,
visited: &mut IndexSet<&'a Function>,
) -> bool {
let len = visited.len();
let is = match node {
Node::Run(nodes) => nodes.iter().all(|node| recurse(node, purity, asm, visited)),
Node::Prim(prim, _) => prim.purity() >= purity,
Node::ImplPrim(prim, _) => prim.purity() >= purity,
Node::Mod(prim, args, _) => {
prim.purity() >= purity
&& args
.iter()
.all(|arg| recurse(&arg.node, purity, asm, visited))
}
Node::ImplMod(prim, args, _) => {
prim.purity() >= purity
&& args
.iter()
.all(|arg| recurse(&arg.node, purity, asm, visited))
}
Node::Array { inner, .. } => recurse(inner, purity, asm, visited),
Node::Call(func, _) => {
if (func.origin.binding()).is_some_and(|i| asm.bindings[i].meta.external) {
false
} else {
visited.insert(func) && recurse(&asm[func], purity, asm, visited)
}
}
Node::CallGlobal(index, _) => {
if let Some(binding) = asm.bindings.get(*index) {
match &binding.kind {
BindingKind::Const(_) => true,
BindingKind::Func(f) => {
visited.insert(f) && recurse(&asm[f], purity, asm, visited)
}
_ => false,
}
} else {
false
}
}
Node::Switch { branches, .. } => branches
.iter()
.all(|br| recurse(&br.node, purity, asm, visited)),
Node::CustomInverse(cust, _) => (cust.normal.as_ref().ok())
.or(cust.un.as_ref())
.is_some_and(|sn| recurse(&sn.node, purity, asm, visited)),
Node::TrackCaller(sn) => recurse(&sn.node, purity, asm, visited),
Node::NoInline(n) => recurse(n, purity, asm, visited),
Node::Dynamic(_) => false,
_ => true,
};
visited.truncate(len);
is
}
thread_local! {
static CACHE: RefCell<HashMap<(u64, Purity), bool>> = RefCell::default();
}
let mut hasher = RapidHasher::new(1);
self.hash(&mut hasher);
let hash = hasher.finish();
CACHE.with(|cache| {
*(cache.borrow_mut().entry((hash, min_purity)))
.or_insert_with(|| recurse(self, min_purity, asm, &mut IndexSet::new()))
})
}
pub fn is_limit_bounded<'a>(&'a self, asm: &'a Assembly) -> bool {
fn recurse<'a>(
node: &'a Node,
asm: &'a Assembly,
visited: &mut IndexSet<&'a Function>,
) -> bool {
let len = visited.len();
let is = match node {
Node::Run(nodes) => nodes.iter().all(|node| recurse(node, asm, visited)),
Node::Prim(Primitive::Send | Primitive::Recv, _) => false,
Node::Prim(Primitive::Sys(op), _) if op.purity() <= Purity::Mutating => false,
Node::Mod(_, args, _) | Node::ImplMod(_, args, _) => {
args.iter().all(|arg| recurse(&arg.node, asm, visited))
}
Node::Array { inner, .. } => recurse(inner, asm, visited),
Node::Call(func, _) => visited.insert(func) && recurse(&asm[func], asm, visited),
Node::CallGlobal(index, _) => {
if let Some(binding) = asm.bindings.get(*index) {
match &binding.kind {
BindingKind::Const(Some(_)) => true,
BindingKind::Func(f) => {
visited.insert(f) && recurse(&asm[f], asm, visited)
}
_ => false,
}
} else {
false
}
}
Node::Switch { branches, .. } => {
branches.iter().all(|br| recurse(&br.node, asm, visited))
}
Node::CustomInverse(cust, _) => (cust.normal.as_ref().ok())
.or(cust.un.as_ref())
.is_some_and(|sn| recurse(&sn.node, asm, visited)),
Node::TrackCaller(sn) => recurse(&sn.node, asm, visited),
Node::NoInline(n) => recurse(n, asm, visited),
_ => true,
};
visited.truncate(len);
is
}
recurse(self, asm, &mut IndexSet::new())
}
pub fn is_recursive(&self, asm: &Assembly) -> bool {
fn recurse<'a>(
node: &'a Node,
asm: &'a Assembly,
visited: &mut IndexSet<&'a Function>,
) -> bool {
let len = visited.len();
let is = match node {
Node::Run(nodes) => nodes.iter().any(|node| recurse(node, asm, visited)),
Node::Mod(_, args, _) | Node::ImplMod(_, args, _) => {
args.iter().any(|sn| recurse(&sn.node, asm, visited))
}
Node::Call(f, _) => !visited.insert(f) || recurse(&asm[f], asm, visited),
Node::Switch { branches, .. } => {
branches.iter().any(|br| recurse(&br.node, asm, visited))
}
Node::CustomInverse(cust, _) => (cust.normal.as_ref().ok())
.or(cust.un.as_ref())
.is_some_and(|sn| recurse(&sn.node, asm, visited)),
Node::Array { inner, .. } => recurse(inner, asm, visited),
Node::TrackCaller(sn) => recurse(&sn.node, asm, visited),
Node::NoInline(n) => recurse(n, asm, visited),
_ => false,
};
visited.truncate(len);
is
}
recurse(self, asm, &mut IndexSet::new())
}
pub fn check_callability<'a>(
&'a self,
asm: &'a Assembly,
) -> Result<(), (InversionError, Option<&'a Function>, Vec<usize>)> {
fn recurse<'a>(
node: &'a Node,
asm: &'a Assembly,
spans: &mut Vec<usize>,
visited: &mut IndexSet<&'a Function>,
) -> Option<(InversionError, Option<&'a Function>)> {
let len = visited.len();
let e = match node {
Node::Run(nodes) => nodes.iter().find_map(|n| recurse(n, asm, spans, visited)),
Node::Call(f, span) => {
if visited.insert(f) {
recurse(&asm[f], asm, spans, visited).map(|(e, mut func)| {
spans.push(*span);
func.get_or_insert(f);
(e, func)
})
} else {
None
}
}
Node::CustomInverse(cust, span) => cust.normal.as_ref().err().cloned().map(|e| {
spans.push(*span);
(e, None)
}),
Node::Switch { branches, .. } => branches
.iter()
.find_map(|br| recurse(&br.node, asm, spans, visited)),
Node::Array { inner, .. } => recurse(inner, asm, spans, visited),
Node::TrackCaller(sn) => recurse(&sn.node, asm, spans, visited),
Node::NoInline(n) => recurse(n, asm, spans, visited),
_ => None,
};
visited.truncate(len);
e
}
let mut spans = Vec::new();
if let Some((e, func)) = recurse(self, asm, &mut spans, &mut IndexSet::new()) {
Err((e, func, spans))
} else {
Ok(())
}
}
pub fn is_noreturn<'a>(&'a self, asm: &'a Assembly) -> bool {
use Primitive::*;
fn recurse<'a>(
node: &'a Node,
asm: &'a Assembly,
visited: &mut IndexSet<&'a Function>,
) -> bool {
let len = visited.len();
let res = 'blk: {
if let Some(i) =
(node.as_slice().iter()).position(|n| matches!(n, Node::Prim(Assert, _)))
{
let init = &node.as_slice()[..i];
let noreturn = match init {
[.., Node::Push(val), Node::Prim(Dup | Flip, _)] if *val != 1 => true,
[
..,
Node::Format(..) | Node::Prim(Couple | Join, _) | Node::Array { .. },
Node::Prim(Dup, _),
] => true,
[.., Node::Push(val), Node::Push(_)] if *val != 1 => true,
[.., Node::Mod(Dip, args, _)]
if args.len() == 1
&& matches!(&args[0].node, Node::Push(val) if *val != 1) =>
{
true
}
_ => false,
};
if noreturn {
break 'blk true;
}
}
match node {
Node::Mod(Dip | Gap | On | By | With | Off, inner, _) => {
recurse(&inner[0].node, asm, visited)
}
Node::Call(f, _) => visited.insert(f) && recurse(&asm[f], asm, visited),
_ => false,
}
};
visited.truncate(len);
res
}
recurse(self, asm, &mut IndexSet::new())
}
}
impl Deref for Node {
type Target = [Node];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl AsRef<[Node]> for Node {
fn as_ref(&self) -> &[Node] {
self.as_slice()
}
}
impl Borrow<[Node]> for Node {
fn borrow(&self) -> &[Node] {
self.as_slice()
}
}
impl<'a> IntoIterator for &'a Node {
type Item = &'a Node;
type IntoIter = slice::Iter<'a, Node>;
fn into_iter(self) -> Self::IntoIter {
self.as_slice().iter()
}
}
impl<'a> IntoIterator for &'a mut Node {
type Item = &'a mut Node;
type IntoIter = slice::IterMut<'a, Node>;
fn into_iter(self) -> Self::IntoIter {
self.as_mut_slice().iter_mut()
}
}
impl IntoIterator for Node {
type Item = Node;
type IntoIter = ecow::vec::IntoIter<Node>;
fn into_iter(self) -> Self::IntoIter {
self.into_vec().into_iter()
}
}
macro_rules! node {
($(
$(#[$attr:meta])*
$((#[$rep_attr:meta] rep),)?
$name:ident
$(($($tup_name:ident($tup_type:ty)),* $(,)?))?
$({$($field_name:ident : $field_type:ty),* $(,)?})?
),* $(,)?) => {
#[derive(Clone, Serialize, Deserialize)]
#[repr(u8)]
#[allow(missing_docs)]
#[serde(from = "NodeRep", into = "NodeRep")]
pub enum Node {
$(
$(#[$attr])*
$name $(($($tup_type),*))? $({$($field_name : $field_type),*})?,
)*
}
macro_rules! field_span {
(span, $sp:ident) => {
return Some($sp)
};
($sp:ident, $sp2:ident) => {};
}
impl Node {
#[allow(unreachable_code, unused)]
pub fn span(&self) -> Option<usize> {
if let Node::Run(nodes) = &self {
return nodes.iter().find_map(Node::span);
}
(|| match self {
$(
Self::$name $(($($tup_name),*))? $({$($field_name),*})? => {
$($(field_span!($tup_name, $tup_name);)*)*
$($(field_span!($field_name, $field_name);)*)*
return None;
},
)*
})().copied()
}
#[allow(unreachable_code, unused)]
pub fn span_mut(&mut self) -> Option<&mut usize> {
match self {
$(
Self::$name $(($($tup_name),*))? $({$($field_name),*})? => {
$($(field_span!($tup_name, $tup_name);)*)*
$($(field_span!($field_name, $field_name);)*)*
return None;
},
)*
}
}
}
impl PartialEq for Node {
#[allow(unused_variables)]
fn eq(&self, other: &Self) -> bool {
let mut hasher = RapidHasher::new(1);
self.hash(&mut hasher);
let hash = hasher.finish();
let mut other_hasher = RapidHasher::new(1);
other.hash(&mut other_hasher);
let other_hash = other_hasher.finish();
hash == other_hash
}
}
impl Eq for Node {}
impl Hash for Node {
#[allow(unused_variables)]
fn hash<H: Hasher>(&self, state: &mut H) {
macro_rules! hash_field {
(span span) => {};
(val $val:ident) => {Hash::hash(&AestheticHash($val), state)};
($nm:ident $_nm:ident) => {Hash::hash($nm, state)};
}
match self {
$(
Self::$name $(($($tup_name),*))? $({$($field_name),*})? => {
discriminant(self).hash(state);
$($(hash_field!($field_name $field_name);)*)?
$($(hash_field!($tup_name $tup_name);)*)?
}
)*
}
}
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub(crate) enum NodeRep {
#[serde(rename = "e")]
Empty(),
$(
$(#[$rep_attr])?
$name(
$($($tup_type),*)?
$($($field_type),*)?
),
)*
}
impl From<NodeRep> for Node {
fn from(rep: NodeRep) -> Self {
match rep {
NodeRep::Empty() => Self::empty(),
$(
NodeRep::$name (
$($($tup_name,)*)?
$($($field_name,)*)?
) => Self::$name $(($($tup_name),*))? $({$($field_name),*})?,
)*
}
}
}
impl From<Node> for NodeRep {
fn from(instr: Node) -> Self {
match instr {
Node::Run(nodes) if nodes.is_empty() => NodeRep::Empty(),
$(
Node::$name $(($($tup_name),*))? $({$($field_name),*})? => NodeRep::$name (
$($($tup_name),*)?
$($($field_name),*)?
),
)*
}
}
}
};
}
use node;