1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
use serde::Serialize;
use std::fmt::Debug;
use thiserror::Error;

use sqlparser::ast::{Assignment, Ident};

use super::context::FilterContext;
use super::evaluate::{evaluate, Evaluated};
use crate::data::{Row, Value};
use crate::result::Result;
use crate::store::Store;

#[derive(Error, Serialize, Debug, PartialEq)]
pub enum UpdateError {
    #[error("column not found {0}")]
    ColumnNotFound(String),

    #[error("conflict on schema, row data does not fit to schema")]
    ConflictOnSchema,

    #[error("unreachable")]
    Unreachable,
}

pub struct Update<'a, T: 'static + Debug> {
    storage: &'a dyn Store<T>,
    table_name: &'a str,
    fields: &'a [Assignment],
    columns: &'a [Ident],
}

impl<'a, T: 'static + Debug> Update<'a, T> {
    pub fn new(
        storage: &'a dyn Store<T>,
        table_name: &'a str,
        fields: &'a [Assignment],
        columns: &'a [Ident],
    ) -> Result<Self> {
        for assignment in fields.iter() {
            let Assignment { id, .. } = assignment;

            if columns.iter().all(|column| column.value != id.value) {
                return Err(UpdateError::ColumnNotFound(id.value.to_string()).into());
            }
        }

        Ok(Self {
            storage,
            table_name,
            fields,
            columns,
        })
    }

    fn find(&self, row: &Row, column: &Ident) -> Option<Result<Value>> {
        let context = FilterContext::new(self.table_name, self.columns, row, None);
        let context = Some(&context);

        self.fields
            .iter()
            .find(|assignment| assignment.id.value == column.value)
            .map(|assignment| {
                let Assignment { id, value } = &assignment;

                let index = self
                    .columns
                    .iter()
                    .position(|column| column.value == id.value)
                    .ok_or_else(|| UpdateError::Unreachable)?;

                let evaluated = evaluate(self.storage, context, None, value)?;
                let Row(values) = &row;
                let value = &values[index];

                match evaluated {
                    Evaluated::LiteralRef(v) => value.clone_by(v),
                    Evaluated::Literal(v) => value.clone_by(&v),
                    Evaluated::StringRef(v) => Ok(Value::Str(v.to_string())),
                    Evaluated::ValueRef(v) => Ok(v.clone()),
                    Evaluated::Value(v) => Ok(v),
                }
            })
    }

    pub fn apply(&self, row: Row) -> Result<Row> {
        let Row(values) = &row;

        values
            .clone()
            .into_iter()
            .enumerate()
            .map(|(i, value)| {
                let column = &self
                    .columns
                    .get(i)
                    .ok_or_else(|| UpdateError::ConflictOnSchema)?;

                self.find(&row, column).unwrap_or(Ok(value))
            })
            .collect::<Result<_>>()
            .map(Row)
    }
}