specta-swift 0.0.2

Export your Rust types to Swift
Documentation
//! Swift language exporter configuration and main export functionality.

use std::{borrow::Cow, path::Path};

use specta::{
    ResolvedTypes, Types,
    datatype::{DataType, Fields, Reference},
};

use crate::error::Result;
use crate::primitives::{export_type, is_duration_struct};

/// Swift language exporter.
#[derive(Debug, Clone)]
pub struct Swift {
    /// Header comment for generated files.
    pub header: Cow<'static, str>,
    /// Indentation style for generated code.
    pub indent: IndentStyle,
    /// Naming convention for identifiers.
    pub naming: NamingConvention,
    /// Generic type style.
    pub generics: GenericStyle,
    /// Optional type style.
    pub optionals: OptionalStyle,
    /// Additional protocols to conform to.
    pub protocols: Vec<Cow<'static, str>>,
}

/// Indentation style for generated Swift code.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndentStyle {
    /// Use spaces for indentation.
    Spaces(usize),
    /// Use tabs for indentation.
    Tabs,
}

impl Default for IndentStyle {
    fn default() -> Self {
        Self::Spaces(4)
    }
}

/// Naming convention for Swift identifiers.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NamingConvention {
    /// PascalCase naming (default for Swift types).
    #[default]
    PascalCase,
    /// camelCase naming.
    CamelCase,
    /// snake_case naming.
    SnakeCase,
}

/// Generic type style for Swift.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum GenericStyle {
    /// Use protocol constraints: `<T: Codable>`.
    #[default]
    Protocol,
    /// Use where clauses: `<T> where T: Codable`.
    Typealias,
}

/// Optional type style for Swift.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OptionalStyle {
    /// Use question mark syntax: `String?`.
    #[default]
    QuestionMark,
    /// Use Optional type: `Optional<String>`.
    Optional,
}

impl Default for Swift {
    fn default() -> Self {
        Self {
            header: "// This file has been generated by Specta. DO NOT EDIT.".into(),
            indent: IndentStyle::default(),
            naming: NamingConvention::default(),
            generics: GenericStyle::default(),
            optionals: OptionalStyle::default(),
            protocols: vec![],
        }
    }
}

impl Swift {
    /// Create a new Swift exporter with default configuration.
    pub fn new() -> Self {
        Self::default()
    }

    /// Set the header comment for generated files.
    pub fn header(mut self, header: impl Into<Cow<'static, str>>) -> Self {
        self.header = header.into();
        self
    }

    /// Set the indentation style.
    pub fn indent(mut self, style: IndentStyle) -> Self {
        self.indent = style;
        self
    }

    /// Set the naming convention.
    pub fn naming(mut self, convention: NamingConvention) -> Self {
        self.naming = convention;
        self
    }

    /// Set the generic type style.
    pub fn generics(mut self, style: GenericStyle) -> Self {
        self.generics = style;
        self
    }

    /// Set the optional type style.
    pub fn optionals(mut self, style: OptionalStyle) -> Self {
        self.optionals = style;
        self
    }

    /// Add a protocol that all types should conform to.
    pub fn add_protocol(mut self, protocol: impl Into<Cow<'static, str>>) -> Self {
        self.protocols.push(protocol.into());
        self
    }

    /// Export types to a Swift string.
    pub fn export(&self, types: &ResolvedTypes) -> Result<String> {
        let mut result = String::new();
        let raw_types = types.as_types();

        // Add header
        if !self.header.is_empty() {
            result.push_str(&self.header);
            result.push('\n');
        }

        // Add imports
        result.push_str("import Foundation\n");
        for protocol in &self.protocols {
            result.push_str(&format!("import {}\n", protocol));
        }
        result.push('\n');

        // Check if we need to inject Duration helper
        if needs_duration_helper(raw_types) {
            result.push_str(&generate_duration_helper());
        }

        // Export types
        for ndt in raw_types.into_sorted_iter() {
            let exported = export_type(self, raw_types, ndt)?;
            if !exported.is_empty() {
                result.push_str(&exported);
                result.push_str("\n\n");
            }
        }

        Ok(result)
    }

    /// Export types to a file.
    pub fn export_to(&self, path: impl AsRef<Path>, types: &ResolvedTypes) -> Result<()> {
        let content = self.export(types)?;
        std::fs::write(path, content)?;
        Ok(())
    }
}

impl NamingConvention {
    /// Convert a string to the appropriate naming convention.
    pub fn convert(&self, name: &str) -> String {
        match self {
            Self::PascalCase => self.to_pascal_case(name),
            Self::CamelCase => self.to_camel_case(name),
            Self::SnakeCase => self.to_snake_case(name),
        }
    }

    /// Convert a string to camelCase (for field names).
    pub fn convert_to_camel_case(&self, name: &str) -> String {
        self.to_camel_case(name)
    }

    /// Convert a string to the appropriate naming convention for fields.
    pub fn convert_field(&self, name: &str) -> String {
        match self {
            Self::PascalCase => self.to_camel_case(name), // Fields should be camelCase even with PascalCase
            Self::CamelCase => self.to_camel_case(name),
            Self::SnakeCase => self.to_snake_case(name),
        }
    }

    /// Convert a string to the appropriate naming convention for enum cases.
    pub fn convert_enum_case(&self, name: &str) -> String {
        match self {
            Self::PascalCase => self.to_camel_case(name), // Enum cases should be camelCase
            Self::CamelCase => self.to_camel_case(name),
            Self::SnakeCase => self.to_snake_case(name),
        }
    }

    #[allow(clippy::wrong_self_convention)]
    fn to_camel_case(&self, name: &str) -> String {
        // Convert snake_case or PascalCase to camelCase
        if name.contains('_') {
            // Handle snake_case
            let parts: Vec<&str> = name.split('_').collect();
            if parts.is_empty() {
                return name.to_string();
            }

            let mut result = String::new();
            for (i, part) in parts.iter().enumerate() {
                if i == 0 {
                    result.push_str(&part.to_lowercase());
                } else {
                    let mut chars = part.chars();
                    match chars.next() {
                        None => continue,
                        Some(first) => {
                            result.push(first.to_uppercase().next().unwrap_or(first));
                            for c in chars {
                                result.extend(c.to_lowercase());
                            }
                        }
                    }
                }
            }
            result
        } else {
            if name.chars().any(|c| c.is_ascii_alphabetic())
                && name
                    .chars()
                    .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase())
            {
                return name.to_ascii_lowercase();
            }

            // Handle PascalCase - convert to camelCase
            let mut chars = name.chars();
            match chars.next() {
                None => name.to_string(),
                Some(first) => {
                    let mut result = String::new();
                    result.push(first.to_lowercase().next().unwrap_or(first));
                    for c in chars {
                        result.push(c); // Keep the rest as-is for PascalCase
                    }
                    result
                }
            }
        }
    }

    #[allow(clippy::wrong_self_convention)]
    fn to_pascal_case(&self, name: &str) -> String {
        // Convert snake_case to PascalCase
        name.split('_')
            .map(|part| {
                let mut chars = part.chars();
                match chars.next() {
                    None => String::new(),
                    Some(first) => first.to_uppercase().chain(chars).collect(),
                }
            })
            .collect()
    }

    #[allow(clippy::wrong_self_convention)]
    fn to_snake_case(&self, name: &str) -> String {
        // Convert camelCase/PascalCase to snake_case
        let mut result = String::new();
        let chars = name.chars();

        for c in chars {
            if c.is_uppercase() && !result.is_empty() {
                result.push('_');
            }
            result.push(c.to_lowercase().next().unwrap_or(c));
        }

        result
    }
}

/// Check if the type collection contains any Duration types that need the helper
fn needs_duration_helper(types: &Types) -> bool {
    for ndt in types.into_sorted_iter() {
        if ndt.name() == "Duration" {
            return true;
        }
        // Also check if any struct fields contain Duration
        if let DataType::Struct(s) = ndt.ty()
            && let Fields::Named(fields) = s.fields()
        {
            for (_, field) in fields.fields() {
                if let Some(ty) = field.ty() {
                    if let DataType::Reference(Reference::Named(r)) = ty
                        && let Some(referenced_ndt) = r.get(types)
                        && referenced_ndt.name() == "Duration"
                    {
                        return true;
                    }
                    // Also check if the field type is a Duration struct directly
                    if let DataType::Struct(struct_ty) = ty
                        && is_duration_struct(struct_ty)
                    {
                        return true;
                    }
                }
            }
        }
    }
    false
}

/// Generate the Duration helper struct
fn generate_duration_helper() -> String {
    "// MARK: - Duration Helper\n".to_string()
        + "/// Helper struct to decode Rust Duration format {\"secs\": u64, \"nanos\": u32}\n"
        + "public struct RustDuration: Codable {\n"
        + "    public let secs: UInt64\n"
        + "    public let nanos: UInt32\n"
        + "    \n"
        + "    public var timeInterval: TimeInterval {\n"
        + "        return Double(secs) + Double(nanos) / 1_000_000_000.0\n"
        + "    }\n"
        + "}\n\n"
        + "// MARK: - Generated Types\n\n"
}