grammers_session/storages/
sqlite.rs

1// Copyright 2020 - developers of the `grammers` project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::defs::{
10    ChannelKind, ChannelState, DcOption, PeerAuth, PeerId, PeerInfo, PeerKind, UpdateState,
11    UpdatesState,
12};
13use crate::{DEFAULT_DC, KNOWN_DC_OPTIONS, Session};
14use std::path::Path;
15use std::sync::Mutex;
16
17const VERSION: i64 = 1;
18
19struct Database(sqlite::Connection);
20
21struct TransactionGuard<'c>(&'c sqlite::Connection);
22
23/// SQLite-based storage. This is the recommended option.
24pub struct SqliteSession {
25    database: Mutex<Database>,
26}
27
28#[repr(u8)]
29enum PeerSubtype {
30    UserSelf = 1,
31    UserBot = 2,
32    UserSelfBot = 3,
33    Megagroup = 4,
34    Broadcast = 8,
35    Gigagroup = 12,
36}
37
38impl Database {
39    fn init(&self) -> sqlite::Result<()> {
40        let mut user_version = self
41            .fetch_one("PRAGMA user_version", &[], |stmt| stmt.read::<i64, _>(0))?
42            .unwrap_or(0);
43        if user_version == VERSION {
44            return Ok(());
45        }
46
47        if user_version == 0 {
48            self.migrate_v0_to_v1()?;
49            user_version += 1;
50        }
51        if user_version == VERSION {
52            // Can't bind PRAGMA parameters, but `VERSION` is not user-controlled input.
53            self.0.execute(format!("PRAGMA user_version = {VERSION}"))?;
54        }
55        Ok(())
56    }
57
58    fn migrate_v0_to_v1(&self) -> sqlite::Result<()> {
59        let _transaction = self.begin_transaction()?;
60        self.0.execute(
61            "CREATE TABLE dc_home (
62                dc_id INTEGER NOT NULL,
63                PRIMARY KEY(dc_id))",
64        )?;
65        self.0.execute(
66            "CREATE TABLE dc_option (
67                dc_id INTEGER NOT NULL,
68                ipv4 TEXT NOT NULL,
69                ipv6 TEXT NOT NULL,
70                auth_key BLOB,
71                PRIMARY KEY (dc_id))",
72        )?;
73        self.0.execute(
74            "CREATE TABLE peer_info (
75                peer_id INTEGER NOT NULL,
76                hash INTEGER,
77                subtype INTEGER,
78                PRIMARY KEY (peer_id))",
79        )?;
80        self.0.execute(
81            "CREATE TABLE update_state (
82                pts INTEGER NOT NULL,
83                qts INTEGER NOT NULL,
84                date INTEGER NOT NULL,
85                seq INTEGER NOT NULL)",
86        )?;
87        self.0.execute(
88            "CREATE TABLE channel_state (
89                peer_id INTEGER NOT NULL,
90                pts INTEGER NOT NULL,
91                PRIMARY KEY (peer_id))",
92        )?;
93
94        Ok(())
95    }
96
97    fn begin_transaction(&self) -> sqlite::Result<TransactionGuard<'_>> {
98        self.0.execute("BEGIN TRANSACTION")?;
99        Ok(TransactionGuard(&self.0))
100    }
101
102    fn fetch_one<T, F: FnOnce(sqlite::Statement) -> sqlite::Result<T>>(
103        &self,
104        statement: &str,
105        bindings: &[(&str, sqlite::Value)],
106        select: F,
107    ) -> sqlite::Result<Option<T>> {
108        let mut statement = self.0.prepare(statement)?;
109        statement.bind(bindings)?;
110        let result = match statement.next()? {
111            sqlite::State::Row => Some(select(statement)?),
112            sqlite::State::Done => None,
113        };
114        Ok(result)
115    }
116
117    fn fetch_all<T, F: FnMut(&sqlite::Statement) -> sqlite::Result<T>>(
118        &self,
119        statement: &str,
120        bindings: &[(&str, sqlite::Value)],
121        mut select: F,
122    ) -> sqlite::Result<Vec<T>> {
123        let mut result = Vec::new();
124        let mut statement = self.0.prepare(statement)?;
125        statement.bind(bindings)?;
126        while statement.next()? == sqlite::State::Row {
127            result.push(select(&statement)?);
128        }
129        Ok(result)
130    }
131}
132
133impl Drop for TransactionGuard<'_> {
134    fn drop(&mut self) {
135        self.0.execute("COMMIT").unwrap();
136    }
137}
138
139impl SqliteSession {
140    /// Open a connection to the SQLite database at `path`,
141    /// creating one if it doesn't exist.
142    pub fn open<P: AsRef<Path>>(path: P) -> sqlite::Result<Self> {
143        let database = Database(sqlite::Connection::open(path)?);
144        database.init()?;
145        Ok(SqliteSession {
146            database: Mutex::new(database),
147        })
148    }
149}
150
151impl Session for SqliteSession {
152    fn home_dc_id(&self) -> i32 {
153        let db = self.database.lock().unwrap();
154        db.fetch_one("SELECT * FROM dc_home LIMIT 1", &[], |stmt| {
155            Ok(stmt.read::<i64, _>("dc_id")? as i32)
156        })
157        .unwrap()
158        .unwrap_or(DEFAULT_DC)
159    }
160
161    fn set_home_dc_id(&self, dc_id: i32) {
162        let db = self.database.lock().unwrap();
163        let _transaction = db.begin_transaction().unwrap();
164        db.0.execute("DELETE FROM dc_home").unwrap();
165        let mut stmt = db.0.prepare("INSERT INTO dc_home VALUES (:dc_id)").unwrap();
166        stmt.bind((":dc_id", dc_id as i64)).unwrap();
167        stmt.next().unwrap();
168    }
169
170    fn dc_option(&self, dc_id: i32) -> Option<DcOption> {
171        let db = self.database.lock().unwrap();
172        db.fetch_one(
173            "SELECT * FROM dc_option WHERE dc_id = :dc_id LIMIT 1",
174            &[(":dc_id", sqlite::Value::Integer(dc_id as _))],
175            |stmt| {
176                Ok(DcOption {
177                    id: stmt.read::<i64, _>("dc_id")? as _,
178                    ipv4: stmt.read::<String, _>("ipv4")?.parse().unwrap(),
179                    ipv6: stmt.read::<String, _>("ipv6")?.parse().unwrap(),
180                    auth_key: stmt
181                        .read::<Option<Vec<u8>>, _>("auth_key")?
182                        .map(|auth_key| auth_key.try_into().unwrap()),
183                })
184            },
185        )
186        .unwrap()
187        .or_else(|| {
188            KNOWN_DC_OPTIONS
189                .iter()
190                .find(|dc_option| dc_option.id == dc_id)
191                .cloned()
192        })
193    }
194
195    fn set_dc_option(&self, dc_option: &DcOption) {
196        let db = self.database.lock().unwrap();
197        let mut stmt = db
198            .0
199            .prepare("INSERT OR REPLACE INTO dc_option VALUES (:dc_id, :ipv4, :ipv6, :auth_key)")
200            .unwrap();
201        stmt.bind((":dc_id", dc_option.id as i64)).unwrap();
202        stmt.bind((":ipv4", dc_option.ipv4.to_string().as_str()))
203            .unwrap();
204        stmt.bind((":ipv6", dc_option.ipv6.to_string().as_str()))
205            .unwrap();
206        if let Some(auth_key) = dc_option.auth_key {
207            stmt.bind((":auth_key", auth_key.as_slice())).unwrap();
208        }
209        stmt.next().unwrap();
210    }
211
212    fn peer(&self, peer: PeerId) -> Option<PeerInfo> {
213        let db = self.database.lock().unwrap();
214        let map_stmt = |stmt: sqlite::Statement| {
215            let subtype = stmt.read::<Option<i64>, _>("subtype")?.map(|s| s as u8);
216            Ok(match peer.kind() {
217                PeerKind::User | PeerKind::UserSelf => PeerInfo::User {
218                    id: PeerId::user(stmt.read::<i64, _>("peer_id")?).bare_id(),
219                    auth: stmt
220                        .read::<Option<i64>, _>("hash")?
221                        .map(PeerAuth::from_hash),
222                    bot: subtype.map(|s| s & PeerSubtype::UserBot as u8 != 0),
223                    is_self: subtype.map(|s| s & PeerSubtype::UserSelf as u8 != 0),
224                },
225                PeerKind::Chat => PeerInfo::Chat { id: peer.bare_id() },
226                PeerKind::Channel => PeerInfo::Channel {
227                    id: peer.bare_id(),
228                    auth: stmt
229                        .read::<Option<i64>, _>("hash")?
230                        .map(PeerAuth::from_hash),
231                    kind: subtype.and_then(|s| {
232                        if (s & PeerSubtype::Gigagroup as u8) == PeerSubtype::Gigagroup as _ {
233                            Some(ChannelKind::Gigagroup)
234                        } else if s & PeerSubtype::Broadcast as u8 != 0 {
235                            Some(ChannelKind::Broadcast)
236                        } else if s & PeerSubtype::Megagroup as u8 != 0 {
237                            Some(ChannelKind::Megagroup)
238                        } else {
239                            None
240                        }
241                    }),
242                },
243            })
244        };
245
246        if peer.kind() == PeerKind::UserSelf {
247            db.fetch_one(
248                "SELECT * FROM peer_info WHERE subtype & :type LIMIT 1",
249                &[(":type", sqlite::Value::Integer(PeerSubtype::UserSelf as _))],
250                map_stmt,
251            )
252            .unwrap()
253        } else {
254            db.fetch_one(
255                "SELECT * FROM peer_info WHERE peer_id = :peer_id LIMIT 1",
256                &[(":peer_id", sqlite::Value::Integer(peer.bot_api_dialog_id()))],
257                map_stmt,
258            )
259            .unwrap()
260        }
261    }
262
263    fn cache_peer(&self, peer: &PeerInfo) {
264        let db = self.database.lock().unwrap();
265        let mut stmt =
266            db.0.prepare("INSERT OR REPLACE INTO peer_info VALUES (:peer_id, :hash, :subtype)")
267                .unwrap();
268        stmt.bind((":peer_id", peer.id().bot_api_dialog_id()))
269            .unwrap();
270        if peer.auth() != PeerAuth::default() {
271            stmt.bind((":hash", peer.auth().hash())).unwrap();
272        }
273        let subtype = match peer {
274            PeerInfo::User { bot, is_self, .. } => {
275                match (bot.unwrap_or_default(), is_self.unwrap_or_default()) {
276                    (true, true) => Some(PeerSubtype::UserSelfBot),
277                    (true, false) => Some(PeerSubtype::UserBot),
278                    (false, true) => Some(PeerSubtype::UserSelf),
279                    (false, false) => None,
280                }
281            }
282            PeerInfo::Chat { .. } => None,
283            PeerInfo::Channel { kind, .. } => kind.map(|kind| match kind {
284                ChannelKind::Megagroup => PeerSubtype::Megagroup,
285                ChannelKind::Broadcast => PeerSubtype::Broadcast,
286                ChannelKind::Gigagroup => PeerSubtype::Gigagroup,
287            }),
288        };
289        if let Some(subtype) = subtype {
290            stmt.bind((":subtype", subtype as i64)).unwrap();
291        }
292        stmt.next().unwrap();
293    }
294
295    fn updates_state(&self) -> UpdatesState {
296        let db = self.database.lock().unwrap();
297        let mut state = db
298            .fetch_one("SELECT * FROM update_state LIMIT 1", &[], |stmt| {
299                Ok(UpdatesState {
300                    pts: stmt.read::<i64, _>("pts")? as _,
301                    qts: stmt.read::<i64, _>("qts")? as _,
302                    date: stmt.read::<i64, _>("date")? as _,
303                    seq: stmt.read::<i64, _>("seq")? as _,
304                    channels: Vec::new(),
305                })
306            })
307            .unwrap()
308            .unwrap_or_default();
309        state.channels = db
310            .fetch_all("SELECT * FROM channel_state", &[], |stmt| {
311                Ok(ChannelState {
312                    id: stmt.read::<i64, _>("peer_id")?,
313                    pts: stmt.read::<i64, _>("pts")? as _,
314                })
315            })
316            .unwrap();
317        state
318    }
319
320    fn set_update_state(&self, update: UpdateState) {
321        let db = self.database.lock().unwrap();
322        let _transaction = db.begin_transaction().unwrap();
323
324        match update {
325            UpdateState::All(updates_state) => {
326                db.0.execute("DELETE FROM update_state").unwrap();
327                let mut stmt =
328                    db.0.prepare("INSERT INTO update_state VALUES (:pts, :qts, :date, :seq)")
329                        .unwrap();
330                stmt.bind((":pts", updates_state.pts as i64)).unwrap();
331                stmt.bind((":qts", updates_state.qts as i64)).unwrap();
332                stmt.bind((":date", updates_state.date as i64)).unwrap();
333                stmt.bind((":seq", updates_state.seq as i64)).unwrap();
334                stmt.next().unwrap();
335
336                db.0.execute("DELETE FROM channel_state").unwrap();
337                for channel in updates_state.channels {
338                    let mut stmt =
339                        db.0.prepare("INSERT INTO channel_state VALUES (:peer_id, :pts)")
340                            .unwrap();
341                    stmt.bind((":peer_id", channel.id as i64)).unwrap();
342                    stmt.bind((":pts", channel.pts as i64)).unwrap();
343                    stmt.next().unwrap();
344                }
345            }
346            UpdateState::Primary { pts, date, seq } => {
347                let previous = db
348                    .fetch_one("SELECT * FROM update_state LIMIT 1", &[], |_| Ok(()))
349                    .unwrap();
350
351                let mut stmt = if previous.is_some() {
352                    db.0.prepare("UPDATE update_state SET pts = :pts, date = :date, seq = :seq")
353                        .unwrap()
354                } else {
355                    db.0.prepare("INSERT INTO update_state VALUES (:pts, 0, :date, :seq)")
356                        .unwrap()
357                };
358                stmt.bind((":pts", pts as i64)).unwrap();
359                stmt.bind((":date", date as i64)).unwrap();
360                stmt.bind((":seq", seq as i64)).unwrap();
361                stmt.next().unwrap();
362            }
363            UpdateState::Secondary { qts } => {
364                let previous = db
365                    .fetch_one("SELECT * FROM update_state LIMIT 1", &[], |_| Ok(()))
366                    .unwrap();
367
368                let mut stmt = if previous.is_some() {
369                    db.0.prepare("UPDATE update_state SET qts = :qts").unwrap()
370                } else {
371                    db.0.prepare("INSERT INTO update_state VALUES (0, :qts, 0, 0)")
372                        .unwrap()
373                };
374                stmt.bind((":qts", qts as i64)).unwrap();
375                stmt.next().unwrap();
376            }
377            UpdateState::Channel { id, pts } => {
378                let mut stmt =
379                    db.0.prepare("INSERT OR REPLACE INTO channel_state VALUES (:peer_id, :pts)")
380                        .unwrap();
381                stmt.bind((":peer_id", id)).unwrap();
382                stmt.bind((":pts", pts as i64)).unwrap();
383                stmt.next().unwrap();
384            }
385        }
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
392
393    use {DcOption, KNOWN_DC_OPTIONS, PeerInfo, Session, UpdateState};
394
395    use super::*;
396
397    #[test]
398    fn exercise_sqlite_session() {
399        let session = SqliteSession::open(":memory:").unwrap();
400
401        assert_eq!(session.home_dc_id(), DEFAULT_DC);
402        session.set_home_dc_id(DEFAULT_DC + 1);
403        assert_eq!(session.home_dc_id(), DEFAULT_DC + 1);
404
405        assert_eq!(
406            session.dc_option(KNOWN_DC_OPTIONS[0].id),
407            Some(KNOWN_DC_OPTIONS[0].clone())
408        );
409        let new_dc_option = DcOption {
410            id: KNOWN_DC_OPTIONS
411                .iter()
412                .map(|dc_option| dc_option.id)
413                .max()
414                .unwrap()
415                + 1,
416            ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 1),
417            ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 1, 0, 0),
418            auth_key: Some([1; 256]),
419        };
420        assert_eq!(session.dc_option(new_dc_option.id), None);
421        session.set_dc_option(&new_dc_option);
422        assert_eq!(session.dc_option(new_dc_option.id), Some(new_dc_option));
423
424        assert_eq!(session.peer(PeerId::self_user()), None);
425        assert_eq!(session.peer(PeerId::user(1)), None);
426        let peer = PeerInfo::User {
427            id: 1,
428            auth: None,
429            bot: Some(true),
430            is_self: Some(true),
431        };
432        session.cache_peer(&peer);
433        assert_eq!(session.peer(PeerId::self_user()), Some(peer.clone()));
434        assert_eq!(session.peer(PeerId::user(1)), Some(peer));
435
436        assert_eq!(session.peer(PeerId::channel(1)), None);
437        let peer = PeerInfo::Channel {
438            id: 1,
439            auth: Some(PeerAuth::from_hash(-1)),
440            kind: Some(ChannelKind::Broadcast),
441        };
442        session.cache_peer(&peer);
443        assert_eq!(session.peer(PeerId::channel(1)), Some(peer));
444
445        assert_eq!(session.updates_state(), UpdatesState::default());
446        session.set_update_state(UpdateState::All(UpdatesState {
447            pts: 1,
448            qts: 2,
449            date: 3,
450            seq: 4,
451            channels: vec![
452                ChannelState { id: 5, pts: 6 },
453                ChannelState { id: 7, pts: 8 },
454            ],
455        }));
456        session.set_update_state(UpdateState::Primary {
457            pts: 2,
458            date: 4,
459            seq: 5,
460        });
461        session.set_update_state(UpdateState::Secondary { qts: 3 });
462        session.set_update_state(UpdateState::Channel { id: 7, pts: 9 });
463        assert_eq!(
464            session.updates_state(),
465            UpdatesState {
466                pts: 2,
467                qts: 3,
468                date: 4,
469                seq: 5,
470                channels: vec![
471                    ChannelState { id: 5, pts: 6 },
472                    ChannelState { id: 7, pts: 9 },
473                ],
474            }
475        );
476    }
477}