rdb_pagination_core/sql/
join.rs

1use std::{
2    error::Error,
3    fmt,
4    fmt::{Display, Formatter},
5};
6
7use crate::{ColumnName, TableColumnAttributes, TableName};
8
9/// Struct for generating the `JOIN` clause.
10#[derive(Debug, Clone, Eq, PartialEq)]
11pub struct SqlJoin {
12    pub other_table_name:  TableName,
13    pub other_column_name: ColumnName,
14    pub real_table_name:   Option<TableName>,
15    pub using_table_name:  TableName,
16    pub using_column_name: ColumnName,
17}
18
19impl SqlJoin {
20    #[doc(hidden)]
21    #[inline]
22    pub fn from_table_column_attributes(table_column_attributes: &TableColumnAttributes) -> Self {
23        Self {
24            other_table_name:  table_column_attributes.table_name.clone(),
25            other_column_name: table_column_attributes.column_name.clone(),
26            real_table_name:   table_column_attributes.real_table_name.clone(),
27            using_table_name:  table_column_attributes.foreign_table_name.clone(),
28            using_column_name: table_column_attributes.foreign_column_name.clone(),
29        }
30    }
31}
32
33#[cfg(any(feature = "mysql", feature = "sqlite"))]
34impl SqlJoin {
35    fn to_sql_join_clause<'a>(&self, s: &'a mut String) -> &'a str {
36        use std::{fmt::Write, str::from_utf8_unchecked};
37
38        let len = s.len();
39
40        if let Some(real_table_name) = &self.real_table_name {
41            s.write_fmt(format_args!(
42                "LEFT JOIN `{real_table_name}` AS `{other_table_name}` ON \
43                 `{other_table_name}`.`{other_column_name}` = \
44                 `{using_table_name}`.`{using_column_name}`",
45                other_table_name = self.other_table_name,
46                other_column_name = self.other_column_name,
47                using_table_name = self.using_table_name,
48                using_column_name = self.using_column_name,
49            ))
50            .unwrap()
51        } else {
52            s.write_fmt(format_args!(
53                "LEFT JOIN `{other_table_name}` ON `{other_table_name}`.`{other_column_name}` = \
54                 `{using_table_name}`.`{using_column_name}`",
55                other_table_name = self.other_table_name,
56                other_column_name = self.other_column_name,
57                using_table_name = self.using_table_name,
58                using_column_name = self.using_column_name,
59            ))
60            .unwrap()
61        }
62
63        unsafe { from_utf8_unchecked(&s.as_bytes()[len..]) }
64    }
65
66    fn format_sql_join_clauses<'a>(joins: &[SqlJoin], s: &'a mut String) -> &'a str {
67        use std::str::from_utf8_unchecked;
68
69        if joins.is_empty() {
70            return "";
71        }
72
73        let len = s.len();
74
75        for join in joins {
76            join.to_sql_join_clause(s);
77            s.push('\n');
78        }
79
80        unsafe {
81            let len = s.len();
82
83            s.as_mut_vec().truncate(len - 1);
84        }
85
86        unsafe { from_utf8_unchecked(&s.as_bytes()[len..]) }
87    }
88}
89
90#[cfg(feature = "mysql")]
91impl SqlJoin {
92    /// Generate a `JOIN` clause for MySQL.
93    ///
94    /// If `real_table_name` exists,
95    ///
96    /// ```sql
97    /// JOIN `<real_table_name>` AS `<other_table_name>` ON `<other_table_name>`.`<other_column_name>` = `<using_table_name>`.`<using_column_name>`
98    /// ```
99    ///
100    /// or
101    ///
102    /// ```sql
103    /// JOIN `<other_table_name>` ON `<other_table_name>`.`<other_column_name>` = `<using_table_name>`.`<using_column_name>`
104    /// ```
105    #[inline]
106    pub fn to_mysql_join_clause<'a>(&self, s: &'a mut String) -> &'a str {
107        self.to_sql_join_clause(s)
108    }
109
110    /// Generate `JOIN` clauses for MySQL.
111    ///
112    /// Concatenate a series of `SqlJoin`s with `\n`.
113    #[inline]
114    pub fn format_mysql_join_clauses<'a>(joins: &[SqlJoin], s: &'a mut String) -> &'a str {
115        Self::format_sql_join_clauses(joins, s)
116    }
117}
118
119#[cfg(feature = "sqlite")]
120impl SqlJoin {
121    /// Generate a `JOIN` clause for SQLite.
122    ///
123    /// If `real_table_name` exists,
124    ///
125    /// ```sql
126    /// JOIN `<real_table_name>` AS `<other_table_name>` ON `<other_table_name>`.`<other_column_name>` = `<using_table_name>`.`<using_column_name>`
127    /// ```
128    ///
129    /// or
130    ///
131    /// ```sql
132    /// JOIN `<other_table_name>` ON `<other_table_name>`.`<other_column_name>` = `<using_table_name>`.`<using_column_name>`
133    /// ```
134    #[inline]
135    pub fn to_sqlite_join_clause<'a>(&self, s: &'a mut String) -> &'a str {
136        self.to_sql_join_clause(s)
137    }
138
139    /// Generate `JOIN` clauses for SQLite.
140    ///
141    /// Concatenate a series of `SqlJoin`s with `\n`.
142    #[inline]
143    pub fn format_sqlite_join_clauses<'a>(joins: &[SqlJoin], s: &'a mut String) -> &'a str {
144        Self::format_sql_join_clauses(joins, s)
145    }
146}
147
148/// Operators for `SqlJoin`s.
149pub trait SqlJoinsOps {
150    /// Insert a `SqlJoin` if it does not exist. Return `Ok(true)` if a new `SqlJoin` has been pushed.
151    fn add_join(&mut self, join: SqlJoin) -> Result<bool, SqlJoinsInsertError>;
152}
153
154impl SqlJoinsOps for Vec<SqlJoin> {
155    #[inline]
156    fn add_join(&mut self, join: SqlJoin) -> Result<bool, SqlJoinsInsertError> {
157        if let Some(existing_join) = self
158            .iter()
159            .find(|existing_join| existing_join.other_table_name == join.other_table_name)
160        {
161            if existing_join.other_column_name != join.other_column_name
162                || existing_join.real_table_name != join.real_table_name
163                || existing_join.using_table_name != join.using_table_name
164                || existing_join.using_column_name != join.using_column_name
165            {
166                Err(SqlJoinsInsertError::OtherTableNameConflict)
167            } else {
168                Ok(false)
169            }
170        } else {
171            self.push(join);
172
173            Ok(true)
174        }
175    }
176}
177
178#[derive(Debug, Clone)]
179pub enum SqlJoinsInsertError {
180    OtherTableNameConflict,
181}
182
183impl Display for SqlJoinsInsertError {
184    #[inline]
185    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
186        match self {
187            Self::OtherTableNameConflict => {
188                f.write_str("other_table_name exists but the join clauses are not exactly the same")
189            },
190        }
191    }
192}
193
194impl Error for SqlJoinsInsertError {}