houseflow_db/postgres/
mod.rs

1use crate::Error;
2
3use async_trait::async_trait;
4use deadpool_postgres::Pool;
5use houseflow_config::postgres::Config;
6use houseflow_types::{Device, DeviceID, User, UserID, UserStructure};
7use semver::Version;
8use tokio_postgres::NoTls;
9
10use refinery::embed_migrations;
11embed_migrations!("migrations");
12
13#[derive(Debug, thiserror::Error)]
14pub enum InternalError {
15    #[error("Error when sending query: `{0}`")]
16    QueryError(#[from] tokio_postgres::Error),
17
18    #[error("pool error: {0}")]
19    PoolError(#[from] deadpool_postgres::PoolError),
20
21    #[error("Column `{column}` is invalid: `{error}`")]
22    InvalidColumn {
23        column: &'static str,
24        error: Box<dyn std::error::Error + Send + Sync>,
25    },
26
27    #[error("Error when running migrations: `{0}`")]
28    MigrationError(#[from] refinery::Error),
29}
30
31use crate::DatabaseInternalError;
32
33impl DatabaseInternalError for InternalError {}
34impl DatabaseInternalError for deadpool_postgres::PoolError {}
35impl DatabaseInternalError for tokio_postgres::Error {}
36impl DatabaseInternalError for refinery::Error {}
37
38#[derive(Clone)]
39pub struct Database {
40    pool: Pool,
41}
42
43impl Database {
44    fn get_pool_config(cfg: &Config) -> deadpool_postgres::Config {
45        let mut dpcfg = deadpool_postgres::Config::new();
46        dpcfg.user = Some(cfg.user.to_string());
47        dpcfg.password = Some(cfg.password.to_string());
48        dpcfg.host = Some(cfg.address.ip().to_string());
49        dpcfg.port = Some(cfg.address.port());
50        dpcfg.dbname = Some(cfg.database_name.to_string());
51        dpcfg
52    }
53
54    /// This function connect with database and runs migrations on it, after doing so it's fully
55    /// ready for operations
56    pub async fn new(opts: &Config) -> Result<Self, Error> {
57        use std::ops::DerefMut;
58
59        let pool_config = Self::get_pool_config(&opts);
60        let pool = pool_config
61            .create_pool(NoTls)
62            .expect("invalid pool configuration");
63        let mut obj = pool.get().await?;
64        let client = obj.deref_mut().deref_mut();
65        migrations::runner().run_async(client).await?;
66        Ok(Self { pool })
67    }
68}
69
70#[async_trait]
71impl crate::Database for Database {
72    async fn add_structure(&self, structure: &houseflow_types::Structure) -> Result<(), Error> {
73        let connection = self.pool.get().await?;
74        let insert_statement = connection
75            .prepare(
76                r#"
77            INSERT INTO structures (id, name) 
78            VALUES ($1, $2)
79            "#,
80            )
81            .await?;
82
83        let n = connection
84            .execute(&insert_statement, &[&structure.id, &structure.name])
85            .await?;
86
87        match n {
88            0 => Err(Error::NotModified),
89            1 => Ok(()),
90            _ => unreachable!(),
91        }
92    }
93
94    async fn add_room(&self, room: &houseflow_types::Room) -> Result<(), Error> {
95        let connection = self.pool.get().await?;
96        let insert_statement = connection
97            .prepare(
98                r#"
99            INSERT INTO rooms (id, structure_id, name) 
100            VALUES ($1, $2, $3)
101            "#,
102            )
103            .await?;
104
105        let n = connection
106            .execute(
107                &insert_statement,
108                &[&room.id, &room.structure_id, &room.name],
109            )
110            .await?;
111
112        match n {
113            0 => Err(Error::NotModified),
114            1 => Ok(()),
115            _ => unreachable!(),
116        }
117    }
118
119    async fn add_device(&self, device: &Device) -> Result<(), Error> {
120        let connection = self.pool.get().await?;
121        let insert_statement = connection.prepare(
122            r#"
123            INSERT INTO devices(
124                id, room_id, password_hash, type, traits, name, will_push_state, model, hw_version, sw_version, attributes
125            ) 
126            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
127            "#,
128        ).await?;
129
130        let n = connection
131            .execute(
132                &insert_statement,
133                &[
134                    &device.id,
135                    &device.room_id,
136                    &device.password_hash,
137                    &device.device_type.to_string(),
138                    &device
139                        .traits
140                        .iter()
141                        .map(|t| t.to_string())
142                        .collect::<Vec<String>>(),
143                    &device.name,
144                    &device.will_push_state,
145                    &device.model,
146                    &device.hw_version.to_string(),
147                    &device.sw_version.to_string(),
148                    &device.attributes,
149                ],
150            )
151            .await?;
152
153        match n {
154            0 => Err(Error::NotModified),
155            1 => Ok(()),
156            _ => unreachable!(),
157        }
158    }
159
160    async fn add_user_structure(&self, user_structure: &UserStructure) -> Result<(), Error> {
161        let connection = self.pool.get().await?;
162        let insert_statement = connection
163            .prepare(
164                r#"
165            INSERT INTO user_structures (structure_id, user_id, is_manager) 
166            VALUES ($1, $2, $3)
167            "#,
168            )
169            .await?;
170
171        let n = connection
172            .execute(
173                &insert_statement,
174                &[
175                    &user_structure.structure_id,
176                    &user_structure.user_id,
177                    &user_structure.is_manager,
178                ],
179            )
180            .await?;
181
182        match n {
183            0 => Err(Error::NotModified),
184            1 => Ok(()),
185            _ => unreachable!(),
186        }
187    }
188
189    async fn add_user(&self, user: &User) -> Result<(), Error> {
190        let connection = self.pool.get().await?;
191        let check_exists_statement = connection.prepare(
192            r#"
193            SELECT 1
194            FROM users 
195            WHERE email = $1
196            OR username = $2
197            "#,
198        );
199
200        let insert_statement = connection.prepare(
201            r#"
202            INSERT INTO users(id, username, email, password_hash) 
203            VALUES ($1, $2, $3, $4)
204            "#,
205        );
206
207        let (check_exists_statement, insert_statement) =
208            tokio::join!(check_exists_statement, insert_statement);
209
210        let (check_exists_statement, insert_statement) =
211            (check_exists_statement?, insert_statement?);
212
213        let exists = connection
214            .query_opt(&check_exists_statement, &[&user.email, &user.username])
215            .await?
216            .is_some();
217
218        if exists {
219            return Err(Error::AlreadyExists);
220        }
221
222        let n = connection
223            .execute(
224                &insert_statement,
225                &[&user.id, &user.username, &user.email, &user.password_hash],
226            )
227            .await?;
228
229        match n {
230            0 => Err(Error::NotModified),
231            1 => Ok(()),
232            _ => unreachable!(),
233        }
234    }
235
236    async fn get_device(&self, device_id: &DeviceID) -> Result<Option<Device>, Error> {
237        const QUERY: &str = "
238            SELECT * 
239            FROM devices 
240            WHERE id = $1";
241        let connection = self.pool.get().await?;
242        let row = match connection.query_opt(QUERY, &[&device_id]).await? {
243            Some(row) => row,
244            None => return Ok(None),
245        };
246
247        let device = Device {
248            id: row.try_get("id")?,
249            password_hash: row.try_get("password_hash")?,
250            device_type: row.try_get("type")?,
251            traits: row.try_get("traits")?,
252            name: row.try_get("name")?,
253            will_push_state: row.try_get("will_push_state")?,
254            room_id: row.try_get("room_id")?,
255            model: row.try_get("model")?,
256            hw_version: Version::parse(row.try_get("hw_version")?).map_err(|err| {
257                InternalError::InvalidColumn {
258                    column: "hw_version",
259                    error: Box::new(err),
260                }
261            })?,
262            sw_version: Version::parse(row.try_get("sw_version")?).map_err(|err| {
263                InternalError::InvalidColumn {
264                    column: "sw_version",
265                    error: Box::new(err),
266                }
267            })?,
268            attributes: row.try_get("attributes")?,
269        };
270
271        Ok(Some(device))
272    }
273
274    async fn get_user_devices(&self, user_id: &UserID) -> Result<Vec<Device>, Error> {
275        let connection = self.pool.get().await?;
276        let query_statement = connection
277            .prepare(
278                r#"
279            SELECT *
280            FROM devices
281            WHERE room_id = (
282                SELECT id 
283                FROM rooms 
284                WHERE structure_id = (
285                    SELECT structure_id
286                    FROM user_structures
287                    WHERE user_id = $1
288                )
289            )
290            "#,
291            )
292            .await?;
293        let row = connection.query(&query_statement, &[&user_id]).await?;
294        let devices = row.iter().map(|row| {
295            Ok::<Device, Error>(Device {
296                id: row.try_get("id")?,
297                room_id: row.try_get("room_id")?,
298                password_hash: row.try_get("password_hash")?,
299                device_type: row.try_get("type")?,
300                traits: row.try_get("traits")?,
301                name: row.try_get("name")?,
302                will_push_state: row.try_get("will_push_state")?,
303                model: row.try_get("model")?,
304                hw_version: Version::parse(row.try_get("hw_version")?).map_err(|err| {
305                    InternalError::InvalidColumn {
306                        column: "hw_version",
307                        error: Box::new(err),
308                    }
309                })?,
310                sw_version: Version::parse(row.try_get("sw_version")?).map_err(|err| {
311                    InternalError::InvalidColumn {
312                        column: "sw_version",
313                        error: Box::new(err),
314                    }
315                })?,
316                attributes: row.try_get("attributes")?,
317            })
318        });
319        let devices: Result<Vec<Device>, Error> = devices.collect();
320        devices
321    }
322
323    async fn get_user(&self, user_id: &UserID) -> Result<Option<User>, Error> {
324        const QUERY: &str = "SELECT * FROM users WHERE id = $1";
325        let connection = self.pool.get().await?;
326        let row = match connection.query_opt(QUERY, &[&user_id]).await? {
327            Some(row) => row,
328            None => return Ok(None),
329        };
330        let user = User {
331            id: row.try_get("id")?,
332            username: row.try_get("username")?,
333            email: row.try_get("email")?,
334            password_hash: row.try_get("password_hash")?,
335        };
336
337        Ok(Some(user))
338    }
339
340    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>, Error> {
341        const QUERY: &str = "SELECT * FROM users WHERE email = $1";
342        let connection = self.pool.get().await?;
343        let row = match connection.query_opt(QUERY, &[&email.to_string()]).await? {
344            Some(row) => row,
345            None => return Ok(None),
346        };
347        let user = User {
348            id: row.try_get("id")?,
349            username: row.try_get("username")?,
350            email: row.try_get("email")?,
351            password_hash: row.try_get("password_hash")?,
352        };
353
354        Ok(Some(user))
355    }
356
357    async fn check_user_device_access(
358        &self,
359        user_id: &UserID,
360        device_id: &DeviceID,
361    ) -> Result<bool, Error> {
362        let connection = self.pool.get().await?;
363        let query_statement = connection
364            .prepare(
365                r#"
366            SELECT 1
367            FROM devices
368            WHERE id = $1
369            AND room_id = ( 
370                SELECT id 
371                FROM rooms 
372                WHERE structure_id = (
373                    SELECT structure_id
374                    FROM user_structures
375                    WHERE user_id = $2
376                )
377            )
378            "#,
379            )
380            .await?;
381        let result = connection
382            .query_opt(&query_statement, &[&device_id, &user_id])
383            .await?;
384
385        Ok(result.is_some())
386    }
387
388    async fn check_user_device_manager_access(
389        &self,
390        user_id: &UserID,
391        device_id: &DeviceID,
392    ) -> Result<bool, Error> {
393        let connection = self.pool.get().await?;
394        let query_statement = connection
395            .prepare(
396                r#"
397            SELECT 1
398            FROM devices
399            WHERE id = $1
400            AND room_id = ( 
401                SELECT id 
402                FROM rooms 
403                WHERE structure_id = (
404                    SELECT structure_id
405                    FROM user_structures
406                    WHERE user_id = $2
407                    AND is_manager = true
408                )
409            )
410            "#,
411            )
412            .await?;
413        let result = connection
414            .query_opt(&query_statement, &[&device_id, &user_id])
415            .await?;
416
417        Ok(result.is_some())
418    }
419
420    async fn check_user_admin(&self, user_id: &UserID) -> Result<bool, Error> {
421        let connection = self.pool.get().await?;
422        let query_statement = connection
423            .prepare(
424                r#"
425            SELECT 1
426            FROM admins
427            WHERE user_id = $1
428            "#,
429            )
430            .await?;
431
432        let result = connection.query_opt(&query_statement, &[&user_id]).await?;
433
434        Ok(result.is_some())
435    }
436
437    async fn delete_user(&self, user_id: &UserID) -> Result<(), Error> {
438        const QUERY: &str = "DELETE FROM users WHERE id = $1";
439        let connection = self.pool.get().await?;
440        let n = connection.execute(QUERY, &[&user_id]).await?;
441        match n {
442            0 => Err(Error::NotModified),
443            1 => Ok(()),
444            _ => unreachable!(),
445        }
446    }
447}