1use super::naming;
4use crate::parser::ColumnMetadata;
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum RustType {
9 Bool,
10 I8,
11 I16,
12 I32,
13 I64,
14 U8,
15 U16,
16 U32,
17 U64,
18 F32,
19 F64,
20 String,
21 Bytes,
22 Decimal,
23 NaiveDate,
24 NaiveDateTime,
25 NaiveTime,
26 Json,
27 Enum(String),
29 Option(Box<RustType>),
31}
32
33impl RustType {
34 pub fn to_type_string(&self) -> String {
36 match self {
37 RustType::Bool => "bool".to_string(),
38 RustType::I8 => "i8".to_string(),
39 RustType::I16 => "i16".to_string(),
40 RustType::I32 => "i32".to_string(),
41 RustType::I64 => "i64".to_string(),
42 RustType::U8 => "u8".to_string(),
43 RustType::U16 => "u16".to_string(),
44 RustType::U32 => "u32".to_string(),
45 RustType::U64 => "u64".to_string(),
46 RustType::F32 => "f32".to_string(),
47 RustType::F64 => "f64".to_string(),
48 RustType::String => "String".to_string(),
49 RustType::Bytes => "Vec<u8>".to_string(),
50 RustType::Decimal => "rust_decimal::Decimal".to_string(),
51 RustType::NaiveDate => "chrono::NaiveDate".to_string(),
52 RustType::NaiveDateTime => "chrono::NaiveDateTime".to_string(),
53 RustType::NaiveTime => "chrono::NaiveTime".to_string(),
54 RustType::Json => "serde_json::Value".to_string(),
55 RustType::Enum(name) => name.clone(),
56 RustType::Option(inner) => format!("Option<{}>", inner.to_type_string()),
57 }
58 }
59
60 pub fn to_param_type_string(&self) -> String {
62 match self {
63 RustType::String => "&str".to_string(),
64 RustType::Bytes => "&[u8]".to_string(),
65 RustType::Option(inner) => match inner.as_ref() {
66 RustType::String => "Option<&str>".to_string(),
67 RustType::Bytes => "Option<&[u8]>".to_string(),
68 _ => format!("Option<{}>", inner.to_type_string()),
69 },
70 _ => self.to_type_string(),
71 }
72 }
73
74 pub fn needs_chrono(&self) -> bool {
76 match self {
77 RustType::NaiveDate | RustType::NaiveDateTime | RustType::NaiveTime => true,
78 RustType::Option(inner) => inner.needs_chrono(),
79 _ => false,
80 }
81 }
82
83 pub fn needs_decimal(&self) -> bool {
85 match self {
86 RustType::Decimal => true,
87 RustType::Option(inner) => inner.needs_decimal(),
88 _ => false,
89 }
90 }
91
92 pub fn needs_serde_json(&self) -> bool {
94 match self {
95 RustType::Json => true,
96 RustType::Option(inner) => inner.needs_serde_json(),
97 _ => false,
98 }
99 }
100
101 pub fn inner_type(&self) -> &RustType {
103 match self {
104 RustType::Option(inner) => inner,
105 _ => self,
106 }
107 }
108
109 pub fn is_optional(&self) -> bool {
111 matches!(self, RustType::Option(_))
112 }
113
114 pub fn is_copy(&self) -> bool {
119 match self {
120 RustType::String | RustType::Bytes | RustType::Json => false,
121 RustType::Option(inner) => inner.is_copy(),
122 _ => true,
123 }
124 }
125}
126
127pub struct TypeResolver;
129
130impl TypeResolver {
131 pub fn resolve(column: &ColumnMetadata, table_name: &str) -> RustType {
133 let base_type = Self::resolve_base_type(column, table_name);
134
135 if column.nullable {
136 RustType::Option(Box::new(base_type))
137 } else {
138 base_type
139 }
140 }
141
142 fn resolve_base_type(column: &ColumnMetadata, table_name: &str) -> RustType {
144 let data_type = column.data_type.to_uppercase();
145
146 if column.is_enum() {
148 let enum_name = naming::to_enum_name(table_name, &column.name);
149 return RustType::Enum(enum_name);
150 }
151
152 let data_type_lower = data_type.to_lowercase();
154
155 if Self::is_boolean_type(&data_type_lower, &column.data_type) {
157 return RustType::Bool;
158 }
159
160 if data_type_lower.starts_with("tinyint") {
162 return if column.is_unsigned {
163 RustType::U8
164 } else {
165 RustType::I8
166 };
167 }
168 if data_type_lower.starts_with("smallint") {
169 return if column.is_unsigned {
170 RustType::U16
171 } else {
172 RustType::I16
173 };
174 }
175 if data_type_lower.starts_with("mediumint") || data_type_lower.starts_with("int") {
176 return if column.is_unsigned {
177 RustType::U32
178 } else {
179 RustType::I32
180 };
181 }
182 if data_type_lower.starts_with("bigint") {
183 return if column.is_unsigned {
184 RustType::U64
185 } else {
186 RustType::I64
187 };
188 }
189
190 if data_type_lower.starts_with("float") {
192 return RustType::F32;
193 }
194 if data_type_lower.starts_with("double") || data_type_lower.starts_with("real") {
195 return RustType::F64;
196 }
197
198 if data_type_lower.starts_with("decimal") || data_type_lower.starts_with("numeric") {
200 return RustType::Decimal;
201 }
202
203 if data_type_lower.starts_with("char")
205 || data_type_lower.starts_with("varchar")
206 || data_type_lower.contains("text")
207 || data_type_lower.starts_with("enum")
208 || data_type_lower.starts_with("set")
209 {
210 return RustType::String;
211 }
212
213 if data_type_lower.starts_with("binary")
215 || data_type_lower.starts_with("varbinary")
216 || data_type_lower.contains("blob")
217 || (data_type_lower.starts_with("bit") && !Self::is_bit_1(&data_type_lower))
218 {
219 return RustType::Bytes;
220 }
221
222 if data_type_lower == "date" {
224 return RustType::NaiveDate;
225 }
226 if data_type_lower.starts_with("datetime") || data_type_lower.starts_with("timestamp") {
227 return RustType::NaiveDateTime;
228 }
229 if data_type_lower == "time" {
230 return RustType::NaiveTime;
231 }
232
233 if data_type_lower == "json" {
235 return RustType::Json;
236 }
237
238 if data_type_lower.starts_with("geometry")
240 || data_type_lower.starts_with("point")
241 || data_type_lower.starts_with("linestring")
242 || data_type_lower.starts_with("polygon")
243 || data_type_lower.starts_with("multi")
244 || data_type_lower.starts_with("geometrycollection")
245 {
246 return RustType::Bytes;
247 }
248
249 RustType::String
251 }
252
253 fn is_boolean_type(data_type_lower: &str, original: &str) -> bool {
255 if data_type_lower == "bool" || data_type_lower == "boolean" {
257 return true;
258 }
259
260 if data_type_lower.starts_with("tinyint") {
262 if original.contains("(1)") || data_type_lower.contains("(1)") {
264 return true;
265 }
266 }
267
268 if Self::is_bit_1(data_type_lower) {
270 return true;
271 }
272
273 false
274 }
275
276 fn is_bit_1(data_type_lower: &str) -> bool {
278 data_type_lower.starts_with("bit") && data_type_lower.contains("(1)")
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 fn make_column(name: &str, data_type: &str, nullable: bool, unsigned: bool) -> ColumnMetadata {
287 ColumnMetadata {
288 name: name.to_string(),
289 data_type: data_type.to_string(),
290 nullable,
291 default_value: None,
292 is_auto_increment: false,
293 is_unsigned: unsigned,
294 enum_values: None,
295 comment: None,
296 }
297 }
298
299 #[test]
300 fn test_integer_types() {
301 let col = make_column("id", "BIGINT", false, false);
302 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::I64);
303
304 let col = make_column("id", "BIGINT", false, true);
305 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::U64);
306
307 let col = make_column("count", "INT", false, false);
308 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::I32);
309 }
310
311 #[test]
312 fn test_boolean_type() {
313 let col = make_column("active", "TINYINT(1)", false, false);
314 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::Bool);
315
316 let col = make_column("flag", "BOOL", false, false);
317 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::Bool);
318 }
319
320 #[test]
321 fn test_string_types() {
322 let col = make_column("name", "VARCHAR(255)", false, false);
323 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::String);
324
325 let col = make_column("bio", "TEXT", true, false);
326 assert_eq!(
327 TypeResolver::resolve(&col, "users"),
328 RustType::Option(Box::new(RustType::String))
329 );
330 }
331
332 #[test]
333 fn test_datetime_types() {
334 let col = make_column("created_at", "DATETIME", true, false);
335 assert_eq!(
336 TypeResolver::resolve(&col, "users"),
337 RustType::Option(Box::new(RustType::NaiveDateTime))
338 );
339
340 let col = make_column("birth_date", "DATE", false, false);
341 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::NaiveDate);
342 }
343
344 #[test]
345 fn test_enum_type() {
346 let mut col = make_column("status", "ENUM", false, false);
347 col.enum_values = Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]);
348 assert_eq!(
349 TypeResolver::resolve(&col, "users"),
350 RustType::Enum("UsersStatus".to_string())
351 );
352 }
353
354 #[test]
355 fn test_nullable() {
356 let col = make_column("optional", "INT", true, false);
357 assert_eq!(
358 TypeResolver::resolve(&col, "users"),
359 RustType::Option(Box::new(RustType::I32))
360 );
361 }
362
363 #[test]
364 fn test_type_string() {
365 assert_eq!(RustType::I64.to_type_string(), "i64");
366 assert_eq!(
367 RustType::Option(Box::new(RustType::String)).to_type_string(),
368 "Option<String>"
369 );
370 assert_eq!(
371 RustType::NaiveDateTime.to_type_string(),
372 "chrono::NaiveDateTime"
373 );
374 }
375
376 #[test]
377 fn test_param_type_string() {
378 assert_eq!(RustType::String.to_param_type_string(), "&str");
379 assert_eq!(
380 RustType::Option(Box::new(RustType::String)).to_param_type_string(),
381 "Option<&str>"
382 );
383 assert_eq!(RustType::I64.to_param_type_string(), "i64");
384 }
385}