sos_database/entity/
account.rs

1use crate::{
2    entity::{FolderEntity, FolderRecord},
3    Error, Result,
4};
5use async_sqlite::{
6    rusqlite::{
7        Connection, Error as SqlError, OptionalExtension, Row, Transaction,
8    },
9    Client,
10};
11use sos_core::{AccountId, PublicIdentity, UtcDateTime};
12use sos_vault::Vault;
13use sql_query_builder as sql;
14use std::ops::Deref;
15
16/// Account row from the database.
17#[doc(hidden)]
18#[derive(Debug, Default)]
19pub struct AccountRow {
20    /// Row identifier.
21    pub row_id: i64,
22    /// RFC3339 date and time.
23    created_at: String,
24    /// RFC3339 date and time.
25    modified_at: String,
26    /// Account identifier.
27    identifier: String,
28    /// Account name.
29    name: String,
30}
31
32impl AccountRow {
33    /// Create an account row for insertion.
34    pub fn new_insert(account_id: &AccountId, name: String) -> Result<Self> {
35        Ok(AccountRow {
36            identifier: account_id.to_string(),
37            name,
38            created_at: UtcDateTime::default().to_rfc3339()?,
39            modified_at: UtcDateTime::default().to_rfc3339()?,
40            ..Default::default()
41        })
42    }
43}
44
45impl<'a> TryFrom<&Row<'a>> for AccountRow {
46    type Error = SqlError;
47    fn try_from(row: &Row<'a>) -> std::result::Result<Self, Self::Error> {
48        Ok(AccountRow {
49            row_id: row.get(0)?,
50            created_at: row.get(1)?,
51            modified_at: row.get(2)?,
52            identifier: row.get(3)?,
53            name: row.get(4)?,
54        })
55    }
56}
57
58/// Account record from the database.
59#[derive(Debug)]
60pub struct AccountRecord {
61    /// Row identifier.
62    pub row_id: i64,
63    /// Created date and time.
64    pub created_at: UtcDateTime,
65    /// Modified date and time.
66    pub modified_at: UtcDateTime,
67    /// Account identity.
68    pub identity: PublicIdentity,
69}
70
71impl TryFrom<AccountRow> for AccountRecord {
72    type Error = Error;
73
74    fn try_from(value: AccountRow) -> std::result::Result<Self, Self::Error> {
75        let created_at = UtcDateTime::parse_rfc3339(&value.created_at)?;
76        let modified_at = UtcDateTime::parse_rfc3339(&value.modified_at)?;
77        let account_id: AccountId = value.identifier.parse()?;
78        Ok(AccountRecord {
79            row_id: value.row_id,
80            created_at,
81            modified_at,
82            identity: PublicIdentity::new(account_id, value.name),
83        })
84    }
85}
86
87/// Account entity.
88pub struct AccountEntity<'conn, C>
89where
90    C: Deref<Target = Connection>,
91{
92    conn: &'conn C,
93}
94
95impl<'conn> AccountEntity<'conn, Box<Connection>> {
96    /// Liat all accounts.
97    pub async fn list_all_accounts(
98        client: &Client,
99    ) -> Result<Vec<AccountRecord>> {
100        let account_rows = client
101            .conn_and_then(move |conn| {
102                let account = AccountEntity::new(&conn);
103                account.list_accounts()
104            })
105            .await?;
106
107        let mut accounts = Vec::new();
108        for row in account_rows {
109            accounts.push(row.try_into()?);
110        }
111        Ok(accounts)
112    }
113
114    /// Find an account and login folder.
115    pub async fn find_account_with_login(
116        client: &Client,
117        account_id: &AccountId,
118    ) -> Result<(AccountRecord, FolderRecord)> {
119        let (account, folder_row) =
120            Self::find_account_with_login_optional(client, account_id)
121                .await?;
122
123        let account_id = account.row_id;
124        Ok((
125            account,
126            folder_row.ok_or_else(|| Error::NoLoginFolder(account_id))?,
127        ))
128    }
129
130    /// Find an account and optional login folder.
131    pub async fn find_account_with_login_optional(
132        client: &Client,
133        account_id: &AccountId,
134    ) -> Result<(AccountRecord, Option<FolderRecord>)> {
135        let account_id = *account_id;
136        let (account_row, folder_row) = client
137            .conn_and_then(move |conn| {
138                let account = AccountEntity::new(&conn);
139                let account_row = account.find_one(&account_id)?;
140                let folders = FolderEntity::new(&conn);
141                let folder_row =
142                    folders.find_login_folder_optional(account_row.row_id)?;
143                Ok::<_, Error>((account_row, folder_row))
144            })
145            .await?;
146
147        let login_folder = if let Some(folder_row) = folder_row {
148            Some(FolderRecord::from_row(folder_row).await?)
149        } else {
150            None
151        };
152        Ok((account_row.try_into()?, login_folder))
153    }
154}
155
156impl<'conn> AccountEntity<'conn, Transaction<'conn>> {
157    /// Upsert the login folder.
158    pub async fn upsert_login_folder(
159        client: &Client,
160        account_id: &AccountId,
161        vault: &Vault,
162    ) -> Result<(AccountRecord, i64)> {
163        // Check if we already have a login folder
164        let (account, folder) =
165            AccountEntity::find_account_with_login_optional(
166                client, account_id,
167            )
168            .await?;
169
170        // TODO: folder creation and join should be merged into a single
171        // TODO: transaction
172
173        // Create or update the folder and secrets
174        let (folder_row_id, _) = FolderEntity::upsert_folder_and_secrets(
175            client,
176            account.row_id,
177            vault,
178        )
179        .await?;
180
181        let account_row_id = account.row_id;
182
183        // Update or insert the join
184        if folder.is_some() {
185            client
186                .conn(move |conn| {
187                    let account_entity = AccountEntity::new(&conn);
188                    account_entity
189                        .update_login_folder(account_row_id, folder_row_id)
190                })
191                .await?;
192        } else {
193            client
194                .conn(move |conn| {
195                    let account_entity = AccountEntity::new(&conn);
196                    account_entity
197                        .insert_login_folder(account_row_id, folder_row_id)
198                })
199                .await?;
200        }
201
202        Ok((account, folder_row_id))
203    }
204}
205
206impl<'conn, C> AccountEntity<'conn, C>
207where
208    C: Deref<Target = Connection>,
209{
210    /// Create a new account entity.
211    pub fn new(conn: &'conn C) -> Self {
212        Self { conn }
213    }
214
215    fn account_select_columns(&self, sql: sql::Select) -> sql::Select {
216        sql.select(
217            r#"
218                account_id,
219                created_at,
220                modified_at,
221                identifier,
222                name
223            "#,
224        )
225    }
226
227    /// Find an account in the database.
228    pub fn find_one(
229        &self,
230        account_id: &AccountId,
231    ) -> std::result::Result<AccountRow, SqlError> {
232        let query = self
233            .account_select_columns(sql::Select::new())
234            .from("accounts")
235            .where_clause("identifier = ?1");
236        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
237        Ok(stmt
238            .query_row([account_id.to_string()], |row| Ok(row.try_into()?))?)
239    }
240
241    /// Find an optional account in the database.
242    pub fn find_optional(
243        &self,
244        account_id: &AccountId,
245    ) -> std::result::Result<Option<AccountRow>, SqlError> {
246        let query = self
247            .account_select_columns(sql::Select::new())
248            .from("accounts")
249            .where_clause("identifier = ?1");
250        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
251        Ok(stmt
252            .query_row([account_id.to_string()], |row| Ok(row.try_into()?))
253            .optional()?)
254    }
255
256    /// List accounts.
257    pub fn list_accounts(&self) -> Result<Vec<AccountRow>> {
258        let query = self
259            .account_select_columns(sql::Select::new())
260            .from("accounts");
261
262        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
263
264        fn convert_row(row: &Row<'_>) -> Result<AccountRow> {
265            Ok(row.try_into()?)
266        }
267
268        let rows = stmt.query_and_then([], |row| {
269            Ok::<_, crate::Error>(convert_row(row)?)
270        })?;
271        let mut accounts = Vec::new();
272        for row in rows {
273            accounts.push(row?);
274        }
275        Ok(accounts)
276    }
277
278    /// Create the account entity in the database.
279    pub fn insert(
280        &self,
281        row: &AccountRow,
282    ) -> std::result::Result<i64, SqlError> {
283        let query = sql::Insert::new()
284            .insert_into(
285                "accounts (created_at, modified_at, identifier, name)",
286            )
287            .values("(?1, ?2, ?3, ?4)");
288        self.conn.execute(
289            &query.as_string(),
290            (
291                &row.created_at,
292                &row.modified_at,
293                &row.identifier,
294                &row.name,
295            ),
296        )?;
297        Ok(self.conn.last_insert_rowid())
298    }
299
300    /// Create the join for the account login folder.
301    pub fn insert_login_folder(
302        &self,
303        account_id: i64,
304        folder_id: i64,
305    ) -> std::result::Result<i64, SqlError> {
306        let query = sql::Insert::new()
307            .insert_into("account_login_folder (account_id, folder_id)")
308            .values("(?1, ?2)");
309        self.conn
310            .execute(&query.as_string(), [account_id, folder_id])?;
311        Ok(self.conn.last_insert_rowid())
312    }
313
314    /// Update the join for an account login folder.
315    pub fn update_login_folder(
316        &self,
317        account_id: i64,
318        folder_id: i64,
319    ) -> std::result::Result<(), SqlError> {
320        let query = sql::Update::new()
321            .update("account_login_folder")
322            .set("folder_id = ?2")
323            .where_clause("account_id = ?1");
324        self.conn
325            .execute(&query.as_string(), [account_id, folder_id])?;
326        Ok(())
327    }
328
329    /// Create the join for the account device folder.
330    pub fn insert_device_folder(
331        &self,
332        account_id: i64,
333        folder_id: i64,
334    ) -> std::result::Result<i64, SqlError> {
335        let query = sql::Insert::new()
336            .insert_into("account_device_folder (account_id, folder_id)")
337            .values("(?1, ?2)");
338        self.conn
339            .execute(&query.as_string(), [account_id, folder_id])?;
340        Ok(self.conn.last_insert_rowid())
341    }
342
343    /// Rename the account.
344    pub fn rename_account(&self, account_id: i64, name: &str) -> Result<()> {
345        let modified_at = UtcDateTime::default().to_rfc3339()?;
346        let query = sql::Update::new()
347            .update("accounts")
348            .set("name = ?1, modified_at = ?2")
349            .where_clause("account_id = ?3");
350        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
351        stmt.execute((name, modified_at, account_id))?;
352        Ok(())
353    }
354
355    /// Delete the account from the database.
356    pub fn delete_account(
357        &self,
358        account_id: &AccountId,
359    ) -> std::result::Result<(), SqlError> {
360        let account_row = self.find_one(account_id)?;
361        let query = sql::Delete::new()
362            .delete_from("accounts")
363            .where_clause("account_id = ?1");
364        self.conn
365            .execute(&query.as_string(), [account_row.row_id])?;
366        Ok(())
367    }
368}