rquery_orm/
query.rs

1use std::marker::PhantomData;
2
3use std::sync::Arc;
4
5use crate::db::DatabaseRef;
6use crate::mapping::{Entity, FromRowWithPrefix};
7use anyhow::Result;
8use futures::TryStreamExt;
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11pub enum PlaceholderStyle {
12    AtP,
13    Dollar,
14}
15
16#[derive(Clone, Debug, PartialEq)]
17pub enum SqlParam {
18    I32(i32),
19    I64(i64),
20    Bool(bool),
21    Text(String),
22    Uuid(uuid::Uuid),
23    Decimal(rust_decimal::Decimal),
24    DateTime(chrono::NaiveDateTime),
25    Bytes(Vec<u8>),
26    Null,
27}
28
29pub trait ToParam {
30    fn to_param(self) -> SqlParam;
31}
32
33impl ToParam for i32 {
34    fn to_param(self) -> SqlParam {
35        SqlParam::I32(self)
36    }
37}
38impl ToParam for i64 {
39    fn to_param(self) -> SqlParam {
40        SqlParam::I64(self)
41    }
42}
43impl ToParam for bool {
44    fn to_param(self) -> SqlParam {
45        SqlParam::Bool(self)
46    }
47}
48impl ToParam for String {
49    fn to_param(self) -> SqlParam {
50        SqlParam::Text(self)
51    }
52}
53impl<'a> ToParam for &'a str {
54    fn to_param(self) -> SqlParam {
55        SqlParam::Text(self.to_string())
56    }
57}
58impl ToParam for uuid::Uuid {
59    fn to_param(self) -> SqlParam {
60        SqlParam::Uuid(self)
61    }
62}
63impl ToParam for rust_decimal::Decimal {
64    fn to_param(self) -> SqlParam {
65        SqlParam::Decimal(self)
66    }
67}
68impl ToParam for chrono::NaiveDateTime {
69    fn to_param(self) -> SqlParam {
70        SqlParam::DateTime(self)
71    }
72}
73impl ToParam for Vec<u8> {
74    fn to_param(self) -> SqlParam {
75        SqlParam::Bytes(self)
76    }
77}
78
79impl<T: ToParam> ToParam for Option<T> {
80    fn to_param(self) -> SqlParam {
81        match self {
82            Some(v) => v.to_param(),
83            None => SqlParam::Null,
84        }
85    }
86}
87
88#[derive(Clone, Debug)]
89pub enum Expr {
90    Col(String),
91    Param(SqlParam),
92    Binary {
93        left: Box<Expr>,
94        op: &'static str,
95        right: Box<Expr>,
96    },
97    Like {
98        left: Box<Expr>,
99        right: SqlParam,
100    },
101    InList {
102        left: Box<Expr>,
103        list: Vec<SqlParam>,
104    },
105    Group(Box<Expr>),
106}
107
108impl Expr {
109    pub fn eq(self, rhs: Expr) -> Expr {
110        Expr::Binary {
111            left: Box::new(self),
112            op: "=",
113            right: Box::new(rhs),
114        }
115    }
116    pub fn ne(self, rhs: Expr) -> Expr {
117        Expr::Binary {
118            left: Box::new(self),
119            op: "<>",
120            right: Box::new(rhs),
121        }
122    }
123    pub fn gt(self, rhs: Expr) -> Expr {
124        Expr::Binary {
125            left: Box::new(self),
126            op: ">",
127            right: Box::new(rhs),
128        }
129    }
130    pub fn ge(self, rhs: Expr) -> Expr {
131        Expr::Binary {
132            left: Box::new(self),
133            op: ">=",
134            right: Box::new(rhs),
135        }
136    }
137    pub fn lt(self, rhs: Expr) -> Expr {
138        Expr::Binary {
139            left: Box::new(self),
140            op: "<",
141            right: Box::new(rhs),
142        }
143    }
144    pub fn le(self, rhs: Expr) -> Expr {
145        Expr::Binary {
146            left: Box::new(self),
147            op: "<=",
148            right: Box::new(rhs),
149        }
150    }
151    pub fn and(self, rhs: Expr) -> Expr {
152        Expr::Binary {
153            left: Box::new(self),
154            op: "AND",
155            right: Box::new(rhs),
156        }
157    }
158    pub fn or(self, rhs: Expr) -> Expr {
159        Expr::Binary {
160            left: Box::new(self),
161            op: "OR",
162            right: Box::new(rhs),
163        }
164    }
165    pub fn like(self, pattern: Expr) -> Expr {
166        match pattern {
167            Expr::Param(p) => Expr::Like {
168                left: Box::new(self),
169                right: p,
170            },
171            other => panic!("like expects Expr::Param but received {:?}", other),
172        }
173    }
174    pub fn in_list(self, list: Vec<Expr>) -> Expr {
175        let mut ps = Vec::new();
176        for e in list {
177            match e {
178                Expr::Param(p) => ps.push(p),
179                other => panic!("in_list expects Expr::Param items but received {:?}", other),
180            }
181        }
182        Expr::InList {
183            left: Box::new(self),
184            list: ps,
185        }
186    }
187    pub fn group(self) -> Expr {
188        Expr::Group(Box::new(self))
189    }
190
191    pub fn to_sql_with(&self, style: PlaceholderStyle, params: &mut Vec<SqlParam>) -> String {
192        match self {
193            Expr::Col(c) => c.clone(),
194            Expr::Param(p) => {
195                params.push(p.clone());
196                let idx = params.len();
197                match style {
198                    PlaceholderStyle::AtP => format!("@P{}", idx),
199                    PlaceholderStyle::Dollar => format!("${}", idx),
200                }
201            }
202            Expr::Binary { left, op, right } => {
203                let l = left.to_sql_with(style, params);
204                let r = right.to_sql_with(style, params);
205                if *op == "AND" || *op == "OR" {
206                    format!("{} {} {}", l, op, r)
207                } else {
208                    format!("({} {} {})", l, op, r)
209                }
210            }
211            Expr::Like { left, right } => {
212                params.push(right.clone());
213                let idx = params.len();
214                let ph = match style {
215                    PlaceholderStyle::AtP => format!("@P{}", idx),
216                    PlaceholderStyle::Dollar => format!("${}", idx),
217                };
218                format!("({} LIKE {})", left.to_sql_with(style, params), ph)
219            }
220            Expr::InList { left, list } => {
221                let mut phs = Vec::new();
222                for p in list {
223                    params.push(p.clone());
224                    let idx = params.len();
225                    phs.push(match style {
226                        PlaceholderStyle::AtP => format!("@P{}", idx),
227                        PlaceholderStyle::Dollar => format!("${}", idx),
228                    });
229                }
230                format!(
231                    "{} IN ({})",
232                    left.to_sql_with(style, params),
233                    phs.join(", ")
234                )
235            }
236            Expr::Group(e) => format!("({})", e.to_sql_with(style, params)),
237        }
238    }
239}
240
241#[macro_export]
242macro_rules! col {
243    ($name:expr) => {
244        $crate::query::Expr::Col($name.to_string())
245    };
246}
247
248#[macro_export]
249macro_rules! val {
250    ($v:expr) => {
251        $crate::query::Expr::Param($crate::query::ToParam::to_param($v))
252    };
253}
254
255// Helper macro for column equality to reduce duplication in on! and condition!
256#[macro_export]
257macro_rules! __col_eq {
258    // Compare two columns
259    (($lt:ident :: $lf:ident), ($rt:ident :: $rf:ident)) => {
260        $crate::query::Expr::Col(format!("{}.{}", $lt::TABLE, $lt::$lf))
261            .eq($crate::query::Expr::Col(format!("{}.{}", $rt::TABLE, $rt::$rf)))
262    };
263    // Compare column to value
264    (($lt:ident :: $lf:ident), $rv:expr) => {
265        $crate::query::Expr::Col(format!("{}.{}", $lt::TABLE, $lt::$lf))
266            .eq($crate::query::Expr::Param($crate::query::ToParam::to_param($rv)))
267    };
268}
269
270#[macro_export]
271macro_rules! on {
272    ($lt:ident :: $lf:ident == $rt:ident :: $rf:ident) => {
273        $crate::__col_eq!(($lt :: $lf), ($rt :: $rf))
274    };
275    ($lt:ident :: $lf:ident == $rv:expr) => {
276        $crate::__col_eq!(($lt :: $lf), $rv)
277    };
278}
279
280#[macro_export]
281macro_rules! condition {
282    ($l:literal == $rv:expr) => {{
283        $crate::query::Expr::Col($l.to_string())
284            .eq($crate::query::Expr::Param($crate::query::ToParam::to_param($rv)))
285    }};
286    ($lt:ident :: $lf:ident == $rt:ident :: $rf:ident) => {
287        $crate::__col_eq!(($lt :: $lf), ($rt :: $rf))
288    };
289    ($lt:ident :: $lf:ident == $rv:expr) => {
290        $crate::__col_eq!(($lt :: $lf), $rv)
291    };
292}
293
294#[derive(Clone, Copy, Debug, PartialEq, Eq)]
295pub enum JoinType {
296    Inner,
297    Left,
298    Right,
299    Full,
300}
301
302impl JoinType {
303    fn to_sql(self) -> &'static str {
304        match self {
305            JoinType::Inner => "INNER JOIN",
306            JoinType::Left => "LEFT JOIN",
307            JoinType::Right => "RIGHT JOIN",
308            JoinType::Full => "FULL JOIN",
309        }
310    }
311}
312
313pub struct DualQuery<T, U>
314where
315    T: Entity + FromRowWithPrefix,
316    U: Entity + FromRowWithPrefix,
317{
318    style: PlaceholderStyle,
319    db: Option<Arc<DatabaseRef>>,
320    filters: Vec<Expr>,
321    order_by: Option<String>,
322    top: Option<i64>,
323    join: Option<(JoinType, Expr)>,
324    _t: PhantomData<T>,
325    _u: PhantomData<U>,
326}
327
328impl<T, U> DualQuery<T, U>
329where
330    T: Entity + FromRowWithPrefix,
331    U: Entity + FromRowWithPrefix,
332{
333    pub fn new(style: PlaceholderStyle) -> Self {
334        Self {
335            style,
336            db: None,
337            filters: Vec::new(),
338            order_by: None,
339            top: None,
340            join: None,
341            _t: PhantomData,
342            _u: PhantomData,
343        }
344    }
345
346    pub fn with_db(mut self, db: Arc<DatabaseRef>) -> Self {
347        self.db = Some(db);
348        self
349    }
350
351    pub fn Join(mut self, join_type: JoinType, on_expr: Expr) -> Self {
352        self.join = Some((join_type, on_expr));
353        self
354    }
355
356    pub fn Where(mut self, expr: Expr) -> Self {
357        self.filters.push(expr);
358        self
359    }
360
361    pub fn OrderBy(mut self, ob: &str) -> Self {
362        self.order_by = Some(ob.to_string());
363        self
364    }
365
366    pub fn Top(mut self, n: i64) -> Self {
367        self.top = Some(n);
368        self
369    }
370
371    pub fn to_sql(&self) -> (String, Vec<SqlParam>) {
372        let mut params = Vec::new();
373        let tname = T::table().name;
374        let uname = U::table().name;
375        let mut cols = Vec::new();
376        for c in T::table().columns {
377            cols.push(format!("{}.{} AS t_{}", tname, c.name, c.name));
378        }
379        for c in U::table().columns {
380            cols.push(format!("{}.{} AS u_{}", uname, c.name, c.name));
381        }
382        let mut sql = String::new();
383        match self.style {
384            PlaceholderStyle::AtP => {
385                if let Some(n) = self.top {
386                    sql.push_str(&format!("SELECT TOP({}) {} FROM {}", n, cols.join(", "), tname));
387                } else {
388                    sql.push_str(&format!("SELECT {} FROM {}", cols.join(", "), tname));
389                }
390            }
391            PlaceholderStyle::Dollar => {
392                sql.push_str(&format!("SELECT {} FROM {}", cols.join(", "), tname));
393            }
394        }
395        if let Some((jt, on)) = &self.join {
396            sql.push(' ');
397            sql.push_str(jt.to_sql());
398            sql.push(' ');
399            sql.push_str(uname);
400            sql.push_str(" ON ");
401            sql.push_str(&on.to_sql_with(self.style, &mut params));
402        }
403        if !self.filters.is_empty() {
404            let mut it = self.filters.iter();
405            if let Some(first) = it.next() {
406                sql.push_str(" WHERE ");
407                sql.push_str(&first.to_sql_with(self.style, &mut params));
408                for f in it {
409                    sql.push_str(" AND ");
410                    sql.push_str(&f.to_sql_with(self.style, &mut params));
411                }
412            }
413        }
414        if let Some(ob) = &self.order_by {
415            sql.push_str(" ORDER BY ");
416            sql.push_str(ob);
417        }
418        if let Some(n) = self.top {
419            if self.style == PlaceholderStyle::Dollar {
420                sql.push_str(&format!(" LIMIT {}", n));
421            }
422        }
423        (sql, params)
424    }
425
426    pub async fn to_list_async(self) -> Result<Vec<(T, U)>> {
427        let db = self.db.clone().expect("database reference not set");
428        let (sql, params) = self.to_sql();
429        match db.as_ref() {
430            DatabaseRef::Mssql(conn) => {
431                let mut guard = conn.lock().await;
432                let mut boxed: Vec<Box<dyn tiberius::ToSql + Send + Sync>> = Vec::new();
433                for p in &params {
434                    let b: Box<dyn tiberius::ToSql + Send + Sync> = match p {
435                        SqlParam::I32(v) => Box::new(*v),
436                        SqlParam::I64(v) => Box::new(*v),
437                        SqlParam::Bool(v) => Box::new(*v),
438                        SqlParam::Text(v) => Box::new(v.clone()),
439                        SqlParam::Uuid(v) => Box::new(*v),
440                        SqlParam::Decimal(v) => Box::new(v.to_string()),
441                        SqlParam::DateTime(v) => Box::new(*v),
442                        SqlParam::Bytes(v) => Box::new(v.clone()),
443                        SqlParam::Null => Box::new(Option::<i32>::None),
444                    };
445                    boxed.push(b);
446                }
447                let refs: Vec<&dyn tiberius::ToSql> =
448                    boxed.iter().map(|b| &**b as &dyn tiberius::ToSql).collect();
449                let mut stream = guard.query(&sql, &refs[..]).await?;
450                let mut out = Vec::new();
451                while let Some(item) = stream.try_next().await? {
452                    if let Some(row) = item.into_row() {
453                        let left = T::from_row_ms_with(&row, "t")?;
454                        let right = U::from_row_ms_with(&row, "u")?;
455                        out.push((left, right));
456                    }
457                }
458                Ok(out)
459            }
460            DatabaseRef::Postgres(pg) => {
461                let mut boxed: Vec<Box<dyn tokio_postgres::types::ToSql + Send + Sync>> =
462                    Vec::new();
463                for p in &params {
464                    let b: Box<dyn tokio_postgres::types::ToSql + Send + Sync> = match p {
465                        SqlParam::I32(v) => Box::new(*v),
466                        SqlParam::I64(v) => Box::new(*v),
467                        SqlParam::Bool(v) => Box::new(*v),
468                        SqlParam::Text(v) => Box::new(v.clone()),
469                        SqlParam::Uuid(v) => Box::new(*v),
470                        SqlParam::Decimal(v) => Box::new(v.to_string()),
471                        SqlParam::DateTime(v) => Box::new(*v),
472                        SqlParam::Bytes(v) => Box::new(v.clone()),
473                        SqlParam::Null => Box::new(Option::<i32>::None),
474                    };
475                    boxed.push(b);
476                }
477                let refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
478                    boxed.iter().map(|b| &**b as _).collect();
479                let rows = pg.query(&sql, &refs[..]).await?;
480                let mut out = Vec::new();
481                for row in rows {
482                    let left = T::from_row_pg_with(&row, "t")?;
483                    let right = U::from_row_pg_with(&row, "u")?;
484                    out.push((left, right));
485                }
486                Ok(out)
487            }
488        }
489    }
490}
491
492struct JoinClause {
493    join_type: JoinType,
494    table: String,
495    on: Expr,
496}
497
498#[allow(non_snake_case)]
499pub struct Query<T>
500where
501    T: Entity + crate::mapping::FromRowNamed,
502{
503    table: String,
504    style: PlaceholderStyle,
505    db: Option<Arc<DatabaseRef>>,
506    joins: Vec<JoinClause>,
507    filters: Vec<Expr>,
508    order_by: Option<String>,
509    top: Option<i64>,
510    _t: PhantomData<T>,
511}
512
513#[allow(non_snake_case)]
514impl<T> Query<T>
515where
516    T: Entity + crate::mapping::FromRowNamed,
517{
518    pub fn new(table: &str, style: PlaceholderStyle) -> Self {
519        Self {
520            table: table.to_string(),
521            style,
522            db: None,
523            joins: Vec::new(),
524            filters: Vec::new(),
525            order_by: None,
526            top: None,
527            _t: PhantomData,
528        }
529    }
530
531    pub fn with_db(mut self, db: Arc<DatabaseRef>) -> Self {
532        self.db = Some(db);
533        self
534    }
535
536    pub fn Where(mut self, expr: Expr) -> Self {
537        self.filters.push(expr);
538        self
539    }
540
541    pub fn Join(mut self, join_type: JoinType, table: &str, on_expr: Expr) -> Self {
542        self.joins.push(JoinClause {
543            join_type,
544            table: table.to_string(),
545            on: on_expr,
546        });
547        self
548    }
549
550    pub fn OrderBy(mut self, ob: &str) -> Self {
551        self.order_by = Some(ob.to_string());
552        self
553    }
554
555    pub fn Top(mut self, n: i64) -> Self {
556        self.top = Some(n);
557        self
558    }
559
560    pub fn to_sql(&self) -> (String, Vec<SqlParam>) {
561        let mut params = Vec::new();
562        let mut sql = String::new();
563        match self.style {
564            PlaceholderStyle::AtP => {
565                if let Some(n) = self.top {
566                    sql.push_str(&format!("SELECT TOP({}) * FROM {}", n, self.table));
567                } else {
568                    sql.push_str(&format!("SELECT * FROM {}", self.table));
569                }
570            }
571            PlaceholderStyle::Dollar => {
572                sql.push_str(&format!("SELECT * FROM {}", self.table));
573            }
574        }
575        for j in &self.joins {
576            sql.push(' ');
577            sql.push_str(j.join_type.to_sql());
578            sql.push(' ');
579            sql.push_str(&j.table);
580            sql.push_str(" ON ");
581            sql.push_str(&j.on.to_sql_with(self.style, &mut params));
582        }
583        if !self.filters.is_empty() {
584            let mut it = self.filters.iter();
585            if let Some(first) = it.next() {
586                sql.push_str(" WHERE ");
587                sql.push_str(&first.to_sql_with(self.style, &mut params));
588                for f in it {
589                    sql.push_str(" AND ");
590                    sql.push_str(&f.to_sql_with(self.style, &mut params));
591                }
592            }
593        }
594        if let Some(ob) = &self.order_by {
595            sql.push_str(" ORDER BY ");
596            sql.push_str(ob);
597        }
598        if let Some(n) = self.top {
599            if self.style == PlaceholderStyle::Dollar {
600                sql.push_str(&format!(" LIMIT {}", n));
601            }
602        }
603        (sql, params)
604    }
605
606    pub async fn to_list_async(self) -> Result<Vec<T>> {
607        let db = self.db.clone().expect("database reference not set");
608        let (sql, params) = self.to_sql();
609        match db.as_ref() {
610            DatabaseRef::Mssql(conn) => {
611                let mut guard = conn.lock().await;
612                let mut boxed: Vec<Box<dyn tiberius::ToSql + Send + Sync>> = Vec::new();
613                for p in &params {
614                    let b: Box<dyn tiberius::ToSql + Send + Sync> = match p {
615                        SqlParam::I32(v) => Box::new(*v),
616                        SqlParam::I64(v) => Box::new(*v),
617                        SqlParam::Bool(v) => Box::new(*v),
618                        SqlParam::Text(v) => Box::new(v.clone()),
619                        SqlParam::Uuid(v) => Box::new(*v),
620                        SqlParam::Decimal(v) => Box::new(v.to_string()),
621                        SqlParam::DateTime(v) => Box::new(*v),
622                        SqlParam::Bytes(v) => Box::new(v.clone()),
623                        SqlParam::Null => Box::new(Option::<i32>::None),
624                    };
625                    boxed.push(b);
626                }
627                let refs: Vec<&dyn tiberius::ToSql> =
628                    boxed.iter().map(|b| &**b as &dyn tiberius::ToSql).collect();
629                let mut stream = guard.query(&sql, &refs[..]).await?;
630                let mut out = Vec::new();
631                while let Some(item) = stream.try_next().await? {
632                    if let Some(row) = item.into_row() {
633                        out.push(T::from_row_ms(&row)?);
634                    }
635                }
636                Ok(out)
637            }
638            DatabaseRef::Postgres(pg) => {
639                let mut boxed: Vec<Box<dyn tokio_postgres::types::ToSql + Send + Sync>> =
640                    Vec::new();
641                for p in &params {
642                    let b: Box<dyn tokio_postgres::types::ToSql + Send + Sync> = match p {
643                        SqlParam::I32(v) => Box::new(*v),
644                        SqlParam::I64(v) => Box::new(*v),
645                        SqlParam::Bool(v) => Box::new(*v),
646                        SqlParam::Text(v) => Box::new(v.clone()),
647                        SqlParam::Uuid(v) => Box::new(*v),
648                        SqlParam::Decimal(v) => Box::new(v.to_string()),
649                        SqlParam::DateTime(v) => Box::new(*v),
650                        SqlParam::Bytes(v) => Box::new(v.clone()),
651                        SqlParam::Null => Box::new(Option::<i32>::None),
652                    };
653                    boxed.push(b);
654                }
655                let refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
656                    boxed.iter().map(|b| &**b as _).collect();
657                let rows = pg.query(&sql, &refs[..]).await?;
658                let mut out = Vec::new();
659                for row in rows {
660                    out.push(T::from_row_pg(&row)?);
661                }
662                Ok(out)
663            }
664        }
665    }
666
667    pub async fn to_single_async(self) -> Result<Option<T>> {
668        let mut list = self.Top(1).to_list_async().await?;
669        Ok(list.pop())
670    }
671
672    pub async fn ToDictionaryKeyIntAsync(
673        self,
674    ) -> anyhow::Result<std::collections::HashMap<i32, T>> {
675        unimplemented!("execution not implemented");
676    }
677
678    pub async fn ToDictionaryKeyGuidAsync(
679        self,
680    ) -> anyhow::Result<std::collections::HashMap<uuid::Uuid, T>> {
681        unimplemented!("execution not implemented");
682    }
683
684    pub async fn ToDictionaryKeyStringAsync(
685        self,
686    ) -> anyhow::Result<std::collections::HashMap<String, T>> {
687        unimplemented!("execution not implemented");
688    }
689}