1use super::naming;
4use crate::parser::ColumnMetadata;
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum RustType {
9 Bool,
10 I8,
11 I16,
12 I32,
13 I64,
14 U8,
15 U16,
16 U32,
17 U64,
18 F32,
19 F64,
20 String,
21 Bytes,
22 Decimal,
23 NaiveDate,
24 NaiveDateTime,
25 NaiveTime,
26 Json,
27 Enum(String),
29 Option(Box<RustType>),
31}
32
33impl RustType {
34 pub fn to_type_string(&self) -> String {
36 match self {
37 RustType::Bool => "bool".to_string(),
38 RustType::I8 => "i8".to_string(),
39 RustType::I16 => "i16".to_string(),
40 RustType::I32 => "i32".to_string(),
41 RustType::I64 => "i64".to_string(),
42 RustType::U8 => "u8".to_string(),
43 RustType::U16 => "u16".to_string(),
44 RustType::U32 => "u32".to_string(),
45 RustType::U64 => "u64".to_string(),
46 RustType::F32 => "f32".to_string(),
47 RustType::F64 => "f64".to_string(),
48 RustType::String => "String".to_string(),
49 RustType::Bytes => "Vec<u8>".to_string(),
50 RustType::Decimal => "rust_decimal::Decimal".to_string(),
51 RustType::NaiveDate => "chrono::NaiveDate".to_string(),
52 RustType::NaiveDateTime => "chrono::NaiveDateTime".to_string(),
53 RustType::NaiveTime => "chrono::NaiveTime".to_string(),
54 RustType::Json => "serde_json::Value".to_string(),
55 RustType::Enum(name) => name.clone(),
56 RustType::Option(inner) => format!("Option<{}>", inner.to_type_string()),
57 }
58 }
59
60 pub fn to_param_type_string(&self) -> String {
62 match self {
63 RustType::String => "&str".to_string(),
64 RustType::Bytes => "&[u8]".to_string(),
65 RustType::Option(inner) => match inner.as_ref() {
66 RustType::String => "Option<&str>".to_string(),
67 RustType::Bytes => "Option<&[u8]>".to_string(),
68 _ => format!("Option<{}>", inner.to_type_string()),
69 },
70 _ => self.to_type_string(),
71 }
72 }
73
74 pub fn needs_chrono(&self) -> bool {
76 match self {
77 RustType::NaiveDate | RustType::NaiveDateTime | RustType::NaiveTime => true,
78 RustType::Option(inner) => inner.needs_chrono(),
79 _ => false,
80 }
81 }
82
83 pub fn needs_decimal(&self) -> bool {
85 match self {
86 RustType::Decimal => true,
87 RustType::Option(inner) => inner.needs_decimal(),
88 _ => false,
89 }
90 }
91
92 pub fn needs_serde_json(&self) -> bool {
94 match self {
95 RustType::Json => true,
96 RustType::Option(inner) => inner.needs_serde_json(),
97 _ => false,
98 }
99 }
100
101 pub fn inner_type(&self) -> &RustType {
103 match self {
104 RustType::Option(inner) => inner,
105 _ => self,
106 }
107 }
108
109 pub fn is_optional(&self) -> bool {
111 matches!(self, RustType::Option(_))
112 }
113}
114
115pub struct TypeResolver;
117
118impl TypeResolver {
119 pub fn resolve(column: &ColumnMetadata, table_name: &str) -> RustType {
121 let base_type = Self::resolve_base_type(column, table_name);
122
123 if column.nullable {
124 RustType::Option(Box::new(base_type))
125 } else {
126 base_type
127 }
128 }
129
130 fn resolve_base_type(column: &ColumnMetadata, table_name: &str) -> RustType {
132 let data_type = column.data_type.to_uppercase();
133
134 if column.is_enum() {
136 let enum_name = naming::to_enum_name(table_name, &column.name);
137 return RustType::Enum(enum_name);
138 }
139
140 let data_type_lower = data_type.to_lowercase();
142
143 if Self::is_boolean_type(&data_type_lower, &column.data_type) {
145 return RustType::Bool;
146 }
147
148 if data_type_lower.starts_with("tinyint") {
150 return if column.is_unsigned {
151 RustType::U8
152 } else {
153 RustType::I8
154 };
155 }
156 if data_type_lower.starts_with("smallint") {
157 return if column.is_unsigned {
158 RustType::U16
159 } else {
160 RustType::I16
161 };
162 }
163 if data_type_lower.starts_with("mediumint") || data_type_lower.starts_with("int") {
164 return if column.is_unsigned {
165 RustType::U32
166 } else {
167 RustType::I32
168 };
169 }
170 if data_type_lower.starts_with("bigint") {
171 return if column.is_unsigned {
172 RustType::U64
173 } else {
174 RustType::I64
175 };
176 }
177
178 if data_type_lower.starts_with("float") {
180 return RustType::F32;
181 }
182 if data_type_lower.starts_with("double") || data_type_lower.starts_with("real") {
183 return RustType::F64;
184 }
185
186 if data_type_lower.starts_with("decimal") || data_type_lower.starts_with("numeric") {
188 return RustType::Decimal;
189 }
190
191 if data_type_lower.starts_with("char")
193 || data_type_lower.starts_with("varchar")
194 || data_type_lower.contains("text")
195 || data_type_lower.starts_with("enum")
196 || data_type_lower.starts_with("set")
197 {
198 return RustType::String;
199 }
200
201 if data_type_lower.starts_with("binary")
203 || data_type_lower.starts_with("varbinary")
204 || data_type_lower.contains("blob")
205 || (data_type_lower.starts_with("bit") && !Self::is_bit_1(&data_type_lower))
206 {
207 return RustType::Bytes;
208 }
209
210 if data_type_lower == "date" {
212 return RustType::NaiveDate;
213 }
214 if data_type_lower.starts_with("datetime") || data_type_lower.starts_with("timestamp") {
215 return RustType::NaiveDateTime;
216 }
217 if data_type_lower == "time" {
218 return RustType::NaiveTime;
219 }
220
221 if data_type_lower == "json" {
223 return RustType::Json;
224 }
225
226 if data_type_lower.starts_with("geometry")
228 || data_type_lower.starts_with("point")
229 || data_type_lower.starts_with("linestring")
230 || data_type_lower.starts_with("polygon")
231 || data_type_lower.starts_with("multi")
232 || data_type_lower.starts_with("geometrycollection")
233 {
234 return RustType::Bytes;
235 }
236
237 RustType::String
239 }
240
241 fn is_boolean_type(data_type_lower: &str, original: &str) -> bool {
243 if data_type_lower == "bool" || data_type_lower == "boolean" {
245 return true;
246 }
247
248 if data_type_lower.starts_with("tinyint") {
250 if original.contains("(1)") || data_type_lower.contains("(1)") {
252 return true;
253 }
254 }
255
256 if Self::is_bit_1(data_type_lower) {
258 return true;
259 }
260
261 false
262 }
263
264 fn is_bit_1(data_type_lower: &str) -> bool {
266 data_type_lower.starts_with("bit") && data_type_lower.contains("(1)")
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 fn make_column(name: &str, data_type: &str, nullable: bool, unsigned: bool) -> ColumnMetadata {
275 ColumnMetadata {
276 name: name.to_string(),
277 data_type: data_type.to_string(),
278 nullable,
279 default_value: None,
280 is_auto_increment: false,
281 is_unsigned: unsigned,
282 enum_values: None,
283 comment: None,
284 }
285 }
286
287 #[test]
288 fn test_integer_types() {
289 let col = make_column("id", "BIGINT", false, false);
290 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::I64);
291
292 let col = make_column("id", "BIGINT", false, true);
293 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::U64);
294
295 let col = make_column("count", "INT", false, false);
296 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::I32);
297 }
298
299 #[test]
300 fn test_boolean_type() {
301 let col = make_column("active", "TINYINT(1)", false, false);
302 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::Bool);
303
304 let col = make_column("flag", "BOOL", false, false);
305 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::Bool);
306 }
307
308 #[test]
309 fn test_string_types() {
310 let col = make_column("name", "VARCHAR(255)", false, false);
311 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::String);
312
313 let col = make_column("bio", "TEXT", true, false);
314 assert_eq!(
315 TypeResolver::resolve(&col, "users"),
316 RustType::Option(Box::new(RustType::String))
317 );
318 }
319
320 #[test]
321 fn test_datetime_types() {
322 let col = make_column("created_at", "DATETIME", true, false);
323 assert_eq!(
324 TypeResolver::resolve(&col, "users"),
325 RustType::Option(Box::new(RustType::NaiveDateTime))
326 );
327
328 let col = make_column("birth_date", "DATE", false, false);
329 assert_eq!(TypeResolver::resolve(&col, "users"), RustType::NaiveDate);
330 }
331
332 #[test]
333 fn test_enum_type() {
334 let mut col = make_column("status", "ENUM", false, false);
335 col.enum_values = Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]);
336 assert_eq!(
337 TypeResolver::resolve(&col, "users"),
338 RustType::Enum("UsersStatus".to_string())
339 );
340 }
341
342 #[test]
343 fn test_nullable() {
344 let col = make_column("optional", "INT", true, false);
345 assert_eq!(
346 TypeResolver::resolve(&col, "users"),
347 RustType::Option(Box::new(RustType::I32))
348 );
349 }
350
351 #[test]
352 fn test_type_string() {
353 assert_eq!(RustType::I64.to_type_string(), "i64");
354 assert_eq!(
355 RustType::Option(Box::new(RustType::String)).to_type_string(),
356 "Option<String>"
357 );
358 assert_eq!(
359 RustType::NaiveDateTime.to_type_string(),
360 "chrono::NaiveDateTime"
361 );
362 }
363
364 #[test]
365 fn test_param_type_string() {
366 assert_eq!(RustType::String.to_param_type_string(), "&str");
367 assert_eq!(
368 RustType::Option(Box::new(RustType::String)).to_param_type_string(),
369 "Option<&str>"
370 );
371 assert_eq!(RustType::I64.to_param_type_string(), "i64");
372 }
373}