Skip to main content

rdbi_codegen/codegen/
type_resolver.rs

1//! MySQL to Rust type mapping
2
3use super::naming;
4use crate::parser::ColumnMetadata;
5
6/// Represents a Rust type for code generation
7#[derive(Debug, Clone, PartialEq)]
8pub enum RustType {
9    Bool,
10    I8,
11    I16,
12    I32,
13    I64,
14    U8,
15    U16,
16    U32,
17    U64,
18    F32,
19    F64,
20    String,
21    Bytes,
22    Decimal,
23    NaiveDate,
24    NaiveDateTime,
25    NaiveTime,
26    Json,
27    /// Custom enum type with the enum name
28    Enum(String),
29    /// Optional wrapper
30    Option(Box<RustType>),
31}
32
33impl RustType {
34    /// Get the type string for code generation
35    pub fn to_type_string(&self) -> String {
36        match self {
37            RustType::Bool => "bool".to_string(),
38            RustType::I8 => "i8".to_string(),
39            RustType::I16 => "i16".to_string(),
40            RustType::I32 => "i32".to_string(),
41            RustType::I64 => "i64".to_string(),
42            RustType::U8 => "u8".to_string(),
43            RustType::U16 => "u16".to_string(),
44            RustType::U32 => "u32".to_string(),
45            RustType::U64 => "u64".to_string(),
46            RustType::F32 => "f32".to_string(),
47            RustType::F64 => "f64".to_string(),
48            RustType::String => "String".to_string(),
49            RustType::Bytes => "Vec<u8>".to_string(),
50            RustType::Decimal => "rust_decimal::Decimal".to_string(),
51            RustType::NaiveDate => "chrono::NaiveDate".to_string(),
52            RustType::NaiveDateTime => "chrono::NaiveDateTime".to_string(),
53            RustType::NaiveTime => "chrono::NaiveTime".to_string(),
54            RustType::Json => "serde_json::Value".to_string(),
55            RustType::Enum(name) => name.clone(),
56            RustType::Option(inner) => format!("Option<{}>", inner.to_type_string()),
57        }
58    }
59
60    /// Get the type string for function parameters (use references for strings)
61    pub fn to_param_type_string(&self) -> String {
62        match self {
63            RustType::String => "&str".to_string(),
64            RustType::Bytes => "&[u8]".to_string(),
65            RustType::Option(inner) => match inner.as_ref() {
66                RustType::String => "Option<&str>".to_string(),
67                RustType::Bytes => "Option<&[u8]>".to_string(),
68                _ => format!("Option<{}>", inner.to_type_string()),
69            },
70            _ => self.to_type_string(),
71        }
72    }
73
74    /// Check if this type needs the chrono crate
75    pub fn needs_chrono(&self) -> bool {
76        match self {
77            RustType::NaiveDate | RustType::NaiveDateTime | RustType::NaiveTime => true,
78            RustType::Option(inner) => inner.needs_chrono(),
79            _ => false,
80        }
81    }
82
83    /// Check if this type needs the rust_decimal crate
84    pub fn needs_decimal(&self) -> bool {
85        match self {
86            RustType::Decimal => true,
87            RustType::Option(inner) => inner.needs_decimal(),
88            _ => false,
89        }
90    }
91
92    /// Check if this type needs the serde_json crate
93    pub fn needs_serde_json(&self) -> bool {
94        match self {
95            RustType::Json => true,
96            RustType::Option(inner) => inner.needs_serde_json(),
97            _ => false,
98        }
99    }
100
101    /// Get the inner type if this is an Option
102    pub fn inner_type(&self) -> &RustType {
103        match self {
104            RustType::Option(inner) => inner,
105            _ => self,
106        }
107    }
108
109    /// Check if this is an Option type
110    pub fn is_optional(&self) -> bool {
111        matches!(self, RustType::Option(_))
112    }
113
114    /// Check if this type implements Copy
115    ///
116    /// Only String, Bytes (Vec<u8>), and Json (serde_json::Value) are non-Copy.
117    /// All other types (primitives, enums, chrono dates, Decimal) implement Copy.
118    pub fn is_copy(&self) -> bool {
119        match self {
120            RustType::String | RustType::Bytes | RustType::Json => false,
121            RustType::Option(inner) => inner.is_copy(),
122            _ => true,
123        }
124    }
125}
126
127/// Resolve MySQL data types to Rust types
128pub struct TypeResolver;
129
130impl TypeResolver {
131    /// Get the Rust type for a column
132    pub fn resolve(column: &ColumnMetadata, table_name: &str) -> RustType {
133        let base_type = Self::resolve_base_type(column, table_name);
134
135        if column.nullable {
136            RustType::Option(Box::new(base_type))
137        } else {
138            base_type
139        }
140    }
141
142    /// Resolve the base type (without Option wrapper)
143    fn resolve_base_type(column: &ColumnMetadata, table_name: &str) -> RustType {
144        let data_type = column.data_type.to_uppercase();
145
146        // Check for enum first
147        if column.is_enum() {
148            let enum_name = naming::to_enum_name(table_name, &column.name);
149            return RustType::Enum(enum_name);
150        }
151
152        // Parse the data type
153        let data_type_lower = data_type.to_lowercase();
154
155        // Boolean types
156        if Self::is_boolean_type(&data_type_lower, &column.data_type) {
157            return RustType::Bool;
158        }
159
160        // Integer types
161        if data_type_lower.starts_with("tinyint") {
162            return if column.is_unsigned {
163                RustType::U8
164            } else {
165                RustType::I8
166            };
167        }
168        if data_type_lower.starts_with("smallint") {
169            return if column.is_unsigned {
170                RustType::U16
171            } else {
172                RustType::I16
173            };
174        }
175        if data_type_lower.starts_with("mediumint") || data_type_lower.starts_with("int") {
176            return if column.is_unsigned {
177                RustType::U32
178            } else {
179                RustType::I32
180            };
181        }
182        if data_type_lower.starts_with("bigint") {
183            return if column.is_unsigned {
184                RustType::U64
185            } else {
186                RustType::I64
187            };
188        }
189
190        // Float types
191        if data_type_lower.starts_with("float") {
192            return RustType::F32;
193        }
194        if data_type_lower.starts_with("double") || data_type_lower.starts_with("real") {
195            return RustType::F64;
196        }
197
198        // Decimal types
199        if data_type_lower.starts_with("decimal") || data_type_lower.starts_with("numeric") {
200            return RustType::Decimal;
201        }
202
203        // String types
204        if data_type_lower.starts_with("char")
205            || data_type_lower.starts_with("varchar")
206            || data_type_lower.contains("text")
207            || data_type_lower.starts_with("enum")
208            || data_type_lower.starts_with("set")
209        {
210            return RustType::String;
211        }
212
213        // Binary types
214        if data_type_lower.starts_with("binary")
215            || data_type_lower.starts_with("varbinary")
216            || data_type_lower.contains("blob")
217            || (data_type_lower.starts_with("bit") && !Self::is_bit_1(&data_type_lower))
218        {
219            return RustType::Bytes;
220        }
221
222        // Date/time types
223        if data_type_lower == "date" {
224            return RustType::NaiveDate;
225        }
226        if data_type_lower.starts_with("datetime") || data_type_lower.starts_with("timestamp") {
227            return RustType::NaiveDateTime;
228        }
229        if data_type_lower == "time" {
230            return RustType::NaiveTime;
231        }
232
233        // JSON type
234        if data_type_lower == "json" {
235            return RustType::Json;
236        }
237
238        // Spatial types -> bytes
239        if data_type_lower.starts_with("geometry")
240            || data_type_lower.starts_with("point")
241            || data_type_lower.starts_with("linestring")
242            || data_type_lower.starts_with("polygon")
243            || data_type_lower.starts_with("multi")
244            || data_type_lower.starts_with("geometrycollection")
245        {
246            return RustType::Bytes;
247        }
248
249        // Default fallback
250        RustType::String
251    }
252
253    /// Check if the type represents a boolean
254    fn is_boolean_type(data_type_lower: &str, original: &str) -> bool {
255        // BOOL or BOOLEAN
256        if data_type_lower == "bool" || data_type_lower == "boolean" {
257            return true;
258        }
259
260        // TINYINT(1) is typically used as boolean in MySQL
261        if data_type_lower.starts_with("tinyint") {
262            // Check for (1) specifically
263            if original.contains("(1)") || data_type_lower.contains("(1)") {
264                return true;
265            }
266        }
267
268        // BIT(1)
269        if Self::is_bit_1(data_type_lower) {
270            return true;
271        }
272
273        false
274    }
275
276    /// Check if this is a BIT(1) type
277    fn is_bit_1(data_type_lower: &str) -> bool {
278        data_type_lower.starts_with("bit") && data_type_lower.contains("(1)")
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    fn make_column(name: &str, data_type: &str, nullable: bool, unsigned: bool) -> ColumnMetadata {
287        ColumnMetadata {
288            name: name.to_string(),
289            data_type: data_type.to_string(),
290            nullable,
291            default_value: None,
292            is_auto_increment: false,
293            is_unsigned: unsigned,
294            enum_values: None,
295            comment: None,
296        }
297    }
298
299    #[test]
300    fn test_integer_types() {
301        let col = make_column("id", "BIGINT", false, false);
302        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::I64);
303
304        let col = make_column("id", "BIGINT", false, true);
305        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::U64);
306
307        let col = make_column("count", "INT", false, false);
308        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::I32);
309    }
310
311    #[test]
312    fn test_boolean_type() {
313        let col = make_column("active", "TINYINT(1)", false, false);
314        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::Bool);
315
316        let col = make_column("flag", "BOOL", false, false);
317        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::Bool);
318    }
319
320    #[test]
321    fn test_string_types() {
322        let col = make_column("name", "VARCHAR(255)", false, false);
323        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::String);
324
325        let col = make_column("bio", "TEXT", true, false);
326        assert_eq!(
327            TypeResolver::resolve(&col, "users"),
328            RustType::Option(Box::new(RustType::String))
329        );
330    }
331
332    #[test]
333    fn test_datetime_types() {
334        let col = make_column("created_at", "DATETIME", true, false);
335        assert_eq!(
336            TypeResolver::resolve(&col, "users"),
337            RustType::Option(Box::new(RustType::NaiveDateTime))
338        );
339
340        let col = make_column("birth_date", "DATE", false, false);
341        assert_eq!(TypeResolver::resolve(&col, "users"), RustType::NaiveDate);
342    }
343
344    #[test]
345    fn test_enum_type() {
346        let mut col = make_column("status", "ENUM", false, false);
347        col.enum_values = Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]);
348        assert_eq!(
349            TypeResolver::resolve(&col, "users"),
350            RustType::Enum("UsersStatus".to_string())
351        );
352    }
353
354    #[test]
355    fn test_nullable() {
356        let col = make_column("optional", "INT", true, false);
357        assert_eq!(
358            TypeResolver::resolve(&col, "users"),
359            RustType::Option(Box::new(RustType::I32))
360        );
361    }
362
363    #[test]
364    fn test_type_string() {
365        assert_eq!(RustType::I64.to_type_string(), "i64");
366        assert_eq!(
367            RustType::Option(Box::new(RustType::String)).to_type_string(),
368            "Option<String>"
369        );
370        assert_eq!(
371            RustType::NaiveDateTime.to_type_string(),
372            "chrono::NaiveDateTime"
373        );
374    }
375
376    #[test]
377    fn test_param_type_string() {
378        assert_eq!(RustType::String.to_param_type_string(), "&str");
379        assert_eq!(
380            RustType::Option(Box::new(RustType::String)).to_param_type_string(),
381            "Option<&str>"
382        );
383        assert_eq!(RustType::I64.to_param_type_string(), "i64");
384    }
385}