Skip to main content

ries_rs/
symbol_table.rs

1//! Per-run symbol configuration table
2//!
3//! This module provides `SymbolTable`, an immutable configuration container that
4//! stores symbol weights and display names for a single search run. This replaces
5//! the process-global mutable state pattern with per-run configuration, enabling:
6//!
7//! - Concurrent searches with different profiles
8//! - Library usage without side effects
9//! - Reproducible results regardless of process state
10//!
11//! # Thread Safety
12//!
13//! `SymbolTable` is immutable after construction and can be freely shared across
14//! threads via `Arc<SymbolTable>`. Each search run should construct its own table
15//! from the relevant profile.
16
17use std::sync::Arc;
18
19use crate::profile::{Profile, UserConstant};
20use crate::symbol::Symbol;
21use crate::udf::UserFunction;
22
23/// Number of possible symbol values (u8 range)
24const SYMBOL_COUNT: usize = 256;
25
26/// Immutable per-run symbol configuration
27///
28/// Stores weights and display names for all symbols used in a search.
29/// Built from a profile and user-defined constants/functions, then
30/// passed through the search pipeline for consistent behavior.
31#[derive(Clone, Debug)]
32pub struct SymbolTable {
33    /// Complexity weights for each symbol (indexed by symbol byte value)
34    weights: [u32; SYMBOL_COUNT],
35    /// Display names for each symbol (indexed by symbol byte value)
36    names: [String; SYMBOL_COUNT],
37}
38
39impl Default for SymbolTable {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl SymbolTable {
46    /// Create a new symbol table with default weights and names
47    pub fn new() -> Self {
48        let mut weights = [0u32; SYMBOL_COUNT];
49        let mut names: [String; SYMBOL_COUNT] = std::array::from_fn(|_| String::new());
50
51        // Initialize all symbols with their default values
52        for &sym in Symbol::constants()
53            .iter()
54            .chain(Symbol::unary_ops().iter())
55            .chain(Symbol::binary_ops().iter())
56        {
57            let idx = sym as usize;
58            weights[idx] = sym.default_weight();
59            names[idx] = sym.name().to_string();
60        }
61
62        // Initialize user constant placeholders
63        for (i, sym) in [
64            Symbol::UserConstant0,
65            Symbol::UserConstant1,
66            Symbol::UserConstant2,
67            Symbol::UserConstant3,
68            Symbol::UserConstant4,
69            Symbol::UserConstant5,
70            Symbol::UserConstant6,
71            Symbol::UserConstant7,
72            Symbol::UserConstant8,
73            Symbol::UserConstant9,
74            Symbol::UserConstant10,
75            Symbol::UserConstant11,
76            Symbol::UserConstant12,
77            Symbol::UserConstant13,
78            Symbol::UserConstant14,
79            Symbol::UserConstant15,
80        ]
81        .iter()
82        .enumerate()
83        {
84            let idx = *sym as usize;
85            weights[idx] = (*sym).default_weight();
86            names[idx] = format!("u{}", i);
87        }
88
89        // Initialize user function placeholders
90        for (i, sym) in [
91            Symbol::UserFunction0,
92            Symbol::UserFunction1,
93            Symbol::UserFunction2,
94            Symbol::UserFunction3,
95            Symbol::UserFunction4,
96            Symbol::UserFunction5,
97            Symbol::UserFunction6,
98            Symbol::UserFunction7,
99            Symbol::UserFunction8,
100            Symbol::UserFunction9,
101            Symbol::UserFunction10,
102            Symbol::UserFunction11,
103            Symbol::UserFunction12,
104            Symbol::UserFunction13,
105            Symbol::UserFunction14,
106            Symbol::UserFunction15,
107        ]
108        .iter()
109        .enumerate()
110        {
111            let idx = *sym as usize;
112            weights[idx] = (*sym).default_weight();
113            names[idx] = format!("f{}", i);
114        }
115
116        // Initialize X (variable) - not included in constants()
117        weights[Symbol::X as usize] = Symbol::X.default_weight();
118        names[Symbol::X as usize] = Symbol::X.name().to_string();
119
120        Self { weights, names }
121    }
122
123    /// Build a symbol table from a profile
124    ///
125    /// Applies:
126    /// - Profile weight overrides
127    /// - Profile name overrides
128    /// - User constant names and weights
129    /// - User function names and weights
130    pub fn from_profile(profile: &Profile) -> Self {
131        let mut table = Self::new();
132
133        // Apply profile weight overrides
134        for (&sym, &weight) in &profile.symbol_weights {
135            let idx = sym as usize;
136            if idx < SYMBOL_COUNT {
137                table.weights[idx] = weight;
138            }
139        }
140
141        // Apply profile name overrides
142        for (&sym, name) in &profile.symbol_names {
143            let idx = sym as usize;
144            if idx < SYMBOL_COUNT {
145                table.names[idx] = name.clone();
146            }
147        }
148
149        // Apply user constant names and weights
150        for (i, uc) in profile.constants.iter().enumerate() {
151            if i >= 16 {
152                break;
153            }
154            let sym = match i {
155                0 => Symbol::UserConstant0,
156                1 => Symbol::UserConstant1,
157                2 => Symbol::UserConstant2,
158                3 => Symbol::UserConstant3,
159                4 => Symbol::UserConstant4,
160                5 => Symbol::UserConstant5,
161                6 => Symbol::UserConstant6,
162                7 => Symbol::UserConstant7,
163                8 => Symbol::UserConstant8,
164                9 => Symbol::UserConstant9,
165                10 => Symbol::UserConstant10,
166                11 => Symbol::UserConstant11,
167                12 => Symbol::UserConstant12,
168                13 => Symbol::UserConstant13,
169                14 => Symbol::UserConstant14,
170                15 => Symbol::UserConstant15,
171                _ => continue,
172            };
173            let idx = sym as usize;
174            table.weights[idx] = uc.weight;
175            table.names[idx] = uc.name.clone();
176        }
177
178        // Apply user function names and weights
179        for (i, uf) in profile.functions.iter().enumerate() {
180            if i >= 16 {
181                break;
182            }
183            let sym = match i {
184                0 => Symbol::UserFunction0,
185                1 => Symbol::UserFunction1,
186                2 => Symbol::UserFunction2,
187                3 => Symbol::UserFunction3,
188                4 => Symbol::UserFunction4,
189                5 => Symbol::UserFunction5,
190                6 => Symbol::UserFunction6,
191                7 => Symbol::UserFunction7,
192                8 => Symbol::UserFunction8,
193                9 => Symbol::UserFunction9,
194                10 => Symbol::UserFunction10,
195                11 => Symbol::UserFunction11,
196                12 => Symbol::UserFunction12,
197                13 => Symbol::UserFunction13,
198                14 => Symbol::UserFunction14,
199                15 => Symbol::UserFunction15,
200                _ => continue,
201            };
202            let idx = sym as usize;
203            table.weights[idx] = uf.weight as u32;
204            table.names[idx] = uf.name.clone();
205        }
206
207        table
208    }
209
210    /// Build from profile with explicit user constants and functions
211    ///
212    /// This is useful when user constants/functions come from CLI args
213    /// rather than a profile file.
214    pub fn from_parts(
215        profile: &Profile,
216        user_constants: &[UserConstant],
217        user_functions: &[UserFunction],
218    ) -> Self {
219        let mut table = Self::new();
220
221        // Apply profile weight overrides
222        for (&sym, &weight) in &profile.symbol_weights {
223            let idx = sym as usize;
224            if idx < SYMBOL_COUNT {
225                table.weights[idx] = weight;
226            }
227        }
228
229        // Apply profile name overrides
230        for (&sym, name) in &profile.symbol_names {
231            let idx = sym as usize;
232            if idx < SYMBOL_COUNT {
233                table.names[idx] = name.clone();
234            }
235        }
236
237        // Apply user constant names and weights
238        for (i, uc) in user_constants.iter().enumerate() {
239            if i >= 16 {
240                break;
241            }
242            let sym = user_constant_symbol(i);
243            let idx = sym as usize;
244            table.weights[idx] = uc.weight;
245            table.names[idx] = uc.name.clone();
246        }
247
248        // Apply user function names and weights
249        for (i, uf) in user_functions.iter().enumerate() {
250            if i >= 16 {
251                break;
252            }
253            let sym = user_function_symbol(i);
254            let idx = sym as usize;
255            table.weights[idx] = uf.weight as u32;
256            table.names[idx] = uf.name.clone();
257        }
258
259        table
260    }
261
262    /// Get the weight for a symbol
263    #[inline]
264    pub fn weight(&self, sym: Symbol) -> u32 {
265        self.weights[sym as usize]
266    }
267
268    /// Get the display name for a symbol
269    #[inline]
270    pub fn name(&self, sym: Symbol) -> &str {
271        &self.names[sym as usize]
272    }
273
274    /// Wrap this table in an Arc for sharing
275    pub fn into_shared(self) -> Arc<Self> {
276        Arc::new(self)
277    }
278}
279
280/// Get the user constant symbol for an index (0-15)
281///
282/// # Panics
283///
284/// Panics if index >= 16. Use `user_constant_symbol_opt` for a non-panicking version.
285#[inline]
286pub fn user_constant_symbol(index: usize) -> Symbol {
287    user_constant_symbol_opt(index)
288        .unwrap_or_else(|| panic!("User constant index out of bounds: {}", index))
289}
290
291/// Get the user constant symbol for an index (0-15), returning None if out of bounds
292#[inline]
293pub fn user_constant_symbol_opt(index: usize) -> Option<Symbol> {
294    match index {
295        0 => Some(Symbol::UserConstant0),
296        1 => Some(Symbol::UserConstant1),
297        2 => Some(Symbol::UserConstant2),
298        3 => Some(Symbol::UserConstant3),
299        4 => Some(Symbol::UserConstant4),
300        5 => Some(Symbol::UserConstant5),
301        6 => Some(Symbol::UserConstant6),
302        7 => Some(Symbol::UserConstant7),
303        8 => Some(Symbol::UserConstant8),
304        9 => Some(Symbol::UserConstant9),
305        10 => Some(Symbol::UserConstant10),
306        11 => Some(Symbol::UserConstant11),
307        12 => Some(Symbol::UserConstant12),
308        13 => Some(Symbol::UserConstant13),
309        14 => Some(Symbol::UserConstant14),
310        15 => Some(Symbol::UserConstant15),
311        _ => None,
312    }
313}
314
315/// Get the user function symbol for an index (0-15)
316///
317/// # Panics
318///
319/// Panics if index >= 16. Use `user_function_symbol_opt` for a non-panicking version.
320#[inline]
321pub fn user_function_symbol(index: usize) -> Symbol {
322    user_function_symbol_opt(index)
323        .unwrap_or_else(|| panic!("User function index out of bounds: {}", index))
324}
325
326/// Get the user function symbol for an index (0-15), returning None if out of bounds
327#[inline]
328pub fn user_function_symbol_opt(index: usize) -> Option<Symbol> {
329    match index {
330        0 => Some(Symbol::UserFunction0),
331        1 => Some(Symbol::UserFunction1),
332        2 => Some(Symbol::UserFunction2),
333        3 => Some(Symbol::UserFunction3),
334        4 => Some(Symbol::UserFunction4),
335        5 => Some(Symbol::UserFunction5),
336        6 => Some(Symbol::UserFunction6),
337        7 => Some(Symbol::UserFunction7),
338        8 => Some(Symbol::UserFunction8),
339        9 => Some(Symbol::UserFunction9),
340        10 => Some(Symbol::UserFunction10),
341        11 => Some(Symbol::UserFunction11),
342        12 => Some(Symbol::UserFunction12),
343        13 => Some(Symbol::UserFunction13),
344        14 => Some(Symbol::UserFunction14),
345        15 => Some(Symbol::UserFunction15),
346        _ => None,
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_default_table() {
356        let table = SymbolTable::new();
357
358        // Check some default weights (matching original RIES calibration)
359        assert_eq!(table.weight(Symbol::One), 10);
360        assert_eq!(table.weight(Symbol::Pi), 14);
361        assert_eq!(table.weight(Symbol::Add), 4);
362
363        // Check some default names
364        assert_eq!(table.name(Symbol::One), "1");
365        assert_eq!(table.name(Symbol::Pi), "pi");
366        assert_eq!(table.name(Symbol::Add), "+");
367    }
368
369    #[test]
370    fn test_profile_overrides() {
371        let mut profile = Profile::new();
372        profile.symbol_weights.insert(Symbol::Pi, 20);
373        profile.symbol_names.insert(Symbol::Pi, "π".to_string());
374
375        let table = SymbolTable::from_profile(&profile);
376
377        assert_eq!(table.weight(Symbol::Pi), 20);
378        assert_eq!(table.name(Symbol::Pi), "π");
379    }
380
381    #[test]
382    fn test_user_constant_overrides() {
383        let mut profile = Profile::new();
384        profile.constants.push(UserConstant {
385            weight: 15,
386            name: "myconst".to_string(),
387            description: "My constant".to_string(),
388            value: 1.234,
389            num_type: crate::symbol::NumType::Transcendental,
390        });
391
392        let table = SymbolTable::from_profile(&profile);
393
394        assert_eq!(table.weight(Symbol::UserConstant0), 15);
395        assert_eq!(table.name(Symbol::UserConstant0), "myconst");
396    }
397
398    #[test]
399    fn test_concurrent_tables_dont_interfere() {
400        // Create two tables with different configurations
401        let mut profile1 = Profile::new();
402        profile1
403            .symbol_names
404            .insert(Symbol::Pi, "pi_one".to_string());
405
406        let mut profile2 = Profile::new();
407        profile2
408            .symbol_names
409            .insert(Symbol::Pi, "pi_two".to_string());
410
411        let table1 = SymbolTable::from_profile(&profile1);
412        let table2 = SymbolTable::from_profile(&profile2);
413
414        // Verify they have different names for Pi
415        assert_eq!(table1.name(Symbol::Pi), "pi_one");
416        assert_eq!(table2.name(Symbol::Pi), "pi_two");
417
418        // Verify they still work independently
419        assert_eq!(table1.name(Symbol::E), "e");
420        assert_eq!(table2.name(Symbol::E), "e");
421    }
422
423    #[test]
424    fn test_shared_table() {
425        let table = SymbolTable::new().into_shared();
426
427        // Can clone Arc cheaply
428        let table2 = Arc::clone(&table);
429
430        assert_eq!(table.weight(Symbol::One), table2.weight(Symbol::One));
431        assert_eq!(table.name(Symbol::Pi), table2.name(Symbol::Pi));
432    }
433
434    #[test]
435    fn test_expression_formatting_with_different_tables() {
436        use crate::expr::Expression;
437
438        // Create two tables with different names for pi
439        let mut profile2 = Profile::new();
440        profile2.symbol_names.insert(Symbol::Pi, "PI".to_string());
441
442        let table1 = SymbolTable::new();
443        let table2 = SymbolTable::from_profile(&profile2);
444
445        // Build expression using table1: x + pi (postfix: X Pi Add)
446        let mut expr = Expression::new();
447        expr.push_with_table(Symbol::X, &table1);
448        expr.push_with_table(Symbol::Pi, &table1);
449        expr.push_with_table(Symbol::Add, &table1);
450
451        // Format with different tables - the key insight is that the same expression
452        // can be formatted differently based on the table used
453        let formatted1 = expr.to_infix_with_table(&table1);
454        let formatted2 = expr.to_infix_with_table(&table2);
455
456        // With default table, pi is "pi"
457        // With table2, pi is "PI"
458        assert!(formatted1.contains("pi") || formatted1.contains("x"));
459        assert!(formatted2.contains("PI"));
460
461        // Verify the tables are independent - no global state pollution
462        assert_ne!(formatted1, formatted2);
463    }
464
465    #[test]
466    fn test_complexity_with_different_tables() {
467        use crate::expr::Expression;
468
469        // Create two tables with different weights for Pi
470        let mut profile2 = Profile::new();
471        profile2.symbol_weights.insert(Symbol::Pi, 20); // Heavier weight
472
473        let table1 = SymbolTable::new(); // Default weights
474        let table2 = SymbolTable::from_profile(&profile2);
475
476        // Verify the tables have different weights for Pi
477        assert_eq!(table1.weight(Symbol::Pi), 14); // default = original RIES value
478        assert_eq!(table2.weight(Symbol::Pi), 20); // overridden
479
480        // Build expressions using each table
481        let mut expr1 = Expression::new();
482        expr1.push_with_table(Symbol::X, &table1); // 15
483        expr1.push_with_table(Symbol::Pi, &table1); // 14
484        expr1.push_with_table(Symbol::Add, &table1); // 4
485                                                     // Total: 15 + 14 + 4 = 33
486
487        let mut expr2 = Expression::new();
488        expr2.push_with_table(Symbol::X, &table2); // 15
489        expr2.push_with_table(Symbol::Pi, &table2); // 20
490        expr2.push_with_table(Symbol::Add, &table2); // 4
491                                                     // Total: 15 + 20 + 4 = 39
492
493        // Same symbols, different complexity due to different tables
494        assert_eq!(expr1.complexity(), 33);
495        assert_eq!(expr2.complexity(), 39);
496    }
497
498    #[test]
499    fn test_user_constant_symbol_out_of_bounds() {
500        // Test that out-of-bounds indices return None instead of panicking
501        let result = user_constant_symbol_opt(16);
502        assert!(result.is_none(), "Index 16 should return None");
503
504        let result = user_constant_symbol_opt(100);
505        assert!(result.is_none(), "Index 100 should return None");
506
507        // Valid indices should work
508        let result = user_constant_symbol_opt(0);
509        assert_eq!(result, Some(Symbol::UserConstant0));
510
511        let result = user_constant_symbol_opt(15);
512        assert_eq!(result, Some(Symbol::UserConstant15));
513    }
514
515    #[test]
516    fn test_user_function_symbol_out_of_bounds() {
517        // Test that out-of-bounds indices return None instead of panicking
518        let result = user_function_symbol_opt(16);
519        assert!(result.is_none(), "Index 16 should return None");
520
521        let result = user_function_symbol_opt(100);
522        assert!(result.is_none(), "Index 100 should return None");
523
524        // Valid indices should work
525        let result = user_function_symbol_opt(0);
526        assert_eq!(result, Some(Symbol::UserFunction0));
527
528        let result = user_function_symbol_opt(15);
529        assert_eq!(result, Some(Symbol::UserFunction15));
530    }
531}