mathhook_core/core/
symbol.rs1use crate::core::commutativity::Commutativity;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
13pub enum SymbolType {
14 #[default]
19 Scalar,
20
21 Matrix,
26
27 Operator,
32
33 Quaternion,
38}
39
40static SYMBOL_CACHE: Mutex<Option<HashMap<String, Arc<str>>>> = Mutex::new(None);
42
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub struct Symbol {
46 pub name: Arc<str>,
47 symbol_type: SymbolType,
48}
49
50impl Symbol {
51 #[inline]
65 pub fn new<S: AsRef<str>>(name: S) -> Self {
66 Self::scalar(name)
67 }
68
69 pub fn scalar<S: AsRef<str>>(name: S) -> Self {
81 let name_str = name.as_ref();
82
83 let interned_name = match name_str {
84 "x" | "y" | "z" | "a" | "b" | "c" | "t" | "n" | "i" | "j" | "k" => match name_str {
85 "x" => {
86 static X_SYMBOL: std::sync::OnceLock<Arc<str>> = std::sync::OnceLock::new();
87 X_SYMBOL.get_or_init(|| "x".into()).clone()
88 }
89 "y" => {
90 static Y_SYMBOL: std::sync::OnceLock<Arc<str>> = std::sync::OnceLock::new();
91 Y_SYMBOL.get_or_init(|| "y".into()).clone()
92 }
93 "z" => {
94 static Z_SYMBOL: std::sync::OnceLock<Arc<str>> = std::sync::OnceLock::new();
95 Z_SYMBOL.get_or_init(|| "z".into()).clone()
96 }
97 _ => Self::intern_symbol(name_str),
98 },
99 _ => Self::intern_symbol(name_str),
100 };
101
102 Self {
103 name: interned_name,
104 symbol_type: SymbolType::Scalar,
105 }
106 }
107
108 pub fn matrix<S: AsRef<str>>(name: S) -> Self {
120 let name_str = name.as_ref();
121 Self {
122 name: Self::intern_symbol(name_str),
123 symbol_type: SymbolType::Matrix,
124 }
125 }
126
127 pub fn operator<S: AsRef<str>>(name: S) -> Self {
139 let name_str = name.as_ref();
140 Self {
141 name: Self::intern_symbol(name_str),
142 symbol_type: SymbolType::Operator,
143 }
144 }
145
146 pub fn quaternion<S: AsRef<str>>(name: S) -> Self {
158 let name_str = name.as_ref();
159 Self {
160 name: Self::intern_symbol(name_str),
161 symbol_type: SymbolType::Quaternion,
162 }
163 }
164
165 fn intern_symbol(name: &str) -> Arc<str> {
167 let mut cache_guard = SYMBOL_CACHE
168 .lock()
169 .expect("BUG: Symbol cache lock poisoned - indicates panic during symbol interning in another thread");
170 let cache = cache_guard.get_or_insert_with(HashMap::new);
171
172 if let Some(existing) = cache.get(name) {
173 existing.clone()
174 } else {
175 let arc_str: Arc<str> = name.into();
176 cache.insert(name.to_owned(), arc_str.clone());
177 arc_str
178 }
179 }
180
181 #[inline]
192 pub fn name(&self) -> &str {
193 &self.name
194 }
195
196 #[inline]
210 pub fn symbol_type(&self) -> SymbolType {
211 self.symbol_type
212 }
213
214 #[inline]
229 pub fn commutativity(&self) -> Commutativity {
230 match self.symbol_type {
231 SymbolType::Scalar => Commutativity::Commutative,
232 SymbolType::Matrix | SymbolType::Operator | SymbolType::Quaternion => {
233 Commutativity::Noncommutative
234 }
235 }
236 }
237}
238
239impl Serialize for Symbol {
240 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
241 where
242 S: Serializer,
243 {
244 use serde::ser::SerializeStruct;
245 let mut state = serializer.serialize_struct("Symbol", 2)?;
246 state.serialize_field("name", &*self.name)?;
247 state.serialize_field("symbol_type", &self.symbol_type)?;
248 state.end()
249 }
250}
251
252impl<'de> Deserialize<'de> for Symbol {
253 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
254 where
255 D: Deserializer<'de>,
256 {
257 use serde::de::{self, MapAccess, Visitor};
258
259 struct SymbolVisitor;
260
261 impl<'de> Visitor<'de> for SymbolVisitor {
262 type Value = Symbol;
263
264 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
265 formatter.write_str("a Symbol struct or string")
266 }
267
268 fn visit_str<E>(self, value: &str) -> Result<Symbol, E>
269 where
270 E: de::Error,
271 {
272 Ok(Symbol::new(value))
273 }
274
275 fn visit_map<M>(self, mut map: M) -> Result<Symbol, M::Error>
276 where
277 M: MapAccess<'de>,
278 {
279 let mut name: Option<String> = None;
280 let mut symbol_type: Option<SymbolType> = None;
281
282 while let Some(key) = map.next_key::<String>()? {
283 match key.as_str() {
284 "name" => {
285 name = Some(map.next_value()?);
286 }
287 "symbol_type" => {
288 symbol_type = Some(map.next_value()?);
289 }
290 _ => {
291 let _: serde::de::IgnoredAny = map.next_value()?;
292 }
293 }
294 }
295
296 let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
297 let symbol_type = symbol_type.unwrap_or_default();
298
299 let interned_name = Symbol::intern_symbol(&name);
300 Ok(Symbol {
301 name: interned_name,
302 symbol_type,
303 })
304 }
305 }
306
307 deserializer.deserialize_any(SymbolVisitor)
308 }
309}
310
311impl From<&str> for Symbol {
312 fn from(name: &str) -> Self {
313 Self::new(name)
314 }
315}
316
317impl From<String> for Symbol {
318 fn from(name: String) -> Self {
319 Self::new(name)
320 }
321}
322
323#[cfg(test)]
324mod symbol_type_tests {
325 use super::*;
326
327 #[test]
328 fn test_scalar_is_commutative() {
329 let x = Symbol::scalar("x");
330 assert_eq!(x.symbol_type(), SymbolType::Scalar);
331 assert_eq!(x.commutativity(), Commutativity::Commutative);
332 }
333
334 #[test]
335 fn test_matrix_is_noncommutative() {
336 let a = Symbol::matrix("A");
337 assert_eq!(a.symbol_type(), SymbolType::Matrix);
338 assert_eq!(a.commutativity(), Commutativity::Noncommutative);
339 }
340
341 #[test]
342 fn test_operator_is_noncommutative() {
343 let p = Symbol::operator("p");
344 assert_eq!(p.symbol_type(), SymbolType::Operator);
345 assert_eq!(p.commutativity(), Commutativity::Noncommutative);
346 }
347
348 #[test]
349 fn test_quaternion_is_noncommutative() {
350 let i = Symbol::quaternion("i");
351 assert_eq!(i.symbol_type(), SymbolType::Quaternion);
352 assert_eq!(i.commutativity(), Commutativity::Noncommutative);
353 }
354
355 #[test]
356 fn test_default_symbol_type_is_scalar() {
357 assert_eq!(SymbolType::default(), SymbolType::Scalar);
358 }
359
360 #[test]
361 fn test_backward_compatibility() {
362 let x = Symbol::new("x");
363 assert_eq!(x.symbol_type(), SymbolType::Scalar);
364 assert_eq!(x.commutativity(), Commutativity::Commutative);
365 }
366}