1mod definition_lookup;
4use definition_lookup::{DefinitionLookup, DefinitionLookupError};
5
6use crate::constants::Nullable;
7use crate::engine::objects::{JoinType, SqlTuple};
8
9use super::io::VisibleRowManager;
10use super::objects::types::{BaseSqlTypes, BaseSqlTypesError, SqlTypeDefinition};
11use super::objects::{
12 Attribute, CommandType, ParseExpression, ParseTree, QueryTree, RangeRelation,
13 RangeRelationTable, RawInsertCommand, RawSelectCommand, Table,
14};
15use super::transactions::TransactionId;
16use std::collections::HashMap;
17use std::sync::Arc;
18use thiserror::Error;
19
20#[derive(Clone)]
21pub struct Analyzer {
22 dl: DefinitionLookup,
23}
24
25impl Analyzer {
26 pub fn new(vis_row_man: VisibleRowManager) -> Analyzer {
27 Analyzer {
28 dl: DefinitionLookup::new(vis_row_man),
29 }
30 }
31
32 pub async fn analyze(
33 &self,
34 tran_id: TransactionId,
35 parse_tree: ParseTree,
36 ) -> Result<QueryTree, AnalyzerError> {
37 match parse_tree {
38 ParseTree::Insert(i) => self.insert_processing(tran_id, i).await,
39 ParseTree::Select(i) => self.select_processing(tran_id, i).await,
40 _ => Err(AnalyzerError::NotImplemented()),
41 }
42 }
43
44 async fn insert_processing(
45 &self,
46 tran_id: TransactionId,
47 raw_insert: RawInsertCommand,
48 ) -> Result<QueryTree, AnalyzerError> {
49 let definition = self
50 .dl
51 .get_definition(tran_id, raw_insert.table_name)
52 .await?;
53
54 let (output_type, val_cols) = Analyzer::validate_columns(
55 definition.clone(),
56 raw_insert.provided_columns,
57 raw_insert.provided_values,
58 )?;
59
60 let anon_tbl = RangeRelation::AnonymousTable(Arc::new(vec![val_cols]));
61 let target_tbl = RangeRelation::Table(RangeRelationTable {
62 alias: None,
63 table: definition,
64 });
65
66 Ok(QueryTree {
68 command_type: CommandType::Insert,
69 targets: Arc::new(output_type),
71 range_tables: vec![target_tbl.clone(), anon_tbl.clone()],
72 joins: vec![(JoinType::Inner, target_tbl, anon_tbl)],
73 })
74 }
75
76 async fn select_processing(
77 &self,
78 tran_id: TransactionId,
79 raw_select: RawSelectCommand,
80 ) -> Result<QueryTree, AnalyzerError> {
81 let definition = self.dl.get_definition(tran_id, raw_select.table).await?;
82
83 let mut targets = vec![];
85 'outer: for rcol in raw_select.columns {
86 for c in definition.attributes.as_slice() {
87 if rcol == c.name {
88 targets.push((c.name.clone(), c.sql_type.clone()));
89 continue 'outer;
90 }
91 }
92 return Err(AnalyzerError::UnknownColumn(rcol));
93 }
94
95 Ok(QueryTree {
97 command_type: CommandType::Select,
98 targets: Arc::new(SqlTypeDefinition(targets)),
99 range_tables: vec![RangeRelation::Table(RangeRelationTable {
100 table: definition,
101 alias: None,
102 })],
103 joins: vec![],
104 })
105 }
106
107 fn validate_columns(
109 table: Arc<Table>,
110 provided_columns: Option<Vec<String>>,
111 provided_values: Vec<ParseExpression>,
112 ) -> Result<(SqlTypeDefinition, SqlTuple), AnalyzerError> {
113 let columns = match provided_columns {
114 Some(pc) => {
115 let mut provided_pair: HashMap<String, ParseExpression> =
117 pc.into_iter().zip(provided_values).collect();
118 let mut result = vec![];
119 for a in table.attributes.clone() {
120 match provided_pair.get(&a.name) {
121 Some(ppv) => {
122 result.push((a.clone(), Some(ppv.clone())));
123 provided_pair.remove(&a.name);
124 }
125 None => match a.nullable {
126 Nullable::NotNull => return Err(AnalyzerError::MissingColumn(a)),
127 Nullable::Null => result.push((a, None)),
128 },
129 }
130 }
131
132 if !provided_pair.is_empty() {
133 return Err(AnalyzerError::UnknownColumns(
134 provided_pair.keys().cloned().collect(),
135 ));
136 }
137
138 result
139 }
140 None => {
141 table
143 .attributes
144 .clone()
145 .into_iter()
146 .zip(provided_values)
147 .map(|(a, s)| (a, Some(s)))
148 .collect()
149 }
150 };
151
152 Analyzer::convert_into_types(columns)
153 }
154
155 fn convert_into_types(
156 provided: Vec<(Attribute, Option<ParseExpression>)>,
157 ) -> Result<(SqlTypeDefinition, SqlTuple), AnalyzerError> {
158 let mut tbl_cols = vec![];
159 let mut val_cols = vec![];
160 for (a, s) in provided {
161 match s {
162 Some(s2) => match s2 {
163 ParseExpression::String(s3) => {
164 tbl_cols.push((a.name, a.sql_type.clone()));
165 val_cols.push(Some(BaseSqlTypes::parse(a.sql_type, &s3)?));
166 }
167 ParseExpression::Null() => {
168 tbl_cols.push((a.name, a.sql_type));
169 val_cols.push(None);
170 }
171 },
172 None => {
173 tbl_cols.push((a.name, a.sql_type));
174 val_cols.push(None);
175 }
176 }
177 }
178 Ok((SqlTypeDefinition(tbl_cols), SqlTuple(val_cols)))
179 }
180}
181
182#[derive(Debug, Error)]
183pub enum AnalyzerError {
184 #[error(transparent)]
185 DefinitionLookupError(#[from] DefinitionLookupError),
186 #[error(transparent)]
187 BaseSqlTypesError(#[from] BaseSqlTypesError),
188 #[error("Provided columns {0:?} does not match the underlying table columns {1:?}")]
189 ColumnVsColumnMismatch(Vec<String>, Vec<String>),
190 #[error("Provided value count {0} does not match the underlying table column count {1}")]
191 ValueVsColumnMismatch(usize, usize),
192 #[error("Missing required column {0}")]
193 MissingColumn(Attribute),
194 #[error("Unknown column received {0}")]
195 UnknownColumn(String),
196 #[error("Unknown columns received {0:?}")]
197 UnknownColumns(Vec<String>),
198 #[error("Not implemented")]
199 NotImplemented(),
200}