1use serde_json::Value;
2
3use crate::{apply_stats_operation, Error, print_data_info, value_to_string, string_ops};
4
5pub fn apply_simple_filter(data: Vec<Value>, filter: &str) -> Result<Vec<Value>, Error> {
6 if filter.starts_with("select(") && filter.ends_with(")") {
7 let condition = &filter[7..filter.len() - 1];
9
10 if condition.contains(" | ") {
12 apply_filter_with_string_operations(data, condition)
13 } else {
14 apply_existing_simple_filter(data, condition)
15 }
16 } else {
17 Err(Error::InvalidQuery(format!(
18 "Unsupported filter: {}",
19 filter
20 )))
21 }
22}
23
24fn apply_filter_with_string_operations(data: Vec<Value>, condition: &str) -> Result<Vec<Value>, Error> {
26 let parts: Vec<&str> = condition.split(" | ").map(|s| s.trim()).collect();
27
28 if parts.len() < 2 {
29 return Err(Error::InvalidQuery("Invalid filter condition".to_string()));
30 }
31
32 let field_access = parts[0];
33 let string_operations: Vec<&str> = parts[1..].to_vec();
34
35 let last_operation = string_operations.last().ok_or_else(|| {
37 Error::InvalidQuery("Missing comparison operation".to_string())
38 })?;
39
40 if !is_comparison_operation(last_operation) {
41 return Err(Error::InvalidQuery("Last operation must be a comparison".to_string()));
42 }
43
44 let mut results = Vec::new();
45
46 for item in data {
47 let field_value = extract_field_value(&item, field_access)?;
49
50 let final_value = string_ops::apply_string_pipeline(&field_value, &string_operations)?;
52
53 if let Value::Bool(true) = final_value {
55 results.push(item);
56 }
57 }
58
59 Ok(results)
60}
61
62fn is_comparison_operation(operation: &str) -> bool {
64 operation.starts_with("contains(") ||
65 operation.starts_with("starts_with(") ||
66 operation.starts_with("ends_with(") ||
67 operation == "==" ||
68 operation == "!=" ||
69 operation.starts_with("== ") ||
70 operation.starts_with("!= ")
71}
72
73fn apply_existing_simple_filter(data: Vec<Value>, condition: &str) -> Result<Vec<Value>, Error> {
75 let (field_path, operator, value) = parse_condition(condition)?;
77
78 let filtered: Vec<Value> = data
80 .into_iter()
81 .filter(|item| evaluate_condition(item, &field_path, &operator, &value))
82 .collect();
83
84 Ok(filtered)
85}
86
87pub fn apply_pipeline_operation(data: Vec<Value>, operation: &str) -> Result<Vec<Value>, Error> {
88 let trimmed_op = operation.trim();
89
90 if trimmed_op.starts_with("select(") && trimmed_op.ends_with(")") {
91 apply_simple_filter(data, trimmed_op)
93 } else if trimmed_op == "count" {
94 if is_grouped_data(&data) {
96 apply_aggregation_to_groups(data, "count", "")
97 } else {
98 let count = data.len();
99 let count_value = Value::Number(serde_json::Number::from(count));
100 Ok(vec![count_value])
101 }
102 } else if trimmed_op.starts_with("map(") && trimmed_op.ends_with(")") {
103 apply_map_operation(data, trimmed_op)
104 } else if trimmed_op.starts_with("select_fields(") && trimmed_op.ends_with(")") {
105 let fields_str = &trimmed_op[14..trimmed_op.len() - 1]; let field_list: Vec<String> = fields_str
108 .split(',')
109 .map(|s| s.trim().to_string())
110 .collect();
111
112 apply_field_selection(data, field_list)
113 } else if trimmed_op == "info" {
114 print_data_info(&data);
116 Ok(vec![]) } else if trimmed_op.starts_with("sum(") && trimmed_op.ends_with(")") {
118 let field = &trimmed_op[4..trimmed_op.len() - 1];
120 let field_name = field.trim_start_matches('.');
121
122 if is_grouped_data(&data) {
123 apply_aggregation_to_groups(data, "sum", field_name)
124 } else {
125 let sum: f64 = data
126 .iter()
127 .filter_map(|item| item.get(field_name))
128 .filter_map(|val| val.as_f64())
129 .sum();
130
131 let round_sum = if sum.fract() == 0.0 {
132 sum
133 } else {
134 (sum * 10.0).round() / 10.0
135 };
136 let sum_value = Value::Number(serde_json::Number::from_f64(round_sum).unwrap());
137 Ok(vec![sum_value])
138 }
139 } else if trimmed_op.starts_with("avg(") && trimmed_op.ends_with(")") {
140 let field = &trimmed_op[4..trimmed_op.len() - 1];
142 let field_name = field.trim_start_matches('.');
143
144 if is_grouped_data(&data) {
145 apply_aggregation_to_groups(data, "avg", field_name)
146 } else {
147 let values: Vec<f64> = data
148 .iter()
149 .filter_map(|item| item.get(field_name))
150 .filter_map(|val| val.as_f64())
151 .collect();
152
153 if values.is_empty() {
154 Ok(vec![Value::Null])
155 } else {
156 let avg = values.iter().sum::<f64>() / values.len() as f64;
157 let round_avg = (avg * 10.0).round() / 10.0;
158 let avg_value = Value::Number(serde_json::Number::from_f64(round_avg).unwrap());
159 Ok(vec![avg_value])
160 }
161 }
162 } else if trimmed_op.starts_with("min(") && trimmed_op.ends_with(")") {
163 let field = &trimmed_op[4..trimmed_op.len() - 1];
165 let field_name = field.trim_start_matches('.');
166
167 if is_grouped_data(&data) {
168 apply_aggregation_to_groups(data, "min", field_name)
169 } else {
170 let min_val = data
171 .iter()
172 .filter_map(|item| item.get(field_name))
173 .filter_map(|val| val.as_f64())
174 .fold(f64::INFINITY, f64::min);
175
176 if min_val == f64::INFINITY {
177 Ok(vec![Value::Null])
178 } else {
179 let min_value = Value::Number(serde_json::Number::from_f64(min_val).unwrap());
180 Ok(vec![min_value])
181 }
182 }
183 } else if trimmed_op.starts_with("max(") && trimmed_op.ends_with(")") {
184 let field = &trimmed_op[4..trimmed_op.len() - 1];
186 let field_name = field.trim_start_matches('.');
187
188 if is_grouped_data(&data) {
189 apply_aggregation_to_groups(data, "max", field_name)
190 } else {
191 let max_val = data
192 .iter()
193 .filter_map(|item| item.get(field_name))
194 .filter_map(|val| val.as_f64())
195 .fold(f64::NEG_INFINITY, f64::max);
196
197 if max_val == f64::NEG_INFINITY {
198 Ok(vec![Value::Null])
199 } else {
200 let max_value = Value::Number(serde_json::Number::from_f64(max_val).unwrap());
201 Ok(vec![max_value])
202 }
203 }
204 } else if trimmed_op.starts_with("group_by(") && trimmed_op.ends_with(")") {
205 let field = &trimmed_op[9..trimmed_op.len() - 1];
207 let field_name = field.trim_start_matches('.');
208
209 let grouped = group_data_by_field(data, field_name)?;
210 Ok(grouped)
211 } else if trimmed_op == "unique" {
212 let result = apply_stats_operation(&data, "unique", None)?;
214 if let Value::Array(arr) = result {
215 Ok(arr)
216 } else {
217 Ok(vec![result])
218 }
219 } else if trimmed_op == "sort" {
220 let result = apply_stats_operation(&data, "sort", None)?;
222 if let Value::Array(arr) = result {
223 Ok(arr)
224 } else {
225 Ok(vec![result])
226 }
227 } else if trimmed_op == "length" {
228 let result = apply_stats_operation(&data, "length", None)?;
230 Ok(vec![result])
231 } else if trimmed_op == "median" {
232 let result = apply_stats_operation(&data, "median", None)?;
234 Ok(vec![result])
235 } else if trimmed_op == "stddev" {
236 let result = apply_stats_operation(&data, "stddev", None)?;
238 Ok(vec![result])
239 } else if trimmed_op.starts_with("unique(") && trimmed_op.ends_with(")") {
240 let field = &trimmed_op[7..trimmed_op.len() - 1];
242 let field_name = field.trim_start_matches('.');
243 let result = apply_stats_operation(&data, "unique", Some(field_name))?;
244 if let Value::Array(arr) = result {
245 Ok(arr)
246 } else {
247 Ok(vec![result])
248 }
249 } else if trimmed_op.starts_with("sort(") && trimmed_op.ends_with(")") {
250 let field = &trimmed_op[5..trimmed_op.len() - 1];
252 let field_name = field.trim_start_matches('.');
253 let result = apply_stats_operation(&data, "sort", Some(field_name))?;
254 if let Value::Array(arr) = result {
255 Ok(arr)
256 } else {
257 Ok(vec![result])
258 }
259 } else if trimmed_op.starts_with("median(") && trimmed_op.ends_with(")") {
260 let field = &trimmed_op[7..trimmed_op.len() - 1];
262 let field_name = field.trim_start_matches('.');
263 let result = apply_stats_operation(&data, "median", Some(field_name))?;
264 Ok(vec![result])
265 } else if trimmed_op.starts_with("stddev(") && trimmed_op.ends_with(")") {
266 let field = &trimmed_op[7..trimmed_op.len() - 1];
268 let field_name = field.trim_start_matches('.');
269 let result = apply_stats_operation(&data, "stddev", Some(field_name))?;
270 Ok(vec![result])
271 } else {
272 Err(Error::InvalidQuery(format!(
274 "Unsupported operation: '{}' (length: {}, starts with 'map(': {}, ends with ')': {})",
275 trimmed_op,
276 trimmed_op.len(),
277 trimmed_op.starts_with("map("),
278 trimmed_op.ends_with(")")
279 )))
280 }
281}
282
283fn apply_map_operation(data: Vec<Value>, operation: &str) -> Result<Vec<Value>, Error> {
285 let content = &operation[4..operation.len() - 1]; let (field_access, string_operations) = parse_map_content(content)?;
289
290 let mut results = Vec::new();
291
292 for item in data {
293 let field_value = extract_field_value(&item, &field_access)?;
295
296 let transformed_value = apply_string_operations(&field_value, &string_operations)?;
298
299 let result = update_or_create_value(&item, &field_access, transformed_value)?;
301 results.push(result);
302 }
303
304 Ok(results)
305}
306
307fn parse_map_content(content: &str) -> Result<(String, Vec<String>), Error> {
309 let parts: Vec<&str> = content.split('|').map(|s| s.trim()).collect();
310
311 if parts.is_empty() {
312 return Err(Error::InvalidQuery("Empty map operation".to_string()));
313 }
314
315 let field_access = parts[0].to_string();
317
318 let string_operations: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
320
321 Ok((field_access, string_operations))
322}
323
324fn extract_field_value(item: &Value, field_access: &str) -> Result<Value, Error> {
326 if field_access == "." {
327 return Ok(item.clone());
329 }
330
331 if field_access.starts_with('.') {
332 let field_name = &field_access[1..]; if let Some(value) = item.get(field_name) {
335 Ok(value.clone())
336 } else {
337 Err(Error::InvalidQuery(format!("Field '{}' not found", field_name)))
338 }
339 } else {
340 Err(Error::InvalidQuery(format!("Invalid field access: {}", field_access)))
341 }
342}
343
344fn apply_string_operations(value: &Value, operations: &[String]) -> Result<Value, Error> {
346 if operations.is_empty() {
347 return Ok(value.clone());
348 }
349
350 let operations_str: Vec<&str> = operations.iter().map(|s| s.as_str()).collect();
351 string_ops::apply_string_pipeline(value, &operations_str)
352}
353
354fn update_or_create_value(original: &Value, field_access: &str, new_value: Value) -> Result<Value, Error> {
356 if field_access == "." {
357 Ok(new_value)
359 } else if field_access.starts_with('.') {
360 let field_name = &field_access[1..];
361
362 if let Value::Object(mut obj) = original.clone() {
364 obj.insert(field_name.to_string(), new_value);
365 Ok(Value::Object(obj))
366 } else {
367 let mut new_obj = serde_json::Map::new();
369 new_obj.insert(field_name.to_string(), new_value);
370 Ok(Value::Object(new_obj))
371 }
372 } else {
373 Err(Error::InvalidQuery(format!("Invalid field access: {}", field_access)))
374 }
375}
376
377fn apply_field_selection(data: Vec<Value>, field_list: Vec<String>) -> Result<Vec<Value>, Error> {
378 let mut results = Vec::new();
379
380 for item in data {
381 if let Value::Object(obj) = item {
382 let mut selected_obj = serde_json::Map::new();
383
384 for field_name in &field_list {
386 if let Some(value) = obj.get(field_name) {
387 selected_obj.insert(field_name.clone(), value.clone());
388 }
389 }
390
391 results.push(Value::Object(selected_obj));
392 } else {
393 return Err(Error::InvalidQuery(
395 "select_fields can only be applied to objects".into(),
396 ));
397 }
398 }
399
400 Ok(results)
401}
402
403fn group_data_by_field(data: Vec<Value>, field_name: &str) -> Result<Vec<Value>, Error> {
404 use std::collections::HashMap;
405
406 let mut groups: HashMap<String, Vec<Value>> = HashMap::new();
407
408 for item in data {
410 if let Some(field_value) = item.get(field_name) {
411 let key = value_to_string(field_value);
412 groups.entry(key).or_default().push(item);
413 }
414 }
415
416 let result: Vec<Value> = groups
418 .into_iter()
419 .map(|(group_name, group_items)| {
420 let mut group_obj = serde_json::Map::new();
421 group_obj.insert("group".to_string(), Value::String(group_name));
422 group_obj.insert("items".to_string(), Value::Array(group_items));
423 Value::Object(group_obj)
424 })
425 .collect();
426
427 Ok(result)
428}
429
430fn parse_condition(condition: &str) -> Result<(String, String, String), Error> {
431 let condition = condition.trim();
433
434 if let Some(pos) = condition.find(" > ") {
436 let field = condition[..pos].trim().to_string();
437 let value = condition[pos + 3..].trim().to_string();
438 return Ok((field, ">".to_string(), value));
439 }
440
441 if let Some(pos) = condition.find(" < ") {
442 let field = condition[..pos].trim().to_string();
443 let value = condition[pos + 3..].trim().to_string();
444 return Ok((field, "<".to_string(), value));
445 }
446
447 if let Some(pos) = condition.find(" == ") {
448 let field = condition[..pos].trim().to_string();
449 let value = condition[pos + 4..].trim().to_string();
450 return Ok((field, "==".to_string(), value));
451 }
452
453 if let Some(pos) = condition.find(" != ") {
454 let field = condition[..pos].trim().to_string();
455 let value = condition[pos + 4..].trim().to_string();
456 return Ok((field, "!=".to_string(), value));
457 }
458
459 Err(Error::InvalidQuery("Invalid condition format".into()))
460}
461
462fn evaluate_condition(item: &Value, field_path: &str, operator: &str, value: &str) -> bool {
463 let field_name = if field_path.starts_with('.') {
465 &field_path[1..]
466 } else {
467 field_path
468 };
469
470 let field_value = match item.get(field_name) {
471 Some(val) => val,
472 None => return false, };
474
475 match operator {
476 ">" => compare_greater(field_value, value),
477 "<" => compare_less(field_value, value),
478 "==" => compare_equal(field_value, value),
479 "!=" => !compare_equal(field_value, value),
480 _ => false,
481 }
482}
483
484fn compare_greater(field_value: &Value, target: &str) -> bool {
485 match field_value {
486 Value::Number(n) => {
487 if let Ok(target_num) = target.parse::<f64>() {
488 n.as_f64().unwrap_or(0.0) > target_num
489 } else {
490 false
491 }
492 }
493 _ => false,
494 }
495}
496
497fn compare_less(field_value: &Value, target: &str) -> bool {
498 match field_value {
499 Value::Number(n) => {
500 if let Ok(target_num) = target.parse::<f64>() {
501 n.as_f64().unwrap_or(0.0) < target_num
502 } else {
503 false
504 }
505 }
506 _ => false,
507 }
508}
509
510fn compare_equal(field_value: &Value, target: &str) -> bool {
511 match field_value {
512 Value::String(s) => {
513 let target_clean = target.trim_matches('"');
515 s == target_clean
516 }
517 Value::Number(n) => {
518 if let Ok(target_num) = target.parse::<f64>() {
519 n.as_f64().unwrap_or(0.0) == target_num
520 } else {
521 false
522 }
523 }
524 Value::Bool(b) => match target {
525 "true" => *b,
526 "false" => !*b,
527 _ => false,
528 },
529 _ => false,
530 }
531}
532
533fn is_grouped_data(data: &[Value]) -> bool {
534 data.iter().all(|item| {
535 if let Value::Object(obj) = item {
536 obj.contains_key("group") && obj.contains_key("items")
537 } else {
538 false
539 }
540 })
541}
542
543fn apply_aggregation_to_groups(
544 data: Vec<Value>,
545 operation: &str,
546 field_name: &str,
547) -> Result<Vec<Value>, Error> {
548 let mut results = Vec::new();
549
550 for group_data in data {
551 if let Value::Object(group_obj) = group_data {
552 let group_name = group_obj.get("group").unwrap();
553 let items = group_obj.get("items").and_then(|v| v.as_array()).unwrap();
554
555 let aggregated_value = match operation {
557 "avg" => calculate_avg(items, field_name)?,
558 "sum" => calculate_sum(items, field_name)?,
559 "count" => Value::Number(serde_json::Number::from(items.len())),
560 "min" => calculate_min(items, field_name)?,
561 "max" => calculate_max(items, field_name)?,
562 _ => Value::Null,
563 };
564
565 let mut result_obj = serde_json::Map::new();
567 result_obj.insert("group".to_string(), group_name.clone());
568 result_obj.insert(operation.to_string(), aggregated_value);
569 results.push(Value::Object(result_obj));
570 }
571 }
572
573 Ok(results)
574}
575
576fn calculate_avg(items: &[Value], field_name: &str) -> Result<Value, Error> {
577 let values: Vec<f64> = items
578 .iter()
579 .filter_map(|item| item.get(field_name))
580 .filter_map(|val| val.as_f64())
581 .collect();
582
583 if values.is_empty() {
584 Ok(Value::Null)
585 } else {
586 let avg = values.iter().sum::<f64>() / values.len() as f64;
587 let rounded_avg = (avg * 10.0).round() / 10.0;
588 Ok(Value::Number(
589 serde_json::Number::from_f64(rounded_avg).unwrap(),
590 ))
591 }
592}
593
594fn calculate_sum(items: &[Value], field_name: &str) -> Result<Value, Error> {
595 let sum: f64 = items
596 .iter()
597 .filter_map(|item| item.get(field_name))
598 .filter_map(|val| val.as_f64())
599 .sum();
600
601 let rounded_sum = if sum.fract() == 0.0 {
602 sum
603 } else {
604 (sum * 10.0).round() / 10.0
605 };
606
607 Ok(Value::Number(
608 serde_json::Number::from_f64(rounded_sum).unwrap(),
609 ))
610}
611
612fn calculate_min(items: &[Value], field_name: &str) -> Result<Value, Error> {
613 let min_val = items
614 .iter()
615 .filter_map(|item| item.get(field_name))
616 .filter_map(|val| val.as_f64())
617 .fold(f64::INFINITY, f64::min);
618
619 if min_val == f64::INFINITY {
620 Ok(Value::Null)
621 } else {
622 Ok(Value::Number(
623 serde_json::Number::from_f64(min_val).unwrap(),
624 ))
625 }
626}
627
628fn calculate_max(items: &[Value], field_name: &str) -> Result<Value, Error> {
629 let max_val = items
630 .iter()
631 .filter_map(|item| item.get(field_name))
632 .filter_map(|val| val.as_f64())
633 .fold(f64::NEG_INFINITY, f64::max);
634
635 if max_val == f64::NEG_INFINITY {
636 Ok(Value::Null)
637 } else {
638 Ok(Value::Number(
639 serde_json::Number::from_f64(max_val).unwrap(),
640 ))
641 }
642}