use std::{borrow::Cow, collections::BTreeSet, hash::BuildHasherDefault};
use indexmap::IndexMap;
use rustc_hash::FxHasher;
use thiserror::Error;
use crate::v0::table::{NodeId, RegionId};
type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
pub type SymbolName<'a> = &'a str;
#[derive(Debug, Clone, Default)]
pub struct SymbolTable<'a> {
symbols: FxIndexMap<SymbolKey<'a>, BindingIndex>,
bindings: FxIndexMap<NodeId, Binding>,
scopes: FxIndexMap<RegionId, Scope>,
latest_versioned: FxIndexMap<SymbolName<'a>, BTreeSet<semver::Version>>,
}
impl<'a> SymbolTable<'a> {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn enter(&mut self, region: RegionId) {
self.scopes.insert(
region,
Scope {
binding_stack: self.bindings.len(),
},
);
}
pub fn exit(&mut self) {
let (_, scope) = self.scopes.pop().unwrap();
for _ in scope.binding_stack..self.bindings.len() {
let (_, binding) = self.bindings.pop().unwrap();
let key = self
.symbols
.get_index(binding.symbol_index)
.expect("Symbol must be present in version table")
.0
.clone();
if let Some(shadows) = binding.shadows {
self.symbols[binding.symbol_index] = shadows;
} else {
let last = self.symbols.pop();
debug_assert_eq!(last.unwrap().1, self.bindings.len());
self.remove_latest_versioned(&key);
}
}
}
pub fn insert(
&mut self,
name: SymbolName<'a>,
version: Option<&semver::Version>,
node: NodeId,
) -> Result<(), DuplicateSymbolError<'_>> {
self.insert_binding(name, version, node)
}
pub fn insert_import(
&mut self,
name: SymbolName<'a>,
version: Option<&semver::Version>,
node: NodeId,
) -> Result<(), DuplicateSymbolError<'_>> {
self.insert_binding(name, version, node)
}
fn insert_binding(
&mut self,
name: SymbolName<'a>,
version: Option<&semver::Version>,
node: NodeId,
) -> Result<(), DuplicateSymbolError<'_>> {
let key = SymbolKey::new(name, version);
let scope_depth = self.scopes.len() as u16 - 1;
let (symbol_index, shadowed) = self.symbols.insert_full(key.clone(), self.bindings.len());
if let Some(shadowed) = shadowed {
let (shadowed_node, shadowed_binding) = self.bindings.get_index(shadowed).unwrap();
if shadowed_binding.scope_depth == scope_depth {
self.symbols.insert(key, shadowed);
return Err(DuplicateSymbolError(name.into(), node, *shadowed_node));
}
}
self.insert_latest_versioned(&key);
self.bindings.insert(
node,
Binding {
scope_depth,
shadows: shadowed,
symbol_index,
},
);
Ok(())
}
#[must_use]
pub fn is_visible(&self, node: NodeId) -> bool {
let Some(binding) = self.bindings.get(&node) else {
return false;
};
self.symbols[binding.symbol_index] == binding.symbol_index
}
pub fn resolve(
&self,
name: SymbolName<'a>,
version: Option<&semver::Version>,
) -> Result<NodeId, UnknownSymbolError<'_>> {
let index = match version {
Some(version) => self
.symbols
.get(&SymbolKey::new(name, Some(version)))
.copied(),
None => self
.symbols
.get(&SymbolKey::new(name, None))
.copied()
.or_else(|| self.latest_versioned(name)),
}
.ok_or(UnknownSymbolError(name.into()))?;
let (node, _) = self.bindings.get_index(index).unwrap();
Ok(*node)
}
fn latest_versioned(&self, name: SymbolName<'a>) -> Option<BindingIndex> {
let latest = self.latest_versioned.get(name)?.last()?;
self.symbols
.get(&SymbolKey::new(name, Some(latest)))
.copied()
}
fn insert_latest_versioned(&mut self, key: &SymbolKey<'a>) {
let Some(version) = &key.version else {
return;
};
self.latest_versioned
.entry(key.name)
.or_default()
.insert(version.clone());
}
fn remove_latest_versioned(&mut self, key: &SymbolKey<'a>) {
let Some(version) = &key.version else {
return;
};
let Some(versions) = self.latest_versioned.get_mut(key.name) else {
return;
};
versions.remove(version);
if versions.is_empty() {
self.latest_versioned.swap_remove(key.name);
}
}
#[must_use]
pub fn region_to_depth(&self, region: RegionId) -> Option<ScopeDepth> {
Some(self.scopes.get_index_of(®ion)? as _)
}
#[must_use]
pub fn depth_to_region(&self, depth: ScopeDepth) -> Option<RegionId> {
let (region, _) = self.scopes.get_index(depth as _)?;
Some(*region)
}
pub fn clear(&mut self) {
self.symbols.clear();
self.bindings.clear();
self.scopes.clear();
self.latest_versioned.clear();
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct SymbolKey<'a> {
name: SymbolName<'a>,
version: Option<semver::Version>,
}
impl<'a> SymbolKey<'a> {
fn new(name: SymbolName<'a>, version: Option<&semver::Version>) -> Self {
Self {
name,
version: version.cloned(),
}
}
}
#[derive(Debug, Clone, Copy)]
struct Binding {
scope_depth: ScopeDepth,
shadows: Option<BindingIndex>,
symbol_index: SymbolIndex,
}
#[derive(Debug, Clone, Copy)]
struct Scope {
binding_stack: usize,
}
type BindingIndex = usize;
type SymbolIndex = usize;
pub type ScopeDepth = u16;
#[derive(Debug, Clone, Error)]
#[error("symbol name `{0}` not found in this scope")]
pub struct UnknownSymbolError<'a>(pub Cow<'a, str>);
#[derive(Debug, Clone, Error)]
#[error("symbol `{0}` is already defined in this scope")]
pub struct DuplicateSymbolError<'a>(pub Cow<'a, str>, pub NodeId, pub NodeId);
#[cfg(test)]
mod tests {
use super::*;
use rstest::{fixture, rstest};
#[derive(Debug, Clone, Copy)]
enum BindingSource {
Import,
Symbol,
}
impl BindingSource {
fn insert(
self,
symbols: &mut SymbolTable<'static>,
name: &'static str,
version: Option<&semver::Version>,
node: NodeId,
) {
match self {
Self::Import => symbols.insert_import(name, version, node).unwrap(),
Self::Symbol => symbols.insert(name, version, node).unwrap(),
}
}
}
#[fixture]
fn symbols() -> SymbolTable<'static> {
let mut symbols = SymbolTable::new();
symbols.enter(RegionId(0));
symbols
}
fn version(version: &str) -> semver::Version {
version.parse().unwrap()
}
#[rstest]
#[case::imports(BindingSource::Import)]
#[case::symbols(BindingSource::Symbol)]
fn latest_version(
#[case] source: BindingSource,
#[from(symbols)] mut symbols: SymbolTable<'static>,
) {
let older = version("1.2.3");
let newer = version("1.3.0");
source.insert(&mut symbols, "std.int.add", Some(&older), NodeId(0));
source.insert(&mut symbols, "std.int.add", Some(&newer), NodeId(1));
assert_eq!(symbols.resolve("std.int.add", None).unwrap(), NodeId(1));
assert_eq!(
symbols.resolve("std.int.add", Some(&older)).unwrap(),
NodeId(0)
);
assert_eq!(
symbols.resolve("std.int.add", Some(&newer)).unwrap(),
NodeId(1)
);
}
#[rstest]
#[case::imports(BindingSource::Import)]
#[case::symbols(BindingSource::Symbol)]
fn single_version(
#[case] source: BindingSource,
#[from(symbols)] mut symbols: SymbolTable<'static>,
) {
let version = version("1.2.3");
source.insert(&mut symbols, "std.int.add", Some(&version), NodeId(0));
assert_eq!(symbols.resolve("std.int.add", None).unwrap(), NodeId(0));
}
#[rstest]
fn unversioned_is_exact(#[from(symbols)] mut symbols: SymbolTable<'static>) {
let version = version("1.2.3");
symbols.insert("std.int.add", None, NodeId(0)).unwrap();
symbols
.insert_import("std.int.add", Some(&version), NodeId(1))
.unwrap();
assert_eq!(symbols.resolve("std.int.add", None).unwrap(), NodeId(0));
assert_eq!(
symbols.resolve("std.int.add", Some(&version)).unwrap(),
NodeId(1)
);
}
#[rstest]
#[case::imports(BindingSource::Import)]
#[case::symbols(BindingSource::Symbol)]
fn latest_after_exit(
#[case] source: BindingSource,
#[from(symbols)] mut symbols: SymbolTable<'static>,
) {
let older = version("1.2.3");
let newer = version("1.3.0");
source.insert(&mut symbols, "std.int.add", Some(&older), NodeId(0));
symbols.enter(RegionId(1));
source.insert(&mut symbols, "std.int.add", Some(&newer), NodeId(1));
assert_eq!(symbols.resolve("std.int.add", None).unwrap(), NodeId(1));
symbols.exit();
assert_eq!(symbols.resolve("std.int.add", None).unwrap(), NodeId(0));
}
#[rstest]
#[case::imports(BindingSource::Import)]
#[case::symbols(BindingSource::Symbol)]
fn latest_shadowed(
#[case] source: BindingSource,
#[from(symbols)] mut symbols: SymbolTable<'static>,
) {
let version = version("1.2.3");
source.insert(&mut symbols, "std.int.add", Some(&version), NodeId(0));
symbols.enter(RegionId(1));
source.insert(&mut symbols, "std.int.add", Some(&version), NodeId(1));
assert_eq!(symbols.resolve("std.int.add", None).unwrap(), NodeId(1));
symbols.exit();
assert_eq!(symbols.resolve("std.int.add", None).unwrap(), NodeId(0));
}
}