use std::{borrow::Cow, fmt, path::Path};
use specta::{
Format, Types,
datatype::{DataType, Fields, Reference},
};
use crate::Error;
use crate::primitives::{export_type, is_duration_struct};
#[derive(Clone)]
pub struct Swift {
pub header: Cow<'static, str>,
pub indent: IndentStyle,
pub naming: NamingConvention,
pub generics: GenericStyle,
pub optionals: OptionalStyle,
pub protocols: Vec<Cow<'static, str>>,
}
impl fmt::Debug for Swift {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Swift")
.field("header", &self.header)
.field("indent", &self.indent)
.field("naming", &self.naming)
.field("generics", &self.generics)
.field("optionals", &self.optionals)
.field("protocols", &self.protocols)
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndentStyle {
Spaces(usize),
Tabs,
}
impl Default for IndentStyle {
fn default() -> Self {
Self::Spaces(4)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NamingConvention {
#[default]
PascalCase,
CamelCase,
SnakeCase,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum GenericStyle {
#[default]
Protocol,
Typealias,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OptionalStyle {
#[default]
QuestionMark,
Optional,
}
impl Default for Swift {
fn default() -> Self {
Self {
header: "// This file has been generated by Specta. DO NOT EDIT.".into(),
indent: IndentStyle::default(),
naming: NamingConvention::default(),
generics: GenericStyle::default(),
optionals: OptionalStyle::default(),
protocols: vec![],
}
}
}
impl Swift {
pub fn new() -> Self {
Self::default()
}
pub fn header(mut self, header: impl Into<Cow<'static, str>>) -> Self {
self.header = header.into();
self
}
pub fn indent(mut self, style: IndentStyle) -> Self {
self.indent = style;
self
}
pub fn naming(mut self, convention: NamingConvention) -> Self {
self.naming = convention;
self
}
pub fn generics(mut self, style: GenericStyle) -> Self {
self.generics = style;
self
}
pub fn optionals(mut self, style: OptionalStyle) -> Self {
self.optionals = style;
self
}
pub fn add_protocol(mut self, protocol: impl Into<Cow<'static, str>>) -> Self {
self.protocols.push(protocol.into());
self
}
pub fn export(&self, types: &Types, format: impl Format) -> Result<String, Error> {
let exporter = self.clone();
let formatted_types = format_types(types, &format)?.into_owned();
let raw_types = &formatted_types;
let mut result = String::new();
if !exporter.header.is_empty() {
result.push_str(&exporter.header);
result.push('\n');
}
result.push_str("import Foundation\n");
for protocol in &exporter.protocols {
result.push_str(&format!("import {}\n", protocol));
}
result.push('\n');
if needs_duration_helper(raw_types) {
result.push_str(&generate_duration_helper());
}
for ndt in raw_types.into_sorted_iter() {
let exported = export_type(&exporter, Some(&format), raw_types, ndt)?;
if !exported.is_empty() {
result.push_str(&exported);
result.push_str("\n\n");
}
}
Ok(result)
}
pub fn export_to(
&self,
path: impl AsRef<Path>,
types: &Types,
format: impl Format,
) -> Result<(), Error> {
let content = self.export(types, format)?;
std::fs::write(path, content)?;
Ok(())
}
}
fn format_types<'a>(types: &'a Types, format: &'a dyn Format) -> Result<Cow<'a, Types>, Error> {
format
.map_types(types)
.map_err(|err| Error::format("type graph formatter failed", err))
}
impl NamingConvention {
pub fn convert(&self, name: &str) -> String {
match self {
Self::PascalCase => self.to_pascal_case(name),
Self::CamelCase => self.to_camel_case(name),
Self::SnakeCase => self.to_snake_case(name),
}
}
pub fn convert_to_camel_case(&self, name: &str) -> String {
self.to_camel_case(name)
}
pub fn convert_field(&self, name: &str) -> String {
match self {
Self::PascalCase => self.to_camel_case(name), Self::CamelCase => self.to_camel_case(name),
Self::SnakeCase => self.to_snake_case(name),
}
}
pub fn convert_enum_case(&self, name: &str) -> String {
match self {
Self::PascalCase => self.to_camel_case(name), Self::CamelCase => self.to_camel_case(name),
Self::SnakeCase => self.to_snake_case(name),
}
}
#[allow(clippy::wrong_self_convention)]
fn to_camel_case(&self, name: &str) -> String {
if name.contains('_') {
let parts: Vec<&str> = name.split('_').collect();
if parts.is_empty() {
return name.to_string();
}
let mut result = String::new();
for (i, part) in parts.iter().enumerate() {
if i == 0 {
result.push_str(&part.to_lowercase());
} else {
let mut chars = part.chars();
match chars.next() {
None => continue,
Some(first) => {
result.push(first.to_uppercase().next().unwrap_or(first));
for c in chars {
result.extend(c.to_lowercase());
}
}
}
}
}
result
} else {
if name.chars().any(|c| c.is_ascii_alphabetic())
&& name
.chars()
.all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase())
{
return name.to_ascii_lowercase();
}
let mut chars = name.chars();
match chars.next() {
None => name.to_string(),
Some(first) => {
let mut result = String::new();
result.push(first.to_lowercase().next().unwrap_or(first));
for c in chars {
result.push(c); }
result
}
}
}
}
#[allow(clippy::wrong_self_convention)]
fn to_pascal_case(&self, name: &str) -> String {
name.split('_')
.map(|part| {
let mut chars = part.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().chain(chars).collect(),
}
})
.collect()
}
#[allow(clippy::wrong_self_convention)]
fn to_snake_case(&self, name: &str) -> String {
let mut result = String::new();
let chars = name.chars();
for c in chars {
if c.is_uppercase() && !result.is_empty() {
result.push('_');
}
result.push(c.to_lowercase().next().unwrap_or(c));
}
result
}
}
fn needs_duration_helper(types: &Types) -> bool {
for ndt in types.into_sorted_iter() {
if ndt.name == "Duration" {
return true;
}
if let Some(DataType::Struct(s)) = &ndt.ty
&& let Fields::Named(fields) = &s.fields
{
for (_, field) in &fields.fields {
if let Some(ty) = field.ty.as_ref() {
if let DataType::Reference(Reference::Named(r)) = ty
&& let Some(referenced_ndt) = types.get(r)
&& referenced_ndt.name == "Duration"
{
return true;
}
if let DataType::Struct(struct_ty) = ty
&& is_duration_struct(struct_ty)
{
return true;
}
}
}
}
}
false
}
fn generate_duration_helper() -> String {
"// MARK: - Duration Helper\n".to_string()
+ "/// Helper struct to decode Rust Duration format {\"secs\": u64, \"nanos\": u32}\n"
+ "public struct RustDuration: Codable {\n"
+ " public let secs: UInt64\n"
+ " public let nanos: UInt32\n"
+ " \n"
+ " public var timeInterval: TimeInterval {\n"
+ " return Double(secs) + Double(nanos) / 1_000_000_000.0\n"
+ " }\n"
+ "}\n\n"
+ "// MARK: - Generated Types\n\n"
}