1use serde_json::Value;
2
3use crate::{Error, print_data_info, value_to_string};
4
5pub fn apply_simple_filter(data: Vec<Value>, filter: &str) -> Result<Vec<Value>, Error> {
6 if filter.starts_with("select(") && filter.ends_with(")") {
7 let condition = &filter[7..filter.len() - 1];
9
10 let (field_path, operator, value) = parse_condition(condition)?;
13
14 let filtered: Vec<Value> = data
17 .into_iter()
18 .filter(|item| evaluate_condition(item, &field_path, &operator, &value))
19 .collect();
20
21 Ok(filtered)
22 } else {
23 Err(Error::InvalidQuery(format!(
24 "Unsupported filter: {}",
25 filter
26 )))
27 }
28}
29
30pub fn apply_pipeline_operation(data: Vec<Value>, operation: &str) -> Result<Vec<Value>, Error> {
31 if operation.starts_with("select(") && operation.ends_with(")") {
32 apply_simple_filter(data, operation)
35 } else if operation == "count" {
36 if is_grouped_data(&data) {
39 apply_aggregation_to_groups(data, "count", "")
40 } else {
41 let count = data.len();
42 let count_value = Value::Number(serde_json::Number::from(count));
43 Ok(vec![count_value])
44 }
45 } else if operation.starts_with("select_fields(") && operation.ends_with(")") {
46 let fields_str = &operation[14..operation.len() - 1]; let field_list: Vec<String> = fields_str
49 .split(',')
50 .map(|s| s.trim().to_string())
51 .collect();
52
53 apply_field_selection(data, field_list)
54 } else if operation == "info" {
55 print_data_info(&data);
58 Ok(vec![]) } else if operation.starts_with("sum(") && operation.ends_with(")") {
60 let field = &operation[4..operation.len() - 1];
63 let field_name = field.trim_start_matches('.');
64
65 if is_grouped_data(&data) {
66 apply_aggregation_to_groups(data, "sum", field_name)
67 } else {
68 let sum: f64 = data
69 .iter()
70 .filter_map(|item| item.get(field_name))
71 .filter_map(|val| val.as_f64())
72 .sum();
73
74 let round_sum = if sum.fract() == 0.0 {
75 sum
76 } else {
77 (sum * 10.0).round() / 10.0
78 };
79 let sum_value = Value::Number(serde_json::Number::from_f64(round_sum).unwrap());
80 Ok(vec![sum_value])
81 }
82 } else if operation.starts_with("avg(") && operation.ends_with(")") {
83 let field = &operation[4..operation.len() - 1];
86 let field_name = field.trim_start_matches('.');
87
88 if is_grouped_data(&data) {
89 apply_aggregation_to_groups(data, "avg", field_name)
90 } else {
91 let values: Vec<f64> = data
92 .iter()
93 .filter_map(|item| item.get(field_name))
94 .filter_map(|val| val.as_f64())
95 .collect();
96
97 if values.is_empty() {
98 Ok(vec![Value::Null])
99 } else {
100 let avg = values.iter().sum::<f64>() / values.len() as f64;
101 let round_avg = (avg * 10.0).round() / 10.0;
102 let avg_value = Value::Number(serde_json::Number::from_f64(round_avg).unwrap());
103 Ok(vec![avg_value])
104 }
105 }
106 } else if operation.starts_with("min(") && operation.ends_with(")") {
107 let field = &operation[4..operation.len() - 1];
110 let field_name = field.trim_start_matches('.');
111
112 if is_grouped_data(&data) {
113 apply_aggregation_to_groups(data, "min", field_name)
114 } else {
115 let min_val = data
116 .iter()
117 .filter_map(|item| item.get(field_name))
118 .filter_map(|val| val.as_f64())
119 .fold(f64::INFINITY, f64::min);
120
121 if min_val == f64::INFINITY {
122 Ok(vec![Value::Null])
123 } else {
124 let min_value = Value::Number(serde_json::Number::from_f64(min_val).unwrap());
125 Ok(vec![min_value])
126 }
127 }
128 } else if operation.starts_with("max(") && operation.ends_with(")") {
129 let field = &operation[4..operation.len() - 1];
132 let field_name = field.trim_start_matches('.');
133
134 if is_grouped_data(&data) {
135 apply_aggregation_to_groups(data, "max", field_name)
136 } else {
137 let max_val = data
138 .iter()
139 .filter_map(|item| item.get(field_name))
140 .filter_map(|val| val.as_f64())
141 .fold(f64::NEG_INFINITY, f64::max);
142
143 if max_val == f64::NEG_INFINITY {
144 Ok(vec![Value::Null])
145 } else {
146 let max_value = Value::Number(serde_json::Number::from_f64(max_val).unwrap());
147 Ok(vec![max_value])
148 }
149 }
150 } else if operation.starts_with("group_by(") && operation.ends_with(")") {
151 let field = &operation[9..operation.len() - 1];
154 let field_name = field.trim_start_matches('.');
155
156 let grouped = group_data_by_field(data, field_name)?;
157 Ok(grouped)
158 } else {
159 Err(Error::InvalidQuery(format!(
160 "Unsupported operation: {}",
161 operation
162 )))
163 }
164}
165
166fn apply_field_selection(data: Vec<Value>, field_list: Vec<String>) -> Result<Vec<Value>, Error> {
167 let mut results = Vec::new();
168
169 for item in data {
170 if let Value::Object(obj) = item {
171 let mut selected_obj = serde_json::Map::new();
172
173 for field_name in &field_list {
175 if let Some(value) = obj.get(field_name) {
176 selected_obj.insert(field_name.clone(), value.clone());
177 }
178 }
179
180 results.push(Value::Object(selected_obj));
181 } else {
182 return Err(Error::InvalidQuery(
184 "select_fields can only be applied to objects".into(),
185 ));
186 }
187 }
188
189 Ok(results)
190}
191
192fn group_data_by_field(data: Vec<Value>, field_name: &str) -> Result<Vec<Value>, Error> {
193 use std::collections::HashMap;
194
195 let mut groups: HashMap<String, Vec<Value>> = HashMap::new();
196
197 for item in data {
200 if let Some(field_value) = item.get(field_name) {
201 let key = value_to_string(field_value);
202 groups.entry(key).or_insert_with(Vec::new).push(item);
203 }
204 }
205
206 let result: Vec<Value> = groups
209 .into_iter()
210 .map(|(group_name, group_items)| {
211 let mut group_obj = serde_json::Map::new();
212 group_obj.insert("group".to_string(), Value::String(group_name));
213 group_obj.insert("items".to_string(), Value::Array(group_items));
214 Value::Object(group_obj)
215 })
216 .collect();
217
218 Ok(result)
219}
220
221fn parse_condition(condition: &str) -> Result<(String, String, String), Error> {
222 let condition = condition.trim();
225
226 if let Some(pos) = condition.find(" > ") {
229 let field = condition[..pos].trim().to_string();
230 let value = condition[pos + 3..].trim().to_string();
231 return Ok((field, ">".to_string(), value));
232 }
233
234 if let Some(pos) = condition.find(" < ") {
235 let field = condition[..pos].trim().to_string();
236 let value = condition[pos + 3..].trim().to_string();
237 return Ok((field, "<".to_string(), value));
238 }
239
240 if let Some(pos) = condition.find(" == ") {
241 let field = condition[..pos].trim().to_string();
242 let value = condition[pos + 4..].trim().to_string();
243 return Ok((field, "==".to_string(), value));
244 }
245
246 if let Some(pos) = condition.find(" != ") {
247 let field = condition[..pos].trim().to_string();
248 let value = condition[pos + 4..].trim().to_string();
249 return Ok((field, "!=".to_string(), value));
250 }
251
252 Err(Error::InvalidQuery("Invalid condition format".into()))
253}
254
255fn evaluate_condition(item: &Value, field_path: &str, operator: &str, value: &str) -> bool {
256 let field_name = if field_path.starts_with('.') {
259 &field_path[1..]
260 } else {
261 field_path
262 };
263
264 let field_value = match item.get(field_name) {
265 Some(val) => val,
266 None => return false, };
268
269 match operator {
270 ">" => compare_greater(field_value, value),
271 "<" => compare_less(field_value, value),
272 "==" => compare_equal(field_value, value),
273 "!=" => !compare_equal(field_value, value),
274 _ => false,
275 }
276}
277
278fn compare_greater(field_value: &Value, target: &str) -> bool {
279 match field_value {
280 Value::Number(n) => {
281 if let Ok(target_num) = target.parse::<f64>() {
282 n.as_f64().unwrap_or(0.0) > target_num
283 } else {
284 false
285 }
286 }
287 _ => false,
288 }
289}
290
291fn compare_less(field_value: &Value, target: &str) -> bool {
292 match field_value {
293 Value::Number(n) => {
294 if let Ok(target_num) = target.parse::<f64>() {
295 n.as_f64().unwrap_or(0.0) < target_num
296 } else {
297 false
298 }
299 }
300 _ => false,
301 }
302}
303
304fn compare_equal(field_value: &Value, target: &str) -> bool {
305 match field_value {
306 Value::String(s) => {
307 let target_clean = target.trim_matches('"');
310 s == target_clean
311 }
312 Value::Number(n) => {
313 if let Ok(target_num) = target.parse::<f64>() {
314 n.as_f64().unwrap_or(0.0) == target_num
315 } else {
316 false
317 }
318 }
319 Value::Bool(b) => match target {
320 "true" => *b,
321 "false" => !*b,
322 _ => false,
323 },
324 _ => false,
325 }
326}
327
328fn is_grouped_data(data: &[Value]) -> bool {
329 data.iter().all(|item| {
330 if let Value::Object(obj) = item {
331 obj.contains_key("group") && obj.contains_key("items")
332 } else {
333 false
334 }
335 })
336}
337
338fn apply_aggregation_to_groups(
339 data: Vec<Value>,
340 operation: &str,
341 field_name: &str,
342) -> Result<Vec<Value>, Error> {
343 let mut results = Vec::new();
344
345 for group_data in data {
346 if let Value::Object(group_obj) = group_data {
347 let group_name = group_obj.get("group").unwrap();
348 let items = group_obj.get("items").and_then(|v| v.as_array()).unwrap();
349
350 let aggregated_value = match operation {
353 "avg" => calculate_avg(items, field_name)?,
354 "sum" => calculate_sum(items, field_name)?,
355 "count" => Value::Number(serde_json::Number::from(items.len())),
356 "min" => calculate_min(items, field_name)?,
357 "max" => calculate_max(items, field_name)?,
358 _ => Value::Null,
359 };
360
361 let mut result_obj = serde_json::Map::new();
364 result_obj.insert("group".to_string(), group_name.clone());
365 result_obj.insert(operation.to_string(), aggregated_value);
366 results.push(Value::Object(result_obj));
367 }
368 }
369
370 Ok(results)
371}
372
373fn calculate_avg(items: &[Value], field_name: &str) -> Result<Value, Error> {
374 let values: Vec<f64> = items
375 .iter()
376 .filter_map(|item| item.get(field_name))
377 .filter_map(|val| val.as_f64())
378 .collect();
379
380 if values.is_empty() {
381 Ok(Value::Null)
382 } else {
383 let avg = values.iter().sum::<f64>() / values.len() as f64;
384 let rounded_avg = (avg * 10.0).round() / 10.0;
385 Ok(Value::Number(
386 serde_json::Number::from_f64(rounded_avg).unwrap(),
387 ))
388 }
389}
390
391fn calculate_sum(items: &[Value], field_name: &str) -> Result<Value, Error> {
392 let sum: f64 = items
393 .iter()
394 .filter_map(|item| item.get(field_name))
395 .filter_map(|val| val.as_f64())
396 .sum();
397
398 let rounded_sum = if sum.fract() == 0.0 {
399 sum
400 } else {
401 (sum * 10.0).round() / 10.0
402 };
403
404 Ok(Value::Number(
405 serde_json::Number::from_f64(rounded_sum).unwrap(),
406 ))
407}
408
409fn calculate_min(items: &[Value], field_name: &str) -> Result<Value, Error> {
410 let min_val = items
411 .iter()
412 .filter_map(|item| item.get(field_name))
413 .filter_map(|val| val.as_f64())
414 .fold(f64::INFINITY, f64::min);
415
416 if min_val == f64::INFINITY {
417 Ok(Value::Null)
418 } else {
419 Ok(Value::Number(
420 serde_json::Number::from_f64(min_val).unwrap(),
421 ))
422 }
423}
424
425fn calculate_max(items: &[Value], field_name: &str) -> Result<Value, Error> {
426 let max_val = items
427 .iter()
428 .filter_map(|item| item.get(field_name))
429 .filter_map(|val| val.as_f64())
430 .fold(f64::NEG_INFINITY, f64::max);
431
432 if max_val == f64::NEG_INFINITY {
433 Ok(Value::Null)
434 } else {
435 Ok(Value::Number(
436 serde_json::Number::from_f64(max_val).unwrap(),
437 ))
438 }
439}