Skip to main content

radiate_expr/expression/
projection.rs

1use crate::{AnyValue, Field, SelectExpr, expression::select::PathSegment};
2use std::collections::HashMap;
3
4pub trait ExprProjection {
5    fn project(&self, path: &SelectExpr) -> Option<AnyValue<'static>>;
6}
7
8impl<T> ExprProjection for Vec<T>
9where
10    T: Clone + Into<AnyValue<'static>>,
11{
12    fn project(&self, selector: &SelectExpr) -> Option<AnyValue<'static>> {
13        match selector {
14            SelectExpr::Path(path) => {
15                let mut current = AnyValue::Vector(self.iter().cloned().map(Into::into).collect());
16
17                for segment in path {
18                    current = match segment {
19                        PathSegment::Index(i) => current.get_index(*i)?,
20                        _ => return None,
21                    };
22                }
23
24                Some(current)
25            }
26            SelectExpr::Nth(n) => self.get(*n).cloned().map(Into::into),
27            SelectExpr::Element => Some(AnyValue::Vector(
28                self.iter().cloned().map(Into::into).collect(),
29            )),
30            _ => None,
31        }
32    }
33}
34
35impl<T> ExprProjection for HashMap<String, T>
36where
37    T: Clone + Into<AnyValue<'static>>,
38{
39    fn project(&self, selector: &SelectExpr) -> Option<AnyValue<'static>> {
40        match selector {
41            SelectExpr::Path(path) => {
42                let mut current = AnyValue::Struct(
43                    self.iter()
44                        .map(|(k, v)| {
45                            let cloned_value = v.clone().into();
46                            (Field::new(k.into(), cloned_value.dtype()), cloned_value)
47                        })
48                        .collect(),
49                );
50
51                for segment in path {
52                    current = match segment {
53                        PathSegment::Key(key) => current.get_key(key)?,
54                        PathSegment::Index(i) => current.get_index(*i)?,
55                        PathSegment::StructField(field) => current.get_field(field)?,
56                    };
57                }
58
59                Some(current)
60            }
61            SelectExpr::Field(key, field) => {
62                let value = self.get(&key.clone().into_string()?)?.clone().into();
63                value.get_field(field)
64            }
65            _ => None,
66        }
67    }
68}
69
70impl<'a> ExprProjection for AnyValue<'a> {
71    fn project(&self, selector: &SelectExpr) -> Option<AnyValue<'static>> {
72        match selector {
73            SelectExpr::Path(path) => {
74                let mut current = self.clone();
75
76                for segment in path {
77                    current = match segment {
78                        PathSegment::Key(key) => current.get_key(key)?,
79                        PathSegment::Index(i) => current.get_index(*i)?,
80                        PathSegment::StructField(field) => current.get_field(field)?,
81                    };
82                }
83
84                Some(current.into_static())
85            }
86            SelectExpr::Field(key, field) => {
87                let value = self.get_key(key)?.into_static();
88                value.get_field(field)
89            }
90            SelectExpr::Nth(n) => self.get_index(*n).map(|v| v.into_static()),
91            SelectExpr::Element => Some(self.clone().into_static()),
92        }
93    }
94}
95
96impl ExprProjection for f32 {
97    fn project(&self, _: &SelectExpr) -> Option<AnyValue<'static>> {
98        Some(AnyValue::Float32(*self))
99    }
100}
101
102impl ExprProjection for i32 {
103    fn project(&self, _: &SelectExpr) -> Option<AnyValue<'static>> {
104        Some(AnyValue::Int32(*self))
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::{AnyValue, Expr, ExprQuery, Field, expr, expression::select::PathBuilder};
112    use std::collections::HashMap;
113
114    fn f32_of(value: AnyValue<'_>) -> f32 {
115        value.extract::<f32>().unwrap()
116    }
117
118    fn i32_of(value: AnyValue<'_>) -> i32 {
119        value.extract::<i32>().unwrap()
120    }
121
122    #[test]
123    fn vec_projection_supports_nth() {
124        let values = vec![10i32, 20, 30];
125        let mut selector = SelectExpr::Nth(1);
126
127        let result = selector.dispatch(&values).unwrap();
128
129        assert_eq!(i32_of(result), 20);
130    }
131
132    #[test]
133    fn vec_projection_supports_path_index() {
134        let values = vec![10i32, 20, 30];
135        let mut selector = SelectExpr::Path(vec![PathSegment::Index(2)]);
136
137        let result = selector.dispatch(&values).unwrap();
138
139        assert_eq!(i32_of(result), 30);
140    }
141
142    #[test]
143    fn vec_projection_invalid_path_returns_null() {
144        let values = vec![10i32, 20, 30];
145        let mut selector =
146            SelectExpr::Path(vec![PathSegment::Key(AnyValue::from("nope").into_static())]);
147
148        let result = selector.dispatch(&values).unwrap_or(AnyValue::Null);
149
150        assert!(matches!(result, AnyValue::Null));
151    }
152
153    #[test]
154    fn hashmap_projection_supports_field() {
155        let mut inner = HashMap::new();
156        inner.insert("mean".to_string(), 12.5f32);
157
158        let mut selector: Expr = expr::path("mean").into();
159
160        let result = selector.dispatch(&inner).unwrap();
161
162        assert_eq!(f32_of(result), 12.5);
163    }
164
165    #[test]
166    fn hashmap_projection_supports_path_key() {
167        let mut map = HashMap::new();
168        map.insert("accuracy".to_string(), 0.91f32);
169
170        let mut selector = SelectExpr::Path(vec![PathSegment::Key(
171            AnyValue::from("accuracy").into_static(),
172        )]);
173
174        let result = selector.dispatch(&map).unwrap();
175
176        assert_eq!(f32_of(result), 0.91);
177    }
178
179    #[test]
180    fn hashmap_invalid_key_returns_null() {
181        let mut map = HashMap::new();
182        map.insert("accuracy".to_string(), 0.91f32);
183
184        let mut selector = SelectExpr::Path(vec![PathSegment::Key(
185            AnyValue::from("missing").into_static(),
186        )]);
187
188        let result = selector.dispatch(&map).unwrap_or(AnyValue::Null);
189
190        assert!(matches!(result, AnyValue::Null));
191    }
192
193    #[test]
194    fn nested_hashmap_vec_hashmap_path_works() {
195        let mut user1 = HashMap::new();
196        user1.insert("name".to_string(), AnyValue::from("alice").into_static());
197        user1.insert("score".to_string(), AnyValue::from(10.0f32).into_static());
198
199        let mut user2 = HashMap::new();
200        user2.insert("name".to_string(), AnyValue::from("bob").into_static());
201        user2.insert("score".to_string(), AnyValue::from(25.0f32).into_static());
202
203        let users = vec![
204            AnyValue::Struct(
205                user1
206                    .iter()
207                    .map(|(k, v)| (Field::new(k.clone().into(), v.dtype()), v.clone()))
208                    .collect(),
209            ),
210            AnyValue::Struct(
211                user2
212                    .iter()
213                    .map(|(k, v)| (Field::new(k.clone().into(), v.dtype()), v.clone()))
214                    .collect(),
215            ),
216        ];
217
218        let mut root = HashMap::new();
219        root.insert("users".to_string(), AnyValue::Vector(users));
220
221        let mut selector = SelectExpr::Path(vec![
222            PathSegment::Key(AnyValue::from("users").into_static()),
223            PathSegment::Index(1),
224            PathSegment::Key(AnyValue::from("name").into_static()),
225        ]);
226
227        let result = selector.dispatch(&root).unwrap();
228
229        match result {
230            AnyValue::Str(s) => assert_eq!(s, "bob"),
231            AnyValue::StrOwned(s) => assert_eq!(s, "bob"),
232            other => panic!("expected string, got {other:?}"),
233        }
234    }
235
236    #[test]
237    fn path_builder_builds_selector_expr() {
238        let expr: Expr = PathBuilder::default()
239            .key("users")
240            .index(0)
241            .key("name")
242            .into();
243
244        match expr {
245            Expr::Selector(SelectExpr::Path(path)) => {
246                assert_eq!(path.len(), 3);
247                assert!(matches!(&path[0], PathSegment::Key(_)));
248                assert!(matches!(&path[1], PathSegment::Index(0)));
249                assert!(matches!(&path[2], PathSegment::Key(_)));
250            }
251            other => panic!("expected Expr::Selector(Path), got {other:?}"),
252        }
253    }
254
255    #[test]
256    fn nested_numeric_path_can_be_compared_through_expr_tree() {
257        let mut inner = HashMap::new();
258        inner.insert("value".to_string(), 7.0f32);
259
260        let mut root = HashMap::new();
261        root.insert("metric".to_string(), inner);
262
263        let mut expr: Expr = PathBuilder::default().key("metric").key("value").into();
264        expr = expr.gt(5.0);
265
266        let result = expr.dispatch(&root).unwrap();
267
268        match result {
269            AnyValue::Bool(v) => assert!(v),
270            other => panic!("expected bool result, got {other:?}"),
271        }
272    }
273}