Skip to main content

openjd_expr/
symbol_table.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5//! Hierarchical symbol table for expression evaluation.
6//!
7//! Mirrors Python `openjd.expr._symbol_table.SymbolTable`.
8//! Supports dotted key paths and nested tables.
9//!
10//! # Construction
11//!
12//! ```
13//! use openjd_expr::{symtab, SymbolTable, ExprValue, ExprType};
14//!
15//! // Macro (most concise):
16//! let st = symtab! {
17//!     "Param.Frame" => 42,
18//!     "Param.Name" => "test",
19//!     "Session.Dir" => ExprType::PATH,  // auto-wraps as unresolved
20//! };
21//!
22//! // Builder-style:
23//! let mut st = SymbolTable::new();
24//! st.set("Param.Frame", 42).unwrap();
25//! st.set("Param.Name", "test").unwrap();
26//!
27//! // From iterator:
28//! let st: SymbolTable = [
29//!     ("Param.Frame", ExprValue::from(42)),
30//!     ("Param.Name", "test".into()),
31//! ].into_iter().collect();
32//! ```
33
34use crate::types::ExprType;
35use crate::value::ExprValue;
36use std::collections::HashMap;
37
38/// Maximum number of entries permitted in a `SymbolTable` deserialized from
39/// the JSON transport format (`SerializedSymbolTable` or the serde
40/// `Deserialize` impl on `SymbolTable` itself).
41///
42/// The deserializer walks an untrusted JSON array and invokes
43/// `SymbolTable::set` for each entry. Real-world symbol tables carry a
44/// handful of job parameters, a handful of session variables, and a
45/// handful of path-mapping-derived bindings — well under a thousand
46/// entries in the aggregate. This cap rejects transport blobs that would
47/// produce a multi-million-entry table purely to inflate the worker's
48/// memory footprint before evaluation begins.
49///
50/// The cap applies **only** to the transport-deserialization paths. Direct
51/// in-process calls to `SymbolTable::set` or `set_table` are trusted
52/// (host code), and callers that legitimately need larger tables (e.g.
53/// a test builder) can construct them by hand without tripping the cap.
54///
55/// See `specs/expr/symbol-table.md` (Defensive caps) for rationale.
56pub const MAX_SYMBOL_TABLE_ENTRIES: usize = 100_000;
57
58/// Error returned when a [`SymbolTable::set`] call conflicts with an existing entry.
59///
60/// For example, setting `"A.B.C"` when `"A.B"` is already a scalar value.
61#[derive(Debug, Clone)]
62pub struct SymbolTableError {
63    pub key: String,
64    pub conflict: String,
65}
66
67impl std::fmt::Display for SymbolTableError {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(
70            f,
71            "Cannot set '{}': '{}' is not a table",
72            self.key, self.conflict
73        )
74    }
75}
76
77impl std::error::Error for SymbolTableError {}
78
79impl From<SymbolTableError> for crate::error::ExpressionError {
80    /// Bubble a symbol-table conflict up as an `ExpressionError::Other`.
81    ///
82    /// The full `SymbolTableError` message is preserved verbatim so callers
83    /// can still see which key was being set and what it conflicted with.
84    /// Call sites that have an AST node available can further attach
85    /// `.with_node(...)` / `.with_span(...)` for caret-annotated output.
86    fn from(e: SymbolTableError) -> Self {
87        crate::error::ExpressionError::new(e.to_string())
88    }
89}
90
91/// Entry in a symbol table: either a nested table or a value.
92#[derive(Debug, Clone)]
93pub enum SymbolTableEntry {
94    Table(SymbolTable),
95    Value(ExprValue),
96}
97
98/// Hierarchical symbol table mapping names to values or nested tables.
99///
100/// Supports dotted paths: `table.set("Param.Frame", 42)`
101/// creates a nested structure `Param -> Frame -> 42`.
102#[derive(Debug, Clone, Default)]
103pub struct SymbolTable {
104    pub(crate) table: HashMap<String, SymbolTableEntry>,
105}
106
107impl SymbolTable {
108    pub fn new() -> Self {
109        Self {
110            table: HashMap::new(),
111        }
112    }
113
114    /// Construct from a list of (dotted_key, value) pairs.
115    /// Build a `SymbolTable` from any iterable of `(dotted_key, value)` pairs.
116    ///
117    /// Accepts any `IntoIterator`, so callers can pass `Vec`, arrays, iterators
118    /// (e.g., from `map`/`filter` chains), or other containers without collecting first.
119    pub fn from_pairs<'a, I>(pairs: I) -> Result<Self, SymbolTableError>
120    where
121        I: IntoIterator<Item = (&'a str, ExprValue)>,
122    {
123        let mut st = Self::new();
124        for (k, v) in pairs {
125            st.set(k, v)?;
126        }
127        Ok(st)
128    }
129
130    /// Set a nested SymbolTable at a key (for dict-like nesting).
131    pub fn set_table(&mut self, key: &str, subtable: SymbolTable) {
132        self.table
133            .insert(key.to_string(), SymbolTableEntry::Table(subtable));
134    }
135
136    /// Get a subtable at a key, or None.
137    pub fn get_table(&self, key: &str) -> Option<&SymbolTable> {
138        match self.get(key) {
139            Some(SymbolTableEntry::Table(t)) => Some(t),
140            _ => None,
141        }
142    }
143
144    /// Set a value at a dotted path, creating intermediate tables as needed.
145    ///
146    /// Accepts anything convertible to `ExprValue` via `Into`:
147    /// - `i32`, `i64` → `ExprValue::Int`
148    /// - `bool` → `ExprValue::Bool`
149    /// - `&str`, `String` → `ExprValue::String`
150    /// - `ExprType` → `ExprValue::Unresolved` (for type-checking symbol tables)
151    /// - `ExprValue` → used directly
152    ///
153    /// For floats, construct `ExprValue::Float(Float64::new(v)?)` explicitly.
154    ///
155    /// Returns an error if an intermediate path component is already set to a
156    /// value (not a table). For example, setting `"A.B.C"` fails if `"A.B"` is
157    /// already a scalar value.
158    pub fn set(&mut self, key: &str, value: impl Into<ExprValue>) -> Result<(), SymbolTableError> {
159        self.set_value(key, value.into())
160    }
161
162    fn set_value(&mut self, key: &str, value: ExprValue) -> Result<(), SymbolTableError> {
163        let parts: Vec<&str> = key.split('.').collect();
164        if parts.len() == 1 {
165            if matches!(self.table.get(key), Some(SymbolTableEntry::Table(_))) {
166                return Err(SymbolTableError {
167                    key: key.to_string(),
168                    conflict: key.to_string(),
169                });
170            }
171            self.table
172                .insert(key.to_string(), SymbolTableEntry::Value(value));
173            return Ok(());
174        }
175        let mut current = self;
176        for &part in &parts[..parts.len() - 1] {
177            let entry = current
178                .table
179                .entry(part.to_string())
180                .or_insert_with(|| SymbolTableEntry::Table(SymbolTable::new()));
181            current = match entry {
182                SymbolTableEntry::Table(t) => t,
183                _ => {
184                    return Err(SymbolTableError {
185                        key: key.to_string(),
186                        conflict: part.to_string(),
187                    })
188                }
189            };
190        }
191        let last = parts.last().unwrap().to_string();
192        if matches!(current.table.get(&last), Some(SymbolTableEntry::Table(_))) {
193            return Err(SymbolTableError {
194                key: key.to_string(),
195                conflict: last,
196            });
197        }
198        current.table.insert(last, SymbolTableEntry::Value(value));
199        Ok(())
200    }
201
202    /// Set a string value at a dotted path (convenience).
203    pub fn set_string(&mut self, key: &str, value: &str) -> Result<(), SymbolTableError> {
204        self.set(key, ExprValue::String(value.to_string()))
205    }
206
207    /// Get an entry at a dotted path.
208    pub fn get(&self, key: &str) -> Option<&SymbolTableEntry> {
209        let parts: Vec<&str> = key.split('.').collect();
210        let mut current = self;
211        for (i, &part) in parts.iter().enumerate() {
212            match current.table.get(part) {
213                Some(SymbolTableEntry::Table(t)) if i < parts.len() - 1 => current = t,
214                Some(entry) if i == parts.len() - 1 => return Some(entry),
215                _ => return None,
216            }
217        }
218        None
219    }
220
221    /// Get a value at a dotted path, returning None if not found or if it's a table.
222    pub fn get_value(&self, key: &str) -> Option<&ExprValue> {
223        match self.get(key) {
224            Some(SymbolTableEntry::Value(v)) => Some(v),
225            _ => None,
226        }
227    }
228
229    /// Get a string value at a dotted path.
230    pub fn get_string(&self, key: &str) -> Option<&str> {
231        match self.get_value(key) {
232            Some(ExprValue::String(s)) => Some(s),
233            Some(ExprValue::Path { value, .. }) => Some(value),
234            _ => None,
235        }
236    }
237
238    pub fn contains(&self, key: &str) -> bool {
239        self.get(key).is_some()
240    }
241
242    /// Top-level keys.
243    pub fn keys(&self) -> impl Iterator<Item = &str> {
244        self.table.keys().map(|s| s.as_str())
245    }
246
247    /// Collect all leaf symbol paths (dotted names) in this table.
248    ///
249    /// If `prefix` is non-empty, each returned path is prefixed with
250    /// `"{prefix}."`. Use `""` for a top-level walk.
251    pub fn all_paths(&self, prefix: &str) -> Vec<String> {
252        let mut out = Vec::new();
253        self.collect_paths(prefix, &mut out);
254        out
255    }
256
257    fn collect_paths(&self, prefix: &str, out: &mut Vec<String>) {
258        for (key, entry) in &self.table {
259            let path = if prefix.is_empty() {
260                key.clone()
261            } else {
262                format!("{prefix}.{key}")
263            };
264            match entry {
265                SymbolTableEntry::Value(_) => out.push(path),
266                SymbolTableEntry::Table(sub) => sub.collect_paths(&path, out),
267            }
268        }
269    }
270
271    /// Merge all entries from `other` into this table, overwriting on conflict.
272    pub fn merge_from(&mut self, other: &SymbolTable) {
273        for (key, entry) in &other.table {
274            match entry {
275                SymbolTableEntry::Value(v) => {
276                    self.table
277                        .insert(key.clone(), SymbolTableEntry::Value(v.clone()));
278                }
279                SymbolTableEntry::Table(sub) => match self.table.get_mut(key) {
280                    Some(SymbolTableEntry::Table(existing)) => existing.merge_from(sub),
281                    _ => {
282                        self.table
283                            .insert(key.clone(), SymbolTableEntry::Table(sub.clone()));
284                    }
285                },
286            }
287        }
288    }
289}
290
291impl serde::Serialize for SymbolTable {
292    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
293        use serde::ser::SerializeSeq;
294        let paths = self.all_paths("");
295        // Collect only resolved values — skip Unresolved entries
296        let entries: Vec<_> = paths
297            .iter()
298            .filter_map(|p| {
299                self.get_value(p).and_then(|v| {
300                    if matches!(v, ExprValue::Unresolved(_)) {
301                        None
302                    } else {
303                        Some((p, v))
304                    }
305                })
306            })
307            .collect();
308        let mut seq = s.serialize_seq(Some(entries.len()))?;
309        for (path, value) in entries {
310            seq.serialize_element(&serde_json::json!({
311                "name": path,
312                "value": value.transport_value(),
313                "type": value.expr_type().to_string(),
314            }))?;
315        }
316        seq.end()
317    }
318}
319
320impl<'de> serde::Deserialize<'de> for SymbolTable {
321    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
322        let arr: Vec<serde_json::Value> = serde::Deserialize::deserialize(d)?;
323        if arr.len() > MAX_SYMBOL_TABLE_ENTRIES {
324            return Err(serde::de::Error::custom(format!(
325                "SymbolTable: too many entries ({}); maximum is {}",
326                arr.len(),
327                MAX_SYMBOL_TABLE_ENTRIES,
328            )));
329        }
330        let mut st = SymbolTable::new();
331        for entry in &arr {
332            let name = entry
333                .get("name")
334                .and_then(|n| n.as_str())
335                .ok_or_else(|| serde::de::Error::missing_field("name"))?;
336            let type_str = entry
337                .get("type")
338                .and_then(|t| t.as_str())
339                .ok_or_else(|| serde::de::Error::missing_field("type"))?;
340            let binding_type = ExprType::parse(type_str).map_err(serde::de::Error::custom)?;
341            let raw_value = entry
342                .get("value")
343                .ok_or_else(|| serde::de::Error::missing_field("value"))?;
344            let value = ExprValue::from_transport_value(
345                raw_value,
346                &binding_type,
347                crate::path_mapping::PathFormat::Posix,
348            )
349            .map_err(serde::de::Error::custom)?;
350            st.set(name, value).map_err(serde::de::Error::custom)?;
351        }
352        Ok(st)
353    }
354}
355
356/// Collect `(&str, ExprValue)` pairs into a `SymbolTable`.
357///
358/// # Panics
359///
360/// Panics if a dotted path conflicts with an existing non-table entry.
361/// Use [`SymbolTable::set`] directly if you need error handling.
362impl<'a> FromIterator<(&'a str, ExprValue)> for SymbolTable {
363    fn from_iter<I: IntoIterator<Item = (&'a str, ExprValue)>>(iter: I) -> Self {
364        let mut st = Self::new();
365        for (k, v) in iter {
366            st.set(k, v)
367                .expect("SymbolTable path conflict in FromIterator");
368        }
369        st
370    }
371}
372
373// ═══════════════════════════════════════════════════════════════
374// SerializedSymbolTable — boundary type between template and session scope
375// ═══════════════════════════════════════════════════════════════
376
377/// A symbol table in JSON transport format.
378///
379/// This is the boundary type between template scope (always Posix paths) and
380/// session scope (host-native paths). The session deserializes it with
381/// `PathFormat::host()`, ensuring path separators match the worker OS.
382///
383/// This mirrors the real-world flow where a scheduler serializes the symbol
384/// table to JSON and sends it to a worker that may be on a different OS.
385#[derive(Debug, Clone, serde::Serialize)]
386#[serde(transparent)]
387pub struct SerializedSymbolTable(serde_json::Value);
388
389impl SerializedSymbolTable {
390    /// Create from a `serde_json::Value` (already-parsed JSON array).
391    pub fn from_value(v: serde_json::Value) -> Self {
392        Self(v)
393    }
394
395    /// Create from a JSON string.
396    pub fn from_json_str(s: &str) -> Result<Self, serde_json::Error> {
397        Ok(Self(serde_json::from_str(s)?))
398    }
399
400    /// Serialize a `SymbolTable` into transport format.
401    pub fn from_symtab(st: &SymbolTable) -> Self {
402        Self(serde_json::to_value(st).expect("SymbolTable serialization cannot fail"))
403    }
404
405    /// Deserialize to a `SymbolTable` with the given path format.
406    ///
407    /// Path values in the transport format are plain strings; this method
408    /// reconstructs them as `ExprValue::Path` with separators normalized
409    /// to the specified format.
410    ///
411    /// Returns an error if the transport array holds more than
412    /// [`MAX_SYMBOL_TABLE_ENTRIES`] entries.
413    pub fn to_symtab(
414        &self,
415        path_format: crate::path_mapping::PathFormat,
416    ) -> Result<SymbolTable, String> {
417        let arr = self
418            .0
419            .as_array()
420            .ok_or("SerializedSymbolTable: expected JSON array")?;
421        if arr.len() > MAX_SYMBOL_TABLE_ENTRIES {
422            return Err(format!(
423                "SerializedSymbolTable: too many entries ({}); maximum is {}",
424                arr.len(),
425                MAX_SYMBOL_TABLE_ENTRIES,
426            ));
427        }
428        let mut st = SymbolTable::new();
429        for entry in arr {
430            let name = entry
431                .get("name")
432                .and_then(|n| n.as_str())
433                .ok_or("SerializedSymbolTable: missing 'name' field")?;
434            let type_str = entry
435                .get("type")
436                .and_then(|t| t.as_str())
437                .ok_or("SerializedSymbolTable: missing 'type' field")?;
438            let binding_type = ExprType::parse(type_str)
439                .map_err(|e| format!("SerializedSymbolTable: bad type '{type_str}': {e}"))?;
440            let raw_value = entry
441                .get("value")
442                .ok_or("SerializedSymbolTable: missing 'value' field")?;
443            let value = ExprValue::from_transport_value(raw_value, &binding_type, path_format)
444                .map_err(|e| format!("SerializedSymbolTable: {e}"))?;
445            st.set(name, value)
446                .map_err(|e| format!("SerializedSymbolTable: {e}"))?;
447        }
448        Ok(st)
449    }
450
451    /// Get the underlying JSON value.
452    pub fn as_value(&self) -> &serde_json::Value {
453        &self.0
454    }
455}
456
457impl<'de> serde::Deserialize<'de> for SerializedSymbolTable {
458    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
459        let v: serde_json::Value = serde::Deserialize::deserialize(d)?;
460        Ok(Self(v))
461    }
462}
463
464/// Construct a [`SymbolTable`] from key-value pairs.
465///
466/// Values can be any type that implements `Into<ExprValue>`:
467/// integers, floats, bools, string literals, `ExprValue`, or
468/// `ExprType` (auto-wrapped as unresolved for type checking).
469///
470/// # Panics
471///
472/// Panics if a dotted path conflicts with an existing non-table entry.
473///
474/// ```
475/// use openjd_expr::{symtab, ExprType};
476///
477/// let st = symtab! {
478///     "Param.Frame" => 42,
479///     "Param.Name" => "test",
480///     "Session.Dir" => ExprType::PATH,
481/// };
482/// ```
483#[macro_export]
484macro_rules! symtab {
485    ($($key:expr => $val:expr),* $(,)?) => {{
486        let mut st = $crate::SymbolTable::new();
487        $(st.set($key, $val).expect("symtab! path conflict");)*
488        st
489    }};
490}