Skip to main content

streamling_state/
sqlite.rs

1/// State Backend backed by Sqlite.
2///
3/// It uses the following table schema:
4///
5/// ```sql
6/// CREATE TABLE state (
7///   namespace TEXT,
8///   key TEXT,
9///   data TEXT NOT NULL,
10///   created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
11///   PRIMARY KEY(namespace, key)
12/// );
13/// ```
14/// Namespace can be used to separate different applications or versions.
15/// Key is used to identify the state value (e.g. individual operator).
16/// Data is the actual state value stored in JSON format.
17use crate::{
18    StateBackendError, StateBackendErrorKind, StateKey, StateOperatorBackend,
19    StateOperatorBackendFactory,
20};
21use async_trait::async_trait;
22use serde::{Deserialize, Serialize};
23use sqlx::pool::PoolOptions;
24use sqlx::sqlite::SqliteConnectOptions;
25use sqlx::{Row, SqlitePool};
26use std::fmt::Debug;
27use std::str::FromStr;
28use std::sync::Arc;
29use tracing::info;
30
31const DEFAULT_MAX_CONNECTIONS: u32 = 10;
32const DEFAULT_TABLE_NAME: &str = "state";
33
34pub struct SqliteStateOperatorBackendFactory {
35    pool: Arc<SqlitePool>,
36    state_table_name: String,
37}
38
39impl SqliteStateOperatorBackendFactory {
40    pub async fn new(
41        database_path: String,
42        max_connections: Option<u32>,
43        state_table_name: Option<String>,
44    ) -> Result<Self, StateBackendError> {
45        let state_table_name = state_table_name.unwrap_or_else(|| DEFAULT_TABLE_NAME.to_string());
46
47        let options = SqliteConnectOptions::from_str(format!("sqlite:{}", database_path).as_str())
48            .unwrap()
49            .create_if_missing(true);
50
51        let pool = PoolOptions::<sqlx::Sqlite>::new()
52            .max_connections(max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS))
53            .connect_with(options)
54            .await
55            .map_err(|e| {
56                StateBackendError::with_source(
57                    StateBackendErrorKind::Connection,
58                    "failed to create SQLite connection pool",
59                    e,
60                )
61            })?;
62
63        let pool = Arc::new(pool);
64
65        Self::initialize(pool.clone(), &state_table_name).await?;
66
67        Ok(Self {
68            pool,
69            state_table_name,
70        })
71    }
72
73    async fn initialize(
74        pool: Arc<SqlitePool>,
75        state_table_name: &str,
76    ) -> Result<(), StateBackendError> {
77        sqlx::query(
78            format!(
79                r#"
80                CREATE TABLE IF NOT EXISTS {} (
81                    namespace TEXT,
82                    key TEXT,
83                    data TEXT NOT NULL,
84                    created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
85                    PRIMARY KEY(namespace, key)
86                );
87            "#,
88                state_table_name
89            )
90            .as_str(),
91        )
92        .execute(pool.as_ref())
93        .await
94        .map(|_| ())
95        .map_err(|e| {
96            StateBackendError::with_source(
97                StateBackendErrorKind::Initialization,
98                "failed to create state table",
99                e,
100            )
101        })
102    }
103}
104
105impl StateOperatorBackendFactory for SqliteStateOperatorBackendFactory {
106    fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
107    where
108        V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Clone + Debug + 'static,
109    {
110        Arc::new(SqliteStateOperatorBackend::new(
111            self.pool.clone(),
112            self.state_table_name.clone(),
113            namespace,
114        ))
115    }
116}
117
118#[derive(Debug)]
119struct SqliteStateOperatorBackend {
120    pool: Arc<SqlitePool>,
121    state_table_name: String,
122    namespace: String,
123}
124
125impl SqliteStateOperatorBackend {
126    fn new(pool: Arc<SqlitePool>, state_table_name: String, namespace: &str) -> Self {
127        info!(
128            "Creating a new SQLite JSON state backend for namespace: {}",
129            namespace
130        );
131
132        Self {
133            pool,
134            state_table_name,
135            namespace: namespace.to_string(),
136        }
137    }
138}
139
140#[async_trait]
141impl<V> StateOperatorBackend<V> for SqliteStateOperatorBackend
142where
143    V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Debug + 'static,
144{
145    async fn get(&self, key: StateKey) -> Result<Option<V>, StateBackendError> {
146        let result = sqlx::query(
147            format!(
148                r#"
149                SELECT data
150                FROM {}
151                WHERE namespace = ? AND key = ?
152            "#,
153                self.state_table_name
154            )
155            .as_str(),
156        )
157        .bind(&self.namespace)
158        .bind(&key.0)
159        .fetch_optional(self.pool.as_ref())
160        .await
161        .map_err(|e| {
162            StateBackendError::with_source(StateBackendErrorKind::Query, "failed to fetch state", e)
163        })?;
164
165        if result.is_none() {
166            return Ok(None);
167        }
168
169        let data = result.unwrap();
170        let json_str: String = data.try_get(0).map_err(|e| {
171            StateBackendError::with_source(
172                StateBackendErrorKind::Query,
173                "failed to read data column",
174                e,
175            )
176        })?;
177
178        serde_json::from_str(&json_str).map(Some).map_err(|e| {
179            StateBackendError::with_source(
180                StateBackendErrorKind::Serialization,
181                "failed to deserialize state",
182                e,
183            )
184        })
185    }
186
187    async fn put(&self, key: StateKey, value: V) -> Result<(), StateBackendError> {
188        let json_str = serde_json::to_string(&value).unwrap();
189        sqlx::query(
190            format!(
191                r#"
192                INSERT INTO {} (namespace, key, data, created_at)
193                VALUES (?, ?, ?, CURRENT_TIMESTAMP)
194                ON CONFLICT(namespace, key) DO UPDATE SET data = excluded.data
195            "#,
196                self.state_table_name
197            )
198            .as_str(),
199        )
200        .bind(&self.namespace)
201        .bind(&key.0)
202        .bind(&json_str)
203        .execute(self.pool.as_ref())
204        .await
205        .map(|_| ())
206        .map_err(|e| {
207            StateBackendError::with_source(
208                StateBackendErrorKind::Query,
209                "failed to update state",
210                e,
211            )
212        })
213    }
214
215    async fn remove(&self, key: StateKey) -> Result<(), StateBackendError> {
216        sqlx::query(
217            format!(
218                r#"
219                DELETE FROM {}
220                WHERE namespace = ? AND key = ?
221            "#,
222                self.state_table_name
223            )
224            .as_str(),
225        )
226        .bind(&self.namespace)
227        .bind(&key.0)
228        .execute(self.pool.as_ref())
229        .await
230        .map(|_| ())
231        .map_err(|e| {
232            StateBackendError::with_source(
233                StateBackendErrorKind::Query,
234                "failed to remove state",
235                e,
236            )
237        })
238    }
239
240    async fn clear(&self) -> Result<(), StateBackendError> {
241        sqlx::query(
242            format!(
243                r#"
244                DELETE FROM {}
245                WHERE namespace = ?
246            "#,
247                self.state_table_name
248            )
249            .as_str(),
250        )
251        .bind(&self.namespace)
252        .execute(self.pool.as_ref())
253        .await
254        .map(|_| ())
255        .map_err(|e| {
256            StateBackendError::with_source(StateBackendErrorKind::Query, "failed to clear state", e)
257        })
258    }
259}