1use crate::Error;
2
3use async_trait::async_trait;
4use deadpool_postgres::Pool;
5use houseflow_config::postgres::Config;
6use houseflow_types::{Device, DeviceID, User, UserID, UserStructure};
7use semver::Version;
8use tokio_postgres::NoTls;
9
10use refinery::embed_migrations;
11embed_migrations!("migrations");
12
13#[derive(Debug, thiserror::Error)]
14pub enum InternalError {
15 #[error("Error when sending query: `{0}`")]
16 QueryError(#[from] tokio_postgres::Error),
17
18 #[error("pool error: {0}")]
19 PoolError(#[from] deadpool_postgres::PoolError),
20
21 #[error("Column `{column}` is invalid: `{error}`")]
22 InvalidColumn {
23 column: &'static str,
24 error: Box<dyn std::error::Error + Send + Sync>,
25 },
26
27 #[error("Error when running migrations: `{0}`")]
28 MigrationError(#[from] refinery::Error),
29}
30
31use crate::DatabaseInternalError;
32
33impl DatabaseInternalError for InternalError {}
34impl DatabaseInternalError for deadpool_postgres::PoolError {}
35impl DatabaseInternalError for tokio_postgres::Error {}
36impl DatabaseInternalError for refinery::Error {}
37
38#[derive(Clone)]
39pub struct Database {
40 pool: Pool,
41}
42
43impl Database {
44 fn get_pool_config(cfg: &Config) -> deadpool_postgres::Config {
45 let mut dpcfg = deadpool_postgres::Config::new();
46 dpcfg.user = Some(cfg.user.to_string());
47 dpcfg.password = Some(cfg.password.to_string());
48 dpcfg.host = Some(cfg.address.ip().to_string());
49 dpcfg.port = Some(cfg.address.port());
50 dpcfg.dbname = Some(cfg.database_name.to_string());
51 dpcfg
52 }
53
54 pub async fn new(opts: &Config) -> Result<Self, Error> {
57 use std::ops::DerefMut;
58
59 let pool_config = Self::get_pool_config(&opts);
60 let pool = pool_config
61 .create_pool(NoTls)
62 .expect("invalid pool configuration");
63 let mut obj = pool.get().await?;
64 let client = obj.deref_mut().deref_mut();
65 migrations::runner().run_async(client).await?;
66 Ok(Self { pool })
67 }
68}
69
70#[async_trait]
71impl crate::Database for Database {
72 async fn add_structure(&self, structure: &houseflow_types::Structure) -> Result<(), Error> {
73 let connection = self.pool.get().await?;
74 let insert_statement = connection
75 .prepare(
76 r#"
77 INSERT INTO structures (id, name)
78 VALUES ($1, $2)
79 "#,
80 )
81 .await?;
82
83 let n = connection
84 .execute(&insert_statement, &[&structure.id, &structure.name])
85 .await?;
86
87 match n {
88 0 => Err(Error::NotModified),
89 1 => Ok(()),
90 _ => unreachable!(),
91 }
92 }
93
94 async fn add_room(&self, room: &houseflow_types::Room) -> Result<(), Error> {
95 let connection = self.pool.get().await?;
96 let insert_statement = connection
97 .prepare(
98 r#"
99 INSERT INTO rooms (id, structure_id, name)
100 VALUES ($1, $2, $3)
101 "#,
102 )
103 .await?;
104
105 let n = connection
106 .execute(
107 &insert_statement,
108 &[&room.id, &room.structure_id, &room.name],
109 )
110 .await?;
111
112 match n {
113 0 => Err(Error::NotModified),
114 1 => Ok(()),
115 _ => unreachable!(),
116 }
117 }
118
119 async fn add_device(&self, device: &Device) -> Result<(), Error> {
120 let connection = self.pool.get().await?;
121 let insert_statement = connection.prepare(
122 r#"
123 INSERT INTO devices(
124 id, room_id, password_hash, type, traits, name, will_push_state, model, hw_version, sw_version, attributes
125 )
126 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
127 "#,
128 ).await?;
129
130 let n = connection
131 .execute(
132 &insert_statement,
133 &[
134 &device.id,
135 &device.room_id,
136 &device.password_hash,
137 &device.device_type.to_string(),
138 &device
139 .traits
140 .iter()
141 .map(|t| t.to_string())
142 .collect::<Vec<String>>(),
143 &device.name,
144 &device.will_push_state,
145 &device.model,
146 &device.hw_version.to_string(),
147 &device.sw_version.to_string(),
148 &device.attributes,
149 ],
150 )
151 .await?;
152
153 match n {
154 0 => Err(Error::NotModified),
155 1 => Ok(()),
156 _ => unreachable!(),
157 }
158 }
159
160 async fn add_user_structure(&self, user_structure: &UserStructure) -> Result<(), Error> {
161 let connection = self.pool.get().await?;
162 let insert_statement = connection
163 .prepare(
164 r#"
165 INSERT INTO user_structures (structure_id, user_id, is_manager)
166 VALUES ($1, $2, $3)
167 "#,
168 )
169 .await?;
170
171 let n = connection
172 .execute(
173 &insert_statement,
174 &[
175 &user_structure.structure_id,
176 &user_structure.user_id,
177 &user_structure.is_manager,
178 ],
179 )
180 .await?;
181
182 match n {
183 0 => Err(Error::NotModified),
184 1 => Ok(()),
185 _ => unreachable!(),
186 }
187 }
188
189 async fn add_user(&self, user: &User) -> Result<(), Error> {
190 let connection = self.pool.get().await?;
191 let check_exists_statement = connection.prepare(
192 r#"
193 SELECT 1
194 FROM users
195 WHERE email = $1
196 OR username = $2
197 "#,
198 );
199
200 let insert_statement = connection.prepare(
201 r#"
202 INSERT INTO users(id, username, email, password_hash)
203 VALUES ($1, $2, $3, $4)
204 "#,
205 );
206
207 let (check_exists_statement, insert_statement) =
208 tokio::join!(check_exists_statement, insert_statement);
209
210 let (check_exists_statement, insert_statement) =
211 (check_exists_statement?, insert_statement?);
212
213 let exists = connection
214 .query_opt(&check_exists_statement, &[&user.email, &user.username])
215 .await?
216 .is_some();
217
218 if exists {
219 return Err(Error::AlreadyExists);
220 }
221
222 let n = connection
223 .execute(
224 &insert_statement,
225 &[&user.id, &user.username, &user.email, &user.password_hash],
226 )
227 .await?;
228
229 match n {
230 0 => Err(Error::NotModified),
231 1 => Ok(()),
232 _ => unreachable!(),
233 }
234 }
235
236 async fn get_device(&self, device_id: &DeviceID) -> Result<Option<Device>, Error> {
237 const QUERY: &str = "
238 SELECT *
239 FROM devices
240 WHERE id = $1";
241 let connection = self.pool.get().await?;
242 let row = match connection.query_opt(QUERY, &[&device_id]).await? {
243 Some(row) => row,
244 None => return Ok(None),
245 };
246
247 let device = Device {
248 id: row.try_get("id")?,
249 password_hash: row.try_get("password_hash")?,
250 device_type: row.try_get("type")?,
251 traits: row.try_get("traits")?,
252 name: row.try_get("name")?,
253 will_push_state: row.try_get("will_push_state")?,
254 room_id: row.try_get("room_id")?,
255 model: row.try_get("model")?,
256 hw_version: Version::parse(row.try_get("hw_version")?).map_err(|err| {
257 InternalError::InvalidColumn {
258 column: "hw_version",
259 error: Box::new(err),
260 }
261 })?,
262 sw_version: Version::parse(row.try_get("sw_version")?).map_err(|err| {
263 InternalError::InvalidColumn {
264 column: "sw_version",
265 error: Box::new(err),
266 }
267 })?,
268 attributes: row.try_get("attributes")?,
269 };
270
271 Ok(Some(device))
272 }
273
274 async fn get_user_devices(&self, user_id: &UserID) -> Result<Vec<Device>, Error> {
275 let connection = self.pool.get().await?;
276 let query_statement = connection
277 .prepare(
278 r#"
279 SELECT *
280 FROM devices
281 WHERE room_id = (
282 SELECT id
283 FROM rooms
284 WHERE structure_id = (
285 SELECT structure_id
286 FROM user_structures
287 WHERE user_id = $1
288 )
289 )
290 "#,
291 )
292 .await?;
293 let row = connection.query(&query_statement, &[&user_id]).await?;
294 let devices = row.iter().map(|row| {
295 Ok::<Device, Error>(Device {
296 id: row.try_get("id")?,
297 room_id: row.try_get("room_id")?,
298 password_hash: row.try_get("password_hash")?,
299 device_type: row.try_get("type")?,
300 traits: row.try_get("traits")?,
301 name: row.try_get("name")?,
302 will_push_state: row.try_get("will_push_state")?,
303 model: row.try_get("model")?,
304 hw_version: Version::parse(row.try_get("hw_version")?).map_err(|err| {
305 InternalError::InvalidColumn {
306 column: "hw_version",
307 error: Box::new(err),
308 }
309 })?,
310 sw_version: Version::parse(row.try_get("sw_version")?).map_err(|err| {
311 InternalError::InvalidColumn {
312 column: "sw_version",
313 error: Box::new(err),
314 }
315 })?,
316 attributes: row.try_get("attributes")?,
317 })
318 });
319 let devices: Result<Vec<Device>, Error> = devices.collect();
320 devices
321 }
322
323 async fn get_user(&self, user_id: &UserID) -> Result<Option<User>, Error> {
324 const QUERY: &str = "SELECT * FROM users WHERE id = $1";
325 let connection = self.pool.get().await?;
326 let row = match connection.query_opt(QUERY, &[&user_id]).await? {
327 Some(row) => row,
328 None => return Ok(None),
329 };
330 let user = User {
331 id: row.try_get("id")?,
332 username: row.try_get("username")?,
333 email: row.try_get("email")?,
334 password_hash: row.try_get("password_hash")?,
335 };
336
337 Ok(Some(user))
338 }
339
340 async fn get_user_by_email(&self, email: &str) -> Result<Option<User>, Error> {
341 const QUERY: &str = "SELECT * FROM users WHERE email = $1";
342 let connection = self.pool.get().await?;
343 let row = match connection.query_opt(QUERY, &[&email.to_string()]).await? {
344 Some(row) => row,
345 None => return Ok(None),
346 };
347 let user = User {
348 id: row.try_get("id")?,
349 username: row.try_get("username")?,
350 email: row.try_get("email")?,
351 password_hash: row.try_get("password_hash")?,
352 };
353
354 Ok(Some(user))
355 }
356
357 async fn check_user_device_access(
358 &self,
359 user_id: &UserID,
360 device_id: &DeviceID,
361 ) -> Result<bool, Error> {
362 let connection = self.pool.get().await?;
363 let query_statement = connection
364 .prepare(
365 r#"
366 SELECT 1
367 FROM devices
368 WHERE id = $1
369 AND room_id = (
370 SELECT id
371 FROM rooms
372 WHERE structure_id = (
373 SELECT structure_id
374 FROM user_structures
375 WHERE user_id = $2
376 )
377 )
378 "#,
379 )
380 .await?;
381 let result = connection
382 .query_opt(&query_statement, &[&device_id, &user_id])
383 .await?;
384
385 Ok(result.is_some())
386 }
387
388 async fn check_user_device_manager_access(
389 &self,
390 user_id: &UserID,
391 device_id: &DeviceID,
392 ) -> Result<bool, Error> {
393 let connection = self.pool.get().await?;
394 let query_statement = connection
395 .prepare(
396 r#"
397 SELECT 1
398 FROM devices
399 WHERE id = $1
400 AND room_id = (
401 SELECT id
402 FROM rooms
403 WHERE structure_id = (
404 SELECT structure_id
405 FROM user_structures
406 WHERE user_id = $2
407 AND is_manager = true
408 )
409 )
410 "#,
411 )
412 .await?;
413 let result = connection
414 .query_opt(&query_statement, &[&device_id, &user_id])
415 .await?;
416
417 Ok(result.is_some())
418 }
419
420 async fn check_user_admin(&self, user_id: &UserID) -> Result<bool, Error> {
421 let connection = self.pool.get().await?;
422 let query_statement = connection
423 .prepare(
424 r#"
425 SELECT 1
426 FROM admins
427 WHERE user_id = $1
428 "#,
429 )
430 .await?;
431
432 let result = connection.query_opt(&query_statement, &[&user_id]).await?;
433
434 Ok(result.is_some())
435 }
436
437 async fn delete_user(&self, user_id: &UserID) -> Result<(), Error> {
438 const QUERY: &str = "DELETE FROM users WHERE id = $1";
439 let connection = self.pool.get().await?;
440 let n = connection.execute(QUERY, &[&user_id]).await?;
441 match n {
442 0 => Err(Error::NotModified),
443 1 => Ok(()),
444 _ => unreachable!(),
445 }
446 }
447}