Skip to main content

clickhouse_kit/
codegen.rs

1//! Codegen — emit typed bindings for a ClickHouse table in two directions:
2//!
3//! - **ClickHouse → Rust rows** (`ch_type_to_rust` / `rust_row_struct`): the
4//!   TS→Rust bridge. TypeScript owns a (static) table; this turns its live/spec
5//!   columns into a Rust `#[derive(Row)]` struct so the Rust services get faithful,
6//!   drift-checked rows instead of hand-writing them. Pair with `introspect` +
7//!   `check_drift`. Temporal types map to `String` (the works-everywhere default
8//!   over the HTTP/RowBinary boundary; refine to `time`/`chrono` behind features).
9//! - **`TableSpec` → TS + Zod** (`emit_row_interface` / `emit_select_schema` /
10//!   `emit_insert_schema` / `emit_ts_module`): a `createSelectSchema`/`createInsertSchema`
11//!   style emitter (parity with smooai-postgres-kit). Useful for handing a runtime
12//!   (dynamic) table's shape to a TypeScript client as a row interface + Zod schemas.
13
14use crate::safety::{ColumnTypeSpec, ScalarType};
15use crate::table::{ColumnSpec, TableSpec};
16
17// ── ClickHouse → Rust rows ─────────────────────────────────────────────────────
18
19/// Strip a single-arg wrapper like `Nullable(...)` / `Array(...)`, returning the inner.
20fn strip_wrapper<'a>(t: &'a str, name: &str) -> Option<&'a str> {
21    let prefix = format!("{name}(");
22    t.strip_prefix(&prefix)
23        .and_then(|rest| rest.strip_suffix(')'))
24}
25
26/// Split a `Map(...)` inner on its top-level comma (respecting nested parens).
27fn split_top_comma(inner: &str) -> Option<(&str, &str)> {
28    let mut depth = 0usize;
29    for (i, c) in inner.char_indices() {
30        match c {
31            '(' => depth += 1,
32            ')' => depth = depth.saturating_sub(1),
33            ',' if depth == 0 => return Some((inner[..i].trim(), inner[i + 1..].trim())),
34            _ => {}
35        }
36    }
37    None
38}
39
40/// Map a ClickHouse type string to the Rust type a `clickhouse`-crate row uses.
41/// Wrappers recurse; unknown scalars fall back to `String` (safe over the wire).
42pub fn ch_type_to_rust(ch_type: &str) -> String {
43    let t = ch_type.trim();
44    if let Some(inner) = strip_wrapper(t, "Nullable") {
45        return format!("Option<{}>", ch_type_to_rust(inner));
46    }
47    if let Some(inner) = strip_wrapper(t, "LowCardinality") {
48        return ch_type_to_rust(inner);
49    }
50    if let Some(inner) = strip_wrapper(t, "Array") {
51        return format!("Vec<{}>", ch_type_to_rust(inner));
52    }
53    if let Some(inner) = strip_wrapper(t, "Map") {
54        if let Some((k, v)) = split_top_comma(inner) {
55            return format!(
56                "std::collections::HashMap<{}, {}>",
57                ch_type_to_rust(k),
58                ch_type_to_rust(v)
59            );
60        }
61    }
62    // Scalar — match on the base type, ignoring any `(...)` parameters.
63    let base = t.split('(').next().unwrap_or(t).trim();
64    match base {
65        "Bool" => "bool",
66        "UInt8" => "u8",
67        "UInt16" => "u16",
68        "UInt32" => "u32",
69        "UInt64" => "u64",
70        "Int8" => "i8",
71        "Int16" => "i16",
72        "Int32" => "i32",
73        "Int64" => "i64",
74        "Float32" => "f32",
75        "Float64" => "f64",
76        // String, UUID, FixedString, Date*, DateTime*, IPv4/6, Enum*, JSON, and
77        // anything unrecognized → String (the safe over-the-wire default).
78        _ => "String",
79    }
80    .to_string()
81}
82
83/// Rust raw-ident escape for column names that collide with Rust keywords.
84fn rust_field_ident(name: &str) -> String {
85    const KEYWORDS: &[&str] = &[
86        "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn",
87        "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
88        "return", "self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use",
89        "where", "while", "async", "await", "dyn",
90    ];
91    if KEYWORDS.contains(&name) {
92        format!("r#{name}")
93    } else {
94        name.to_string()
95    }
96}
97
98/// Emit a Rust row struct for a table's columns — `(column_name, clickhouse_type)`
99/// pairs. Derives the `clickhouse` crate's `Row` + serde, so it deserializes
100/// straight from a query. The emitted source references `clickhouse::Row`
101/// (a dev/consumer dependency); this function only produces the string.
102pub fn rust_row_struct(struct_name: &str, columns: &[(String, String)]) -> String {
103    let mut out = String::new();
104    out.push_str(
105        "#[derive(Debug, Clone, clickhouse::Row, serde::Serialize, serde::Deserialize)]\n",
106    );
107    out.push_str(&format!("pub struct {struct_name} {{\n"));
108    for (name, ch_type) in columns {
109        let field = rust_field_ident(name);
110        // Preserve the exact column name for (de)serialization when the field was escaped.
111        if field != *name {
112            out.push_str(&format!("    #[serde(rename = \"{name}\")]\n"));
113        }
114        out.push_str(&format!("    pub {field}: {},\n", ch_type_to_rust(ch_type)));
115    }
116    out.push_str("}\n");
117    out
118}
119
120// ── TableSpec → TS + Zod ───────────────────────────────────────────────────────
121
122/// `snake_case` / `kebab-case` → `camelCase` (e.g. `organization_id` → `organizationId`).
123fn to_camel_case(s: &str) -> String {
124    let mut out = String::with_capacity(s.len());
125    let mut upper_next = false;
126    let mut first = true;
127    for c in s.chars() {
128        if c == '_' || c == '-' {
129            upper_next = !first;
130            continue;
131        }
132        if upper_next {
133            out.extend(c.to_uppercase());
134            upper_next = false;
135        } else {
136            out.push(c);
137        }
138        first = false;
139    }
140    out
141}
142
143/// `snake_case` → `PascalCase` (e.g. `observability_traces` → `ObservabilityTraces`).
144fn to_pascal_case(s: &str) -> String {
145    let camel = to_camel_case(s);
146    let mut chars = camel.chars();
147    match chars.next() {
148        Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
149        None => camel,
150    }
151}
152
153fn scalar_ts(s: ScalarType) -> &'static str {
154    match s {
155        ScalarType::String
156        | ScalarType::Uuid
157        | ScalarType::Date
158        | ScalarType::DateTime
159        | ScalarType::DateTime64 => "string",
160        ScalarType::Bool => "boolean",
161        ScalarType::Int8
162        | ScalarType::Int16
163        | ScalarType::Int32
164        | ScalarType::Int64
165        | ScalarType::UInt8
166        | ScalarType::UInt16
167        | ScalarType::UInt32
168        | ScalarType::UInt64
169        | ScalarType::Float32
170        | ScalarType::Float64 => "number",
171        ScalarType::Json => "unknown",
172    }
173}
174
175fn scalar_zod(s: ScalarType) -> &'static str {
176    match s {
177        ScalarType::String
178        | ScalarType::Uuid
179        | ScalarType::Date
180        | ScalarType::DateTime
181        | ScalarType::DateTime64 => "z.string()",
182        ScalarType::Bool => "z.boolean()",
183        ScalarType::Int8
184        | ScalarType::Int16
185        | ScalarType::Int32
186        | ScalarType::Int64
187        | ScalarType::UInt8
188        | ScalarType::UInt16
189        | ScalarType::UInt32
190        | ScalarType::UInt64
191        | ScalarType::Float32
192        | ScalarType::Float64 => "z.number()",
193        ScalarType::Json => "z.unknown()",
194    }
195}
196
197/// The TS type for a column spec. `Nullable(T)` widens to `T | null`;
198/// `LowCardinality(T)` is transparent (renders as `T`).
199fn ts_type(spec: &ColumnTypeSpec) -> String {
200    match spec {
201        ColumnTypeSpec::Scalar(s) => scalar_ts(*s).to_string(),
202        ColumnTypeSpec::DateTime64 { .. } => "string".to_string(),
203        ColumnTypeSpec::Nullable { nullable } => format!("{} | null", ts_type(nullable)),
204        ColumnTypeSpec::LowCardinality { low_cardinality } => ts_type(low_cardinality),
205        ColumnTypeSpec::Array { .. } => "string[]".to_string(),
206        ColumnTypeSpec::Map { .. } => "Record<string, string>".to_string(),
207    }
208}
209
210/// The Zod expression for a column spec. `Nullable(T)` appends `.nullable()`;
211/// `LowCardinality(T)` is transparent (renders as the inner Zod).
212fn zod_type(spec: &ColumnTypeSpec) -> String {
213    match spec {
214        ColumnTypeSpec::Scalar(s) => scalar_zod(*s).to_string(),
215        ColumnTypeSpec::DateTime64 { .. } => "z.string()".to_string(),
216        ColumnTypeSpec::Nullable { nullable } => format!("{}.nullable()", zod_type(nullable)),
217        ColumnTypeSpec::LowCardinality { low_cardinality } => zod_type(low_cardinality),
218        ColumnTypeSpec::Array { .. } => "z.array(z.string())".to_string(),
219        ColumnTypeSpec::Map { .. } => "z.record(z.string(), z.string())".to_string(),
220    }
221}
222
223/// Whether a column is nullable (a `Nullable(...)` at the core, seen through any
224/// transparent `LowCardinality(...)` wrappers). Nullable columns become optional
225/// (`field?`) in the emitted interface.
226fn is_nullable(spec: &ColumnTypeSpec) -> bool {
227    match spec {
228        ColumnTypeSpec::Nullable { .. } => true,
229        ColumnTypeSpec::LowCardinality { low_cardinality } => is_nullable(low_cardinality),
230        _ => false,
231    }
232}
233
234/// The TS interface name for a table, e.g. `observability_traces` → `ObservabilityTracesRow`.
235pub fn row_type_name(table: &TableSpec) -> String {
236    format!("{}Row", to_pascal_case(&table.name))
237}
238
239/// The Zod select-schema const name, e.g. `observabilityTracesSelectSchema`.
240pub fn select_schema_name(table: &TableSpec) -> String {
241    format!("{}SelectSchema", to_camel_case(&table.name))
242}
243
244/// The Zod insert-schema const name, e.g. `observabilityTracesInsertSchema`.
245pub fn insert_schema_name(table: &TableSpec) -> String {
246    format!("{}InsertSchema", to_camel_case(&table.name))
247}
248
249/// Emit the TS row `interface` for a table (one field per column).
250pub fn emit_row_interface(table: &TableSpec) -> String {
251    let mut out = format!("export interface {} {{\n", row_type_name(table));
252    for c in &table.columns {
253        let optional = if is_nullable(&c.type_spec) { "?" } else { "" };
254        out.push_str(&format!(
255            "    {}{}: {};\n",
256            to_camel_case(&c.name),
257            optional,
258            ts_type(&c.type_spec)
259        ));
260    }
261    out.push('}');
262    out
263}
264
265fn emit_zod_object(name: &str, columns: &[ColumnSpec], insert: bool) -> String {
266    let mut out = format!("export const {name} = z.object({{\n");
267    for c in columns {
268        let mut zod = zod_type(&c.type_spec);
269        // Columns with a ClickHouse DEFAULT are optional on insert (the server fills them).
270        if insert && c.default.is_some() {
271            zod.push_str(".optional()");
272        }
273        out.push_str(&format!("    {}: {},\n", to_camel_case(&c.name), zod));
274    }
275    out.push_str("});");
276    out
277}
278
279/// Emit the Zod **select** schema (`z.object(...)`) for a table.
280pub fn emit_select_schema(table: &TableSpec) -> String {
281    emit_zod_object(&select_schema_name(table), &table.columns, false)
282}
283
284/// Emit the Zod **insert** schema — columns with a `DEFAULT` become `.optional()`.
285pub fn emit_insert_schema(table: &TableSpec) -> String {
286    emit_zod_object(&insert_schema_name(table), &table.columns, true)
287}
288
289/// Emit a full TS module for a table: the `zod` import, the row interface, and the
290/// select + insert schemas, separated by blank lines.
291pub fn emit_ts_module(table: &TableSpec) -> String {
292    format!(
293        "import {{ z }} from \"zod\";\n\n{}\n\n{}\n\n{}\n",
294        emit_row_interface(table),
295        emit_select_schema(table),
296        emit_insert_schema(table),
297    )
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::safety::StringOnly;
304
305    // ── ClickHouse → Rust ──
306    #[test]
307    fn maps_scalars() {
308        assert_eq!(ch_type_to_rust("String"), "String");
309        assert_eq!(ch_type_to_rust("UInt64"), "u64");
310        assert_eq!(ch_type_to_rust("Int32"), "i32");
311        assert_eq!(ch_type_to_rust("Float64"), "f64");
312        assert_eq!(ch_type_to_rust("Bool"), "bool");
313        assert_eq!(ch_type_to_rust("UUID"), "String");
314        assert_eq!(ch_type_to_rust("DateTime64(3)"), "String");
315    }
316
317    #[test]
318    fn maps_wrappers_and_containers() {
319        assert_eq!(ch_type_to_rust("Nullable(String)"), "Option<String>");
320        assert_eq!(ch_type_to_rust("LowCardinality(String)"), "String");
321        assert_eq!(
322            ch_type_to_rust("LowCardinality(Nullable(String))"),
323            "Option<String>"
324        );
325        assert_eq!(ch_type_to_rust("Array(String)"), "Vec<String>");
326        assert_eq!(ch_type_to_rust("Array(UInt32)"), "Vec<u32>");
327        assert_eq!(
328            ch_type_to_rust("Map(String, String)"),
329            "std::collections::HashMap<String, String>"
330        );
331        assert_eq!(
332            ch_type_to_rust("Map(String, Array(UInt8))"),
333            "std::collections::HashMap<String, Vec<u8>>"
334        );
335    }
336
337    #[test]
338    fn emits_row_struct_with_keyword_escape() {
339        let cols = vec![
340            ("id".to_string(), "UUID".to_string()),
341            ("count".to_string(), "UInt64".to_string()),
342            ("type".to_string(), "LowCardinality(String)".to_string()),
343            ("tags".to_string(), "Array(String)".to_string()),
344        ];
345        let src = rust_row_struct("EventRow", &cols);
346        assert!(src.contains(
347            "#[derive(Debug, Clone, clickhouse::Row, serde::Serialize, serde::Deserialize)]"
348        ));
349        assert!(src.contains("pub struct EventRow {"));
350        assert!(src.contains("pub id: String,"));
351        assert!(src.contains("pub count: u64,"));
352        assert!(src.contains("#[serde(rename = \"type\")]"));
353        assert!(src.contains("pub r#type: String,"));
354        assert!(src.contains("pub tags: Vec<String>,"));
355    }
356
357    // ── TableSpec → TS / Zod ──
358    fn col(name: &str, t: ColumnTypeSpec) -> ColumnSpec {
359        ColumnSpec {
360            name: name.into(),
361            type_spec: t,
362            default: None,
363        }
364    }
365
366    fn lc(inner: ColumnTypeSpec) -> ColumnTypeSpec {
367        ColumnTypeSpec::LowCardinality {
368            low_cardinality: Box::new(inner),
369        }
370    }
371
372    fn nullable(inner: ColumnTypeSpec) -> ColumnTypeSpec {
373        ColumnTypeSpec::Nullable {
374            nullable: Box::new(inner),
375        }
376    }
377
378    fn sample() -> TableSpec {
379        TableSpec {
380            name: "events".into(),
381            columns: vec![
382                col("id", ColumnTypeSpec::Scalar(ScalarType::Uuid)),
383                col(
384                    "occurred_at",
385                    ColumnTypeSpec::Scalar(ScalarType::DateTime64),
386                ),
387                col("status", lc(ColumnTypeSpec::Scalar(ScalarType::String))),
388                col(
389                    "region",
390                    lc(nullable(ColumnTypeSpec::Scalar(ScalarType::String))),
391                ),
392                col("score", ColumnTypeSpec::Scalar(ScalarType::Float64)),
393                col("retry_count", ColumnTypeSpec::Scalar(ScalarType::UInt32)),
394                col("is_error", ColumnTypeSpec::Scalar(ScalarType::Bool)),
395                col(
396                    "tags",
397                    ColumnTypeSpec::Array {
398                        array: StringOnly::String,
399                    },
400                ),
401                col(
402                    "attributes",
403                    ColumnTypeSpec::Map {
404                        map: (StringOnly::String, StringOnly::String),
405                    },
406                ),
407                col("payload", ColumnTypeSpec::Scalar(ScalarType::Json)),
408                ColumnSpec {
409                    name: "ingested_at".into(),
410                    type_spec: ColumnTypeSpec::Scalar(ScalarType::DateTime),
411                    default: Some("now()".into()),
412                },
413            ],
414            engine: "MergeTree()".into(),
415            order_by: vec!["id".into()],
416            partition_by: None,
417            ttl: None,
418            indexes: vec![],
419            settings: vec![],
420        }
421    }
422
423    #[test]
424    fn names_are_derived_from_table_name() {
425        let t = TableSpec {
426            name: "observability_traces".into(),
427            ..sample()
428        };
429        assert_eq!(row_type_name(&t), "ObservabilityTracesRow");
430        assert_eq!(select_schema_name(&t), "observabilityTracesSelectSchema");
431        assert_eq!(insert_schema_name(&t), "observabilityTracesInsertSchema");
432    }
433
434    #[test]
435    fn golden_row_interface() {
436        let expected = "\
437export interface EventsRow {
438    id: string;
439    occurredAt: string;
440    status: string;
441    region?: string | null;
442    score: number;
443    retryCount: number;
444    isError: boolean;
445    tags: string[];
446    attributes: Record<string, string>;
447    payload: unknown;
448    ingestedAt: string;
449}";
450        assert_eq!(emit_row_interface(&sample()), expected);
451    }
452
453    #[test]
454    fn golden_select_schema() {
455        let expected = "\
456export const eventsSelectSchema = z.object({
457    id: z.string(),
458    occurredAt: z.string(),
459    status: z.string(),
460    region: z.string().nullable(),
461    score: z.number(),
462    retryCount: z.number(),
463    isError: z.boolean(),
464    tags: z.array(z.string()),
465    attributes: z.record(z.string(), z.string()),
466    payload: z.unknown(),
467    ingestedAt: z.string(),
468});";
469        assert_eq!(emit_select_schema(&sample()), expected);
470    }
471
472    #[test]
473    fn golden_insert_schema_makes_default_columns_optional() {
474        let expected = "\
475export const eventsInsertSchema = z.object({
476    id: z.string(),
477    occurredAt: z.string(),
478    status: z.string(),
479    region: z.string().nullable(),
480    score: z.number(),
481    retryCount: z.number(),
482    isError: z.boolean(),
483    tags: z.array(z.string()),
484    attributes: z.record(z.string(), z.string()),
485    payload: z.unknown(),
486    ingestedAt: z.string().optional(),
487});";
488        assert_eq!(emit_insert_schema(&sample()), expected);
489    }
490
491    #[test]
492    fn parametrised_datetime64_maps_to_string() {
493        let dt: ColumnTypeSpec =
494            serde_json::from_str(r#"{"datetime64":{"precision":6,"timezone":"UTC"}}"#).unwrap();
495        let t = TableSpec {
496            name: "t".into(),
497            columns: vec![col("occurred_at", dt)],
498            ..sample()
499        };
500        assert!(emit_row_interface(&t).contains("occurredAt: string;"));
501        assert!(emit_select_schema(&t).contains("occurredAt: z.string()"));
502    }
503
504    #[test]
505    fn nullable_scalar_without_low_cardinality_is_optional_and_nullable() {
506        let t = TableSpec {
507            name: "t".into(),
508            columns: vec![col(
509                "note",
510                nullable(ColumnTypeSpec::Scalar(ScalarType::String)),
511            )],
512            ..sample()
513        };
514        assert!(emit_row_interface(&t).contains("note?: string | null;"));
515        assert!(emit_select_schema(&t).contains("note: z.string().nullable(),"));
516    }
517
518    #[test]
519    fn camel_case_helper() {
520        assert_eq!(to_camel_case("organization_id"), "organizationId");
521        assert_eq!(to_camel_case("started_at"), "startedAt");
522        assert_eq!(to_camel_case("id"), "id");
523        assert_eq!(to_camel_case("_leading"), "leading");
524        assert_eq!(
525            to_pascal_case("observability_traces"),
526            "ObservabilityTraces"
527        );
528    }
529}