1use bumpalo::Bump;
34use std::cell::Cell;
35use std::fmt::Write;
36
37use super::interning::InternedStr;
38
39pub struct QueryArena {
47 bump: Bump,
48 stats: Cell<ArenaStats>,
49}
50
51impl QueryArena {
52 pub fn new() -> Self {
54 Self {
55 bump: Bump::new(),
56 stats: Cell::new(ArenaStats::default()),
57 }
58 }
59
60 pub fn with_capacity(capacity: usize) -> Self {
62 Self {
63 bump: Bump::with_capacity(capacity),
64 stats: Cell::new(ArenaStats::default()),
65 }
66 }
67
68 pub fn scope<F, R>(&self, f: F) -> R
73 where
74 F: FnOnce(&ArenaScope<'_>) -> R,
75 {
76 let scope = ArenaScope::new(&self.bump, &self.stats);
77 f(&scope)
78 }
79
80 pub fn reset(&mut self) {
84 self.bump.reset();
85 self.stats.set(ArenaStats::default());
86 }
87
88 pub fn allocated_bytes(&self) -> usize {
90 self.bump.allocated_bytes()
91 }
92
93 pub fn stats(&self) -> ArenaStats {
95 self.stats.get()
96 }
97}
98
99impl Default for QueryArena {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105pub struct ArenaScope<'a> {
113 bump: &'a Bump,
114 stats: &'a Cell<ArenaStats>,
115}
116
117impl<'a> ArenaScope<'a> {
118 fn new(bump: &'a Bump, stats: &'a Cell<ArenaStats>) -> Self {
119 Self { bump, stats }
120 }
121
122 fn record_alloc(&self, bytes: usize) {
123 let mut s = self.stats.get();
124 s.allocations += 1;
125 s.total_bytes += bytes;
126 self.stats.set(s);
127 }
128
129 #[inline]
131 pub fn alloc_str(&self, s: &str) -> &'a str {
132 self.record_alloc(s.len());
133 self.bump.alloc_str(s)
134 }
135
136 #[inline]
138 pub fn alloc_slice<T: Copy>(&self, slice: &[T]) -> &'a [T] {
139 self.record_alloc(std::mem::size_of_val(slice));
140 self.bump.alloc_slice_copy(slice)
141 }
142
143 #[inline]
145 pub fn alloc_slice_iter<T, I>(&self, iter: I) -> &'a [T]
146 where
147 I: IntoIterator<Item = T>,
148 I::IntoIter: ExactSizeIterator,
149 {
150 let iter = iter.into_iter();
151 self.record_alloc(iter.len() * std::mem::size_of::<T>());
152 self.bump.alloc_slice_fill_iter(iter)
153 }
154
155 #[inline]
157 pub fn alloc<T>(&self, value: T) -> &'a T {
158 self.record_alloc(std::mem::size_of::<T>());
159 self.bump.alloc(value)
160 }
161
162 #[inline]
164 pub fn alloc_mut<T>(&self, value: T) -> &'a mut T {
165 self.record_alloc(std::mem::size_of::<T>());
166 self.bump.alloc(value)
167 }
168
169 #[inline]
175 pub fn eq<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
176 ScopedFilter::Equals(self.alloc_str(field), value.into())
177 }
178
179 #[inline]
181 pub fn ne<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
182 ScopedFilter::NotEquals(self.alloc_str(field), value.into())
183 }
184
185 #[inline]
187 pub fn lt<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
188 ScopedFilter::Lt(self.alloc_str(field), value.into())
189 }
190
191 #[inline]
193 pub fn lte<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
194 ScopedFilter::Lte(self.alloc_str(field), value.into())
195 }
196
197 #[inline]
199 pub fn gt<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
200 ScopedFilter::Gt(self.alloc_str(field), value.into())
201 }
202
203 #[inline]
205 pub fn gte<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
206 ScopedFilter::Gte(self.alloc_str(field), value.into())
207 }
208
209 #[inline]
211 pub fn is_in(&self, field: &str, values: Vec<ScopedValue<'a>>) -> ScopedFilter<'a> {
212 ScopedFilter::In(self.alloc_str(field), self.alloc_slice_iter(values))
213 }
214
215 #[inline]
217 pub fn not_in(&self, field: &str, values: Vec<ScopedValue<'a>>) -> ScopedFilter<'a> {
218 ScopedFilter::NotIn(self.alloc_str(field), self.alloc_slice_iter(values))
219 }
220
221 #[inline]
223 pub fn contains<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
224 ScopedFilter::Contains(self.alloc_str(field), value.into())
225 }
226
227 #[inline]
229 pub fn starts_with<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
230 ScopedFilter::StartsWith(self.alloc_str(field), value.into())
231 }
232
233 #[inline]
235 pub fn ends_with<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
236 ScopedFilter::EndsWith(self.alloc_str(field), value.into())
237 }
238
239 #[inline]
241 pub fn is_null(&self, field: &str) -> ScopedFilter<'a> {
242 ScopedFilter::IsNull(self.alloc_str(field))
243 }
244
245 #[inline]
247 pub fn is_not_null(&self, field: &str) -> ScopedFilter<'a> {
248 ScopedFilter::IsNotNull(self.alloc_str(field))
249 }
250
251 #[inline]
253 pub fn and(&self, filters: Vec<ScopedFilter<'a>>) -> ScopedFilter<'a> {
254 let filters: Vec<_> = filters
256 .into_iter()
257 .filter(|f| !matches!(f, ScopedFilter::None))
258 .collect();
259
260 match filters.len() {
261 0 => ScopedFilter::None,
262 1 => filters.into_iter().next().unwrap(),
263 _ => ScopedFilter::And(self.alloc_slice_iter(filters)),
264 }
265 }
266
267 #[inline]
269 pub fn or(&self, filters: Vec<ScopedFilter<'a>>) -> ScopedFilter<'a> {
270 let filters: Vec<_> = filters
271 .into_iter()
272 .filter(|f| !matches!(f, ScopedFilter::None))
273 .collect();
274
275 match filters.len() {
276 0 => ScopedFilter::None,
277 1 => filters.into_iter().next().unwrap(),
278 _ => ScopedFilter::Or(self.alloc_slice_iter(filters)),
279 }
280 }
281
282 #[inline]
284 pub fn not(&self, filter: ScopedFilter<'a>) -> ScopedFilter<'a> {
285 if matches!(filter, ScopedFilter::None) {
286 return ScopedFilter::None;
287 }
288 ScopedFilter::Not(self.alloc(filter))
289 }
290
291 pub fn build_select(&self, table: &str, filter: ScopedFilter<'a>) -> String {
297 let mut sql = String::with_capacity(128);
298 sql.push_str("SELECT * FROM ");
299 sql.push_str(table);
300
301 if !matches!(filter, ScopedFilter::None) {
302 sql.push_str(" WHERE ");
303 filter.write_sql(&mut sql, &mut 1);
304 }
305
306 sql
307 }
308
309 pub fn build_select_columns(
311 &self,
312 table: &str,
313 columns: &[&str],
314 filter: ScopedFilter<'a>,
315 ) -> String {
316 let mut sql = String::with_capacity(128);
317 sql.push_str("SELECT ");
318
319 for (i, col) in columns.iter().enumerate() {
320 if i > 0 {
321 sql.push_str(", ");
322 }
323 sql.push_str(col);
324 }
325
326 sql.push_str(" FROM ");
327 sql.push_str(table);
328
329 if !matches!(filter, ScopedFilter::None) {
330 sql.push_str(" WHERE ");
331 filter.write_sql(&mut sql, &mut 1);
332 }
333
334 sql
335 }
336
337 pub fn build_query(&self, query: &ScopedQuery<'a>) -> String {
339 let mut sql = String::with_capacity(256);
340
341 sql.push_str("SELECT ");
343 if query.columns.is_empty() {
344 sql.push('*');
345 } else {
346 for (i, col) in query.columns.iter().enumerate() {
347 if i > 0 {
348 sql.push_str(", ");
349 }
350 sql.push_str(col);
351 }
352 }
353
354 sql.push_str(" FROM ");
356 sql.push_str(query.table);
357
358 if !matches!(query.filter, ScopedFilter::None) {
360 sql.push_str(" WHERE ");
361 query.filter.write_sql(&mut sql, &mut 1);
362 }
363
364 if !query.order_by.is_empty() {
366 sql.push_str(" ORDER BY ");
367 for (i, (col, dir)) in query.order_by.iter().enumerate() {
368 if i > 0 {
369 sql.push_str(", ");
370 }
371 sql.push_str(col);
372 sql.push(' ');
373 sql.push_str(dir);
374 }
375 }
376
377 if let Some(limit) = query.limit {
379 write!(sql, " LIMIT {}", limit).unwrap();
380 }
381
382 if let Some(offset) = query.offset {
384 write!(sql, " OFFSET {}", offset).unwrap();
385 }
386
387 sql
388 }
389
390 pub fn query(&self, table: &str) -> ScopedQuery<'a> {
392 ScopedQuery {
393 table: self.alloc_str(table),
394 columns: &[],
395 filter: ScopedFilter::None,
396 order_by: &[],
397 limit: None,
398 offset: None,
399 }
400 }
401}
402
403#[derive(Debug, Clone)]
409pub enum ScopedFilter<'a> {
410 None,
412 Equals(&'a str, ScopedValue<'a>),
414 NotEquals(&'a str, ScopedValue<'a>),
416 Lt(&'a str, ScopedValue<'a>),
418 Lte(&'a str, ScopedValue<'a>),
420 Gt(&'a str, ScopedValue<'a>),
422 Gte(&'a str, ScopedValue<'a>),
424 In(&'a str, &'a [ScopedValue<'a>]),
426 NotIn(&'a str, &'a [ScopedValue<'a>]),
428 Contains(&'a str, ScopedValue<'a>),
430 StartsWith(&'a str, ScopedValue<'a>),
432 EndsWith(&'a str, ScopedValue<'a>),
434 IsNull(&'a str),
436 IsNotNull(&'a str),
438 And(&'a [ScopedFilter<'a>]),
440 Or(&'a [ScopedFilter<'a>]),
442 Not(&'a ScopedFilter<'a>),
444}
445
446impl<'a> ScopedFilter<'a> {
447 pub fn write_sql(&self, buf: &mut String, param_idx: &mut usize) {
449 match self {
450 ScopedFilter::None => {}
451 ScopedFilter::Equals(field, _) => {
452 write!(buf, "{} = ${}", field, param_idx).unwrap();
453 *param_idx += 1;
454 }
455 ScopedFilter::NotEquals(field, _) => {
456 write!(buf, "{} != ${}", field, param_idx).unwrap();
457 *param_idx += 1;
458 }
459 ScopedFilter::Lt(field, _) => {
460 write!(buf, "{} < ${}", field, param_idx).unwrap();
461 *param_idx += 1;
462 }
463 ScopedFilter::Lte(field, _) => {
464 write!(buf, "{} <= ${}", field, param_idx).unwrap();
465 *param_idx += 1;
466 }
467 ScopedFilter::Gt(field, _) => {
468 write!(buf, "{} > ${}", field, param_idx).unwrap();
469 *param_idx += 1;
470 }
471 ScopedFilter::Gte(field, _) => {
472 write!(buf, "{} >= ${}", field, param_idx).unwrap();
473 *param_idx += 1;
474 }
475 ScopedFilter::In(field, values) => {
476 write!(buf, "{} IN (", field).unwrap();
477 for (i, _) in values.iter().enumerate() {
478 if i > 0 {
479 buf.push_str(", ");
480 }
481 write!(buf, "${}", param_idx).unwrap();
482 *param_idx += 1;
483 }
484 buf.push(')');
485 }
486 ScopedFilter::NotIn(field, values) => {
487 write!(buf, "{} NOT IN (", field).unwrap();
488 for (i, _) in values.iter().enumerate() {
489 if i > 0 {
490 buf.push_str(", ");
491 }
492 write!(buf, "${}", param_idx).unwrap();
493 *param_idx += 1;
494 }
495 buf.push(')');
496 }
497 ScopedFilter::Contains(field, _) => {
498 write!(buf, "{} LIKE ${}", field, param_idx).unwrap();
499 *param_idx += 1;
500 }
501 ScopedFilter::StartsWith(field, _) => {
502 write!(buf, "{} LIKE ${}", field, param_idx).unwrap();
503 *param_idx += 1;
504 }
505 ScopedFilter::EndsWith(field, _) => {
506 write!(buf, "{} LIKE ${}", field, param_idx).unwrap();
507 *param_idx += 1;
508 }
509 ScopedFilter::IsNull(field) => {
510 write!(buf, "{} IS NULL", field).unwrap();
511 }
512 ScopedFilter::IsNotNull(field) => {
513 write!(buf, "{} IS NOT NULL", field).unwrap();
514 }
515 ScopedFilter::And(filters) => {
516 buf.push('(');
517 for (i, filter) in filters.iter().enumerate() {
518 if i > 0 {
519 buf.push_str(" AND ");
520 }
521 filter.write_sql(buf, param_idx);
522 }
523 buf.push(')');
524 }
525 ScopedFilter::Or(filters) => {
526 buf.push('(');
527 for (i, filter) in filters.iter().enumerate() {
528 if i > 0 {
529 buf.push_str(" OR ");
530 }
531 filter.write_sql(buf, param_idx);
532 }
533 buf.push(')');
534 }
535 ScopedFilter::Not(filter) => {
536 buf.push_str("NOT (");
537 filter.write_sql(buf, param_idx);
538 buf.push(')');
539 }
540 }
541 }
542}
543
544#[derive(Debug, Clone)]
550pub enum ScopedValue<'a> {
551 Null,
553 Bool(bool),
555 Int(i64),
557 Float(f64),
559 String(&'a str),
561 Interned(InternedStr),
563}
564
565impl<'a> From<bool> for ScopedValue<'a> {
566 fn from(v: bool) -> Self {
567 ScopedValue::Bool(v)
568 }
569}
570
571impl<'a> From<i32> for ScopedValue<'a> {
572 fn from(v: i32) -> Self {
573 ScopedValue::Int(v as i64)
574 }
575}
576
577impl<'a> From<i64> for ScopedValue<'a> {
578 fn from(v: i64) -> Self {
579 ScopedValue::Int(v)
580 }
581}
582
583impl<'a> From<f64> for ScopedValue<'a> {
584 fn from(v: f64) -> Self {
585 ScopedValue::Float(v)
586 }
587}
588
589impl<'a> From<&'a str> for ScopedValue<'a> {
590 fn from(v: &'a str) -> Self {
591 ScopedValue::String(v)
592 }
593}
594
595impl<'a> From<InternedStr> for ScopedValue<'a> {
596 fn from(v: InternedStr) -> Self {
597 ScopedValue::Interned(v)
598 }
599}
600
601#[derive(Debug, Clone)]
607pub struct ScopedQuery<'a> {
608 pub table: &'a str,
610 pub columns: &'a [&'a str],
612 pub filter: ScopedFilter<'a>,
614 pub order_by: &'a [(&'a str, &'a str)],
616 pub limit: Option<usize>,
618 pub offset: Option<usize>,
620}
621
622impl<'a> ScopedQuery<'a> {
623 pub fn select(mut self, columns: &'a [&'a str]) -> Self {
625 self.columns = columns;
626 self
627 }
628
629 pub fn filter(mut self, filter: ScopedFilter<'a>) -> Self {
631 self.filter = filter;
632 self
633 }
634
635 pub fn order_by(mut self, order_by: &'a [(&'a str, &'a str)]) -> Self {
637 self.order_by = order_by;
638 self
639 }
640
641 pub fn limit(mut self, limit: usize) -> Self {
643 self.limit = Some(limit);
644 self
645 }
646
647 pub fn offset(mut self, offset: usize) -> Self {
649 self.offset = Some(offset);
650 self
651 }
652}
653
654#[derive(Debug, Clone, Copy, Default)]
660pub struct ArenaStats {
661 pub allocations: usize,
663 pub total_bytes: usize,
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670
671 #[test]
672 fn test_arena_basic_filter() {
673 let arena = QueryArena::new();
674
675 let sql = arena.scope(|scope| scope.build_select("users", scope.eq("id", 42)));
676
677 assert!(sql.contains("SELECT * FROM users"));
678 assert!(sql.contains("WHERE"));
679 assert!(sql.contains("id = $1"));
680 }
681
682 #[test]
683 fn test_arena_complex_filter() {
684 let arena = QueryArena::new();
685
686 let sql = arena.scope(|scope| {
687 let filter = scope.and(vec![
688 scope.eq("active", true),
689 scope.or(vec![scope.gt("age", 18), scope.is_not_null("verified_at")]),
690 ]);
691 scope.build_select("users", filter)
692 });
693
694 assert!(sql.contains("AND"));
695 assert!(sql.contains("OR"));
696 }
697
698 #[test]
699 fn test_arena_reset() {
700 let mut arena = QueryArena::with_capacity(1024);
701
702 let _ = arena.scope(|scope| scope.build_select("users", scope.eq("id", 1)));
704 let bytes1 = arena.allocated_bytes();
705
706 arena.reset();
708
709 let _ = arena.scope(|scope| scope.build_select("posts", scope.eq("id", 2)));
711 let bytes2 = arena.allocated_bytes();
712
713 assert!(bytes2 <= bytes1 * 2);
715 }
716
717 #[test]
718 fn test_arena_query_builder() {
719 let arena = QueryArena::new();
720
721 let sql = arena.scope(|scope| {
722 let query = scope
723 .query("users")
724 .filter(scope.eq("active", true))
725 .limit(10)
726 .offset(20);
727 scope.build_query(&query)
728 });
729
730 assert!(sql.contains("SELECT * FROM users"));
731 assert!(sql.contains("LIMIT 10"));
732 assert!(sql.contains("OFFSET 20"));
733 }
734
735 #[test]
736 fn test_arena_in_filter() {
737 let arena = QueryArena::new();
738
739 let sql = arena.scope(|scope| {
740 let filter = scope.is_in(
741 "status",
742 vec!["pending".into(), "processing".into(), "completed".into()],
743 );
744 scope.build_select("orders", filter)
745 });
746
747 assert!(sql.contains("IN"));
748 assert!(sql.contains("$1"));
749 assert!(sql.contains("$2"));
750 assert!(sql.contains("$3"));
751 }
752
753 #[test]
754 fn test_arena_stats() {
755 let arena = QueryArena::new();
756
757 arena.scope(|scope| {
758 let _ = scope.alloc_str("test string");
759 let _ = scope.alloc_str("another string");
760 });
761
762 let stats = arena.stats();
763 assert_eq!(stats.allocations, 2);
764 }
765}