1use serde::{Deserialize, Serialize};
4use std::borrow::Cow;
5use std::fmt;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum SortOrder {
10 Asc,
12 Desc,
14}
15
16impl SortOrder {
17 pub fn as_sql(&self) -> &'static str {
19 match self {
20 Self::Asc => "ASC",
21 Self::Desc => "DESC",
22 }
23 }
24}
25
26impl fmt::Display for SortOrder {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 write!(f, "{}", self.as_sql())
29 }
30}
31
32impl Default for SortOrder {
33 fn default() -> Self {
34 Self::Asc
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
40pub enum NullsOrder {
41 First,
43 Last,
45}
46
47impl NullsOrder {
48 pub fn as_sql(&self) -> &'static str {
50 match self {
51 Self::First => "NULLS FIRST",
52 Self::Last => "NULLS LAST",
53 }
54 }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct OrderByField {
60 pub column: Cow<'static, str>,
62 pub order: SortOrder,
64 pub nulls: Option<NullsOrder>,
66}
67
68impl OrderByField {
69 pub fn new(column: impl Into<Cow<'static, str>>, order: SortOrder) -> Self {
71 Self {
72 column: column.into(),
73 order,
74 nulls: None,
75 }
76 }
77
78 #[inline]
80 pub const fn new_static(column: &'static str, order: SortOrder) -> Self {
81 Self {
82 column: Cow::Borrowed(column),
83 order,
84 nulls: None,
85 }
86 }
87
88 pub fn nulls(mut self, nulls: NullsOrder) -> Self {
90 self.nulls = Some(nulls);
91 self
92 }
93
94 pub fn asc(column: impl Into<Cow<'static, str>>) -> Self {
96 Self::new(column, SortOrder::Asc)
97 }
98
99 pub fn desc(column: impl Into<Cow<'static, str>>) -> Self {
101 Self::new(column, SortOrder::Desc)
102 }
103
104 #[inline]
106 pub const fn asc_static(column: &'static str) -> Self {
107 Self::new_static(column, SortOrder::Asc)
108 }
109
110 #[inline]
112 pub const fn desc_static(column: &'static str) -> Self {
113 Self::new_static(column, SortOrder::Desc)
114 }
115
116 pub fn to_sql(&self) -> String {
120 let cap = self.column.len() + 5 + if self.nulls.is_some() { 12 } else { 0 };
122 let mut sql = String::with_capacity(cap);
123 self.write_sql(&mut sql);
124 sql
125 }
126
127 #[inline]
141 pub fn write_sql(&self, buffer: &mut String) {
142 buffer.push_str(&self.column);
143 buffer.push(' ');
144 buffer.push_str(self.order.as_sql());
145 if let Some(nulls) = self.nulls {
146 buffer.push(' ');
147 buffer.push_str(nulls.as_sql());
148 }
149 }
150
151 #[inline]
153 pub fn estimated_len(&self) -> usize {
154 self.column.len() + 5 + if self.nulls.is_some() { 12 } else { 0 }
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
160pub enum OrderBy {
161 Field(OrderByField),
163 Fields(Box<[OrderByField]>),
165}
166
167impl OrderBy {
168 pub fn none() -> Self {
170 Self::Fields(Box::new([]))
171 }
172
173 pub fn is_empty(&self) -> bool {
175 match self {
176 Self::Field(_) => false,
177 Self::Fields(fields) => fields.is_empty(),
178 }
179 }
180
181 pub fn then(self, field: OrderByField) -> Self {
183 match self {
184 Self::Field(existing) => Self::Fields(vec![existing, field].into_boxed_slice()),
185 Self::Fields(existing) => {
186 let mut fields: Vec<_> = existing.into_vec();
187 fields.push(field);
188 Self::Fields(fields.into_boxed_slice())
189 }
190 }
191 }
192
193 pub fn from_fields(fields: impl IntoIterator<Item = OrderByField>) -> Self {
195 let fields: Vec<_> = fields.into_iter().collect();
196 match fields.len() {
197 0 => Self::none(),
198 1 => Self::Field(fields.into_iter().next().unwrap()),
199 _ => Self::Fields(fields.into_boxed_slice()),
200 }
201 }
202
203 pub fn to_sql(&self) -> String {
207 match self {
208 Self::Field(field) => field.to_sql(),
209 Self::Fields(fields) if fields.is_empty() => String::new(),
210 Self::Fields(fields) => {
211 let cap: usize = fields.iter().map(|f| f.estimated_len() + 2).sum();
213 let mut sql = String::with_capacity(cap);
214 self.write_sql(&mut sql);
215 sql
216 }
217 }
218 }
219
220 #[inline]
237 pub fn write_sql(&self, buffer: &mut String) {
238 match self {
239 Self::Field(field) => field.write_sql(buffer),
240 Self::Fields(fields) => {
241 for (i, field) in fields.iter().enumerate() {
242 if i > 0 {
243 buffer.push_str(", ");
244 }
245 field.write_sql(buffer);
246 }
247 }
248 }
249 }
250
251 #[inline]
253 pub fn field_count(&self) -> usize {
254 match self {
255 Self::Field(_) => 1,
256 Self::Fields(fields) => fields.len(),
257 }
258 }
259}
260
261impl From<OrderByField> for OrderBy {
262 fn from(field: OrderByField) -> Self {
263 Self::Field(field)
264 }
265}
266
267impl From<Vec<OrderByField>> for OrderBy {
268 fn from(fields: Vec<OrderByField>) -> Self {
269 match fields.len() {
270 0 => Self::none(),
271 1 => Self::Field(fields.into_iter().next().unwrap()),
272 _ => Self::Fields(fields.into_boxed_slice()),
273 }
274 }
275}
276
277#[derive(Debug)]
279pub struct OrderByBuilder {
280 fields: Vec<OrderByField>,
281}
282
283impl OrderByBuilder {
284 #[inline]
286 pub fn with_capacity(capacity: usize) -> Self {
287 Self {
288 fields: Vec::with_capacity(capacity),
289 }
290 }
291
292 #[inline]
294 pub fn push(mut self, field: OrderByField) -> Self {
295 self.fields.push(field);
296 self
297 }
298
299 #[inline]
301 pub fn asc(self, column: impl Into<Cow<'static, str>>) -> Self {
302 self.push(OrderByField::asc(column))
303 }
304
305 #[inline]
307 pub fn desc(self, column: impl Into<Cow<'static, str>>) -> Self {
308 self.push(OrderByField::desc(column))
309 }
310
311 #[inline]
313 pub fn build(self) -> OrderBy {
314 OrderBy::from(self.fields)
315 }
316}
317
318pub mod order_patterns {
320 use super::*;
321
322 pub const CREATED_AT_DESC: OrderByField = OrderByField::desc_static("created_at");
324
325 pub const CREATED_AT_ASC: OrderByField = OrderByField::asc_static("created_at");
327
328 pub const UPDATED_AT_DESC: OrderByField = OrderByField::desc_static("updated_at");
330
331 pub const UPDATED_AT_ASC: OrderByField = OrderByField::asc_static("updated_at");
333
334 pub const ID_ASC: OrderByField = OrderByField::asc_static("id");
336
337 pub const ID_DESC: OrderByField = OrderByField::desc_static("id");
339
340 pub const NAME_ASC: OrderByField = OrderByField::asc_static("name");
342
343 pub const NAME_DESC: OrderByField = OrderByField::desc_static("name");
345
346 pub const PRICE_ASC: OrderByField = OrderByField::asc_static("price");
348
349 pub const PRICE_DESC: OrderByField = OrderByField::desc_static("price");
351}
352
353#[derive(Debug, Clone, PartialEq, Eq)]
355pub enum Select {
356 All,
358 Fields(Vec<String>),
360 Field(String),
362}
363
364impl Select {
365 pub fn all() -> Self {
367 Self::All
368 }
369
370 pub fn fields(fields: impl IntoIterator<Item = impl Into<String>>) -> Self {
372 Self::Fields(fields.into_iter().map(Into::into).collect())
373 }
374
375 pub fn field(field: impl Into<String>) -> Self {
377 Self::Field(field.into())
378 }
379
380 pub fn is_all(&self) -> bool {
382 matches!(self, Self::All)
383 }
384
385 pub fn field_names(&self) -> Vec<&str> {
387 match self {
388 Self::All => vec!["*"],
389 Self::Fields(fields) => fields.iter().map(String::as_str).collect(),
390 Self::Field(field) => vec![field.as_str()],
391 }
392 }
393
394 pub fn to_sql(&self) -> String {
396 match self {
397 Self::All => "*".to_string(),
398 Self::Fields(fields) => {
399 let cap: usize = fields.iter().map(|f| f.len() + 2).sum();
401 let mut sql = String::with_capacity(cap);
402 self.write_sql(&mut sql);
403 sql
404 }
405 Self::Field(field) => field.clone(),
406 }
407 }
408
409 #[inline]
411 pub fn write_sql(&self, buffer: &mut String) {
412 match self {
413 Self::All => buffer.push('*'),
414 Self::Fields(fields) => {
415 for (i, field) in fields.iter().enumerate() {
416 if i > 0 {
417 buffer.push_str(", ");
418 }
419 buffer.push_str(field);
420 }
421 }
422 Self::Field(field) => buffer.push_str(field),
423 }
424 }
425}
426
427impl Default for Select {
428 fn default() -> Self {
429 Self::All
430 }
431}
432
433#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
435pub enum SetParam<T> {
436 Set(T),
438 Unset,
440}
441
442impl<T> SetParam<T> {
443 pub fn is_set(&self) -> bool {
445 matches!(self, Self::Set(_))
446 }
447
448 pub fn get(&self) -> Option<&T> {
450 match self {
451 Self::Set(v) => Some(v),
452 Self::Unset => None,
453 }
454 }
455
456 pub fn take(self) -> Option<T> {
458 match self {
459 Self::Set(v) => Some(v),
460 Self::Unset => None,
461 }
462 }
463
464 pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> SetParam<U> {
466 match self {
467 Self::Set(v) => SetParam::Set(f(v)),
468 Self::Unset => SetParam::Unset,
469 }
470 }
471}
472
473impl<T> Default for SetParam<T> {
474 fn default() -> Self {
475 Self::Unset
476 }
477}
478
479impl<T> From<T> for SetParam<T> {
480 fn from(value: T) -> Self {
481 Self::Set(value)
482 }
483}
484
485impl<T> From<Option<T>> for SetParam<T> {
486 fn from(opt: Option<T>) -> Self {
487 match opt {
488 Some(v) => Self::Set(v),
489 None => Self::Unset,
490 }
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_sort_order() {
500 assert_eq!(SortOrder::Asc.as_sql(), "ASC");
501 assert_eq!(SortOrder::Desc.as_sql(), "DESC");
502 }
503
504 #[test]
505 fn test_order_by_field() {
506 let field = OrderByField::desc("created_at");
507 assert_eq!(field.to_sql(), "created_at DESC");
508
509 let field_with_nulls = OrderByField::asc("name").nulls(NullsOrder::Last);
510 assert_eq!(field_with_nulls.to_sql(), "name ASC NULLS LAST");
511 }
512
513 #[test]
514 fn test_order_by_field_static() {
515 let field = OrderByField::desc_static("created_at");
516 assert_eq!(field.to_sql(), "created_at DESC");
517
518 let field = OrderByField::asc_static("id");
519 assert_eq!(field.to_sql(), "id ASC");
520 }
521
522 #[test]
523 fn test_order_by_field_write_sql() {
524 let field = OrderByField::desc("created_at");
525 let mut buffer = String::with_capacity(32);
526 field.write_sql(&mut buffer);
527 assert_eq!(buffer, "created_at DESC");
528
529 let field = OrderByField::asc("name").nulls(NullsOrder::First);
530 let mut buffer = String::with_capacity(32);
531 field.write_sql(&mut buffer);
532 assert_eq!(buffer, "name ASC NULLS FIRST");
533 }
534
535 #[test]
536 fn test_order_by_multiple() {
537 let order = OrderBy::Field(OrderByField::desc("created_at"))
538 .then(OrderByField::asc("name"));
539 assert_eq!(order.to_sql(), "created_at DESC, name ASC");
540 }
541
542 #[test]
543 fn test_order_by_from_fields() {
544 let order = OrderBy::from_fields([
545 OrderByField::desc("created_at"),
546 OrderByField::asc("id"),
547 ]);
548 assert_eq!(order.to_sql(), "created_at DESC, id ASC");
549 assert_eq!(order.field_count(), 2);
550 }
551
552 #[test]
553 fn test_order_by_write_sql() {
554 let order = OrderBy::from_fields([
555 OrderByField::desc("created_at"),
556 OrderByField::asc("id"),
557 ]);
558 let mut buffer = String::with_capacity(64);
559 buffer.push_str("ORDER BY ");
560 order.write_sql(&mut buffer);
561 assert_eq!(buffer, "ORDER BY created_at DESC, id ASC");
562 }
563
564 #[test]
565 fn test_order_by_builder() {
566 let order = OrderByBuilder::with_capacity(3)
567 .desc("created_at")
568 .asc("name")
569 .asc("id")
570 .build();
571 assert_eq!(order.to_sql(), "created_at DESC, name ASC, id ASC");
572 assert_eq!(order.field_count(), 3);
573 }
574
575 #[test]
576 fn test_order_patterns() {
577 assert_eq!(order_patterns::CREATED_AT_DESC.to_sql(), "created_at DESC");
578 assert_eq!(order_patterns::ID_ASC.to_sql(), "id ASC");
579 assert_eq!(order_patterns::NAME_ASC.to_sql(), "name ASC");
580 }
581
582 #[test]
583 fn test_select() {
584 assert_eq!(Select::all().to_sql(), "*");
585 assert_eq!(Select::field("id").to_sql(), "id");
586 assert_eq!(Select::fields(["id", "name", "email"]).to_sql(), "id, name, email");
587 }
588
589 #[test]
590 fn test_select_write_sql() {
591 let select = Select::fields(["id", "name", "email"]);
592 let mut buffer = String::with_capacity(32);
593 buffer.push_str("SELECT ");
594 select.write_sql(&mut buffer);
595 assert_eq!(buffer, "SELECT id, name, email");
596 }
597
598 #[test]
599 fn test_set_param() {
600 let set: SetParam<i32> = SetParam::Set(42);
601 assert!(set.is_set());
602 assert_eq!(set.get(), Some(&42));
603
604 let unset: SetParam<i32> = SetParam::Unset;
605 assert!(!unset.is_set());
606 assert_eq!(unset.get(), None);
607 }
608}
609