Skip to main content

rio_rs/state/
sqlite.rs

1use crate::{errors::LoadStateError, sql_migration::SqlMigrations};
2use async_trait::async_trait;
3use futures::TryFutureExt;
4use serde::{de::DeserializeOwned, Serialize};
5use sqlx::{
6    self,
7    sqlite::{SqlitePoolOptions, SqliteRow},
8    Row, SqlitePool,
9};
10
11use super::{StateLoader, StateSaver};
12
13pub struct SqliteStateMigrations {}
14
15impl SqlMigrations for SqliteStateMigrations {
16    fn queries() -> Vec<String> {
17        let migration_001 = include_str!("./migrations/0001-sqlite-init.sql");
18        vec![migration_001.to_string()]
19    }
20}
21
22#[derive(Debug)]
23pub struct SqliteState {
24    pool: SqlitePool,
25}
26
27impl SqliteState {
28    pub fn pool() -> SqlitePoolOptions {
29        SqlitePoolOptions::new()
30    }
31
32    pub fn new(pool: SqlitePool) -> Self {
33        Self { pool }
34    }
35
36    pub async fn migrate(&self) {
37        let mut transaction = self.pool.begin().await.unwrap();
38        let queries = SqliteStateMigrations::queries();
39        for query in queries {
40            sqlx::query(&query)
41                .execute(&mut *transaction)
42                .await
43                .unwrap();
44        }
45        transaction.commit().await.unwrap();
46    }
47}
48
49#[async_trait]
50impl<T: DeserializeOwned> StateLoader<T> for SqliteState {
51    async fn prepare(&self) {
52        self.migrate().await;
53    }
54
55    async fn load(
56        &self,
57        object_kind: &str,
58        object_id: &str,
59        state_type: &str,
60    ) -> Result<T, LoadStateError> {
61        let items = sqlx::query(
62            r#"
63            SELECT serialized_state
64            FROM state_provider_object_state
65            WHERE object_kind=$1 AND object_id=$2 AND state_type = $3
66            "#,
67        )
68        .bind(object_kind)
69        .bind(object_id)
70        .bind(state_type)
71        .map(|x: SqliteRow| -> String {
72            let tmp = x.get::<Vec<u8>, _>("serialized_state");
73            String::from_utf8(tmp).expect("TODO")
74        })
75        .fetch_one(&self.pool)
76        .map_err(|_| LoadStateError::ObjectNotFound)
77        .await?;
78        let data = serde_json::from_str(&items).expect("TODO");
79        Ok(data)
80    }
81}
82
83#[async_trait]
84impl<T: Serialize + Send + Sync> StateSaver<T> for SqliteState {
85    async fn prepare(&self) {
86        self.migrate().await;
87    }
88
89    async fn save(
90        &self,
91        object_kind: &str,
92        object_id: &str,
93        state_type: &str,
94        data: &T,
95    ) -> Result<(), LoadStateError> {
96        let serialized_data = serde_json::to_string(&data).expect("TODO");
97        sqlx::query(
98            r#"
99            INSERT INTO
100                state_provider_object_state(object_kind, object_id, state_type, serialized_state)
101            VALUES ($1, $2, $3, $4)
102            ON CONFLICT(object_kind, object_id, state_type) DO UPDATE SET serialized_state=$4
103            "#,
104        )
105        .bind(object_kind)
106        .bind(object_id)
107        .bind(state_type)
108        .bind(serialized_data.bytes().collect::<Vec<_>>())
109        .execute(&self.pool)
110        .map_err(|e| {
111            eprintln!("{:?}", e);
112            LoadStateError::Unknown
113        })
114        .await
115        .map(|_| ())
116    }
117}