1use crate::{errors::LoadStateError, sql_migration::SqlMigrations};
2use async_trait::async_trait;
3use futures::TryFutureExt;
4use serde::{de::DeserializeOwned, Serialize};
5use sqlx::{
6 self,
7 sqlite::{SqlitePoolOptions, SqliteRow},
8 Row, SqlitePool,
9};
10
11use super::{StateLoader, StateSaver};
12
13pub struct SqliteStateMigrations {}
14
15impl SqlMigrations for SqliteStateMigrations {
16 fn queries() -> Vec<String> {
17 let migration_001 = include_str!("./migrations/0001-sqlite-init.sql");
18 vec![migration_001.to_string()]
19 }
20}
21
22#[derive(Debug)]
23pub struct SqliteState {
24 pool: SqlitePool,
25}
26
27impl SqliteState {
28 pub fn pool() -> SqlitePoolOptions {
29 SqlitePoolOptions::new()
30 }
31
32 pub fn new(pool: SqlitePool) -> Self {
33 Self { pool }
34 }
35
36 pub async fn migrate(&self) {
37 let mut transaction = self.pool.begin().await.unwrap();
38 let queries = SqliteStateMigrations::queries();
39 for query in queries {
40 sqlx::query(&query)
41 .execute(&mut *transaction)
42 .await
43 .unwrap();
44 }
45 transaction.commit().await.unwrap();
46 }
47}
48
49#[async_trait]
50impl<T: DeserializeOwned> StateLoader<T> for SqliteState {
51 async fn prepare(&self) {
52 self.migrate().await;
53 }
54
55 async fn load(
56 &self,
57 object_kind: &str,
58 object_id: &str,
59 state_type: &str,
60 ) -> Result<T, LoadStateError> {
61 let items = sqlx::query(
62 r#"
63 SELECT serialized_state
64 FROM state_provider_object_state
65 WHERE object_kind=$1 AND object_id=$2 AND state_type = $3
66 "#,
67 )
68 .bind(object_kind)
69 .bind(object_id)
70 .bind(state_type)
71 .map(|x: SqliteRow| -> String {
72 let tmp = x.get::<Vec<u8>, _>("serialized_state");
73 String::from_utf8(tmp).expect("TODO")
74 })
75 .fetch_one(&self.pool)
76 .map_err(|_| LoadStateError::ObjectNotFound)
77 .await?;
78 let data = serde_json::from_str(&items).expect("TODO");
79 Ok(data)
80 }
81}
82
83#[async_trait]
84impl<T: Serialize + Send + Sync> StateSaver<T> for SqliteState {
85 async fn prepare(&self) {
86 self.migrate().await;
87 }
88
89 async fn save(
90 &self,
91 object_kind: &str,
92 object_id: &str,
93 state_type: &str,
94 data: &T,
95 ) -> Result<(), LoadStateError> {
96 let serialized_data = serde_json::to_string(&data).expect("TODO");
97 sqlx::query(
98 r#"
99 INSERT INTO
100 state_provider_object_state(object_kind, object_id, state_type, serialized_state)
101 VALUES ($1, $2, $3, $4)
102 ON CONFLICT(object_kind, object_id, state_type) DO UPDATE SET serialized_state=$4
103 "#,
104 )
105 .bind(object_kind)
106 .bind(object_id)
107 .bind(state_type)
108 .bind(serialized_data.bytes().collect::<Vec<_>>())
109 .execute(&self.pool)
110 .map_err(|e| {
111 eprintln!("{:?}", e);
112 LoadStateError::Unknown
113 })
114 .await
115 .map(|_| ())
116 }
117}