Skip to main content

rio_rs/object_placement/
sqlite.rs

1//! SQL implementation of the trait [ObjectPlacement] to work with relational databases
2//!
3//! This uses [sqlx] under the hood
4
5use async_trait::async_trait;
6use sqlx::sqlite::SqlitePoolOptions;
7use sqlx::{self, Row, SqlitePool};
8
9use super::{ObjectPlacement, ObjectPlacementItem};
10use crate::sql_migration::SqlMigrations;
11use crate::ObjectId;
12
13pub struct SqliteObjectPlacementMigrations {}
14
15impl SqlMigrations for SqliteObjectPlacementMigrations {
16    fn queries() -> Vec<String> {
17        let migration_001 = include_str!("./migrations/0001-sqlite-init.sql");
18        vec![migration_001.to_string()]
19    }
20}
21
22#[derive(Clone, Debug)]
23pub struct SqliteObjectPlacement {
24    pool: SqlitePool,
25}
26
27impl SqliteObjectPlacement {
28    pub fn new(pool: SqlitePool) -> Self {
29        SqliteObjectPlacement { pool }
30    }
31
32    /// Pool builder, so one doesn't need to include sqlx as a dependency
33    ///
34    /// # Example
35    ///
36    /// ```
37    /// # use rio_rs::object_placement::sqlite::SqliteObjectPlacement;
38    /// # async fn test_fn() {
39    /// let pool = SqliteObjectPlacement::pool()
40    ///     .connect("sqlite::memory:")
41    ///     .await
42    ///     .expect("Connection failure");
43    /// let object_placement = SqliteObjectPlacement::new(pool);
44    /// # }
45    /// ```
46    pub fn pool() -> SqlitePoolOptions {
47        SqlitePoolOptions::new()
48    }
49}
50
51#[async_trait]
52impl ObjectPlacement for SqliteObjectPlacement {
53    /// Run the schema/data migrations for this membership storage.
54    ///
55    /// For now, the Rio server doesn't run this at start-up and it needs
56    /// to be invoked on manually in the server's setup.
57    async fn prepare(&self) {
58        let mut transaction = self.pool.begin().await.unwrap();
59        let queries = SqliteObjectPlacementMigrations::queries();
60        for query in queries {
61            sqlx::query(&query)
62                .execute(&mut *transaction)
63                .await
64                .unwrap();
65        }
66        transaction.commit().await.unwrap();
67    }
68
69    async fn update(&self, object_placement: ObjectPlacementItem) {
70        sqlx::query(
71            r#"
72            INSERT INTO
73            object_placement(struct_name, object_id, server_address)
74            VALUES ($1, $2, $3)
75            ON CONFLICT(struct_name, object_id) DO UPDATE SET server_address=$3"#,
76        )
77        .bind(&object_placement.object_id.0)
78        .bind(&object_placement.object_id.1)
79        .bind(&object_placement.server_address)
80        .execute(&self.pool)
81        .await
82        .unwrap();
83    }
84    async fn lookup(&self, object_id: &ObjectId) -> Option<String> {
85        let row = sqlx::query(
86            r#"
87            SELECT server_address
88            FROM object_placement
89            WHERE struct_name = $1 and object_id = $2
90            "#,
91        )
92        .bind(&object_id.0)
93        .bind(&object_id.1)
94        .fetch_one(&self.pool)
95        .await
96        .ok();
97        row.map(|row| row.get("server_address"))
98    }
99    async fn clean_server(&self, address: String) {
100        sqlx::query(
101            r#"
102            DELETE FROM object_placement
103            WHERE server_address = $1
104            "#,
105        )
106        .bind(&address)
107        .execute(&self.pool)
108        .await
109        .unwrap();
110    }
111
112    async fn remove(&self, object_id: &ObjectId) {
113        sqlx::query(
114            r#"
115            DELETE FROM object_placement
116            WHERE struct_name = $1 and object_id = $2
117            "#,
118        )
119        .bind(&object_id.0)
120        .bind(&object_id.1)
121        .execute(&self.pool)
122        .await
123        .unwrap();
124    }
125}
126
127#[cfg(test)]
128mod test {
129
130    use super::*;
131
132    async fn pool() -> SqlitePool {
133        SqlitePoolOptions::new()
134            .max_connections(5)
135            .connect("sqlite::memory:")
136            .await
137            .expect("TODO: Connection failure")
138    }
139
140    async fn object_placement_provider() -> (SqlitePool, impl ObjectPlacement) {
141        let pool = pool().await;
142        let object_placement_provider = SqliteObjectPlacement::new(pool.clone());
143        object_placement_provider.prepare().await;
144        (pool, object_placement_provider)
145    }
146
147    #[tokio::test]
148    async fn test_sanity() {
149        let (_, object_placement_provider) = object_placement_provider().await;
150        let placement = object_placement_provider
151            .lookup(&ObjectId::new("Test", "1"))
152            .await;
153        assert_eq!(placement, None);
154
155        let object_placement =
156            ObjectPlacementItem::new(ObjectId::new("Test", "1"), Some("0.0.0.0:5000".to_string()));
157        object_placement_provider.update(object_placement).await;
158        let placement = object_placement_provider
159            .lookup(&ObjectId::new("Test", "1"))
160            .await
161            .unwrap();
162        assert_eq!(placement, "0.0.0.0:5000");
163
164        let object_placement =
165            ObjectPlacementItem::new(ObjectId::new("Test", "1"), Some("0.0.0.0:5001".to_string()));
166        object_placement_provider.update(object_placement).await;
167        let placement = object_placement_provider
168            .lookup(&ObjectId::new("Test", "1"))
169            .await
170            .unwrap();
171        assert_eq!(placement, "0.0.0.0:5001");
172
173        object_placement_provider
174            .clean_server("0.0.0.0:5001".to_string())
175            .await;
176        let placement = object_placement_provider
177            .lookup(&ObjectId::new("Test", "1"))
178            .await;
179        assert_eq!(placement, None);
180    }
181}