Skip to main content

shape_runtime/
chart_detect.rs

1//! Chart auto-detection from Arrow IPC table data.
2//!
3//! Inspects Arrow schemas to determine appropriate chart types and generates
4//! ECharts option JSON with embedded data. Also provides a channel-based
5//! `ChartSpec` output for unified rendering.
6
7use arrow_ipc::reader::StreamReader;
8use arrow_schema::{DataType, Schema};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11use std::io::Cursor;
12use std::sync::Arc;
13
14/// Column metadata extracted from Arrow IPC data.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ColumnInfo {
17    pub name: String,
18    pub data_type: String,
19}
20
21/// Detected chart type
22#[derive(Debug, Clone, PartialEq)]
23enum ChartType {
24    Candlestick,
25    Line,
26    Bar,
27    Scatter,
28    TableOnly,
29}
30
31/// Extract column info from Arrow IPC bytes
32pub fn extract_columns(ipc_bytes: &[u8]) -> Vec<ColumnInfo> {
33    let schema = match read_schema(ipc_bytes) {
34        Some(s) => s,
35        None => return vec![],
36    };
37
38    schema
39        .fields()
40        .iter()
41        .map(|f| ColumnInfo {
42            name: f.name().clone(),
43            data_type: format_arrow_type(f.data_type()),
44        })
45        .collect()
46}
47
48/// Auto-detect chart type and generate ECharts option JSON from Arrow IPC bytes
49pub fn detect_chart(ipc_bytes: &[u8]) -> Option<Value> {
50    if ipc_bytes.is_empty() {
51        return None;
52    }
53
54    let (schema, data) = read_schema_and_data(ipc_bytes)?;
55    let chart_type = detect_chart_type(&schema);
56
57    if chart_type == ChartType::TableOnly {
58        return None;
59    }
60
61    Some(build_echart_option(&chart_type, &schema, &data))
62}
63
64/// Read just the Arrow schema from IPC bytes
65fn read_schema(ipc_bytes: &[u8]) -> Option<Arc<Schema>> {
66    let cursor = Cursor::new(ipc_bytes);
67    let reader = StreamReader::try_new(cursor, None).ok()?;
68    Some(reader.schema().clone())
69}
70
71/// Read schema and all data from Arrow IPC bytes
72fn read_schema_and_data(ipc_bytes: &[u8]) -> Option<(Arc<Schema>, Vec<Vec<Value>>)> {
73    let cursor = Cursor::new(ipc_bytes);
74    let reader = StreamReader::try_new(cursor, None).ok()?;
75    let schema = reader.schema().clone();
76    let num_cols = schema.fields().len();
77
78    // Collect all data as JSON arrays per column
79    let mut columns: Vec<Vec<Value>> = vec![vec![]; num_cols];
80
81    for batch_result in reader {
82        let batch = batch_result.ok()?;
83        for col_idx in 0..num_cols {
84            let array = batch.column(col_idx);
85            for row_idx in 0..batch.num_rows() {
86                let val = arrow_value_to_json(array, row_idx);
87                columns[col_idx].push(val);
88            }
89        }
90    }
91
92    Some((schema, columns))
93}
94
95/// Detect chart type from Arrow schema
96fn detect_chart_type(schema: &Schema) -> ChartType {
97    let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
98
99    // Check for OHLC candlestick pattern
100    let has_ohlc = ["open", "high", "low", "close"]
101        .iter()
102        .all(|name| field_names.iter().any(|f| f.eq_ignore_ascii_case(name)));
103
104    if has_ohlc {
105        return ChartType::Candlestick;
106    }
107
108    // Classify columns
109    let mut has_timestamp = false;
110    let mut numeric_count = 0;
111    let mut string_count = 0;
112
113    for field in schema.fields() {
114        match field.data_type() {
115            DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64 => {
116                has_timestamp = true;
117            }
118            DataType::Float16
119            | DataType::Float32
120            | DataType::Float64
121            | DataType::Int8
122            | DataType::Int16
123            | DataType::Int32
124            | DataType::Int64
125            | DataType::UInt8
126            | DataType::UInt16
127            | DataType::UInt32
128            | DataType::UInt64 => {
129                numeric_count += 1;
130            }
131            DataType::Utf8 | DataType::LargeUtf8 => {
132                string_count += 1;
133            }
134            _ => {}
135        }
136    }
137
138    // Timestamp + numeric → line chart
139    if has_timestamp && numeric_count >= 1 {
140        return ChartType::Line;
141    }
142
143    // Categorical (string) + numeric → bar chart
144    if string_count >= 1 && numeric_count >= 1 {
145        return ChartType::Bar;
146    }
147
148    // Two+ numeric columns → scatter
149    if numeric_count >= 2 {
150        return ChartType::Scatter;
151    }
152
153    ChartType::TableOnly
154}
155
156/// Build an ECharts option JSON from chart type and data
157fn build_echart_option(chart_type: &ChartType, schema: &Schema, columns: &[Vec<Value>]) -> Value {
158    match chart_type {
159        ChartType::Candlestick => build_candlestick(schema, columns),
160        ChartType::Line => build_line(schema, columns),
161        ChartType::Bar => build_bar(schema, columns),
162        ChartType::Scatter => build_scatter(schema, columns),
163        ChartType::TableOnly => json!(null),
164    }
165}
166
167fn build_candlestick(schema: &Schema, columns: &[Vec<Value>]) -> Value {
168    let find_col = |name: &str| -> Option<usize> {
169        schema
170            .fields()
171            .iter()
172            .position(|f| f.name().eq_ignore_ascii_case(name))
173    };
174
175    let open_idx = find_col("open").unwrap_or(0);
176    let close_idx = find_col("close").unwrap_or(1);
177    let low_idx = find_col("low").unwrap_or(2);
178    let high_idx = find_col("high").unwrap_or(3);
179
180    // Look for a timestamp/date column for x-axis
181    let x_idx = schema
182        .fields()
183        .iter()
184        .position(|f| {
185            matches!(
186                f.data_type(),
187                DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64
188            )
189        })
190        .or_else(|| find_col("timestamp"))
191        .or_else(|| find_col("date"));
192
193    let row_count = columns.first().map(|c| c.len()).unwrap_or(0);
194
195    let x_data: Vec<Value> = if let Some(xi) = x_idx {
196        columns[xi].clone()
197    } else {
198        (0..row_count).map(|i| json!(i)).collect()
199    };
200
201    // ECharts candlestick format: [open, close, low, high]
202    let ohlc_data: Vec<Value> = (0..row_count)
203        .map(|i| {
204            json!([
205                columns[open_idx].get(i).unwrap_or(&json!(0)),
206                columns[close_idx].get(i).unwrap_or(&json!(0)),
207                columns[low_idx].get(i).unwrap_or(&json!(0)),
208                columns[high_idx].get(i).unwrap_or(&json!(0)),
209            ])
210        })
211        .collect();
212
213    json!({
214        "xAxis": {
215            "type": "category",
216            "data": x_data,
217            "axisLine": { "lineStyle": { "color": "#8392A5" } }
218        },
219        "yAxis": {
220            "scale": true,
221            "splitArea": { "show": true }
222        },
223        "series": [{
224            "type": "candlestick",
225            "data": ohlc_data,
226            "itemStyle": {
227                "color": "#26a69a",
228                "color0": "#ef5350",
229                "borderColor": "#26a69a",
230                "borderColor0": "#ef5350"
231            }
232        }],
233        "tooltip": { "trigger": "axis", "axisPointer": { "type": "cross" } },
234        "dataZoom": [
235            { "type": "inside", "start": 0, "end": 100 },
236            { "type": "slider", "start": 0, "end": 100 }
237        ],
238        "grid": { "left": "10%", "right": "10%", "bottom": "15%" }
239    })
240}
241
242fn build_line(schema: &Schema, columns: &[Vec<Value>]) -> Value {
243    // Find timestamp column for x-axis
244    let x_idx = schema
245        .fields()
246        .iter()
247        .position(|f| {
248            matches!(
249                f.data_type(),
250                DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64
251            )
252        })
253        .unwrap_or(0);
254
255    let row_count = columns.first().map(|c| c.len()).unwrap_or(0);
256    let x_data: Vec<Value> = columns.get(x_idx).cloned().unwrap_or_default();
257
258    // All numeric columns become line series
259    let mut series = Vec::new();
260    for (i, field) in schema.fields().iter().enumerate() {
261        if i == x_idx {
262            continue;
263        }
264        if is_numeric_type(field.data_type()) {
265            let data: Vec<Value> = columns.get(i).cloned().unwrap_or_default();
266            series.push(json!({
267                "name": field.name(),
268                "type": "line",
269                "data": data,
270                "sampling": "lttb",
271                "smooth": false,
272                "symbol": if row_count > 100 { "none" } else { "circle" },
273            }));
274        }
275    }
276
277    json!({
278        "xAxis": {
279            "type": "category",
280            "data": x_data,
281            "axisLine": { "lineStyle": { "color": "#8392A5" } }
282        },
283        "yAxis": { "type": "value", "scale": true },
284        "series": series,
285        "tooltip": { "trigger": "axis" },
286        "legend": { "show": series.len() > 1 },
287        "dataZoom": [
288            { "type": "inside", "start": 0, "end": 100 },
289            { "type": "slider", "start": 0, "end": 100 }
290        ],
291        "grid": { "left": "10%", "right": "10%", "bottom": "15%" }
292    })
293}
294
295fn build_bar(schema: &Schema, columns: &[Vec<Value>]) -> Value {
296    // Find string column for categories
297    let cat_idx = schema
298        .fields()
299        .iter()
300        .position(|f| matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8))
301        .unwrap_or(0);
302
303    let categories: Vec<Value> = columns.get(cat_idx).cloned().unwrap_or_default();
304
305    let mut series = Vec::new();
306    for (i, field) in schema.fields().iter().enumerate() {
307        if i == cat_idx {
308            continue;
309        }
310        if is_numeric_type(field.data_type()) {
311            let data: Vec<Value> = columns.get(i).cloned().unwrap_or_default();
312            series.push(json!({
313                "name": field.name(),
314                "type": "bar",
315                "data": data,
316            }));
317        }
318    }
319
320    json!({
321        "xAxis": { "type": "category", "data": categories },
322        "yAxis": { "type": "value" },
323        "series": series,
324        "tooltip": { "trigger": "axis" },
325        "legend": { "show": series.len() > 1 },
326        "grid": { "left": "10%", "right": "10%", "bottom": "10%" }
327    })
328}
329
330fn build_scatter(schema: &Schema, columns: &[Vec<Value>]) -> Value {
331    // First two numeric columns become x and y
332    let numeric_indices: Vec<usize> = schema
333        .fields()
334        .iter()
335        .enumerate()
336        .filter(|(_, f)| is_numeric_type(f.data_type()))
337        .map(|(i, _)| i)
338        .collect();
339
340    let x_idx = numeric_indices.first().copied().unwrap_or(0);
341    let y_idx = numeric_indices.get(1).copied().unwrap_or(1);
342
343    let row_count = columns.first().map(|c| c.len()).unwrap_or(0);
344    let scatter_data: Vec<Value> = (0..row_count)
345        .map(|i| {
346            json!([
347                columns
348                    .get(x_idx)
349                    .and_then(|c| c.get(i))
350                    .unwrap_or(&json!(0)),
351                columns
352                    .get(y_idx)
353                    .and_then(|c| c.get(i))
354                    .unwrap_or(&json!(0)),
355            ])
356        })
357        .collect();
358
359    let x_name = schema
360        .fields()
361        .get(x_idx)
362        .map(|f| f.name().as_str())
363        .unwrap_or("x");
364    let y_name = schema
365        .fields()
366        .get(y_idx)
367        .map(|f| f.name().as_str())
368        .unwrap_or("y");
369
370    json!({
371        "xAxis": { "type": "value", "name": x_name, "scale": true },
372        "yAxis": { "type": "value", "name": y_name, "scale": true },
373        "series": [{
374            "type": "scatter",
375            "data": scatter_data,
376            "symbolSize": 5,
377        }],
378        "tooltip": { "trigger": "item" },
379        "grid": { "left": "10%", "right": "10%", "bottom": "10%" }
380    })
381}
382
383// ---------------------------------------------------------------------------
384// Helpers
385// ---------------------------------------------------------------------------
386
387fn is_numeric_type(dt: &DataType) -> bool {
388    matches!(
389        dt,
390        DataType::Float16
391            | DataType::Float32
392            | DataType::Float64
393            | DataType::Int8
394            | DataType::Int16
395            | DataType::Int32
396            | DataType::Int64
397            | DataType::UInt8
398            | DataType::UInt16
399            | DataType::UInt32
400            | DataType::UInt64
401    )
402}
403
404fn format_arrow_type(dt: &DataType) -> String {
405    match dt {
406        DataType::Float32 | DataType::Float64 | DataType::Float16 => "Number".to_string(),
407        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
408            "Integer".to_string()
409        }
410        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
411            "Integer".to_string()
412        }
413        DataType::Utf8 | DataType::LargeUtf8 => "String".to_string(),
414        DataType::Boolean => "Bool".to_string(),
415        DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64 => "Timestamp".to_string(),
416        other => format!("{:?}", other),
417    }
418}
419
420/// Extract a single value from an Arrow array at the given index as JSON
421fn arrow_value_to_json(array: &dyn arrow_array::Array, idx: usize) -> Value {
422    use arrow_array::*;
423
424    if array.is_null(idx) {
425        return Value::Null;
426    }
427
428    if let Some(a) = array.as_any().downcast_ref::<Float64Array>() {
429        return json!(a.value(idx));
430    }
431    if let Some(a) = array.as_any().downcast_ref::<Float32Array>() {
432        return json!(a.value(idx) as f64);
433    }
434    if let Some(a) = array.as_any().downcast_ref::<Int64Array>() {
435        return json!(a.value(idx));
436    }
437    if let Some(a) = array.as_any().downcast_ref::<Int32Array>() {
438        return json!(a.value(idx));
439    }
440    if let Some(a) = array.as_any().downcast_ref::<UInt64Array>() {
441        return json!(a.value(idx));
442    }
443    if let Some(a) = array.as_any().downcast_ref::<UInt32Array>() {
444        return json!(a.value(idx));
445    }
446    if let Some(a) = array.as_any().downcast_ref::<StringArray>() {
447        return json!(a.value(idx));
448    }
449    if let Some(a) = array.as_any().downcast_ref::<BooleanArray>() {
450        return json!(a.value(idx));
451    }
452    if let Some(a) = array.as_any().downcast_ref::<TimestampMillisecondArray>() {
453        return json!(a.value(idx));
454    }
455    if let Some(a) = array.as_any().downcast_ref::<TimestampMicrosecondArray>() {
456        return json!(a.value(idx) / 1000); // Convert to ms
457    }
458    if let Some(a) = array.as_any().downcast_ref::<TimestampNanosecondArray>() {
459        return json!(a.value(idx) / 1_000_000); // Convert to ms
460    }
461    if let Some(a) = array.as_any().downcast_ref::<Date32Array>() {
462        return json!(a.value(idx));
463    }
464    if let Some(a) = array.as_any().downcast_ref::<Date64Array>() {
465        return json!(a.value(idx));
466    }
467
468    // Fallback
469    json!(null)
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn test_detect_chart_type_ohlc() {
478        let schema = Schema::new(vec![
479            arrow_schema::Field::new(
480                "timestamp",
481                DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
482                false,
483            ),
484            arrow_schema::Field::new("open", DataType::Float64, false),
485            arrow_schema::Field::new("high", DataType::Float64, false),
486            arrow_schema::Field::new("low", DataType::Float64, false),
487            arrow_schema::Field::new("close", DataType::Float64, false),
488            arrow_schema::Field::new("volume", DataType::Float64, false),
489        ]);
490        assert_eq!(detect_chart_type(&schema), ChartType::Candlestick);
491    }
492
493    #[test]
494    fn test_detect_chart_type_line() {
495        let schema = Schema::new(vec![
496            arrow_schema::Field::new(
497                "time",
498                DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
499                false,
500            ),
501            arrow_schema::Field::new("value", DataType::Float64, false),
502        ]);
503        assert_eq!(detect_chart_type(&schema), ChartType::Line);
504    }
505
506    #[test]
507    fn test_detect_chart_type_bar() {
508        let schema = Schema::new(vec![
509            arrow_schema::Field::new("category", DataType::Utf8, false),
510            arrow_schema::Field::new("count", DataType::Int64, false),
511        ]);
512        assert_eq!(detect_chart_type(&schema), ChartType::Bar);
513    }
514
515    #[test]
516    fn test_detect_chart_type_scatter() {
517        let schema = Schema::new(vec![
518            arrow_schema::Field::new("x", DataType::Float64, false),
519            arrow_schema::Field::new("y", DataType::Float64, false),
520        ]);
521        assert_eq!(detect_chart_type(&schema), ChartType::Scatter);
522    }
523
524    #[test]
525    fn test_extract_columns_empty() {
526        let cols = extract_columns(&[]);
527        assert!(cols.is_empty());
528    }
529
530    #[test]
531    fn test_detect_chart_empty() {
532        assert!(detect_chart(&[]).is_none());
533    }
534
535    #[test]
536    fn test_format_arrow_type() {
537        assert_eq!(format_arrow_type(&DataType::Float64), "Number");
538        assert_eq!(format_arrow_type(&DataType::Int64), "Integer");
539        assert_eq!(format_arrow_type(&DataType::Utf8), "String");
540        assert_eq!(format_arrow_type(&DataType::Boolean), "Bool");
541    }
542}