remdb_macros/
lib.rs

1mod codegen;
2mod ddl_parser;
3
4use proc_macro::TokenStream;
5use syn::parse_macro_input;
6use quote::quote;
7
8#[proc_macro]
9pub fn define_schema(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as syn::LitStr);
11    let schema = input.value();
12    
13    match ddl_parser::parse_ddl(&schema) {
14        Ok(table_defs) => {
15            let output = codegen::generate_code(table_defs);
16            output.into()
17        },
18        Err(e) => {
19            panic!("Failed to parse DDL: {}", e);
20        }
21    }
22}
23
24#[proc_macro_derive(MemdbTable, attributes(memdb_schema))]
25pub fn derive_memdb_table(input: TokenStream) -> TokenStream {
26    let derive_input = parse_macro_input!(input as syn::DeriveInput);
27    
28    // 查找memdb_schema属性
29    let mut ddl = String::new();
30    
31    for attr in &derive_input.attrs {
32        if attr.path().is_ident("memdb_schema") {
33            // 使用正确的syn 2.0 API解析属性
34            attr.parse_nested_meta(|meta| {
35                if meta.path.is_ident("ddl") {
36                    let lit = meta.value()?;
37                    let lit_str = lit.parse::<syn::LitStr>()?;
38                    ddl = lit_str.value();
39                }
40                Ok(())
41            }).unwrap();
42        }
43    }
44    
45    if ddl.is_empty() {
46        panic!("memdb_schema attribute with ddl parameter is required");
47    }
48    
49    // 解析DDL并生成代码
50    match ddl_parser::parse_ddl(&ddl) {
51        Ok(table_defs) => {
52            let generated_code = codegen::generate_code(table_defs);
53            generated_code.into()
54        },
55        Err(e) => {
56            panic!("Failed to parse DDL: {}", e);
57        }
58    }
59}
60
61use syn::{LitInt, Ident, Token};
62use syn::parse::{Parse, ParseStream};
63
64// 字段定义
65struct Field {
66    name: Ident,
67    colon: Token![:],
68    // 自定义类型解析,支持 str(32) 这种语法
69    type_name: Ident,
70    type_params: Option<LitInt>,
71}
72
73impl Parse for Field {
74    fn parse(input: ParseStream) -> syn::Result<Self> {
75        let name = input.parse()?;
76        let colon = input.parse()?;
77        
78        // 解析类型名称
79        let type_name = input.parse()?;
80        
81        // 检查是否有括号参数,如 str(32)
82        let type_params = if input.peek(syn::token::Paren) {
83            let content; 
84            syn::parenthesized!(content in input);
85            let params = content.parse()?;
86            Some(params)
87        } else {
88            None
89        };
90        
91        Ok(Self {
92            name,
93            colon,
94            type_name,
95            type_params,
96        })
97    }
98}
99
100// 表定义结构
101struct TableArgs {
102    name: Ident,
103    max_records: LitInt,
104    primary_key: Ident,
105    secondary_index: Option<Ident>,
106    secondary_index_type: Option<Ident>,
107    fields: Vec<Field>,
108}
109
110impl Parse for TableArgs {
111    fn parse(input: ParseStream) -> syn::Result<Self> {
112        // 解析表名
113        let name = input.parse()?;
114        
115        // 解析逗号
116        let _comma1: Token![,] = input.parse()?;
117        
118        // 解析最大记录数
119        let max_records = input.parse()?;
120        
121        // 解析逗号
122        let _comma2: Token![,] = input.parse()?;
123        
124        // 解析primary_key
125        let primary_key_keyword: Ident = input.parse()?;
126        let _colon1: Token![:] = input.parse()?;
127        let primary_key = input.parse()?;
128        
129        // 解析secondary_index(可选)
130        let mut secondary_index = None;
131        let mut secondary_index_type = None;
132        
133        // 检查primary_key之后是否有逗号
134        if input.peek(Token![,]) {
135            let _comma3: Token![,] = input.parse()?;
136        }
137        
138        // 解析secondary_index、secondary_index_type和fields关键字
139        loop {
140            // 检查下一个标记
141            let next = input.lookahead1();
142            if next.peek(Ident) {
143                let param_name = input.parse::<Ident>()?;
144                if param_name == "secondary_index" {
145                    let _colon: Token![:] = input.parse()?;
146                    secondary_index = Some(input.parse()?);
147                    
148                    // 解析逗号
149                    if input.peek(Token![,]) {
150                        let _comma4: Token![,] = input.parse()?;
151                    }
152                } else if param_name == "secondary_index_type" {
153                    let _colon: Token![:] = input.parse()?;
154                    secondary_index_type = Some(input.parse()?);
155                    
156                    // 解析逗号
157                    if input.peek(Token![,]) {
158                        let _comma5: Token![,] = input.parse()?;
159                    }
160                } else if param_name == "fields" {
161                    let _colon_fields: Token![:] = input.parse()?;
162                    break;
163                } else {
164                    return Err(syn::Error::new(param_name.span(), format!("expected 'secondary_index', 'secondary_index_type' or 'fields' keyword, got '{}'", param_name)));
165                }
166            } else {
167                return Err(next.error());
168            }
169        }
170        
171        // 解析fields块
172        let content; 
173        syn::braced!(content in input);
174        
175        // 解析fields块内的内容
176        let mut fields = Vec::new();
177        while !content.is_empty() {
178            // 解析字段
179            let field = content.parse::<Field>()?;
180            fields.push(field);
181            
182            // 如果还有逗号,解析它
183            if content.peek(Token![,]) {
184                content.parse::<Token![,]>()?;
185            }
186        }
187        
188        Ok(Self {
189            name,
190            max_records,
191            primary_key,
192            secondary_index,
193            secondary_index_type,
194            fields,
195        })
196    }
197}
198
199// 数据库定义结构,解析数据库名和表列表
200struct DatabaseArgs {
201    name: Ident,
202    tables: Vec<Ident>,
203    low_power: bool,
204    low_power_max_records: Option<usize>,
205    default_max_records: usize,
206}
207
208impl Parse for DatabaseArgs {
209    fn parse(input: ParseStream) -> syn::Result<Self> {
210        // 解析数据库名
211        let name = input.parse()?;
212        
213        // 解析逗号
214        let _comma: Token![,] = input.parse()?;
215        
216        // 解析tables关键字
217        let _tables: Ident = input.parse()?;
218        
219        // 解析冒号
220        let _colon: Token![:] = input.parse()?;
221        
222        // 解析表列表
223        let content; 
224        syn::bracketed!(content in input);
225        
226        let mut tables = Vec::new();
227        while !content.is_empty() {
228            // 解析表名
229            let table = content.parse::<Ident>()?;
230            tables.push(table);
231            
232            // 如果还有逗号,解析它
233            if content.peek(Token![,]) {
234                content.parse::<Token![,]>()?;
235            }
236        }
237        
238        // 解析可选的low_power参数
239        let mut low_power = false;
240        let mut low_power_max_records = None;
241        let mut default_max_records = 100000; // 默认值
242        
243        // 检查是否还有更多参数
244        while !input.is_empty() {
245            // 解析逗号
246            let _comma: Token![,] = input.parse()?;
247            
248            // 解析参数名
249            let param_name = input.parse::<Ident>()?;
250            
251            // 解析冒号
252            let _colon: Token![:] = input.parse()?;
253            
254            if param_name == "low_power" {
255                // 解析布尔值
256                let lit_bool = input.parse::<syn::LitBool>()?;
257                low_power = lit_bool.value;
258            } else if param_name == "low_power_max_records" {
259                // 解析数字
260                let lit_int = input.parse::<syn::LitInt>()?;
261                low_power_max_records = Some(lit_int.base10_parse().unwrap_or(0));
262            } else if param_name == "default_max_records" {
263                // 解析数字
264                let lit_int = input.parse::<syn::LitInt>()?;
265                default_max_records = lit_int.base10_parse().unwrap_or(100000);
266            }
267        }
268        
269        Ok(Self {
270            name,
271            tables,
272            low_power,
273            low_power_max_records,
274            default_max_records,
275        })
276    }
277}
278
279#[proc_macro]
280pub fn table(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
281    // 解析输入参数
282    let args = parse_macro_input!(input as TableArgs);
283    let name = &args.name;
284    let max_records = &args.max_records;
285    let primary_key = &args.primary_key;
286    let secondary_index = &args.secondary_index;
287    let secondary_index_type = &args.secondary_index_type;
288    let fields = &args.fields;
289    
290    // 生成字段定义
291    let mut offset = 0;
292    let mut field_defs = Vec::new();
293    let mut record_size = 0;
294    let mut primary_key_index = 0usize;
295    let mut secondary_key_index: Option<usize> = None;
296    
297    for (i, field) in fields.iter().enumerate() {
298        let field_name = &field.name;
299        let type_name = &field.type_name;
300        let type_params = &field.type_params;
301        
302        // 确定数据类型和大小
303        let (data_type, size_val) = if type_name == "i32" {
304            (quote!(remdb::types::DataType::Int32), 4)
305        } else if type_name == "i8" {
306            (quote!(remdb::types::DataType::Int8), 1)
307        } else if type_name == "u64" {
308            (quote!(remdb::types::DataType::UInt64), 8)
309        } else if type_name == "f64" {
310            (quote!(remdb::types::DataType::Float64), 8)
311        } else if type_name == "bool" {
312            (quote!(remdb::types::DataType::Bool), 1)
313        } else if type_name == "str" {
314            // 处理str(32)这样的类型
315            let str_size = if let Some(params) = type_params {
316                params.base10_parse().unwrap_or(32)
317            } else {
318                32
319            };
320            (quote!(remdb::types::DataType::String), str_size)
321        } else {
322            (quote!(remdb::types::DataType::Int32), 4)
323        };
324        
325        // 计算对齐要求
326        let alignment = if type_name == "u64" || type_name == "f64" || type_name == "i64" {
327            8
328        } else if type_name == "i32" || type_name == "u32" || type_name == "f32" {
329            4
330        } else if type_name == "i16" || type_name == "u16" {
331            2
332        } else {
333            1
334        };
335        
336        // 调整偏移量以满足对齐要求
337        offset = ((offset + alignment - 1) / alignment) * alignment;
338        
339        // 确定约束字段值
340        let is_primary_key = field_name == primary_key;
341        let primary_key_val = is_primary_key;
342        let not_null_val = is_primary_key; // 主键字段默认为非空
343        let unique_val = is_primary_key;
344        
345        // 检查是否为自增主键:
346        // 1. 整数主键默认自增
347        // 2. 可以显式指定AUTOINCREMENT
348        let is_integer_type = type_name == "i32" || type_name == "i64" || type_name == "u32" || type_name == "u64";
349        let auto_increment_val = is_primary_key && is_integer_type;
350        
351        // 生成字段定义
352        let field_def = quote! {
353            remdb::types::FieldDef {
354                name: stringify!(#field_name),
355                data_type: #data_type,
356                size: #size_val as usize, // 确保是usize类型
357                offset: #offset as usize,  // 确保是usize类型
358                primary_key: #primary_key_val,
359                not_null: #not_null_val,
360                unique: #unique_val,
361                auto_increment: #auto_increment_val,
362            }
363        };
364        
365        field_defs.push(field_def);
366        
367        // 确定主键和二级索引的字段索引
368        if field_name == primary_key {
369            primary_key_index = i;
370        }
371        
372        if let Some(secondary_field) = secondary_index {
373            if field_name == secondary_field {
374                secondary_key_index = Some(i);
375            }
376        }
377        
378        // 更新偏移量和记录大小
379        offset += size_val;
380        record_size = offset;
381    }
382    
383    // 确保整个记录满足最大对齐要求(8字节对齐)
384    let max_alignment = 8;
385    record_size = ((record_size + max_alignment - 1) / max_alignment) * max_alignment;
386    
387    // 将max_records转换为usize
388    let max_records_usize = max_records.base10_parse::<usize>().unwrap_or(100);
389    
390    // 确定索引类型
391    let index_type = match secondary_index_type.as_ref() {
392        Some(ty) if ty == "btree" => quote!(remdb::types::IndexType::BTree),
393        Some(ty) if ty == "hash" => quote!(remdb::types::IndexType::Hash),
394        Some(ty) if ty == "ttree" => quote!(remdb::types::IndexType::TTree),
395        Some(ty) if ty == "sortedarray" => quote!(remdb::types::IndexType::SortedArray),
396        _ => quote!(remdb::types::IndexType::BTree),
397    };
398    
399    // 生成secondary_index代码
400    let secondary_index_code = match secondary_key_index {
401        Some(index) => quote! { Some(#index) },
402        None => quote! { None },
403    };
404    
405    // 生成代码:返回一个TableDef静态变量
406    let output = quote! {
407        #[allow(non_upper_case_globals)]
408        pub static #name: remdb::types::TableDef = remdb::types::TableDef {
409            id: 0,
410            name: stringify!(#name),
411            fields: &[#(#field_defs,)*],
412            primary_key: #primary_key_index as usize,
413            secondary_index: #secondary_index_code,
414            secondary_index_type: #index_type,
415            record_size: #record_size as usize,
416            max_records: #max_records_usize,
417        };
418    };
419    
420    output.into()
421}
422
423#[proc_macro]
424pub fn database(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
425    // 解析输入参数
426    let args = parse_macro_input!(input as DatabaseArgs);
427    let name = &args.name;
428    let tables = &args.tables;
429    let low_power = args.low_power;
430    let default_max_records = args.default_max_records;
431    
432    // 处理low_power_max_records,转换为Option<usize>
433    let low_power_max_records = match args.low_power_max_records {
434        Some(val) => quote! { Some(#val) },
435        None => quote! { None }
436    };
437    
438    // 生成代码:返回一个DbConfig静态变量
439    let output = quote! {
440        #[allow(non_upper_case_globals)]
441        pub static #name: remdb::config::DbConfig = remdb::config::DbConfig {
442            tables: &[#(#tables),*],
443            total_memory: 65536,
444            low_power_mode_supported: #low_power,
445            low_power_max_records: #low_power_max_records,
446            default_max_records: #default_max_records,
447            memory_allocator: unsafe {
448                // 使用默认的内存分配器实现,这里返回一个空指针的静态引用
449                static mut DEFAULT_ALLOCATOR: remdb::config::DefaultMemoryAllocator = remdb::config::DefaultMemoryAllocator;
450                &mut DEFAULT_ALLOCATOR
451            },
452        };
453    };
454    
455    output.into()
456}