vox-codegen 0.8.1

Language bindings codegen for vox
Documentation
//! Generate TypeScript wire protocol definitions.
//!
//! Walks Facet shapes and emits:
//! - interfaces and tagged unions for all named types
//! - discriminant constants for all named enums (auto-derived)
//! - narrowed per-variant type aliases for all named enums (auto-derived)
//! - messageSchemasCbor: CBOR-encoded `Vec<Schema>` for the handshake
//! - messageRootRef / messageSchemaRegistry: canonical local message schema graph

use std::collections::HashSet;

use facet_core::{ScalarType, Shape};
use vox_types::{
    EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant,
    extract_schemas, is_bytes,
};

use crate::targets::typescript::schema::{render_schema, render_type_ref};

/// A wire type to generate TypeScript definitions for.
pub struct WireType {
    /// The facet Shape to generate from.
    pub shape: &'static Shape,
}

pub struct WireTypeGenConfig {
    pub types: Vec<WireType>,
}

/// Generate a complete TypeScript module with wire protocol type definitions,
/// schema constants, and CBOR helpers. Everything is derived from the shapes —
/// nothing is hardcoded.
pub fn generate_wire(config: &WireTypeGenConfig) -> Result<String, Box<dyn std::error::Error>> {
    let mut out = String::new();
    out.push_str("// @generated by cargo xtask codegen --typescript\n");
    out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --typescript`\n\n");

    let named_types = collect_wire_named_types(&config.types);

    for (name, shape) in &named_types {
        if let Some((_, inner)) = transparent_named_alias(shape) {
            out.push_str(&format!(
                "export type {name} = {};\n\n",
                wire_ts_type(inner)
            ));
            continue;
        }

        match classify_shape(shape) {
            ShapeKind::Struct(StructInfo { fields, .. }) => {
                out.push_str(&format!("export interface {name} {{\n"));
                for field in fields {
                    out.push_str(&format!(
                        "  {}: {};\n",
                        field.name,
                        wire_ts_type(field.shape())
                    ));
                }
                out.push_str("}\n\n");
            }
            ShapeKind::Enum(EnumInfo { variants, .. }) => {
                out.push_str(&format!("export type {name} =\n"));
                for (i, variant) in variants.iter().enumerate() {
                    let variant_type = match classify_variant(variant) {
                        VariantKind::Unit => format!("{{ tag: \"{}\" }}", variant.name),
                        VariantKind::Newtype { inner } => {
                            format!(
                                "{{ tag: \"{}\"; value: {} }}",
                                variant.name,
                                wire_ts_type(inner)
                            )
                        }
                        VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
                            let field_strs = fields
                                .iter()
                                .map(|f| format!("{}: {}", f.name, wire_ts_type(f.shape())))
                                .collect::<Vec<_>>()
                                .join("; ");
                            format!("{{ tag: \"{}\"; {} }}", variant.name, field_strs)
                        }
                    };
                    let sep = if i < variants.len() - 1 { "" } else { ";" };
                    out.push_str(&format!("  | {variant_type}{sep}\n"));
                }
                out.push('\n');
            }
            _ => {}
        }
    }

    // Auto-generate discriminant constants for every named enum
    for (name, shape) in &named_types {
        if let ShapeKind::Enum(EnumInfo { variants, .. }) = classify_shape(shape) {
            out.push_str(&format!("export const {name}Discriminant = {{\n"));
            for (i, variant) in variants.iter().enumerate() {
                out.push_str(&format!("  {}: {},\n", variant.name, i));
            }
            out.push_str("} as const;\n\n");
        }
    }

    // Auto-generate narrowed per-variant type aliases for every named enum
    for (name, shape) in &named_types {
        if let ShapeKind::Enum(EnumInfo { variants, .. }) = classify_shape(shape) {
            for variant in variants {
                out.push_str(&format!(
                    "export type {}{} = Extract<{}, {{ tag: \"{}\" }}>;\n",
                    name, variant.name, name, variant.name,
                ));
            }
            out.push('\n');
        }
    }

    // messageSchemasCbor + local canonical message schema graph.
    // Derived from the first (root) type shape in config.
    if let Some(root) = config.types.first() {
        let extracted = extract_schemas(root.shape)?;
        out.push_str(
            "export const messageSchemaRegistry: import(\"@bearcove/vox-postcard\").SchemaRegistry = new Map<bigint, import(\"@bearcove/vox-postcard\").Schema>([\n",
        );
        for schema in &extracted.schemas {
            out.push_str(&format!(
                "  [{}n, {}],\n",
                schema.id.0,
                render_schema(schema)
            ));
        }
        out.push_str("]);\n\n");
        out.push_str(&format!(
            "export const messageRootRef: import(\"@bearcove/vox-postcard\").TypeRef = {};\n\n",
            render_type_ref(&extracted.root)
        ));
        let cbor_bytes = facet_cbor::to_vec(&extracted.schemas)?;
        let body = cbor_bytes
            .iter()
            .map(|b| b.to_string())
            .collect::<Vec<_>>()
            .join(", ");
        out.push_str(&format!(
            "export const messageSchemasCbor = new Uint8Array([{body}]);\n"
        ));
    }

    Ok(out)
}

/// Collect named types from wire type shapes in dependency order.
fn collect_wire_named_types(types: &[WireType]) -> Vec<(String, &'static Shape)> {
    let mut seen = HashSet::new();
    let mut result = Vec::new();

    for wire_type in types {
        visit(wire_type.shape, &mut seen, &mut result);
    }

    result
}

fn visit(
    shape: &'static Shape,
    seen: &mut HashSet<String>,
    types: &mut Vec<(String, &'static Shape)>,
) {
    if let Some((name, inner)) = transparent_named_alias(shape) {
        if !seen.contains(name) {
            seen.insert(name.to_string());
            visit(inner, seen, types);
            types.push((name.to_string(), shape));
        }
        return;
    }

    match classify_shape(shape) {
        ShapeKind::Struct(StructInfo {
            name: Some(name),
            fields,
            ..
        }) if seen.insert(name.to_string()) => {
            for field in fields {
                visit(field.shape(), seen, types);
            }
            types.push((name.to_string(), shape));
        }
        ShapeKind::Enum(EnumInfo {
            name: Some(name),
            variants,
        }) if seen.insert(name.to_string()) => {
            for variant in variants {
                match classify_variant(variant) {
                    VariantKind::Newtype { inner } => visit(inner, seen, types),
                    VariantKind::Struct { fields } | VariantKind::Tuple { fields } => {
                        for field in fields {
                            visit(field.shape(), seen, types);
                        }
                    }
                    VariantKind::Unit => {}
                }
            }
            types.push((name.to_string(), shape));
        }
        ShapeKind::List { element } => visit(element, seen, types),
        ShapeKind::Option { inner } => visit(inner, seen, types),
        ShapeKind::Array { element, .. } => visit(element, seen, types),
        ShapeKind::Map { key, value } => {
            visit(key, seen, types);
            visit(value, seen, types);
        }
        ShapeKind::Set { element } => visit(element, seen, types),
        ShapeKind::Tuple { elements } => {
            for param in elements {
                visit(param.shape, seen, types);
            }
        }
        ShapeKind::Pointer { pointee } => visit(pointee, seen, types),
        ShapeKind::Result { ok, err } => {
            visit(ok, seen, types);
            visit(err, seen, types);
        }
        _ => {}
    }
}

fn transparent_named_alias(shape: &'static Shape) -> Option<(&'static str, &'static Shape)> {
    if !shape.is_transparent() {
        return None;
    }
    let name = extract_type_name(shape.type_identifier)?;
    let inner = shape.inner?;
    Some((name, inner))
}

fn extract_type_name(type_identifier: &'static str) -> Option<&'static str> {
    if type_identifier.is_empty()
        || type_identifier.starts_with('(')
        || type_identifier.starts_with('[')
    {
        return None;
    }
    Some(type_identifier)
}

/// Convert a Shape to a TypeScript type string for wire protocol types.
fn wire_ts_type(shape: &'static Shape) -> String {
    if let Some((name, _)) = transparent_named_alias(shape) {
        return name.to_string();
    }

    match classify_shape(shape) {
        ShapeKind::Struct(StructInfo {
            name: Some(name), ..
        }) => name.to_string(),
        ShapeKind::Enum(EnumInfo {
            name: Some(name), ..
        }) => name.to_string(),

        ShapeKind::List { .. } if is_bytes(shape) => "Uint8Array".into(),
        ShapeKind::List { element } => {
            if matches!(
                classify_shape(element),
                ShapeKind::Enum(EnumInfo { name: None, .. })
            ) {
                format!("({})[]", wire_ts_type(element))
            } else {
                format!("{}[]", wire_ts_type(element))
            }
        }
        ShapeKind::Option { inner } => format!("{} | null", wire_ts_type(inner)),
        ShapeKind::Scalar(scalar) => wire_ts_scalar_type(scalar),
        ShapeKind::Slice { .. } if is_bytes(shape) => "Uint8Array".into(),
        ShapeKind::Slice { element } => format!("{}[]", wire_ts_type(element)),
        ShapeKind::Pointer { pointee } if is_bytes(pointee) => "Uint8Array".into(),
        ShapeKind::Pointer { pointee } => wire_ts_type(pointee),
        ShapeKind::Opaque => "Uint8Array".into(),

        ShapeKind::Struct(StructInfo {
            name: None, fields, ..
        }) => {
            let inner = fields
                .iter()
                .map(|f| format!("{}: {}", f.name, wire_ts_type(f.shape())))
                .collect::<Vec<_>>()
                .join("; ");
            format!("{{ {inner} }}")
        }
        ShapeKind::Enum(EnumInfo {
            name: None,
            variants,
        }) => variants
            .iter()
            .map(|v| match classify_variant(v) {
                VariantKind::Unit => format!("{{ tag: \"{}\" }}", v.name),
                VariantKind::Newtype { inner } => {
                    format!("{{ tag: \"{}\"; value: {} }}", v.name, wire_ts_type(inner))
                }
                VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
                    let field_strs = fields
                        .iter()
                        .map(|f| format!("{}: {}", f.name, wire_ts_type(f.shape())))
                        .collect::<Vec<_>>()
                        .join("; ");
                    format!("{{ tag: \"{}\"; {} }}", v.name, field_strs)
                }
            })
            .collect::<Vec<_>>()
            .join(" | "),

        ShapeKind::Tuple { elements } => {
            let inner = elements
                .iter()
                .map(|p| wire_ts_type(p.shape))
                .collect::<Vec<_>>()
                .join(", ");
            format!("[{inner}]")
        }
        ShapeKind::Map { key, value } => {
            format!("Map<{}, {}>", wire_ts_type(key), wire_ts_type(value))
        }
        ShapeKind::Set { element } => format!("Set<{}>", wire_ts_type(element)),
        ShapeKind::Array { element, len } => format!("[{}; {}]", wire_ts_type(element), len),

        _ => "unknown".into(),
    }
}

fn wire_ts_scalar_type(scalar: ScalarType) -> String {
    match scalar {
        ScalarType::Bool => "boolean".into(),
        ScalarType::U8
        | ScalarType::U16
        | ScalarType::U32
        | ScalarType::I8
        | ScalarType::I16
        | ScalarType::I32
        | ScalarType::F32
        | ScalarType::F64 => "number".into(),
        ScalarType::U64
        | ScalarType::U128
        | ScalarType::I64
        | ScalarType::I128
        | ScalarType::USize
        | ScalarType::ISize => "bigint".into(),
        ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
            "string".into()
        }
        ScalarType::Unit => "void".into(),
        _ => "unknown".into(),
    }
}