Skip to main content

litex/rational_expression/
evaluate.rs

1use crate::prelude::*;
2use crate::rational_expression::evaluate_div::safe_div;
3
4impl Obj {
5    pub fn evaluate_to_normalized_decimal_number(&self) -> Option<Number> {
6        let result = match self {
7            Obj::Number(number) => Some(number.clone()),
8            Obj::Add(add) => {
9                let left_number = add.left.evaluate_to_normalized_decimal_number();
10                let right_number = add.right.evaluate_to_normalized_decimal_number();
11                if let (Some(left_number), Some(right_number)) = (left_number, right_number) {
12                    let a = &left_number.normalized_value;
13                    let b = &right_number.normalized_value;
14                    let sum = if normalized_decimal_str_is_non_negative(a)
15                        && normalized_decimal_str_is_non_negative(b)
16                    {
17                        // `add_decimal_str_and_normalize` 仅适用于两操作数非负(见函数注释)
18                        add_decimal_str_and_normalize(a, b)
19                    } else {
20                        add_signed_decimal_str(a, b)
21                    };
22                    Some(Number::new(sum))
23                } else {
24                    None
25                }
26            }
27            Obj::Sub(sub) => {
28                let left_number = sub.left.evaluate_to_normalized_decimal_number();
29                let right_number = sub.right.evaluate_to_normalized_decimal_number();
30                if let (Some(left_number), Some(right_number)) = (left_number, right_number) {
31                    let a = &left_number.normalized_value;
32                    let b = &right_number.normalized_value;
33                    let diff = if normalized_decimal_str_is_non_negative(a)
34                        && normalized_decimal_str_is_non_negative(b)
35                    {
36                        // `sub_decimal_str_and_normalize` 的竖式比较同样按非负量设计
37                        sub_decimal_str_and_normalize(a, b)
38                    } else {
39                        sub_signed_decimal_str(a, b)
40                    };
41                    Some(Number::new(diff))
42                } else {
43                    None
44                }
45            }
46            Obj::Mul(mul) => {
47                let left_number = mul.left.evaluate_to_normalized_decimal_number();
48                let right_number = mul.right.evaluate_to_normalized_decimal_number();
49                if let (Some(left_number), Some(right_number)) = (left_number, right_number) {
50                    Some(Number::new(mul_signed_decimal_str(
51                        &left_number.normalized_value,
52                        &right_number.normalized_value,
53                    )))
54                } else {
55                    None
56                }
57            }
58            Obj::Mod(mod_obj) => {
59                let left_number = mod_obj.left.evaluate_to_normalized_decimal_number();
60                let right_number = mod_obj.right.evaluate_to_normalized_decimal_number();
61                if let (Some(left_number), Some(right_number)) = (left_number, right_number) {
62                    Some(Number::new(mod_decimal_str_and_normalize(
63                        &left_number.normalized_value,
64                        &right_number.normalized_value,
65                    )))
66                } else {
67                    None
68                }
69            }
70            Obj::Pow(pow_obj) => {
71                let base_number = pow_obj.base.evaluate_to_normalized_decimal_number();
72                let exponent_number = pow_obj.exponent.evaluate_to_normalized_decimal_number();
73                if let (Some(base_number), Some(exponent_number)) = (base_number, exponent_number) {
74                    Some(Number::new(pow_decimal_str_and_normalize(
75                        &base_number.normalized_value,
76                        &exponent_number.normalized_value,
77                    )))
78                } else {
79                    None
80                }
81            }
82            Obj::Div(div) => {
83                let left_number = div.left.evaluate_to_normalized_decimal_number();
84                let right_number = div.right.evaluate_to_normalized_decimal_number();
85                if let (Some(left_number), Some(right_number)) = (left_number, right_number) {
86                    let exact_quotient_string = safe_div(
87                        &left_number.normalized_value,
88                        &right_number.normalized_value,
89                    );
90
91                    if let Some(exact_quotient_string) = exact_quotient_string {
92                        Some(Number::new(exact_quotient_string))
93                    } else {
94                        None
95                    }
96                } else {
97                    return None;
98                }
99            }
100            Obj::CartDim(cart_dim) => match &*cart_dim.set {
101                Obj::Cart(cart) => Some(Number::new(cart.args.len().to_string())),
102                _ => None,
103            },
104            Obj::TupleDim(tuple_dim) => match &*tuple_dim.arg {
105                Obj::Tuple(tuple) => Some(Number::new(tuple.args.len().to_string())),
106                _ => None,
107            },
108            Obj::Count(count) => match &*count.set {
109                Obj::ListSet(list_set) => Some(Number::new(list_set.list.len().to_string())),
110                _ => None,
111            },
112            _ => None,
113        };
114
115        match result {
116            Some(number) => Some(number),
117            None => None,
118        }
119    }
120
121    pub fn two_objs_can_be_calculated_and_equal_by_calculation(&self, other: &Obj) -> bool {
122        match (
123            self.evaluate_to_normalized_decimal_number(),
124            other.evaluate_to_normalized_decimal_number(),
125        ) {
126            (Some(left_number), Some(right_number)) => {
127                return left_number.normalized_value == right_number.normalized_value;
128            }
129            _ => return false,
130        }
131    }
132}
133
134/// 规范化后的十进制串是否表示非负数(无 `-` 前缀;`-0` 若已规范为 `0` 亦视为非负)。
135fn normalized_decimal_str_is_non_negative(s: &str) -> bool {
136    !s.trim().starts_with('-')
137}
138
139fn split_sign_and_magnitude(number_string: &str) -> (bool, String) {
140    let trimmed_number_string = number_string.trim();
141    if let Some(stripped_number_string) = trimmed_number_string.strip_prefix('-') {
142        (true, stripped_number_string.trim().to_string())
143    } else {
144        (false, trimmed_number_string.to_string())
145    }
146}
147
148pub fn mul_signed_decimal_str(left_number_string: &str, right_number_string: &str) -> String {
149    let (left_is_negative, left_magnitude_number_string) =
150        split_sign_and_magnitude(left_number_string);
151    let (right_is_negative, right_magnitude_number_string) =
152        split_sign_and_magnitude(right_number_string);
153    let multiplied_magnitude_number_string = mul_decimal_str_and_normalize(
154        &left_magnitude_number_string,
155        &right_magnitude_number_string,
156    );
157    let multiplied_magnitude_is_zero = multiplied_magnitude_number_string == "0";
158    let multiplied_result_is_negative = left_is_negative ^ right_is_negative;
159    if multiplied_result_is_negative && !multiplied_magnitude_is_zero {
160        normalize_decimal_number_string(&format!("-{}", multiplied_magnitude_number_string))
161    } else {
162        normalize_decimal_number_string(&multiplied_magnitude_number_string)
163    }
164}
165
166/// 带符号加法 a + b(系数合并用;`add_decimal_str_and_normalize` 仅适用于非负操作数)
167pub fn add_signed_decimal_str(a: &str, b: &str) -> String {
168    let (a_neg, a_mag) = split_sign_and_magnitude(a);
169    let (b_neg, b_mag) = split_sign_and_magnitude(b);
170    match (a_neg, b_neg) {
171        (false, false) => add_decimal_str_and_normalize(&a_mag, &b_mag),
172        (true, true) => {
173            let sum_mag = add_decimal_str_and_normalize(&a_mag, &b_mag);
174            if sum_mag == "0" {
175                "0".to_string()
176            } else {
177                normalize_decimal_number_string(&format!("-{}", sum_mag))
178            }
179        }
180        (false, true) => sub_decimal_str_and_normalize(&a_mag, &b_mag),
181        (true, false) => sub_decimal_str_and_normalize(&b_mag, &a_mag),
182    }
183}
184
185/// 带符号减法 a - b
186pub fn sub_signed_decimal_str(a: &str, b: &str) -> String {
187    add_signed_decimal_str(a, &mul_signed_decimal_str(b, "-1"))
188}
189
190impl Obj {
191    pub fn replace_with_numeric_result_if_can_be_calculated(&self) -> (Obj, bool) {
192        if let Some(calculated_number) = self.evaluate_to_normalized_decimal_number() {
193            (Obj::Number(calculated_number), true)
194        } else {
195            (self.clone(), false)
196        }
197    }
198}
199
200/// 竖式加法:两个表示非负数的数字串(可含小数点),返回和的字符串
201pub fn add_decimal_str_and_normalize(a: &str, b: &str) -> String {
202    let (mut int_a, mut frac_a) = parse_decimal_parts(a);
203    let (mut int_b, mut frac_b) = parse_decimal_parts(b);
204    let frac_len = frac_a.len().max(frac_b.len());
205    frac_a.resize(frac_len, 0);
206    frac_b.resize(frac_len, 0);
207    let int_len = int_a.len().max(int_b.len());
208    int_a.reverse();
209    int_b.reverse();
210    int_a.resize(int_len, 0);
211    int_b.resize(int_len, 0);
212
213    let mut out_frac = vec![0u8; frac_len];
214    let mut carry = 0u8;
215    for i in (0..frac_len).rev() {
216        let sum = frac_a[i] + frac_b[i] + carry;
217        out_frac[i] = sum % 10;
218        carry = sum / 10;
219    }
220    let mut out_int = Vec::with_capacity(int_len + 1);
221    for i in 0..int_len {
222        let sum = int_a[i] + int_b[i] + carry;
223        out_int.push(sum % 10);
224        carry = sum / 10;
225    }
226    if carry > 0 {
227        out_int.push(carry);
228    }
229    out_int.reverse();
230
231    let int_str: String = out_int.iter().map(|&d| (b'0' + d) as char).collect();
232    let frac_str: String = out_frac.iter().map(|&d| (b'0' + d) as char).collect();
233    let result = if frac_str.is_empty() || out_frac.iter().all(|&d| d == 0) {
234        int_str
235    } else {
236        format!("{}.{}", int_str, frac_str.trim_end_matches('0'))
237    };
238    normalize_decimal_number_string(&result)
239}
240
241/// 竖式减法:a - b,若 a >= b 返回非负结果字符串,否则返回 "-" + (b - a) 的字符串
242pub fn sub_decimal_str_and_normalize(a: &str, b: &str) -> String {
243    let (int_a, frac_a) = parse_decimal_parts(a);
244    let (int_b, frac_b) = parse_decimal_parts(b);
245    let frac_len = frac_a.len().max(frac_b.len());
246    let mut fa: Vec<u8> = frac_a.iter().cloned().collect();
247    let mut fb: Vec<u8> = frac_b.iter().cloned().collect();
248    fa.resize(frac_len, 0);
249    fb.resize(frac_len, 0);
250    let int_len = int_a.len().max(int_b.len());
251    let mut ia: Vec<u8> = int_a.iter().cloned().collect();
252    let mut ib: Vec<u8> = int_b.iter().cloned().collect();
253    ia.reverse();
254    ib.reverse();
255    ia.resize(int_len, 0);
256    ib.resize(int_len, 0);
257
258    let cmp = compare_decimal_parts(&ia, &fa, &ib, &fb);
259    let (top_int, top_frac, bot_int, bot_frac) = if cmp >= 0 {
260        (ia, fa, ib, fb)
261    } else {
262        let inner = sub_decimal_str_and_normalize(b, a);
263        return normalize_decimal_number_string(&format!("-{}", inner));
264    };
265
266    let mut out_frac = vec![0u8; frac_len];
267    let mut borrow: i16 = 0;
268    for i in (0..frac_len).rev() {
269        let mut d = top_frac[i] as i16 - bot_frac[i] as i16 - borrow;
270        borrow = 0;
271        if d < 0 {
272            d += 10;
273            borrow = 1;
274        }
275        out_frac[i] = d as u8;
276    }
277    let mut out_int = Vec::with_capacity(int_len);
278    for i in 0..int_len {
279        let mut d = top_int[i] as i16 - bot_int[i] as i16 - borrow;
280        borrow = 0;
281        if d < 0 {
282            d += 10;
283            borrow = 1;
284        }
285        out_int.push(d as u8);
286    }
287    out_int.reverse();
288    let start = out_int
289        .iter()
290        .position(|&d| d != 0)
291        .unwrap_or(out_int.len().saturating_sub(1));
292    let out_int = out_int[start..].to_vec();
293
294    let int_str: String = if out_int.is_empty() {
295        "0".to_string()
296    } else {
297        out_int.iter().map(|&d| (b'0' + d) as char).collect()
298    };
299    let frac_str: String = out_frac.iter().map(|&d| (b'0' + d) as char).collect();
300    let frac_trim = frac_str.trim_end_matches('0');
301    let result = if frac_trim.is_empty() {
302        int_str
303    } else {
304        format!("{}.{}", int_str, frac_trim)
305    };
306    normalize_decimal_number_string(&result)
307}
308
309fn compare_decimal_parts(int_a: &[u8], frac_a: &[u8], int_b: &[u8], frac_b: &[u8]) -> i32 {
310    let len_a = int_a.len();
311    let len_b = int_b.len();
312    if len_a != len_b {
313        return (len_a as i32) - (len_b as i32);
314    }
315    for i in (0..len_a).rev() {
316        if int_a[i] != int_b[i] {
317            return int_a[i] as i32 - int_b[i] as i32;
318        }
319    }
320    for i in 0..frac_a.len().max(frac_b.len()) {
321        let da = match frac_a.get(i) {
322            Some(&d) => d,
323            None => 0,
324        };
325        let db = match frac_b.get(i) {
326            Some(&d) => d,
327            None => 0,
328        };
329        if da != db {
330            return da as i32 - db as i32;
331        }
332    }
333    0
334}
335
336/// 竖式乘法:两个非负数字串,返回积的字符串(product[0]=个位,即最低位)
337pub fn mul_decimal_str_and_normalize(a: &str, b: &str) -> String {
338    let (int_a, frac_a) = parse_decimal_parts(a);
339    let (int_b, frac_b) = parse_decimal_parts(b);
340    let frac_places = frac_a.len() + frac_b.len();
341    let digits_a: Vec<u8> = int_a
342        .iter()
343        .cloned()
344        .chain(frac_a.iter().cloned())
345        .collect();
346    let digits_b: Vec<u8> = int_b
347        .iter()
348        .cloned()
349        .chain(frac_b.iter().cloned())
350        .collect();
351    let len_a = digits_a.len();
352    let len_b = digits_b.len();
353    let mut product = vec![0u32; len_a + len_b];
354    for (i, &da) in digits_a.iter().enumerate() {
355        for (j, &db) in digits_b.iter().enumerate() {
356            let place = (len_a - 1 - i) + (len_b - 1 - j);
357            product[place] += da as u32 * db as u32;
358        }
359    }
360    let mut carry = 0u32;
361    for p in product.iter_mut() {
362        *p += carry;
363        carry = *p / 10;
364        *p %= 10;
365    }
366    while carry > 0 {
367        product.push(carry % 10);
368        carry /= 10;
369    }
370    let total_len = product.len();
371    let int_part: String = if frac_places >= total_len {
372        "0".to_string()
373    } else {
374        product[frac_places..]
375            .iter()
376            .rev()
377            .map(|&d| (b'0' + d as u8) as char)
378            .collect::<String>()
379            .trim_start_matches('0')
380            .to_string()
381    };
382    let frac_part: String = if frac_places == 0 {
383        String::new()
384    } else {
385        product[..frac_places.min(total_len)]
386            .iter()
387            .rev()
388            .map(|&d| (b'0' + d as u8) as char)
389            .collect::<String>()
390            .trim_end_matches('0')
391            .to_string()
392    };
393    let int_str = if int_part.is_empty() { "0" } else { &int_part };
394    let result = if frac_part.is_empty() {
395        int_str.to_string()
396    } else {
397        format!("{}.{}", int_str, frac_part)
398    };
399    normalize_decimal_number_string(&result)
400}
401
402/// 竖式取余:a mod b,返回余数字符串。约定:b 仅为非零纯整数(字符串),a 取整数部分参与运算。
403pub fn mod_decimal_str_and_normalize(a: &str, b: &str) -> String {
404    let (int_a, _) = parse_decimal_parts(a);
405    let (int_b, _) = parse_decimal_parts(b);
406    let a_digits = trim_leading_zeros(&int_a);
407    let b_digits = trim_leading_zeros(&int_b);
408    if a_digits.is_empty() {
409        return "0".to_string();
410    }
411    if b_digits.is_empty() || (b_digits.len() == 1 && b_digits[0] == 0) {
412        return "0".to_string();
413    }
414    if compare_digits(&a_digits, &b_digits) == std::cmp::Ordering::Less {
415        return digits_to_string(&a_digits);
416    }
417    let mut current: Vec<u8> = vec![];
418    for &da in &a_digits {
419        current.push(da);
420        current = trim_leading_zeros(&current);
421        let mut d = 9u8;
422        loop {
423            let product = mul_digit(&b_digits, d);
424            if compare_digits(&current, &product) != std::cmp::Ordering::Less {
425                current = sub_digits(&current, &product);
426                break;
427            }
428            if d == 0 {
429                break;
430            }
431            d -= 1;
432        }
433    }
434    normalize_decimal_number_string(&digits_to_string(&current))
435}
436
437/// 仅支持非负整数指数:base^exp,exp 必须为整数(如 "3" 或 "0"),返回字符串
438pub fn pow_decimal_str_and_normalize(base: &str, exp: &str) -> String {
439    let (exp_int, exp_frac) = parse_decimal_parts(exp);
440    if exp_frac.iter().any(|&d| d != 0) {
441        unreachable!("幂运算仅支持整数指数");
442    }
443    let mut n = 0usize;
444    for &d in &exp_int {
445        n = n.saturating_mul(10).saturating_add(d as usize);
446    }
447    if n == 0 {
448        return "1".to_string();
449    }
450    let mut acc = "1".to_string();
451    let mut b = base.to_string();
452    let mut e = n;
453    while e > 0 {
454        if e % 2 == 1 {
455            acc = mul_decimal_str_and_normalize(&acc, &b);
456        }
457        b = mul_decimal_str_and_normalize(&b, &b);
458        e /= 2;
459    }
460    normalize_decimal_number_string(&acc)
461}
462
463fn trim_leading_zeros(d: &[u8]) -> Vec<u8> {
464    let start = d.iter().position(|&x| x != 0).unwrap_or(d.len());
465    d[start..].to_vec()
466}
467
468/// 数字序列转字符串(高位在前)
469fn digits_to_string(d: &[u8]) -> String {
470    let t = trim_leading_zeros(d);
471    if t.is_empty() {
472        return "0".to_string();
473    }
474    t.iter().map(|&x| (b'0' + x) as char).collect()
475}
476
477/// 大数乘一位数:b * d,0 <= d <= 9,返回各位(高位在前)
478fn mul_digit(b: &[u8], d: u8) -> Vec<u8> {
479    if d == 0 {
480        return vec![0];
481    }
482    let mut b = b.to_vec();
483    b.reverse();
484    let mut carry = 0u16;
485    for x in b.iter_mut() {
486        let p = *x as u16 * d as u16 + carry;
487        *x = (p % 10) as u8;
488        carry = p / 10;
489    }
490    while carry > 0 {
491        b.push((carry % 10) as u8);
492        carry /= 10;
493    }
494    b.reverse();
495    trim_leading_zeros(&b)
496}
497
498/// 比较两个“整数”数字序列(高位在前)
499fn compare_digits(a: &[u8], b: &[u8]) -> std::cmp::Ordering {
500    let a = trim_leading_zeros(a);
501    let b = trim_leading_zeros(b);
502    if a.len() != b.len() {
503        return a.len().cmp(&b.len());
504    }
505    for (x, y) in a.iter().zip(b.iter()) {
506        if x != y {
507            return x.cmp(y);
508        }
509    }
510    std::cmp::Ordering::Equal
511}
512
513/// 大数减法:要求 a >= b,返回 a - b 的各位(高位在前)
514fn sub_digits(a: &[u8], b: &[u8]) -> Vec<u8> {
515    let mut a = a.to_vec();
516    let mut b = b.to_vec();
517    let len = a.len().max(b.len());
518    a.reverse();
519    b.reverse();
520    a.resize(len, 0);
521    b.resize(len, 0);
522    let mut borrow: i16 = 0;
523    let mut out = Vec::with_capacity(len);
524    for i in 0..len {
525        let mut d = a[i] as i16 - b[i] as i16 - borrow;
526        borrow = 0;
527        if d < 0 {
528            d += 10;
529            borrow = 1;
530        }
531        out.push(d as u8);
532    }
533    out.reverse();
534    trim_leading_zeros(&out)
535}
536
537/// 化简结果:多个负号合并(---1.1 -> -1.1)、0.0或者-0 写成 0、小数尾零去掉(1.000 -> 1)
538pub fn normalize_decimal_number_string(s: &str) -> String {
539    let s = s.trim();
540    if s.is_empty() {
541        return "0".to_string();
542    }
543    let minus_count = s.chars().take_while(|&c| c == '-').count();
544    let rest = s[minus_count..].trim();
545    let negative = (minus_count % 2) == 1;
546
547    let magnitude = if rest.contains('.') {
548        let (int_str, frac_str) = rest.split_once('.').unwrap_or((rest, ""));
549        let frac_trimmed = frac_str.trim_end_matches('0');
550        let int_trimmed = int_str.trim_start_matches('0');
551        let int_part = if int_trimmed.is_empty() || int_trimmed == "." {
552            "0"
553        } else {
554            int_trimmed
555        };
556        if frac_trimmed.is_empty() {
557            int_part.to_string()
558        } else {
559            format!("{}.{}", int_part, frac_trimmed)
560        }
561    } else {
562        let t = rest.trim_start_matches('0');
563        if t.is_empty() { "0" } else { t }.to_string()
564    };
565
566    let is_zero = magnitude == "0"
567        || (magnitude.starts_with("0.") && magnitude[2..].chars().all(|c| c == '0'));
568    if is_zero {
569        "0".to_string()
570    } else if negative {
571        format!("-{}", magnitude)
572    } else {
573        magnitude
574    }
575}
576
577/// 解析数字串为 (整数部分数字, 小数部分数字),允许 "123.45"、"123"、".5"、"0.5"
578fn parse_decimal_parts(s: &str) -> (Vec<u8>, Vec<u8>) {
579    let s = s.trim();
580    let (int_str, frac_str) = match s.find('.') {
581        Some(i) => (&s[..i], &s[i + 1..]),
582        None => (s, ""),
583    };
584    let int_digits: Vec<u8> = if int_str.is_empty() || int_str == "-" {
585        vec![0]
586    } else {
587        int_str
588            .chars()
589            .filter(|c| c.is_ascii_digit())
590            .map(|c| c as u8 - b'0')
591            .collect()
592    };
593    let frac_digits: Vec<u8> = frac_str
594        .chars()
595        .filter(|c| c.is_ascii_digit())
596        .map(|c| c as u8 - b'0')
597        .collect();
598    let int_digits = if int_digits.is_empty() {
599        vec![0]
600    } else {
601        int_digits
602    };
603    (int_digits, frac_digits)
604}