orso_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    Attribute, Data, DeriveInput, Fields, Lit, parse_macro_input, punctuated::Punctuated,
5    token::Comma,
6};
7
8#[proc_macro_attribute]
9pub fn orso_column(_args: TokenStream, input: TokenStream) -> TokenStream {
10    input
11}
12
13// orso_table attribute (passthrough - only used for table naming)
14#[proc_macro_attribute]
15pub fn orso_table(_args: TokenStream, input: TokenStream) -> TokenStream {
16    input
17}
18
19// Derive macro for Orso trait
20#[proc_macro_derive(Orso, attributes(orso_table, orso_column))]
21pub fn derive_orso(input: TokenStream) -> TokenStream {
22    let input = parse_macro_input!(input as DeriveInput);
23    let name = input.ident;
24
25    // Extract table name from attributes or use default
26    let table_name =
27        extract_orso_table_name(&input.attrs).unwrap_or_else(|| name.to_string().to_lowercase());
28
29    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
30
31    // Extract field metadata
32    let (
33        field_names,
34        column_definitions,
35        mathilde_field_types,
36        nullable_flags,
37        primary_key_field,
38        created_at_field,
39        updated_at_field,
40        unique_fields,
41    ) = if let Data::Struct(data) = &input.data {
42        if let Fields::Named(fields) = &data.fields {
43            extract_field_metadata_original(&fields.named)
44        } else {
45            (vec![], vec![], vec![], vec![], None, None, None, vec![])
46        }
47    } else {
48        (vec![], vec![], vec![], vec![], None, None, None, vec![])
49    };
50
51    // Generate dynamic getters based on actual fields found
52    let primary_key_getter = if let Some(ref pk_field) = primary_key_field {
53        quote! {
54            match &self.#pk_field {
55                Some(pk) => Some(pk.to_string()),
56                None => None,
57            }
58        }
59    } else {
60        quote! { None }
61    };
62
63    let primary_key_setter = if let Some(ref pk_field) = primary_key_field {
64        quote! {
65            if let Ok(parsed_id) = id.parse() {
66                self.#pk_field = Some(parsed_id);
67            }
68        }
69    } else {
70        quote! { /* No primary key field found */ }
71    };
72
73    let created_at_getter = if let Some(ref ca_field) = created_at_field {
74        quote! { self.#ca_field }
75    } else {
76        quote! { None }
77    };
78
79    let updated_at_getter = if let Some(ref ua_field) = updated_at_field {
80        quote! { self.#ua_field }
81    } else {
82        quote! { None }
83    };
84
85    let updated_at_setter = if let Some(ref ua_field) = updated_at_field {
86        quote! { self.#ua_field = Some(updated_at); }
87    } else {
88        quote! { /* No updated_at field found */ }
89    };
90
91    // Generate field name constants
92    let primary_key_field_name = if let Some(ref pk_field) = primary_key_field {
93        quote! { stringify!(#pk_field) }
94    } else {
95        quote! { "id" }
96    };
97
98    let created_at_field_name = if let Some(ref ca_field) = created_at_field {
99        quote! { Some(stringify!(#ca_field)) }
100    } else {
101        quote! { None }
102    };
103
104    let updated_at_field_name = if let Some(ref ua_field) = updated_at_field {
105        quote! { Some(stringify!(#ua_field)) }
106    } else {
107        quote! { None }
108    };
109
110    // Generate unique fields list
111    let unique_field_names: Vec<proc_macro2::TokenStream> = unique_fields
112        .iter()
113        .map(|field| quote! { stringify!(#field) })
114        .collect();
115
116    // Generate only the trait implementation
117    let expanded = quote! {
118        impl #impl_generics orso::Orso for #name #ty_generics #where_clause {
119            fn table_name() -> &'static str {
120                #table_name
121            }
122
123            fn primary_key_field() -> &'static str {
124                #primary_key_field_name
125            }
126
127            fn created_at_field() -> Option<&'static str> {
128                #created_at_field_name
129            }
130
131            fn updated_at_field() -> Option<&'static str> {
132                #updated_at_field_name
133            }
134
135            fn unique_fields() -> Vec<&'static str> {
136                vec![#(#unique_field_names),*]
137            }
138
139            fn get_primary_key(&self) -> Option<String> {
140                #primary_key_getter
141            }
142
143            fn set_primary_key(&mut self, id: String) {
144                #primary_key_setter
145            }
146
147            fn get_created_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
148                #created_at_getter
149            }
150
151            fn get_updated_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
152                #updated_at_getter
153            }
154
155            fn set_updated_at(&mut self, updated_at: chrono::DateTime<chrono::Utc>) {
156                #updated_at_setter
157            }
158
159            fn field_names() -> Vec<&'static str> {
160                vec![#(#field_names),*]
161            }
162
163            fn field_types() -> Vec<orso::FieldType> {
164                vec![#(#mathilde_field_types),*]
165            }
166
167            fn field_nullable() -> Vec<bool> {
168                vec![#(#nullable_flags),*]
169            }
170
171            fn columns() -> Vec<&'static str> {
172                vec![#(#field_names),*]
173            }
174
175            fn migration_sql() -> String {
176                // Only generate columns for actual struct fields
177                let columns: Vec<String> = vec![#(#column_definitions),*];
178
179                format!(
180                    "CREATE TABLE IF NOT EXISTS {} (\n    {}\n)",
181                    Self::table_name(),
182                    columns.join(",\n    ")
183                )
184            }
185
186            fn to_map(&self) -> orso::Result<std::collections::HashMap<String, orso::Value>> {
187                use serde_json;
188                let json = serde_json::to_value(self)?;
189                let map: std::collections::HashMap<String, serde_json::Value> =
190                    serde_json::from_value(json)?;
191
192                let mut result = std::collections::HashMap::new();
193
194                // Get field names for auto-generated fields
195                let pk_field = Self::primary_key_field();
196                let created_field = Self::created_at_field();
197                let updated_field = Self::updated_at_field();
198
199                for (k, v) in map {
200                    // Skip auto-generated fields when they are null - let SQLite use DEFAULT values
201                    let should_skip = matches!(v, serde_json::Value::Null) && (
202                        k == pk_field ||
203                        (created_field.is_some() && k == created_field.unwrap()) ||
204                        (updated_field.is_some() && k == updated_field.unwrap())
205                    );
206
207                    if should_skip {
208                        continue;
209                    }
210
211                    let value = match v {
212                        serde_json::Value::Null => orso::Value::Null,
213                        serde_json::Value::Bool(b) => orso::Value::Boolean(b),
214                        serde_json::Value::Number(n) => {
215                            if let Some(i) = n.as_i64() {
216                                orso::Value::Integer(i)
217                            } else if let Some(f) = n.as_f64() {
218                                orso::Value::Real(f)
219                            } else {
220                                orso::Value::Text(n.to_string())
221                            }
222                        }
223                        serde_json::Value::String(s) => orso::Value::Text(s),
224                        serde_json::Value::Array(_) => orso::Value::Text(serde_json::to_string(&v)?),
225                        serde_json::Value::Object(_) => orso::Value::Text(serde_json::to_string(&v)?),
226                    };
227                    result.insert(k, value);
228                }
229                Ok(result)
230            }
231
232            fn from_map(mut map: std::collections::HashMap<String, orso::Value>) -> orso::Result<Self> {
233                use serde_json;
234                let mut json_map = serde_json::Map::new();
235
236                // Get field metadata for type-aware conversion
237                let field_names = Self::field_names();
238                let field_types = Self::field_types();
239
240                for (k, v) in &map {
241                    // Don't skip any fields when deserializing FROM database - we want all values
242
243                    let json_value = match v {
244                        orso::Value::Null => serde_json::Value::Null,
245                        orso::Value::Boolean(b) => serde_json::Value::Bool(*b),
246                        orso::Value::Integer(i) => {
247                            // Check if this field should be a boolean based on field type
248                            if let Some(pos) = field_names.iter().position(|&name| name == k) {
249                                if matches!(field_types.get(pos), Some(orso::FieldType::Boolean)) {
250                                    // This is a boolean field, convert 0/1 to bool
251                                    serde_json::Value::Bool(*i != 0)
252                                } else {
253                                    serde_json::Value::Number(serde_json::Number::from(*i))
254                                }
255                            } else {
256                                serde_json::Value::Number(serde_json::Number::from(*i))
257                            }
258                        },
259                        orso::Value::Real(f) => {
260                            if let Some(n) = serde_json::Number::from_f64(*f) {
261                                serde_json::Value::Number(n)
262                            } else {
263                                serde_json::Value::String(f.to_string())
264                            }
265                        }
266                        orso::Value::Text(s) => {
267                            // Check if this might be a SQLite datetime that needs conversion
268                            if s.len() == 19 && s.chars().nth(4) == Some('-') && s.chars().nth(7) == Some('-') && s.chars().nth(10) == Some(' ') {
269                                // This looks like SQLite datetime format: "2025-09-13 10:50:43"
270                                // Convert to RFC3339 format: "2025-09-13T10:50:43Z"
271                                let rfc3339_format = s.replace(' ', "T") + "Z";
272                                serde_json::Value::String(rfc3339_format)
273                            } else {
274                                serde_json::Value::String(s.clone())
275                            }
276                        },
277                        orso::Value::Blob(b) => {
278                            serde_json::Value::Array(
279                                b.iter()
280                                .map(|byte| serde_json::Value::Number(serde_json::Number::from(*byte)))
281                                .collect()
282                            )
283                        }
284                    };
285                    json_map.insert(k.clone(), json_value);
286                }
287
288                let json_value = serde_json::Value::Object(json_map);
289
290                match serde_json::from_value(json_value) {
291                    Ok(result) => Ok(result),
292                    Err(e) => Err(orso::Error::Serialization(e.to_string()))
293                }
294            }
295
296
297            // Utility methods
298            fn row_to_map(row: &libsql::Row) -> orso::Result<std::collections::HashMap<String, orso::Value>> {
299                let mut map = std::collections::HashMap::new();
300                for i in 0..row.column_count() {
301                    if let Some(column_name) = row.column_name(i) {
302                        let value = row.get_value(i).unwrap_or(libsql::Value::Null);
303                        map.insert(column_name.to_string(), Self::libsql_value_to_value(&value));
304                    }
305                }
306                Ok(map)
307            }
308
309            fn value_to_libsql_value(value: &orso::Value) -> libsql::Value {
310                match value {
311                    orso::Value::Null => libsql::Value::Null,
312                    orso::Value::Integer(i) => libsql::Value::Integer(*i),
313                    orso::Value::Real(f) => libsql::Value::Real(*f),
314                    orso::Value::Text(s) => libsql::Value::Text(s.clone()),
315                    orso::Value::Blob(b) => libsql::Value::Blob(b.clone()),
316                    orso::Value::Boolean(b) => libsql::Value::Integer(if *b { 1 } else { 0 }),
317                }
318            }
319
320            fn libsql_value_to_value(value: &libsql::Value) -> orso::Value {
321                match value {
322                    libsql::Value::Null => orso::Value::Null,
323                    libsql::Value::Integer(i) => {
324                        // SQLite stores booleans as integers 0/1
325                        // Check if this might be a boolean value
326                        if *i == 0 || *i == 1 {
327                            // This could be a boolean, but we don't have type context here
328                            // For now, keep as integer and let from_map handle the conversion
329                            orso::Value::Integer(*i)
330                        } else {
331                            orso::Value::Integer(*i)
332                        }
333                    },
334                    libsql::Value::Real(f) => orso::Value::Real(*f),
335                    libsql::Value::Text(s) => orso::Value::Text(s.clone()),
336                    libsql::Value::Blob(b) => orso::Value::Blob(b.clone()),
337                }
338            }
339        }
340    };
341
342    TokenStream::from(expanded)
343}
344
345// Parse field-level column definition with inline REFERENCES for maximum Turso compatibility
346fn parse_field_column_definition(field: &syn::Field) -> String {
347    let field_name = field.ident.as_ref().unwrap().to_string();
348
349    // Check for orso_column attributes
350    for attr in &field.attrs {
351        if attr.path().is_ident("orso_column") {
352            return parse_orso_column_attr(attr, &field_name, &field.ty);
353        }
354    }
355
356    // Default column definition based on field type
357    map_rust_type_to_sql_column(&field.ty, &field_name)
358}
359
360// Parse orso_column attribute with support for foreign keys
361fn parse_orso_column_attr(
362    attr: &syn::Attribute,
363    field_name: &str,
364    field_type: &syn::Type,
365) -> String {
366    let mut column_type = None;
367    let mut is_foreign_key = false;
368    let mut foreign_table = None;
369    let mut unique = false;
370    let mut primary_key = false;
371
372    let mut is_created_at = false;
373    let mut is_updated_at = false;
374
375    let _ = attr.parse_nested_meta(|meta| {
376        if meta.path.is_ident("ref") {
377            is_foreign_key = true;
378            if let Ok(value) = meta.value() {
379                let lit: Lit = value.parse()?;
380                if let Lit::Str(lit_str) = lit {
381                    foreign_table = Some(lit_str.value());
382                }
383            }
384        } else if meta.path.is_ident("type") {
385            if let Ok(value) = meta.value() {
386                let lit: Lit = value.parse()?;
387                if let Lit::Str(lit_str) = lit {
388                    column_type = Some(lit_str.value());
389                }
390            }
391        } else if meta.path.is_ident("unique") {
392            unique = true;
393        } else if meta.path.is_ident("primary_key") {
394            primary_key = true;
395        } else if meta.path.is_ident("created_at") {
396            is_created_at = true;
397        } else if meta.path.is_ident("updated_at") {
398            is_updated_at = true;
399        }
400        Ok(())
401    });
402
403    // Generate column definition
404    let base_type = if is_foreign_key {
405        "TEXT".to_string() // Foreign keys are always TEXT (UUID)
406    } else {
407        column_type.unwrap_or_else(|| map_rust_type_to_sql_type(field_type))
408    };
409
410    let mut column_def = format!("{} {}", field_name, base_type);
411
412    if primary_key {
413        column_def.push_str(" PRIMARY KEY");
414        // Add default for primary key if it's TEXT type
415        if base_type == "TEXT" {
416            column_def.push_str(" DEFAULT (lower(hex(randomblob(16))))");
417        }
418    }
419    // Add NOT NULL for non-Option types (except primary keys which are already handled)
420    if !is_option_type(field_type) && !primary_key {
421        column_def.push_str(" NOT NULL");
422    }
423    if unique {
424        column_def.push_str(" UNIQUE");
425    }
426    if let Some(ref_table) = foreign_table {
427        column_def.push_str(&format!(" REFERENCES {}(id)", ref_table));
428    }
429
430    // Add defaults for timestamp columns
431    if is_created_at || is_updated_at {
432        column_def.push_str(" DEFAULT (datetime('now'))");
433    }
434
435    column_def
436}
437
438// Map Rust types to SQL column definitions
439fn map_rust_type_to_sql_column(rust_type: &syn::Type, field_name: &str) -> String {
440    let sql_type = map_rust_type_to_sql_type(rust_type);
441    let mut column_def = format!("{} {}", field_name, sql_type);
442
443    // Add NOT NULL for non-Option types
444    if !is_option_type(rust_type) {
445        column_def.push_str(" NOT NULL");
446    }
447
448    column_def
449}
450
451// Map Rust types to SQL types
452fn map_rust_type_to_sql_type(rust_type: &syn::Type) -> String {
453    if let syn::Type::Path(type_path) = rust_type {
454        if let Some(segment) = type_path.path.segments.last() {
455            let type_name = segment.ident.to_string();
456            return match type_name.as_str() {
457                "String" => "TEXT".to_string(),
458                "i64" | "i32" | "i16" | "i8" => "INTEGER".to_string(),
459                "u64" | "u32" | "u16" | "u8" => "INTEGER".to_string(),
460                "f64" | "f32" => "REAL".to_string(),
461                "bool" => "INTEGER".to_string(), // SQLite stores booleans as integers
462                "Option" => {
463                    // Handle Option<T> types
464                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
465                        if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
466                            return map_rust_type_to_sql_type(inner_type);
467                        }
468                    }
469                    "TEXT".to_string()
470                }
471                _ => "TEXT".to_string(),
472            };
473        }
474    }
475    "TEXT".to_string()
476}
477
478// Map field types to FieldType enum
479fn map_field_type(rust_type: &syn::Type, _field: &syn::Field) -> proc_macro2::TokenStream {
480    if let syn::Type::Path(type_path) = rust_type {
481        if let Some(segment) = type_path.path.segments.last() {
482            let type_name = segment.ident.to_string();
483            return match type_name.as_str() {
484                "String" => quote! { orso::FieldType::Text },
485                "i64" => quote! { orso::FieldType::BigInt },
486                "i32" | "i16" | "i8" => quote! { orso::FieldType::Integer },
487                "u64" => quote! { orso::FieldType::BigInt },
488                "u32" | "u16" | "u8" => quote! { orso::FieldType::Integer },
489                "f64" | "f32" => quote! { orso::FieldType::Numeric },
490                "bool" => quote! { orso::FieldType::Boolean },
491                "Option" => {
492                    // Handle Option<T> types - get the inner type
493                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
494                        if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
495                            return map_field_type(inner_type, _field);
496                        }
497                    }
498                    quote! { orso::FieldType::Text }
499                }
500                _ => quote! { orso::FieldType::Text },
501            };
502        }
503    }
504    quote! { orso::FieldType::Text }
505}
506
507// Check if a type is Option<T>
508fn is_option_type(rust_type: &syn::Type) -> bool {
509    if let syn::Type::Path(type_path) = rust_type {
510        if let Some(segment) = type_path.path.segments.last() {
511            return segment.ident == "Option";
512        }
513    }
514    false
515}
516
517// Extract field metadata from all struct fields
518fn extract_field_metadata_original(
519    fields: &Punctuated<syn::Field, Comma>,
520) -> (
521    Vec<proc_macro2::TokenStream>,
522    Vec<proc_macro2::TokenStream>,
523    Vec<proc_macro2::TokenStream>,
524    Vec<bool>,
525    Option<proc_macro2::Ident>,
526    Option<proc_macro2::Ident>,
527    Option<proc_macro2::Ident>,
528    Vec<proc_macro2::Ident>,
529) {
530    let mut field_names = Vec::new();
531    let mut column_defs = Vec::new();
532    let mut field_types = Vec::new();
533    let mut nullable_flags = Vec::new();
534    let mut primary_key_field: Option<proc_macro2::Ident> = None;
535    let mut created_at_field: Option<proc_macro2::Ident> = None;
536    let mut updated_at_field: Option<proc_macro2::Ident> = None;
537    let mut unique_fields = Vec::new();
538
539    for field in fields {
540        if let Some(field_name) = &field.ident {
541            // Check for special attributes
542            let mut is_primary_key = false;
543            let mut is_created_at = false;
544            let mut is_updated_at = false;
545            let mut is_unique = false;
546
547            for attr in &field.attrs {
548                if attr.path().is_ident("orso_column") {
549                    let _ = attr.parse_nested_meta(|meta| {
550                        if meta.path.is_ident("primary_key") {
551                            is_primary_key = true;
552                            primary_key_field = Some(field_name.clone());
553                        } else if meta.path.is_ident("created_at") {
554                            is_created_at = true;
555                            created_at_field = Some(field_name.clone());
556                        } else if meta.path.is_ident("updated_at") {
557                            is_updated_at = true;
558                            updated_at_field = Some(field_name.clone());
559                        } else if meta.path.is_ident("unique") {
560                            is_unique = true;
561                        }
562                        Ok(())
563                    });
564                }
565            }
566
567            if is_unique {
568                unique_fields.push(field_name.clone());
569            }
570
571            // Process ALL fields - no skipping based on field names
572
573            let field_name_token = quote! { stringify!(#field_name) };
574            field_names.push(field_name_token);
575
576            // Parse column attributes for foreign key references (inline REFERENCES)
577            let column_def = parse_field_column_definition(field);
578            column_defs.push(quote! { #column_def.to_string() });
579
580            // Enhanced type mapping based on field type and attributes
581            let field_type = map_field_type(&field.ty, field);
582            field_types.push(field_type);
583
584            // Check if field is Option<T> (nullable)
585            let is_nullable = is_option_type(&field.ty);
586            nullable_flags.push(is_nullable);
587        }
588    }
589
590    (
591        field_names,
592        column_defs,
593        field_types,
594        nullable_flags,
595        primary_key_field,
596        created_at_field,
597        updated_at_field,
598        unique_fields,
599    )
600}
601
602// Extract table name from struct attributes
603fn extract_orso_table_name(attrs: &[Attribute]) -> Option<String> {
604    for attr in attrs {
605        if attr.path().is_ident("orso_table") {
606            if let Ok(Lit::Str(lit_str)) = attr.parse_args::<Lit>() {
607                return Some(lit_str.value());
608            }
609        }
610    }
611    None
612}