hawk_data/
filter.rs

1use serde_json::Value;
2
3use crate::{Error, print_data_info, value_to_string};
4
5pub fn apply_simple_filter(data: Vec<Value>, filter: &str) -> Result<Vec<Value>, Error> {
6    if filter.starts_with("select(") && filter.ends_with(")") {
7        // "select(.age > 30)" から ".age > 30" を抽出
8        let condition = &filter[7..filter.len() - 1];
9
10        // 条件をパース
11        // Parse conditions
12        let (field_path, operator, value) = parse_condition(condition)?;
13
14        // フィルタリングを実行
15        // Execute filtering
16        let filtered: Vec<Value> = data
17            .into_iter()
18            .filter(|item| evaluate_condition(item, &field_path, &operator, &value))
19            .collect();
20
21        Ok(filtered)
22    } else {
23        Err(Error::InvalidQuery(format!(
24            "Unsupported filter: {}",
25            filter
26        )))
27    }
28}
29
30pub fn apply_pipeline_operation(data: Vec<Value>, operation: &str) -> Result<Vec<Value>, Error> {
31    if operation.starts_with("select(") && operation.ends_with(")") {
32        // フィルタリング操作
33        // Filtering operations
34        apply_simple_filter(data, operation)
35    } else if operation == "count" {
36        // カウント操作
37        // Count operation
38        if is_grouped_data(&data) {
39            apply_aggregation_to_groups(data, "count", "")
40        } else {
41            let count = data.len();
42            let count_value = Value::Number(serde_json::Number::from(count));
43            Ok(vec![count_value])
44        }
45    } else if operation.starts_with("select_fields(") && operation.ends_with(")") {
46        // **新規追加: 複数フィールド選択**
47        let fields_str = &operation[14..operation.len() - 1]; // "name,age,department"
48        let field_list: Vec<String> = fields_str
49            .split(',')
50            .map(|s| s.trim().to_string())
51            .collect();
52
53        apply_field_selection(data, field_list)
54    } else if operation == "info" {
55        // info操作
56        // info operation
57        print_data_info(&data);
58        Ok(vec![]) // Return empty vector
59    } else if operation.starts_with("sum(") && operation.ends_with(")") {
60        // sum(.field) の処理
61        // Processing of sum(.field)
62        let field = &operation[4..operation.len() - 1];
63        let field_name = field.trim_start_matches('.');
64
65        if is_grouped_data(&data) {
66            apply_aggregation_to_groups(data, "sum", field_name)
67        } else {
68            let sum: f64 = data
69                .iter()
70                .filter_map(|item| item.get(field_name))
71                .filter_map(|val| val.as_f64())
72                .sum();
73
74            let round_sum = if sum.fract() == 0.0 {
75                sum
76            } else {
77                (sum * 10.0).round() / 10.0
78            };
79            let sum_value = Value::Number(serde_json::Number::from_f64(round_sum).unwrap());
80            Ok(vec![sum_value])
81        }
82    } else if operation.starts_with("avg(") && operation.ends_with(")") {
83        // avg(.field) の処理
84        // Processing of avg(.field)
85        let field = &operation[4..operation.len() - 1];
86        let field_name = field.trim_start_matches('.');
87
88        if is_grouped_data(&data) {
89            apply_aggregation_to_groups(data, "avg", field_name)
90        } else {
91            let values: Vec<f64> = data
92                .iter()
93                .filter_map(|item| item.get(field_name))
94                .filter_map(|val| val.as_f64())
95                .collect();
96
97            if values.is_empty() {
98                Ok(vec![Value::Null])
99            } else {
100                let avg = values.iter().sum::<f64>() / values.len() as f64;
101                let round_avg = (avg * 10.0).round() / 10.0;
102                let avg_value = Value::Number(serde_json::Number::from_f64(round_avg).unwrap());
103                Ok(vec![avg_value])
104            }
105        }
106    } else if operation.starts_with("min(") && operation.ends_with(")") {
107        // min(.field) の処理
108        // Processing of min(.field)
109        let field = &operation[4..operation.len() - 1];
110        let field_name = field.trim_start_matches('.');
111
112        if is_grouped_data(&data) {
113            apply_aggregation_to_groups(data, "min", field_name)
114        } else {
115            let min_val = data
116                .iter()
117                .filter_map(|item| item.get(field_name))
118                .filter_map(|val| val.as_f64())
119                .fold(f64::INFINITY, f64::min);
120
121            if min_val == f64::INFINITY {
122                Ok(vec![Value::Null])
123            } else {
124                let min_value = Value::Number(serde_json::Number::from_f64(min_val).unwrap());
125                Ok(vec![min_value])
126            }
127        }
128    } else if operation.starts_with("max(") && operation.ends_with(")") {
129        // max(.field) の処理
130        // Processing of max(.field)
131        let field = &operation[4..operation.len() - 1];
132        let field_name = field.trim_start_matches('.');
133
134        if is_grouped_data(&data) {
135            apply_aggregation_to_groups(data, "max", field_name)
136        } else {
137            let max_val = data
138                .iter()
139                .filter_map(|item| item.get(field_name))
140                .filter_map(|val| val.as_f64())
141                .fold(f64::NEG_INFINITY, f64::max);
142
143            if max_val == f64::NEG_INFINITY {
144                Ok(vec![Value::Null])
145            } else {
146                let max_value = Value::Number(serde_json::Number::from_f64(max_val).unwrap());
147                Ok(vec![max_value])
148            }
149        }
150    } else if operation.starts_with("group_by(") && operation.ends_with(")") {
151        // group_by(.department) の処理
152        // Processing of group_by(.field)
153        let field = &operation[9..operation.len() - 1];
154        let field_name = field.trim_start_matches('.');
155
156        let grouped = group_data_by_field(data, field_name)?;
157        Ok(grouped)
158    } else {
159        Err(Error::InvalidQuery(format!(
160            "Unsupported operation: {}",
161            operation
162        )))
163    }
164}
165
166fn apply_field_selection(data: Vec<Value>, field_list: Vec<String>) -> Result<Vec<Value>, Error> {
167    let mut results = Vec::new();
168
169    for item in data {
170        if let Value::Object(obj) = item {
171            let mut selected_obj = serde_json::Map::new();
172
173            // 指定されたフィールドのみを抽出
174            for field_name in &field_list {
175                if let Some(value) = obj.get(field_name) {
176                    selected_obj.insert(field_name.clone(), value.clone());
177                }
178            }
179
180            results.push(Value::Object(selected_obj));
181        } else {
182            // オブジェクト以外は無視するか、エラーにする
183            return Err(Error::InvalidQuery(
184                "select_fields can only be applied to objects".into(),
185            ));
186        }
187    }
188
189    Ok(results)
190}
191
192fn group_data_by_field(data: Vec<Value>, field_name: &str) -> Result<Vec<Value>, Error> {
193    use std::collections::HashMap;
194
195    let mut groups: HashMap<String, Vec<Value>> = HashMap::new();
196
197    // データをフィールド値でグルーピング
198    // Group data by field values
199    for item in data {
200        if let Some(field_value) = item.get(field_name) {
201            let key = value_to_string(field_value);
202            groups.entry(key).or_insert_with(Vec::new).push(item);
203        }
204    }
205
206    // グループを配列として返す
207    // Return the group as an array
208    let result: Vec<Value> = groups
209        .into_iter()
210        .map(|(group_name, group_items)| {
211            let mut group_obj = serde_json::Map::new();
212            group_obj.insert("group".to_string(), Value::String(group_name));
213            group_obj.insert("items".to_string(), Value::Array(group_items));
214            Value::Object(group_obj)
215        })
216        .collect();
217
218    Ok(result)
219}
220
221fn parse_condition(condition: &str) -> Result<(String, String, String), Error> {
222    // ".age > 30" のような条件をパース
223    // Parse conditions such as “.age > 30”
224    let condition = condition.trim();
225
226    // 演算子を検出
227    // Detect operators
228    if let Some(pos) = condition.find(" > ") {
229        let field = condition[..pos].trim().to_string();
230        let value = condition[pos + 3..].trim().to_string();
231        return Ok((field, ">".to_string(), value));
232    }
233
234    if let Some(pos) = condition.find(" < ") {
235        let field = condition[..pos].trim().to_string();
236        let value = condition[pos + 3..].trim().to_string();
237        return Ok((field, "<".to_string(), value));
238    }
239
240    if let Some(pos) = condition.find(" == ") {
241        let field = condition[..pos].trim().to_string();
242        let value = condition[pos + 4..].trim().to_string();
243        return Ok((field, "==".to_string(), value));
244    }
245
246    if let Some(pos) = condition.find(" != ") {
247        let field = condition[..pos].trim().to_string();
248        let value = condition[pos + 4..].trim().to_string();
249        return Ok((field, "!=".to_string(), value));
250    }
251
252    Err(Error::InvalidQuery("Invalid condition format".into()))
253}
254
255fn evaluate_condition(item: &Value, field_path: &str, operator: &str, value: &str) -> bool {
256    // フィールドパスから値を取得 (.age -> age)
257    // Get the value from the field path (.age -> age)
258    let field_name = if field_path.starts_with('.') {
259        &field_path[1..]
260    } else {
261        field_path
262    };
263
264    let field_value = match item.get(field_name) {
265        Some(val) => val,
266        None => return false, // false if the field does not exist
267    };
268
269    match operator {
270        ">" => compare_greater(field_value, value),
271        "<" => compare_less(field_value, value),
272        "==" => compare_equal(field_value, value),
273        "!=" => !compare_equal(field_value, value),
274        _ => false,
275    }
276}
277
278fn compare_greater(field_value: &Value, target: &str) -> bool {
279    match field_value {
280        Value::Number(n) => {
281            if let Ok(target_num) = target.parse::<f64>() {
282                n.as_f64().unwrap_or(0.0) > target_num
283            } else {
284                false
285            }
286        }
287        _ => false,
288    }
289}
290
291fn compare_less(field_value: &Value, target: &str) -> bool {
292    match field_value {
293        Value::Number(n) => {
294            if let Ok(target_num) = target.parse::<f64>() {
295                n.as_f64().unwrap_or(0.0) < target_num
296            } else {
297                false
298            }
299        }
300        _ => false,
301    }
302}
303
304fn compare_equal(field_value: &Value, target: &str) -> bool {
305    match field_value {
306        Value::String(s) => {
307            // 文字列比較(引用符を除去)
308            // String comparison (remove quotation marks)
309            let target_clean = target.trim_matches('"');
310            s == target_clean
311        }
312        Value::Number(n) => {
313            if let Ok(target_num) = target.parse::<f64>() {
314                n.as_f64().unwrap_or(0.0) == target_num
315            } else {
316                false
317            }
318        }
319        Value::Bool(b) => match target {
320            "true" => *b,
321            "false" => !*b,
322            _ => false,
323        },
324        _ => false,
325    }
326}
327
328fn is_grouped_data(data: &[Value]) -> bool {
329    data.iter().all(|item| {
330        if let Value::Object(obj) = item {
331            obj.contains_key("group") && obj.contains_key("items")
332        } else {
333            false
334        }
335    })
336}
337
338fn apply_aggregation_to_groups(
339    data: Vec<Value>,
340    operation: &str,
341    field_name: &str,
342) -> Result<Vec<Value>, Error> {
343    let mut results = Vec::new();
344
345    for group_data in data {
346        if let Value::Object(group_obj) = group_data {
347            let group_name = group_obj.get("group").unwrap();
348            let items = group_obj.get("items").and_then(|v| v.as_array()).unwrap();
349
350            // 各グループのitemsに対して集約を実行
351            // Perform aggregation on items in each group
352            let aggregated_value = match operation {
353                "avg" => calculate_avg(items, field_name)?,
354                "sum" => calculate_sum(items, field_name)?,
355                "count" => Value::Number(serde_json::Number::from(items.len())),
356                "min" => calculate_min(items, field_name)?,
357                "max" => calculate_max(items, field_name)?,
358                _ => Value::Null,
359            };
360
361            // 結果オブジェクトを作成
362            // Create result object
363            let mut result_obj = serde_json::Map::new();
364            result_obj.insert("group".to_string(), group_name.clone());
365            result_obj.insert(operation.to_string(), aggregated_value);
366            results.push(Value::Object(result_obj));
367        }
368    }
369
370    Ok(results)
371}
372
373fn calculate_avg(items: &[Value], field_name: &str) -> Result<Value, Error> {
374    let values: Vec<f64> = items
375        .iter()
376        .filter_map(|item| item.get(field_name))
377        .filter_map(|val| val.as_f64())
378        .collect();
379
380    if values.is_empty() {
381        Ok(Value::Null)
382    } else {
383        let avg = values.iter().sum::<f64>() / values.len() as f64;
384        let rounded_avg = (avg * 10.0).round() / 10.0;
385        Ok(Value::Number(
386            serde_json::Number::from_f64(rounded_avg).unwrap(),
387        ))
388    }
389}
390
391fn calculate_sum(items: &[Value], field_name: &str) -> Result<Value, Error> {
392    let sum: f64 = items
393        .iter()
394        .filter_map(|item| item.get(field_name))
395        .filter_map(|val| val.as_f64())
396        .sum();
397
398    let rounded_sum = if sum.fract() == 0.0 {
399        sum
400    } else {
401        (sum * 10.0).round() / 10.0
402    };
403
404    Ok(Value::Number(
405        serde_json::Number::from_f64(rounded_sum).unwrap(),
406    ))
407}
408
409fn calculate_min(items: &[Value], field_name: &str) -> Result<Value, Error> {
410    let min_val = items
411        .iter()
412        .filter_map(|item| item.get(field_name))
413        .filter_map(|val| val.as_f64())
414        .fold(f64::INFINITY, f64::min);
415
416    if min_val == f64::INFINITY {
417        Ok(Value::Null)
418    } else {
419        Ok(Value::Number(
420            serde_json::Number::from_f64(min_val).unwrap(),
421        ))
422    }
423}
424
425fn calculate_max(items: &[Value], field_name: &str) -> Result<Value, Error> {
426    let max_val = items
427        .iter()
428        .filter_map(|item| item.get(field_name))
429        .filter_map(|val| val.as_f64())
430        .fold(f64::NEG_INFINITY, f64::max);
431
432    if max_val == f64::NEG_INFINITY {
433        Ok(Value::Null)
434    } else {
435        Ok(Value::Number(
436            serde_json::Number::from_f64(max_val).unwrap(),
437        ))
438    }
439}