1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
use serde::Serialize; use std::fmt::Debug; use thiserror::Error; use sqlparser::ast::{Assignment, Ident}; use super::context::FilterContext; use super::evaluate::{evaluate, Evaluated}; use crate::data::{Row, Value}; use crate::result::Result; use crate::store::Store; #[derive(Error, Serialize, Debug, PartialEq)] pub enum UpdateError { #[error("column not found {0}")] ColumnNotFound(String), #[error("conflict on schema, row data does not fit to schema")] ConflictOnSchema, #[error("unreachable")] Unreachable, } pub struct Update<'a, T: 'static + Debug> { storage: &'a dyn Store<T>, table_name: &'a str, fields: &'a [Assignment], columns: &'a [Ident], } impl<'a, T: 'static + Debug> Update<'a, T> { pub fn new( storage: &'a dyn Store<T>, table_name: &'a str, fields: &'a [Assignment], columns: &'a [Ident], ) -> Result<Self> { for assignment in fields.iter() { let Assignment { id, .. } = assignment; if columns.iter().all(|column| column.value != id.value) { return Err(UpdateError::ColumnNotFound(id.value.to_string()).into()); } } Ok(Self { storage, table_name, fields, columns, }) } fn find(&self, row: &Row, column: &Ident) -> Option<Result<Value>> { let context = FilterContext::new(self.table_name, self.columns, row, None); let context = Some(&context); self.fields .iter() .find(|assignment| assignment.id.value == column.value) .map(|assignment| { let Assignment { id, value } = &assignment; let index = self .columns .iter() .position(|column| column.value == id.value) .ok_or_else(|| UpdateError::Unreachable)?; let evaluated = evaluate(self.storage, context, None, value)?; let Row(values) = &row; let value = &values[index]; match evaluated { Evaluated::LiteralRef(v) => value.clone_by(v), Evaluated::Literal(v) => value.clone_by(&v), Evaluated::StringRef(v) => Ok(Value::Str(v.to_string())), Evaluated::ValueRef(v) => Ok(v.clone()), Evaluated::Value(v) => Ok(v), } }) } pub fn apply(&self, row: Row) -> Result<Row> { let Row(values) = &row; values .clone() .into_iter() .enumerate() .map(|(i, value)| { let column = &self .columns .get(i) .ok_or_else(|| UpdateError::ConflictOnSchema)?; self.find(&row, column).unwrap_or(Ok(value)) }) .collect::<Result<_>>() .map(Row) } }