rdb_pagination_core/sql/
join.rs

1use std::{
2    error::Error,
3    fmt,
4    fmt::{Display, Formatter},
5};
6
7use crate::{ColumnName, TableColumnAttributes, TableName};
8
9/// Struct for generating the `JOIN` clause.
10#[derive(Debug, Clone, Eq, PartialEq)]
11pub struct SqlJoin {
12    pub other_table_name:  TableName,
13    pub other_column_name: ColumnName,
14    pub real_table_name:   Option<TableName>,
15    pub using_table_name:  TableName,
16    pub using_column_name: ColumnName,
17}
18
19impl SqlJoin {
20    #[doc(hidden)]
21    #[inline]
22    pub fn from_table_column_attributes(table_column_attributes: &TableColumnAttributes) -> Self {
23        Self {
24            other_table_name:  table_column_attributes.table_name.clone(),
25            other_column_name: table_column_attributes.column_name.clone(),
26            real_table_name:   table_column_attributes.real_table_name.clone(),
27            using_table_name:  table_column_attributes.foreign_table_name.clone(),
28            using_column_name: table_column_attributes.foreign_column_name.clone(),
29        }
30    }
31}
32
33#[cfg(any(feature = "mysql", feature = "sqlite"))]
34impl SqlJoin {
35    fn to_sql_join_clause<'a>(&self, s: &'a mut String) -> &'a str {
36        use std::{fmt::Write, str::from_utf8_unchecked};
37
38        let len = s.len();
39
40        if let Some(real_table_name) = &self.real_table_name {
41            s.write_fmt(format_args!(
42                "LEFT JOIN `{real_table_name}` AS `{other_table_name}` ON \
43                 `{other_table_name}`.`{other_column_name}` = \
44                 `{using_table_name}`.`{using_column_name}`",
45                other_table_name = self.other_table_name,
46                other_column_name = self.other_column_name,
47                using_table_name = self.using_table_name,
48                using_column_name = self.using_column_name,
49            ))
50            .unwrap()
51        } else {
52            s.write_fmt(format_args!(
53                "LEFT JOIN `{other_table_name}` ON `{other_table_name}`.`{other_column_name}` = \
54                 `{using_table_name}`.`{using_column_name}`",
55                other_table_name = self.other_table_name,
56                other_column_name = self.other_column_name,
57                using_table_name = self.using_table_name,
58                using_column_name = self.using_column_name,
59            ))
60            .unwrap()
61        }
62
63        unsafe { from_utf8_unchecked(&s.as_bytes()[len..]) }
64    }
65
66    fn format_sql_join_clauses<'a>(joins: &[SqlJoin], s: &'a mut String) -> &'a str {
67        use std::str::from_utf8_unchecked;
68
69        if joins.is_empty() {
70            return "";
71        }
72
73        let len = s.len();
74
75        for join in joins {
76            join.to_sql_join_clause(s);
77            s.push('\n');
78        }
79
80        unsafe {
81            let len = s.len();
82
83            s.as_mut_vec().truncate(len - 1);
84        }
85
86        unsafe { from_utf8_unchecked(&s.as_bytes()[len..]) }
87    }
88}
89
90#[cfg(feature = "mysql")]
91impl SqlJoin {
92    /// Generate a `JOIN` clause for MySQL.
93    ///
94    /// If `real_table_name` exists,
95    ///
96    /// ```sql
97    /// JOIN `<real_table_name>` AS `<other_table_name>` ON `<other_table_name>`.`<other_column_name>` = `<using_table_name>`.`<using_column_name>`
98    /// ```
99    ///
100    /// or
101    ///
102    /// ```sql
103    /// JOIN `<other_table_name>` ON `<other_table_name>`.`<other_column_name>` = `<using_table_name>`.`<using_column_name>`
104    /// ```
105    #[inline]
106    pub fn to_mysql_join_clause<'a>(&self, s: &'a mut String) -> &'a str {
107        self.to_sql_join_clause(s)
108    }
109
110    /// Generate `JOIN` clauses for MySQL.
111    ///
112    /// Concatenate a series of `SqlJoin`s with `\n`.
113    #[inline]
114    pub fn format_mysql_join_clauses<'a>(joins: &[SqlJoin], s: &'a mut String) -> &'a str {
115        Self::format_sql_join_clauses(joins, s)
116    }
117}
118
119#[cfg(feature = "sqlite")]
120impl SqlJoin {
121    /// Generate a `JOIN` clause for SQLite.
122    ///
123    /// If `real_table_name` exists,
124    ///
125    /// ```sql
126    /// JOIN `<real_table_name>` AS `<other_table_name>` ON `<other_table_name>`.`<other_column_name>` = `<using_table_name>`.`<using_column_name>`
127    /// ```
128    ///
129    /// or
130    ///
131    /// ```sql
132    /// JOIN `<other_table_name>` ON `<other_table_name>`.`<other_column_name>` = `<using_table_name>`.`<using_column_name>`
133    /// ```
134    #[inline]
135    pub fn to_sqlite_join_clause<'a>(&self, s: &'a mut String) -> &'a str {
136        self.to_sql_join_clause(s)
137    }
138
139    /// Generate `JOIN` clauses for SQLite.
140    ///
141    /// Concatenate a series of `SqlJoin`s with `\n`.
142    #[inline]
143    pub fn format_sqlite_join_clauses<'a>(joins: &[SqlJoin], s: &'a mut String) -> &'a str {
144        Self::format_sql_join_clauses(joins, s)
145    }
146}
147
148#[cfg(any(feature = "mssql", feature = "mssql2008"))]
149impl SqlJoin {
150    fn to_sql_join_clause_ms<'a>(&self, s: &'a mut String) -> &'a str {
151        use std::{fmt::Write, str::from_utf8_unchecked};
152
153        let len = s.len();
154
155        if let Some(real_table_name) = &self.real_table_name {
156            s.write_fmt(format_args!(
157                "LEFT JOIN [{real_table_name}] AS [{other_table_name}] ON \
158                 [{other_table_name}].[{other_column_name}] = \
159                 [{using_table_name}].[{using_column_name}]",
160                other_table_name = self.other_table_name,
161                other_column_name = self.other_column_name,
162                using_table_name = self.using_table_name,
163                using_column_name = self.using_column_name,
164            ))
165            .unwrap()
166        } else {
167            s.write_fmt(format_args!(
168                "LEFT JOIN [{other_table_name}] ON [{other_table_name}].[{other_column_name}] = \
169                 [{using_table_name}].[{using_column_name}]",
170                other_table_name = self.other_table_name,
171                other_column_name = self.other_column_name,
172                using_table_name = self.using_table_name,
173                using_column_name = self.using_column_name,
174            ))
175            .unwrap()
176        }
177
178        unsafe { from_utf8_unchecked(&s.as_bytes()[len..]) }
179    }
180
181    fn format_sql_join_clauses_ms<'a>(joins: &[SqlJoin], s: &'a mut String) -> &'a str {
182        use std::str::from_utf8_unchecked;
183
184        if joins.is_empty() {
185            return "";
186        }
187
188        let len = s.len();
189
190        for join in joins {
191            join.to_sql_join_clause_ms(s);
192            s.push('\n');
193        }
194
195        unsafe {
196            let len = s.len();
197
198            s.as_mut_vec().truncate(len - 1);
199        }
200
201        unsafe { from_utf8_unchecked(&s.as_bytes()[len..]) }
202    }
203}
204
205#[cfg(any(feature = "mssql", feature = "mssql2008"))]
206impl SqlJoin {
207    /// Generate a `JOIN` clause for Microsoft SQL Server.
208    ///
209    /// If `real_table_name` exists,
210    ///
211    /// ```sql
212    /// JOIN [<real_table_name>] AS [<other_table_name>] ON [<other_table_name>].[<other_column_name>] = [<using_table_name>].[<using_column_name>]
213    /// ```
214    ///
215    /// or
216    ///
217    /// ```sql
218    /// JOIN [<other_table_name>] ON [<other_table_name>].[<other_column_name>] = [<using_table_name>].[<using_column_name>]
219    /// ```
220    #[inline]
221    pub fn to_mssql_join_clause<'a>(&self, s: &'a mut String) -> &'a str {
222        self.to_sql_join_clause_ms(s)
223    }
224
225    /// Generate `JOIN` clauses for Microsoft SQL Server.
226    ///
227    /// Concatenate a series of `SqlJoin`s with `\n`.
228    #[inline]
229    pub fn format_mssql_join_clauses<'a>(joins: &[SqlJoin], s: &'a mut String) -> &'a str {
230        Self::format_sql_join_clauses_ms(joins, s)
231    }
232}
233
234/// Operators for `SqlJoin`s.
235pub trait SqlJoinsOps {
236    /// Insert a `SqlJoin` if it does not exist. Return `Ok(true)` if a new `SqlJoin` has been pushed.
237    fn add_join(&mut self, join: SqlJoin) -> Result<bool, SqlJoinsInsertError>;
238}
239
240impl SqlJoinsOps for Vec<SqlJoin> {
241    #[inline]
242    fn add_join(&mut self, join: SqlJoin) -> Result<bool, SqlJoinsInsertError> {
243        if let Some(existing_join) = self
244            .iter()
245            .find(|existing_join| existing_join.other_table_name == join.other_table_name)
246        {
247            if existing_join.other_column_name != join.other_column_name
248                || existing_join.real_table_name != join.real_table_name
249                || existing_join.using_table_name != join.using_table_name
250                || existing_join.using_column_name != join.using_column_name
251            {
252                Err(SqlJoinsInsertError::OtherTableNameConflict)
253            } else {
254                Ok(false)
255            }
256        } else {
257            self.push(join);
258
259            Ok(true)
260        }
261    }
262}
263
264#[derive(Debug, Clone)]
265pub enum SqlJoinsInsertError {
266    OtherTableNameConflict,
267}
268
269impl Display for SqlJoinsInsertError {
270    #[inline]
271    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
272        match self {
273            Self::OtherTableNameConflict => {
274                f.write_str("other_table_name exists but the join clauses are not exactly the same")
275            },
276        }
277    }
278}
279
280impl Error for SqlJoinsInsertError {}