1use bumpalo::Bump;
34use std::cell::Cell;
35use std::fmt::Write;
36
37use super::interning::InternedStr;
38
39pub struct QueryArena {
47 bump: Bump,
48 stats: Cell<ArenaStats>,
49}
50
51impl QueryArena {
52 pub fn new() -> Self {
54 Self {
55 bump: Bump::new(),
56 stats: Cell::new(ArenaStats::default()),
57 }
58 }
59
60 pub fn with_capacity(capacity: usize) -> Self {
62 Self {
63 bump: Bump::with_capacity(capacity),
64 stats: Cell::new(ArenaStats::default()),
65 }
66 }
67
68 pub fn scope<F, R>(&self, f: F) -> R
73 where
74 F: FnOnce(&ArenaScope<'_>) -> R,
75 {
76 let scope = ArenaScope::new(&self.bump, &self.stats);
77 f(&scope)
78 }
79
80 pub fn reset(&mut self) {
84 self.bump.reset();
85 self.stats.set(ArenaStats::default());
86 }
87
88 pub fn allocated_bytes(&self) -> usize {
90 self.bump.allocated_bytes()
91 }
92
93 pub fn stats(&self) -> ArenaStats {
95 self.stats.get()
96 }
97}
98
99impl Default for QueryArena {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105pub struct ArenaScope<'a> {
113 bump: &'a Bump,
114 stats: &'a Cell<ArenaStats>,
115}
116
117impl<'a> ArenaScope<'a> {
118 fn new(bump: &'a Bump, stats: &'a Cell<ArenaStats>) -> Self {
119 Self { bump, stats }
120 }
121
122 fn record_alloc(&self, bytes: usize) {
123 let mut s = self.stats.get();
124 s.allocations += 1;
125 s.total_bytes += bytes;
126 self.stats.set(s);
127 }
128
129 #[inline]
131 pub fn alloc_str(&self, s: &str) -> &'a str {
132 self.record_alloc(s.len());
133 self.bump.alloc_str(s)
134 }
135
136 #[inline]
138 pub fn alloc_slice<T: Copy>(&self, slice: &[T]) -> &'a [T] {
139 self.record_alloc(std::mem::size_of_val(slice));
140 self.bump.alloc_slice_copy(slice)
141 }
142
143 #[inline]
145 pub fn alloc_slice_iter<T, I>(&self, iter: I) -> &'a [T]
146 where
147 I: IntoIterator<Item = T>,
148 I::IntoIter: ExactSizeIterator,
149 {
150 let iter = iter.into_iter();
151 self.record_alloc(iter.len() * std::mem::size_of::<T>());
152 self.bump.alloc_slice_fill_iter(iter)
153 }
154
155 #[inline]
157 pub fn alloc<T>(&self, value: T) -> &'a T {
158 self.record_alloc(std::mem::size_of::<T>());
159 self.bump.alloc(value)
160 }
161
162 #[inline]
164 pub fn alloc_mut<T>(&self, value: T) -> &'a mut T {
165 self.record_alloc(std::mem::size_of::<T>());
166 self.bump.alloc(value)
167 }
168
169 #[inline]
175 pub fn eq<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
176 ScopedFilter::Equals(self.alloc_str(field), value.into())
177 }
178
179 #[inline]
181 pub fn ne<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
182 ScopedFilter::NotEquals(self.alloc_str(field), value.into())
183 }
184
185 #[inline]
187 pub fn lt<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
188 ScopedFilter::Lt(self.alloc_str(field), value.into())
189 }
190
191 #[inline]
193 pub fn lte<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
194 ScopedFilter::Lte(self.alloc_str(field), value.into())
195 }
196
197 #[inline]
199 pub fn gt<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
200 ScopedFilter::Gt(self.alloc_str(field), value.into())
201 }
202
203 #[inline]
205 pub fn gte<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
206 ScopedFilter::Gte(self.alloc_str(field), value.into())
207 }
208
209 #[inline]
211 pub fn is_in(&self, field: &str, values: Vec<ScopedValue<'a>>) -> ScopedFilter<'a> {
212 ScopedFilter::In(self.alloc_str(field), self.alloc_slice_iter(values))
213 }
214
215 #[inline]
217 pub fn not_in(&self, field: &str, values: Vec<ScopedValue<'a>>) -> ScopedFilter<'a> {
218 ScopedFilter::NotIn(self.alloc_str(field), self.alloc_slice_iter(values))
219 }
220
221 #[inline]
223 pub fn contains<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
224 ScopedFilter::Contains(self.alloc_str(field), value.into())
225 }
226
227 #[inline]
229 pub fn starts_with<V: Into<ScopedValue<'a>>>(
230 &self,
231 field: &str,
232 value: V,
233 ) -> ScopedFilter<'a> {
234 ScopedFilter::StartsWith(self.alloc_str(field), value.into())
235 }
236
237 #[inline]
239 pub fn ends_with<V: Into<ScopedValue<'a>>>(&self, field: &str, value: V) -> ScopedFilter<'a> {
240 ScopedFilter::EndsWith(self.alloc_str(field), value.into())
241 }
242
243 #[inline]
245 pub fn is_null(&self, field: &str) -> ScopedFilter<'a> {
246 ScopedFilter::IsNull(self.alloc_str(field))
247 }
248
249 #[inline]
251 pub fn is_not_null(&self, field: &str) -> ScopedFilter<'a> {
252 ScopedFilter::IsNotNull(self.alloc_str(field))
253 }
254
255 #[inline]
257 pub fn and(&self, filters: Vec<ScopedFilter<'a>>) -> ScopedFilter<'a> {
258 let filters: Vec<_> = filters
260 .into_iter()
261 .filter(|f| !matches!(f, ScopedFilter::None))
262 .collect();
263
264 match filters.len() {
265 0 => ScopedFilter::None,
266 1 => filters.into_iter().next().unwrap(),
267 _ => ScopedFilter::And(self.alloc_slice_iter(filters)),
268 }
269 }
270
271 #[inline]
273 pub fn or(&self, filters: Vec<ScopedFilter<'a>>) -> ScopedFilter<'a> {
274 let filters: Vec<_> = filters
275 .into_iter()
276 .filter(|f| !matches!(f, ScopedFilter::None))
277 .collect();
278
279 match filters.len() {
280 0 => ScopedFilter::None,
281 1 => filters.into_iter().next().unwrap(),
282 _ => ScopedFilter::Or(self.alloc_slice_iter(filters)),
283 }
284 }
285
286 #[inline]
288 pub fn not(&self, filter: ScopedFilter<'a>) -> ScopedFilter<'a> {
289 if matches!(filter, ScopedFilter::None) {
290 return ScopedFilter::None;
291 }
292 ScopedFilter::Not(self.alloc(filter))
293 }
294
295 pub fn build_select(&self, table: &str, filter: ScopedFilter<'a>) -> String {
301 let mut sql = String::with_capacity(128);
302 sql.push_str("SELECT * FROM ");
303 sql.push_str(table);
304
305 if !matches!(filter, ScopedFilter::None) {
306 sql.push_str(" WHERE ");
307 filter.write_sql(&mut sql, &mut 1);
308 }
309
310 sql
311 }
312
313 pub fn build_select_columns(
315 &self,
316 table: &str,
317 columns: &[&str],
318 filter: ScopedFilter<'a>,
319 ) -> String {
320 let mut sql = String::with_capacity(128);
321 sql.push_str("SELECT ");
322
323 for (i, col) in columns.iter().enumerate() {
324 if i > 0 {
325 sql.push_str(", ");
326 }
327 sql.push_str(col);
328 }
329
330 sql.push_str(" FROM ");
331 sql.push_str(table);
332
333 if !matches!(filter, ScopedFilter::None) {
334 sql.push_str(" WHERE ");
335 filter.write_sql(&mut sql, &mut 1);
336 }
337
338 sql
339 }
340
341 pub fn build_query(&self, query: &ScopedQuery<'a>) -> String {
343 let mut sql = String::with_capacity(256);
344
345 sql.push_str("SELECT ");
347 if query.columns.is_empty() {
348 sql.push('*');
349 } else {
350 for (i, col) in query.columns.iter().enumerate() {
351 if i > 0 {
352 sql.push_str(", ");
353 }
354 sql.push_str(col);
355 }
356 }
357
358 sql.push_str(" FROM ");
360 sql.push_str(query.table);
361
362 if !matches!(query.filter, ScopedFilter::None) {
364 sql.push_str(" WHERE ");
365 query.filter.write_sql(&mut sql, &mut 1);
366 }
367
368 if !query.order_by.is_empty() {
370 sql.push_str(" ORDER BY ");
371 for (i, (col, dir)) in query.order_by.iter().enumerate() {
372 if i > 0 {
373 sql.push_str(", ");
374 }
375 sql.push_str(col);
376 sql.push(' ');
377 sql.push_str(dir);
378 }
379 }
380
381 if let Some(limit) = query.limit {
383 write!(sql, " LIMIT {}", limit).unwrap();
384 }
385
386 if let Some(offset) = query.offset {
388 write!(sql, " OFFSET {}", offset).unwrap();
389 }
390
391 sql
392 }
393
394 pub fn query(&self, table: &str) -> ScopedQuery<'a> {
396 ScopedQuery {
397 table: self.alloc_str(table),
398 columns: &[],
399 filter: ScopedFilter::None,
400 order_by: &[],
401 limit: None,
402 offset: None,
403 }
404 }
405}
406
407#[derive(Debug, Clone)]
413pub enum ScopedFilter<'a> {
414 None,
416 Equals(&'a str, ScopedValue<'a>),
418 NotEquals(&'a str, ScopedValue<'a>),
420 Lt(&'a str, ScopedValue<'a>),
422 Lte(&'a str, ScopedValue<'a>),
424 Gt(&'a str, ScopedValue<'a>),
426 Gte(&'a str, ScopedValue<'a>),
428 In(&'a str, &'a [ScopedValue<'a>]),
430 NotIn(&'a str, &'a [ScopedValue<'a>]),
432 Contains(&'a str, ScopedValue<'a>),
434 StartsWith(&'a str, ScopedValue<'a>),
436 EndsWith(&'a str, ScopedValue<'a>),
438 IsNull(&'a str),
440 IsNotNull(&'a str),
442 And(&'a [ScopedFilter<'a>]),
444 Or(&'a [ScopedFilter<'a>]),
446 Not(&'a ScopedFilter<'a>),
448}
449
450impl<'a> ScopedFilter<'a> {
451 pub fn write_sql(&self, buf: &mut String, param_idx: &mut usize) {
453 match self {
454 ScopedFilter::None => {}
455 ScopedFilter::Equals(field, _) => {
456 write!(buf, "{} = ${}", field, param_idx).unwrap();
457 *param_idx += 1;
458 }
459 ScopedFilter::NotEquals(field, _) => {
460 write!(buf, "{} != ${}", field, param_idx).unwrap();
461 *param_idx += 1;
462 }
463 ScopedFilter::Lt(field, _) => {
464 write!(buf, "{} < ${}", field, param_idx).unwrap();
465 *param_idx += 1;
466 }
467 ScopedFilter::Lte(field, _) => {
468 write!(buf, "{} <= ${}", field, param_idx).unwrap();
469 *param_idx += 1;
470 }
471 ScopedFilter::Gt(field, _) => {
472 write!(buf, "{} > ${}", field, param_idx).unwrap();
473 *param_idx += 1;
474 }
475 ScopedFilter::Gte(field, _) => {
476 write!(buf, "{} >= ${}", field, param_idx).unwrap();
477 *param_idx += 1;
478 }
479 ScopedFilter::In(field, values) => {
480 write!(buf, "{} IN (", field).unwrap();
481 for (i, _) in values.iter().enumerate() {
482 if i > 0 {
483 buf.push_str(", ");
484 }
485 write!(buf, "${}", param_idx).unwrap();
486 *param_idx += 1;
487 }
488 buf.push(')');
489 }
490 ScopedFilter::NotIn(field, values) => {
491 write!(buf, "{} NOT IN (", field).unwrap();
492 for (i, _) in values.iter().enumerate() {
493 if i > 0 {
494 buf.push_str(", ");
495 }
496 write!(buf, "${}", param_idx).unwrap();
497 *param_idx += 1;
498 }
499 buf.push(')');
500 }
501 ScopedFilter::Contains(field, _) => {
502 write!(buf, "{} LIKE ${}", field, param_idx).unwrap();
503 *param_idx += 1;
504 }
505 ScopedFilter::StartsWith(field, _) => {
506 write!(buf, "{} LIKE ${}", field, param_idx).unwrap();
507 *param_idx += 1;
508 }
509 ScopedFilter::EndsWith(field, _) => {
510 write!(buf, "{} LIKE ${}", field, param_idx).unwrap();
511 *param_idx += 1;
512 }
513 ScopedFilter::IsNull(field) => {
514 write!(buf, "{} IS NULL", field).unwrap();
515 }
516 ScopedFilter::IsNotNull(field) => {
517 write!(buf, "{} IS NOT NULL", field).unwrap();
518 }
519 ScopedFilter::And(filters) => {
520 buf.push('(');
521 for (i, filter) in filters.iter().enumerate() {
522 if i > 0 {
523 buf.push_str(" AND ");
524 }
525 filter.write_sql(buf, param_idx);
526 }
527 buf.push(')');
528 }
529 ScopedFilter::Or(filters) => {
530 buf.push('(');
531 for (i, filter) in filters.iter().enumerate() {
532 if i > 0 {
533 buf.push_str(" OR ");
534 }
535 filter.write_sql(buf, param_idx);
536 }
537 buf.push(')');
538 }
539 ScopedFilter::Not(filter) => {
540 buf.push_str("NOT (");
541 filter.write_sql(buf, param_idx);
542 buf.push(')');
543 }
544 }
545 }
546}
547
548#[derive(Debug, Clone)]
554pub enum ScopedValue<'a> {
555 Null,
557 Bool(bool),
559 Int(i64),
561 Float(f64),
563 String(&'a str),
565 Interned(InternedStr),
567}
568
569impl<'a> From<bool> for ScopedValue<'a> {
570 fn from(v: bool) -> Self {
571 ScopedValue::Bool(v)
572 }
573}
574
575impl<'a> From<i32> for ScopedValue<'a> {
576 fn from(v: i32) -> Self {
577 ScopedValue::Int(v as i64)
578 }
579}
580
581impl<'a> From<i64> for ScopedValue<'a> {
582 fn from(v: i64) -> Self {
583 ScopedValue::Int(v)
584 }
585}
586
587impl<'a> From<f64> for ScopedValue<'a> {
588 fn from(v: f64) -> Self {
589 ScopedValue::Float(v)
590 }
591}
592
593impl<'a> From<&'a str> for ScopedValue<'a> {
594 fn from(v: &'a str) -> Self {
595 ScopedValue::String(v)
596 }
597}
598
599impl<'a> From<InternedStr> for ScopedValue<'a> {
600 fn from(v: InternedStr) -> Self {
601 ScopedValue::Interned(v)
602 }
603}
604
605#[derive(Debug, Clone)]
611pub struct ScopedQuery<'a> {
612 pub table: &'a str,
614 pub columns: &'a [&'a str],
616 pub filter: ScopedFilter<'a>,
618 pub order_by: &'a [(&'a str, &'a str)],
620 pub limit: Option<usize>,
622 pub offset: Option<usize>,
624}
625
626impl<'a> ScopedQuery<'a> {
627 pub fn select(mut self, columns: &'a [&'a str]) -> Self {
629 self.columns = columns;
630 self
631 }
632
633 pub fn filter(mut self, filter: ScopedFilter<'a>) -> Self {
635 self.filter = filter;
636 self
637 }
638
639 pub fn order_by(mut self, order_by: &'a [(&'a str, &'a str)]) -> Self {
641 self.order_by = order_by;
642 self
643 }
644
645 pub fn limit(mut self, limit: usize) -> Self {
647 self.limit = Some(limit);
648 self
649 }
650
651 pub fn offset(mut self, offset: usize) -> Self {
653 self.offset = Some(offset);
654 self
655 }
656}
657
658#[derive(Debug, Clone, Copy, Default)]
664pub struct ArenaStats {
665 pub allocations: usize,
667 pub total_bytes: usize,
669}
670
671#[cfg(test)]
672mod tests {
673 use super::*;
674
675 #[test]
676 fn test_arena_basic_filter() {
677 let arena = QueryArena::new();
678
679 let sql = arena.scope(|scope| scope.build_select("users", scope.eq("id", 42)));
680
681 assert!(sql.contains("SELECT * FROM users"));
682 assert!(sql.contains("WHERE"));
683 assert!(sql.contains("id = $1"));
684 }
685
686 #[test]
687 fn test_arena_complex_filter() {
688 let arena = QueryArena::new();
689
690 let sql = arena.scope(|scope| {
691 let filter = scope.and(vec![
692 scope.eq("active", true),
693 scope.or(vec![scope.gt("age", 18), scope.is_not_null("verified_at")]),
694 ]);
695 scope.build_select("users", filter)
696 });
697
698 assert!(sql.contains("AND"));
699 assert!(sql.contains("OR"));
700 }
701
702 #[test]
703 fn test_arena_reset() {
704 let mut arena = QueryArena::with_capacity(1024);
705
706 let _ = arena.scope(|scope| scope.build_select("users", scope.eq("id", 1)));
708 let bytes1 = arena.allocated_bytes();
709
710 arena.reset();
712
713 let _ = arena.scope(|scope| scope.build_select("posts", scope.eq("id", 2)));
715 let bytes2 = arena.allocated_bytes();
716
717 assert!(bytes2 <= bytes1 * 2);
719 }
720
721 #[test]
722 fn test_arena_query_builder() {
723 let arena = QueryArena::new();
724
725 let sql = arena.scope(|scope| {
726 let query = scope
727 .query("users")
728 .filter(scope.eq("active", true))
729 .limit(10)
730 .offset(20);
731 scope.build_query(&query)
732 });
733
734 assert!(sql.contains("SELECT * FROM users"));
735 assert!(sql.contains("LIMIT 10"));
736 assert!(sql.contains("OFFSET 20"));
737 }
738
739 #[test]
740 fn test_arena_in_filter() {
741 let arena = QueryArena::new();
742
743 let sql = arena.scope(|scope| {
744 let filter = scope.is_in(
745 "status",
746 vec!["pending".into(), "processing".into(), "completed".into()],
747 );
748 scope.build_select("orders", filter)
749 });
750
751 assert!(sql.contains("IN"));
752 assert!(sql.contains("$1"));
753 assert!(sql.contains("$2"));
754 assert!(sql.contains("$3"));
755 }
756
757 #[test]
758 fn test_arena_stats() {
759 let arena = QueryArena::new();
760
761 arena.scope(|scope| {
762 let _ = scope.alloc_str("test string");
763 let _ = scope.alloc_str("another string");
764 });
765
766 let stats = arena.stats();
767 assert_eq!(stats.allocations, 2);
768 }
769}
770