use indexmap::IndexMap;
use super::value::{EvalError, EvalResult, Value};
#[derive(Debug, Clone)]
pub struct Scopes {
stack: Vec<Scope>,
}
#[derive(Debug, Clone, Default)]
pub struct Scope {
bindings: IndexMap<String, Value>,
}
impl Scope {
pub fn new() -> Self {
Self {
bindings: IndexMap::new(),
}
}
pub fn from_captures(captures: IndexMap<String, Value>) -> Self {
Self { bindings: captures }
}
pub fn define(&mut self, name: impl Into<String>, value: Value) {
self.bindings.insert(name.into(), value);
}
pub fn get(&self, name: &str) -> Option<&Value> {
self.bindings.get(name)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut Value> {
self.bindings.get_mut(name)
}
pub fn contains(&self, name: &str) -> bool {
self.bindings.contains_key(name)
}
pub fn bindings(&self) -> &IndexMap<String, Value> {
&self.bindings
}
pub fn into_bindings(self) -> IndexMap<String, Value> {
self.bindings
}
}
impl Scopes {
pub fn new() -> Self {
Self {
stack: vec![Scope::new()],
}
}
pub fn with_stdlib(stdlib: IndexMap<String, Value>) -> Self {
Self {
stack: vec![Scope::from_captures(stdlib)],
}
}
pub fn enter(&mut self) {
self.stack.push(Scope::new());
}
pub fn enter_with_captures(&mut self, captures: IndexMap<String, Value>) {
self.stack.push(Scope::from_captures(captures));
}
pub fn exit(&mut self) -> Option<Scope> {
if self.stack.len() > 1 {
self.stack.pop()
} else {
None
}
}
pub fn define(&mut self, name: impl Into<String>, value: Value) {
if let Some(scope) = self.stack.last_mut() {
scope.define(name, value);
}
}
pub fn get(&self, name: &str) -> Option<&Value> {
for scope in self.stack.iter().rev() {
if let Some(value) = scope.get(name) {
return Some(value);
}
}
None
}
pub fn get_or_err(&self, name: &str) -> EvalResult<&Value> {
self.get(name)
.ok_or_else(|| EvalError::undefined(name.to_string()))
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut Value> {
for scope in self.stack.iter_mut().rev() {
if scope.contains(name) {
return scope.get_mut(name);
}
}
None
}
pub fn assign(&mut self, name: &str, value: Value) -> EvalResult<()> {
for scope in self.stack.iter_mut().rev() {
if scope.contains(name) {
scope.define(name, value);
return Ok(());
}
}
Err(EvalError::undefined(name.to_string()))
}
pub fn contains(&self, name: &str) -> bool {
self.get(name).is_some()
}
pub fn depth(&self) -> usize {
self.stack.len()
}
pub fn current(&self) -> Option<&Scope> {
self.stack.last()
}
pub fn current_mut(&mut self) -> Option<&mut Scope> {
self.stack.last_mut()
}
pub fn top_bindings(&self) -> IndexMap<String, Value> {
self.stack
.last()
.map(|s| s.bindings.clone())
.unwrap_or_default()
}
pub fn capture_all(&self) -> IndexMap<String, Value> {
let mut captures = IndexMap::new();
for scope in &self.stack {
for (name, value) in scope.bindings() {
captures.insert(name.clone(), value.clone());
}
}
captures
}
pub fn capture(&self, names: &[&str]) -> IndexMap<String, Value> {
let mut captures = IndexMap::new();
for name in names {
if let Some(value) = self.get(name) {
captures.insert((*name).to_string(), value.clone());
}
}
captures
}
}
impl Default for Scopes {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scope_define_and_get() {
let mut scopes = Scopes::new();
scopes.define("x", Value::Int(42));
assert_eq!(scopes.get("x"), Some(&Value::Int(42)));
assert_eq!(scopes.get("y"), None);
}
#[test]
fn test_scope_shadowing() {
let mut scopes = Scopes::new();
scopes.define("x", Value::Int(1));
scopes.enter();
scopes.define("x", Value::Int(2));
assert_eq!(scopes.get("x"), Some(&Value::Int(2)));
scopes.exit();
assert_eq!(scopes.get("x"), Some(&Value::Int(1)));
}
#[test]
fn test_scope_nested_lookup() {
let mut scopes = Scopes::new();
scopes.define("outer", Value::Int(1));
scopes.enter();
scopes.define("inner", Value::Int(2));
assert_eq!(scopes.get("outer"), Some(&Value::Int(1)));
assert_eq!(scopes.get("inner"), Some(&Value::Int(2)));
scopes.exit();
assert_eq!(scopes.get("outer"), Some(&Value::Int(1)));
assert_eq!(scopes.get("inner"), None);
}
#[test]
fn test_capture_all() {
let mut scopes = Scopes::new();
scopes.define("a", Value::Int(1));
scopes.enter();
scopes.define("b", Value::Int(2));
scopes.define("a", Value::Int(10));
let captures = scopes.capture_all();
assert_eq!(captures.get("a"), Some(&Value::Int(10))); assert_eq!(captures.get("b"), Some(&Value::Int(2)));
}
}