use super::{Context, Expr, ExprRef, GetNode, StringRef};
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::iter::Enumerate;
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum SignalKind {
Node,
Input,
State,
}
impl SignalKind {
#[inline]
pub fn to_string(&self) -> &'static str {
match &self {
SignalKind::Node => "node",
SignalKind::Input => "input",
SignalKind::State => "state",
}
}
}
impl Display for SignalKind {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_string())
}
}
#[derive(Copy, Clone, Default, Eq, PartialEq, Hash)]
pub struct SignalLabels(u8);
impl SignalLabels {
pub fn output() -> Self {
Self::set(0)
}
pub fn is_output(&self) -> bool {
self.get(0)
}
pub fn bad() -> Self {
Self::set(1)
}
pub fn is_bad(&self) -> bool {
self.get(1)
}
pub fn constraint() -> Self {
Self::set(2)
}
pub fn is_constraint(&self) -> bool {
self.get(2)
}
pub fn fair() -> Self {
Self::set(3)
}
pub fn is_fair(&self) -> bool {
self.get(3)
}
#[inline]
fn get(&self, pos: usize) -> bool {
(self.0 >> pos) & 1 == 1
}
#[inline]
fn set(pos: usize) -> Self {
Self(1 << pos)
}
pub fn is_none(&self) -> bool {
self.0 == 0
}
pub fn union(&self, other: &Self) -> Self {
Self(self.0 | other.0)
}
pub fn clear(&self, other: &Self) -> Self {
Self(self.0 & !other.0)
}
}
impl std::str::FromStr for SignalLabels {
type Err = ();
fn from_str(label: &str) -> Result<Self, Self::Err> {
match label {
"output" => Ok(SignalLabels::output()),
"bad" => Ok(SignalLabels::bad()),
"constraint" => Ok(SignalLabels::constraint()),
"fair" => Ok(SignalLabels::fair()),
_ => Err(()),
}
}
}
impl std::fmt::Debug for SignalLabels {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "SignalLabels(")?;
if self.is_output() {
write!(f, "output ")?;
}
if self.is_bad() {
write!(f, "bad ")?;
}
if self.is_constraint() {
write!(f, "constraint ")?;
}
if self.is_fair() {
write!(f, "fair ")?;
}
write!(f, ")")
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct SignalInfo {
pub name: Option<StringRef>,
pub kind: SignalKind,
pub labels: SignalLabels,
}
impl SignalInfo {
pub fn is_input(&self) -> bool {
self.kind == SignalKind::Input
}
pub fn is_state(&self) -> bool {
self.kind == SignalKind::State
}
pub fn is_output(&self) -> bool {
self.labels.is_output()
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct State {
pub symbol: ExprRef,
pub init: Option<ExprRef>,
pub next: Option<ExprRef>,
}
impl State {
pub fn is_const(&self) -> bool {
self.next.map(|n| n == self.symbol).unwrap_or(false)
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub struct StateRef(usize);
impl StateRef {
pub fn to_index(&self) -> usize {
self.0
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub struct InputRef(usize);
#[derive(Debug, PartialEq, Eq)]
pub struct TransitionSystem {
pub name: String,
pub(crate) states: Vec<State>,
signals: Vec<Option<SignalInfo>>,
}
impl TransitionSystem {
pub fn new(name: String) -> Self {
TransitionSystem {
name,
states: Vec::default(),
signals: Vec::default(),
}
}
pub fn add_signal(
&mut self,
expr: ExprRef,
kind: SignalKind,
labels: SignalLabels,
name: Option<StringRef>,
) {
let id = expr.index();
if self.signals.len() <= id {
self.signals.resize(id + 1, None);
}
self.signals[id] = Some(SignalInfo { name, kind, labels });
}
pub fn get_signal(&self, expr: ExprRef) -> Option<&SignalInfo> {
let entry = self.signals.get(expr.index())?;
entry.as_ref()
}
pub fn remove_signal(&mut self, expr: ExprRef) {
*self
.signals
.get_mut(expr.index())
.expect("trying to remove non-existing signal") = None;
}
pub fn update_signal_expr(&mut self, old: ExprRef, new: ExprRef) {
if old != new {
if let Some(old_info) = &self.signals[old.index()] {
let cloned = old_info.clone();
let new_id = new.index();
if self.signals.len() <= new_id {
self.signals.resize(new_id + 1, None);
}
let merged = if let Some(new_info) = &self.signals[new_id] {
merge_signal_info(&cloned, new_info)
} else {
cloned
};
self.signals[new_id] = Some(merged);
self.signals[old.index()] = None;
}
}
}
pub fn add_input(&mut self, ctx: &impl GetNode<Expr, ExprRef>, symbol: ExprRef) {
assert!(symbol.is_symbol(ctx));
let name = symbol.get_symbol_name_ref(ctx);
self.add_signal(symbol, SignalKind::Input, SignalLabels::default(), name);
}
pub fn add_state(&mut self, ctx: &impl GetNode<Expr, ExprRef>, symbol: ExprRef) -> StateRef {
assert!(symbol.is_symbol(ctx));
let name = symbol.get_symbol_name_ref(ctx);
self.add_signal(symbol, SignalKind::State, SignalLabels::default(), name);
let id = self.states.len();
self.states.push(State {
symbol,
init: None,
next: None,
});
StateRef(id)
}
pub fn get_state_by_name(&self, ctx: &Context, name: &str) -> Option<&State> {
self.states
.iter()
.find(|s| s.symbol.get_symbol_name(ctx).unwrap() == name)
}
pub fn modify_state<F>(&mut self, reference: StateRef, modify: F)
where
F: FnOnce(&mut State),
{
modify(self.states.get_mut(reference.0).unwrap())
}
pub fn states(&self) -> StateIter<'_> {
StateIter {
underlying: self.states.iter().enumerate(),
}
}
pub fn state_map(&self) -> HashMap<ExprRef, &State> {
HashMap::from_iter(self.states.iter().map(|s| (s.symbol, s)))
}
pub fn remove_state(&mut self, state: StateRef) -> State {
self.states.remove(state.0)
}
pub fn get_signals(&self, filter: fn(&SignalInfo) -> bool) -> Vec<(ExprRef, SignalInfo)> {
self.signals
.iter()
.enumerate()
.filter(|(_, opt)| opt.as_ref().map(filter).unwrap_or(false))
.map(|(index, opt_info)| {
(
ExprRef::from_index(index),
opt_info.as_ref().unwrap().clone(),
)
})
.collect::<Vec<_>>()
}
pub fn constraints(&self) -> Vec<(ExprRef, SignalInfo)> {
self.get_signals(|info| info.labels.is_constraint())
}
pub fn bad_states(&self) -> Vec<(ExprRef, SignalInfo)> {
self.get_signals(|info| info.labels.is_bad())
}
pub fn generate_name_to_ref(&self, ctx: &Context) -> HashMap<String, ExprRef> {
let mut out = HashMap::new();
for (idx, maybe_signal) in self.signals.iter().enumerate() {
if let Some(signal) = maybe_signal {
let skip = signal.kind == SignalKind::Node && signal.labels.is_none();
if !skip {
let expr_ref = ExprRef::from_index(idx);
if let Some(name) = signal.name {
let name_str = ctx.get(name).to_string();
out.insert(name_str, expr_ref);
}
if let Some(name) = expr_ref.get_symbol_name(ctx) {
out.insert(name.to_string(), expr_ref);
}
}
}
}
out
}
}
pub struct StateIter<'a> {
underlying: Enumerate<std::slice::Iter<'a, State>>,
}
impl<'a> Iterator for StateIter<'a> {
type Item = (StateRef, &'a State);
fn next(&mut self) -> Option<Self::Item> {
self.underlying.next().map(|(i, s)| (StateRef(i), s))
}
}
impl<'a> DoubleEndedIterator for StateIter<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
self.underlying.next_back().map(|(i, s)| (StateRef(i), s))
}
}
pub fn merge_signal_info(original: &SignalInfo, alias: &SignalInfo) -> SignalInfo {
let name = match (original.name, alias.name) {
(Some(name), None) => Some(name),
(None, Some(name)) => Some(name),
(None, None) => None,
(Some(old_name), Some(new_name)) => {
if original.labels.is_output() {
Some(old_name)
} else {
Some(new_name)
}
}
};
let kind = match (original.kind, alias.kind) {
(SignalKind::Node, alias) => alias,
(original, _) => original,
};
let labels = original.labels.union(&alias.labels);
SignalInfo { name, kind, labels }
}
impl GetNode<SignalInfo, ExprRef> for TransitionSystem {
fn get(&self, reference: ExprRef) -> &SignalInfo {
self.signals[reference.index()].as_ref().unwrap()
}
}
impl GetNode<State, StateRef> for TransitionSystem {
fn get(&self, reference: StateRef) -> &State {
&self.states[reference.0]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ir_type_size() {
assert_eq!(std::mem::size_of::<SignalKind>(), 1);
assert_eq!(std::mem::size_of::<SignalLabels>(), 1);
assert_eq!(std::mem::size_of::<SignalInfo>(), 8);
assert_eq!(std::mem::size_of::<Option<SignalInfo>>(), 8);
}
}