1use std::fmt;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum Seft {
10 A,
12 B,
14 C,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20#[repr(u8)]
21pub enum NumType {
22 Transcendental = 0,
24 Liouvillian = 1,
26 Elementary = 2,
28 Algebraic = 3,
30 Constructible = 4,
32 Rational = 5,
34 Integer = 6,
36}
37
38impl NumType {
39 #[inline]
41 pub fn combine(self, other: Self) -> Self {
42 std::cmp::min(self, other)
43 }
44
45 #[inline]
47 pub fn is_at_least(self, required: Self) -> bool {
48 self >= required
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
54#[repr(u8)]
55pub enum Symbol {
56 One = b'1',
58 Two = b'2',
59 Three = b'3',
60 Four = b'4',
61 Five = b'5',
62 Six = b'6',
63 Seven = b'7',
64 Eight = b'8',
65 Nine = b'9',
66 Pi = b'p',
67 E = b'e',
68 Phi = b'f',
69 Gamma = b'g',
71 Plastic = b'P',
73 Apery = b'z',
75 Catalan = b'G',
77 X = b'x',
78
79 UserConstant0 = 128,
83 UserConstant1 = 129,
84 UserConstant2 = 130,
85 UserConstant3 = 131,
86 UserConstant4 = 132,
87 UserConstant5 = 133,
88 UserConstant6 = 134,
89 UserConstant7 = 135,
90 UserConstant8 = 136,
91 UserConstant9 = 137,
92 UserConstant10 = 138,
93 UserConstant11 = 139,
94 UserConstant12 = 140,
95 UserConstant13 = 141,
96 UserConstant14 = 142,
97 UserConstant15 = 143,
98
99 UserFunction0 = 144,
102 UserFunction1 = 145,
103 UserFunction2 = 146,
104 UserFunction3 = 147,
105 UserFunction4 = 148,
106 UserFunction5 = 149,
107 UserFunction6 = 150,
108 UserFunction7 = 151,
109 UserFunction8 = 152,
110 UserFunction9 = 153,
111 UserFunction10 = 154,
112 UserFunction11 = 155,
113 UserFunction12 = 156,
114 UserFunction13 = 157,
115 UserFunction14 = 158,
116 UserFunction15 = 159,
117
118 Neg = b'n',
120 Recip = b'r',
121 Sqrt = b'q',
122 Square = b's',
123 Ln = b'l',
124 Exp = b'E',
125 SinPi = b'S',
126 CosPi = b'C',
127 TanPi = b'T',
128 LambertW = b'W',
129
130 Add = b'+',
132 Sub = b'-',
133 Mul = b'*',
134 Div = b'/',
135 Pow = b'^',
136 Root = b'v', Log = b'L', Atan2 = b'A',
139}
140
141impl Symbol {
142 #[inline]
144 pub const fn seft(self) -> Seft {
145 use Symbol::*;
146 match self {
147 One | Two | Three | Four | Five | Six | Seven | Eight | Nine | Pi | E | Phi | Gamma
148 | Plastic | Apery | Catalan | X | UserConstant0 | UserConstant1 | UserConstant2
149 | UserConstant3 | UserConstant4 | UserConstant5 | UserConstant6 | UserConstant7
150 | UserConstant8 | UserConstant9 | UserConstant10 | UserConstant11 | UserConstant12
151 | UserConstant13 | UserConstant14 | UserConstant15 => Seft::A,
152
153 Neg | Recip | Sqrt | Square | Ln | Exp | SinPi | CosPi | TanPi | LambertW
154 | UserFunction0 | UserFunction1 | UserFunction2 | UserFunction3 | UserFunction4
155 | UserFunction5 | UserFunction6 | UserFunction7 | UserFunction8 | UserFunction9
156 | UserFunction10 | UserFunction11 | UserFunction12 | UserFunction13
157 | UserFunction14 | UserFunction15 => Seft::B,
158
159 Add | Sub | Mul | Div | Pow | Root | Log | Atan2 => Seft::C,
160 }
161 }
162
163 #[inline]
235 pub fn weight(self) -> u32 {
236 self.default_weight()
237 }
238
239 #[inline]
243 pub const fn default_weight(self) -> u32 {
244 use Symbol::*;
245 match self {
246 One => 10,
250 Two => 13,
251 Three => 15,
252 Four => 16,
253 Five => 17,
254 Six => 18,
255 Seven => 18,
256 Eight => 19,
257 Nine => 19,
258
259 Pi => 14,
261 E => 16,
262 Phi => 18,
263
264 Gamma => 20, Plastic => 20, Apery => 22, Catalan => 20, X => 15,
273
274 UserConstant0 | UserConstant1 | UserConstant2 | UserConstant3 | UserConstant4
276 | UserConstant5 | UserConstant6 | UserConstant7 | UserConstant8 | UserConstant9
277 | UserConstant10 | UserConstant11 | UserConstant12 | UserConstant13
278 | UserConstant14 | UserConstant15 => 16,
279
280 Neg => 7,
282 Recip => 7,
283 Sqrt => 9,
284 Square => 9,
285 Ln => 13,
286 Exp => 13,
287 SinPi => 13,
288 CosPi => 13,
289 TanPi => 16,
290 LambertW => 20, UserFunction0 | UserFunction1 | UserFunction2 | UserFunction3 | UserFunction4
294 | UserFunction5 | UserFunction6 | UserFunction7 | UserFunction8 | UserFunction9
295 | UserFunction10 | UserFunction11 | UserFunction12 | UserFunction13
296 | UserFunction14 | UserFunction15 => 16,
297
298 Add => 4,
300 Sub => 5,
301 Mul => 4,
302 Div => 5,
303 Pow => 6,
304 Root => 7,
305 Log => 9,
306 Atan2 => 9,
307 }
308 }
309
310 #[inline]
316 pub fn legacy_parity_weight(self) -> i32 {
317 use Symbol::*;
318 match self {
319 One => 0,
321 Two => 3,
322 Three => 5,
323 Four => 6,
324 Five => 7,
325 Six => 8,
326 Seven => 8,
327 Eight => 9,
328 Nine => 9,
329 Pi => 4,
330 E => 6,
331 Phi => 8,
332 X => 5,
333
334 Neg => -3,
336 Recip => -3,
337 Square => -1,
338 Sqrt => -1,
339 Ln => 3,
340 Exp => 3,
341 SinPi => 3,
342 CosPi => 3,
343 TanPi => 6,
344 LambertW => 5,
345
346 Add => -6,
348 Sub => -5,
349 Mul => -6,
350 Div => -5,
351 Pow => -4,
352 Root => -3,
353 Log => -1,
354 Atan2 => -1,
355
356 _ => self.weight() as i32,
358 }
359 }
360
361 pub fn result_type(self, arg_types: &[NumType]) -> NumType {
363 use NumType::*;
364 use Symbol::*;
365
366 match self {
367 One | Two | Three | Four | Five | Six | Seven | Eight | Nine => Integer,
369
370 Pi | E => Transcendental,
372
373 Phi => Algebraic,
375
376 Gamma => Transcendental,
379 Plastic => Algebraic,
381 Apery => Transcendental,
383 Catalan => Transcendental,
385
386 X => Transcendental,
388
389 UserConstant0 | UserConstant1 | UserConstant2 | UserConstant3 | UserConstant4
391 | UserConstant5 | UserConstant6 | UserConstant7 | UserConstant8 | UserConstant9
392 | UserConstant10 | UserConstant11 | UserConstant12 | UserConstant13
393 | UserConstant14 | UserConstant15 => Transcendental,
394
395 Neg | Add | Sub | Mul => {
397 if arg_types.iter().all(|t| *t == Integer) {
398 Integer
399 } else {
400 arg_types.iter().copied().fold(Integer, NumType::combine)
401 }
402 }
403
404 Div | Recip => {
406 let base = arg_types.iter().copied().fold(Integer, NumType::combine);
407 if base == Integer {
408 Rational
409 } else {
410 base
411 }
412 }
413
414 Sqrt => {
416 let base = arg_types.iter().copied().fold(Integer, NumType::combine);
417 if base.is_at_least(Constructible) {
418 Constructible
419 } else if base.is_at_least(Algebraic) {
420 Algebraic
421 } else {
422 base
423 }
424 }
425
426 Square => arg_types.iter().copied().fold(Integer, NumType::combine),
428
429 Root => Algebraic,
431
432 Pow => {
434 if arg_types.len() >= 2 && arg_types[1] == Integer {
438 arg_types[0]
439 } else {
440 Transcendental
441 }
442 }
443
444 Ln | Exp | SinPi | CosPi | TanPi | Log | LambertW | Atan2 => Transcendental,
446
447 UserFunction0 | UserFunction1 | UserFunction2 | UserFunction3 | UserFunction4
449 | UserFunction5 | UserFunction6 | UserFunction7 | UserFunction8 | UserFunction9
450 | UserFunction10 | UserFunction11 | UserFunction12 | UserFunction13
451 | UserFunction14 | UserFunction15 => Transcendental,
452 }
453 }
454
455 pub const fn inherent_type(self) -> NumType {
458 use NumType::*;
459 use Symbol::*;
460
461 match self {
462 One | Two | Three | Four | Five | Six | Seven | Eight | Nine => Integer,
464
465 Pi | E => Transcendental,
467
468 Phi => Algebraic,
470
471 Gamma => Transcendental,
473 Plastic => Algebraic,
474 Apery => Transcendental,
475 Catalan => Transcendental,
476
477 X => Transcendental,
479
480 UserConstant0 | UserConstant1 | UserConstant2 | UserConstant3 | UserConstant4
482 | UserConstant5 | UserConstant6 | UserConstant7 | UserConstant8 | UserConstant9
483 | UserConstant10 | UserConstant11 | UserConstant12 | UserConstant13
484 | UserConstant14 | UserConstant15 => Transcendental,
485
486 _ => Transcendental,
489 }
490 }
491
492 pub const fn name(self) -> &'static str {
494 use Symbol::*;
495 match self {
496 One => "1",
497 Two => "2",
498 Three => "3",
499 Four => "4",
500 Five => "5",
501 Six => "6",
502 Seven => "7",
503 Eight => "8",
504 Nine => "9",
505 Pi => "pi",
506 E => "e",
507 Phi => "phi",
508 Gamma => "gamma",
509 Plastic => "plastic",
510 Apery => "apery",
511 Catalan => "catalan",
512 X => "x",
513 Neg => "-",
514 Recip => "1/",
515 Sqrt => "sqrt",
516 Square => "^2",
517 Ln => "ln",
518 Exp => "e^",
519 SinPi => "sinpi",
520 CosPi => "cospi",
521 TanPi => "tanpi",
522 LambertW => "W",
523 Add => "+",
524 Sub => "-",
525 Mul => "*",
526 Div => "/",
527 Pow => "^",
528 Root => "\"/",
529 Log => "log_",
530 Atan2 => "atan2",
531 UserConstant0 => "u0",
533 UserConstant1 => "u1",
534 UserConstant2 => "u2",
535 UserConstant3 => "u3",
536 UserConstant4 => "u4",
537 UserConstant5 => "u5",
538 UserConstant6 => "u6",
539 UserConstant7 => "u7",
540 UserConstant8 => "u8",
541 UserConstant9 => "u9",
542 UserConstant10 => "u10",
543 UserConstant11 => "u11",
544 UserConstant12 => "u12",
545 UserConstant13 => "u13",
546 UserConstant14 => "u14",
547 UserConstant15 => "u15",
548 UserFunction0 => "f0",
550 UserFunction1 => "f1",
551 UserFunction2 => "f2",
552 UserFunction3 => "f3",
553 UserFunction4 => "f4",
554 UserFunction5 => "f5",
555 UserFunction6 => "f6",
556 UserFunction7 => "f7",
557 UserFunction8 => "f8",
558 UserFunction9 => "f9",
559 UserFunction10 => "f10",
560 UserFunction11 => "f11",
561 UserFunction12 => "f12",
562 UserFunction13 => "f13",
563 UserFunction14 => "f14",
564 UserFunction15 => "f15",
565 }
566 }
567
568 pub fn display_name(self) -> String {
572 self.name().to_string()
573 }
574
575 pub fn from_byte(b: u8) -> Option<Self> {
577 use Symbol::*;
578 Some(match b {
579 b'1' => One,
580 b'2' => Two,
581 b'3' => Three,
582 b'4' => Four,
583 b'5' => Five,
584 b'6' => Six,
585 b'7' => Seven,
586 b'8' => Eight,
587 b'9' => Nine,
588 b'p' => Pi,
589 b'e' => E,
590 b'f' => Phi,
591 b'x' => X,
592 b'g' => Gamma,
593 b'P' => Plastic,
594 b'z' => Apery,
595 b'G' => Catalan,
596 b'n' => Neg,
597 b'r' => Recip,
598 b'q' => Sqrt,
599 b's' => Square,
600 b'l' => Ln,
601 b'E' => Exp,
602 b'S' => SinPi,
603 b'C' => CosPi,
604 b'T' => TanPi,
605 b'W' => LambertW,
606 b'+' => Add,
607 b'-' => Sub,
608 b'*' => Mul,
609 b'/' => Div,
610 b'^' => Pow,
611 b'v' => Root,
612 b'L' => Log,
613 b'A' => Atan2,
614 128 => UserConstant0,
616 129 => UserConstant1,
617 130 => UserConstant2,
618 131 => UserConstant3,
619 132 => UserConstant4,
620 133 => UserConstant5,
621 134 => UserConstant6,
622 135 => UserConstant7,
623 136 => UserConstant8,
624 137 => UserConstant9,
625 138 => UserConstant10,
626 139 => UserConstant11,
627 140 => UserConstant12,
628 141 => UserConstant13,
629 142 => UserConstant14,
630 143 => UserConstant15,
631 144 => UserFunction0,
635 145 => UserFunction1,
636 146 => UserFunction2,
637 147 => UserFunction3,
638 148 => UserFunction4,
639 149 => UserFunction5,
640 150 => UserFunction6,
641 151 => UserFunction7,
642 152 => UserFunction8,
643 153 => UserFunction9,
644 154 => UserFunction10,
645 155 => UserFunction11,
646 156 => UserFunction12,
647 157 => UserFunction13,
648 158 => UserFunction14,
649 159 => UserFunction15,
650 b'H' => UserFunction0,
653 b'I' => UserFunction1,
654 b'J' => UserFunction2,
655 b'K' => UserFunction3,
656 b'M' => UserFunction4,
657 b'N' => UserFunction5,
658 b'O' => UserFunction6,
659 b'Q' => UserFunction7,
660 b'R' => UserFunction8,
661 b'U' => UserFunction9,
662 b'V' => UserFunction10,
663 b'Y' => UserFunction11,
664 b'Z' => UserFunction12,
665 b'B' => UserFunction13,
666 b'D' => UserFunction14,
667 b'F' => UserFunction15,
668 _ => return None,
669 })
670 }
671
672 pub fn user_constant_index(self) -> Option<u8> {
674 use Symbol::*;
675 match self {
676 UserConstant0 => Some(0),
677 UserConstant1 => Some(1),
678 UserConstant2 => Some(2),
679 UserConstant3 => Some(3),
680 UserConstant4 => Some(4),
681 UserConstant5 => Some(5),
682 UserConstant6 => Some(6),
683 UserConstant7 => Some(7),
684 UserConstant8 => Some(8),
685 UserConstant9 => Some(9),
686 UserConstant10 => Some(10),
687 UserConstant11 => Some(11),
688 UserConstant12 => Some(12),
689 UserConstant13 => Some(13),
690 UserConstant14 => Some(14),
691 UserConstant15 => Some(15),
692 _ => None,
693 }
694 }
695
696 pub fn user_function_index(self) -> Option<u8> {
698 use Symbol::*;
699 match self {
700 UserFunction0 => Some(0),
701 UserFunction1 => Some(1),
702 UserFunction2 => Some(2),
703 UserFunction3 => Some(3),
704 UserFunction4 => Some(4),
705 UserFunction5 => Some(5),
706 UserFunction6 => Some(6),
707 UserFunction7 => Some(7),
708 UserFunction8 => Some(8),
709 UserFunction9 => Some(9),
710 UserFunction10 => Some(10),
711 UserFunction11 => Some(11),
712 UserFunction12 => Some(12),
713 UserFunction13 => Some(13),
714 UserFunction14 => Some(14),
715 UserFunction15 => Some(15),
716 _ => None,
717 }
718 }
719
720 pub fn constants() -> &'static [Symbol] {
722 use Symbol::*;
723 &[
724 One, Two, Three, Four, Five, Six, Seven, Eight, Nine, Pi, E, Phi, Gamma, Plastic,
725 Apery, Catalan,
726 ]
727 }
728
729 pub fn unary_ops() -> &'static [Symbol] {
731 use Symbol::*;
732 &[
733 Neg, Recip, Sqrt, Square, Ln, Exp, SinPi, CosPi, TanPi, LambertW,
734 ]
735 }
736
737 pub fn binary_ops() -> &'static [Symbol] {
739 use Symbol::*;
740 &[Add, Sub, Mul, Div, Pow, Root, Log, Atan2]
741 }
742}
743
744impl fmt::Display for Symbol {
745 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
746 write!(f, "{}", self.name())
747 }
748}
749
750impl From<Symbol> for u8 {
751 fn from(s: Symbol) -> u8 {
752 s as u8
753 }
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759
760 #[test]
763 fn test_weights_match_original_ries() {
764 assert_eq!(Symbol::One.default_weight(), 10);
766 assert_eq!(Symbol::Two.default_weight(), 13);
767 assert_eq!(Symbol::Three.default_weight(), 15);
768 assert_eq!(Symbol::Four.default_weight(), 16);
769 assert_eq!(Symbol::Five.default_weight(), 17);
770 assert_eq!(Symbol::Six.default_weight(), 18);
771 assert_eq!(Symbol::Seven.default_weight(), 18);
772 assert_eq!(Symbol::Eight.default_weight(), 19);
773 assert_eq!(Symbol::Nine.default_weight(), 19);
774
775 assert_eq!(Symbol::Pi.default_weight(), 14);
777 assert_eq!(Symbol::E.default_weight(), 16);
778 assert_eq!(Symbol::Phi.default_weight(), 18);
779 assert_eq!(Symbol::X.default_weight(), 15);
780
781 assert_eq!(Symbol::Add.default_weight(), 4);
783 assert_eq!(Symbol::Mul.default_weight(), 4);
784 assert_eq!(Symbol::Sub.default_weight(), 5);
785 assert_eq!(Symbol::Div.default_weight(), 5);
786 assert_eq!(Symbol::Pow.default_weight(), 6);
787 assert_eq!(Symbol::Root.default_weight(), 7);
788 assert_eq!(Symbol::Atan2.default_weight(), 9);
789 assert_eq!(Symbol::Log.default_weight(), 9);
790
791 assert_eq!(Symbol::Neg.default_weight(), 7);
793 assert_eq!(Symbol::Recip.default_weight(), 7);
794 assert_eq!(Symbol::Sqrt.default_weight(), 9);
795 assert_eq!(Symbol::Square.default_weight(), 9);
796 assert_eq!(Symbol::Ln.default_weight(), 13);
797 assert_eq!(Symbol::Exp.default_weight(), 13);
798 assert_eq!(Symbol::SinPi.default_weight(), 13);
799 assert_eq!(Symbol::CosPi.default_weight(), 13);
800 assert_eq!(Symbol::TanPi.default_weight(), 16);
801 }
802
803 #[test]
804 fn test_symbol_roundtrip() {
805 for &sym in Symbol::constants()
806 .iter()
807 .chain(Symbol::unary_ops())
808 .chain(Symbol::binary_ops())
809 {
810 let byte = sym as u8;
811 let parsed = Symbol::from_byte(byte).unwrap();
812 assert_eq!(sym, parsed);
813 }
814 }
815
816 #[test]
817 fn test_num_type_ordering() {
818 assert!(NumType::Integer > NumType::Rational);
819 assert!(NumType::Rational > NumType::Algebraic);
820 assert!(NumType::Algebraic > NumType::Transcendental);
821 }
822
823 #[test]
824 fn test_seft() {
825 assert_eq!(Symbol::Pi.seft(), Seft::A);
826 assert_eq!(Symbol::Sqrt.seft(), Seft::B);
827 assert_eq!(Symbol::Add.seft(), Seft::C);
828 }
829
830 #[test]
834 fn test_pow_result_type_algebraic_base_integer_exponent() {
835 let result = Symbol::Pow.result_type(&[NumType::Algebraic, NumType::Integer]);
838 assert_eq!(
839 result,
840 NumType::Algebraic,
841 "Algebraic^Integer should be Algebraic"
842 );
843 }
844
845 #[test]
846 fn test_pow_result_type_integer_base_algebraic_exponent() {
847 let result = Symbol::Pow.result_type(&[NumType::Integer, NumType::Algebraic]);
850 assert_eq!(
851 result,
852 NumType::Transcendental,
853 "Integer^Algebraic should be Transcendental"
854 );
855 }
856
857 #[test]
858 fn test_pow_result_type_integer_base_integer_exponent() {
859 let result = Symbol::Pow.result_type(&[NumType::Integer, NumType::Integer]);
861 assert_eq!(
862 result,
863 NumType::Integer,
864 "Integer^Integer should be Integer"
865 );
866 }
867}