use std::{collections::HashMap, sync::Arc};
use crate::{
ast::*,
term_arena::{self, TermArena},
};
pub trait SymbolStorage {
fn get_or_insert_named(&mut self, name: &str) -> Sym;
fn build_sym_map<'a, T>(
&mut self,
pairs: impl IntoIterator<Item = (&'a str, T)>,
) -> HashMap<Sym, T> {
pairs
.into_iter()
.map(|(name, value)| (self.get_or_insert_named(name), value))
.collect()
}
}
impl<T> SymbolStorage for &mut T
where
T: SymbolStorage,
{
fn get_or_insert_named(&mut self, name: &str) -> Sym {
(*self).get_or_insert_named(name)
}
}
pub trait Symbols {
fn get_symbol_name(&self, sym: Sym) -> Option<&str>;
fn num_symbols(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct SymbolStore {
sym_by_name: HashMap<Arc<str>, Sym>,
name_by_sym: Vec<Option<Arc<str>>>,
}
impl SymbolStore {
pub fn new() -> Self {
Self {
sym_by_name: HashMap::new(),
name_by_sym: Vec::new(),
}
}
pub fn insert_unnamed(&mut self) -> Sym {
let sym = Sym::from_ord(self.name_by_sym.len());
self.name_by_sym.push(None);
sym
}
}
impl Symbols for SymbolStore {
fn get_symbol_name(&self, sym: Sym) -> Option<&str> {
self.name_by_sym.get(sym.ord()).and_then(|n| n.as_deref())
}
fn num_symbols(&self) -> usize {
self.name_by_sym.len()
}
}
impl SymbolStorage for SymbolStore {
fn get_or_insert_named(&mut self, name: &str) -> Sym {
self.sym_by_name.get(name).copied().unwrap_or_else(|| {
let sym = Sym::from_ord(self.name_by_sym.len());
let shared_name: Arc<str> = name.into();
self.name_by_sym.push(Some(shared_name.clone()));
self.sym_by_name.insert(shared_name, sym);
sym
})
}
}
impl Default for SymbolStore {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SymbolOverlay<'a> {
symbols: &'a SymbolStore,
overlay: SymbolStore,
}
impl<'a> SymbolOverlay<'a> {
pub fn new(symbols: &'a SymbolStore) -> Self {
Self {
symbols,
overlay: Default::default(),
}
}
}
impl SymbolStorage for SymbolOverlay<'_> {
fn get_or_insert_named(&mut self, name: &str) -> Sym {
self.symbols
.sym_by_name
.get(name)
.copied()
.unwrap_or_else(|| {
Sym::from_ord(
self.symbols.num_symbols() + self.overlay.get_or_insert_named(name).ord(),
)
})
}
}
impl Symbols for SymbolOverlay<'_> {
fn get_symbol_name(&self, sym: Sym) -> Option<&str> {
match sym.ord().checked_sub(self.symbols.num_symbols()) {
None => self.symbols.get_symbol_name(sym),
Some(index) => self.overlay.get_symbol_name(Sym::from_ord(index)),
}
}
fn num_symbols(&self) -> usize {
self.overlay.num_symbols() + self.symbols.num_symbols()
}
}
#[derive(Debug, Clone)]
pub struct CompiledRule {
head_blueprint: TermArena,
head: term_arena::TermId,
tail_blueprint: TermArena,
tail: Vec<term_arena::TermId>,
var_slots: usize,
original: Rule,
}
impl CompiledRule {
pub fn new(rule: Rule) -> CompiledRule {
let mut scratch = Vec::new();
let mut head_blueprint = TermArena::new();
let mut tail_blueprint = TermArena::new();
let head = head_blueprint.insert_ast_appterm(&mut scratch, &rule.head);
let tail = rule
.tail
.iter()
.map(|tail| tail_blueprint.insert_ast_term(&mut scratch, tail))
.collect();
CompiledRule {
head_blueprint,
head,
tail_blueprint,
tail,
var_slots: rule.head.count_var_slots().max(
rule.tail
.iter()
.map(|tail| tail.count_var_slots())
.max()
.unwrap_or(0),
),
original: rule,
}
}
#[inline(always)]
pub fn head(&self) -> (term_arena::TermId, &TermArena) {
(self.head, &self.head_blueprint)
}
#[inline(always)]
pub fn tail(&self) -> (&[term_arena::TermId], &TermArena) {
(&self.tail, &self.tail_blueprint)
}
#[inline(always)]
pub fn var_slots(&self) -> usize {
self.var_slots
}
pub fn original(&self) -> &Rule {
&self.original
}
}
#[derive(Debug)]
pub struct RuleSet {
rules_by_head: Vec<Vec<CompiledRule>>,
}
impl RuleSet {
pub fn new() -> Self {
Self {
rules_by_head: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
rules_by_head: vec![Vec::new(); capacity],
}
}
pub fn insert(&mut self, rule: Rule) {
let head = rule.head.functor;
self.ensure_capacity(head);
let compiled = CompiledRule::new(rule);
self.rules_by_head[head.ord()].push(compiled);
}
#[inline(always)]
pub fn rules_by_head(&self, head: Sym) -> &[CompiledRule] {
if head.ord() < self.rules_by_head.len() {
&self.rules_by_head[head.ord()]
} else {
&[]
}
}
fn ensure_capacity(&mut self, sym: Sym) {
if sym.ord() >= self.rules_by_head.len() {
self.rules_by_head.resize(sym.ord() + 1, Vec::new());
}
}
}
impl Default for RuleSet {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod test {
use crate::{ast::Sym, Symbols};
use super::{SymbolOverlay, SymbolStorage, SymbolStore};
#[test]
fn overlay() {
let mut plain = SymbolStore::new();
plain.insert_unnamed();
plain.insert_unnamed();
plain.insert_unnamed();
plain.get_or_insert_named("a");
let overlay_copy = plain.clone();
let mut overlaid = SymbolOverlay::new(&overlay_copy);
assert_eq!(
plain.get_or_insert_named("b"),
overlaid.get_or_insert_named("b")
);
assert_eq!(
plain.get_or_insert_named("c"),
overlaid.get_or_insert_named("c")
);
assert_eq!(plain.num_symbols(), overlaid.num_symbols());
assert_eq!(
plain.get_symbol_name(Sym::from_ord(3)),
overlaid.get_symbol_name(Sym::from_ord(3))
);
assert_eq!(
plain.get_symbol_name(Sym::from_ord(5)),
overlaid.get_symbol_name(Sym::from_ord(5))
);
assert_eq!(
plain.get_symbol_name(Sym::from_ord(99)),
overlaid.get_symbol_name(Sym::from_ord(99))
);
}
}