sqlstr/expr/
expression.rs

1//! Operators and functions.
2
3pub mod math;
4
5use super::{separator, separator_optional};
6use crate::{ArgumentBuffer, SqlExpr, WriteSql};
7
8pub trait BinaryOperator: private::Sealed {
9    fn push_operator<Sql, Arg>(&self, sql: &mut Sql)
10    where
11        Sql: WriteSql<Arg>;
12}
13
14pub trait UnaryOperator: private::Sealed {
15    fn push_operator<Sql, Arg>(&self, sql: &mut Sql)
16    where
17        Sql: WriteSql<Arg>;
18}
19
20mod private {
21    pub trait Sealed {}
22}
23
24/// Comparison operators
25#[cfg_attr(any(feature = "fmt", test, debug_assertions), derive(Debug))]
26#[derive(Clone, Copy, PartialEq, Eq)]
27pub enum Cmp {
28    /// Equal `==`
29    Eq,
30    /// Not equal `!=`
31    Neq,
32    /// Greater than `>`
33    Gt,
34    //// Greater than or equal `>=`
35    Gte,
36    /// Less than `<`
37    Lt,
38    /// Less than or equal `<=`
39    Lte,
40}
41
42impl Cmp {
43    pub const fn as_str(&self) -> &'static str {
44        match self {
45            Self::Eq => "=",
46            Self::Neq => "<>",
47            Self::Gt => ">",
48            Self::Gte => ">=",
49            Self::Lt => "<",
50            Self::Lte => "<=",
51        }
52    }
53}
54
55impl private::Sealed for Cmp {}
56
57impl BinaryOperator for Cmp {
58    fn push_operator<Sql, Arg>(&self, sql: &mut Sql)
59    where
60        Sql: WriteSql<Arg>,
61    {
62        sql.push_cmd(self.as_str())
63    }
64}
65
66/// Logic binary operators
67#[cfg_attr(any(feature = "fmt", test, debug_assertions), derive(Debug))]
68#[derive(Clone, Copy, PartialEq, Eq)]
69pub enum LogicBi {
70    And,
71    Or,
72}
73
74impl LogicBi {
75    pub const fn as_str(&self) -> &'static str {
76        match *self {
77            Self::And => "AND",
78            Self::Or => "OR",
79        }
80    }
81}
82
83impl private::Sealed for LogicBi {}
84
85impl BinaryOperator for LogicBi {
86    fn push_operator<Sql, Arg>(&self, sql: &mut Sql)
87    where
88        Sql: WriteSql<Arg>,
89    {
90        sql.push_cmd(self.as_str())
91    }
92}
93
94/// Logic unary operators
95#[cfg_attr(any(feature = "fmt", test, debug_assertions), derive(Debug))]
96#[derive(Clone, Copy, PartialEq, Eq)]
97pub enum LogicUn {
98    Not,
99}
100
101impl LogicUn {
102    pub const fn as_str(&self) -> &'static str {
103        match *self {
104            LogicUn::Not => "NOT",
105        }
106    }
107}
108
109impl private::Sealed for LogicUn {}
110
111impl UnaryOperator for LogicUn {
112    fn push_operator<Sql, Arg>(&self, sql: &mut Sql)
113    where
114        Sql: WriteSql<Arg>,
115    {
116        sql.push_cmd(self.as_str())
117    }
118}
119
120pub fn continue_condition<Sql, Arg>(sql: &mut Sql, op: LogicBi)
121where
122    Sql: WriteSql<Arg>,
123{
124    if sql.as_command().is_empty() {
125        return;
126    }
127
128    let end = sql.as_command().trim_end_matches(' ');
129    // WHERE | ON | HAVING | ( <open group>
130    if end.ends_with("WHERE")
131        || end.ends_with("ON")
132        || end.ends_with("HAVING")
133        || end.ends_with('(')
134        || sql.as_command().is_empty()
135    {
136        return;
137    }
138
139    separator_optional(sql);
140    sql.push_cmd(op.as_str());
141}
142
143pub fn lhs_binary_rhs<Sql, Arg, BOp, Lhs, Rhs>(
144    sql: &mut Sql,
145    lhs: SqlExpr<Lhs>,
146    op: BOp,
147    rhs: SqlExpr<Rhs>,
148) -> Result<(), <Arg as ArgumentBuffer<Lhs>>::Error>
149where
150    Sql: WriteSql<Arg>,
151    BOp: BinaryOperator,
152    Arg: ArgumentBuffer<Lhs>,
153    Arg: ArgumentBuffer<Rhs, Error = <Arg as ArgumentBuffer<Lhs>>::Error>,
154{
155    separator_optional(sql);
156
157    sql.push_expr(lhs)?;
158    separator(sql);
159    op.push_operator(sql);
160    separator(sql);
161    sql.push_expr(rhs)
162}
163
164pub fn binary_rhs<Sql, Arg, BOp, Rhs>(
165    sql: &mut Sql,
166    op: BOp,
167    rhs: SqlExpr<Rhs>,
168) -> Result<(), Arg::Error>
169where
170    Sql: WriteSql<Arg>,
171    BOp: BinaryOperator,
172    Arg: ArgumentBuffer<Rhs>,
173{
174    separator_optional(sql);
175
176    op.push_operator(sql);
177    separator(sql);
178    sql.push_expr(rhs)
179}
180
181pub fn unary_rhs<Sql, Arg, UOp, Rhs>(
182    sql: &mut Sql,
183    op: UOp,
184    rhs: SqlExpr<Rhs>,
185) -> Result<(), Arg::Error>
186where
187    Sql: WriteSql<Arg>,
188    UOp: UnaryOperator,
189    Arg: ArgumentBuffer<Rhs>,
190{
191    separator_optional(sql);
192
193    op.push_operator(sql);
194    separator(sql);
195    sql.push_expr(rhs)
196}
197
198#[cfg(test)]
199mod test {
200    use super::{Cmp, LogicBi};
201    use crate::{
202        expr::{
203            binary_rhs, continue_condition, lhs_binary_rhs, math::MathBi, unary_rhs, Group, LogicUn,
204        },
205        sqlexpr, sqlvalue,
206        test::TestArgs,
207        SqlCommand, SqlExpr,
208    };
209
210    #[test]
211    fn condition_comparison() {
212        let mut sql: SqlCommand<TestArgs> = SqlCommand::default();
213        continue_condition(&mut sql, LogicBi::And);
214        {
215            let mut group = Group::open(&mut sql);
216            lhs_binary_rhs(
217                &mut group,
218                sqlexpr::<&str>("user.id"),
219                Cmp::Eq,
220                sqlvalue(32),
221            )
222            .unwrap();
223            continue_condition(&mut group, LogicBi::And);
224            lhs_binary_rhs(
225                &mut group,
226                sqlexpr::<&str>("access.created"),
227                Cmp::Gte,
228                sqlvalue(2040),
229            )
230            .unwrap();
231        }
232
233        assert_eq!(sql.as_command(), "(user.id = $1 AND access.created >= $2)");
234        assert_eq!(sql.arguments.as_str(), "32;2040;");
235    }
236
237    #[test]
238    fn condition_math() {
239        let mut sql: SqlCommand<TestArgs> = SqlCommand::default();
240        continue_condition(&mut sql, LogicBi::And);
241
242        {
243            let mut group = Group::open(&mut sql);
244            lhs_binary_rhs(
245                &mut group,
246                sqlexpr::<u8>("column1"),
247                MathBi::Add,
248                SqlExpr::Value(30),
249            )
250            .unwrap();
251            binary_rhs(&mut group, Cmp::Gt, sqlexpr::<u8>("column2")).unwrap();
252
253            continue_condition(&mut group, LogicBi::And);
254
255            lhs_binary_rhs(
256                &mut group,
257                sqlexpr::<u8>("column1"),
258                Cmp::Eq,
259                sqlexpr::<u8>("column3"),
260            )
261            .unwrap();
262        }
263
264        continue_condition(&mut sql, LogicBi::Or);
265
266        unary_rhs(&mut sql, LogicUn::Not, sqlexpr::<u8>("column4")).unwrap();
267        binary_rhs(&mut sql, Cmp::Lt, SqlExpr::Value(10)).unwrap();
268
269        assert_eq!(
270            sql.as_command(),
271            "(column1 + $1 > column2 AND column1 = column3) OR NOT column4 < $2"
272        );
273    }
274}