remdb_macros/
lib.rs

1mod codegen;
2mod ddl_parser;
3
4use proc_macro::TokenStream;
5use syn::parse_macro_input;
6use quote::quote;
7
8#[proc_macro]
9pub fn define_schema(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as syn::LitStr);
11    let schema = input.value();
12    
13    match ddl_parser::parse_ddl(&schema) {
14        Ok(table_defs) => {
15            let output = codegen::generate_code(table_defs);
16            output.into()
17        },
18        Err(e) => {
19            panic!("Failed to parse DDL: {}", e);
20        }
21    }
22}
23
24#[proc_macro_derive(MemdbTable, attributes(memdb_schema))]
25pub fn derive_memdb_table(input: TokenStream) -> TokenStream {
26    let derive_input = parse_macro_input!(input as syn::DeriveInput);
27    
28    // 查找memdb_schema属性
29    let mut ddl = String::new();
30    
31    for attr in &derive_input.attrs {
32        if attr.path().is_ident("memdb_schema") {
33            // 使用正确的syn 2.0 API解析属性
34            attr.parse_nested_meta(|meta| {
35                if meta.path.is_ident("ddl") {
36                    let lit = meta.value()?;
37                    let lit_str = lit.parse::<syn::LitStr>()?;
38                    ddl = lit_str.value();
39                }
40                Ok(())
41            }).unwrap();
42        }
43    }
44    
45    if ddl.is_empty() {
46        panic!("memdb_schema attribute with ddl parameter is required");
47    }
48    
49    // 解析DDL并生成代码
50    match ddl_parser::parse_ddl(&ddl) {
51        Ok(table_defs) => {
52            let generated_code = codegen::generate_code(table_defs);
53            generated_code.into()
54        },
55        Err(e) => {
56            panic!("Failed to parse DDL: {}", e);
57        }
58    }
59}
60
61use syn::{LitInt, Ident, Token};
62use syn::parse::{Parse, ParseStream};
63
64// 字段定义
65struct Field {
66    name: Ident,
67    colon: Token![:],
68    // 自定义类型解析,支持 str(32) 这种语法
69    type_name: Ident,
70    type_params: Option<LitInt>,
71}
72
73impl Parse for Field {
74    fn parse(input: ParseStream) -> syn::Result<Self> {
75        let name = input.parse()?;
76        let colon = input.parse()?;
77        
78        // 解析类型名称
79        let type_name = input.parse()?;
80        
81        // 检查是否有括号参数,如 str(32)
82        let type_params = if input.peek(syn::token::Paren) {
83            let content; 
84            syn::parenthesized!(content in input);
85            let params = content.parse()?;
86            Some(params)
87        } else {
88            None
89        };
90        
91        Ok(Self {
92            name,
93            colon,
94            type_name,
95            type_params,
96        })
97    }
98}
99
100// 表定义结构
101struct TableArgs {
102    name: Ident,
103    max_records: LitInt,
104    primary_key: Ident,
105    secondary_index: Option<Ident>,
106    secondary_index_type: Option<Ident>,
107    fields: Vec<Field>,
108}
109
110impl Parse for TableArgs {
111    fn parse(input: ParseStream) -> syn::Result<Self> {
112        // 解析表名
113        let name = input.parse()?;
114        
115        // 解析逗号
116        let _comma1: Token![,] = input.parse()?;
117        
118        // 解析最大记录数
119        let max_records = input.parse()?;
120        
121        // 解析逗号
122        let _comma2: Token![,] = input.parse()?;
123        
124        // 解析primary_key
125        let primary_key_keyword: Ident = input.parse()?;
126        let _colon1: Token![:] = input.parse()?;
127        let primary_key = input.parse()?;
128        
129        // 解析secondary_index(可选)
130        let mut secondary_index = None;
131        let mut secondary_index_type = None;
132        
133        // 检查primary_key之后是否有逗号
134        if input.peek(Token![,]) {
135            let _comma3: Token![,] = input.parse()?;
136        }
137        
138        // 解析secondary_index、secondary_index_type和fields关键字
139        loop {
140            // 检查下一个标记
141            let next = input.lookahead1();
142            if next.peek(Ident) {
143                let param_name = input.parse::<Ident>()?;
144                if param_name == "secondary_index" {
145                    let _colon: Token![:] = input.parse()?;
146                    secondary_index = Some(input.parse()?);
147                    
148                    // 解析逗号
149                    if input.peek(Token![,]) {
150                        let _comma4: Token![,] = input.parse()?;
151                    }
152                } else if param_name == "secondary_index_type" {
153                    let _colon: Token![:] = input.parse()?;
154                    secondary_index_type = Some(input.parse()?);
155                    
156                    // 解析逗号
157                    if input.peek(Token![,]) {
158                        let _comma5: Token![,] = input.parse()?;
159                    }
160                } else if param_name == "fields" {
161                    let _colon_fields: Token![:] = input.parse()?;
162                    break;
163                } else {
164                    return Err(syn::Error::new(param_name.span(), format!("expected 'secondary_index', 'secondary_index_type' or 'fields' keyword, got '{}'", param_name)));
165                }
166            } else {
167                return Err(next.error());
168            }
169        }
170        
171        // 解析fields块
172        let content; 
173        syn::braced!(content in input);
174        
175        // 解析fields块内的内容
176        let mut fields = Vec::new();
177        while !content.is_empty() {
178            // 解析字段
179            let field = content.parse::<Field>()?;
180            fields.push(field);
181            
182            // 如果还有逗号,解析它
183            if content.peek(Token![,]) {
184                content.parse::<Token![,]>()?;
185            }
186        }
187        
188        Ok(Self {
189            name,
190            max_records,
191            primary_key,
192            secondary_index,
193            secondary_index_type,
194            fields,
195        })
196    }
197}
198
199// 数据库定义结构,解析数据库名和表列表
200struct DatabaseArgs {
201    name: Ident,
202    tables: Vec<Ident>,
203    low_power: bool,
204    low_power_max_records: Option<usize>,
205}
206
207impl Parse for DatabaseArgs {
208    fn parse(input: ParseStream) -> syn::Result<Self> {
209        // 解析数据库名
210        let name = input.parse()?;
211        
212        // 解析逗号
213        let _comma: Token![,] = input.parse()?;
214        
215        // 解析tables关键字
216        let _tables: Ident = input.parse()?;
217        
218        // 解析冒号
219        let _colon: Token![:] = input.parse()?;
220        
221        // 解析表列表
222        let content; 
223        syn::bracketed!(content in input);
224        
225        let mut tables = Vec::new();
226        while !content.is_empty() {
227            // 解析表名
228            let table = content.parse::<Ident>()?;
229            tables.push(table);
230            
231            // 如果还有逗号,解析它
232            if content.peek(Token![,]) {
233                content.parse::<Token![,]>()?;
234            }
235        }
236        
237        // 解析可选的low_power参数
238        let mut low_power = false;
239        let mut low_power_max_records = None;
240        
241        // 检查是否还有更多参数
242        while !input.is_empty() {
243            // 解析逗号
244            let _comma: Token![,] = input.parse()?;
245            
246            // 解析参数名
247            let param_name = input.parse::<Ident>()?;
248            
249            // 解析冒号
250            let _colon: Token![:] = input.parse()?;
251            
252            if param_name == "low_power" {
253                // 解析布尔值
254                let lit_bool = input.parse::<syn::LitBool>()?;
255                low_power = lit_bool.value;
256            } else if param_name == "low_power_max_records" {
257                // 解析数字
258                let lit_int = input.parse::<syn::LitInt>()?;
259                low_power_max_records = Some(lit_int.base10_parse().unwrap_or(0));
260            }
261        }
262        
263        Ok(Self {
264            name,
265            tables,
266            low_power,
267            low_power_max_records,
268        })
269    }
270}
271
272#[proc_macro]
273pub fn table(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
274    // 解析输入参数
275    let args = parse_macro_input!(input as TableArgs);
276    let name = &args.name;
277    let max_records = &args.max_records;
278    let primary_key = &args.primary_key;
279    let secondary_index = &args.secondary_index;
280    let secondary_index_type = &args.secondary_index_type;
281    let fields = &args.fields;
282    
283    // 生成字段定义
284    let mut offset = 0;
285    let mut field_defs = Vec::new();
286    let mut record_size = 0;
287    let mut primary_key_index = 0usize;
288    let mut secondary_key_index: Option<usize> = None;
289    
290    for (i, field) in fields.iter().enumerate() {
291        let field_name = &field.name;
292        let type_name = &field.type_name;
293        let type_params = &field.type_params;
294        
295        // 确定数据类型和大小
296        let (data_type, size_val) = if type_name == "i32" {
297            (quote!(remdb::types::DataType::Int32), 4)
298        } else if type_name == "i8" {
299            (quote!(remdb::types::DataType::Int8), 1)
300        } else if type_name == "u64" {
301            (quote!(remdb::types::DataType::UInt64), 8)
302        } else if type_name == "f64" {
303            (quote!(remdb::types::DataType::Float64), 8)
304        } else if type_name == "bool" {
305            (quote!(remdb::types::DataType::Bool), 1)
306        } else if type_name == "str" {
307            // 处理str(32)这样的类型
308            let str_size = if let Some(params) = type_params {
309                params.base10_parse().unwrap_or(32)
310            } else {
311                32
312            };
313            (quote!(remdb::types::DataType::String), str_size)
314        } else {
315            (quote!(remdb::types::DataType::Int32), 4)
316        };
317        
318        // 计算对齐要求
319        let alignment = if type_name == "u64" || type_name == "f64" || type_name == "i64" {
320            8
321        } else if type_name == "i32" || type_name == "u32" || type_name == "f32" {
322            4
323        } else if type_name == "i16" || type_name == "u16" {
324            2
325        } else {
326            1
327        };
328        
329        // 调整偏移量以满足对齐要求
330        offset = ((offset + alignment - 1) / alignment) * alignment;
331        
332        // 生成字段定义
333        let field_def = quote! {
334            remdb::types::FieldDef {
335                name: stringify!(#field_name),
336                data_type: #data_type,
337                size: #size_val as usize, // 确保是usize类型
338                offset: #offset as usize,  // 确保是usize类型
339            }
340        };
341        
342        field_defs.push(field_def);
343        
344        // 确定主键和二级索引的字段索引
345        if field_name == primary_key {
346            primary_key_index = i;
347        }
348        
349        if let Some(secondary_field) = secondary_index {
350            if field_name == secondary_field {
351                secondary_key_index = Some(i);
352            }
353        }
354        
355        // 更新偏移量和记录大小
356        offset += size_val;
357        record_size = offset;
358    }
359    
360    // 确保整个记录满足最大对齐要求(8字节对齐)
361    let max_alignment = 8;
362    record_size = ((record_size + max_alignment - 1) / max_alignment) * max_alignment;
363    
364    // 将max_records转换为usize
365    let max_records_usize = max_records.base10_parse::<usize>().unwrap_or(100);
366    
367    // 确定索引类型
368    let index_type = match secondary_index_type.as_ref() {
369        Some(ty) if ty == "btree" => quote!(remdb::types::IndexType::BTree),
370        Some(ty) if ty == "hash" => quote!(remdb::types::IndexType::Hash),
371        Some(ty) if ty == "ttree" => quote!(remdb::types::IndexType::TTree),
372        Some(ty) if ty == "sortedarray" => quote!(remdb::types::IndexType::SortedArray),
373        _ => quote!(remdb::types::IndexType::BTree),
374    };
375    
376    // 生成secondary_index代码
377    let secondary_index_code = match secondary_key_index {
378        Some(index) => quote! { Some(#index) },
379        None => quote! { None },
380    };
381    
382    // 生成代码:返回一个TableDef静态变量
383    let output = quote! {
384        #[allow(non_upper_case_globals)]
385        pub static #name: remdb::types::TableDef = remdb::types::TableDef {
386            id: 0,
387            name: stringify!(#name),
388            fields: &[#(#field_defs,)*],
389            primary_key: #primary_key_index as usize,
390            secondary_index: #secondary_index_code,
391            secondary_index_type: #index_type,
392            record_size: #record_size as usize,
393            max_records: #max_records_usize,
394        };
395    };
396    
397    output.into()
398}
399
400#[proc_macro]
401pub fn database(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
402    // 解析输入参数
403    let args = parse_macro_input!(input as DatabaseArgs);
404    let name = &args.name;
405    let tables = &args.tables;
406    let low_power = args.low_power;
407    
408    // 处理low_power_max_records,转换为Option<usize>
409    let low_power_max_records = match args.low_power_max_records {
410        Some(val) => quote! { Some(#val) },
411        None => quote! { None }
412    };
413    
414    // 生成代码:返回一个DbConfig静态变量
415    let output = quote! {
416        #[allow(non_upper_case_globals)]
417        pub static #name: remdb::config::DbConfig = remdb::config::DbConfig {
418            tables: &[#(#tables),*],
419            total_memory: 65536,
420            low_power_mode_supported: #low_power,
421            low_power_max_records: #low_power_max_records,
422            memory_allocator: unsafe {
423                // 使用默认的内存分配器实现,这里返回一个空指针的静态引用
424                static mut DEFAULT_ALLOCATOR: remdb::config::DefaultMemoryAllocator = remdb::config::DefaultMemoryAllocator;
425                &mut DEFAULT_ALLOCATOR
426            },
427        };
428    };
429    
430    output.into()
431}