1use std::sync::Arc;
18
19use crate::profile::{Profile, UserConstant};
20use crate::symbol::Symbol;
21use crate::udf::UserFunction;
22
23const SYMBOL_COUNT: usize = 256;
25
26#[derive(Clone, Debug)]
32pub struct SymbolTable {
33 weights: [u32; SYMBOL_COUNT],
35 names: [String; SYMBOL_COUNT],
37}
38
39impl Default for SymbolTable {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SymbolTable {
46 pub fn new() -> Self {
48 let mut weights = [0u32; SYMBOL_COUNT];
49 let mut names: [String; SYMBOL_COUNT] = std::array::from_fn(|_| String::new());
50
51 for &sym in Symbol::constants()
53 .iter()
54 .chain(Symbol::unary_ops().iter())
55 .chain(Symbol::binary_ops().iter())
56 {
57 let idx = sym as usize;
58 weights[idx] = sym.default_weight();
59 names[idx] = sym.name().to_string();
60 }
61
62 for (i, sym) in [
64 Symbol::UserConstant0,
65 Symbol::UserConstant1,
66 Symbol::UserConstant2,
67 Symbol::UserConstant3,
68 Symbol::UserConstant4,
69 Symbol::UserConstant5,
70 Symbol::UserConstant6,
71 Symbol::UserConstant7,
72 Symbol::UserConstant8,
73 Symbol::UserConstant9,
74 Symbol::UserConstant10,
75 Symbol::UserConstant11,
76 Symbol::UserConstant12,
77 Symbol::UserConstant13,
78 Symbol::UserConstant14,
79 Symbol::UserConstant15,
80 ]
81 .iter()
82 .enumerate()
83 {
84 let idx = *sym as usize;
85 weights[idx] = (*sym).default_weight();
86 names[idx] = format!("u{}", i);
87 }
88
89 for (i, sym) in [
91 Symbol::UserFunction0,
92 Symbol::UserFunction1,
93 Symbol::UserFunction2,
94 Symbol::UserFunction3,
95 Symbol::UserFunction4,
96 Symbol::UserFunction5,
97 Symbol::UserFunction6,
98 Symbol::UserFunction7,
99 Symbol::UserFunction8,
100 Symbol::UserFunction9,
101 Symbol::UserFunction10,
102 Symbol::UserFunction11,
103 Symbol::UserFunction12,
104 Symbol::UserFunction13,
105 Symbol::UserFunction14,
106 Symbol::UserFunction15,
107 ]
108 .iter()
109 .enumerate()
110 {
111 let idx = *sym as usize;
112 weights[idx] = (*sym).default_weight();
113 names[idx] = format!("f{}", i);
114 }
115
116 weights[Symbol::X as usize] = Symbol::X.default_weight();
118 names[Symbol::X as usize] = Symbol::X.name().to_string();
119
120 Self { weights, names }
121 }
122
123 pub fn from_profile(profile: &Profile) -> Self {
131 let mut table = Self::new();
132
133 for (&sym, &weight) in &profile.symbol_weights {
135 let idx = sym as usize;
136 if idx < SYMBOL_COUNT {
137 table.weights[idx] = weight;
138 }
139 }
140
141 for (&sym, name) in &profile.symbol_names {
143 let idx = sym as usize;
144 if idx < SYMBOL_COUNT {
145 table.names[idx] = name.clone();
146 }
147 }
148
149 for (i, uc) in profile.constants.iter().enumerate() {
151 if i >= 16 {
152 break;
153 }
154 let sym = match i {
155 0 => Symbol::UserConstant0,
156 1 => Symbol::UserConstant1,
157 2 => Symbol::UserConstant2,
158 3 => Symbol::UserConstant3,
159 4 => Symbol::UserConstant4,
160 5 => Symbol::UserConstant5,
161 6 => Symbol::UserConstant6,
162 7 => Symbol::UserConstant7,
163 8 => Symbol::UserConstant8,
164 9 => Symbol::UserConstant9,
165 10 => Symbol::UserConstant10,
166 11 => Symbol::UserConstant11,
167 12 => Symbol::UserConstant12,
168 13 => Symbol::UserConstant13,
169 14 => Symbol::UserConstant14,
170 15 => Symbol::UserConstant15,
171 _ => continue,
172 };
173 let idx = sym as usize;
174 table.weights[idx] = uc.weight;
175 table.names[idx] = uc.name.clone();
176 }
177
178 for (i, uf) in profile.functions.iter().enumerate() {
180 if i >= 16 {
181 break;
182 }
183 let sym = match i {
184 0 => Symbol::UserFunction0,
185 1 => Symbol::UserFunction1,
186 2 => Symbol::UserFunction2,
187 3 => Symbol::UserFunction3,
188 4 => Symbol::UserFunction4,
189 5 => Symbol::UserFunction5,
190 6 => Symbol::UserFunction6,
191 7 => Symbol::UserFunction7,
192 8 => Symbol::UserFunction8,
193 9 => Symbol::UserFunction9,
194 10 => Symbol::UserFunction10,
195 11 => Symbol::UserFunction11,
196 12 => Symbol::UserFunction12,
197 13 => Symbol::UserFunction13,
198 14 => Symbol::UserFunction14,
199 15 => Symbol::UserFunction15,
200 _ => continue,
201 };
202 let idx = sym as usize;
203 table.weights[idx] = uf.weight as u32;
204 table.names[idx] = uf.name.clone();
205 }
206
207 table
208 }
209
210 pub fn from_parts(
215 profile: &Profile,
216 user_constants: &[UserConstant],
217 user_functions: &[UserFunction],
218 ) -> Self {
219 let mut table = Self::new();
220
221 for (&sym, &weight) in &profile.symbol_weights {
223 let idx = sym as usize;
224 if idx < SYMBOL_COUNT {
225 table.weights[idx] = weight;
226 }
227 }
228
229 for (&sym, name) in &profile.symbol_names {
231 let idx = sym as usize;
232 if idx < SYMBOL_COUNT {
233 table.names[idx] = name.clone();
234 }
235 }
236
237 for (i, uc) in user_constants.iter().enumerate() {
239 if i >= 16 {
240 break;
241 }
242 let sym = user_constant_symbol(i);
243 let idx = sym as usize;
244 table.weights[idx] = uc.weight;
245 table.names[idx] = uc.name.clone();
246 }
247
248 for (i, uf) in user_functions.iter().enumerate() {
250 if i >= 16 {
251 break;
252 }
253 let sym = user_function_symbol(i);
254 let idx = sym as usize;
255 table.weights[idx] = uf.weight as u32;
256 table.names[idx] = uf.name.clone();
257 }
258
259 table
260 }
261
262 #[inline]
264 pub fn weight(&self, sym: Symbol) -> u32 {
265 self.weights[sym as usize]
266 }
267
268 #[inline]
270 pub fn name(&self, sym: Symbol) -> &str {
271 &self.names[sym as usize]
272 }
273
274 pub fn into_shared(self) -> Arc<Self> {
276 Arc::new(self)
277 }
278}
279
280#[inline]
286pub fn user_constant_symbol(index: usize) -> Symbol {
287 user_constant_symbol_opt(index)
288 .unwrap_or_else(|| panic!("User constant index out of bounds: {}", index))
289}
290
291#[inline]
293pub fn user_constant_symbol_opt(index: usize) -> Option<Symbol> {
294 match index {
295 0 => Some(Symbol::UserConstant0),
296 1 => Some(Symbol::UserConstant1),
297 2 => Some(Symbol::UserConstant2),
298 3 => Some(Symbol::UserConstant3),
299 4 => Some(Symbol::UserConstant4),
300 5 => Some(Symbol::UserConstant5),
301 6 => Some(Symbol::UserConstant6),
302 7 => Some(Symbol::UserConstant7),
303 8 => Some(Symbol::UserConstant8),
304 9 => Some(Symbol::UserConstant9),
305 10 => Some(Symbol::UserConstant10),
306 11 => Some(Symbol::UserConstant11),
307 12 => Some(Symbol::UserConstant12),
308 13 => Some(Symbol::UserConstant13),
309 14 => Some(Symbol::UserConstant14),
310 15 => Some(Symbol::UserConstant15),
311 _ => None,
312 }
313}
314
315#[inline]
321pub fn user_function_symbol(index: usize) -> Symbol {
322 user_function_symbol_opt(index)
323 .unwrap_or_else(|| panic!("User function index out of bounds: {}", index))
324}
325
326#[inline]
328pub fn user_function_symbol_opt(index: usize) -> Option<Symbol> {
329 match index {
330 0 => Some(Symbol::UserFunction0),
331 1 => Some(Symbol::UserFunction1),
332 2 => Some(Symbol::UserFunction2),
333 3 => Some(Symbol::UserFunction3),
334 4 => Some(Symbol::UserFunction4),
335 5 => Some(Symbol::UserFunction5),
336 6 => Some(Symbol::UserFunction6),
337 7 => Some(Symbol::UserFunction7),
338 8 => Some(Symbol::UserFunction8),
339 9 => Some(Symbol::UserFunction9),
340 10 => Some(Symbol::UserFunction10),
341 11 => Some(Symbol::UserFunction11),
342 12 => Some(Symbol::UserFunction12),
343 13 => Some(Symbol::UserFunction13),
344 14 => Some(Symbol::UserFunction14),
345 15 => Some(Symbol::UserFunction15),
346 _ => None,
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_default_table() {
356 let table = SymbolTable::new();
357
358 assert_eq!(table.weight(Symbol::One), 10);
360 assert_eq!(table.weight(Symbol::Pi), 14);
361 assert_eq!(table.weight(Symbol::Add), 4);
362
363 assert_eq!(table.name(Symbol::One), "1");
365 assert_eq!(table.name(Symbol::Pi), "pi");
366 assert_eq!(table.name(Symbol::Add), "+");
367 }
368
369 #[test]
370 fn test_profile_overrides() {
371 let mut profile = Profile::new();
372 profile.symbol_weights.insert(Symbol::Pi, 20);
373 profile.symbol_names.insert(Symbol::Pi, "π".to_string());
374
375 let table = SymbolTable::from_profile(&profile);
376
377 assert_eq!(table.weight(Symbol::Pi), 20);
378 assert_eq!(table.name(Symbol::Pi), "π");
379 }
380
381 #[test]
382 fn test_user_constant_overrides() {
383 let mut profile = Profile::new();
384 profile.constants.push(UserConstant {
385 weight: 15,
386 name: "myconst".to_string(),
387 description: "My constant".to_string(),
388 value: 1.234,
389 num_type: crate::symbol::NumType::Transcendental,
390 });
391
392 let table = SymbolTable::from_profile(&profile);
393
394 assert_eq!(table.weight(Symbol::UserConstant0), 15);
395 assert_eq!(table.name(Symbol::UserConstant0), "myconst");
396 }
397
398 #[test]
399 fn test_concurrent_tables_dont_interfere() {
400 let mut profile1 = Profile::new();
402 profile1
403 .symbol_names
404 .insert(Symbol::Pi, "pi_one".to_string());
405
406 let mut profile2 = Profile::new();
407 profile2
408 .symbol_names
409 .insert(Symbol::Pi, "pi_two".to_string());
410
411 let table1 = SymbolTable::from_profile(&profile1);
412 let table2 = SymbolTable::from_profile(&profile2);
413
414 assert_eq!(table1.name(Symbol::Pi), "pi_one");
416 assert_eq!(table2.name(Symbol::Pi), "pi_two");
417
418 assert_eq!(table1.name(Symbol::E), "e");
420 assert_eq!(table2.name(Symbol::E), "e");
421 }
422
423 #[test]
424 fn test_shared_table() {
425 let table = SymbolTable::new().into_shared();
426
427 let table2 = Arc::clone(&table);
429
430 assert_eq!(table.weight(Symbol::One), table2.weight(Symbol::One));
431 assert_eq!(table.name(Symbol::Pi), table2.name(Symbol::Pi));
432 }
433
434 #[test]
435 fn test_expression_formatting_with_different_tables() {
436 use crate::expr::Expression;
437
438 let mut profile2 = Profile::new();
440 profile2.symbol_names.insert(Symbol::Pi, "PI".to_string());
441
442 let table1 = SymbolTable::new();
443 let table2 = SymbolTable::from_profile(&profile2);
444
445 let mut expr = Expression::new();
447 expr.push_with_table(Symbol::X, &table1);
448 expr.push_with_table(Symbol::Pi, &table1);
449 expr.push_with_table(Symbol::Add, &table1);
450
451 let formatted1 = expr.to_infix_with_table(&table1);
454 let formatted2 = expr.to_infix_with_table(&table2);
455
456 assert!(formatted1.contains("pi") || formatted1.contains("x"));
459 assert!(formatted2.contains("PI"));
460
461 assert_ne!(formatted1, formatted2);
463 }
464
465 #[test]
466 fn test_complexity_with_different_tables() {
467 use crate::expr::Expression;
468
469 let mut profile2 = Profile::new();
471 profile2.symbol_weights.insert(Symbol::Pi, 20); let table1 = SymbolTable::new(); let table2 = SymbolTable::from_profile(&profile2);
475
476 assert_eq!(table1.weight(Symbol::Pi), 14); assert_eq!(table2.weight(Symbol::Pi), 20); let mut expr1 = Expression::new();
482 expr1.push_with_table(Symbol::X, &table1); expr1.push_with_table(Symbol::Pi, &table1); expr1.push_with_table(Symbol::Add, &table1); let mut expr2 = Expression::new();
488 expr2.push_with_table(Symbol::X, &table2); expr2.push_with_table(Symbol::Pi, &table2); expr2.push_with_table(Symbol::Add, &table2); assert_eq!(expr1.complexity(), 33);
495 assert_eq!(expr2.complexity(), 39);
496 }
497
498 #[test]
499 fn test_user_constant_symbol_out_of_bounds() {
500 let result = user_constant_symbol_opt(16);
502 assert!(result.is_none(), "Index 16 should return None");
503
504 let result = user_constant_symbol_opt(100);
505 assert!(result.is_none(), "Index 100 should return None");
506
507 let result = user_constant_symbol_opt(0);
509 assert_eq!(result, Some(Symbol::UserConstant0));
510
511 let result = user_constant_symbol_opt(15);
512 assert_eq!(result, Some(Symbol::UserConstant15));
513 }
514
515 #[test]
516 fn test_user_function_symbol_out_of_bounds() {
517 let result = user_function_symbol_opt(16);
519 assert!(result.is_none(), "Index 16 should return None");
520
521 let result = user_function_symbol_opt(100);
522 assert!(result.is_none(), "Index 100 should return None");
523
524 let result = user_function_symbol_opt(0);
526 assert_eq!(result, Some(Symbol::UserFunction0));
527
528 let result = user_function_symbol_opt(15);
529 assert_eq!(result, Some(Symbol::UserFunction15));
530 }
531}