use smol_str::SmolStr;
use thiserror::Error;
mod view;
use super::{Literal, RegionKind, Visibility, ast};
pub use view::View;
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
pub struct Package<'a> {
pub modules: Vec<Module<'a>>,
}
impl Package<'_> {
#[must_use]
pub fn as_ast(&self) -> Option<ast::Package> {
let modules = self
.modules
.iter()
.map(Module::as_ast)
.collect::<Option<_>>()?;
Some(ast::Package { modules })
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
pub struct Module<'a> {
pub root: RegionId,
pub nodes: Vec<Node<'a>>,
pub regions: Vec<Region<'a>>,
pub terms: Vec<Term<'a>>,
}
impl<'a> Module<'a> {
#[inline]
#[must_use]
pub fn get_node(&self, node_id: NodeId) -> Option<&Node<'a>> {
self.nodes.get(node_id.index())
}
#[inline]
pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut Node<'a>> {
self.nodes.get_mut(node_id.index())
}
pub fn insert_node(&mut self, node: Node<'a>) -> NodeId {
let id = NodeId::new(self.nodes.len());
self.nodes.push(node);
id
}
#[inline]
#[must_use]
pub fn get_term(&self, term_id: TermId) -> Option<&Term<'a>> {
if term_id.is_valid() {
self.terms.get(term_id.index())
} else {
Some(&Term::Wildcard)
}
}
#[inline]
pub fn get_term_mut(&mut self, term_id: TermId) -> Option<&mut Term<'a>> {
self.terms.get_mut(term_id.index())
}
pub fn insert_term(&mut self, term: Term<'a>) -> TermId {
let id = TermId::new(self.terms.len());
self.terms.push(term);
id
}
#[inline]
#[must_use]
pub fn get_region(&self, region_id: RegionId) -> Option<&Region<'a>> {
self.regions.get(region_id.index())
}
#[inline]
pub fn get_region_mut(&mut self, region_id: RegionId) -> Option<&mut Region<'a>> {
self.regions.get_mut(region_id.index())
}
pub fn insert_region(&mut self, region: Region<'a>) -> RegionId {
let id = RegionId::new(self.regions.len());
self.regions.push(region);
id
}
pub fn view<S, V: View<'a, S>>(&'a self, src: S) -> Option<V> {
V::view(self, src)
}
#[must_use]
pub fn as_ast(&self) -> Option<ast::Module> {
let root = self.view(self.root)?;
Some(ast::Module { root })
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
pub struct Node<'a> {
pub operation: Operation<'a>,
pub inputs: &'a [LinkIndex],
pub outputs: &'a [LinkIndex],
pub regions: &'a [RegionId],
pub meta: &'a [TermId],
pub signature: Option<TermId>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
pub enum Operation<'a> {
#[default]
Invalid,
Dfg,
Cfg,
Block,
DefineFunc(&'a Symbol<'a>),
DeclareFunc(&'a Symbol<'a>),
Custom(TermId),
DefineAlias(&'a Symbol<'a>, TermId),
DeclareAlias(&'a Symbol<'a>),
TailLoop,
Conditional,
DeclareConstructor(&'a Symbol<'a>),
DeclareOperation(&'a Symbol<'a>),
Import {
name: &'a str,
},
}
impl<'a> Operation<'a> {
#[must_use]
pub fn symbol(&self) -> Option<&'a str> {
match self {
Operation::DefineFunc(symbol) => Some(symbol.name),
Operation::DeclareFunc(symbol) => Some(symbol.name),
Operation::DefineAlias(symbol, _) => Some(symbol.name),
Operation::DeclareAlias(symbol) => Some(symbol.name),
Operation::DeclareConstructor(symbol) => Some(symbol.name),
Operation::DeclareOperation(symbol) => Some(symbol.name),
Operation::Import { name } => Some(name),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
pub struct Region<'a> {
pub kind: RegionKind,
pub sources: &'a [LinkIndex],
pub targets: &'a [LinkIndex],
pub children: &'a [NodeId],
pub meta: &'a [TermId],
pub signature: Option<TermId>,
pub scope: Option<RegionScope>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RegionScope {
pub links: u32,
pub ports: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Symbol<'a> {
pub visibility: &'a Option<Visibility>,
pub name: &'a str,
pub params: &'a [Param<'a>],
pub constraints: &'a [TermId],
pub signature: TermId,
}
pub type VarIndex = u16;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
pub enum Term<'a> {
#[default]
Wildcard,
Var(VarId),
Apply(NodeId, &'a [TermId]),
List(&'a [SeqPart]),
Literal(Literal),
Func(RegionId),
Tuple(&'a [SeqPart]),
}
impl From<Literal> for Term<'_> {
fn from(value: Literal) -> Self {
Self::Literal(value)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum SeqPart {
Item(TermId),
Splice(TermId),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Param<'a> {
pub name: &'a str,
pub r#type: TermId,
}
macro_rules! define_index {
($(#[$meta:meta])* $vis:vis struct $name:ident(pub u32);) => {
#[repr(transparent)]
$(#[$meta])*
$vis struct $name(pub u32);
impl $name {
#[must_use] pub fn new(index: usize) -> Self {
assert!(index < u32::MAX as usize, "index out of bounds");
Self(index as u32)
}
#[inline]
#[must_use] pub fn is_valid(self) -> bool {
self.0 < u32::MAX
}
#[inline]
#[must_use] pub fn index(self) -> usize {
self.0 as usize
}
#[must_use] pub fn unwrap_slice(slice: &[Self]) -> &[u32] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u32, slice.len()) }
}
#[must_use] pub fn wrap_slice(slice: &[u32]) -> &[Self] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
}
}
impl Default for $name {
fn default() -> Self {
Self(u32::MAX)
}
}
};
}
define_index! {
#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct NodeId(pub u32);
}
define_index! {
#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct LinkIndex(pub u32);
}
define_index! {
#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RegionId(pub u32);
}
define_index! {
#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TermId(pub u32);
}
#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[display("{_0}#{_1}")]
pub struct LinkId(pub RegionId, pub LinkIndex);
#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[display("{_0}#{_1}")]
pub struct VarId(pub NodeId, pub VarIndex);
#[derive(Debug, Clone, Error)]
#[non_exhaustive]
pub enum ModelError {
#[error("node not found: {0}")]
NodeNotFound(NodeId),
#[error("term not found: {0}")]
TermNotFound(TermId),
#[error("region not found: {0}")]
RegionNotFound(RegionId),
#[error("variable {0} invalid")]
InvalidVar(VarId),
#[error("symbol reference {0} invalid")]
InvalidSymbol(NodeId),
#[error("unexpected operation on node: {0}")]
UnexpectedOperation(NodeId),
#[error("type error in term: {0}")]
TypeError(TermId),
#[error("node has invalid regions: {0}")]
InvalidRegions(NodeId),
#[error("malformed name: {0}")]
MalformedName(SmolStr),
#[error("condition node is malformed: {0}")]
MalformedCondition(NodeId),
#[error("invalid operation on node: {0}")]
InvalidOperation(NodeId),
}