dsq_functions/builtin/
array_push.rs

1use dsq_shared::value::{value_from_any_value, Value};
2use dsq_shared::Result;
3use polars::prelude::*;
4
5use crate::inventory;
6use crate::FunctionRegistration;
7
8pub fn builtin_array_push(args: &[Value]) -> Result<Value> {
9    if args.len() < 2 {
10        return Err(dsq_shared::error::operation_error(
11            "array_push() expects at least 2 arguments",
12        ));
13    }
14
15    match &args[0] {
16        Value::Array(arr) => {
17            let mut new_arr = arr.clone();
18            new_arr.extend_from_slice(&args[1..]);
19            Ok(Value::Array(new_arr))
20        }
21        Value::Series(series) => {
22            if matches!(series.dtype(), DataType::List(_)) {
23                let list_chunked = series.list().unwrap();
24                if series.len() == 1 {
25                    match list_chunked.get_as_series(0) {
26                        Some(list_series) => {
27                            let mut arr = Vec::new();
28                            for i in 0..list_series.len() {
29                                if let Ok(val) = list_series.get(i) {
30                                    let value = value_from_any_value(val).unwrap_or(Value::Null);
31                                    arr.push(value);
32                                }
33                            }
34                            arr.extend_from_slice(&args[1..]);
35                            Ok(Value::Array(arr))
36                        }
37                        _ => Ok(Value::Array(args[1..].to_vec())),
38                    }
39                } else {
40                    Err(dsq_shared::error::operation_error(format!(
41                        "array_push() on series with {} elements not supported",
42                        series.len()
43                    )))
44                }
45            } else {
46                Err(dsq_shared::error::operation_error(
47                    "array_push() requires an array or list series",
48                ))
49            }
50        }
51        Value::DataFrame(df) => {
52            let value_to_push = &args[1];
53            let any_value = match value_to_push {
54                Value::Int(i) => AnyValue::Int64(*i),
55                Value::Float(f) => AnyValue::Float64(*f),
56                Value::String(s) => AnyValue::String(s),
57                Value::Bool(b) => AnyValue::Boolean(*b),
58                Value::Null => AnyValue::Null,
59                _ => AnyValue::Null, // For complex types
60            };
61            let mut new_series_vec = Vec::new();
62            for col_name in df.get_column_names() {
63                if let Ok(series) = df.column(col_name) {
64                    if matches!(series.dtype(), DataType::List(_)) {
65                        let list_chunked = series.list().unwrap();
66                        let mut new_lists = Vec::new();
67                        for i in 0..df.height() {
68                            match list_chunked.get_as_series(i) {
69                                Some(list_series) => {
70                                    let mut values = vec![];
71                                    for j in 0..list_series.len() {
72                                        values.push(list_series.get(j).unwrap());
73                                    }
74                                    values.push(any_value.clone());
75                                    new_lists.push(Series::new("".into(), values));
76                                }
77                                _ => {
78                                    new_lists.push(Series::new("".into(), vec![any_value.clone()]));
79                                }
80                            }
81                        }
82                        let new_list_series = Series::new(col_name.clone(), new_lists);
83                        new_series_vec.push(new_list_series.into());
84                    } else {
85                        let mut s = series.clone();
86                        s.rename(col_name.clone());
87                        new_series_vec.push(s);
88                    }
89                }
90            }
91            match DataFrame::new(new_series_vec) {
92                Ok(new_df) => Ok(Value::DataFrame(new_df)),
93                Err(e) => Err(dsq_shared::error::operation_error(format!(
94                    "array_push() failed on DataFrame: {}",
95                    e
96                ))),
97            }
98        }
99        _ => Err(dsq_shared::error::operation_error(format!(
100            "array_push() first argument must be an array, list series, or DataFrame, got {}",
101            args[0].type_name()
102        ))),
103    }
104}
105
106inventory::submit! {
107    FunctionRegistration {
108        name: "array_push",
109        func: builtin_array_push,
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use dsq_shared::value::Value;
117
118    #[test]
119    fn test_builtin_array_push_array() {
120        let arr = vec![Value::Int(1), Value::Int(2)];
121        let result = builtin_array_push(&[Value::Array(arr), Value::Int(3)]).unwrap();
122        match result {
123            Value::Array(pushed) => {
124                assert_eq!(pushed.len(), 3);
125                assert_eq!(pushed[0], Value::Int(1));
126                assert_eq!(pushed[1], Value::Int(2));
127                assert_eq!(pushed[2], Value::Int(3));
128            }
129            _ => panic!("Expected Array"),
130        }
131    }
132
133    #[test]
134    fn test_builtin_array_push_dataframe() {
135        let s1 = Series::new(PlSmallStr::from(""), &[1i64, 2i64]);
136        let s2 = Series::new(PlSmallStr::from(""), &[3i64]);
137        let list_series = Series::new(PlSmallStr::from("list_col"), &[s1, s2]).into();
138        let df = DataFrame::new(vec![list_series]).unwrap();
139        let result = builtin_array_push(&[Value::DataFrame(df), Value::Int(4)]).unwrap();
140        match result {
141            Value::DataFrame(new_df) => {
142                let list_col = new_df.column("list_col").unwrap().list().unwrap();
143                let first_list = list_col.get_as_series(0).unwrap();
144                assert_eq!(first_list.get(0).unwrap(), AnyValue::Int64(1));
145                assert_eq!(first_list.get(1).unwrap(), AnyValue::Int64(2));
146                assert_eq!(first_list.get(2).unwrap(), AnyValue::Int64(4));
147                let second_list = list_col.get_as_series(1).unwrap();
148                assert_eq!(second_list.get(0).unwrap(), AnyValue::Int64(3));
149                assert_eq!(second_list.get(1).unwrap(), AnyValue::Int64(4));
150            }
151            _ => panic!("Expected DataFrame"),
152        }
153    }
154
155    #[test]
156    fn test_builtin_array_push_multiple_values() {
157        let arr = vec![Value::Int(1)];
158        let result = builtin_array_push(&[
159            Value::Array(arr),
160            Value::Int(2),
161            Value::String("three".to_string()),
162        ])
163        .unwrap();
164        match result {
165            Value::Array(pushed) => {
166                assert_eq!(pushed.len(), 3);
167                assert_eq!(pushed[0], Value::Int(1));
168                assert_eq!(pushed[1], Value::Int(2));
169                assert_eq!(pushed[2], Value::String("three".to_string()));
170            }
171            _ => panic!("Expected Array"),
172        }
173    }
174
175    #[test]
176    fn test_builtin_array_push_empty_array() {
177        let arr = vec![];
178        let result = builtin_array_push(&[Value::Array(arr), Value::Int(1)]).unwrap();
179        match result {
180            Value::Array(pushed) => {
181                assert_eq!(pushed.len(), 1);
182                assert_eq!(pushed[0], Value::Int(1));
183            }
184            _ => panic!("Expected Array"),
185        }
186    }
187
188    #[test]
189    fn test_builtin_array_push_error_too_few_args() {
190        let result = builtin_array_push(&[Value::Array(vec![Value::Int(1)])]);
191        assert!(result.is_err());
192        assert!(result
193            .unwrap_err()
194            .to_string()
195            .contains("expects at least 2 arguments"));
196    }
197
198    #[test]
199    fn test_builtin_array_push_error_invalid_type() {
200        let result = builtin_array_push(&[Value::Int(1), Value::Int(2)]);
201        assert!(result.is_err());
202        assert!(result
203            .unwrap_err()
204            .to_string()
205            .contains("first argument must be an array"));
206    }
207
208    #[test]
209    fn test_array_push_registered_via_inventory() {
210        use crate::BuiltinRegistry;
211        let registry = BuiltinRegistry::new();
212        assert!(registry.functions.contains_key("array_push"));
213    }
214}