progscrape-application 0.0.3

progscrape.com application logic
use std::path::Path;

use super::*;
use keepcalm::Shared;
use serde::{Serialize, de::DeserializeOwned};

pub struct DB {
    connection: Shared<rusqlite::Connection>,
}

impl DB {
    pub fn open<P: AsRef<Path>>(location: P) -> Result<Self, PersistError> {
        let db = rusqlite::Connection::open(location)?;
        let connection = Shared::new_unsync(db);
        Ok(Self { connection })
    }

    pub fn table_for<T: Serialize>() -> &'static str {
        std::any::type_name::<T>().rsplit_once(':').unwrap().1
    }

    fn schema_for(schema: Option<&str>) -> &str {
        schema.unwrap_or("main")
    }

    pub fn create_table<T: Serialize + Default>(&self) -> Result<(), PersistError> {
        self.create_table_schema::<T>(None)
    }

    pub fn create_table_schema<T: Serialize + Default>(
        &self,
        schema: Option<&str>,
    ) -> Result<(), PersistError> {
        use rusqlite::types::ToSqlOutput::*;
        use rusqlite::types::*;
        let t = T::default();
        let params = serde_rusqlite::to_params_named(&t)?;
        let mut types = vec![];
        for (name, value) in params.to_slice() {
            let value = value.to_sql()?;
            types.push(format!(
                "{} {} not null",
                name.trim_start_matches(':'),
                match value.to_sql()? {
                    Owned(Value::Integer(..)) | Borrowed(ValueRef::Integer(..)) => {
                        "int"
                    }
                    Owned(Value::Text(..)) | Borrowed(ValueRef::Text(..)) => {
                        "text"
                    }
                    Owned(Value::Real(..)) | Borrowed(ValueRef::Real(..)) => {
                        "real"
                    }
                    Owned(Value::Blob(..)) | Borrowed(ValueRef::Blob(..)) => {
                        "blob"
                    }
                    _ => {
                        return Err(PersistError::UnexpectedError(
                            "Could not map this column type".into(),
                        ));
                    }
                }
            ));
        }
        let sql = format!(
            "create table if not exists {}.{}({})",
            Self::schema_for(schema),
            Self::table_for::<T>(),
            types.join(",")
        );
        self.connection.read().execute(&sql, ())?;
        Ok(())
    }

    pub fn create_unique_index<T: Serialize + Default>(
        &self,
        name: &str,
        keys: &[&str],
    ) -> Result<(), PersistError> {
        self.create_unique_index_schema::<T>(None, name, keys)
    }

    pub fn create_unique_index_schema<T: Serialize + Default>(
        &self,
        _schema: Option<&str>,
        name: &str,
        keys: &[&str],
    ) -> Result<(), PersistError> {
        let t = T::default();
        let _params = serde_rusqlite::to_params_named(&t)?;
        // TODO: check keys
        let sql = format!(
            "create unique index if not exists {} on {}({})",
            name,
            Self::table_for::<T>(),
            keys.join(",")
        );
        self.connection.read().execute(&sql, ())?;
        Ok(())
    }

    pub fn store_batch<T: Serialize>(&self, t: Vec<T>) -> Result<(), PersistError> {
        self.store_batch_schema(None, t)
    }

    pub fn store_batch_schema<T: Serialize>(
        &self,
        schema: Option<&str>,
        t: Vec<T>,
    ) -> Result<(), PersistError> {
        if let Some(first) = t.first() {
            let params = serde_rusqlite::to_params_named(first)?;
            let params_slice = params.to_slice();
            let columns = params_slice
                .iter()
                .map(|f| f.0.trim_start_matches(':'))
                .collect::<Vec<_>>()
                .join(",");
            let values = params_slice
                .iter()
                .map(|f| f.0)
                .collect::<Vec<_>>()
                .join(",");
            let sql = format!(
                "insert or replace into {}.{}({}) values ({})",
                Self::schema_for(schema),
                Self::table_for::<T>(),
                columns,
                values
            );

            let conn = self.connection.read();
            let mut prep = conn.prepare(&sql)?;
            for t in t {
                let params = serde_rusqlite::to_params_named(t)?;
                prep.execute(params.to_slice().as_slice())?;
            }
            prep.finalize()?;
        }
        Ok(())
    }

    pub fn store<T: Serialize>(&self, t: &T) -> Result<(), PersistError> {
        self.store_schema(None, t)
    }

    pub fn store_schema<T: Serialize>(
        &self,
        schema: Option<&str>,
        t: &T,
    ) -> Result<(), PersistError> {
        let params = serde_rusqlite::to_params_named(t)?;
        let params_slice = params.to_slice();
        let columns = params_slice
            .iter()
            .map(|f| f.0.trim_start_matches(':'))
            .collect::<Vec<_>>()
            .join(",");
        let values = params_slice
            .iter()
            .map(|f| f.0)
            .collect::<Vec<_>>()
            .join(",");
        let sql = format!(
            "insert or replace into {}.{}({}) values ({})",
            Self::schema_for(schema),
            Self::table_for::<T>(),
            columns,
            values
        );
        self.connection
            .read()
            .execute(&sql, params_slice.as_slice())?;
        Ok(())
    }

    pub fn load<T: Serialize + DeserializeOwned>(
        &self,
        id: String,
    ) -> Result<Option<T>, PersistError> {
        self.load_schema(None, id)
    }

    pub fn load_schema<T: Serialize + DeserializeOwned>(
        &self,
        schema: Option<&str>,
        id: String,
    ) -> Result<Option<T>, PersistError> {
        let sql = format!(
            "select * from {}.{} where id = ?",
            Self::schema_for(schema),
            Self::table_for::<T>()
        );
        match self
            .connection
            .read()
            .query_row_and_then(&sql, [id], |row| serde_rusqlite::from_row::<T>(row))
        {
            Err(serde_rusqlite::Error::Rusqlite(rusqlite::Error::QueryReturnedNoRows)) => Ok(None),
            Err(x) => Err(PersistError::SerdeError(x)),
            Ok(x) => Ok(Some(x)),
        }
    }

    pub fn execute_raw(&self, sql: &str) -> Result<(), PersistError> {
        self.connection.read().execute_batch(sql)?;
        Ok(())
    }

    pub fn query_raw<T: Serialize + DeserializeOwned>(
        &self,
        sql: &str,
    ) -> Result<Vec<T>, PersistError> {
        let mut v = vec![];
        self.query_raw_callback(sql, |x| {
            v.push(x);
            Ok(())
        })?;
        Ok(v)
    }

    pub fn query_raw_callback<
        T: Serialize + DeserializeOwned,
        F: FnMut(T) -> Result<(), PersistError>,
    >(
        &self,
        sql: &str,
        mut f: F,
    ) -> Result<(), PersistError> {
        let db = self.connection.read();
        let mut stmt = db.prepare(sql)?;
        let mut res = stmt.query([])?;
        loop {
            match res.next() {
                Ok(Some(row)) => {
                    f(serde_rusqlite::from_row::<T>(row)?)?;
                }
                Ok(None) => {
                    break;
                }
                x @ Err(_) => {
                    x?;
                }
            }
        }
        Ok(())
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use serde::{Deserialize, Serialize};

    #[derive(Deserialize, Serialize, Debug, Default, Eq, PartialEq)]
    struct TestSerialize {
        id: String,
        integer: u32,
        string: String,
    }

    #[test]
    fn load_raw() {
        let db = DB::open(":memory:").unwrap();
        let out = db
            .query_raw::<TestSerialize>("select 'x' as id, 1 as integer, 'y' as string")
            .unwrap();
        assert_eq!(
            out[0],
            TestSerialize {
                id: "x".into(),
                integer: 1,
                string: "y".into()
            }
        );
    }

    #[test]
    fn load_store() {
        let db = DB::open(":memory:").unwrap();
        db.create_table::<TestSerialize>().unwrap();
        let input = TestSerialize {
            id: "a".into(),
            integer: 123,
            string: "hello".into(),
        };
        db.store(&input).unwrap();
        let output = db.load("a".into()).unwrap();
        assert_eq!(Some(input), output);
    }

    #[test]
    fn load_store_missing() {
        let db = DB::open(":memory:").unwrap();
        db.create_table::<TestSerialize>().unwrap();
        let output = db.load::<TestSerialize>("a".into()).unwrap();
        assert_eq!(None, output);
    }
}