atuin_server_database/
lib.rs

1#![forbid(unsafe_code)]
2
3pub mod calendar;
4pub mod models;
5
6use std::{
7    collections::HashMap,
8    fmt::{Debug, Display},
9    ops::Range,
10};
11
12use self::{
13    calendar::{TimePeriod, TimePeriodInfo},
14    models::{History, NewHistory, NewSession, NewUser, Session, User},
15};
16use async_trait::async_trait;
17use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
18use serde::{Deserialize, Serialize};
19use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
20use tracing::instrument;
21
22#[derive(Debug)]
23pub enum DbError {
24    NotFound,
25    Other(eyre::Report),
26}
27
28impl Display for DbError {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "{self:?}")
31    }
32}
33
34impl<T: std::error::Error + Into<time::error::Error>> From<T> for DbError {
35    fn from(value: T) -> Self {
36        DbError::Other(value.into().into())
37    }
38}
39
40impl std::error::Error for DbError {}
41
42pub type DbResult<T> = Result<T, DbError>;
43
44#[derive(Debug, PartialEq)]
45pub enum DbType {
46    Postgres,
47    Sqlite,
48    Unknown,
49}
50
51#[derive(Clone, Deserialize, Serialize)]
52pub struct DbSettings {
53    pub db_uri: String,
54}
55
56impl DbSettings {
57    pub fn db_type(&self) -> DbType {
58        if self.db_uri.starts_with("postgres://") || self.db_uri.starts_with("postgresql://") {
59            DbType::Postgres
60        } else if self.db_uri.starts_with("sqlite://") {
61            DbType::Sqlite
62        } else {
63            DbType::Unknown
64        }
65    }
66}
67
68// Do our best to redact passwords so they're not logged in the event of an error.
69impl Debug for DbSettings {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        if self.db_type() == DbType::Postgres {
72            let redacted_uri = url::Url::parse(&self.db_uri)
73                .map(|mut url| {
74                    let _ = url.set_password(Some("****"));
75                    url.to_string()
76                })
77                .unwrap_or(self.db_uri.clone());
78            f.debug_struct("DbSettings")
79                .field("db_uri", &redacted_uri)
80                .finish()
81        } else {
82            f.debug_struct("DbSettings")
83                .field("db_uri", &self.db_uri)
84                .finish()
85        }
86    }
87}
88
89#[async_trait]
90pub trait Database: Sized + Clone + Send + Sync + 'static {
91    async fn new(settings: &DbSettings) -> DbResult<Self>;
92
93    async fn get_session(&self, token: &str) -> DbResult<Session>;
94    async fn get_session_user(&self, token: &str) -> DbResult<User>;
95    async fn add_session(&self, session: &NewSession) -> DbResult<()>;
96
97    async fn get_user(&self, username: &str) -> DbResult<User>;
98    async fn get_user_session(&self, u: &User) -> DbResult<Session>;
99    async fn add_user(&self, user: &NewUser) -> DbResult<i64>;
100
101    async fn user_verified(&self, id: i64) -> DbResult<bool>;
102    async fn verify_user(&self, id: i64) -> DbResult<()>;
103    async fn user_verification_token(&self, id: i64) -> DbResult<String>;
104
105    async fn update_user_password(&self, u: &User) -> DbResult<()>;
106
107    async fn total_history(&self) -> DbResult<i64>;
108    async fn count_history(&self, user: &User) -> DbResult<i64>;
109    async fn count_history_cached(&self, user: &User) -> DbResult<i64>;
110
111    async fn delete_user(&self, u: &User) -> DbResult<()>;
112    async fn delete_history(&self, user: &User, id: String) -> DbResult<()>;
113    async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
114    async fn delete_store(&self, user: &User) -> DbResult<()>;
115
116    async fn add_records(&self, user: &User, record: &[Record<EncryptedData>]) -> DbResult<()>;
117    async fn next_records(
118        &self,
119        user: &User,
120        host: HostId,
121        tag: String,
122        start: Option<RecordIdx>,
123        count: u64,
124    ) -> DbResult<Vec<Record<EncryptedData>>>;
125
126    // Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
127    async fn status(&self, user: &User) -> DbResult<RecordStatus>;
128
129    async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>)
130    -> DbResult<i64>;
131
132    async fn list_history(
133        &self,
134        user: &User,
135        created_after: OffsetDateTime,
136        since: OffsetDateTime,
137        host: &str,
138        page_size: i64,
139    ) -> DbResult<Vec<History>>;
140
141    async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>;
142
143    async fn oldest_history(&self, user: &User) -> DbResult<History>;
144
145    #[instrument(skip_all)]
146    async fn calendar(
147        &self,
148        user: &User,
149        period: TimePeriod,
150        tz: UtcOffset,
151    ) -> DbResult<HashMap<u64, TimePeriodInfo>> {
152        let mut ret = HashMap::new();
153        let iter: Box<dyn Iterator<Item = DbResult<(u64, Range<Date>)>> + Send> = match period {
154            TimePeriod::Year => {
155                // First we need to work out how far back to calculate. Get the
156                // oldest history item
157                let oldest = self
158                    .oldest_history(user)
159                    .await?
160                    .timestamp
161                    .to_offset(tz)
162                    .year();
163                let current_year = OffsetDateTime::now_utc().to_offset(tz).year();
164
165                // All the years we need to get data for
166                // The upper bound is exclusive, so include current +1
167                let years = oldest..current_year + 1;
168
169                Box::new(years.map(|year| {
170                    let start = Date::from_calendar_date(year, time::Month::January, 1)?;
171                    let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?;
172
173                    Ok((year as u64, start..end))
174                }))
175            }
176
177            TimePeriod::Month { year } => {
178                let months =
179                    std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12);
180
181                Box::new(months.map(move |month| {
182                    let start = Date::from_calendar_date(year, month, 1)?;
183                    let days = start.month().length(year);
184                    let end = start + Duration::days(days as i64);
185
186                    Ok((month as u64, start..end))
187                }))
188            }
189
190            TimePeriod::Day { year, month } => {
191                let days = 1..month.length(year);
192                Box::new(days.map(move |day| {
193                    let start = Date::from_calendar_date(year, month, day)?;
194                    let end = start
195                        .next_day()
196                        .ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?;
197
198                    Ok((day as u64, start..end))
199                }))
200            }
201        };
202
203        for x in iter {
204            let (index, range) = x?;
205
206            let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz);
207            let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz);
208
209            let count = self.count_history_range(user, start..end).await?;
210
211            ret.insert(
212                index,
213                TimePeriodInfo {
214                    count: count as u64,
215                    hash: "".to_string(),
216                },
217            );
218        }
219
220        Ok(ret)
221    }
222}