1use std::str::FromStr;
2
3use async_trait::async_trait;
4use atuin_common::{
5 record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus},
6 utils::crypto_random_string,
7};
8use atuin_server_database::{
9 Database, DbError, DbResult, DbSettings,
10 models::{History, NewHistory, NewSession, NewUser, Session, User},
11};
12use futures_util::TryStreamExt;
13use sqlx::{
14 Row,
15 sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
16 types::Uuid,
17};
18use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
19use tracing::instrument;
20use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
21
22mod wrappers;
23
24#[derive(Clone)]
25pub struct Sqlite {
26 pool: sqlx::Pool<sqlx::sqlite::Sqlite>,
27}
28
29fn fix_error(error: sqlx::Error) -> DbError {
30 match error {
31 sqlx::Error::RowNotFound => DbError::NotFound,
32 error => DbError::Other(error.into()),
33 }
34}
35
36#[async_trait]
37impl Database for Sqlite {
38 async fn new(settings: &DbSettings) -> DbResult<Self> {
39 let opts = SqliteConnectOptions::from_str(&settings.db_uri)
40 .map_err(fix_error)?
41 .journal_mode(SqliteJournalMode::Wal)
42 .create_if_missing(true);
43
44 let pool = SqlitePoolOptions::new()
45 .connect_with(opts)
46 .await
47 .map_err(fix_error)?;
48
49 sqlx::migrate!("./migrations")
50 .run(&pool)
51 .await
52 .map_err(|error| DbError::Other(error.into()))?;
53
54 Ok(Self { pool })
55 }
56
57 #[instrument(skip_all)]
58 async fn get_session(&self, token: &str) -> DbResult<Session> {
59 sqlx::query_as("select id, user_id, token from sessions where token = $1")
60 .bind(token)
61 .fetch_one(&self.pool)
62 .await
63 .map_err(fix_error)
64 .map(|DbSession(session)| session)
65 }
66
67 #[instrument(skip_all)]
68 async fn get_session_user(&self, token: &str) -> DbResult<User> {
69 sqlx::query_as(
70 "select users.id, users.username, users.email, users.password, users.verified_at from users
71 inner join sessions
72 on users.id = sessions.user_id
73 and sessions.token = $1",
74 )
75 .bind(token)
76 .fetch_one(&self.pool)
77 .await
78 .map_err(fix_error)
79 .map(|DbUser(user)| user)
80 }
81
82 #[instrument(skip_all)]
83 async fn add_session(&self, session: &NewSession) -> DbResult<()> {
84 let token: &str = &session.token;
85
86 sqlx::query(
87 "insert into sessions
88 (user_id, token)
89 values($1, $2)",
90 )
91 .bind(session.user_id)
92 .bind(token)
93 .execute(&self.pool)
94 .await
95 .map_err(fix_error)?;
96
97 Ok(())
98 }
99
100 #[instrument(skip_all)]
101 async fn get_user(&self, username: &str) -> DbResult<User> {
102 sqlx::query_as(
103 "select id, username, email, password, verified_at from users where username = $1",
104 )
105 .bind(username)
106 .fetch_one(&self.pool)
107 .await
108 .map_err(fix_error)
109 .map(|DbUser(user)| user)
110 }
111
112 #[instrument(skip_all)]
113 async fn get_user_session(&self, u: &User) -> DbResult<Session> {
114 sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
115 .bind(u.id)
116 .fetch_one(&self.pool)
117 .await
118 .map_err(fix_error)
119 .map(|DbSession(session)| session)
120 }
121
122 #[instrument(skip_all)]
123 async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
124 let email: &str = &user.email;
125 let username: &str = &user.username;
126 let password: &str = &user.password;
127
128 let res: (i64,) = sqlx::query_as(
129 "insert into users
130 (username, email, password)
131 values($1, $2, $3)
132 returning id",
133 )
134 .bind(username)
135 .bind(email)
136 .bind(password)
137 .fetch_one(&self.pool)
138 .await
139 .map_err(fix_error)?;
140
141 Ok(res.0)
142 }
143
144 #[instrument(skip_all)]
145 async fn user_verified(&self, id: i64) -> DbResult<bool> {
146 let res: (bool,) =
147 sqlx::query_as("select verified_at is not null from users where id = $1")
148 .bind(id)
149 .fetch_one(&self.pool)
150 .await
151 .map_err(fix_error)?;
152
153 Ok(res.0)
154 }
155
156 #[instrument(skip_all)]
157 async fn verify_user(&self, id: i64) -> DbResult<()> {
158 sqlx::query(
159 "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
160 )
161 .bind(id)
162 .execute(&self.pool)
163 .await
164 .map_err(fix_error)?;
165
166 Ok(())
167 }
168
169 #[instrument(skip_all)]
170 async fn user_verification_token(&self, id: i64) -> DbResult<String> {
171 const TOKEN_VALID_MINUTES: i64 = 15;
172
173 let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
175 "select token, valid_until from user_verification_token where user_id = $1",
176 )
177 .bind(id)
178 .fetch_optional(&self.pool)
179 .await
180 .map_err(fix_error)?;
181
182 let token = if let Some((token, valid_until)) = token {
183 if valid_until > time::OffsetDateTime::now_utc() {
185 token
186 } else {
187 let token = crypto_random_string::<24>();
189
190 sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
191 .bind(id)
192 .bind(&token)
193 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
194 .execute(&self.pool)
195 .await
196 .map_err(fix_error)?;
197
198 token
199 }
200 } else {
201 let token = crypto_random_string::<24>();
203
204 sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
205 .bind(id)
206 .bind(&token)
207 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
208 .execute(&self.pool)
209 .await
210 .map_err(fix_error)?;
211
212 token
213 };
214
215 Ok(token)
216 }
217
218 #[instrument(skip_all)]
219 async fn update_user_password(&self, user: &User) -> DbResult<()> {
220 sqlx::query(
221 "update users
222 set password = $1
223 where id = $2",
224 )
225 .bind(&user.password)
226 .bind(user.id)
227 .execute(&self.pool)
228 .await
229 .map_err(fix_error)?;
230
231 Ok(())
232 }
233
234 #[instrument(skip_all)]
235 async fn total_history(&self) -> DbResult<i64> {
236 let res: (i64,) = sqlx::query_as("select count(1) from history")
237 .fetch_optional(&self.pool)
238 .await
239 .map_err(fix_error)?
240 .unwrap_or((0,));
241
242 Ok(res.0)
243 }
244
245 #[instrument(skip_all)]
246 async fn count_history(&self, user: &User) -> DbResult<i64> {
247 let res: (i64,) = sqlx::query_as(
252 "select count(1) from history
253 where user_id = $1",
254 )
255 .bind(user.id)
256 .fetch_one(&self.pool)
257 .await
258 .map_err(fix_error)?;
259
260 Ok(res.0)
261 }
262
263 #[instrument(skip_all)]
264 async fn count_history_cached(&self, _user: &User) -> DbResult<i64> {
265 Err(DbError::NotFound)
266 }
267
268 #[instrument(skip_all)]
269 async fn delete_user(&self, u: &User) -> DbResult<()> {
270 sqlx::query("delete from sessions where user_id = $1")
271 .bind(u.id)
272 .execute(&self.pool)
273 .await
274 .map_err(fix_error)?;
275
276 sqlx::query("delete from users where id = $1")
277 .bind(u.id)
278 .execute(&self.pool)
279 .await
280 .map_err(fix_error)?;
281
282 sqlx::query("delete from history where user_id = $1")
283 .bind(u.id)
284 .execute(&self.pool)
285 .await
286 .map_err(fix_error)?;
287
288 Ok(())
289 }
290
291 async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
292 sqlx::query(
293 "update history
294 set deleted_at = $3
295 where user_id = $1
296 and client_id = $2
297 and deleted_at is null", )
299 .bind(user.id)
300 .bind(id)
301 .bind(time::OffsetDateTime::now_utc())
302 .fetch_all(&self.pool)
303 .await
304 .map_err(fix_error)?;
305
306 Ok(())
307 }
308
309 #[instrument(skip_all)]
310 async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
311 let res = sqlx::query(
316 "select client_id from history
317 where user_id = $1
318 and deleted_at is not null",
319 )
320 .bind(user.id)
321 .fetch_all(&self.pool)
322 .await
323 .map_err(fix_error)?;
324
325 let res = res.iter().map(|row| row.get("client_id")).collect();
326
327 Ok(res)
328 }
329
330 async fn delete_store(&self, user: &User) -> DbResult<()> {
331 sqlx::query(
332 "delete from store
333 where user_id = $1",
334 )
335 .bind(user.id)
336 .execute(&self.pool)
337 .await
338 .map_err(fix_error)?;
339
340 Ok(())
341 }
342
343 #[instrument(skip_all)]
344 async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
345 let mut tx = self.pool.begin().await.map_err(fix_error)?;
346
347 for i in records {
348 let id = atuin_common::utils::uuid_v7();
349
350 sqlx::query(
351 "insert into store
352 (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
353 values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
354 on conflict do nothing
355 ",
356 )
357 .bind(id)
358 .bind(i.id)
359 .bind(i.host.id)
360 .bind(i.idx as i64)
361 .bind(i.timestamp as i64) .bind(&i.version)
363 .bind(&i.tag)
364 .bind(&i.data.data)
365 .bind(&i.data.content_encryption_key)
366 .bind(user.id)
367 .execute(&mut *tx)
368 .await
369 .map_err(fix_error)?;
370 }
371
372 tx.commit().await.map_err(fix_error)?;
373
374 Ok(())
375 }
376
377 #[instrument(skip_all)]
378 async fn next_records(
379 &self,
380 user: &User,
381 host: HostId,
382 tag: String,
383 start: Option<RecordIdx>,
384 count: u64,
385 ) -> DbResult<Vec<Record<EncryptedData>>> {
386 tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
387 let start = start.unwrap_or(0);
388
389 let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
390 "select client_id, host, idx, timestamp, version, tag, data, cek from store
391 where user_id = $1
392 and tag = $2
393 and host = $3
394 and idx >= $4
395 order by idx asc
396 limit $5",
397 )
398 .bind(user.id)
399 .bind(tag.clone())
400 .bind(host)
401 .bind(start as i64)
402 .bind(count as i64)
403 .fetch_all(&self.pool)
404 .await
405 .map_err(fix_error);
406
407 let ret = match records {
408 Ok(records) => {
409 let records: Vec<Record<EncryptedData>> = records
410 .into_iter()
411 .map(|f| {
412 let record: Record<EncryptedData> = f.into();
413 record
414 })
415 .collect();
416
417 records
418 }
419 Err(DbError::NotFound) => {
420 tracing::debug!("no records found in store: {:?}/{}", host, tag);
421 return Ok(vec![]);
422 }
423 Err(e) => return Err(e),
424 };
425
426 Ok(ret)
427 }
428
429 async fn status(&self, user: &User) -> DbResult<RecordStatus> {
430 const STATUS_SQL: &str =
431 "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
432
433 let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
434 .bind(user.id)
435 .fetch_all(&self.pool)
436 .await
437 .map_err(fix_error)?;
438
439 let mut status = RecordStatus::new();
440
441 for i in res {
442 status.set_raw(HostId(i.0), i.1, i.2 as u64);
443 }
444
445 Ok(status)
446 }
447
448 #[instrument(skip_all)]
449 async fn count_history_range(
450 &self,
451 user: &User,
452 range: std::ops::Range<time::OffsetDateTime>,
453 ) -> DbResult<i64> {
454 let res: (i64,) = sqlx::query_as(
455 "select count(1) from history
456 where user_id = $1
457 and timestamp >= $2::date
458 and timestamp < $3::date",
459 )
460 .bind(user.id)
461 .bind(into_utc(range.start))
462 .bind(into_utc(range.end))
463 .fetch_one(&self.pool)
464 .await
465 .map_err(fix_error)?;
466
467 Ok(res.0)
468 }
469
470 #[instrument(skip_all)]
471 async fn list_history(
472 &self,
473 user: &User,
474 created_after: time::OffsetDateTime,
475 since: time::OffsetDateTime,
476 host: &str,
477 page_size: i64,
478 ) -> DbResult<Vec<History>> {
479 let res = sqlx::query_as(
480 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
481 where user_id = $1
482 and hostname != $2
483 and created_at >= $3
484 and timestamp >= $4
485 order by timestamp asc
486 limit $5",
487 )
488 .bind(user.id)
489 .bind(host)
490 .bind(into_utc(created_after))
491 .bind(into_utc(since))
492 .bind(page_size)
493 .fetch(&self.pool)
494 .map_ok(|DbHistory(h)| h)
495 .try_collect()
496 .await
497 .map_err(fix_error)?;
498
499 Ok(res)
500 }
501
502 #[instrument(skip_all)]
503 async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
504 let mut tx = self.pool.begin().await.map_err(fix_error)?;
505
506 for i in history {
507 let client_id: &str = &i.client_id;
508 let hostname: &str = &i.hostname;
509 let data: &str = &i.data;
510
511 sqlx::query(
512 "insert into history
513 (client_id, user_id, hostname, timestamp, data)
514 values ($1, $2, $3, $4, $5)
515 on conflict do nothing
516 ",
517 )
518 .bind(client_id)
519 .bind(i.user_id)
520 .bind(hostname)
521 .bind(i.timestamp)
522 .bind(data)
523 .execute(&mut *tx)
524 .await
525 .map_err(fix_error)?;
526 }
527
528 tx.commit().await.map_err(fix_error)?;
529
530 Ok(())
531 }
532
533 #[instrument(skip_all)]
534 async fn oldest_history(&self, user: &User) -> DbResult<History> {
535 sqlx::query_as(
536 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
537 where user_id = $1
538 order by timestamp asc
539 limit 1",
540 )
541 .bind(user.id)
542 .fetch_one(&self.pool)
543 .await
544 .map_err(fix_error)
545 .map(|DbHistory(h)| h)
546 }
547}
548
549fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
550 let x = x.to_offset(UtcOffset::UTC);
551 PrimitiveDateTime::new(x.date(), x.time())
552}