Skip to main content

limbo_macros/
lib.rs

1// UPSTREAM: vendored Limbo fork — allow upstream style
2#![allow(
3    rustdoc::bare_urls,
4    rustdoc::invalid_html_tags,
5    rustdoc::invalid_rust_codeblocks
6)]
7#![allow(clippy::collapsible_match)]
8
9mod ext;
10extern crate proc_macro;
11use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree};
12use std::collections::HashMap;
13
14/// A procedural macro that derives a `Description` trait for enums.
15/// This macro extracts documentation comments (specified with `/// Description...`) for enum variants
16/// and generates an implementation for `get_description`, which returns the associated description.
17#[proc_macro_derive(Description, attributes(desc))]
18pub fn derive_description_from_doc(item: TokenStream) -> TokenStream {
19    // Convert the TokenStream into an iterator of TokenTree
20    let mut tokens = item.into_iter();
21
22    let mut enum_name = String::new();
23
24    // Vector to store enum variants and their associated payloads (if any)
25    let mut enum_variants: Vec<(String, Option<String>)> = Vec::<(String, Option<String>)>::new();
26
27    // HashMap to store descriptions associated with each enum variant
28    let mut variant_description_map: HashMap<String, String> = HashMap::new();
29
30    // Parses the token stream to extract the enum name and its variants
31    while let Some(token) = tokens.next() {
32        match token {
33            TokenTree::Ident(ident) if ident.to_string() == "enum" => {
34                // Get the enum name
35                if let Some(TokenTree::Ident(name)) = tokens.next() {
36                    enum_name = name.to_string();
37                }
38            }
39            TokenTree::Group(group) => {
40                let mut group_tokens_iter: IntoIter = group.stream().into_iter();
41
42                let mut last_seen_desc: Option<String> = None;
43                while let Some(token) = group_tokens_iter.next() {
44                    match token {
45                        TokenTree::Punct(punct) => {
46                            if punct.to_string() == "#" {
47                                last_seen_desc = process_description(&mut group_tokens_iter);
48                            }
49                        }
50                        TokenTree::Ident(ident) => {
51                            // Capture the enum variant name and associate it with its description
52                            let ident_str = ident.to_string();
53                            if let Some(desc) = &last_seen_desc {
54                                variant_description_map.insert(ident_str.clone(), desc.clone());
55                            }
56                            enum_variants.push((ident_str, None));
57                            last_seen_desc = None;
58                        }
59                        TokenTree::Group(group) => {
60                            // Capture payload information for the current enum variant
61                            if let Some(last_variant) = enum_variants.last_mut() {
62                                last_variant.1 = Some(process_payload(group));
63                            }
64                        }
65                        _ => {}
66                    }
67                }
68            }
69            _ => {}
70        }
71    }
72    generate_get_description(enum_name, &variant_description_map, enum_variants)
73}
74
75/// Processes a Rust docs to extract the description string.
76fn process_description(token_iter: &mut IntoIter) -> Option<String> {
77    if let Some(TokenTree::Group(doc_group)) = token_iter.next() {
78        let mut doc_group_iter = doc_group.stream().into_iter();
79        // Skip the `desc` and `(` tokens to reach the actual description
80        doc_group_iter.next();
81        doc_group_iter.next();
82        if let Some(TokenTree::Literal(description)) = doc_group_iter.next() {
83            return Some(description.to_string());
84        }
85    }
86    None
87}
88
89/// Processes the payload of an enum variant to extract variable names (ignoring types).
90fn process_payload(payload_group: Group) -> String {
91    let payload_group_iter = payload_group.stream().into_iter();
92    let mut variable_name_list = String::from("");
93    let mut is_variable_name = true;
94    for token in payload_group_iter {
95        match token {
96            TokenTree::Ident(ident) => {
97                if is_variable_name {
98                    variable_name_list.push_str(&format!("{},", ident));
99                }
100                is_variable_name = false;
101            }
102            TokenTree::Punct(punct) => {
103                if punct.to_string() == "," {
104                    is_variable_name = true;
105                }
106            }
107            _ => {}
108        }
109    }
110    format!("{{ {} }}", variable_name_list).to_string()
111}
112/// Generates the `get_description` implementation for the processed enum.
113fn generate_get_description(
114    enum_name: String,
115    variant_description_map: &HashMap<String, String>,
116    enum_variants: Vec<(String, Option<String>)>,
117) -> TokenStream {
118    let mut all_enum_arms = String::from("");
119    for (variant, payload) in enum_variants {
120        let payload = payload.unwrap_or("".to_string());
121        let desc;
122        if let Some(description) = variant_description_map.get(&variant) {
123            desc = format!("Some({})", description);
124        } else {
125            desc = "None".to_string();
126        }
127        all_enum_arms.push_str(&format!(
128            "{}::{} {} => {},\n",
129            enum_name, variant, payload, desc
130        ));
131    }
132
133    let enum_impl = format!(
134        "impl {}  {{ 
135     pub fn get_description(&self) -> Option<&str> {{
136     match self {{
137     {}
138     }}
139     }}
140     }}",
141        enum_name, all_enum_arms
142    );
143    enum_impl
144        .parse()
145        .expect("generated enum impl should be valid Rust token stream")
146}
147
148/// Register your extension with 'core' by providing the relevant functions
149///```ignore
150///use limbo_ext::{register_extension, scalar, Value, AggregateDerive, AggFunc};
151///
152/// register_extension!{ scalars: { return_one }, aggregates: { SumPlusOne } }
153///
154///#[scalar(name = "one")]
155///fn return_one(args: &[Value]) -> Value {
156///  return Value::from_integer(1);
157///}
158///
159///#[derive(AggregateDerive)]
160///struct SumPlusOne;
161///
162///impl AggFunc for SumPlusOne {
163///   type State = i64;
164///   const NAME: &'static str = "sum_plus_one";
165///   const ARGS: i32 = 1;
166///
167///   fn step(state: &mut Self::State, args: &[Value]) {
168///      let Some(val) = args[0].to_integer() else {
169///        return;
170///      };
171///      *state += val;
172///     }
173///
174///     fn finalize(state: Self::State) -> Value {
175///        Value::from_integer(state + 1)
176///     }
177///}
178///
179/// ```
180#[proc_macro]
181pub fn register_extension(input: TokenStream) -> TokenStream {
182    ext::register_extension(input)
183}
184
185/// Declare a scalar function for your extension. This requires the name:
186/// #[scalar(name = "example")] of what you wish to call your function with.
187/// ```text
188/// use limbo_ext::{scalar, Value};
189/// #[scalar(name = "double", alias = "twice")] // you can provide an <optional> alias
190/// fn double(args: &[Value]) -> Value {
191///       let arg = args.get(0).unwrap();
192///       match arg.value_type() {
193///           ValueType::Float => {
194///               let val = arg.to_float().unwrap();
195///               Value::from_float(val * 2.0)
196///           }
197///           ValueType::Integer => {
198///               let val = arg.to_integer().unwrap();
199///               Value::from_integer(val * 2)
200///           }
201///       }
202///   } else {
203///       Value::null()
204///   }
205/// }
206/// ```
207#[proc_macro_attribute]
208pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream {
209    ext::scalar(attr, input)
210}
211
212/// Define an aggregate function for your extension by deriving
213/// AggregateDerive on a struct that implements the AggFunc trait.
214/// ```ignore
215/// use limbo_ext::{register_extension, Value, AggregateDerive, AggFunc};
216///
217///#[derive(AggregateDerive)]
218///struct SumPlusOne;
219///
220///impl AggFunc for SumPlusOne {
221///   type State = i64;
222///   type Error = &'static str;
223///   const NAME: &'static str = "sum_plus_one";
224///   const ARGS: i32 = 1;
225///   fn step(state: &mut Self::State, args: &[Value]) {
226///      let Some(val) = args[0].to_integer() else {
227///        return;
228///     };
229///     *state += val;
230///     }
231///     fn finalize(state: Self::State) -> Result<Value, Self::Error> {
232///        Ok(Value::from_integer(state + 1))
233///     }
234///}
235/// ```
236#[proc_macro_derive(AggregateDerive)]
237pub fn derive_agg_func(input: TokenStream) -> TokenStream {
238    ext::derive_agg_func(input)
239}
240
241/// Macro to derive a VTabModule for your extension. This macro will generate
242/// the necessary functions to register your module with core. You must implement
243/// the VTabModule, VTable, and VTabCursor traits.
244/// ```ignore
245/// #[derive(Debug, VTabModuleDerive)]
246/// struct CsvVTabModule;
247///
248/// impl VTabModule for CsvVTabModule {
249///  type Table = CsvTable;
250///  const NAME: &'static str = "csv_data";
251///  const VTAB_KIND: VTabKind = VTabKind::VirtualTable;
252///
253///   /// Declare your virtual table and its schema
254///  fn create(args: &[Value]) -> Result<(String, Self::Table), ResultCode> {
255///     let schema = "CREATE TABLE csv_data(
256///             name TEXT,
257///             age TEXT,
258///             city TEXT
259///         )".into();
260///     Ok((schema, CsvTable {}))
261///  }
262/// }
263///
264/// struct CsvTable {}
265///
266/// // Implement the VTable trait for your virtual table
267/// impl VTable for CsvTable {
268///  type Cursor = CsvCursor;
269///  type Error = &'static str;
270///
271///  /// Open the virtual table and return a cursor
272///  fn open(&self) -> Result<Self::Cursor, Self::Error> {
273///     let csv_content = fs::read_to_string("data.csv").unwrap_or_default();
274///     let rows: Vec<Vec<String>> = csv_content
275///         .lines()
276///         .skip(1)
277///         .map(|line| {
278///             line.split(',')
279///                 .map(|s| s.trim().to_string())
280///                 .collect()
281///         })
282///         .collect();
283///     Ok(CsvCursor { rows, index: 0 })
284///  }
285///
286/// /// **Optional** methods for non-readonly tables:
287///
288///  /// Update the row with the provided values, return the new rowid
289///  fn update(&mut self, rowid: i64, args: &[Value]) -> Result<Option<i64>, Self::Error> {
290///      Ok(None)// return Ok(None) for read-only
291///  }
292///
293///  /// Insert a new row with the provided values, return the new rowid
294///  fn insert(&mut self, args: &[Value]) -> Result<(), Self::Error> {
295///      Ok(()) //
296///  }
297///
298///  /// Delete the row with the provided rowid
299///  fn delete(&mut self, rowid: i64) -> Result<(), Self::Error> {
300///    Ok(())
301///  }
302///
303///  /// Destroy the virtual table. Any cleanup logic for when the table is deleted comes heres
304///  fn destroy(&mut self) -> Result<(), Self::Error> {
305///     Ok(())
306///  }
307/// }
308///
309///  #[derive(Debug)]
310/// struct CsvCursor {
311///   rows: Vec<Vec<String>>,
312///   index: usize,
313/// }
314///
315/// impl CsvCursor {
316///   /// Returns the value for a given column index.
317///   fn column(&self, idx: u32) -> Result<Value, Self::Error> {
318///       let row = &self.rows[self.index];
319///       if (idx as usize) < row.len() {
320///           Value::from_text(&row[idx as usize])
321///       } else {
322///           Value::null()
323///       }
324///   }
325/// }
326///
327/// // Implement the VTabCursor trait for your virtual cursor
328/// impl VTabCursor for CsvCursor {
329///  type Error = &'static str;
330///
331///  /// Filter the virtual table based on arguments (omitted here for simplicity)
332///  fn filter(&mut self, _args: &[Value], _idx_info: Option<(&str, i32)>) -> ResultCode {
333///      ResultCode::OK
334///  }
335///
336///  /// Move the cursor to the next row
337///  fn next(&mut self) -> ResultCode {
338///     if self.index < self.rows.len() - 1 {
339///         self.index += 1;
340///         ResultCode::OK
341///     } else {
342///         ResultCode::EOF
343///     }
344///  }
345///
346///  fn eof(&self) -> bool {
347///      self.index >= self.rows.len()
348///  }
349///
350///  /// Return the value for a given column index
351///  fn column(&self, idx: u32) -> Result<Value, Self::Error> {
352///      self.column(idx)
353///  }
354///
355///  fn rowid(&self) -> i64 {
356///      self.index as i64
357///  }
358/// }
359///
360#[proc_macro_derive(VTabModuleDerive)]
361pub fn derive_vtab_module(input: TokenStream) -> TokenStream {
362    ext::derive_vtab_module(input)
363}
364
365/// ```text
366/// use limbo_ext::{ExtResult as Result, VfsDerive, VfsExtension, VfsFile};
367///
368/// // Your struct must also impl Default
369/// #[derive(VfsDerive, Default)]
370/// struct ExampleFS;
371///
372///
373/// struct ExampleFile {
374///    file: std::fs::File,
375///
376///
377/// impl VfsExtension for ExampleFS {
378///    /// The name of your vfs module
379///    const NAME: &'static str = "example";
380///
381///    type File = ExampleFile;
382///
383///    fn open(&self, path: &str, flags: i32, _direct: bool) -> Result<Self::File> {
384///        let file = OpenOptions::new()
385///            .read(true)
386///            .write(true)
387///            .create(flags & 1 != 0)
388///            .open(path)
389///            .map_err(|_| ResultCode::Error)?;
390///        Ok(TestFile { file })
391///    }
392///
393///    fn run_once(&self) -> Result<()> {
394///    // (optional) method to cycle/advance IO, if your extension is asynchronous
395///        Ok(())
396///    }
397///
398///    fn close(&self, file: Self::File) -> Result<()> {
399///    // (optional) method to close or drop the file
400///        Ok(())
401///    }
402///
403///    fn generate_random_number(&self) -> i64 {
404///    // (optional) method to generate random number. Used for testing
405///        let mut buf = [0u8; 8];
406///        getrandom::fill(&mut buf).unwrap();
407///        i64::from_ne_bytes(buf)
408///    }
409///
410///   fn get_current_time(&self) -> String {
411///    // (optional) method to generate random number. Used for testing
412///        chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()
413///    }
414///
415///
416/// impl VfsFile for ExampleFile {
417///    fn read(
418///        &mut self,
419///        buf: &mut [u8],
420///        count: usize,
421///        offset: i64,
422///    ) -> Result<i32> {
423///        if file.file.seek(SeekFrom::Start(offset as u64)).is_err() {
424///            return Err(ResultCode::Error);
425///        }
426///        file.file
427///            .read(&mut buf[..count])
428///            .map_err(|_| ResultCode::Error)
429///            .map(|n| n as i32)
430///    }
431///
432///    fn write(&mut self, buf: &[u8], count: usize, offset: i64) -> Result<i32> {
433///        if self.file.seek(SeekFrom::Start(offset as u64)).is_err() {
434///            return Err(ResultCode::Error);
435///        }
436///        self.file
437///            .write(&buf[..count])
438///            .map_err(|_| ResultCode::Error)
439///            .map(|n| n as i32)
440///    }
441///
442///    fn sync(&self) -> Result<()> {
443///        self.file.sync_all().map_err(|_| ResultCode::Error)
444///    }
445///
446///    fn size(&self) -> i64 {
447///      self.file.metadata().map(|m| m.len() as i64).unwrap_or(-1)
448///   }
449///}
450///
451///```
452#[proc_macro_derive(VfsDerive)]
453pub fn derive_vfs_module(input: TokenStream) -> TokenStream {
454    ext::derive_vfs_module(input)
455}