mathhook_core/core/
symbol.rs

1//! Symbol type for variables and identifiers
2
3use crate::core::commutativity::Commutativity;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7
8/// Type of symbol (determines commutativity)
9///
10/// Symbols can represent different mathematical objects with different algebraic properties.
11/// The symbol type determines whether operations involving this symbol are commutative.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
13pub enum SymbolType {
14    /// Scalar variable (default) - commutative
15    ///
16    /// Examples: x, y, z, theta
17    /// Properties: x*y = y*x
18    #[default]
19    Scalar,
20
21    /// Matrix variable - noncommutative
22    ///
23    /// Examples: A, B, M (typically uppercase)
24    /// Properties: A*B ≠ B*A in general
25    Matrix,
26
27    /// Quantum operator - noncommutative
28    ///
29    /// Examples: x, p, H (position, momentum, Hamiltonian)
30    /// Properties: `[x,p]` = xp - px ≠ 0
31    Operator,
32
33    /// Quaternion - noncommutative
34    ///
35    /// Examples: i, j, k
36    /// Properties: ij = k, ji = -k
37    Quaternion,
38}
39
40/// Global symbol interning cache to avoid duplicate Arc allocations
41static SYMBOL_CACHE: Mutex<Option<HashMap<String, Arc<str>>>> = Mutex::new(None);
42
43/// Mathematical symbol/variable with efficient string sharing
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub struct Symbol {
46    pub name: Arc<str>,
47    symbol_type: SymbolType,
48}
49
50impl Symbol {
51    /// Create a new scalar symbol (default behavior, backward compatible)
52    ///
53    /// **Note**: Prefer using `symbol!(x)` macro in application code.
54    /// This method is kept for backward compatibility and internal use.
55    ///
56    /// # Examples
57    ///
58    /// ```rust
59    /// use mathhook_core::symbol;
60    ///
61    /// let x = symbol!(x);
62    /// let alpha = symbol!(alpha);
63    /// ```
64    #[inline]
65    pub fn new<S: AsRef<str>>(name: S) -> Self {
66        Self::scalar(name)
67    }
68
69    /// Create a scalar symbol (commutative)
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use mathhook_core::core::symbol::Symbol;
75    /// use mathhook_core::core::commutativity::Commutativity;
76    ///
77    /// let x = Symbol::scalar("x");
78    /// assert_eq!(x.commutativity(), Commutativity::Commutative);
79    /// ```
80    pub fn scalar<S: AsRef<str>>(name: S) -> Self {
81        let name_str = name.as_ref();
82
83        let interned_name = match name_str {
84            "x" | "y" | "z" | "a" | "b" | "c" | "t" | "n" | "i" | "j" | "k" => match name_str {
85                "x" => {
86                    static X_SYMBOL: std::sync::OnceLock<Arc<str>> = std::sync::OnceLock::new();
87                    X_SYMBOL.get_or_init(|| "x".into()).clone()
88                }
89                "y" => {
90                    static Y_SYMBOL: std::sync::OnceLock<Arc<str>> = std::sync::OnceLock::new();
91                    Y_SYMBOL.get_or_init(|| "y".into()).clone()
92                }
93                "z" => {
94                    static Z_SYMBOL: std::sync::OnceLock<Arc<str>> = std::sync::OnceLock::new();
95                    Z_SYMBOL.get_or_init(|| "z".into()).clone()
96                }
97                _ => Self::intern_symbol(name_str),
98            },
99            _ => Self::intern_symbol(name_str),
100        };
101
102        Self {
103            name: interned_name,
104            symbol_type: SymbolType::Scalar,
105        }
106    }
107
108    /// Create a matrix symbol (noncommutative)
109    ///
110    /// # Examples
111    ///
112    /// ```
113    /// use mathhook_core::core::symbol::Symbol;
114    /// use mathhook_core::core::commutativity::Commutativity;
115    ///
116    /// let A = Symbol::matrix("A");
117    /// assert_eq!(A.commutativity(), Commutativity::Noncommutative);
118    /// ```
119    pub fn matrix<S: AsRef<str>>(name: S) -> Self {
120        let name_str = name.as_ref();
121        Self {
122            name: Self::intern_symbol(name_str),
123            symbol_type: SymbolType::Matrix,
124        }
125    }
126
127    /// Create an operator symbol (noncommutative)
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// use mathhook_core::core::symbol::Symbol;
133    /// use mathhook_core::core::commutativity::Commutativity;
134    ///
135    /// let p = Symbol::operator("p");
136    /// assert_eq!(p.commutativity(), Commutativity::Noncommutative);
137    /// ```
138    pub fn operator<S: AsRef<str>>(name: S) -> Self {
139        let name_str = name.as_ref();
140        Self {
141            name: Self::intern_symbol(name_str),
142            symbol_type: SymbolType::Operator,
143        }
144    }
145
146    /// Create a quaternion symbol (noncommutative)
147    ///
148    /// # Examples
149    ///
150    /// ```
151    /// use mathhook_core::core::symbol::Symbol;
152    /// use mathhook_core::core::commutativity::Commutativity;
153    ///
154    /// let i = Symbol::quaternion("i");
155    /// assert_eq!(i.commutativity(), Commutativity::Noncommutative);
156    /// ```
157    pub fn quaternion<S: AsRef<str>>(name: S) -> Self {
158        let name_str = name.as_ref();
159        Self {
160            name: Self::intern_symbol(name_str),
161            symbol_type: SymbolType::Quaternion,
162        }
163    }
164
165    /// Internal method to intern symbols using the global cache
166    fn intern_symbol(name: &str) -> Arc<str> {
167        let mut cache_guard = SYMBOL_CACHE
168            .lock()
169            .expect("BUG: Symbol cache lock poisoned - indicates panic during symbol interning in another thread");
170        let cache = cache_guard.get_or_insert_with(HashMap::new);
171
172        if let Some(existing) = cache.get(name) {
173            existing.clone()
174        } else {
175            let arc_str: Arc<str> = name.into();
176            cache.insert(name.to_owned(), arc_str.clone());
177            arc_str
178        }
179    }
180
181    /// Get the symbol name
182    ///
183    /// # Examples
184    ///
185    /// ```rust
186    /// use mathhook_core::symbol;
187    ///
188    /// let x = symbol!(x);
189    /// assert_eq!(x.name(), "x");
190    /// ```
191    #[inline]
192    pub fn name(&self) -> &str {
193        &self.name
194    }
195
196    /// Get the type of this symbol
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// use mathhook_core::core::symbol::{Symbol, SymbolType};
202    ///
203    /// let x = Symbol::scalar("x");
204    /// assert_eq!(x.symbol_type(), SymbolType::Scalar);
205    ///
206    /// let A = Symbol::matrix("A");
207    /// assert_eq!(A.symbol_type(), SymbolType::Matrix);
208    /// ```
209    #[inline]
210    pub fn symbol_type(&self) -> SymbolType {
211        self.symbol_type
212    }
213
214    /// Get commutativity of this symbol
215    ///
216    /// # Examples
217    ///
218    /// ```
219    /// use mathhook_core::core::symbol::Symbol;
220    /// use mathhook_core::core::commutativity::Commutativity;
221    ///
222    /// let x = Symbol::scalar("x");
223    /// assert_eq!(x.commutativity(), Commutativity::Commutative);
224    ///
225    /// let A = Symbol::matrix("A");
226    /// assert_eq!(A.commutativity(), Commutativity::Noncommutative);
227    /// ```
228    #[inline]
229    pub fn commutativity(&self) -> Commutativity {
230        match self.symbol_type {
231            SymbolType::Scalar => Commutativity::Commutative,
232            SymbolType::Matrix | SymbolType::Operator | SymbolType::Quaternion => {
233                Commutativity::Noncommutative
234            }
235        }
236    }
237}
238
239impl Serialize for Symbol {
240    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
241    where
242        S: Serializer,
243    {
244        use serde::ser::SerializeStruct;
245        let mut state = serializer.serialize_struct("Symbol", 2)?;
246        state.serialize_field("name", &*self.name)?;
247        state.serialize_field("symbol_type", &self.symbol_type)?;
248        state.end()
249    }
250}
251
252impl<'de> Deserialize<'de> for Symbol {
253    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
254    where
255        D: Deserializer<'de>,
256    {
257        use serde::de::{self, MapAccess, Visitor};
258
259        struct SymbolVisitor;
260
261        impl<'de> Visitor<'de> for SymbolVisitor {
262            type Value = Symbol;
263
264            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
265                formatter.write_str("a Symbol struct or string")
266            }
267
268            fn visit_str<E>(self, value: &str) -> Result<Symbol, E>
269            where
270                E: de::Error,
271            {
272                Ok(Symbol::new(value))
273            }
274
275            fn visit_map<M>(self, mut map: M) -> Result<Symbol, M::Error>
276            where
277                M: MapAccess<'de>,
278            {
279                let mut name: Option<String> = None;
280                let mut symbol_type: Option<SymbolType> = None;
281
282                while let Some(key) = map.next_key::<String>()? {
283                    match key.as_str() {
284                        "name" => {
285                            name = Some(map.next_value()?);
286                        }
287                        "symbol_type" => {
288                            symbol_type = Some(map.next_value()?);
289                        }
290                        _ => {
291                            let _: serde::de::IgnoredAny = map.next_value()?;
292                        }
293                    }
294                }
295
296                let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
297                let symbol_type = symbol_type.unwrap_or_default();
298
299                let interned_name = Symbol::intern_symbol(&name);
300                Ok(Symbol {
301                    name: interned_name,
302                    symbol_type,
303                })
304            }
305        }
306
307        deserializer.deserialize_any(SymbolVisitor)
308    }
309}
310
311impl From<&str> for Symbol {
312    fn from(name: &str) -> Self {
313        Self::new(name)
314    }
315}
316
317impl From<String> for Symbol {
318    fn from(name: String) -> Self {
319        Self::new(name)
320    }
321}
322
323#[cfg(test)]
324mod symbol_type_tests {
325    use super::*;
326
327    #[test]
328    fn test_scalar_is_commutative() {
329        let x = Symbol::scalar("x");
330        assert_eq!(x.symbol_type(), SymbolType::Scalar);
331        assert_eq!(x.commutativity(), Commutativity::Commutative);
332    }
333
334    #[test]
335    fn test_matrix_is_noncommutative() {
336        let a = Symbol::matrix("A");
337        assert_eq!(a.symbol_type(), SymbolType::Matrix);
338        assert_eq!(a.commutativity(), Commutativity::Noncommutative);
339    }
340
341    #[test]
342    fn test_operator_is_noncommutative() {
343        let p = Symbol::operator("p");
344        assert_eq!(p.symbol_type(), SymbolType::Operator);
345        assert_eq!(p.commutativity(), Commutativity::Noncommutative);
346    }
347
348    #[test]
349    fn test_quaternion_is_noncommutative() {
350        let i = Symbol::quaternion("i");
351        assert_eq!(i.symbol_type(), SymbolType::Quaternion);
352        assert_eq!(i.commutativity(), Commutativity::Noncommutative);
353    }
354
355    #[test]
356    fn test_default_symbol_type_is_scalar() {
357        assert_eq!(SymbolType::default(), SymbolType::Scalar);
358    }
359
360    #[test]
361    fn test_backward_compatibility() {
362        let x = Symbol::new("x");
363        assert_eq!(x.symbol_type(), SymbolType::Scalar);
364        assert_eq!(x.commutativity(), Commutativity::Commutative);
365    }
366}