1use std::fmt;
2use std::str;
3
4use column::Column;
5use condition::ConditionExpression;
6use nom::branch::alt;
7use nom::bytes::complete::tag_no_case;
8use nom::combinator::map;
9use nom::IResult;
10use select::{JoinClause, SelectStatement};
11use table::Table;
12
13#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
14pub enum JoinRightSide {
15 Table(Table),
17 Tables(Vec<Table>),
19 NestedSelect(Box<SelectStatement>, Option<String>),
21 NestedJoin(Box<JoinClause>),
23}
24
25impl fmt::Display for JoinRightSide {
26 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27 match *self {
28 JoinRightSide::Table(ref t) => write!(f, "{}", t)?,
29 JoinRightSide::NestedSelect(ref q, ref a) => {
30 write!(f, "({})", q)?;
31 if a.is_some() {
32 write!(f, " AS {}", a.as_ref().unwrap())?;
33 }
34 }
35 JoinRightSide::NestedJoin(ref jc) => write!(f, "({})", jc)?,
36 _ => unimplemented!(),
37 }
38 Ok(())
39 }
40}
41
42#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
43pub enum JoinOperator {
44 Join,
45 LeftJoin,
46 LeftOuterJoin,
47 InnerJoin,
48 CrossJoin,
49 StraightJoin,
50}
51
52impl fmt::Display for JoinOperator {
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 match *self {
55 JoinOperator::Join => write!(f, "JOIN")?,
56 JoinOperator::LeftJoin => write!(f, "LEFT JOIN")?,
57 JoinOperator::LeftOuterJoin => write!(f, "LEFT OUTER JOIN")?,
58 JoinOperator::InnerJoin => write!(f, "INNER JOIN")?,
59 JoinOperator::CrossJoin => write!(f, "CROSS JOIN")?,
60 JoinOperator::StraightJoin => write!(f, "STRAIGHT JOIN")?,
61 }
62 Ok(())
63 }
64}
65
66#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
67pub enum JoinConstraint {
68 On(ConditionExpression),
69 Using(Vec<Column>),
70}
71
72impl fmt::Display for JoinConstraint {
73 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74 match *self {
75 JoinConstraint::On(ref ce) => write!(f, "ON {}", ce)?,
76 JoinConstraint::Using(ref columns) => write!(
77 f,
78 "USING ({})",
79 columns
80 .iter()
81 .map(|c| format!("{}", c))
82 .collect::<Vec<_>>()
83 .join(", ")
84 )?,
85 }
86 Ok(())
87 }
88}
89
90pub fn join_operator(i: &[u8]) -> IResult<&[u8], JoinOperator> {
92 alt((
93 map(tag_no_case("join"), |_| JoinOperator::Join),
94 map(tag_no_case("left join"), |_| JoinOperator::LeftJoin),
95 map(tag_no_case("left outer join"), |_| {
96 JoinOperator::LeftOuterJoin
97 }),
98 map(tag_no_case("inner join"), |_| JoinOperator::InnerJoin),
99 map(tag_no_case("cross join"), |_| JoinOperator::CrossJoin),
100 map(tag_no_case("straight_join"), |_| JoinOperator::StraightJoin),
101 ))(i)
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use common::{FieldDefinitionExpression, Operator};
108 use condition::ConditionBase::*;
109 use condition::ConditionExpression::{self, *};
110 use condition::ConditionTree;
111 use select::{selection, JoinClause, SelectStatement};
112
113 #[test]
114 fn inner_join() {
115 let qstring = "SELECT tags.* FROM tags \
116 INNER JOIN taggings ON tags.id = taggings.tag_id";
117
118 let res = selection(qstring.as_bytes());
119
120 let ct = ConditionTree {
121 left: Box::new(Base(Field(Column::from("tags.id")))),
122 right: Box::new(Base(Field(Column::from("taggings.tag_id")))),
123 operator: Operator::Equal,
124 };
125 let join_cond = ConditionExpression::ComparisonOp(ct);
126 let expected_stmt = SelectStatement {
127 tables: vec![Table::from("tags")],
128 fields: vec![FieldDefinitionExpression::AllInTable("tags".into())],
129 join: vec![JoinClause {
130 operator: JoinOperator::InnerJoin,
131 right: JoinRightSide::Table(Table::from("taggings")),
132 constraint: JoinConstraint::On(join_cond),
133 }],
134 ..Default::default()
135 };
136
137 let q = res.unwrap().1;
138 assert_eq!(q, expected_stmt);
139 assert_eq!(qstring, format!("{}", q));
140 }
141}