Skip to main content

prax_query/mem_optimize/
arena.rs

1//! Typed arena allocation for query builder chains.
2//!
3//! This module provides arena-based allocation for efficient query construction
4//! with minimal heap allocations.
5//!
6//! # Benefits
7//!
8//! - **Batch deallocation**: All allocations freed at once when scope ends
9//! - **Cache-friendly**: Contiguous memory allocation
10//! - **Fast allocation**: O(1) bump pointer allocation
11//! - **No fragmentation**: No individual deallocation overhead
12//!
13//! # Example
14//!
15//! ```rust
16//! use prax_query::mem_optimize::arena::QueryArena;
17//!
18//! let arena = QueryArena::new();
19//!
20//! // Build query within arena scope
21//! let sql = arena.scope(|scope| {
22//!     let filter = scope.eq("status", "active");
23//!     let filter = scope.and(vec![
24//!         filter,
25//!         scope.gt("age", 18),
26//!     ]);
27//!     scope.build_select("users", filter)
28//! });
29//!
30//! // Arena memory freed, but sql String is owned
31//! ```
32
33use bumpalo::Bump;
34use std::cell::Cell;
35use std::fmt::Write;
36
37use super::interning::InternedStr;
38
39// ============================================================================
40// Query Arena
41// ============================================================================
42
43/// Arena allocator for query building.
44///
45/// Provides fast allocation with batch deallocation when the scope ends.
46pub struct QueryArena {
47    bump: Bump,
48    stats: Cell<ArenaStats>,
49}
50
51impl QueryArena {
52    /// Create a new query arena with default capacity.
53    pub fn new() -> Self {
54        Self {
55            bump: Bump::new(),
56            stats: Cell::new(ArenaStats::default()),
57        }
58    }
59
60    /// Create an arena with specified initial capacity.
61    pub fn with_capacity(capacity: usize) -> Self {
62        Self {
63            bump: Bump::with_capacity(capacity),
64            stats: Cell::new(ArenaStats::default()),
65        }
66    }
67
68    /// Execute a closure with an arena scope.
69    ///
70    /// The scope provides allocation methods. All allocations are valid
71    /// within the closure and freed when it returns.
72    pub fn scope<F, R>(&self, f: F) -> R
73    where
74        F: FnOnce(&ArenaScope<'_>) -> R,
75    {
76        let scope = ArenaScope::new(&self.bump, &self.stats);
77        f(&scope)
78    }
79
80    /// Reset the arena for reuse.
81    ///
82    /// This is O(1) - just resets the bump pointer.
83    pub fn reset(&mut self) {
84        self.bump.reset();
85        self.stats.set(ArenaStats::default());
86    }
87
88    /// Get the number of bytes currently allocated.
89    pub fn allocated_bytes(&self) -> usize {
90        self.bump.allocated_bytes()
91    }
92
93    /// Get arena statistics.
94    pub fn stats(&self) -> ArenaStats {
95        self.stats.get()
96    }
97}
98
99impl Default for QueryArena {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105// ============================================================================
106// Arena Scope
107// ============================================================================
108
109/// A scope for allocating within an arena.
110///
111/// All allocations made through this scope are freed when the scope ends.
112pub struct ArenaScope<'a> {
113    bump: &'a Bump,
114    stats: &'a Cell<ArenaStats>,
115}
116
117impl<'a> ArenaScope<'a> {
118    fn new(bump: &'a Bump, stats: &'a Cell<ArenaStats>) -> Self {
119        Self { bump, stats }
120    }
121
122    fn record_alloc(&self, bytes: usize) {
123        let mut s = self.stats.get();
124        s.allocations += 1;
125        s.total_bytes += bytes;
126        self.stats.set(s);
127    }
128
129    /// Allocate a string in the arena.
130    #[inline]
131    pub fn alloc_str(&self, s: &str) -> &'a str {
132        self.record_alloc(s.len());
133        self.bump.alloc_str(s)
134    }
135
136    /// Allocate a slice in the arena.
137    #[inline]
138    pub fn alloc_slice<T: Copy>(&self, slice: &[T]) -> &'a [T] {
139        self.record_alloc(std::mem::size_of_val(slice));
140        self.bump.alloc_slice_copy(slice)
141    }
142
143    /// Allocate a slice from an iterator.
144    #[inline]
145    pub fn alloc_slice_iter<T, I>(&self, iter: I) -> &'a [T]
146    where
147        I: IntoIterator<Item = T>,
148        I::IntoIter: ExactSizeIterator,
149    {
150        let iter = iter.into_iter();
151        self.record_alloc(iter.len() * std::mem::size_of::<T>());
152        self.bump.alloc_slice_fill_iter(iter)
153    }
154
155    /// Allocate a single value in the arena.
156    #[inline]
157    pub fn alloc<T>(&self, value: T) -> &'a T {
158        self.record_alloc(std::mem::size_of::<T>());
159        self.bump.alloc(value)
160    }
161
162    /// Allocate a mutable value in the arena.
163    #[inline]
164    pub fn alloc_mut<T>(&self, value: T) -> &'a mut T {
165        self.record_alloc(std::mem::size_of::<T>());
166        self.bump.alloc(value)
167    }
168
169    // ========================================================================
170    // Filter Construction
171    // ========================================================================
172
173    /// Create an equality filter.
174    #[inline]
175    pub fn eq<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
176        ScopedFilter::Equals(self.alloc_str(field), value.into())
177    }
178
179    /// Create a not-equals filter.
180    #[inline]
181    pub fn ne<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
182        ScopedFilter::NotEquals(self.alloc_str(field), value.into())
183    }
184
185    /// Create a less-than filter.
186    #[inline]
187    pub fn lt<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
188        ScopedFilter::Lt(self.alloc_str(field), value.into())
189    }
190
191    /// Create a less-than-or-equal filter.
192    #[inline]
193    pub fn lte<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
194        ScopedFilter::Lte(self.alloc_str(field), value.into())
195    }
196
197    /// Create a greater-than filter.
198    #[inline]
199    pub fn gt<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
200        ScopedFilter::Gt(self.alloc_str(field), value.into())
201    }
202
203    /// Create a greater-than-or-equal filter.
204    #[inline]
205    pub fn gte<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
206        ScopedFilter::Gte(self.alloc_str(field), value.into())
207    }
208
209    /// Create an IN filter.
210    #[inline]
211    pub fn is_in(&self, field: &str, values: Vec<ScopedValue<'a>>) -> ScopedFilter<'a> {
212        ScopedFilter::In(self.alloc_str(field), self.alloc_slice_iter(values))
213    }
214
215    /// Create a NOT IN filter.
216    #[inline]
217    pub fn not_in(&self, field: &str, values: Vec<ScopedValue<'a>>) -> ScopedFilter<'a> {
218        ScopedFilter::NotIn(self.alloc_str(field), self.alloc_slice_iter(values))
219    }
220
221    /// Create a CONTAINS filter.
222    #[inline]
223    pub fn contains<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
224        ScopedFilter::Contains(self.alloc_str(field), value.into())
225    }
226
227    /// Create a STARTS WITH filter.
228    #[inline]
229    pub fn starts_with<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
230        ScopedFilter::StartsWith(self.alloc_str(field), value.into())
231    }
232
233    /// Create an ENDS WITH filter.
234    #[inline]
235    pub fn ends_with<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
236        ScopedFilter::EndsWith(self.alloc_str(field), value.into())
237    }
238
239    /// Create an IS NULL filter.
240    #[inline]
241    pub fn is_null(&self, field: &str) -> ScopedFilter<'a> {
242        ScopedFilter::IsNull(self.alloc_str(field))
243    }
244
245    /// Create an IS NOT NULL filter.
246    #[inline]
247    pub fn is_not_null(&self, field: &str) -> ScopedFilter<'a> {
248        ScopedFilter::IsNotNull(self.alloc_str(field))
249    }
250
251    /// Combine filters with AND.
252    #[inline]
253    pub fn and(&self, filters: Vec<ScopedFilter<'a>>) -> ScopedFilter<'a> {
254        // Filter out None filters
255        let filters: Vec<_> = filters
256            .into_iter()
257            .filter(|f| !matches!(f, ScopedFilter::None))
258            .collect();
259
260        match filters.len() {
261            0 => ScopedFilter::None,
262            1 => filters.into_iter().next().unwrap(),
263            _ => ScopedFilter::And(self.alloc_slice_iter(filters)),
264        }
265    }
266
267    /// Combine filters with OR.
268    #[inline]
269    pub fn or(&self, filters: Vec<ScopedFilter<'a>>) -> ScopedFilter<'a> {
270        let filters: Vec<_> = filters
271            .into_iter()
272            .filter(|f| !matches!(f, ScopedFilter::None))
273            .collect();
274
275        match filters.len() {
276            0 => ScopedFilter::None,
277            1 => filters.into_iter().next().unwrap(),
278            _ => ScopedFilter::Or(self.alloc_slice_iter(filters)),
279        }
280    }
281
282    /// Negate a filter.
283    #[inline]
284    pub fn not(&self, filter: ScopedFilter<'a>) -> ScopedFilter<'a> {
285        if matches!(filter, ScopedFilter::None) {
286            return ScopedFilter::None;
287        }
288        ScopedFilter::Not(self.alloc(filter))
289    }
290
291    // ========================================================================
292    // Query Building
293    // ========================================================================
294
295    /// Build a SELECT query string.
296    pub fn build_select(&self, table: &str, filter: ScopedFilter<'a>) -> String {
297        let mut sql = String::with_capacity(128);
298        sql.push_str("SELECT * FROM ");
299        sql.push_str(table);
300
301        if !matches!(filter, ScopedFilter::None) {
302            sql.push_str(" WHERE ");
303            filter.write_sql(&mut sql, &mut 1);
304        }
305
306        sql
307    }
308
309    /// Build a SELECT query with specific columns.
310    pub fn build_select_columns(
311        &self,
312        table: &str,
313        columns: &[&str],
314        filter: ScopedFilter<'a>,
315    ) -> String {
316        let mut sql = String::with_capacity(128);
317        sql.push_str("SELECT ");
318
319        for (i, col) in columns.iter().enumerate() {
320            if i > 0 {
321                sql.push_str(", ");
322            }
323            sql.push_str(col);
324        }
325
326        sql.push_str(" FROM ");
327        sql.push_str(table);
328
329        if !matches!(filter, ScopedFilter::None) {
330            sql.push_str(" WHERE ");
331            filter.write_sql(&mut sql, &mut 1);
332        }
333
334        sql
335    }
336
337    /// Build a complete query with all parts.
338    pub fn build_query(&self, query: &ScopedQuery<'a>) -> String {
339        let mut sql = String::with_capacity(256);
340
341        // SELECT
342        sql.push_str("SELECT ");
343        if query.columns.is_empty() {
344            sql.push('*');
345        } else {
346            for (i, col) in query.columns.iter().enumerate() {
347                if i > 0 {
348                    sql.push_str(", ");
349                }
350                sql.push_str(col);
351            }
352        }
353
354        // FROM
355        sql.push_str(" FROM ");
356        sql.push_str(query.table);
357
358        // WHERE
359        if !matches!(query.filter, ScopedFilter::None) {
360            sql.push_str(" WHERE ");
361            query.filter.write_sql(&mut sql, &mut 1);
362        }
363
364        // ORDER BY
365        if !query.order_by.is_empty() {
366            sql.push_str(" ORDER BY ");
367            for (i, (col, dir)) in query.order_by.iter().enumerate() {
368                if i > 0 {
369                    sql.push_str(", ");
370                }
371                sql.push_str(col);
372                sql.push(' ');
373                sql.push_str(dir);
374            }
375        }
376
377        // LIMIT
378        if let Some(limit) = query.limit {
379            write!(sql, " LIMIT {}", limit).unwrap();
380        }
381
382        // OFFSET
383        if let Some(offset) = query.offset {
384            write!(sql, " OFFSET {}", offset).unwrap();
385        }
386
387        sql
388    }
389
390    /// Create a new query builder.
391    pub fn query(&self, table: &str) -> ScopedQuery<'a> {
392        ScopedQuery {
393            table: self.alloc_str(table),
394            columns: &[],
395            filter: ScopedFilter::None,
396            order_by: &[],
397            limit: None,
398            offset: None,
399        }
400    }
401}
402
403// ============================================================================
404// Scoped Filter
405// ============================================================================
406
407/// A filter allocated within an arena scope.
408#[derive(Debug, Clone)]
409pub enum ScopedFilter<'a> {
410    /// No filter.
411    None,
412    /// Equality.
413    Equals(&'a str, ScopedValue<'a>),
414    /// Not equals.
415    NotEquals(&'a str, ScopedValue<'a>),
416    /// Less than.
417    Lt(&'a str, ScopedValue<'a>),
418    /// Less than or equal.
419    Lte(&'a str, ScopedValue<'a>),
420    /// Greater than.
421    Gt(&'a str, ScopedValue<'a>),
422    /// Greater than or equal.
423    Gte(&'a str, ScopedValue<'a>),
424    /// In list.
425    In(&'a str, &'a [ScopedValue<'a>]),
426    /// Not in list.
427    NotIn(&'a str, &'a [ScopedValue<'a>]),
428    /// Contains.
429    Contains(&'a str, ScopedValue<'a>),
430    /// Starts with.
431    StartsWith(&'a str, ScopedValue<'a>),
432    /// Ends with.
433    EndsWith(&'a str, ScopedValue<'a>),
434    /// Is null.
435    IsNull(&'a str),
436    /// Is not null.
437    IsNotNull(&'a str),
438    /// And.
439    And(&'a [ScopedFilter<'a>]),
440    /// Or.
441    Or(&'a [ScopedFilter<'a>]),
442    /// Not.
443    Not(&'a ScopedFilter<'a>),
444}
445
446impl<'a> ScopedFilter<'a> {
447    /// Write SQL to a string buffer.
448    pub fn write_sql(&self, buf: &mut String, param_idx: &mut usize) {
449        match self {
450            ScopedFilter::None => {}
451            ScopedFilter::Equals(field, _) => {
452                write!(buf, "{} = ${}", field, param_idx).unwrap();
453                *param_idx += 1;
454            }
455            ScopedFilter::NotEquals(field, _) => {
456                write!(buf, "{} != ${}", field, param_idx).unwrap();
457                *param_idx += 1;
458            }
459            ScopedFilter::Lt(field, _) => {
460                write!(buf, "{} < ${}", field, param_idx).unwrap();
461                *param_idx += 1;
462            }
463            ScopedFilter::Lte(field, _) => {
464                write!(buf, "{} <= ${}", field, param_idx).unwrap();
465                *param_idx += 1;
466            }
467            ScopedFilter::Gt(field, _) => {
468                write!(buf, "{} > ${}", field, param_idx).unwrap();
469                *param_idx += 1;
470            }
471            ScopedFilter::Gte(field, _) => {
472                write!(buf, "{} >= ${}", field, param_idx).unwrap();
473                *param_idx += 1;
474            }
475            ScopedFilter::In(field, values) => {
476                write!(buf, "{} IN (", field).unwrap();
477                for (i, _) in values.iter().enumerate() {
478                    if i > 0 {
479                        buf.push_str(", ");
480                    }
481                    write!(buf, "${}", param_idx).unwrap();
482                    *param_idx += 1;
483                }
484                buf.push(')');
485            }
486            ScopedFilter::NotIn(field, values) => {
487                write!(buf, "{} NOT IN (", field).unwrap();
488                for (i, _) in values.iter().enumerate() {
489                    if i > 0 {
490                        buf.push_str(", ");
491                    }
492                    write!(buf, "${}", param_idx).unwrap();
493                    *param_idx += 1;
494                }
495                buf.push(')');
496            }
497            ScopedFilter::Contains(field, _) => {
498                write!(buf, "{} LIKE ${}", field, param_idx).unwrap();
499                *param_idx += 1;
500            }
501            ScopedFilter::StartsWith(field, _) => {
502                write!(buf, "{} LIKE ${}", field, param_idx).unwrap();
503                *param_idx += 1;
504            }
505            ScopedFilter::EndsWith(field, _) => {
506                write!(buf, "{} LIKE ${}", field, param_idx).unwrap();
507                *param_idx += 1;
508            }
509            ScopedFilter::IsNull(field) => {
510                write!(buf, "{} IS NULL", field).unwrap();
511            }
512            ScopedFilter::IsNotNull(field) => {
513                write!(buf, "{} IS NOT NULL", field).unwrap();
514            }
515            ScopedFilter::And(filters) => {
516                buf.push('(');
517                for (i, filter) in filters.iter().enumerate() {
518                    if i > 0 {
519                        buf.push_str(" AND ");
520                    }
521                    filter.write_sql(buf, param_idx);
522                }
523                buf.push(')');
524            }
525            ScopedFilter::Or(filters) => {
526                buf.push('(');
527                for (i, filter) in filters.iter().enumerate() {
528                    if i > 0 {
529                        buf.push_str(" OR ");
530                    }
531                    filter.write_sql(buf, param_idx);
532                }
533                buf.push(')');
534            }
535            ScopedFilter::Not(filter) => {
536                buf.push_str("NOT (");
537                filter.write_sql(buf, param_idx);
538                buf.push(')');
539            }
540        }
541    }
542}
543
544// ============================================================================
545// Scoped Value
546// ============================================================================
547
548/// A value allocated within an arena scope.
549#[derive(Debug, Clone)]
550pub enum ScopedValue<'a> {
551    /// Null.
552    Null,
553    /// Boolean.
554    Bool(bool),
555    /// Integer.
556    Int(i64),
557    /// Float.
558    Float(f64),
559    /// String (borrowed from arena).
560    String(&'a str),
561    /// Interned string (shared reference).
562    Interned(InternedStr),
563}
564
565impl<'a> From<bool> for ScopedValue<'a> {
566    fn from(v: bool) -> Self {
567        ScopedValue::Bool(v)
568    }
569}
570
571impl<'a> From<i32> for ScopedValue<'a> {
572    fn from(v: i32) -> Self {
573        ScopedValue::Int(v as i64)
574    }
575}
576
577impl<'a> From<i64> for ScopedValue<'a> {
578    fn from(v: i64) -> Self {
579        ScopedValue::Int(v)
580    }
581}
582
583impl<'a> From<f64> for ScopedValue<'a> {
584    fn from(v: f64) -> Self {
585        ScopedValue::Float(v)
586    }
587}
588
589impl<'a> From<&'a str> for ScopedValue<'a> {
590    fn from(v: &'a str) -> Self {
591        ScopedValue::String(v)
592    }
593}
594
595impl<'a> From<InternedStr> for ScopedValue<'a> {
596    fn from(v: InternedStr) -> Self {
597        ScopedValue::Interned(v)
598    }
599}
600
601// ============================================================================
602// Scoped Query
603// ============================================================================
604
605/// A query being built within an arena scope.
606#[derive(Debug, Clone)]
607pub struct ScopedQuery<'a> {
608    /// Table name.
609    pub table: &'a str,
610    /// Columns to select.
611    pub columns: &'a [&'a str],
612    /// Filter.
613    pub filter: ScopedFilter<'a>,
614    /// Order by clauses.
615    pub order_by: &'a [(&'a str, &'a str)],
616    /// Limit.
617    pub limit: Option<usize>,
618    /// Offset.
619    pub offset: Option<usize>,
620}
621
622impl<'a> ScopedQuery<'a> {
623    /// Set columns to select.
624    pub fn select(mut self, columns: &'a [&'a str]) -> Self {
625        self.columns = columns;
626        self
627    }
628
629    /// Set filter.
630    pub fn filter(mut self, filter: ScopedFilter<'a>) -> Self {
631        self.filter = filter;
632        self
633    }
634
635    /// Set order by.
636    pub fn order_by(mut self, order_by: &'a [(&'a str, &'a str)]) -> Self {
637        self.order_by = order_by;
638        self
639    }
640
641    /// Set limit.
642    pub fn limit(mut self, limit: usize) -> Self {
643        self.limit = Some(limit);
644        self
645    }
646
647    /// Set offset.
648    pub fn offset(mut self, offset: usize) -> Self {
649        self.offset = Some(offset);
650        self
651    }
652}
653
654// ============================================================================
655// Statistics
656// ============================================================================
657
658/// Statistics for arena usage.
659#[derive(Debug, Clone, Copy, Default)]
660pub struct ArenaStats {
661    /// Number of allocations.
662    pub allocations: usize,
663    /// Total bytes allocated.
664    pub total_bytes: usize,
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670
671    #[test]
672    fn test_arena_basic_filter() {
673        let arena = QueryArena::new();
674
675        let sql = arena.scope(|scope| scope.build_select("users", scope.eq("id", 42)));
676
677        assert!(sql.contains("SELECT * FROM users"));
678        assert!(sql.contains("WHERE"));
679        assert!(sql.contains("id = $1"));
680    }
681
682    #[test]
683    fn test_arena_complex_filter() {
684        let arena = QueryArena::new();
685
686        let sql = arena.scope(|scope| {
687            let filter = scope.and(vec![
688                scope.eq("active", true),
689                scope.or(vec![scope.gt("age", 18), scope.is_not_null("verified_at")]),
690            ]);
691            scope.build_select("users", filter)
692        });
693
694        assert!(sql.contains("AND"));
695        assert!(sql.contains("OR"));
696    }
697
698    #[test]
699    fn test_arena_reset() {
700        let mut arena = QueryArena::with_capacity(1024);
701
702        // Use arena
703        let _ = arena.scope(|scope| scope.build_select("users", scope.eq("id", 1)));
704        let bytes1 = arena.allocated_bytes();
705
706        // Reset
707        arena.reset();
708
709        // Use again
710        let _ = arena.scope(|scope| scope.build_select("posts", scope.eq("id", 2)));
711        let bytes2 = arena.allocated_bytes();
712
713        // Should be similar (reusing memory)
714        assert!(bytes2 <= bytes1 * 2);
715    }
716
717    #[test]
718    fn test_arena_query_builder() {
719        let arena = QueryArena::new();
720
721        let sql = arena.scope(|scope| {
722            let query = scope
723                .query("users")
724                .filter(scope.eq("active", true))
725                .limit(10)
726                .offset(20);
727            scope.build_query(&query)
728        });
729
730        assert!(sql.contains("SELECT * FROM users"));
731        assert!(sql.contains("LIMIT 10"));
732        assert!(sql.contains("OFFSET 20"));
733    }
734
735    #[test]
736    fn test_arena_in_filter() {
737        let arena = QueryArena::new();
738
739        let sql = arena.scope(|scope| {
740            let filter = scope.is_in(
741                "status",
742                vec!["pending".into(), "processing".into(), "completed".into()],
743            );
744            scope.build_select("orders", filter)
745        });
746
747        assert!(sql.contains("IN"));
748        assert!(sql.contains("$1"));
749        assert!(sql.contains("$2"));
750        assert!(sql.contains("$3"));
751    }
752
753    #[test]
754    fn test_arena_stats() {
755        let arena = QueryArena::new();
756
757        arena.scope(|scope| {
758            let _ = scope.alloc_str("test string");
759            let _ = scope.alloc_str("another string");
760        });
761
762        let stats = arena.stats();
763        assert_eq!(stats.allocations, 2);
764    }
765}