1use std::str::FromStr;
2
3use async_trait::async_trait;
4use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
5use atuin_server_database::{
6 Database, DbError, DbResult, DbSettings,
7 models::{History, NewHistory, NewSession, NewUser, Session, User},
8};
9use futures_util::TryStreamExt;
10use sqlx::{
11 Row,
12 sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
13 types::Uuid,
14};
15use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
16use tracing::instrument;
17use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
18
19mod wrappers;
20
21#[derive(Clone)]
22pub struct Sqlite {
23 pool: sqlx::Pool<sqlx::sqlite::Sqlite>,
24}
25
26fn fix_error(error: sqlx::Error) -> DbError {
27 match error {
28 sqlx::Error::RowNotFound => DbError::NotFound,
29 error => DbError::Other(error.into()),
30 }
31}
32
33#[async_trait]
34impl Database for Sqlite {
35 async fn new(settings: &DbSettings) -> DbResult<Self> {
36 let opts = SqliteConnectOptions::from_str(&settings.db_uri)
37 .map_err(fix_error)?
38 .journal_mode(SqliteJournalMode::Wal)
39 .create_if_missing(true);
40
41 let pool = SqlitePoolOptions::new()
42 .connect_with(opts)
43 .await
44 .map_err(fix_error)?;
45
46 sqlx::migrate!("./migrations")
47 .run(&pool)
48 .await
49 .map_err(|error| DbError::Other(error.into()))?;
50
51 Ok(Self { pool })
52 }
53
54 #[instrument(skip_all)]
55 async fn get_session(&self, token: &str) -> DbResult<Session> {
56 sqlx::query_as("select id, user_id, token from sessions where token = $1")
57 .bind(token)
58 .fetch_one(&self.pool)
59 .await
60 .map_err(fix_error)
61 .map(|DbSession(session)| session)
62 }
63
64 #[instrument(skip_all)]
65 async fn get_session_user(&self, token: &str) -> DbResult<User> {
66 sqlx::query_as(
67 "select users.id, users.username, users.email, users.password from users
68 inner join sessions
69 on users.id = sessions.user_id
70 and sessions.token = $1",
71 )
72 .bind(token)
73 .fetch_one(&self.pool)
74 .await
75 .map_err(fix_error)
76 .map(|DbUser(user)| user)
77 }
78
79 #[instrument(skip_all)]
80 async fn add_session(&self, session: &NewSession) -> DbResult<()> {
81 let token: &str = &session.token;
82
83 sqlx::query(
84 "insert into sessions
85 (user_id, token)
86 values($1, $2)",
87 )
88 .bind(session.user_id)
89 .bind(token)
90 .execute(&self.pool)
91 .await
92 .map_err(fix_error)?;
93
94 Ok(())
95 }
96
97 #[instrument(skip_all)]
98 async fn get_user(&self, username: &str) -> DbResult<User> {
99 sqlx::query_as("select id, username, email, password from users where username = $1")
100 .bind(username)
101 .fetch_one(&self.pool)
102 .await
103 .map_err(fix_error)
104 .map(|DbUser(user)| user)
105 }
106
107 #[instrument(skip_all)]
108 async fn get_user_session(&self, u: &User) -> DbResult<Session> {
109 sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
110 .bind(u.id)
111 .fetch_one(&self.pool)
112 .await
113 .map_err(fix_error)
114 .map(|DbSession(session)| session)
115 }
116
117 #[instrument(skip_all)]
118 async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
119 let email: &str = &user.email;
120 let username: &str = &user.username;
121 let password: &str = &user.password;
122
123 let res: (i64,) = sqlx::query_as(
124 "insert into users
125 (username, email, password)
126 values($1, $2, $3)
127 returning id",
128 )
129 .bind(username)
130 .bind(email)
131 .bind(password)
132 .fetch_one(&self.pool)
133 .await
134 .map_err(fix_error)?;
135
136 Ok(res.0)
137 }
138
139 #[instrument(skip_all)]
140 async fn update_user_password(&self, user: &User) -> DbResult<()> {
141 sqlx::query(
142 "update users
143 set password = $1
144 where id = $2",
145 )
146 .bind(&user.password)
147 .bind(user.id)
148 .execute(&self.pool)
149 .await
150 .map_err(fix_error)?;
151
152 Ok(())
153 }
154
155 #[instrument(skip_all)]
156 async fn count_history(&self, user: &User) -> DbResult<i64> {
157 let res: (i64,) = sqlx::query_as(
162 "select count(1) from history
163 where user_id = $1",
164 )
165 .bind(user.id)
166 .fetch_one(&self.pool)
167 .await
168 .map_err(fix_error)?;
169
170 Ok(res.0)
171 }
172
173 #[instrument(skip_all)]
174 async fn count_history_cached(&self, _user: &User) -> DbResult<i64> {
175 Err(DbError::NotFound)
176 }
177
178 #[instrument(skip_all)]
179 async fn delete_user(&self, u: &User) -> DbResult<()> {
180 sqlx::query("delete from sessions where user_id = $1")
181 .bind(u.id)
182 .execute(&self.pool)
183 .await
184 .map_err(fix_error)?;
185
186 sqlx::query("delete from users where id = $1")
187 .bind(u.id)
188 .execute(&self.pool)
189 .await
190 .map_err(fix_error)?;
191
192 sqlx::query("delete from history where user_id = $1")
193 .bind(u.id)
194 .execute(&self.pool)
195 .await
196 .map_err(fix_error)?;
197
198 Ok(())
199 }
200
201 async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
202 sqlx::query(
203 "update history
204 set deleted_at = $3
205 where user_id = $1
206 and client_id = $2
207 and deleted_at is null", )
209 .bind(user.id)
210 .bind(id)
211 .bind(time::OffsetDateTime::now_utc())
212 .fetch_all(&self.pool)
213 .await
214 .map_err(fix_error)?;
215
216 Ok(())
217 }
218
219 #[instrument(skip_all)]
220 async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
221 let res = sqlx::query(
226 "select client_id from history
227 where user_id = $1
228 and deleted_at is not null",
229 )
230 .bind(user.id)
231 .fetch_all(&self.pool)
232 .await
233 .map_err(fix_error)?;
234
235 let res = res.iter().map(|row| row.get("client_id")).collect();
236
237 Ok(res)
238 }
239
240 async fn delete_store(&self, user: &User) -> DbResult<()> {
241 sqlx::query(
242 "delete from store
243 where user_id = $1",
244 )
245 .bind(user.id)
246 .execute(&self.pool)
247 .await
248 .map_err(fix_error)?;
249
250 Ok(())
251 }
252
253 #[instrument(skip_all)]
254 async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
255 let mut tx = self.pool.begin().await.map_err(fix_error)?;
256
257 for i in records {
258 let id = atuin_common::utils::uuid_v7();
259
260 sqlx::query(
261 "insert into store
262 (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
263 values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
264 on conflict do nothing
265 ",
266 )
267 .bind(id)
268 .bind(i.id)
269 .bind(i.host.id)
270 .bind(i.idx as i64)
271 .bind(i.timestamp as i64) .bind(&i.version)
273 .bind(&i.tag)
274 .bind(&i.data.data)
275 .bind(&i.data.content_encryption_key)
276 .bind(user.id)
277 .execute(&mut *tx)
278 .await
279 .map_err(fix_error)?;
280 }
281
282 tx.commit().await.map_err(fix_error)?;
283
284 Ok(())
285 }
286
287 #[instrument(skip_all)]
288 async fn next_records(
289 &self,
290 user: &User,
291 host: HostId,
292 tag: String,
293 start: Option<RecordIdx>,
294 count: u64,
295 ) -> DbResult<Vec<Record<EncryptedData>>> {
296 tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
297 let start = start.unwrap_or(0);
298
299 let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
300 "select client_id, host, idx, timestamp, version, tag, data, cek from store
301 where user_id = $1
302 and tag = $2
303 and host = $3
304 and idx >= $4
305 order by idx asc
306 limit $5",
307 )
308 .bind(user.id)
309 .bind(tag.clone())
310 .bind(host)
311 .bind(start as i64)
312 .bind(count as i64)
313 .fetch_all(&self.pool)
314 .await
315 .map_err(fix_error);
316
317 let ret = match records {
318 Ok(records) => {
319 let records: Vec<Record<EncryptedData>> = records
320 .into_iter()
321 .map(|f| {
322 let record: Record<EncryptedData> = f.into();
323 record
324 })
325 .collect();
326
327 records
328 }
329 Err(DbError::NotFound) => {
330 tracing::debug!("no records found in store: {:?}/{}", host, tag);
331 return Ok(vec![]);
332 }
333 Err(e) => return Err(e),
334 };
335
336 Ok(ret)
337 }
338
339 async fn status(&self, user: &User) -> DbResult<RecordStatus> {
340 const STATUS_SQL: &str =
341 "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
342
343 let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
344 .bind(user.id)
345 .fetch_all(&self.pool)
346 .await
347 .map_err(fix_error)?;
348
349 let mut status = RecordStatus::new();
350
351 for i in res {
352 status.set_raw(HostId(i.0), i.1, i.2 as u64);
353 }
354
355 Ok(status)
356 }
357
358 #[instrument(skip_all)]
359 async fn count_history_range(
360 &self,
361 user: &User,
362 range: std::ops::Range<time::OffsetDateTime>,
363 ) -> DbResult<i64> {
364 let res: (i64,) = sqlx::query_as(
365 "select count(1) from history
366 where user_id = $1
367 and timestamp >= $2::date
368 and timestamp < $3::date",
369 )
370 .bind(user.id)
371 .bind(into_utc(range.start))
372 .bind(into_utc(range.end))
373 .fetch_one(&self.pool)
374 .await
375 .map_err(fix_error)?;
376
377 Ok(res.0)
378 }
379
380 #[instrument(skip_all)]
381 async fn list_history(
382 &self,
383 user: &User,
384 created_after: time::OffsetDateTime,
385 since: time::OffsetDateTime,
386 host: &str,
387 page_size: i64,
388 ) -> DbResult<Vec<History>> {
389 let res = sqlx::query_as(
390 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
391 where user_id = $1
392 and hostname != $2
393 and created_at >= $3
394 and timestamp >= $4
395 order by timestamp asc
396 limit $5",
397 )
398 .bind(user.id)
399 .bind(host)
400 .bind(into_utc(created_after))
401 .bind(into_utc(since))
402 .bind(page_size)
403 .fetch(&self.pool)
404 .map_ok(|DbHistory(h)| h)
405 .try_collect()
406 .await
407 .map_err(fix_error)?;
408
409 Ok(res)
410 }
411
412 #[instrument(skip_all)]
413 async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
414 let mut tx = self.pool.begin().await.map_err(fix_error)?;
415
416 for i in history {
417 let client_id: &str = &i.client_id;
418 let hostname: &str = &i.hostname;
419 let data: &str = &i.data;
420
421 sqlx::query(
422 "insert into history
423 (client_id, user_id, hostname, timestamp, data)
424 values ($1, $2, $3, $4, $5)
425 on conflict do nothing
426 ",
427 )
428 .bind(client_id)
429 .bind(i.user_id)
430 .bind(hostname)
431 .bind(i.timestamp)
432 .bind(data)
433 .execute(&mut *tx)
434 .await
435 .map_err(fix_error)?;
436 }
437
438 tx.commit().await.map_err(fix_error)?;
439
440 Ok(())
441 }
442
443 #[instrument(skip_all)]
444 async fn oldest_history(&self, user: &User) -> DbResult<History> {
445 sqlx::query_as(
446 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
447 where user_id = $1
448 order by timestamp asc
449 limit 1",
450 )
451 .bind(user.id)
452 .fetch_one(&self.pool)
453 .await
454 .map_err(fix_error)
455 .map(|DbHistory(h)| h)
456 }
457}
458
459fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
460 let x = x.to_offset(UtcOffset::UTC);
461 PrimitiveDateTime::new(x.date(), x.time())
462}