Skip to main content

activecube_rs/schema/
generator.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use async_graphql::dynamic::*;
4use async_graphql::Value;
5
6use crate::compiler;
7use crate::compiler::ir::SqlValue;
8use crate::cube::definition::{CubeDefinition, DimType, DimensionNode};
9use crate::cube::registry::CubeRegistry;
10use crate::response::RowMap;
11use crate::schema::filter_types;
12use crate::sql::dialect::SqlDialect;
13use crate::stats::{QueryStats, StatsCallback};
14
15/// Async function type that executes a compiled SQL query and returns rows.
16/// The service layer provides this — the library never touches a database directly.
17pub type QueryExecutor = Arc<
18    dyn Fn(String, Vec<SqlValue>) -> std::pin::Pin<
19        Box<dyn std::future::Future<Output = Result<Vec<RowMap>, String>> + Send>,
20    > + Send + Sync,
21>;
22
23/// Configuration for supported networks (chains) and optional stats collection.
24pub struct SchemaConfig {
25    pub networks: Vec<String>,
26    pub root_query_name: String,
27    /// Optional callback invoked after each cube query with execution metadata.
28    /// Used by application layer for billing, observability, etc.
29    pub stats_callback: Option<StatsCallback>,
30}
31
32impl Default for SchemaConfig {
33    fn default() -> Self {
34        Self {
35            networks: vec!["sol", "eth", "bsc"]
36                .into_iter().map(String::from).collect(),
37            root_query_name: "ChainStream".to_string(),
38            stats_callback: None,
39        }
40    }
41}
42
43/// Build a complete async-graphql dynamic schema from registry + dialect + executor.
44pub fn build_schema(
45    registry: CubeRegistry,
46    dialect: Arc<dyn SqlDialect>,
47    executor: QueryExecutor,
48    config: SchemaConfig,
49) -> Result<Schema, SchemaError> {
50    let mut builder = Schema::build("Query", None, None);
51
52    let mut network_enum = Enum::new("Network")
53        .description("Blockchain network to query");
54    for net in &config.networks {
55        network_enum = network_enum.item(EnumItem::new(net));
56    }
57    builder = builder.register(network_enum);
58    builder = builder.register(filter_types::build_limit_input());
59
60    builder = builder.register(
61        InputObject::new("LimitByInput")
62            .description("Limit results per group (similar to ClickHouse LIMIT BY)")
63            .field(InputValue::new("by", TypeRef::named_nn(TypeRef::STRING))
64                .description("Comma-separated dimension names to group by"))
65            .field(InputValue::new("count", TypeRef::named_nn(TypeRef::INT))
66                .description("Maximum rows per group"))
67            .field(InputValue::new("offset", TypeRef::named(TypeRef::INT))
68                .description("Rows to skip per group")),
69    );
70
71    builder = builder.register(
72        Enum::new("OrderDirection")
73            .description("Sort direction")
74            .item(EnumItem::new("ASC").description("Ascending"))
75            .item(EnumItem::new("DESC").description("Descending")),
76    );
77
78    for input in filter_types::build_filter_primitives() {
79        builder = builder.register(input);
80    }
81
82    // Cubes are top-level Query fields, each with a required `network` argument.
83    // Query pattern: `query { DEXTrades(network: sol, limit: ...) { ... } }`
84    let mut query = Object::new("Query");
85
86    for cube in registry.cubes() {
87        let types = build_cube_types(cube);
88        for obj in types.objects { builder = builder.register(obj); }
89        for inp in types.inputs { builder = builder.register(inp); }
90        for en in types.enums { builder = builder.register(en); }
91
92        let cube_name = cube.name.clone();
93        let dialect_clone = dialect.clone();
94        let executor_clone = executor.clone();
95        let stats_cb = config.stats_callback.clone();
96
97        let orderby_list_input_name = format!("{}OrderByInput", cube.name);
98
99        let cube_description = cube.description.clone();
100        let mut field = Field::new(
101            &cube.name,
102            TypeRef::named_nn_list_nn(format!("{}Record", cube.name)),
103            move |ctx| {
104                let cube_name = cube_name.clone();
105                let dialect = dialect_clone.clone();
106                let executor = executor_clone.clone();
107                let stats_cb = stats_cb.clone();
108                FieldFuture::new(async move {
109                    let registry = ctx.ctx.data::<CubeRegistry>()?;
110                    let network_val = ctx.args.try_get("network")?;
111                    let network = network_val.enum_name()
112                        .map_err(|_| async_graphql::Error::new("network must be a Network enum value"))?;
113
114                    let cube_def = registry.get(&cube_name).ok_or_else(|| {
115                        async_graphql::Error::new(format!("Unknown cube: {cube_name}"))
116                    })?;
117
118                    let metric_requests = extract_metric_requests(&ctx, cube_def);
119                    let requested = extract_requested_fields(&ctx, cube_def);
120                    let ir = compiler::parser::parse_cube_query(
121                        cube_def,
122                        network,
123                        &ctx.args,
124                        &metric_requests,
125                        Some(requested),
126                    )?;
127                    let validated = compiler::validator::validate(ir)?;
128                    let result = dialect.compile(&validated);
129                    let sql = result.sql;
130                    let bindings = result.bindings;
131
132                    let rows = executor(sql.clone(), bindings).await.map_err(|e| {
133                        async_graphql::Error::new(format!("Query execution failed: {e}"))
134                    })?;
135
136                    // Remap aliased columns back to original names for resolvers
137                    let rows = if result.alias_remap.is_empty() {
138                        rows
139                    } else {
140                        rows.into_iter().map(|mut row| {
141                            for (alias, original) in &result.alias_remap {
142                                if let Some(val) = row.shift_remove(alias) {
143                                    row.entry(original.clone()).or_insert(val);
144                                }
145                            }
146                            row
147                        }).collect()
148                    };
149
150                    let effective_cb = ctx.ctx.data::<StatsCallback>().ok().cloned()
151                        .or_else(|| stats_cb.clone());
152                    if let Some(cb) = effective_cb {
153                        let stats = QueryStats::from_ir(&validated, rows.len(), &sql);
154                        cb(stats);
155                    }
156
157                    let values: Vec<FieldValue> = rows.into_iter().map(FieldValue::owned_any).collect();
158                    Ok(Some(FieldValue::list(values)))
159                })
160            },
161        );
162        if !cube_description.is_empty() {
163            field = field.description(&cube_description);
164        }
165        field = field
166            .argument(InputValue::new("network", TypeRef::named_nn("Network"))
167                .description("Blockchain network to query"))
168            .argument(InputValue::new("where", TypeRef::named(format!("{}Filter", cube.name)))
169                .description("Filter conditions"))
170            .argument(InputValue::new("limit", TypeRef::named("LimitInput"))
171                .description("Pagination control"))
172            .argument(InputValue::new("limitBy", TypeRef::named("LimitByInput"))
173                .description("Per-group row limit"))
174            .argument(InputValue::new("orderBy", TypeRef::named(format!("{}OrderBy", cube.name)))
175                .description("Sort order (single column)"))
176            .argument(InputValue::new("orderByList", TypeRef::named_list(&orderby_list_input_name))
177                .description("Sort order (multiple columns)"));
178
179        for sel in &cube.selectors {
180            let filter_type = dim_type_to_filter_name(&sel.dim_type);
181            field = field.argument(InputValue::new(&sel.graphql_name, TypeRef::named(filter_type))
182                .description(format!("Shorthand filter for {}", sel.graphql_name)));
183        }
184
185        query = query.field(field);
186    }
187
188    let metadata_registry = Arc::new(registry.clone());
189    let metadata_field = Field::new(
190        "_cubeMetadata",
191        TypeRef::named_nn(TypeRef::STRING),
192        move |_ctx| {
193            let reg = metadata_registry.clone();
194            FieldFuture::new(async move {
195                let metadata: Vec<serde_json::Value> = reg.cubes().map(|cube| {
196                    serde_json::json!({
197                        "name": cube.name,
198                        "description": cube.description,
199                        "schema": cube.schema,
200                        "tablePattern": cube.table_pattern,
201                        "metrics": cube.metrics,
202                        "selectors": cube.selectors.iter().map(|s| {
203                            serde_json::json!({
204                                "name": s.graphql_name,
205                                "column": s.column,
206                                "type": format!("{:?}", s.dim_type),
207                            })
208                        }).collect::<Vec<_>>(),
209                        "dimensions": serialize_dims(&cube.dimensions),
210                        "defaultLimit": cube.default_limit,
211                        "maxLimit": cube.max_limit,
212                    })
213                }).collect();
214                let json = serde_json::to_string(&metadata).unwrap_or_default();
215                Ok(Some(FieldValue::value(Value::from(json))))
216            })
217        },
218    )
219    .description("Internal: returns JSON metadata about all cubes");
220    query = query.field(metadata_field);
221
222    builder = builder.register(query);
223    builder = builder.data(registry);
224
225    builder.finish()
226}
227
228fn serialize_dims(dims: &[DimensionNode]) -> serde_json::Value {
229    serde_json::Value::Array(dims.iter().map(|d| match d {
230        DimensionNode::Leaf(dim) => {
231            let mut obj = serde_json::json!({
232                "name": dim.graphql_name,
233                "column": dim.column,
234                "type": format!("{:?}", dim.dim_type),
235            });
236            if let Some(desc) = &dim.description {
237                obj["description"] = serde_json::Value::String(desc.clone());
238            }
239            obj
240        },
241        DimensionNode::Group { graphql_name, description, children } => {
242            let mut obj = serde_json::json!({
243                "name": graphql_name,
244                "children": serialize_dims(children),
245            });
246            if let Some(desc) = description {
247                obj["description"] = serde_json::Value::String(desc.clone());
248            }
249            obj
250        },
251    }).collect())
252}
253
254/// Extract metric requests from the GraphQL selection set by inspecting
255/// child fields. If a user selects `count(of: "Trade_Buy_Amount")`, we find
256/// the "count" field in the selection set and extract its `of` argument.
257fn extract_metric_requests(
258    ctx: &async_graphql::dynamic::ResolverContext,
259    cube: &CubeDefinition,
260) -> Vec<compiler::parser::MetricRequest> {
261    let mut requests = Vec::new();
262
263    for sub_field in ctx.ctx.field().selection_set() {
264        let name = sub_field.name();
265        if !cube.metrics.contains(&name.to_string()) {
266            continue;
267        }
268
269        let args = match sub_field.arguments() {
270            Ok(args) => args,
271            Err(_) => continue,
272        };
273
274        let of_dimension = args
275            .iter()
276            .find(|(k, _)| k.as_str() == "of")
277            .and_then(|(_, v)| match v {
278                async_graphql::Value::Enum(e) => Some(e.to_string()),
279                async_graphql::Value::String(s) => Some(s.clone()),
280                _ => None,
281            })
282            .unwrap_or_else(|| "*".to_string());
283
284        let select_where_value = args
285            .iter()
286            .find(|(k, _)| k.as_str() == "selectWhere")
287            .map(|(_, v)| v.clone());
288
289        let condition_filter = args
290            .iter()
291            .find(|(k, _)| k.as_str() == "if")
292            .and_then(|(_, v)| {
293                compiler::filter::parse_filter_from_value(v, &cube.dimensions).ok()
294                    .and_then(|f| if f.is_empty() { None } else { Some(f) })
295            });
296
297        requests.push(compiler::parser::MetricRequest {
298            function: name.to_string(),
299            of_dimension,
300            select_where_value,
301            condition_filter,
302        });
303    }
304
305    requests
306}
307
308fn extract_requested_fields(
309    ctx: &async_graphql::dynamic::ResolverContext,
310    cube: &CubeDefinition,
311) -> HashSet<String> {
312    let mut fields = HashSet::new();
313    collect_selection_paths(&ctx.ctx.field(), "", &mut fields, &cube.metrics);
314    fields
315}
316
317fn collect_selection_paths(
318    field: &async_graphql::SelectionField<'_>,
319    prefix: &str,
320    out: &mut HashSet<String>,
321    metrics: &[String],
322) {
323    for sub in field.selection_set() {
324        let name = sub.name();
325        if metrics.iter().any(|m| m == name) {
326            continue;
327        }
328        let path = if prefix.is_empty() {
329            name.to_string()
330        } else {
331            format!("{prefix}_{name}")
332        };
333        let has_children = sub.selection_set().next().is_some();
334        if has_children {
335            collect_selection_paths(&sub, &path, out, metrics);
336        } else {
337            out.insert(path);
338        }
339    }
340}
341
342// ---------------------------------------------------------------------------
343// Per-Cube GraphQL type generation
344// ---------------------------------------------------------------------------
345
346struct CubeTypes {
347    objects: Vec<Object>,
348    inputs: Vec<InputObject>,
349    enums: Vec<Enum>,
350}
351
352fn build_cube_types(cube: &CubeDefinition) -> CubeTypes {
353    let record_name = format!("{}Record", cube.name);
354    let filter_name = format!("{}Filter", cube.name);
355    let orderby_name = format!("{}OrderBy", cube.name);
356
357    let mut record_fields: Vec<Field> = Vec::new();
358    let mut filter_fields: Vec<InputValue> = Vec::new();
359    let mut orderby_items: Vec<String> = Vec::new();
360    let mut extra_objects: Vec<Object> = Vec::new();
361    let mut extra_inputs: Vec<InputObject> = Vec::new();
362
363    filter_fields.push(InputValue::new("any", TypeRef::named_list(&filter_name))
364        .description("OR combinator — matches if any sub-filter matches"));
365
366    {
367        let mut collector = DimCollector {
368            cube_name: &cube.name,
369            record_fields: &mut record_fields,
370            filter_fields: &mut filter_fields,
371            orderby_items: &mut orderby_items,
372            extra_objects: &mut extra_objects,
373            extra_inputs: &mut extra_inputs,
374        };
375        for node in &cube.dimensions {
376            collect_dimension_types(node, "", &mut collector);
377        }
378    }
379
380    let flat_dims = cube.flat_dimensions();
381    let mut metric_enums: Vec<Enum> = Vec::new();
382    let metric_descriptions: std::collections::HashMap<&str, &str> = [
383        ("count", "Count of rows or distinct values"),
384        ("sum", "Sum of values"),
385        ("avg", "Average of values"),
386        ("min", "Minimum value"),
387        ("max", "Maximum value"),
388        ("uniq", "Count of unique (distinct) values"),
389    ].into_iter().collect();
390
391    for metric in &cube.metrics {
392        let select_where_name = format!("{}_{}_SelectWhere", cube.name, metric);
393        extra_inputs.push(
394            InputObject::new(&select_where_name)
395                .description(format!("Post-aggregation filter for {} (HAVING clause)", metric))
396                .field(InputValue::new("gt", TypeRef::named(TypeRef::STRING)).description("Greater than"))
397                .field(InputValue::new("ge", TypeRef::named(TypeRef::STRING)).description("Greater than or equal to"))
398                .field(InputValue::new("lt", TypeRef::named(TypeRef::STRING)).description("Less than"))
399                .field(InputValue::new("le", TypeRef::named(TypeRef::STRING)).description("Less than or equal to"))
400                .field(InputValue::new("eq", TypeRef::named(TypeRef::STRING)).description("Equal to")),
401        );
402
403        let of_enum_name = format!("{}_{}_Of", cube.name, metric);
404        let mut of_enum = Enum::new(&of_enum_name)
405            .description(format!("Dimension to apply {} aggregation on", metric));
406        for (path, _) in &flat_dims { of_enum = of_enum.item(EnumItem::new(path)); }
407        metric_enums.push(of_enum);
408
409        let metric_clone = metric.clone();
410        let metric_desc = metric_descriptions.get(metric.as_str())
411            .copied()
412            .unwrap_or("Aggregate metric");
413        let metric_field = Field::new(metric, TypeRef::named(TypeRef::FLOAT), move |ctx| {
414            let metric_key = metric_clone.clone();
415            FieldFuture::new(async move {
416                let row = ctx.parent_value.try_downcast_ref::<RowMap>()?;
417                let key = format!("__{metric_key}");
418                let val = row.get(&key).cloned().unwrap_or(serde_json::Value::Null);
419                Ok(Some(FieldValue::value(json_to_gql_value(val))))
420            })
421        })
422        .description(metric_desc)
423        .argument(InputValue::new("of", TypeRef::named(&of_enum_name))
424            .description("Dimension to aggregate on (default: all rows)"))
425        .argument(InputValue::new("selectWhere", TypeRef::named(&select_where_name))
426            .description("Post-aggregation filter (HAVING)"))
427        .argument(InputValue::new("if", TypeRef::named(&filter_name))
428            .description("Conditional filter for this metric"));
429
430        record_fields.push(metric_field);
431    }
432
433    let mut record = Object::new(&record_name);
434    for f in record_fields { record = record.field(f); }
435
436    let mut filter = InputObject::new(&filter_name)
437        .description(format!("Filter conditions for {} query", cube.name));
438    for f in filter_fields { filter = filter.field(f); }
439
440    let mut orderby = Enum::new(&orderby_name)
441        .description(format!("Sort order for {} results (single column)", cube.name));
442    for item in &orderby_items { orderby = orderby.item(EnumItem::new(item)); }
443
444    // Multi-column orderBy: {Cube}OrderBy_Field enum + {Cube}OrderByInput
445    let field_enum_name = format!("{}_Field", orderby_name);
446    let orderby_input_name = format!("{}OrderByInput", cube.name);
447    let mut field_enum = Enum::new(&field_enum_name)
448        .description(format!("Available fields for {} multi-column sort", cube.name));
449    let flat_dims = cube.flat_dimensions();
450    for (path, _) in &flat_dims {
451        field_enum = field_enum.item(EnumItem::new(path));
452    }
453    let orderby_input = InputObject::new(&orderby_input_name)
454        .description(format!("Multi-column sort input for {}", cube.name))
455        .field(InputValue::new("field", TypeRef::named_nn(&field_enum_name))
456            .description("Field to sort by"))
457        .field(InputValue::new("direction", TypeRef::named("OrderDirection"))
458            .description("Sort direction (ASC or DESC)"));
459
460    let mut objects = vec![record]; objects.extend(extra_objects);
461    let mut inputs = vec![filter, orderby_input]; inputs.extend(extra_inputs);
462    let mut enums = vec![orderby, field_enum]; enums.extend(metric_enums);
463
464    CubeTypes { objects, inputs, enums }
465}
466
467struct DimCollector<'a> {
468    cube_name: &'a str,
469    record_fields: &'a mut Vec<Field>,
470    filter_fields: &'a mut Vec<InputValue>,
471    orderby_items: &'a mut Vec<String>,
472    extra_objects: &'a mut Vec<Object>,
473    extra_inputs: &'a mut Vec<InputObject>,
474}
475
476fn collect_dimension_types(node: &DimensionNode, prefix: &str, c: &mut DimCollector<'_>) {
477    match node {
478        DimensionNode::Leaf(dim) => {
479            let col = dim.column.clone();
480            let is_datetime = dim.dim_type == DimType::DateTime;
481            let mut leaf_field = Field::new(
482                &dim.graphql_name, dim_type_to_typeref(&dim.dim_type),
483                move |ctx| {
484                    let col = col.clone();
485                    FieldFuture::new(async move {
486                        let row = ctx.parent_value.try_downcast_ref::<RowMap>()?;
487                        let val = row.get(&col).cloned().unwrap_or(serde_json::Value::Null);
488                        let gql_val = if is_datetime {
489                            json_to_gql_datetime(val)
490                        } else {
491                            json_to_gql_value(val)
492                        };
493                        Ok(Some(FieldValue::value(gql_val)))
494                    })
495                },
496            );
497            if let Some(desc) = &dim.description {
498                leaf_field = leaf_field.description(desc);
499            }
500            c.record_fields.push(leaf_field);
501            c.filter_fields.push(InputValue::new(&dim.graphql_name, TypeRef::named(dim_type_to_filter_name(&dim.dim_type))));
502
503            let path = if prefix.is_empty() { dim.graphql_name.clone() } else { format!("{}_{}", prefix, dim.graphql_name) };
504            c.orderby_items.push(format!("{path}_ASC"));
505            c.orderby_items.push(format!("{path}_DESC"));
506        }
507        DimensionNode::Group { graphql_name, description, children } => {
508            let full_path = if prefix.is_empty() { graphql_name.clone() } else { format!("{prefix}_{graphql_name}") };
509            let nested_record_name = format!("{}_{full_path}_Record", c.cube_name);
510            let nested_filter_name = format!("{}_{full_path}_Filter", c.cube_name);
511
512            let mut child_record_fields: Vec<Field> = Vec::new();
513            let mut child_filter_fields: Vec<InputValue> = Vec::new();
514            let new_prefix = if prefix.is_empty() { graphql_name.clone() } else { format!("{prefix}_{graphql_name}") };
515
516            let mut child_collector = DimCollector {
517                cube_name: c.cube_name,
518                record_fields: &mut child_record_fields,
519                filter_fields: &mut child_filter_fields,
520                orderby_items: c.orderby_items,
521                extra_objects: c.extra_objects,
522                extra_inputs: c.extra_inputs,
523            };
524            for child in children {
525                collect_dimension_types(child, &new_prefix, &mut child_collector);
526            }
527
528            let mut nested_record = Object::new(&nested_record_name);
529            for f in child_record_fields { nested_record = nested_record.field(f); }
530
531            let nested_filter_desc = format!("Filter conditions for {}", graphql_name);
532            let mut nested_filter = InputObject::new(&nested_filter_name)
533                .description(nested_filter_desc);
534            for f in child_filter_fields { nested_filter = nested_filter.field(f); }
535
536            let mut group_field = Field::new(graphql_name, TypeRef::named_nn(&nested_record_name), |ctx| {
537                FieldFuture::new(async move {
538                    let row = ctx.parent_value.try_downcast_ref::<RowMap>()?;
539                    Ok(Some(FieldValue::owned_any(row.clone())))
540                })
541            });
542            if let Some(desc) = description {
543                group_field = group_field.description(desc);
544            }
545            c.record_fields.push(group_field);
546            c.filter_fields.push(InputValue::new(graphql_name, TypeRef::named(&nested_filter_name)));
547            c.extra_objects.push(nested_record);
548            c.extra_inputs.push(nested_filter);
549        }
550    }
551}
552
553fn dim_type_to_typeref(dt: &DimType) -> TypeRef {
554    match dt {
555        DimType::String | DimType::DateTime => TypeRef::named(TypeRef::STRING),
556        DimType::Int => TypeRef::named(TypeRef::INT),
557        DimType::Float => TypeRef::named(TypeRef::FLOAT),
558        DimType::Bool => TypeRef::named(TypeRef::BOOLEAN),
559    }
560}
561
562fn dim_type_to_filter_name(dt: &DimType) -> &'static str {
563    match dt {
564        DimType::String => "StringFilter",
565        DimType::Int => "IntFilter",
566        DimType::Float => "FloatFilter",
567        DimType::DateTime => "DateTimeFilter",
568        DimType::Bool => "BoolFilter",
569    }
570}
571
572pub fn json_to_gql_value(v: serde_json::Value) -> Value {
573    match v {
574        serde_json::Value::Null => Value::Null,
575        serde_json::Value::Bool(b) => Value::from(b),
576        serde_json::Value::Number(n) => {
577            if let Some(i) = n.as_i64() { Value::from(i) }
578            else if let Some(f) = n.as_f64() { Value::from(f) }
579            else { Value::from(n.to_string()) }
580        }
581        serde_json::Value::String(s) => Value::from(s),
582        _ => Value::from(v.to_string()),
583    }
584}
585
586/// Convert a ClickHouse DateTime value to ISO 8601 format.
587/// `"2026-03-27 19:06:41.000"` -> `"2026-03-27T19:06:41.000Z"`
588fn json_to_gql_datetime(v: serde_json::Value) -> Value {
589    match v {
590        serde_json::Value::String(s) => {
591            let iso = if s.contains('T') {
592                if s.ends_with('Z') || s.contains('+') { s } else { format!("{s}Z") }
593            } else {
594                let replaced = s.replacen(' ', "T", 1);
595                if replaced.ends_with('Z') { replaced } else { format!("{replaced}Z") }
596            };
597            Value::from(iso)
598        }
599        other => json_to_gql_value(other),
600    }
601}