Skip to main content

openjd_expr/functions/
comparison.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5//! Comparison, containment, and slice operator implementations.
6
7use crate::error::ExpressionError;
8use crate::function_library::EvalContext;
9use crate::value::ExprValue;
10
11type R = Result<ExprValue, ExpressionError>;
12type Ctx<'a> = &'a mut dyn EvalContext;
13
14// ── Equality ──
15
16pub fn eq_generic(_: Ctx, a: &[ExprValue]) -> R {
17    Ok(ExprValue::Bool(a[0].equals(&a[1])))
18}
19
20pub fn ne_generic(_: Ctx, a: &[ExprValue]) -> R {
21    Ok(ExprValue::Bool(!a[0].equals(&a[1])))
22}
23
24// ── Ordering ──
25
26fn do_compare(op_str: &str, a: &[ExprValue]) -> Result<std::cmp::Ordering, ExpressionError> {
27    a[0].compare(&a[1]).map_err(|_| {
28        ExpressionError::type_error(format!(
29            "Cannot use '{}' operator with {} and {}",
30            op_str,
31            a[0].expr_type(),
32            a[1].expr_type()
33        ))
34    })
35}
36
37pub fn lt_generic(_: Ctx, a: &[ExprValue]) -> R {
38    Ok(ExprValue::Bool(do_compare("<", a)?.is_lt()))
39}
40
41pub fn le_generic(_: Ctx, a: &[ExprValue]) -> R {
42    Ok(ExprValue::Bool(do_compare("<=", a)?.is_le()))
43}
44
45pub fn gt_generic(_: Ctx, a: &[ExprValue]) -> R {
46    Ok(ExprValue::Bool(do_compare(">", a)?.is_gt()))
47}
48
49pub fn ge_generic(_: Ctx, a: &[ExprValue]) -> R {
50    Ok(ExprValue::Bool(do_compare(">=", a)?.is_ge()))
51}
52
53// ── Containment ──
54
55pub fn contains_list(ctx: Ctx, a: &[ExprValue]) -> R {
56    let len = a[0].list_len().unwrap_or(0);
57    ctx.count_ops(len)?;
58    let found = a[0]
59        .list_iter()
60        .map(|mut iter| iter.any(|e| a[1].equals(&e)))
61        .unwrap_or(false);
62    Ok(ExprValue::Bool(found))
63}
64
65pub fn not_contains_list(ctx: Ctx, a: &[ExprValue]) -> R {
66    let len = a[0].list_len().unwrap_or(0);
67    ctx.count_ops(len)?;
68    let found = a[0]
69        .list_iter()
70        .map(|mut iter| iter.any(|e| a[1].equals(&e)))
71        .unwrap_or(false);
72    Ok(ExprValue::Bool(!found))
73}
74
75pub fn contains_string(ctx: Ctx, a: &[ExprValue]) -> R {
76    match (&a[0], &a[1]) {
77        (ExprValue::String(haystack), ExprValue::String(needle)) => {
78            ctx.count_string_ops(haystack.len() + needle.len())?;
79            Ok(ExprValue::Bool(haystack.contains(needle.as_str())))
80        }
81        _ => Err(ExpressionError::type_error("type error")),
82    }
83}
84
85pub fn not_contains_string(ctx: Ctx, a: &[ExprValue]) -> R {
86    match (&a[0], &a[1]) {
87        (ExprValue::String(haystack), ExprValue::String(needle)) => {
88            ctx.count_string_ops(haystack.len() + needle.len())?;
89            Ok(ExprValue::Bool(!haystack.contains(needle.as_str())))
90        }
91        _ => Err(ExpressionError::type_error("type error")),
92    }
93}
94
95pub fn contains_range(_: Ctx, a: &[ExprValue]) -> R {
96    match (&a[0], &a[1]) {
97        (ExprValue::RangeExpr(r), ExprValue::Int(i)) => Ok(ExprValue::Bool(r.contains(*i))),
98        _ => Err(ExpressionError::type_error("type error")),
99    }
100}
101
102pub fn not_contains_range(_: Ctx, a: &[ExprValue]) -> R {
103    match (&a[0], &a[1]) {
104        (ExprValue::RangeExpr(r), ExprValue::Int(i)) => Ok(ExprValue::Bool(!r.contains(*i))),
105        _ => Err(ExpressionError::type_error("type error")),
106    }
107}
108
109// ── Slicing (4-arg __getitem__) ──
110
111fn extract_int_or_none(v: &ExprValue) -> Option<i64> {
112    match v {
113        ExprValue::Int(i) => Some(*i),
114        _ => None, // Null → None
115    }
116}
117
118fn compute_slice_indices(len: i64, start: Option<i64>, stop: Option<i64>, step: i64) -> (i64, i64) {
119    if step > 0 {
120        let s = start
121            .map(|i| if i < 0 { (len + i).max(0) } else { i.min(len) })
122            .unwrap_or(0);
123        let e = stop
124            .map(|i| if i < 0 { (len + i).max(0) } else { i.min(len) })
125            .unwrap_or(len);
126        (s, e)
127    } else {
128        let s = start
129            .map(|i| if i < 0 { len + i } else { i.min(len - 1) })
130            .unwrap_or(len - 1);
131        let e = stop.map(|i| if i < 0 { len + i } else { i }).unwrap_or(-1);
132        (s, e)
133    }
134}
135
136fn collect_indices(start: i64, stop: i64, step: i64) -> Vec<usize> {
137    let mut indices = Vec::new();
138    let mut idx = start;
139    if step > 0 {
140        while idx < stop {
141            if idx >= 0 {
142                indices.push(idx as usize);
143            }
144            idx += step;
145        }
146    } else {
147        while idx > stop {
148            if idx >= 0 {
149                indices.push(idx as usize);
150            }
151            idx += step;
152        }
153    }
154    indices
155}
156
157pub fn slice_list(ctx: Ctx, a: &[ExprValue]) -> R {
158    let step = extract_int_or_none(&a[3]).unwrap_or(1);
159    if step == 0 {
160        return Err(ExpressionError::new("Slice step cannot be zero"));
161    }
162    let elem_type = a[0].list_elem_type().unwrap();
163    let len = a[0].list_len().unwrap() as i64;
164    let start = extract_int_or_none(&a[1]);
165    let stop = extract_int_or_none(&a[2]);
166    let (s, e) = compute_slice_indices(len, start, stop, step);
167    let result: Vec<ExprValue> = collect_indices(s, e, step)
168        .into_iter()
169        .filter_map(|i| a[0].list_get(i as i64))
170        .collect();
171    ctx.count_ops(result.len())?;
172    ExprValue::make_list_checked(ctx, result, elem_type.clone())
173}
174
175pub fn slice_string(ctx: Ctx, a: &[ExprValue]) -> R {
176    let s = match &a[0] {
177        ExprValue::String(s) => s.as_str(),
178        _ => return Err(ExpressionError::type_error("type error")),
179    };
180    ctx.count_string_ops(s.len())?;
181    let step = extract_int_or_none(&a[3]).unwrap_or(1);
182    if step == 0 {
183        return Err(ExpressionError::new("Slice step cannot be zero"));
184    }
185    let chars: Vec<char> = s.chars().collect();
186    let len = chars.len() as i64;
187    let start = extract_int_or_none(&a[1]);
188    let stop = extract_int_or_none(&a[2]);
189    let (sv, ev) = compute_slice_indices(len, start, stop, step);
190    let result: String = collect_indices(sv, ev, step)
191        .into_iter()
192        .filter(|&i| i < chars.len())
193        .map(|i| chars[i])
194        .collect();
195    Ok(ExprValue::String(result))
196}
197
198pub fn slice_range(ctx: Ctx, a: &[ExprValue]) -> R {
199    let r = match &a[0] {
200        ExprValue::RangeExpr(r) => r,
201        _ => return Err(ExpressionError::type_error("type error")),
202    };
203    let step = extract_int_or_none(&a[3]).unwrap_or(1);
204    if step == 0 {
205        return Err(ExpressionError::new("Slice step cannot be zero"));
206    }
207    let len = r.len() as i64;
208    let start = extract_int_or_none(&a[1]);
209    let stop = extract_int_or_none(&a[2]);
210    let (s, e) = compute_slice_indices(len, start, stop, step);
211    if step > 0 {
212        // Forward slice → return RangeExpr
213        Ok(ExprValue::RangeExpr(r.slice(s, e, step)?))
214    } else {
215        // Reverse slice → return list (RangeExpr can't represent descending order)
216        let result: Vec<ExprValue> = collect_indices(s, e, step)
217            .into_iter()
218            .filter_map(|i| r.get(i as i64).map(ExprValue::Int))
219            .collect();
220        ctx.count_ops(result.len())?;
221        Ok(ExprValue::make_list_checked(
222            ctx,
223            result,
224            crate::types::ExprType::INT,
225        )?)
226    }
227}