limbo_macros/lib.rs
1// UPSTREAM: vendored Limbo fork — allow upstream style
2#![allow(
3 rustdoc::bare_urls,
4 rustdoc::invalid_html_tags,
5 rustdoc::invalid_rust_codeblocks
6)]
7#![allow(clippy::collapsible_match)]
8
9mod ext;
10extern crate proc_macro;
11use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree};
12use std::collections::HashMap;
13
14/// A procedural macro that derives a `Description` trait for enums.
15/// This macro extracts documentation comments (specified with `/// Description...`) for enum variants
16/// and generates an implementation for `get_description`, which returns the associated description.
17#[proc_macro_derive(Description, attributes(desc))]
18pub fn derive_description_from_doc(item: TokenStream) -> TokenStream {
19 // Convert the TokenStream into an iterator of TokenTree
20 let mut tokens = item.into_iter();
21
22 let mut enum_name = String::new();
23
24 // Vector to store enum variants and their associated payloads (if any)
25 let mut enum_variants: Vec<(String, Option<String>)> = Vec::<(String, Option<String>)>::new();
26
27 // HashMap to store descriptions associated with each enum variant
28 let mut variant_description_map: HashMap<String, String> = HashMap::new();
29
30 // Parses the token stream to extract the enum name and its variants
31 while let Some(token) = tokens.next() {
32 match token {
33 TokenTree::Ident(ident) if ident.to_string() == "enum" => {
34 // Get the enum name
35 if let Some(TokenTree::Ident(name)) = tokens.next() {
36 enum_name = name.to_string();
37 }
38 }
39 TokenTree::Group(group) => {
40 let mut group_tokens_iter: IntoIter = group.stream().into_iter();
41
42 let mut last_seen_desc: Option<String> = None;
43 while let Some(token) = group_tokens_iter.next() {
44 match token {
45 TokenTree::Punct(punct) => {
46 if punct.to_string() == "#" {
47 last_seen_desc = process_description(&mut group_tokens_iter);
48 }
49 }
50 TokenTree::Ident(ident) => {
51 // Capture the enum variant name and associate it with its description
52 let ident_str = ident.to_string();
53 if let Some(desc) = &last_seen_desc {
54 variant_description_map.insert(ident_str.clone(), desc.clone());
55 }
56 enum_variants.push((ident_str, None));
57 last_seen_desc = None;
58 }
59 TokenTree::Group(group) => {
60 // Capture payload information for the current enum variant
61 if let Some(last_variant) = enum_variants.last_mut() {
62 last_variant.1 = Some(process_payload(group));
63 }
64 }
65 _ => {}
66 }
67 }
68 }
69 _ => {}
70 }
71 }
72 generate_get_description(enum_name, &variant_description_map, enum_variants)
73}
74
75/// Processes a Rust docs to extract the description string.
76fn process_description(token_iter: &mut IntoIter) -> Option<String> {
77 if let Some(TokenTree::Group(doc_group)) = token_iter.next() {
78 let mut doc_group_iter = doc_group.stream().into_iter();
79 // Skip the `desc` and `(` tokens to reach the actual description
80 doc_group_iter.next();
81 doc_group_iter.next();
82 if let Some(TokenTree::Literal(description)) = doc_group_iter.next() {
83 return Some(description.to_string());
84 }
85 }
86 None
87}
88
89/// Processes the payload of an enum variant to extract variable names (ignoring types).
90fn process_payload(payload_group: Group) -> String {
91 let payload_group_iter = payload_group.stream().into_iter();
92 let mut variable_name_list = String::from("");
93 let mut is_variable_name = true;
94 for token in payload_group_iter {
95 match token {
96 TokenTree::Ident(ident) => {
97 if is_variable_name {
98 variable_name_list.push_str(&format!("{},", ident));
99 }
100 is_variable_name = false;
101 }
102 TokenTree::Punct(punct) => {
103 if punct.to_string() == "," {
104 is_variable_name = true;
105 }
106 }
107 _ => {}
108 }
109 }
110 format!("{{ {} }}", variable_name_list).to_string()
111}
112/// Generates the `get_description` implementation for the processed enum.
113fn generate_get_description(
114 enum_name: String,
115 variant_description_map: &HashMap<String, String>,
116 enum_variants: Vec<(String, Option<String>)>,
117) -> TokenStream {
118 let mut all_enum_arms = String::from("");
119 for (variant, payload) in enum_variants {
120 let payload = payload.unwrap_or("".to_string());
121 let desc;
122 if let Some(description) = variant_description_map.get(&variant) {
123 desc = format!("Some({})", description);
124 } else {
125 desc = "None".to_string();
126 }
127 all_enum_arms.push_str(&format!(
128 "{}::{} {} => {},\n",
129 enum_name, variant, payload, desc
130 ));
131 }
132
133 let enum_impl = format!(
134 "impl {} {{
135 pub fn get_description(&self) -> Option<&str> {{
136 match self {{
137 {}
138 }}
139 }}
140 }}",
141 enum_name, all_enum_arms
142 );
143 enum_impl
144 .parse()
145 .expect("generated enum impl should be valid Rust token stream")
146}
147
148/// Register your extension with 'core' by providing the relevant functions
149///```ignore
150///use limbo_ext::{register_extension, scalar, Value, AggregateDerive, AggFunc};
151///
152/// register_extension!{ scalars: { return_one }, aggregates: { SumPlusOne } }
153///
154///#[scalar(name = "one")]
155///fn return_one(args: &[Value]) -> Value {
156/// return Value::from_integer(1);
157///}
158///
159///#[derive(AggregateDerive)]
160///struct SumPlusOne;
161///
162///impl AggFunc for SumPlusOne {
163/// type State = i64;
164/// const NAME: &'static str = "sum_plus_one";
165/// const ARGS: i32 = 1;
166///
167/// fn step(state: &mut Self::State, args: &[Value]) {
168/// let Some(val) = args[0].to_integer() else {
169/// return;
170/// };
171/// *state += val;
172/// }
173///
174/// fn finalize(state: Self::State) -> Value {
175/// Value::from_integer(state + 1)
176/// }
177///}
178///
179/// ```
180#[proc_macro]
181pub fn register_extension(input: TokenStream) -> TokenStream {
182 ext::register_extension(input)
183}
184
185/// Declare a scalar function for your extension. This requires the name:
186/// #[scalar(name = "example")] of what you wish to call your function with.
187/// ```text
188/// use limbo_ext::{scalar, Value};
189/// #[scalar(name = "double", alias = "twice")] // you can provide an <optional> alias
190/// fn double(args: &[Value]) -> Value {
191/// let arg = args.get(0).unwrap();
192/// match arg.value_type() {
193/// ValueType::Float => {
194/// let val = arg.to_float().unwrap();
195/// Value::from_float(val * 2.0)
196/// }
197/// ValueType::Integer => {
198/// let val = arg.to_integer().unwrap();
199/// Value::from_integer(val * 2)
200/// }
201/// }
202/// } else {
203/// Value::null()
204/// }
205/// }
206/// ```
207#[proc_macro_attribute]
208pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream {
209 ext::scalar(attr, input)
210}
211
212/// Define an aggregate function for your extension by deriving
213/// AggregateDerive on a struct that implements the AggFunc trait.
214/// ```ignore
215/// use limbo_ext::{register_extension, Value, AggregateDerive, AggFunc};
216///
217///#[derive(AggregateDerive)]
218///struct SumPlusOne;
219///
220///impl AggFunc for SumPlusOne {
221/// type State = i64;
222/// type Error = &'static str;
223/// const NAME: &'static str = "sum_plus_one";
224/// const ARGS: i32 = 1;
225/// fn step(state: &mut Self::State, args: &[Value]) {
226/// let Some(val) = args[0].to_integer() else {
227/// return;
228/// };
229/// *state += val;
230/// }
231/// fn finalize(state: Self::State) -> Result<Value, Self::Error> {
232/// Ok(Value::from_integer(state + 1))
233/// }
234///}
235/// ```
236#[proc_macro_derive(AggregateDerive)]
237pub fn derive_agg_func(input: TokenStream) -> TokenStream {
238 ext::derive_agg_func(input)
239}
240
241/// Macro to derive a VTabModule for your extension. This macro will generate
242/// the necessary functions to register your module with core. You must implement
243/// the VTabModule, VTable, and VTabCursor traits.
244/// ```ignore
245/// #[derive(Debug, VTabModuleDerive)]
246/// struct CsvVTabModule;
247///
248/// impl VTabModule for CsvVTabModule {
249/// type Table = CsvTable;
250/// const NAME: &'static str = "csv_data";
251/// const VTAB_KIND: VTabKind = VTabKind::VirtualTable;
252///
253/// /// Declare your virtual table and its schema
254/// fn create(args: &[Value]) -> Result<(String, Self::Table), ResultCode> {
255/// let schema = "CREATE TABLE csv_data(
256/// name TEXT,
257/// age TEXT,
258/// city TEXT
259/// )".into();
260/// Ok((schema, CsvTable {}))
261/// }
262/// }
263///
264/// struct CsvTable {}
265///
266/// // Implement the VTable trait for your virtual table
267/// impl VTable for CsvTable {
268/// type Cursor = CsvCursor;
269/// type Error = &'static str;
270///
271/// /// Open the virtual table and return a cursor
272/// fn open(&self) -> Result<Self::Cursor, Self::Error> {
273/// let csv_content = fs::read_to_string("data.csv").unwrap_or_default();
274/// let rows: Vec<Vec<String>> = csv_content
275/// .lines()
276/// .skip(1)
277/// .map(|line| {
278/// line.split(',')
279/// .map(|s| s.trim().to_string())
280/// .collect()
281/// })
282/// .collect();
283/// Ok(CsvCursor { rows, index: 0 })
284/// }
285///
286/// /// **Optional** methods for non-readonly tables:
287///
288/// /// Update the row with the provided values, return the new rowid
289/// fn update(&mut self, rowid: i64, args: &[Value]) -> Result<Option<i64>, Self::Error> {
290/// Ok(None)// return Ok(None) for read-only
291/// }
292///
293/// /// Insert a new row with the provided values, return the new rowid
294/// fn insert(&mut self, args: &[Value]) -> Result<(), Self::Error> {
295/// Ok(()) //
296/// }
297///
298/// /// Delete the row with the provided rowid
299/// fn delete(&mut self, rowid: i64) -> Result<(), Self::Error> {
300/// Ok(())
301/// }
302///
303/// /// Destroy the virtual table. Any cleanup logic for when the table is deleted comes heres
304/// fn destroy(&mut self) -> Result<(), Self::Error> {
305/// Ok(())
306/// }
307/// }
308///
309/// #[derive(Debug)]
310/// struct CsvCursor {
311/// rows: Vec<Vec<String>>,
312/// index: usize,
313/// }
314///
315/// impl CsvCursor {
316/// /// Returns the value for a given column index.
317/// fn column(&self, idx: u32) -> Result<Value, Self::Error> {
318/// let row = &self.rows[self.index];
319/// if (idx as usize) < row.len() {
320/// Value::from_text(&row[idx as usize])
321/// } else {
322/// Value::null()
323/// }
324/// }
325/// }
326///
327/// // Implement the VTabCursor trait for your virtual cursor
328/// impl VTabCursor for CsvCursor {
329/// type Error = &'static str;
330///
331/// /// Filter the virtual table based on arguments (omitted here for simplicity)
332/// fn filter(&mut self, _args: &[Value], _idx_info: Option<(&str, i32)>) -> ResultCode {
333/// ResultCode::OK
334/// }
335///
336/// /// Move the cursor to the next row
337/// fn next(&mut self) -> ResultCode {
338/// if self.index < self.rows.len() - 1 {
339/// self.index += 1;
340/// ResultCode::OK
341/// } else {
342/// ResultCode::EOF
343/// }
344/// }
345///
346/// fn eof(&self) -> bool {
347/// self.index >= self.rows.len()
348/// }
349///
350/// /// Return the value for a given column index
351/// fn column(&self, idx: u32) -> Result<Value, Self::Error> {
352/// self.column(idx)
353/// }
354///
355/// fn rowid(&self) -> i64 {
356/// self.index as i64
357/// }
358/// }
359///
360#[proc_macro_derive(VTabModuleDerive)]
361pub fn derive_vtab_module(input: TokenStream) -> TokenStream {
362 ext::derive_vtab_module(input)
363}
364
365/// ```text
366/// use limbo_ext::{ExtResult as Result, VfsDerive, VfsExtension, VfsFile};
367///
368/// // Your struct must also impl Default
369/// #[derive(VfsDerive, Default)]
370/// struct ExampleFS;
371///
372///
373/// struct ExampleFile {
374/// file: std::fs::File,
375///
376///
377/// impl VfsExtension for ExampleFS {
378/// /// The name of your vfs module
379/// const NAME: &'static str = "example";
380///
381/// type File = ExampleFile;
382///
383/// fn open(&self, path: &str, flags: i32, _direct: bool) -> Result<Self::File> {
384/// let file = OpenOptions::new()
385/// .read(true)
386/// .write(true)
387/// .create(flags & 1 != 0)
388/// .open(path)
389/// .map_err(|_| ResultCode::Error)?;
390/// Ok(TestFile { file })
391/// }
392///
393/// fn run_once(&self) -> Result<()> {
394/// // (optional) method to cycle/advance IO, if your extension is asynchronous
395/// Ok(())
396/// }
397///
398/// fn close(&self, file: Self::File) -> Result<()> {
399/// // (optional) method to close or drop the file
400/// Ok(())
401/// }
402///
403/// fn generate_random_number(&self) -> i64 {
404/// // (optional) method to generate random number. Used for testing
405/// let mut buf = [0u8; 8];
406/// getrandom::fill(&mut buf).unwrap();
407/// i64::from_ne_bytes(buf)
408/// }
409///
410/// fn get_current_time(&self) -> String {
411/// // (optional) method to generate random number. Used for testing
412/// chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()
413/// }
414///
415///
416/// impl VfsFile for ExampleFile {
417/// fn read(
418/// &mut self,
419/// buf: &mut [u8],
420/// count: usize,
421/// offset: i64,
422/// ) -> Result<i32> {
423/// if file.file.seek(SeekFrom::Start(offset as u64)).is_err() {
424/// return Err(ResultCode::Error);
425/// }
426/// file.file
427/// .read(&mut buf[..count])
428/// .map_err(|_| ResultCode::Error)
429/// .map(|n| n as i32)
430/// }
431///
432/// fn write(&mut self, buf: &[u8], count: usize, offset: i64) -> Result<i32> {
433/// if self.file.seek(SeekFrom::Start(offset as u64)).is_err() {
434/// return Err(ResultCode::Error);
435/// }
436/// self.file
437/// .write(&buf[..count])
438/// .map_err(|_| ResultCode::Error)
439/// .map(|n| n as i32)
440/// }
441///
442/// fn sync(&self) -> Result<()> {
443/// self.file.sync_all().map_err(|_| ResultCode::Error)
444/// }
445///
446/// fn size(&self) -> i64 {
447/// self.file.metadata().map(|m| m.len() as i64).unwrap_or(-1)
448/// }
449///}
450///
451///```
452#[proc_macro_derive(VfsDerive)]
453pub fn derive_vfs_module(input: TokenStream) -> TokenStream {
454 ext::derive_vfs_module(input)
455}