rustantic 0.2.0

Rust to Pydantic generator
Documentation
use std::collections::HashSet;

use itertools::{sorted, Itertools};

use crate::{
    collector::MetadataCollector,
    models::{DiscriminatedUnionMetadata, ItemMetadata, UnionVariantMetadata},
};

use super::{
    field_generator::FieldGenerator,
    generator_base::{
        GenerationResult, GeneratorConfig, PydanticCodeGenerator, PydanticCodeGeneratorFactory,
    },
};

pub(crate) struct UnionCodeGenerator {}

impl PydanticCodeGeneratorFactory for UnionCodeGenerator {
    fn create() -> Box<dyn PydanticCodeGenerator> {
        Box::new(Self {})
    }
}

impl PydanticCodeGenerator for UnionCodeGenerator {
    fn is_item_supported(&self, meta: &ItemMetadata) -> bool {
        matches!(meta, ItemMetadata::DiscriminatedUnion(_))
    }

    fn generate(
        &self,
        config: GeneratorConfig,
        collector: &MetadataCollector,
        meta: &ItemMetadata,
    ) -> Result<String, ()> {
        if let ItemMetadata::DiscriminatedUnion(union_md) = meta {
            Ok(self.generate_code(config, collector, union_md))
        } else {
            Err(())
        }
    }
}

impl UnionCodeGenerator {
    fn generate_code(
        &self,
        config: GeneratorConfig,
        collector: &MetadataCollector,
        meta: &DiscriminatedUnionMetadata,
    ) -> String {
        let discriminator = self.generate_discriminator(meta);
        let variants = self.generate_union_variants(config.clone(), collector, meta);
        let definition = self.generate_type_definitions(&config, meta);
        let imports = self.generate_import(config.package_name, &variants.additional_imports);
        format!(
            "# Missing to_rs function\n{}\n\n{}\n{}\n{}",
            imports, discriminator, variants.code, definition
        )
    }

    fn generate_import(&self, package_name: &str, additional_imports: &HashSet<String>) -> String {
        let mut imports: HashSet<String> = vec![
            "import enum".to_owned(),
            "from typing import Literal, Union, Any".to_owned(),
            "from pydantic import BaseModel, Field, RootModel".to_owned(),
            format!("import {}", package_name),
        ]
        .into_iter()
        .collect();

        imports.extend(additional_imports.iter().map(|i| i.to_owned()));

        sorted(imports).join("\n")
    }

    fn generate_discriminator(&self, meta: &DiscriminatedUnionMetadata) -> String {
        let mut code = format!(
            "class {}(enum.Enum):\n",
            self.generate_discriminator_name(&meta.ident)
        );
        for variant in meta.variants.iter() {
            code.push_str(&format!("    {} = enum.auto()\n", &variant.ident));
        }
        code.push_str("\n");
        code
    }

    fn generate_union_variants(
        &self,
        config: GeneratorConfig,
        collector: &MetadataCollector,
        meta: &DiscriminatedUnionMetadata,
    ) -> GenerationResult {
        let field_generator = FieldGenerator::new(config.clone(), collector.entities());
        let mut variants = vec![];
        let mut result = GenerationResult::default();
        for variant in meta.variants.iter() {
            let variant_code = self.generate_union_variant(&field_generator, meta, variant);
            result
                .additional_imports
                .extend(variant_code.additional_imports);
            variants.push(variant_code.code);
        }
        result.code = variants.join("\n");
        result
    }

    fn generate_union_variant(
        &self,
        field_generator: &FieldGenerator,
        meta: &DiscriminatedUnionMetadata,
        variant: &UnionVariantMetadata,
    ) -> GenerationResult {
        let mut result = GenerationResult::default();
        let field_gen = field_generator.generate(
            "value",
            variant.ty.as_ref().expect("Named enum not supported"),
        );
        let code = vec![
            format!("class {0}{1}(BaseModel):", &meta.ident, &variant.ident),
            format!(
                "    kind: Literal[{0}.{1}] = Field(default={0}.{1}, init=False, frozen=True)",
                self.generate_discriminator_name(&meta.ident),
                &variant.ident
            ),
            format!("    {}", field_gen.code),
            "\n".to_string(),
        ];
        result.code = code.join("\n");
        result
            .additional_imports
            .extend(field_gen.additional_imports);
        result
    }

    fn generate_type_definitions(
        &self,
        config: &GeneratorConfig,
        meta: &DiscriminatedUnionMetadata,
    ) -> String {
        let variants: Vec<String> = meta
            .variants
            .iter()
            .map(|v| format!("{}{}", &meta.ident, &v.ident))
            .collect();

        vec![
            format!("{}Type = Union[{}]", &meta.ident, variants.join(",")),
            format!("\nclass {0}(RootModel[{0}Type]):", &meta.ident),
            format!(
                "    root: {}Type = Field(..., discriminator=\"kind\")",
                &meta.ident
            ),
            self.generate_to_pyo3(config, meta),
        ]
        .join("\n")
    }

    fn generate_discriminator_name(&self, ident: &str) -> String {
        format!("{}Discriminator", ident)
    }

    fn generate_to_pyo3(
        &self,
        config: &GeneratorConfig,
        meta: &DiscriminatedUnionMetadata,
    ) -> String {
        let mut code_sections = vec![
            "    def to_rs(self):".to_owned(),
            "        inner_to_rs = getattr(self.root.value, \"to_rs\", lambda v: v)".to_string(),
            "        val: Any = inner_to_rs(self.root.value)".to_string(),
            "        match self.root.kind:".to_string(),
        ];
        for variant in meta.variants.iter() {
            code_sections.push(format!(
                "            case {}.{}:",
                self.generate_discriminator_name(&meta.ident),
                &variant.ident
            ));
            code_sections.push(format!(
                "                return {}.{}.{}(val)\n",
                config.package_name, &meta.ident, &variant.ident,
            ));
        }

        code_sections.join("\n")
    }
}