atuin_server_database/
lib.rs

1#![forbid(unsafe_code)]
2
3pub mod calendar;
4pub mod models;
5
6use std::{
7    collections::HashMap,
8    fmt::{Debug, Display},
9    ops::Range,
10};
11
12use self::{
13    calendar::{TimePeriod, TimePeriodInfo},
14    models::{History, NewHistory, NewSession, NewUser, Session, User},
15};
16use async_trait::async_trait;
17use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
18use serde::{de::DeserializeOwned, Serialize};
19use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
20use tracing::instrument;
21
22#[derive(Debug)]
23pub enum DbError {
24    NotFound,
25    Other(eyre::Report),
26}
27
28impl Display for DbError {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "{self:?}")
31    }
32}
33
34impl<T: std::error::Error + Into<time::error::Error>> From<T> for DbError {
35    fn from(value: T) -> Self {
36        DbError::Other(value.into().into())
37    }
38}
39
40impl std::error::Error for DbError {}
41
42pub type DbResult<T> = Result<T, DbError>;
43
44#[async_trait]
45pub trait Database: Sized + Clone + Send + Sync + 'static {
46    type Settings: Debug + Clone + DeserializeOwned + Serialize + Send + Sync + 'static;
47    async fn new(settings: &Self::Settings) -> DbResult<Self>;
48
49    async fn get_session(&self, token: &str) -> DbResult<Session>;
50    async fn get_session_user(&self, token: &str) -> DbResult<User>;
51    async fn add_session(&self, session: &NewSession) -> DbResult<()>;
52
53    async fn get_user(&self, username: &str) -> DbResult<User>;
54    async fn get_user_session(&self, u: &User) -> DbResult<Session>;
55    async fn add_user(&self, user: &NewUser) -> DbResult<i64>;
56
57    async fn user_verified(&self, id: i64) -> DbResult<bool>;
58    async fn verify_user(&self, id: i64) -> DbResult<()>;
59    async fn user_verification_token(&self, id: i64) -> DbResult<String>;
60
61    async fn update_user_password(&self, u: &User) -> DbResult<()>;
62
63    async fn total_history(&self) -> DbResult<i64>;
64    async fn count_history(&self, user: &User) -> DbResult<i64>;
65    async fn count_history_cached(&self, user: &User) -> DbResult<i64>;
66
67    async fn delete_user(&self, u: &User) -> DbResult<()>;
68    async fn delete_history(&self, user: &User, id: String) -> DbResult<()>;
69    async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
70    async fn delete_store(&self, user: &User) -> DbResult<()>;
71
72    async fn add_records(&self, user: &User, record: &[Record<EncryptedData>]) -> DbResult<()>;
73    async fn next_records(
74        &self,
75        user: &User,
76        host: HostId,
77        tag: String,
78        start: Option<RecordIdx>,
79        count: u64,
80    ) -> DbResult<Vec<Record<EncryptedData>>>;
81
82    // Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
83    async fn status(&self, user: &User) -> DbResult<RecordStatus>;
84
85    async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>)
86        -> DbResult<i64>;
87
88    async fn list_history(
89        &self,
90        user: &User,
91        created_after: OffsetDateTime,
92        since: OffsetDateTime,
93        host: &str,
94        page_size: i64,
95    ) -> DbResult<Vec<History>>;
96
97    async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>;
98
99    async fn oldest_history(&self, user: &User) -> DbResult<History>;
100
101    #[instrument(skip_all)]
102    async fn calendar(
103        &self,
104        user: &User,
105        period: TimePeriod,
106        tz: UtcOffset,
107    ) -> DbResult<HashMap<u64, TimePeriodInfo>> {
108        let mut ret = HashMap::new();
109        let iter: Box<dyn Iterator<Item = DbResult<(u64, Range<Date>)>> + Send> = match period {
110            TimePeriod::Year => {
111                // First we need to work out how far back to calculate. Get the
112                // oldest history item
113                let oldest = self
114                    .oldest_history(user)
115                    .await?
116                    .timestamp
117                    .to_offset(tz)
118                    .year();
119                let current_year = OffsetDateTime::now_utc().to_offset(tz).year();
120
121                // All the years we need to get data for
122                // The upper bound is exclusive, so include current +1
123                let years = oldest..current_year + 1;
124
125                Box::new(years.map(|year| {
126                    let start = Date::from_calendar_date(year, time::Month::January, 1)?;
127                    let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?;
128
129                    Ok((year as u64, start..end))
130                }))
131            }
132
133            TimePeriod::Month { year } => {
134                let months =
135                    std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12);
136
137                Box::new(months.map(move |month| {
138                    let start = Date::from_calendar_date(year, month, 1)?;
139                    let days = time::util::days_in_year_month(year, month);
140                    let end = start + Duration::days(days as i64);
141
142                    Ok((month as u64, start..end))
143                }))
144            }
145
146            TimePeriod::Day { year, month } => {
147                let days = 1..time::util::days_in_year_month(year, month);
148                Box::new(days.map(move |day| {
149                    let start = Date::from_calendar_date(year, month, day)?;
150                    let end = start
151                        .next_day()
152                        .ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?;
153
154                    Ok((day as u64, start..end))
155                }))
156            }
157        };
158
159        for x in iter {
160            let (index, range) = x?;
161
162            let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz);
163            let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz);
164
165            let count = self.count_history_range(user, start..end).await?;
166
167            ret.insert(
168                index,
169                TimePeriodInfo {
170                    count: count as u64,
171                    hash: "".to_string(),
172                },
173            );
174        }
175
176        Ok(ret)
177    }
178}