chamber_core/
postgres.rs

1use crate::core::Database;
2use crate::errors::DatabaseError;
3use crate::secrets::{EncryptedSecret, Secret};
4use crate::users::User;
5
6use sqlx::types::BigDecimal;
7use sqlx::PgPool;
8
9use crate::secrets::SecretInfo;
10
11#[derive(Clone)]
12pub struct Postgres(pub PgPool);
13
14impl Postgres {
15    pub fn from_pool(pool: PgPool) -> Self {
16        Self(pool)
17    }
18}
19
20#[async_trait::async_trait]
21impl Database for Postgres {
22    async fn create_secret(&self, new_secret: EncryptedSecret) -> Result<(), DatabaseError> {
23        // you might need to convert to Vec<u8> here for the Nonce
24        sqlx::query(
25            "INSERT INTO SECRETS 
26                    (key, nonce, ciphertext, tags, access_level, role_whitelist)
27                    VALUES
28                    ($1, $2, $3, $4, $5, $6)",
29        )
30        .bind(new_secret.key())
31        .bind(BigDecimal::from(new_secret.nonce.0 - 1))
32        .bind(new_secret.ciphertext())
33        .bind(new_secret.tags())
34        .bind(new_secret.access_level())
35        .bind(new_secret.role_whitelist())
36        .execute(&self.0)
37        .await?;
38
39        Ok(())
40    }
41
42    async fn view_all_secrets_admin(
43        &self,
44    ) -> Result<Vec<EncryptedSecret>, DatabaseError> {
45        let retrieved_keys = sqlx::query_as::<_, EncryptedSecret>(
46            "SELECT 
47            key, nonce, ciphertext, tags, access_level, role_whitelist
48            FROM secrets
49                ",
50        )
51        .fetch_all(&self.0)
52        .await?;
53
54        Ok(retrieved_keys)
55    }
56
57    async fn view_all_secrets(
58        &self,
59        user: User,
60        tag: Option<String>,
61    ) -> Result<Vec<SecretInfo>, DatabaseError> {
62        let retrieved_keys = sqlx::query_as::<_, SecretInfo>(
63            "SELECT 
64            key, tags FROM secrets WHERE (
65                    case when $1 is not null 
66                    then $1 = ANY(tags)
67                    else 1=1 
68                    end)
69                    AND $2 >= access_level
70                ",
71        )
72        .bind(tag)
73        .bind(user.access_level())
74        .fetch_all(&self.0)
75        .await?;
76
77        Ok(retrieved_keys)
78    }
79
80    async fn update_secret(
81        &self,
82        key: String,
83        secret: EncryptedSecret,
84    ) -> Result<(), DatabaseError> {
85        // Might need to convert back from Vec<u8> to Nonce<U12>
86        sqlx::query("UPDATE secrets SET tags = $1 WHERE key = $2")
87            .bind(secret.tags())
88            .bind(key)
89            .execute(&self.0)
90            .await?;
91
92        Ok(())
93    }
94
95    async fn rekey_all_secrets(&self, secrets: Vec<EncryptedSecret>) -> Result<(), DatabaseError> {
96        let transaction = self.0.try_begin().await?.unwrap();
97
98        for secret in secrets {
99            if let Err(e) = sqlx::query("UPDATE secrets SET ciphertext = $1 WHERE key = $2")
100            .bind(secret.ciphertext())
101            .bind(secret.key())
102            .execute(&self.0)
103            .await {
104                transaction.rollback().await?;
105                return Err(DatabaseError::SQLError(e));
106            }
107        }
108
109        transaction.commit().await?;
110        Ok(())
111
112    }
113
114    async fn view_secret(&self, user: User, key: String) -> Result<EncryptedSecret, DatabaseError> {
115        let retrieved_key = sqlx::query_as::<_, EncryptedSecret>(
116            "SELECT key, nonce, ciphertext, tags, access_level, role_whitelist FROM secrets WHERE key = $1 AND $2 >= access_level",
117        )
118        .bind(key)
119        .bind(user.access_level())
120        .fetch_one(&self.0)
121        .await?;
122
123        Ok(retrieved_key)
124    }
125
126    async fn view_secret_decrypted(
127        &self,
128        user: User,
129        key: String,
130    ) -> Result<Secret, DatabaseError> {
131        let retrieved_key = sqlx::query_as::<_, Secret>(
132            "SELECT nonce, ciphertext FROM secrets WHERE 
133            key = $1 
134            AND $2 >= access_level 
135            AND ( CASE 
136            WHEN ARRAY_LENGTH(role_whitelist, 1) > 0 
137            then role_whitelist && $3
138            else 1=1 end
139            )
140            ",
141        )
142        .bind(key)
143        .bind(user.access_level())
144        .bind(user.roles())
145        .fetch_one(&self.0)
146        .await?;
147
148        Ok(retrieved_key)
149    }
150
151    async fn delete_secret(&self, key: String) -> Result<(), DatabaseError> {
152        sqlx::query("DELETE FROM secrets WHERE key = $1")
153            .bind(key)
154            .execute(&self.0)
155            .await?;
156
157        Ok(())
158    }
159    async fn view_users(&self) -> Result<Vec<User>, DatabaseError> {
160        let query = sqlx::query_as::<_, User>("SELECT username, role FROM USERS")
161            .fetch_all(&self.0)
162            .await?;
163
164        Ok(query)
165    }
166
167    async fn get_user_from_name(&self, username: String) -> Result<User, DatabaseError> {
168        let query = sqlx::query_as::<_, User>(
169            "SELECT username, password, access_level, roles FROM USERS WHERE USERNAME = $1",
170        )
171        .bind(username)
172        .fetch_one(&self.0)
173        .await?;
174
175        Ok(query)
176    }
177
178    async fn get_user_from_password(&self, password: String) -> Result<User, DatabaseError> {
179        let query = sqlx::query_as::<_, User>(
180            "SELECT username, password, access_level, roles FROM users WHERE PASSWORD = $1",
181        )
182        .bind(password)
183        .fetch_one(&self.0)
184        .await?;
185
186        Ok(query)
187    }
188
189    async fn create_user(&self, user: User) -> Result<String, DatabaseError> {
190        let query = sqlx::query_as::<_, SingleValue>(
191            "INSERT INTO users
192            (username, password)
193            VALUES
194            ($1, $2) RETURNING PASSWORD",
195        )
196        .bind(user.username)
197        .bind(user.password)
198        .fetch_one(&self.0)
199        .await?;
200
201        Ok(query.0)
202    }
203
204    async fn update_user(&self, user: User) -> Result<(), DatabaseError> {
205        sqlx::query(
206            "
207            UPDATE users SET
208            access_level = $1,
209            roles = $2
210            where username = $3
211            ",
212        )
213        .bind(user.access_level())
214        .bind(user.clone().roles())
215        .bind(user.username)
216        .execute(&self.0)
217        .await?;
218
219        Ok(())
220    }
221
222    async fn delete_user(&self, name: String) -> Result<(), DatabaseError> {
223        sqlx::query("DELETE FROM USERS WHERE USERNAME = $1")
224            .bind(name)
225            .execute(&self.0)
226            .await?;
227
228        Ok(())
229    }
230}
231
232#[derive(sqlx::FromRow)]
233pub struct SingleValue(String);