Skip to main content

rdbi_codegen/codegen/
type_resolver.rs

1//! MySQL to Rust type mapping
2
3use super::naming;
4use crate::parser::ColumnMetadata;
5
6/// Represents a Rust type for code generation
7#[derive(Debug, Clone, PartialEq)]
8pub enum RustType {
9    Bool,
10    I8,
11    I16,
12    I32,
13    I64,
14    U8,
15    U16,
16    U32,
17    U64,
18    F32,
19    F64,
20    String,
21    Bytes,
22    Decimal,
23    NaiveDate,
24    NaiveDateTime,
25    NaiveTime,
26    Json,
27    /// Custom enum type with the enum name
28    Enum(String),
29    /// Optional wrapper
30    Option(Box<RustType>),
31}
32
33impl RustType {
34    /// Get the type string for code generation
35    pub fn to_type_string(&self) -> String {
36        match self {
37            RustType::Bool => "bool".to_string(),
38            RustType::I8 => "i8".to_string(),
39            RustType::I16 => "i16".to_string(),
40            RustType::I32 => "i32".to_string(),
41            RustType::I64 => "i64".to_string(),
42            RustType::U8 => "u8".to_string(),
43            RustType::U16 => "u16".to_string(),
44            RustType::U32 => "u32".to_string(),
45            RustType::U64 => "u64".to_string(),
46            RustType::F32 => "f32".to_string(),
47            RustType::F64 => "f64".to_string(),
48            RustType::String => "String".to_string(),
49            RustType::Bytes => "Vec<u8>".to_string(),
50            RustType::Decimal => "rust_decimal::Decimal".to_string(),
51            RustType::NaiveDate => "chrono::NaiveDate".to_string(),
52            RustType::NaiveDateTime => "chrono::NaiveDateTime".to_string(),
53            RustType::NaiveTime => "chrono::NaiveTime".to_string(),
54            RustType::Json => "serde_json::Value".to_string(),
55            RustType::Enum(name) => name.clone(),
56            RustType::Option(inner) => format!("Option<{}>", inner.to_type_string()),
57        }
58    }
59
60    /// Get the type string for function parameters (use references for strings)
61    pub fn to_param_type_string(&self) -> String {
62        match self {
63            RustType::String => "&str".to_string(),
64            RustType::Bytes => "&[u8]".to_string(),
65            RustType::Option(inner) => match inner.as_ref() {
66                RustType::String => "Option<&str>".to_string(),
67                RustType::Bytes => "Option<&[u8]>".to_string(),
68                _ => format!("Option<{}>", inner.to_type_string()),
69            },
70            _ => self.to_type_string(),
71        }
72    }
73
74    /// Check if this type needs the chrono crate
75    pub fn needs_chrono(&self) -> bool {
76        match self {
77            RustType::NaiveDate | RustType::NaiveDateTime | RustType::NaiveTime => true,
78            RustType::Option(inner) => inner.needs_chrono(),
79            _ => false,
80        }
81    }
82
83    /// Check if this type needs the rust_decimal crate
84    pub fn needs_decimal(&self) -> bool {
85        match self {
86            RustType::Decimal => true,
87            RustType::Option(inner) => inner.needs_decimal(),
88            _ => false,
89        }
90    }
91
92    /// Check if this type needs the serde_json crate
93    pub fn needs_serde_json(&self) -> bool {
94        match self {
95            RustType::Json => true,
96            RustType::Option(inner) => inner.needs_serde_json(),
97            _ => false,
98        }
99    }
100
101    /// Get the inner type if this is an Option
102    pub fn inner_type(&self) -> &RustType {
103        match self {
104            RustType::Option(inner) => inner,
105            _ => self,
106        }
107    }
108
109    /// Check if this is an Option type
110    pub fn is_optional(&self) -> bool {
111        matches!(self, RustType::Option(_))
112    }
113}
114
115/// Resolve MySQL data types to Rust types
116pub struct TypeResolver;
117
118impl TypeResolver {
119    /// Get the Rust type for a column
120    pub fn resolve(column: &ColumnMetadata, table_name: &str) -> RustType {
121        let base_type = Self::resolve_base_type(column, table_name);
122
123        if column.nullable {
124            RustType::Option(Box::new(base_type))
125        } else {
126            base_type
127        }
128    }
129
130    /// Resolve the base type (without Option wrapper)
131    fn resolve_base_type(column: &ColumnMetadata, table_name: &str) -> RustType {
132        let data_type = column.data_type.to_uppercase();
133
134        // Check for enum first
135        if column.is_enum() {
136            let enum_name = naming::to_enum_name(table_name, &column.name);
137            return RustType::Enum(enum_name);
138        }
139
140        // Parse the data type
141        let data_type_lower = data_type.to_lowercase();
142
143        // Boolean types
144        if Self::is_boolean_type(&data_type_lower, &column.data_type) {
145            return RustType::Bool;
146        }
147
148        // Integer types
149        if data_type_lower.starts_with("tinyint") {
150            return if column.is_unsigned {
151                RustType::U8
152            } else {
153                RustType::I8
154            };
155        }
156        if data_type_lower.starts_with("smallint") {
157            return if column.is_unsigned {
158                RustType::U16
159            } else {
160                RustType::I16
161            };
162        }
163        if data_type_lower.starts_with("mediumint") || data_type_lower.starts_with("int") {
164            return if column.is_unsigned {
165                RustType::U32
166            } else {
167                RustType::I32
168            };
169        }
170        if data_type_lower.starts_with("bigint") {
171            return if column.is_unsigned {
172                RustType::U64
173            } else {
174                RustType::I64
175            };
176        }
177
178        // Float types
179        if data_type_lower.starts_with("float") {
180            return RustType::F32;
181        }
182        if data_type_lower.starts_with("double") || data_type_lower.starts_with("real") {
183            return RustType::F64;
184        }
185
186        // Decimal types
187        if data_type_lower.starts_with("decimal") || data_type_lower.starts_with("numeric") {
188            return RustType::Decimal;
189        }
190
191        // String types
192        if data_type_lower.starts_with("char")
193            || data_type_lower.starts_with("varchar")
194            || data_type_lower.contains("text")
195            || data_type_lower.starts_with("enum")
196            || data_type_lower.starts_with("set")
197        {
198            return RustType::String;
199        }
200
201        // Binary types
202        if data_type_lower.starts_with("binary")
203            || data_type_lower.starts_with("varbinary")
204            || data_type_lower.contains("blob")
205            || (data_type_lower.starts_with("bit") && !Self::is_bit_1(&data_type_lower))
206        {
207            return RustType::Bytes;
208        }
209
210        // Date/time types
211        if data_type_lower == "date" {
212            return RustType::NaiveDate;
213        }
214        if data_type_lower.starts_with("datetime") || data_type_lower.starts_with("timestamp") {
215            return RustType::NaiveDateTime;
216        }
217        if data_type_lower == "time" {
218            return RustType::NaiveTime;
219        }
220
221        // JSON type
222        if data_type_lower == "json" {
223            return RustType::Json;
224        }
225
226        // Spatial types -> bytes
227        if data_type_lower.starts_with("geometry")
228            || data_type_lower.starts_with("point")
229            || data_type_lower.starts_with("linestring")
230            || data_type_lower.starts_with("polygon")
231            || data_type_lower.starts_with("multi")
232            || data_type_lower.starts_with("geometrycollection")
233        {
234            return RustType::Bytes;
235        }
236
237        // Default fallback
238        RustType::String
239    }
240
241    /// Check if the type represents a boolean
242    fn is_boolean_type(data_type_lower: &str, original: &str) -> bool {
243        // BOOL or BOOLEAN
244        if data_type_lower == "bool" || data_type_lower == "boolean" {
245            return true;
246        }
247
248        // TINYINT(1) is typically used as boolean in MySQL
249        if data_type_lower.starts_with("tinyint") {
250            // Check for (1) specifically
251            if original.contains("(1)") || data_type_lower.contains("(1)") {
252                return true;
253            }
254        }
255
256        // BIT(1)
257        if Self::is_bit_1(data_type_lower) {
258            return true;
259        }
260
261        false
262    }
263
264    /// Check if this is a BIT(1) type
265    fn is_bit_1(data_type_lower: &str) -> bool {
266        data_type_lower.starts_with("bit") && data_type_lower.contains("(1)")
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    fn make_column(name: &str, data_type: &str, nullable: bool, unsigned: bool) -> ColumnMetadata {
275        ColumnMetadata {
276            name: name.to_string(),
277            data_type: data_type.to_string(),
278            nullable,
279            default_value: None,
280            is_auto_increment: false,
281            is_unsigned: unsigned,
282            enum_values: None,
283            comment: None,
284        }
285    }
286
287    #[test]
288    fn test_integer_types() {
289        let col = make_column("id", "BIGINT", false, false);
290        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::I64);
291
292        let col = make_column("id", "BIGINT", false, true);
293        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::U64);
294
295        let col = make_column("count", "INT", false, false);
296        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::I32);
297    }
298
299    #[test]
300    fn test_boolean_type() {
301        let col = make_column("active", "TINYINT(1)", false, false);
302        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::Bool);
303
304        let col = make_column("flag", "BOOL", false, false);
305        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::Bool);
306    }
307
308    #[test]
309    fn test_string_types() {
310        let col = make_column("name", "VARCHAR(255)", false, false);
311        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::String);
312
313        let col = make_column("bio", "TEXT", true, false);
314        assert_eq!(
315            TypeResolver::resolve(&col, "users"),
316            RustType::Option(Box::new(RustType::String))
317        );
318    }
319
320    #[test]
321    fn test_datetime_types() {
322        let col = make_column("created_at", "DATETIME", true, false);
323        assert_eq!(
324            TypeResolver::resolve(&col, "users"),
325            RustType::Option(Box::new(RustType::NaiveDateTime))
326        );
327
328        let col = make_column("birth_date", "DATE", false, false);
329        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::NaiveDate);
330    }
331
332    #[test]
333    fn test_enum_type() {
334        let mut col = make_column("status", "ENUM", false, false);
335        col.enum_values = Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]);
336        assert_eq!(
337            TypeResolver::resolve(&col, "users"),
338            RustType::Enum("UsersStatus".to_string())
339        );
340    }
341
342    #[test]
343    fn test_nullable() {
344        let col = make_column("optional", "INT", true, false);
345        assert_eq!(
346            TypeResolver::resolve(&col, "users"),
347            RustType::Option(Box::new(RustType::I32))
348        );
349    }
350
351    #[test]
352    fn test_type_string() {
353        assert_eq!(RustType::I64.to_type_string(), "i64");
354        assert_eq!(
355            RustType::Option(Box::new(RustType::String)).to_type_string(),
356            "Option<String>"
357        );
358        assert_eq!(
359            RustType::NaiveDateTime.to_type_string(),
360            "chrono::NaiveDateTime"
361        );
362    }
363
364    #[test]
365    fn test_param_type_string() {
366        assert_eq!(RustType::String.to_param_type_string(), "&str");
367        assert_eq!(
368            RustType::Option(Box::new(RustType::String)).to_param_type_string(),
369            "Option<&str>"
370        );
371        assert_eq!(RustType::I64.to_param_type_string(), "i64");
372    }
373}