remdb_macros/
lib.rs

1mod codegen;
2mod ddl_parser;
3
4use proc_macro::TokenStream;
5use syn::parse_macro_input;
6use quote::quote;
7
8#[proc_macro]
9pub fn define_schema(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as syn::LitStr);
11    let schema = input.value();
12    
13    match ddl_parser::parse_ddl(&schema) {
14        Ok(table_defs) => {
15            codegen::generate_code(table_defs)
16        },
17        Err(e) => {
18            panic!("Failed to parse DDL: {}", e);
19        }
20    }
21}
22
23#[proc_macro_derive(MemdbTable, attributes(memdb_schema))]
24pub fn derive_memdb_table(input: TokenStream) -> TokenStream {
25    let derive_input = parse_macro_input!(input as syn::DeriveInput);
26    
27    // 查找memdb_schema属性
28    let mut ddl = String::new();
29    
30    for attr in &derive_input.attrs {
31        if attr.path().is_ident("memdb_schema") {
32            // 使用正确的syn 2.0 API解析属性
33            attr.parse_nested_meta(|meta| {
34                if meta.path.is_ident("ddl") {
35                    let lit = meta.value()?;
36                    let lit_str = lit.parse::<syn::LitStr>()?;
37                    ddl = lit_str.value();
38                }
39                Ok(())
40            }).unwrap();
41        }
42    }
43    
44    if ddl.is_empty() {
45        panic!("memdb_schema attribute with ddl parameter is required");
46    }
47    
48    // 解析DDL并生成代码
49    match ddl_parser::parse_ddl(&ddl) {
50        Ok(table_defs) => {
51            codegen::generate_code(table_defs)
52        },
53        Err(e) => {
54            panic!("Failed to parse DDL: {}", e);
55        }
56    }
57}
58
59use syn::{LitInt, Ident, Token};
60use syn::parse::{Parse, ParseStream};
61
62// 字段定义
63struct Field {
64    name: Ident,
65    #[allow(dead_code)]
66    colon: Token![:],
67    // 自定义类型解析,支持 str(32) 这种语法
68    type_name: Ident,
69    type_params: Option<LitInt>,
70}
71
72impl Parse for Field {
73    fn parse(input: ParseStream) -> syn::Result<Self> {
74        let name = input.parse()?;
75        let colon = input.parse()?;
76        
77        // 解析类型名称
78        let type_name = input.parse()?;
79        
80        // 检查是否有括号参数,如 str(32)
81        let type_params = if input.peek(syn::token::Paren) {
82            let content; 
83            syn::parenthesized!(content in input);
84            let params = content.parse()?;
85            Some(params)
86        } else {
87            None
88        };
89        
90        Ok(Self {
91            name,
92            colon,
93            type_name,
94            type_params,
95        })
96    }
97}
98
99// 表定义结构
100struct TableArgs {
101    name: Ident,
102    max_records: LitInt,
103    primary_key: Ident,
104    secondary_index: Option<Ident>,
105    secondary_index_type: Option<Ident>,
106    fields: Vec<Field>,
107}
108
109impl Parse for TableArgs {
110    fn parse(input: ParseStream) -> syn::Result<Self> {
111        // 解析表名
112        let name = input.parse()?;
113        
114        // 解析逗号
115        let _comma1: Token![,] = input.parse()?;
116        
117        // 解析最大记录数
118        let max_records = input.parse()?;
119        
120        // 解析逗号
121        let _comma2: Token![,] = input.parse()?;
122        
123        // 解析primary_key
124        let _primary_key_keyword: Ident = input.parse()?;
125        let _colon1: Token![:] = input.parse()?;
126        let primary_key = input.parse()?;
127        
128        // 解析secondary_index(可选)
129        let mut secondary_index = None;
130        let mut secondary_index_type = None;
131        
132        // 检查primary_key之后是否有逗号
133        if input.peek(Token![,]) {
134            let _comma3: Token![,] = input.parse()?;
135        }
136        
137        // 解析secondary_index、secondary_index_type和fields关键字
138        loop {
139            // 检查下一个标记
140            let next = input.lookahead1();
141            if next.peek(Ident) {
142                let param_name = input.parse::<Ident>()?;
143                if param_name == "secondary_index" {
144                    let _colon: Token![:] = input.parse()?;
145                    secondary_index = Some(input.parse()?);
146                    
147                    // 解析逗号
148                    if input.peek(Token![,]) {
149                        let _comma4: Token![,] = input.parse()?;
150                    }
151                } else if param_name == "secondary_index_type" {
152                    let _colon: Token![:] = input.parse()?;
153                    secondary_index_type = Some(input.parse()?);
154                    
155                    // 解析逗号
156                    if input.peek(Token![,]) {
157                        let _comma5: Token![,] = input.parse()?;
158                    }
159                } else if param_name == "fields" {
160                    let _colon_fields: Token![:] = input.parse()?;
161                    break;
162                } else {
163                    return Err(syn::Error::new(param_name.span(), format!("expected 'secondary_index', 'secondary_index_type' or 'fields' keyword, got '{}'", param_name)));
164                }
165            } else {
166                return Err(next.error());
167            }
168        }
169        
170        // 解析fields块
171        let content; 
172        syn::braced!(content in input);
173        
174        // 解析fields块内的内容
175        let mut fields = Vec::new();
176        while !content.is_empty() {
177            // 解析字段
178            let field = content.parse::<Field>()?;
179            fields.push(field);
180            
181            // 如果还有逗号,解析它
182            if content.peek(Token![,]) {
183                content.parse::<Token![,]>()?;
184            }
185        }
186        
187        Ok(Self {
188            name,
189            max_records,
190            primary_key,
191            secondary_index,
192            secondary_index_type,
193            fields,
194        })
195    }
196}
197
198// 数据库定义结构,解析数据库名和表列表
199struct DatabaseArgs {
200    name: Ident,
201    tables: Vec<Ident>,
202    low_power: bool,
203    low_power_max_records: Option<usize>,
204    default_max_records: usize,
205}
206
207impl Parse for DatabaseArgs {
208    fn parse(input: ParseStream) -> syn::Result<Self> {
209        // 解析数据库名
210        let name = input.parse()?;
211        
212        // 解析逗号
213        let _comma: Token![,] = input.parse()?;
214        
215        // 解析tables关键字
216        let _tables: Ident = input.parse()?;
217        
218        // 解析冒号
219        let _colon: Token![:] = input.parse()?;
220        
221        // 解析表列表
222        let content; 
223        syn::bracketed!(content in input);
224        
225        let mut tables = Vec::new();
226        while !content.is_empty() {
227            // 解析表名
228            let table = content.parse::<Ident>()?;
229            tables.push(table);
230            
231            // 如果还有逗号,解析它
232            if content.peek(Token![,]) {
233                content.parse::<Token![,]>()?;
234            }
235        }
236        
237        // 解析可选的low_power参数
238        let mut low_power = false;
239        let mut low_power_max_records = None;
240        let mut default_max_records = 100000; // 默认值
241        
242        // 检查是否还有更多参数
243        while !input.is_empty() {
244            // 解析逗号
245            let _comma: Token![,] = input.parse()?;
246            
247            // 解析参数名
248            let param_name = input.parse::<Ident>()?;
249            
250            // 解析冒号
251            let _colon: Token![:] = input.parse()?;
252            
253            if param_name == "low_power" {
254                // 解析布尔值
255                let lit_bool = input.parse::<syn::LitBool>()?;
256                low_power = lit_bool.value;
257            } else if param_name == "low_power_max_records" {
258                // 解析数字
259                let lit_int = input.parse::<syn::LitInt>()?;
260                low_power_max_records = Some(lit_int.base10_parse().unwrap_or(0));
261            } else if param_name == "default_max_records" {
262                // 解析数字
263                let lit_int = input.parse::<syn::LitInt>()?;
264                default_max_records = lit_int.base10_parse().unwrap_or(100000);
265            }
266        }
267        
268        Ok(Self {
269            name,
270            tables,
271            low_power,
272            low_power_max_records,
273            default_max_records,
274        })
275    }
276}
277
278#[proc_macro]
279pub fn table(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
280    // 解析输入参数
281    let args = parse_macro_input!(input as TableArgs);
282    let name = &args.name;
283    let max_records = &args.max_records;
284    let primary_key = &args.primary_key;
285    let secondary_index = &args.secondary_index;
286    let secondary_index_type = &args.secondary_index_type;
287    let fields = &args.fields;
288    
289    // 生成字段定义
290    let mut offset = 0;
291    let mut field_defs = Vec::new();
292    let mut record_size = 0;
293    let mut primary_key_index = 0usize;
294    let mut secondary_key_index: Option<usize> = None;
295    
296    for (i, field) in fields.iter().enumerate() {
297        let field_name = &field.name;
298        let type_name = &field.type_name;
299        let type_params = &field.type_params;
300        
301        // 确定数据类型和大小
302        let (data_type, size_val) = if type_name == "i32" {
303            (quote!(remdb::types::DataType::Int32), 4)
304        } else if type_name == "i8" {
305            (quote!(remdb::types::DataType::Int8), 1)
306        } else if type_name == "u64" {
307            (quote!(remdb::types::DataType::UInt64), 8)
308        } else if type_name == "f64" {
309            (quote!(remdb::types::DataType::Float64), 8)
310        } else if type_name == "bool" {
311            (quote!(remdb::types::DataType::Bool), 1)
312        } else if type_name == "str" {
313            // 处理str(32)这样的类型
314            let str_size = if let Some(params) = type_params {
315                params.base10_parse().unwrap_or(32)
316            } else {
317                32
318            };
319            (quote!(remdb::types::DataType::String), str_size)
320        } else {
321            (quote!(remdb::types::DataType::Int32), 4)
322        };
323        
324        // 计算对齐要求
325        let alignment = if type_name == "u64" || type_name == "f64" || type_name == "i64" {
326            8
327        } else if type_name == "i32" || type_name == "u32" || type_name == "f32" {
328            4
329        } else if type_name == "i16" || type_name == "u16" {
330            2
331        } else {
332            1
333        };
334        
335        // 调整偏移量以满足对齐要求
336        offset = ((offset + alignment - 1) / alignment) * alignment;
337        
338        // 确定约束字段值
339        let is_primary_key = field_name == primary_key;
340        let primary_key_val = is_primary_key;
341        let not_null_val = is_primary_key; // 主键字段默认为非空
342        let unique_val = is_primary_key;
343        
344        // 检查是否为自增主键:
345        // 1. 整数主键默认自增
346        // 2. 可以显式指定AUTOINCREMENT
347        let is_integer_type = type_name == "i32" || type_name == "i64" || type_name == "u32" || type_name == "u64";
348        let auto_increment_val = is_primary_key && is_integer_type;
349        
350        // 生成字段定义
351        let field_def = quote! {
352            remdb::types::FieldDef {
353                name: stringify!(#field_name),
354                data_type: #data_type,
355                size: #size_val as usize, // 确保是usize类型
356                offset: #offset as usize,  // 确保是usize类型
357                primary_key: #primary_key_val,
358                not_null: #not_null_val,
359                unique: #unique_val,
360                auto_increment: #auto_increment_val,
361                default_value: None,
362            }
363        };
364        
365        field_defs.push(field_def);
366        
367        // 确定主键和二级索引的字段索引
368        if field_name == primary_key {
369            primary_key_index = i;
370        }
371        
372        if let Some(secondary_field) = secondary_index {
373            if field_name == secondary_field {
374                secondary_key_index = Some(i);
375            }
376        }
377        
378        // 更新偏移量和记录大小
379        offset += size_val;
380        record_size = offset;
381    }
382    
383    // 确保整个记录满足最大对齐要求(8字节对齐)
384    let max_alignment = 8;
385    record_size = ((record_size + max_alignment - 1) / max_alignment) * max_alignment;
386    
387    // 将max_records转换为usize
388    let max_records_usize = max_records.base10_parse::<usize>().unwrap_or(100);
389    
390    // 确定索引类型
391    let index_type = match secondary_index_type.as_ref() {
392        Some(ty) if ty == "btree" => quote!(remdb::types::IndexType::BTree),
393        Some(ty) if ty == "hash" => quote!(remdb::types::IndexType::Hash),
394        Some(ty) if ty == "ttree" => quote!(remdb::types::IndexType::TTree),
395        Some(ty) if ty == "sortedarray" => quote!(remdb::types::IndexType::SortedArray),
396        _ => quote!(remdb::types::IndexType::BTree),
397    };
398    
399    // 生成secondary_index代码
400    let secondary_index_code = match secondary_key_index {
401        Some(index) => quote! { Some(#index) },
402        None => quote! { None },
403    };
404    
405    // 生成代码:返回一个TableDef静态变量
406    let output = quote! {
407        #[allow(non_upper_case_globals)]
408        pub static #name: remdb::types::TableDef = remdb::types::TableDef {
409            id: 0,
410            name: stringify!(#name),
411            fields: &[#(#field_defs,)*],
412            primary_key: #primary_key_index as usize,
413            secondary_index: #secondary_index_code,
414            secondary_index_type: #index_type,
415            record_size: #record_size as usize,
416            max_records: #max_records_usize,
417        };
418    };
419    
420    output.into()
421}
422
423#[proc_macro]
424pub fn database(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
425    // 解析输入参数
426    let args = parse_macro_input!(input as DatabaseArgs);
427    let name = &args.name;
428    let tables = &args.tables;
429    let low_power = args.low_power;
430    let default_max_records = args.default_max_records;
431    
432    // 处理low_power_max_records,转换为Option<usize>
433    let low_power_max_records = match args.low_power_max_records {
434        Some(val) => quote! { Some(#val) },
435        None => quote! { None }
436    };
437    
438    // 生成代码:返回一个DbConfig静态变量
439    let output = quote! {
440        #[allow(non_upper_case_globals)]
441        pub static #name: remdb::config::DbConfig = remdb::config::DbConfig {
442            tables: &[#(#tables),*],
443            total_memory: 65536,
444            low_power_mode_supported: #low_power,
445            low_power_max_records: #low_power_max_records,
446            default_max_records: #default_max_records,
447            memory_allocator: unsafe {
448                // 使用默认的内存分配器实现,这里返回一个空指针的静态引用
449                static mut DEFAULT_ALLOCATOR: remdb::config::DefaultMemoryAllocator = remdb::config::DefaultMemoryAllocator;
450                &mut DEFAULT_ALLOCATOR
451            },
452            // 日志相关配置
453            wal_config: remdb::config::WALConfig {
454                log_path: "remdb.wal",
455                log_mode: remdb::config::LogMode::Sync,
456                checkpoint_interval_ms: 60000, // 默认60秒
457                log_file_size_limit: 16 * 1024 * 1024, // 默认16MB
458                log_prealloc_size: 1 * 1024 * 1024, // 默认1MB预分配
459                log_segment_size: 16 * 1024 * 1024, // 默认16MB分段
460                retained_checkpoints: 3, // 保留3个检查点
461            },
462            // 时序数据默认配置
463            time_series_defaults: remdb::time_series::TimeSeriesConfig::DEFAULT,
464            // PubSub配置(可选)
465            #[cfg(feature = "pubsub")]
466            pubsub_config: None,
467            // HA相关配置(可选)
468            #[cfg(feature = "ha")]
469            ha_config: Some(remdb::ha::HAConfig {
470                node_id: 1, // 默认节点ID为1
471                ha_role: remdb::ha::HARole::Auto,
472                replication_mode: remdb::ha::ReplicationMode::Async,
473                heartbeat_interval_ms: 1000, // 默认1秒
474                failure_detection_ms: 3000, // 默认3秒
475                sync_timeout_ms: 2000, // 默认2秒
476                master_address: None,
477                master_port: None,
478                replication_port: 5556,
479            }),
480        };
481    };
482    
483    output.into()
484}