hawk_data/
filter.rs

1use serde_json::Value;
2
3use crate::{apply_stats_operation, Error, print_data_info, value_to_string, string_ops};
4
5pub fn apply_simple_filter(data: Vec<Value>, filter: &str) -> Result<Vec<Value>, Error> {
6    if filter.starts_with("select(") && filter.ends_with(")") {
7        // "select(.age > 30)" から ".age > 30" を抽出
8        let condition = &filter[7..filter.len() - 1];
9
10        // パイプラインがある場合の処理
11        if condition.contains(" | ") {
12            apply_filter_with_string_operations(data, condition)
13        } else {
14            apply_existing_simple_filter(data, condition)
15        }
16    } else {
17        Err(Error::InvalidQuery(format!(
18            "Unsupported filter: {}",
19            filter
20        )))
21    }
22}
23
24/// 文字列操作付きフィルタの適用
25fn apply_filter_with_string_operations(data: Vec<Value>, condition: &str) -> Result<Vec<Value>, Error> {
26    let parts: Vec<&str> = condition.split(" | ").map(|s| s.trim()).collect();
27    
28    if parts.len() < 2 {
29        return Err(Error::InvalidQuery("Invalid filter condition".to_string()));
30    }
31    
32    let field_access = parts[0];
33    let string_operations: Vec<&str> = parts[1..].to_vec();
34    
35    // 最後の操作は比較操作である必要がある
36    let last_operation = string_operations.last().ok_or_else(|| {
37        Error::InvalidQuery("Missing comparison operation".to_string())
38    })?;
39    
40    if !is_comparison_operation(last_operation) {
41        return Err(Error::InvalidQuery("Last operation must be a comparison".to_string()));
42    }
43    
44    let mut results = Vec::new();
45    
46    for item in data {
47        // フィールド値を取得
48        let field_value = extract_field_value(&item, field_access)?;
49        
50        // 文字列操作を適用(比較操作まで)
51        let final_value = string_ops::apply_string_pipeline(&field_value, &string_operations)?;
52        
53        // 比較結果が true の場合のみ追加
54        if let Value::Bool(true) = final_value {
55            results.push(item);
56        }
57    }
58    
59    Ok(results)
60}
61
62/// 比較操作かどうかを判定
63fn is_comparison_operation(operation: &str) -> bool {
64    operation.starts_with("contains(") ||
65    operation.starts_with("starts_with(") ||
66    operation.starts_with("ends_with(") ||
67    operation == "==" ||
68    operation == "!=" ||
69    operation.starts_with("== ") ||
70    operation.starts_with("!= ")
71}
72
73/// 既存のシンプルフィルタ処理
74fn apply_existing_simple_filter(data: Vec<Value>, condition: &str) -> Result<Vec<Value>, Error> {
75    // 条件をパース
76    let (field_path, operator, value) = parse_condition(condition)?;
77
78    // フィルタリングを実行
79    let filtered: Vec<Value> = data
80        .into_iter()
81        .filter(|item| evaluate_condition(item, &field_path, &operator, &value))
82        .collect();
83
84    Ok(filtered)
85}
86
87pub fn apply_pipeline_operation(data: Vec<Value>, operation: &str) -> Result<Vec<Value>, Error> {
88    let trimmed_op = operation.trim();
89    
90    if trimmed_op.starts_with("select(") && trimmed_op.ends_with(")") {
91        // フィルタリング操作
92        apply_simple_filter(data, trimmed_op)
93    } else if trimmed_op == "count" {
94        // カウント操作
95        if is_grouped_data(&data) {
96            apply_aggregation_to_groups(data, "count", "")
97        } else {
98            let count = data.len();
99            let count_value = Value::Number(serde_json::Number::from(count));
100            Ok(vec![count_value])
101        }
102    } else if trimmed_op.starts_with("map(") && trimmed_op.ends_with(")") {
103        apply_map_operation(data, trimmed_op)
104    } else if trimmed_op.starts_with("select_fields(") && trimmed_op.ends_with(")") {
105        // 複数フィールド選択
106        let fields_str = &trimmed_op[14..trimmed_op.len() - 1]; // "name,age,department"
107        let field_list: Vec<String> = fields_str
108            .split(',')
109            .map(|s| s.trim().to_string())
110            .collect();
111
112        apply_field_selection(data, field_list)
113    } else if trimmed_op == "info" {
114        // info操作
115        print_data_info(&data);
116        Ok(vec![]) // Return empty vector
117    } else if trimmed_op.starts_with("sum(") && trimmed_op.ends_with(")") {
118        // sum(.field) の処理
119        let field = &trimmed_op[4..trimmed_op.len() - 1];
120        let field_name = field.trim_start_matches('.');
121
122        if is_grouped_data(&data) {
123            apply_aggregation_to_groups(data, "sum", field_name)
124        } else {
125            let sum: f64 = data
126                .iter()
127                .filter_map(|item| item.get(field_name))
128                .filter_map(|val| val.as_f64())
129                .sum();
130
131            let round_sum = if sum.fract() == 0.0 {
132                sum
133            } else {
134                (sum * 10.0).round() / 10.0
135            };
136            let sum_value = Value::Number(serde_json::Number::from_f64(round_sum).unwrap());
137            Ok(vec![sum_value])
138        }
139    } else if trimmed_op.starts_with("avg(") && trimmed_op.ends_with(")") {
140        // avg(.field) の処理
141        let field = &trimmed_op[4..trimmed_op.len() - 1];
142        let field_name = field.trim_start_matches('.');
143
144        if is_grouped_data(&data) {
145            apply_aggregation_to_groups(data, "avg", field_name)
146        } else {
147            let values: Vec<f64> = data
148                .iter()
149                .filter_map(|item| item.get(field_name))
150                .filter_map(|val| val.as_f64())
151                .collect();
152
153            if values.is_empty() {
154                Ok(vec![Value::Null])
155            } else {
156                let avg = values.iter().sum::<f64>() / values.len() as f64;
157                let round_avg = (avg * 10.0).round() / 10.0;
158                let avg_value = Value::Number(serde_json::Number::from_f64(round_avg).unwrap());
159                Ok(vec![avg_value])
160            }
161        }
162    } else if trimmed_op.starts_with("min(") && trimmed_op.ends_with(")") {
163        // min(.field) の処理
164        let field = &trimmed_op[4..trimmed_op.len() - 1];
165        let field_name = field.trim_start_matches('.');
166
167        if is_grouped_data(&data) {
168            apply_aggregation_to_groups(data, "min", field_name)
169        } else {
170            let min_val = data
171                .iter()
172                .filter_map(|item| item.get(field_name))
173                .filter_map(|val| val.as_f64())
174                .fold(f64::INFINITY, f64::min);
175
176            if min_val == f64::INFINITY {
177                Ok(vec![Value::Null])
178            } else {
179                let min_value = Value::Number(serde_json::Number::from_f64(min_val).unwrap());
180                Ok(vec![min_value])
181            }
182        }
183    } else if trimmed_op.starts_with("max(") && trimmed_op.ends_with(")") {
184        // max(.field) の処理
185        let field = &trimmed_op[4..trimmed_op.len() - 1];
186        let field_name = field.trim_start_matches('.');
187
188        if is_grouped_data(&data) {
189            apply_aggregation_to_groups(data, "max", field_name)
190        } else {
191            let max_val = data
192                .iter()
193                .filter_map(|item| item.get(field_name))
194                .filter_map(|val| val.as_f64())
195                .fold(f64::NEG_INFINITY, f64::max);
196
197            if max_val == f64::NEG_INFINITY {
198                Ok(vec![Value::Null])
199            } else {
200                let max_value = Value::Number(serde_json::Number::from_f64(max_val).unwrap());
201                Ok(vec![max_value])
202            }
203        }
204    } else if trimmed_op.starts_with("group_by(") && trimmed_op.ends_with(")") {
205        // group_by(.department) の処理
206        let field = &trimmed_op[9..trimmed_op.len() - 1];
207        let field_name = field.trim_start_matches('.');
208
209        let grouped = group_data_by_field(data, field_name)?;
210        Ok(grouped)
211    } else if trimmed_op == "unique" {
212        // unique操作(重複除去)
213        let result = apply_stats_operation(&data, "unique", None)?;
214        if let Value::Array(arr) = result {
215            Ok(arr)
216        } else {
217            Ok(vec![result])
218        }
219    } else if trimmed_op == "sort" {
220        // sort操作
221        let result = apply_stats_operation(&data, "sort", None)?;
222        if let Value::Array(arr) = result {
223            Ok(arr)
224        } else {
225            Ok(vec![result])
226        }
227    } else if trimmed_op == "length" {
228        // length操作(配列の長さ)
229        let result = apply_stats_operation(&data, "length", None)?;
230        Ok(vec![result])
231    } else if trimmed_op == "median" {
232        // median操作(中央値)
233        let result = apply_stats_operation(&data, "median", None)?;
234        Ok(vec![result])
235    } else if trimmed_op == "stddev" {
236        // stddev操作(標準偏差)
237        let result = apply_stats_operation(&data, "stddev", None)?;
238        Ok(vec![result])
239    } else if trimmed_op.starts_with("unique(") && trimmed_op.ends_with(")") {
240        // unique(.field) - フィールド指定
241        let field = &trimmed_op[7..trimmed_op.len() - 1];
242        let field_name = field.trim_start_matches('.');
243        let result = apply_stats_operation(&data, "unique", Some(field_name))?;
244        if let Value::Array(arr) = result {
245            Ok(arr)
246        } else {
247            Ok(vec![result])
248        }
249    } else if trimmed_op.starts_with("sort(") && trimmed_op.ends_with(")") {
250        // sort(.field) - フィールド指定
251        let field = &trimmed_op[5..trimmed_op.len() - 1];
252        let field_name = field.trim_start_matches('.');
253        let result = apply_stats_operation(&data, "sort", Some(field_name))?;
254        if let Value::Array(arr) = result {
255            Ok(arr)
256        } else {
257            Ok(vec![result])
258        }
259    } else if trimmed_op.starts_with("median(") && trimmed_op.ends_with(")") {
260        // median(.field) - フィールド指定
261        let field = &trimmed_op[7..trimmed_op.len() - 1];
262        let field_name = field.trim_start_matches('.');
263        let result = apply_stats_operation(&data, "median", Some(field_name))?;
264        Ok(vec![result])
265    } else if trimmed_op.starts_with("stddev(") && trimmed_op.ends_with(")") {
266        // stddev(.field) - フィールド指定
267        let field = &trimmed_op[7..trimmed_op.len() - 1];
268        let field_name = field.trim_start_matches('.');
269        let result = apply_stats_operation(&data, "stddev", Some(field_name))?;
270        Ok(vec![result])
271    } else {
272        // より詳細なエラーメッセージ
273        Err(Error::InvalidQuery(format!(
274            "Unsupported operation: '{}' (length: {}, starts with 'map(': {}, ends with ')': {})",
275            trimmed_op, 
276            trimmed_op.len(),
277            trimmed_op.starts_with("map("),
278            trimmed_op.ends_with(")")
279        )))
280    }
281}
282
283/// map操作の実装
284fn apply_map_operation(data: Vec<Value>, operation: &str) -> Result<Vec<Value>, Error> {
285    // "map(.field | string_operation)" の解析
286    let content = &operation[4..operation.len() - 1]; // "map(" と ")" を除去
287    
288    let (field_access, string_operations) = parse_map_content(content)?;
289    
290    let mut results = Vec::new();
291    
292    for item in data {
293        // フィールドにアクセス
294        let field_value = extract_field_value(&item, &field_access)?;
295        
296        // 文字列操作を適用
297        let transformed_value = apply_string_operations(&field_value, &string_operations)?;
298        
299        // 元のオブジェクトを更新または新しい値を作成
300        let result = update_or_create_value(&item, &field_access, transformed_value)?;
301        results.push(result);
302    }
303    
304    Ok(results)
305}
306
307/// map操作の内容を解析(例: ".name | upper | trim")
308fn parse_map_content(content: &str) -> Result<(String, Vec<String>), Error> {
309    let parts: Vec<&str> = content.split('|').map(|s| s.trim()).collect();
310    
311    if parts.is_empty() {
312        return Err(Error::InvalidQuery("Empty map operation".to_string()));
313    }
314    
315    // 最初の部分はフィールドアクセス
316    let field_access = parts[0].to_string();
317    
318    // 残りは文字列操作
319    let string_operations: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
320    
321    Ok((field_access, string_operations))
322}
323
324/// フィールド値を抽出
325fn extract_field_value(item: &Value, field_access: &str) -> Result<Value, Error> {
326    if field_access == "." {
327        // ルート値(Text配列の場合の各行)
328        return Ok(item.clone());
329    }
330    
331    if field_access.starts_with('.') {
332        let field_name = &field_access[1..]; // '.' を除去
333        
334        if let Some(value) = item.get(field_name) {
335            Ok(value.clone())
336        } else {
337            Err(Error::InvalidQuery(format!("Field '{}' not found", field_name)))
338        }
339    } else {
340        Err(Error::InvalidQuery(format!("Invalid field access: {}", field_access)))
341    }
342}
343
344/// 文字列操作を順次適用
345fn apply_string_operations(value: &Value, operations: &[String]) -> Result<Value, Error> {
346    if operations.is_empty() {
347        return Ok(value.clone());
348    }
349    
350    let operations_str: Vec<&str> = operations.iter().map(|s| s.as_str()).collect();
351    string_ops::apply_string_pipeline(value, &operations_str)
352}
353
354/// 値を更新または新しい値を作成
355fn update_or_create_value(original: &Value, field_access: &str, new_value: Value) -> Result<Value, Error> {
356    if field_access == "." {
357        // ルート値の場合は直接置き換え
358        Ok(new_value)
359    } else if field_access.starts_with('.') {
360        let field_name = &field_access[1..];
361        
362        // オブジェクトの場合はフィールドを更新
363        if let Value::Object(mut obj) = original.clone() {
364            obj.insert(field_name.to_string(), new_value);
365            Ok(Value::Object(obj))
366        } else {
367            // オブジェクトでない場合は新しいオブジェクトを作成
368            let mut new_obj = serde_json::Map::new();
369            new_obj.insert(field_name.to_string(), new_value);
370            Ok(Value::Object(new_obj))
371        }
372    } else {
373        Err(Error::InvalidQuery(format!("Invalid field access: {}", field_access)))
374    }
375}
376
377fn apply_field_selection(data: Vec<Value>, field_list: Vec<String>) -> Result<Vec<Value>, Error> {
378    let mut results = Vec::new();
379
380    for item in data {
381        if let Value::Object(obj) = item {
382            let mut selected_obj = serde_json::Map::new();
383
384            // 指定されたフィールドのみを抽出
385            for field_name in &field_list {
386                if let Some(value) = obj.get(field_name) {
387                    selected_obj.insert(field_name.clone(), value.clone());
388                }
389            }
390
391            results.push(Value::Object(selected_obj));
392        } else {
393            // オブジェクト以外は無視するか、エラーにする
394            return Err(Error::InvalidQuery(
395                "select_fields can only be applied to objects".into(),
396            ));
397        }
398    }
399
400    Ok(results)
401}
402
403fn group_data_by_field(data: Vec<Value>, field_name: &str) -> Result<Vec<Value>, Error> {
404    use std::collections::HashMap;
405
406    let mut groups: HashMap<String, Vec<Value>> = HashMap::new();
407
408    // データをフィールド値でグルーピング
409    for item in data {
410        if let Some(field_value) = item.get(field_name) {
411            let key = value_to_string(field_value);
412            groups.entry(key).or_default().push(item);
413        }
414    }
415
416    // グループを配列として返す
417    let result: Vec<Value> = groups
418        .into_iter()
419        .map(|(group_name, group_items)| {
420            let mut group_obj = serde_json::Map::new();
421            group_obj.insert("group".to_string(), Value::String(group_name));
422            group_obj.insert("items".to_string(), Value::Array(group_items));
423            Value::Object(group_obj)
424        })
425        .collect();
426
427    Ok(result)
428}
429
430fn parse_condition(condition: &str) -> Result<(String, String, String), Error> {
431    // ".age > 30" のような条件をパース
432    let condition = condition.trim();
433
434    // 演算子を検出
435    if let Some(pos) = condition.find(" > ") {
436        let field = condition[..pos].trim().to_string();
437        let value = condition[pos + 3..].trim().to_string();
438        return Ok((field, ">".to_string(), value));
439    }
440
441    if let Some(pos) = condition.find(" < ") {
442        let field = condition[..pos].trim().to_string();
443        let value = condition[pos + 3..].trim().to_string();
444        return Ok((field, "<".to_string(), value));
445    }
446
447    if let Some(pos) = condition.find(" == ") {
448        let field = condition[..pos].trim().to_string();
449        let value = condition[pos + 4..].trim().to_string();
450        return Ok((field, "==".to_string(), value));
451    }
452
453    if let Some(pos) = condition.find(" != ") {
454        let field = condition[..pos].trim().to_string();
455        let value = condition[pos + 4..].trim().to_string();
456        return Ok((field, "!=".to_string(), value));
457    }
458
459    Err(Error::InvalidQuery("Invalid condition format".into()))
460}
461
462fn evaluate_condition(item: &Value, field_path: &str, operator: &str, value: &str) -> bool {
463    // フィールドパスから値を取得 (.age -> age)
464    let field_name = if field_path.starts_with('.') {
465        &field_path[1..]
466    } else {
467        field_path
468    };
469
470    let field_value = match item.get(field_name) {
471        Some(val) => val,
472        None => return false, // false if the field does not exist
473    };
474
475    match operator {
476        ">" => compare_greater(field_value, value),
477        "<" => compare_less(field_value, value),
478        "==" => compare_equal(field_value, value),
479        "!=" => !compare_equal(field_value, value),
480        _ => false,
481    }
482}
483
484fn compare_greater(field_value: &Value, target: &str) -> bool {
485    match field_value {
486        Value::Number(n) => {
487            if let Ok(target_num) = target.parse::<f64>() {
488                n.as_f64().unwrap_or(0.0) > target_num
489            } else {
490                false
491            }
492        }
493        _ => false,
494    }
495}
496
497fn compare_less(field_value: &Value, target: &str) -> bool {
498    match field_value {
499        Value::Number(n) => {
500            if let Ok(target_num) = target.parse::<f64>() {
501                n.as_f64().unwrap_or(0.0) < target_num
502            } else {
503                false
504            }
505        }
506        _ => false,
507    }
508}
509
510fn compare_equal(field_value: &Value, target: &str) -> bool {
511    match field_value {
512        Value::String(s) => {
513            // 文字列比較(引用符を除去)
514            let target_clean = target.trim_matches('"');
515            s == target_clean
516        }
517        Value::Number(n) => {
518            if let Ok(target_num) = target.parse::<f64>() {
519                n.as_f64().unwrap_or(0.0) == target_num
520            } else {
521                false
522            }
523        }
524        Value::Bool(b) => match target {
525            "true" => *b,
526            "false" => !*b,
527            _ => false,
528        },
529        _ => false,
530    }
531}
532
533fn is_grouped_data(data: &[Value]) -> bool {
534    data.iter().all(|item| {
535        if let Value::Object(obj) = item {
536            obj.contains_key("group") && obj.contains_key("items")
537        } else {
538            false
539        }
540    })
541}
542
543fn apply_aggregation_to_groups(
544    data: Vec<Value>,
545    operation: &str,
546    field_name: &str,
547) -> Result<Vec<Value>, Error> {
548    let mut results = Vec::new();
549
550    for group_data in data {
551        if let Value::Object(group_obj) = group_data {
552            let group_name = group_obj.get("group").unwrap();
553            let items = group_obj.get("items").and_then(|v| v.as_array()).unwrap();
554
555            // 各グループのitemsに対して集約を実行
556            let aggregated_value = match operation {
557                "avg" => calculate_avg(items, field_name)?,
558                "sum" => calculate_sum(items, field_name)?,
559                "count" => Value::Number(serde_json::Number::from(items.len())),
560                "min" => calculate_min(items, field_name)?,
561                "max" => calculate_max(items, field_name)?,
562                _ => Value::Null,
563            };
564
565            // 結果オブジェクトを作成
566            let mut result_obj = serde_json::Map::new();
567            result_obj.insert("group".to_string(), group_name.clone());
568            result_obj.insert(operation.to_string(), aggregated_value);
569            results.push(Value::Object(result_obj));
570        }
571    }
572
573    Ok(results)
574}
575
576fn calculate_avg(items: &[Value], field_name: &str) -> Result<Value, Error> {
577    let values: Vec<f64> = items
578        .iter()
579        .filter_map(|item| item.get(field_name))
580        .filter_map(|val| val.as_f64())
581        .collect();
582
583    if values.is_empty() {
584        Ok(Value::Null)
585    } else {
586        let avg = values.iter().sum::<f64>() / values.len() as f64;
587        let rounded_avg = (avg * 10.0).round() / 10.0;
588        Ok(Value::Number(
589            serde_json::Number::from_f64(rounded_avg).unwrap(),
590        ))
591    }
592}
593
594fn calculate_sum(items: &[Value], field_name: &str) -> Result<Value, Error> {
595    let sum: f64 = items
596        .iter()
597        .filter_map(|item| item.get(field_name))
598        .filter_map(|val| val.as_f64())
599        .sum();
600
601    let rounded_sum = if sum.fract() == 0.0 {
602        sum
603    } else {
604        (sum * 10.0).round() / 10.0
605    };
606
607    Ok(Value::Number(
608        serde_json::Number::from_f64(rounded_sum).unwrap(),
609    ))
610}
611
612fn calculate_min(items: &[Value], field_name: &str) -> Result<Value, Error> {
613    let min_val = items
614        .iter()
615        .filter_map(|item| item.get(field_name))
616        .filter_map(|val| val.as_f64())
617        .fold(f64::INFINITY, f64::min);
618
619    if min_val == f64::INFINITY {
620        Ok(Value::Null)
621    } else {
622        Ok(Value::Number(
623            serde_json::Number::from_f64(min_val).unwrap(),
624        ))
625    }
626}
627
628fn calculate_max(items: &[Value], field_name: &str) -> Result<Value, Error> {
629    let max_val = items
630        .iter()
631        .filter_map(|item| item.get(field_name))
632        .filter_map(|val| val.as_f64())
633        .fold(f64::NEG_INFINITY, f64::max);
634
635    if max_val == f64::NEG_INFINITY {
636        Ok(Value::Null)
637    } else {
638        Ok(Value::Number(
639            serde_json::Number::from_f64(max_val).unwrap(),
640        ))
641    }
642}