use sqlparser::{ast::*, dialect::Dialect, parser::Parser};
#[derive(Debug)]
pub struct ObserverPostgresSqlDialect {}
impl Dialect for ObserverPostgresSqlDialect {
fn is_identifier_start(&self, ch: char) -> bool {
('a'..='z').contains(&ch)
|| ('A'..='Z').contains(&ch)
|| (ch == '@')
|| ch == '$'
|| ch == '_'
}
fn is_identifier_part(&self, ch: char) -> bool {
('a'..='z').contains(&ch)
|| ('A'..='Z').contains(&ch)
|| ('0'..='9').contains(&ch)
|| (ch == '@')
|| ch == '$'
|| ch == '_'
}
}
#[allow(dead_code)]
fn split_query_by_where(query: &str) -> String {
let query = query.to_lowercase();
let sql: Vec<&str> = query.split("where").collect::<Vec<&str>>();
(*sql.first().unwrap_or(&query.as_str())).to_string()
}
#[allow(dead_code)]
pub fn parse_sql(sql: &str) -> (String, String) {
match Parser::parse_sql(
&ObserverPostgresSqlDialect {},
split_query_by_where(sql).as_str(),
) {
Ok(ast) =>
{
#[allow(clippy::never_loop)]
for x in ast {
match x {
Statement::Query(query) => {
match query.body {
SetExpr::Select(select) => {
let mut table_name = vec![];
for x in select.from {
table_name.push(x.relation.to_string());
for join in x.joins {
table_name.push(join.relation.to_string());
}
}
return ("select".to_string(), table_name.join("__"));
}
_ => return ("unknown4".to_string(), "unknown4".to_string()),
};
}
Statement::Update { table_name, .. } => {
return ("update".to_string(), table_name.to_string());
}
Statement::Insert { table_name, .. } => {
return ("insert".to_string(), table_name.to_string());
}
Statement::Copy { table_name, .. } => {
return ("copy".to_string(), table_name.to_string());
}
Statement::Delete { table_name, .. } => {
return ("delete".to_string(), table_name.to_string());
}
Statement::CreateView { name, .. } => {
return ("create_view".to_string(), name.to_string());
}
Statement::CreateTable { name, .. } => {
return ("create_table".to_string(), name.to_string());
}
Statement::AlterTable { name, .. } => {
return ("alter".to_string(), name.to_string());
}
Statement::Drop { names, .. } => {
return (
"drop".to_string(),
names
.iter()
.map(|x| x.to_string())
.collect::<Vec<String>>()
.join("__"),
);
}
_ => {
return ("unknown3".to_string(), "unknown3".to_string());
}
}
}
}
Err(_err) => {
return ("unknown".to_string(), "unknown".to_string());
}
};
("unknown2".to_string(), "unknown2".to_string())
}
#[cfg(test)]
mod tests {
use super::parse_sql;
#[test]
fn parse_test() {
assert_eq!(
parse_sql("select abc from employee, abc1 where name = asgief"),
("select".to_string(), "employee, abc1".to_string())
);
assert_eq!(
parse_sql("select * from supplier join orders on supplier.id=orders.id;"),
("select".to_string(), "supplier, orders".to_string())
);
assert_eq!(
parse_sql(
r#"
SELECT customer.customer_id FROM customer
INNER JOIN payment ON payment.customer_id = customer.customer_id
INNER JOIN payment1 ON payment1.customer_id = customer.customer_id;
"#
),
(
"select".to_string(),
"customer, payment, payment1".to_string()
)
);
assert_eq!(
parse_sql("update employee set name = asgief"),
("update".to_string(), "employee".to_string())
);
assert_eq!(
parse_sql("insert into employee(id, name) values(1, 23)"),
("insert".to_string(), "employee".to_string())
);
assert_eq!(
parse_sql("delete from employee where name = asgief"),
("delete".to_string(), "employee".to_string())
);
assert_eq!(
parse_sql(
r#"
CREATE TABLE account(
user_id serial PRIMARY KEY,
username VARCHAR (50) UNIQUE NOT NULL,
password VARCHAR (50) NOT NULL,
email VARCHAR (355) UNIQUE NOT NULL,
created_on TIMESTAMP NOT NULL,
last_login TIMESTAMP);
"#
),
("create table".to_string(), "account".to_string())
);
assert_eq!(
parse_sql("drop table employee, employee1;"),
("drop".to_string(), "employee, employee1".to_string())
);
assert_eq!(
parse_sql(
r#" SELECT "users_skill"."id", "users_skill"."name", "users_skill"."description",
"users_skill"."allocation_logic" FROM "users_skill" WHERE "users_skill"."id" > $1"#
),
("select".to_string(), "\"users_skill\"".to_string())
);
assert_eq!(
parse_sql(
r#"
SELECT "ackore_policy"."id", "ackore_policy"."data", "ackore_policy"."created_on",
"ackore_policy"."updated_on", "ackore_policy"."plan_id", "ackore_policy"."user_id",
"ackore_policy"."output", "ackore_policy"."sort_on", "ackore_policy"."payment_id",
"ackore_policy"."insurance_data", "ackore_policy"."intermediary_id",
"ackore_policy"."policy_number", "ackore_policy"."refund_id"
FROM "ackore_policy" WHERE "ackore_policy"."id" = $1 LIMIT $2 -- binds: [143343871, 1]
"#
),
("select".to_string(), "\"ackore_policy\"".to_string())
);
assert_eq!(
parse_sql(
r#"
SELECT customer.customer_id FROM customer
INNER JOIN payment ON payment.customer_id = customer.customer_id
INNER JOIN payment1 ON payment1.customer_id = customer.customer_id;
"#
),
(
"select".to_string(),
"customer, payment, payment1".to_string()
)
);
}
#[test]
fn split_sql_test1() {
let sql = r#"
SELECT "ackore_policy"."id", "ackore_policy"."data", "ackore_policy"."created_on",
"ackore_policy"."updated_on", "ackore_policy"."plan_id", "ackore_policy"."user_id",
"ackore_policy"."output", "ackore_policy"."sort_on", "ackore_policy"."payment_id",
"ackore_policy"."insurance_data", "ackore_policy"."intermediary_id",
"ackore_policy"."policy_number", "ackore_policy"."refund_id"
FROM "ackore_policy" WHERE "ackore_policy"."id" = $1 LIMIT $2 -- binds: [143343871, 1]
"#;
assert_eq!(
r#"
select "ackore_policy"."id", "ackore_policy"."data", "ackore_policy"."created_on",
"ackore_policy"."updated_on", "ackore_policy"."plan_id", "ackore_policy"."user_id",
"ackore_policy"."output", "ackore_policy"."sort_on", "ackore_policy"."payment_id",
"ackore_policy"."insurance_data", "ackore_policy"."intermediary_id",
"ackore_policy"."policy_number", "ackore_policy"."refund_id"
from "ackore_policy" "#,
super::split_query_by_where(sql)
);
}
#[test]
fn split_sql_test2() {
let sql = r#"
SELECT customer.customer_id FROM customer
INNER JOIN payment ON payment.customer_id = customer.customer_id
INNER JOIN payment1 ON payment1.customer_id = customer.customer_id;
"#;
assert_eq!(sql.to_lowercase(), super::split_query_by_where(sql));
}
}