1use std::collections::HashMap;
2use std::ops::Range;
3
4use rand::Rng;
5
6use async_trait::async_trait;
7use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
8use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
9use atuin_server_database::{Database, DbError, DbResult, DbSettings};
10use futures_util::TryStreamExt;
11use sqlx::Row;
12use sqlx::postgres::PgPoolOptions;
13
14use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
15use tracing::instrument;
16use uuid::Uuid;
17use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
18
19mod wrappers;
20
21const MIN_PG_VERSION: u32 = 14;
22
23#[derive(Clone)]
24pub struct Postgres {
25 pool: sqlx::Pool<sqlx::postgres::Postgres>,
26 read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>,
28}
29
30impl Postgres {
31 fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> {
34 self.read_pool.as_ref().unwrap_or(&self.pool)
35 }
36}
37
38fn fix_error(error: sqlx::Error) -> DbError {
39 match error {
40 sqlx::Error::RowNotFound => DbError::NotFound,
41 error => DbError::Other(error.into()),
42 }
43}
44
45#[async_trait]
46impl Database for Postgres {
47 async fn new(settings: &DbSettings) -> DbResult<Self> {
48 let pool = PgPoolOptions::new()
49 .max_connections(100)
50 .connect(settings.db_uri.as_str())
51 .await
52 .map_err(fix_error)?;
53
54 let pg_major_version: u32 = pool
57 .acquire()
58 .await
59 .map_err(fix_error)?
60 .server_version_num()
61 .ok_or(DbError::Other(eyre::Report::msg(
62 "could not get PostgreSQL version",
63 )))?
64 / 10000;
65
66 if pg_major_version < MIN_PG_VERSION {
67 return Err(DbError::Other(eyre::Report::msg(format!(
68 "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}"
69 ))));
70 }
71
72 sqlx::migrate!("./migrations")
73 .run(&pool)
74 .await
75 .map_err(|error| DbError::Other(error.into()))?;
76
77 let read_pool = if let Some(read_db_uri) = &settings.read_db_uri {
79 tracing::info!("Connecting to read replica database");
80 let read_pool = PgPoolOptions::new()
81 .max_connections(100)
82 .connect(read_db_uri.as_str())
83 .await
84 .map_err(fix_error)?;
85
86 let read_pg_major_version: u32 = read_pool
88 .acquire()
89 .await
90 .map_err(fix_error)?
91 .server_version_num()
92 .ok_or(DbError::Other(eyre::Report::msg(
93 "could not get PostgreSQL version from read replica",
94 )))?
95 / 10000;
96
97 if read_pg_major_version < MIN_PG_VERSION {
98 return Err(DbError::Other(eyre::Report::msg(format!(
99 "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}"
100 ))));
101 }
102
103 Some(read_pool)
104 } else {
105 None
106 };
107
108 Ok(Self { pool, read_pool })
109 }
110
111 #[instrument(skip_all)]
112 async fn get_session(&self, token: &str) -> DbResult<Session> {
113 sqlx::query_as("select id, user_id, token from sessions where token = $1")
114 .bind(token)
115 .fetch_one(self.read_pool())
116 .await
117 .map_err(fix_error)
118 .map(|DbSession(session)| session)
119 }
120
121 #[instrument(skip_all)]
122 async fn get_user(&self, username: &str) -> DbResult<User> {
123 sqlx::query_as("select id, username, email, password from users where username = $1")
124 .bind(username)
125 .fetch_one(self.read_pool())
126 .await
127 .map_err(fix_error)
128 .map(|DbUser(user)| user)
129 }
130
131 #[instrument(skip_all)]
132 async fn get_session_user(&self, token: &str) -> DbResult<User> {
133 sqlx::query_as(
134 "select users.id, users.username, users.email, users.password from users
135 inner join sessions
136 on users.id = sessions.user_id
137 and sessions.token = $1",
138 )
139 .bind(token)
140 .fetch_one(self.read_pool())
141 .await
142 .map_err(fix_error)
143 .map(|DbUser(user)| user)
144 }
145
146 #[instrument(skip_all)]
147 async fn count_history(&self, user: &User) -> DbResult<i64> {
148 let res: (i64,) = sqlx::query_as(
153 "select count(1) from history
154 where user_id = $1",
155 )
156 .bind(user.id)
157 .fetch_one(self.read_pool())
158 .await
159 .map_err(fix_error)?;
160
161 Ok(res.0)
162 }
163
164 #[instrument(skip_all)]
165 async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
166 let res: (i32,) = sqlx::query_as(
167 "select total from total_history_count_user
168 where user_id = $1",
169 )
170 .bind(user.id)
171 .fetch_one(self.read_pool())
172 .await
173 .map_err(fix_error)?;
174
175 Ok(res.0 as i64)
176 }
177
178 async fn delete_store(&self, user: &User) -> DbResult<()> {
179 let mut tx = self.pool.begin().await.map_err(fix_error)?;
180
181 sqlx::query(
182 "delete from store
183 where user_id = $1",
184 )
185 .bind(user.id)
186 .execute(&mut *tx)
187 .await
188 .map_err(fix_error)?;
189
190 sqlx::query(
191 "delete from store_idx_cache
192 where user_id = $1",
193 )
194 .bind(user.id)
195 .execute(&mut *tx)
196 .await
197 .map_err(fix_error)?;
198
199 tx.commit().await.map_err(fix_error)?;
200
201 Ok(())
202 }
203
204 async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
205 sqlx::query(
206 "update history
207 set deleted_at = $3
208 where user_id = $1
209 and client_id = $2
210 and deleted_at is null", )
212 .bind(user.id)
213 .bind(id)
214 .bind(OffsetDateTime::now_utc())
215 .fetch_all(&self.pool)
216 .await
217 .map_err(fix_error)?;
218
219 Ok(())
220 }
221
222 #[instrument(skip_all)]
223 async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
224 let res = sqlx::query(
229 "select client_id from history
230 where user_id = $1
231 and deleted_at is not null",
232 )
233 .bind(user.id)
234 .fetch_all(self.read_pool())
235 .await
236 .map_err(fix_error)?;
237
238 let res = res
239 .iter()
240 .map(|row| row.get::<String, _>("client_id"))
241 .collect();
242
243 Ok(res)
244 }
245
246 #[instrument(skip_all)]
247 async fn count_history_range(
248 &self,
249 user: &User,
250 range: Range<OffsetDateTime>,
251 ) -> DbResult<i64> {
252 let res: (i64,) = sqlx::query_as(
253 "select count(1) from history
254 where user_id = $1
255 and timestamp >= $2::date
256 and timestamp < $3::date",
257 )
258 .bind(user.id)
259 .bind(into_utc(range.start))
260 .bind(into_utc(range.end))
261 .fetch_one(self.read_pool())
262 .await
263 .map_err(fix_error)?;
264
265 Ok(res.0)
266 }
267
268 #[instrument(skip_all)]
269 async fn list_history(
270 &self,
271 user: &User,
272 created_after: OffsetDateTime,
273 since: OffsetDateTime,
274 host: &str,
275 page_size: i64,
276 ) -> DbResult<Vec<History>> {
277 let res = sqlx::query_as(
278 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
279 where user_id = $1
280 and hostname != $2
281 and created_at >= $3
282 and timestamp >= $4
283 order by timestamp asc
284 limit $5",
285 )
286 .bind(user.id)
287 .bind(host)
288 .bind(into_utc(created_after))
289 .bind(into_utc(since))
290 .bind(page_size)
291 .fetch(self.read_pool())
292 .map_ok(|DbHistory(h)| h)
293 .try_collect()
294 .await
295 .map_err(fix_error)?;
296
297 Ok(res)
298 }
299
300 #[instrument(skip_all)]
301 async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
302 let mut tx = self.pool.begin().await.map_err(fix_error)?;
303
304 for i in history {
305 let client_id: &str = &i.client_id;
306 let hostname: &str = &i.hostname;
307 let data: &str = &i.data;
308
309 sqlx::query(
310 "insert into history
311 (client_id, user_id, hostname, timestamp, data)
312 values ($1, $2, $3, $4, $5)
313 on conflict do nothing
314 ",
315 )
316 .bind(client_id)
317 .bind(i.user_id)
318 .bind(hostname)
319 .bind(i.timestamp)
320 .bind(data)
321 .execute(&mut *tx)
322 .await
323 .map_err(fix_error)?;
324 }
325
326 tx.commit().await.map_err(fix_error)?;
327
328 Ok(())
329 }
330
331 #[instrument(skip_all)]
332 async fn delete_user(&self, u: &User) -> DbResult<()> {
333 sqlx::query("delete from sessions where user_id = $1")
334 .bind(u.id)
335 .execute(&self.pool)
336 .await
337 .map_err(fix_error)?;
338
339 sqlx::query("delete from history where user_id = $1")
340 .bind(u.id)
341 .execute(&self.pool)
342 .await
343 .map_err(fix_error)?;
344
345 sqlx::query("delete from store where user_id = $1")
346 .bind(u.id)
347 .execute(&self.pool)
348 .await
349 .map_err(fix_error)?;
350
351 sqlx::query("delete from total_history_count_user where user_id = $1")
352 .bind(u.id)
353 .execute(&self.pool)
354 .await
355 .map_err(fix_error)?;
356
357 sqlx::query("delete from users where id = $1")
358 .bind(u.id)
359 .execute(&self.pool)
360 .await
361 .map_err(fix_error)?;
362
363 Ok(())
364 }
365
366 #[instrument(skip_all)]
367 async fn update_user_password(&self, user: &User) -> DbResult<()> {
368 sqlx::query(
369 "update users
370 set password = $1
371 where id = $2",
372 )
373 .bind(&user.password)
374 .bind(user.id)
375 .execute(&self.pool)
376 .await
377 .map_err(fix_error)?;
378
379 Ok(())
380 }
381
382 #[instrument(skip_all)]
383 async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
384 let email: &str = &user.email;
385 let username: &str = &user.username;
386 let password: &str = &user.password;
387
388 let res: (i64,) = sqlx::query_as(
389 "insert into users
390 (username, email, password)
391 values($1, $2, $3)
392 returning id",
393 )
394 .bind(username)
395 .bind(email)
396 .bind(password)
397 .fetch_one(&self.pool)
398 .await
399 .map_err(fix_error)?;
400
401 Ok(res.0)
402 }
403
404 #[instrument(skip_all)]
405 async fn add_session(&self, session: &NewSession) -> DbResult<()> {
406 let token: &str = &session.token;
407
408 sqlx::query(
409 "insert into sessions
410 (user_id, token)
411 values($1, $2)",
412 )
413 .bind(session.user_id)
414 .bind(token)
415 .execute(&self.pool)
416 .await
417 .map_err(fix_error)?;
418
419 Ok(())
420 }
421
422 #[instrument(skip_all)]
423 async fn get_user_session(&self, u: &User) -> DbResult<Session> {
424 sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
425 .bind(u.id)
426 .fetch_one(self.read_pool())
427 .await
428 .map_err(fix_error)
429 .map(|DbSession(session)| session)
430 }
431
432 #[instrument(skip_all)]
433 async fn oldest_history(&self, user: &User) -> DbResult<History> {
434 sqlx::query_as(
435 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
436 where user_id = $1
437 order by timestamp asc
438 limit 1",
439 )
440 .bind(user.id)
441 .fetch_one(self.read_pool())
442 .await
443 .map_err(fix_error)
444 .map(|DbHistory(h)| h)
445 }
446
447 #[instrument(skip_all)]
448 async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
449 let mut tx = self.pool.begin().await.map_err(fix_error)?;
450
451 let mut heads = HashMap::<(HostId, &str), u64>::new();
459
460 for i in records {
461 let id = atuin_common::utils::uuid_v7();
462
463 let result = sqlx::query(
464 "insert into store
465 (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
466 values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
467 on conflict do nothing
468 ",
469 )
470 .bind(id)
471 .bind(i.id)
472 .bind(i.host.id)
473 .bind(i.idx as i64)
474 .bind(i.timestamp as i64) .bind(&i.version)
476 .bind(&i.tag)
477 .bind(&i.data.data)
478 .bind(&i.data.content_encryption_key)
479 .bind(user.id)
480 .execute(&mut *tx)
481 .await
482 .map_err(fix_error)?;
483
484 if result.rows_affected() > 0 {
486 heads
487 .entry((i.host.id, &i.tag))
488 .and_modify(|e| {
489 if i.idx > *e {
490 *e = i.idx
491 }
492 })
493 .or_insert(i.idx);
494 }
495 }
496
497 for ((host, tag), idx) in heads {
499 sqlx::query(
500 "insert into store_idx_cache
501 (user_id, host, tag, idx)
502 values ($1, $2, $3, $4)
503 on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
504 ",
505 )
506 .bind(user.id)
507 .bind(host)
508 .bind(tag)
509 .bind(idx as i64)
510 .execute(&mut *tx)
511 .await
512 .map_err(fix_error)?;
513 }
514
515 tx.commit().await.map_err(fix_error)?;
516
517 Ok(())
518 }
519
520 #[instrument(skip_all)]
521 async fn next_records(
522 &self,
523 user: &User,
524 host: HostId,
525 tag: String,
526 start: Option<RecordIdx>,
527 count: u64,
528 ) -> DbResult<Vec<Record<EncryptedData>>> {
529 tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
530 let start = start.unwrap_or(0);
531
532 let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
533 "select client_id, host, idx, timestamp, version, tag, data, cek from store
534 where user_id = $1
535 and tag = $2
536 and host = $3
537 and idx >= $4
538 order by idx asc
539 limit $5",
540 )
541 .bind(user.id)
542 .bind(tag.clone())
543 .bind(host)
544 .bind(start as i64)
545 .bind(count as i64)
546 .fetch_all(self.read_pool())
547 .await
548 .map_err(fix_error);
549
550 let ret = match records {
551 Ok(records) => {
552 let records: Vec<Record<EncryptedData>> = records
553 .into_iter()
554 .map(|f| {
555 let record: Record<EncryptedData> = f.into();
556 record
557 })
558 .collect();
559
560 records
561 }
562 Err(DbError::NotFound) => {
563 tracing::debug!("no records found in store: {:?}/{}", host, tag);
564 return Ok(vec![]);
565 }
566 Err(e) => return Err(e),
567 };
568
569 Ok(ret)
570 }
571
572 async fn status(&self, user: &User) -> DbResult<RecordStatus> {
573 const STATUS_SQL: &str =
574 "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
575
576 let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string());
583 let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0);
584 let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0);
585
586 let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache {
587 tracing::debug!("using idx cache for user {}", user.id);
588 sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
589 .bind(user.id)
590 .fetch_all(self.read_pool())
591 .await
592 .map_err(fix_error)?
593 } else {
594 tracing::debug!("using aggregate query for user {}", user.id);
595 sqlx::query_as(STATUS_SQL)
596 .bind(user.id)
597 .fetch_all(self.read_pool())
598 .await
599 .map_err(fix_error)?
600 };
601
602 res.sort();
603
604 let mut status = RecordStatus::new();
605
606 for i in res.iter() {
607 status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
608 }
609
610 Ok(status)
611 }
612}
613
614fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
615 let x = x.to_offset(UtcOffset::UTC);
616 PrimitiveDateTime::new(x.date(), x.time())
617}
618
619#[cfg(test)]
620mod tests {
621 use time::macros::datetime;
622
623 use crate::into_utc;
624
625 #[test]
626 fn utc() {
627 let dt = datetime!(2023-09-26 15:11:02 +05:30);
628 assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
629 assert_eq!(into_utc(dt).assume_utc(), dt);
630
631 let dt = datetime!(2023-09-26 15:11:02 -07:00);
632 assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
633 assert_eq!(into_utc(dt).assume_utc(), dt);
634
635 let dt = datetime!(2023-09-26 15:11:02 +00:00);
636 assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
637 assert_eq!(into_utc(dt).assume_utc(), dt);
638 }
639}