specta-swift 0.0.3

Export your Rust types to Swift
Documentation
//! Swift language exporter configuration and main export functionality.

use std::{borrow::Cow, fmt, path::Path};

use specta::{
    Format, Types,
    datatype::{DataType, Fields, Reference},
};

use crate::Error;
use crate::primitives::{export_type, is_duration_struct};

/// Swift language exporter.
#[derive(Clone)]
pub struct Swift {
    /// Header comment for generated files.
    pub header: Cow<'static, str>,
    /// Indentation style for generated code.
    pub indent: IndentStyle,
    /// Naming convention for identifiers.
    pub naming: NamingConvention,
    /// Generic type style.
    pub generics: GenericStyle,
    /// Optional type style.
    pub optionals: OptionalStyle,
    /// Additional protocols to conform to.
    pub protocols: Vec<Cow<'static, str>>,
}

impl fmt::Debug for Swift {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Swift")
            .field("header", &self.header)
            .field("indent", &self.indent)
            .field("naming", &self.naming)
            .field("generics", &self.generics)
            .field("optionals", &self.optionals)
            .field("protocols", &self.protocols)
            .finish()
    }
}

/// Indentation style for generated Swift code.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndentStyle {
    /// Use spaces for indentation.
    Spaces(usize),
    /// Use tabs for indentation.
    Tabs,
}

impl Default for IndentStyle {
    fn default() -> Self {
        Self::Spaces(4)
    }
}

/// Naming convention for Swift identifiers.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NamingConvention {
    /// PascalCase naming (default for Swift types).
    #[default]
    PascalCase,
    /// camelCase naming.
    CamelCase,
    /// snake_case naming.
    SnakeCase,
}

/// Generic type style for Swift.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum GenericStyle {
    /// Use protocol constraints: `<T: Codable>`.
    #[default]
    Protocol,
    /// Use where clauses: `<T> where T: Codable`.
    Typealias,
}

/// Optional type style for Swift.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OptionalStyle {
    /// Use question mark syntax: `String?`.
    #[default]
    QuestionMark,
    /// Use Optional type: `Optional<String>`.
    Optional,
}

impl Default for Swift {
    fn default() -> Self {
        Self {
            header: "// This file has been generated by Specta. DO NOT EDIT.".into(),
            indent: IndentStyle::default(),
            naming: NamingConvention::default(),
            generics: GenericStyle::default(),
            optionals: OptionalStyle::default(),
            protocols: vec![],
        }
    }
}

impl Swift {
    /// Create a new Swift exporter with default configuration.
    pub fn new() -> Self {
        Self::default()
    }

    /// Set the header comment for generated files.
    pub fn header(mut self, header: impl Into<Cow<'static, str>>) -> Self {
        self.header = header.into();
        self
    }

    /// Set the indentation style.
    pub fn indent(mut self, style: IndentStyle) -> Self {
        self.indent = style;
        self
    }

    /// Set the naming convention.
    pub fn naming(mut self, convention: NamingConvention) -> Self {
        self.naming = convention;
        self
    }

    /// Set the generic type style.
    pub fn generics(mut self, style: GenericStyle) -> Self {
        self.generics = style;
        self
    }

    /// Set the optional type style.
    pub fn optionals(mut self, style: OptionalStyle) -> Self {
        self.optionals = style;
        self
    }

    /// Add a protocol that all types should conform to.
    pub fn add_protocol(mut self, protocol: impl Into<Cow<'static, str>>) -> Self {
        self.protocols.push(protocol.into());
        self
    }

    /// Export types to a Swift string.
    pub fn export(&self, types: &Types, format: impl Format) -> Result<String, Error> {
        let exporter = self.clone();
        let formatted_types = format_types(types, &format)?.into_owned();
        let raw_types = &formatted_types;

        let mut result = String::new();

        // Add header
        if !exporter.header.is_empty() {
            result.push_str(&exporter.header);
            result.push('\n');
        }

        // Add imports
        result.push_str("import Foundation\n");
        for protocol in &exporter.protocols {
            result.push_str(&format!("import {}\n", protocol));
        }
        result.push('\n');

        // Check if we need to inject Duration helper
        if needs_duration_helper(raw_types) {
            result.push_str(&generate_duration_helper());
        }

        // Export types
        for ndt in raw_types.into_sorted_iter() {
            let exported = export_type(&exporter, Some(&format), raw_types, ndt)?;
            if !exported.is_empty() {
                result.push_str(&exported);
                result.push_str("\n\n");
            }
        }

        Ok(result)
    }

    /// Export types to a file.
    pub fn export_to(
        &self,
        path: impl AsRef<Path>,
        types: &Types,
        format: impl Format,
    ) -> Result<(), Error> {
        let content = self.export(types, format)?;
        std::fs::write(path, content)?;
        Ok(())
    }
}

fn format_types<'a>(types: &'a Types, format: &'a dyn Format) -> Result<Cow<'a, Types>, Error> {
    format
        .map_types(types)
        .map_err(|err| Error::format("type graph formatter failed", err))
}

impl NamingConvention {
    /// Convert a string to the appropriate naming convention.
    pub fn convert(&self, name: &str) -> String {
        match self {
            Self::PascalCase => self.to_pascal_case(name),
            Self::CamelCase => self.to_camel_case(name),
            Self::SnakeCase => self.to_snake_case(name),
        }
    }

    /// Convert a string to camelCase (for field names).
    pub fn convert_to_camel_case(&self, name: &str) -> String {
        self.to_camel_case(name)
    }

    /// Convert a string to the appropriate naming convention for fields.
    pub fn convert_field(&self, name: &str) -> String {
        match self {
            Self::PascalCase => self.to_camel_case(name), // Fields should be camelCase even with PascalCase
            Self::CamelCase => self.to_camel_case(name),
            Self::SnakeCase => self.to_snake_case(name),
        }
    }

    /// Convert a string to the appropriate naming convention for enum cases.
    pub fn convert_enum_case(&self, name: &str) -> String {
        match self {
            Self::PascalCase => self.to_camel_case(name), // Enum cases should be camelCase
            Self::CamelCase => self.to_camel_case(name),
            Self::SnakeCase => self.to_snake_case(name),
        }
    }

    #[allow(clippy::wrong_self_convention)]
    fn to_camel_case(&self, name: &str) -> String {
        // Convert snake_case or PascalCase to camelCase
        if name.contains('_') {
            // Handle snake_case
            let parts: Vec<&str> = name.split('_').collect();
            if parts.is_empty() {
                return name.to_string();
            }

            let mut result = String::new();
            for (i, part) in parts.iter().enumerate() {
                if i == 0 {
                    result.push_str(&part.to_lowercase());
                } else {
                    let mut chars = part.chars();
                    match chars.next() {
                        None => continue,
                        Some(first) => {
                            result.push(first.to_uppercase().next().unwrap_or(first));
                            for c in chars {
                                result.extend(c.to_lowercase());
                            }
                        }
                    }
                }
            }
            result
        } else {
            if name.chars().any(|c| c.is_ascii_alphabetic())
                && name
                    .chars()
                    .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase())
            {
                return name.to_ascii_lowercase();
            }

            // Handle PascalCase - convert to camelCase
            let mut chars = name.chars();
            match chars.next() {
                None => name.to_string(),
                Some(first) => {
                    let mut result = String::new();
                    result.push(first.to_lowercase().next().unwrap_or(first));
                    for c in chars {
                        result.push(c); // Keep the rest as-is for PascalCase
                    }
                    result
                }
            }
        }
    }

    #[allow(clippy::wrong_self_convention)]
    fn to_pascal_case(&self, name: &str) -> String {
        // Convert snake_case to PascalCase
        name.split('_')
            .map(|part| {
                let mut chars = part.chars();
                match chars.next() {
                    None => String::new(),
                    Some(first) => first.to_uppercase().chain(chars).collect(),
                }
            })
            .collect()
    }

    #[allow(clippy::wrong_self_convention)]
    fn to_snake_case(&self, name: &str) -> String {
        // Convert camelCase/PascalCase to snake_case
        let mut result = String::new();
        let chars = name.chars();

        for c in chars {
            if c.is_uppercase() && !result.is_empty() {
                result.push('_');
            }
            result.push(c.to_lowercase().next().unwrap_or(c));
        }

        result
    }
}

/// Check if the type collection contains any Duration types that need the helper
fn needs_duration_helper(types: &Types) -> bool {
    for ndt in types.into_sorted_iter() {
        if ndt.name == "Duration" {
            return true;
        }
        // Also check if any struct fields contain Duration
        if let Some(DataType::Struct(s)) = &ndt.ty
            && let Fields::Named(fields) = &s.fields
        {
            for (_, field) in &fields.fields {
                if let Some(ty) = field.ty.as_ref() {
                    if let DataType::Reference(Reference::Named(r)) = ty
                        && let Some(referenced_ndt) = types.get(r)
                        && referenced_ndt.name == "Duration"
                    {
                        return true;
                    }
                    // Also check if the field type is a Duration struct directly
                    if let DataType::Struct(struct_ty) = ty
                        && is_duration_struct(struct_ty)
                    {
                        return true;
                    }
                }
            }
        }
    }
    false
}

/// Generate the Duration helper struct
fn generate_duration_helper() -> String {
    "// MARK: - Duration Helper\n".to_string()
        + "/// Helper struct to decode Rust Duration format {\"secs\": u64, \"nanos\": u32}\n"
        + "public struct RustDuration: Codable {\n"
        + "    public let secs: UInt64\n"
        + "    public let nanos: UInt32\n"
        + "    \n"
        + "    public var timeInterval: TimeInterval {\n"
        + "        return Double(secs) + Double(nanos) / 1_000_000_000.0\n"
        + "    }\n"
        + "}\n\n"
        + "// MARK: - Generated Types\n\n"
}