nom_sql/
join.rs

1use std::fmt;
2use std::str;
3
4use column::Column;
5use condition::ConditionExpression;
6use nom::branch::alt;
7use nom::bytes::complete::tag_no_case;
8use nom::combinator::map;
9use nom::IResult;
10use select::{JoinClause, SelectStatement};
11use table::Table;
12
13#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
14pub enum JoinRightSide {
15    /// A single table.
16    Table(Table),
17    /// A comma-separated (and implicitly joined) sequence of tables.
18    Tables(Vec<Table>),
19    /// A nested selection, represented as (query, alias).
20    NestedSelect(Box<SelectStatement>, Option<String>),
21    /// A nested join clause.
22    NestedJoin(Box<JoinClause>),
23}
24
25impl fmt::Display for JoinRightSide {
26    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27        match *self {
28            JoinRightSide::Table(ref t) => write!(f, "{}", t)?,
29            JoinRightSide::NestedSelect(ref q, ref a) => {
30                write!(f, "({})", q)?;
31                if a.is_some() {
32                    write!(f, " AS {}", a.as_ref().unwrap())?;
33                }
34            }
35            JoinRightSide::NestedJoin(ref jc) => write!(f, "({})", jc)?,
36            _ => unimplemented!(),
37        }
38        Ok(())
39    }
40}
41
42#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
43pub enum JoinOperator {
44    Join,
45    LeftJoin,
46    LeftOuterJoin,
47    InnerJoin,
48    CrossJoin,
49    StraightJoin,
50}
51
52impl fmt::Display for JoinOperator {
53    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54        match *self {
55            JoinOperator::Join => write!(f, "JOIN")?,
56            JoinOperator::LeftJoin => write!(f, "LEFT JOIN")?,
57            JoinOperator::LeftOuterJoin => write!(f, "LEFT OUTER JOIN")?,
58            JoinOperator::InnerJoin => write!(f, "INNER JOIN")?,
59            JoinOperator::CrossJoin => write!(f, "CROSS JOIN")?,
60            JoinOperator::StraightJoin => write!(f, "STRAIGHT JOIN")?,
61        }
62        Ok(())
63    }
64}
65
66#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
67pub enum JoinConstraint {
68    On(ConditionExpression),
69    Using(Vec<Column>),
70}
71
72impl fmt::Display for JoinConstraint {
73    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74        match *self {
75            JoinConstraint::On(ref ce) => write!(f, "ON {}", ce)?,
76            JoinConstraint::Using(ref columns) => write!(
77                f,
78                "USING ({})",
79                columns
80                    .iter()
81                    .map(|c| format!("{}", c))
82                    .collect::<Vec<_>>()
83                    .join(", ")
84            )?,
85        }
86        Ok(())
87    }
88}
89
90// Parse binary comparison operators
91pub fn join_operator(i: &[u8]) -> IResult<&[u8], JoinOperator> {
92    alt((
93        map(tag_no_case("join"), |_| JoinOperator::Join),
94        map(tag_no_case("left join"), |_| JoinOperator::LeftJoin),
95        map(tag_no_case("left outer join"), |_| {
96            JoinOperator::LeftOuterJoin
97        }),
98        map(tag_no_case("inner join"), |_| JoinOperator::InnerJoin),
99        map(tag_no_case("cross join"), |_| JoinOperator::CrossJoin),
100        map(tag_no_case("straight_join"), |_| JoinOperator::StraightJoin),
101    ))(i)
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use common::{FieldDefinitionExpression, Operator};
108    use condition::ConditionBase::*;
109    use condition::ConditionExpression::{self, *};
110    use condition::ConditionTree;
111    use select::{selection, JoinClause, SelectStatement};
112
113    #[test]
114    fn inner_join() {
115        let qstring = "SELECT tags.* FROM tags \
116                       INNER JOIN taggings ON tags.id = taggings.tag_id";
117
118        let res = selection(qstring.as_bytes());
119
120        let ct = ConditionTree {
121            left: Box::new(Base(Field(Column::from("tags.id")))),
122            right: Box::new(Base(Field(Column::from("taggings.tag_id")))),
123            operator: Operator::Equal,
124        };
125        let join_cond = ConditionExpression::ComparisonOp(ct);
126        let expected_stmt = SelectStatement {
127            tables: vec![Table::from("tags")],
128            fields: vec![FieldDefinitionExpression::AllInTable("tags".into())],
129            join: vec![JoinClause {
130                operator: JoinOperator::InnerJoin,
131                right: JoinRightSide::Table(Table::from("taggings")),
132                constraint: JoinConstraint::On(join_cond),
133            }],
134            ..Default::default()
135        };
136
137        let q = res.unwrap().1;
138        assert_eq!(q, expected_stmt);
139        assert_eq!(qstring, format!("{}", q));
140    }
141}