rest_api/database/
query.rs

1use std::collections::HashMap;
2use std::fmt::Display;
3
4use hyper::body::to_bytes;
5use hyper::{Request, Body, Method};
6
7use sqlite3::{Cursor, Connection};
8use sqlite3::Result as SqlResult;
9use sqlite3::Value as SqlValue;
10use sqlite3::Error as SqlError;
11
12use super::super::api_http_server::routing::split_uri_args;
13use super::table_schema::SqlTableSchema;
14
15use json::parse;
16
17#[derive(PartialEq)]
18pub enum HttpMethod {
19    GET,
20    POST,
21    DELETE,
22    PATCH,
23    INVALID
24}
25
26#[derive(Debug)]
27pub struct QueryErr (
28    pub String,  // description
29    pub bool,  // server fault?
30);
31
32impl Display for QueryErr {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        self.0.fmt(f)
35    }
36}
37
38// Used to convert the incoming HTTP request to a SQL statement
39#[async_trait::async_trait]
40pub trait Query<'a, T, A> {
41    async fn from_request(request: &mut Request<Body>, table: &'a SqlTableSchema) -> Result<Self, QueryErr>
42        where Self: Sized;
43    fn execute_sql(&'a self, connection: T) -> A;
44}
45
46
47
48pub struct Sqlite3Query<'a> {
49    pub method: HttpMethod,
50    pub table_schema: &'a SqlTableSchema,
51    pub fields_data: HashMap<String, String>,
52    pub filter: HashMap<String, String>,
53}
54
55#[async_trait::async_trait]
56impl<'a> Query<'a, &'a Connection, SqlResult<Cursor<'a>>> for Sqlite3Query<'a> {
57    
58    async fn from_request(request: &mut Request<Body>, table: &'a SqlTableSchema) -> Result<Self, QueryErr> {
59        let method = match request.method().clone() {
60            Method::GET => HttpMethod::GET,
61            Method::PATCH => HttpMethod::PATCH,
62            Method::DELETE => HttpMethod::DELETE,
63            Method::POST => HttpMethod::POST,
64            _ => HttpMethod::INVALID,
65        };
66
67        if method == HttpMethod::INVALID {
68            return Err(QueryErr("Invalid Method".to_string(), false))
69        }
70
71        if method == HttpMethod::GET || method == HttpMethod::DELETE {
72            // GET and DELETE are constructed from uri args
73
74            let (_, uri_args) = split_uri_args(request.uri().to_string());
75
76            let uri_args = uri_args.to_ascii_lowercase();
77
78            let mut uri_args_parsed: HashMap<String, String> = HashMap::new();
79            for arg in uri_args.split('&') {
80                let res = arg.split_once('=');
81
82                if res.is_none() {
83                    continue;
84                }
85
86                let (left, right) = res.unwrap();
87                let left = left.to_lowercase();
88
89                let right_with_space = right.replace('+', " ");
90
91                if table.field_exists(&left) {
92                    uri_args_parsed.insert(left, right_with_space.to_string());
93                }
94            }
95
96            return Ok(Self {
97                method,
98                table_schema: &table,
99                fields_data: HashMap::new(),
100                filter: uri_args_parsed,
101            })
102        }
103
104        // TODO: possible vunerability in to_bytes
105        let body_read_result = to_bytes(request.body_mut()).await;
106        if body_read_result.is_err() {
107            return Err(QueryErr("Error reading request body".to_string(), true))
108        }
109        let body = String::from_utf8(body_read_result.unwrap().into_iter().collect());
110        if body.is_err() {
111            return Err(QueryErr("Error creating string from request body bytes".to_string(), true))
112        }
113
114        let body = body.unwrap();
115        let parsed = parse(
116            &body
117        );
118        if parsed.is_err() {
119            let error = parsed.err().unwrap();
120            return Err(QueryErr(format!("Error parsing json ( {} ): {}", body, error), false))
121        }
122        
123        let mut content = parsed.unwrap();
124        let columns = content.remove("columns");
125
126        if columns.is_null() {
127            return Err(QueryErr("Error getting 'columns' from json".to_string(), false));
128        }
129        
130        let mut data_hashmap = HashMap::new();
131        if columns.is_object() {
132            for col in columns.entries() {
133                let col_as_str = col.1.as_str();
134
135                if col_as_str.is_none() {
136                    return Err(QueryErr("Columns json contains non-string".to_string(), false))
137                }
138
139                // prevent sql injection by only allowing valid field names
140                if table.field_exists(&col.0.to_lowercase()) {
141                    data_hashmap.insert(col.0.to_string(), col_as_str.unwrap().to_string());
142                }
143            }
144        } else if !columns.is_null() {
145            // null means keep empty columns hashmap, if not null, it is wrong type
146            return Err(QueryErr("'columns' in json is wrong type".to_string(), false))
147        }
148
149        let filters = content.remove("filters");
150        let mut filters_hashmap = HashMap::new();
151
152        if filters.is_object() {
153            for filter in filters.entries() {
154                let filter_val = filter.1.as_str();
155                if filter_val.is_none() {
156                    return Err(QueryErr("Filters json contains non-string".to_string(), false))
157                }
158
159                // prevent sql injection by only allowing valid field names
160                if table.field_exists(filter.0) {
161                    filters_hashmap.insert(filter.0.to_string(), filter_val.unwrap().to_string());
162                }
163            }
164        } else if !filters.is_null() {
165            // null means keep empty filters hashmap, if not null, it is wrong type
166            return Err(QueryErr("'filters' in json is wrong type".to_string(), false))
167        }
168
169        Ok(Self {
170            method,
171            table_schema: &table,
172            fields_data: data_hashmap,
173            filter: filters_hashmap
174        })
175    }
176
177    fn execute_sql(&'a self, connection: &'a Connection) -> SqlResult<Cursor<'a>> {
178        match self.method {
179            HttpMethod::GET => self.construct_get_sql(connection),
180            HttpMethod::POST => self.construct_post_sql(connection),
181            HttpMethod::DELETE => self.construct_delete_sql(connection),
182            HttpMethod::PATCH => self.construct_patch_sql(connection),
183            _  => SqlResult::Err(
184                SqlError {code: None, message: Some("Invalid method".to_string())}
185            )
186        }
187    }
188}
189
190
191
192
193impl<'a> Sqlite3Query<'a> {
194    fn construct_get_sql(&'a self, connection: &'a Connection) -> SqlResult<Cursor> {
195        let mut bindings: Vec<SqlValue> = Vec::new();
196        let mut select_builder = "SELECT *".to_string();
197
198        select_builder.push_str(&format!(" FROM {}", self.table_schema.name));
199
200        if self.filter.len() > 0 {
201            select_builder.push_str(" WHERE ");
202
203            for filter in &self.filter {
204                // fields MUST be checked to be valid for the table when constructing query object
205                // or vulnerable to SQL injection
206                select_builder.push_str( &format!("{}=? AND ", filter.0) );
207
208                bindings.push(SqlValue::String(filter.1.clone()));
209            }
210
211            select_builder.remove(select_builder.len()-1);
212            select_builder.remove(select_builder.len()-1);
213            select_builder.remove(select_builder.len()-1);
214            select_builder.remove(select_builder.len()-1);
215            select_builder.remove(select_builder.len()-1);
216        }
217
218        let statement = connection.prepare(select_builder);
219        
220        if statement.is_err() {
221            let error = statement.err().unwrap();
222            return Err(error)
223        }
224
225        let mut bound = statement.unwrap().cursor();
226        let _res = bound.bind(bindings.as_slice());
227
228        Ok(bound)
229    }
230
231    fn construct_post_sql(&'a self, connection: &'a Connection) -> SqlResult<Cursor> {
232        let mut insert_builder = "INSERT INTO ".to_string();
233        insert_builder.push_str(&self.table_schema.name.clone());
234        // null for pk autoincrement col
235        insert_builder.push_str(" VALUES (Null, ");
236
237        let mut bindings: Vec<SqlValue> = Vec::new();
238
239        if self.fields_data.len() == 0 {
240            return Err(SqlError {message: Some("No parsed data in POST body".to_string()), code: None})
241        }
242        
243        // iterate over every field and find corresponding value to insert
244        // TODO: test that fields are in correct order consistently
245        for field in &self.table_schema.fields {
246            let field_value = self.fields_data.get(field.0);
247            if field_value.is_none() {
248                return Err(SqlError {message: Some(format!("Missing field value {}", field.0)), code: None})
249            }
250            let v = field_value.unwrap();
251            insert_builder.push_str("?,");
252            bindings.push(SqlValue::String(v.clone()))
253        }
254
255        insert_builder.remove(insert_builder.len()-1);
256
257        insert_builder.push_str(")");
258
259        // execute the INSERT statement
260        {
261            let post_statement = connection.prepare(insert_builder);
262
263            if post_statement.is_err() {
264                let error = post_statement.err().unwrap();
265                return Err(error)
266            }
267            
268            let mut bound = post_statement.unwrap().cursor();
269            let _res = bound.bind(bindings.as_slice());
270
271            let success = bound.next();
272            if success.is_err() {
273                return Err(success.err().unwrap())
274            }
275        }
276
277        // return a cursor for the new values
278        let select_statement = connection.prepare(
279            format!("SELECT * FROM {} ORDER BY id DESC LIMIT 1", self.table_schema.name)
280        );
281
282        if select_statement.is_err() {
283            let error = select_statement.err().unwrap();
284            return Err(error)
285        }
286
287        Ok(select_statement.unwrap().cursor())
288    }
289    
290    fn construct_delete_sql(&'a self, connection: &'a Connection) -> SqlResult<Cursor> {
291        let mut bindings: Vec<SqlValue> = Vec::new();
292        let mut delete_builder = format!("DELETE FROM {}", self.table_schema.name);
293
294        if self.filter.len() > 0 {
295            delete_builder.push_str(" WHERE ");
296
297            for filter in &self.filter {
298                // fields MUST be checked to be valid for the table when constructing query object
299                // or vulnerable to SQL injection
300                delete_builder.push_str( &format!("{}=? AND ", filter.0) );
301
302                bindings.push(SqlValue::String(filter.1.clone()));
303            }
304
305            // remove last AND
306            delete_builder.remove(delete_builder.len()-1);
307            delete_builder.remove(delete_builder.len()-1);
308            delete_builder.remove(delete_builder.len()-1);
309            delete_builder.remove(delete_builder.len()-1);
310            delete_builder.remove(delete_builder.len()-1);
311        }
312        let statement = connection.prepare(delete_builder);
313        
314        if statement.is_err() {
315            let error = statement.err().unwrap();
316            return Err(error)
317        }
318
319        let mut bound = statement.unwrap().cursor();
320        let _res = bound.bind(bindings.as_slice());
321
322        Ok(bound)
323    }
324
325    fn construct_patch_sql(&'a self, connection: &'a Connection) -> SqlResult<Cursor> {
326        let mut patch_builder = format!("UPDATE {} SET ", self.table_schema.name);
327
328        let mut bindings: Vec<SqlValue> = Vec::new();
329
330        if self.fields_data.len() == 0 {
331            return Err(SqlError {message: Some("No parsed data in PATCH body".to_string()), code: None})
332        }
333
334        for field in &self.table_schema.fields {
335            let field_value = self.fields_data.get(field.0);
336            if field_value.is_none() {
337                continue
338            }
339
340            let v = field_value.unwrap();
341            patch_builder.push_str(&format!("{}=?,", field.0));
342            bindings.push(SqlValue::String(v.clone()))
343        }
344
345        patch_builder.remove(patch_builder.len()-1);
346
347        if self.filter.len() > 0 {
348            patch_builder.push_str(" WHERE");
349
350            for filter in &self.filter {
351                patch_builder.push_str(&format!(" {}=? AND", filter.0));
352                bindings.push(SqlValue::String(filter.1.clone()));
353            }
354
355            patch_builder.remove(patch_builder.len()-1);
356            patch_builder.remove(patch_builder.len()-1);
357            patch_builder.remove(patch_builder.len()-1);
358            patch_builder.remove(patch_builder.len()-1);
359        }
360
361        // execute the update statement
362        let patch_statement = connection.prepare(patch_builder);
363
364        if patch_statement.is_err() {
365            let error = patch_statement.err().unwrap();
366            return Err(error)
367        }
368        
369        let mut bound = patch_statement.unwrap().cursor();
370        let _res = bound.bind(bindings.as_slice());
371
372        return Ok(bound)
373    }
374}