Skip to main content

grammers_session/storages/
sqlite.rs

1// Copyright 2020 - developers of the `grammers` project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Mutex;
12
13use futures_core::future::BoxFuture;
14use libsql::{named_params, params};
15use tokio::sync::Mutex as AsyncMutex;
16
17use crate::types::{
18    ChannelKind, ChannelState, DcOption, PeerAuth, PeerId, PeerInfo, PeerKind, UpdateState,
19    UpdatesState,
20};
21use crate::{DEFAULT_DC, KNOWN_DC_OPTIONS, Session};
22
23const VERSION: i64 = 1;
24
25struct Database(libsql::Connection);
26
27struct Cache {
28    pub home_dc: i32,
29    pub dc_options: HashMap<i32, DcOption>,
30}
31
32/// SQLite-based storage. This is the recommended option.
33pub struct SqliteSession {
34    database: AsyncMutex<Database>,
35    cache: Mutex<Cache>,
36}
37
38#[repr(u8)]
39enum PeerSubtype {
40    UserSelf = 1,
41    UserBot = 2,
42    UserSelfBot = 3,
43    Megagroup = 4,
44    Broadcast = 8,
45    Gigagroup = 12,
46}
47
48impl Database {
49    async fn init(&self) -> libsql::Result<()> {
50        let mut user_version: i64 = self
51            .fetch_one("PRAGMA user_version", params![], |row| row.get(0))
52            .await?
53            .unwrap_or(0);
54        if user_version == VERSION {
55            return Ok(());
56        }
57
58        if user_version == 0 {
59            self.migrate_v0_to_v1().await?;
60            user_version += 1;
61        }
62        if user_version == VERSION {
63            // Can't bind PRAGMA parameters, but `VERSION` is not user-controlled input.
64            self.0
65                .execute(&format!("PRAGMA user_version = {VERSION}"), params![])
66                .await?;
67        }
68        Ok(())
69    }
70
71    async fn migrate_v0_to_v1(&self) -> libsql::Result<()> {
72        let transaction = self.begin_transaction().await?;
73        transaction
74            .execute(
75                "CREATE TABLE dc_home (
76                dc_id INTEGER NOT NULL,
77                PRIMARY KEY(dc_id))",
78                params![],
79            )
80            .await?;
81        transaction
82            .execute(
83                "CREATE TABLE dc_option (
84                dc_id INTEGER NOT NULL,
85                ipv4 TEXT NOT NULL,
86                ipv6 TEXT NOT NULL,
87                auth_key BLOB,
88                PRIMARY KEY (dc_id))",
89                params![],
90            )
91            .await?;
92        transaction
93            .execute(
94                "CREATE TABLE peer_info (
95                peer_id INTEGER NOT NULL,
96                hash INTEGER,
97                subtype INTEGER,
98                PRIMARY KEY (peer_id))",
99                params![],
100            )
101            .await?;
102        transaction
103            .execute(
104                "CREATE TABLE update_state (
105                pts INTEGER NOT NULL,
106                qts INTEGER NOT NULL,
107                date INTEGER NOT NULL,
108                seq INTEGER NOT NULL)",
109                params![],
110            )
111            .await?;
112        transaction
113            .execute(
114                "CREATE TABLE channel_state (
115                peer_id INTEGER NOT NULL,
116                pts INTEGER NOT NULL,
117                PRIMARY KEY (peer_id))",
118                params![],
119            )
120            .await?;
121
122        transaction.commit().await?;
123        Ok(())
124    }
125
126    async fn begin_transaction(&self) -> libsql::Result<libsql::Transaction> {
127        self.0.transaction().await
128    }
129
130    async fn fetch_one<
131        T,
132        P: libsql::params::IntoParams,
133        F: FnOnce(libsql::Row) -> libsql::Result<T>,
134    >(
135        &self,
136        statement: &str,
137        params: P,
138        select: F,
139    ) -> libsql::Result<Option<T>> {
140        let mut statement = self.0.prepare(statement).await?;
141        let result = statement.query_row(params).await;
142        match result {
143            Ok(value) => Ok(Some(select(value)?)),
144            Err(libsql::Error::QueryReturnedNoRows) => Ok(None),
145            Err(e) => Err(e),
146        }
147    }
148
149    async fn fetch_all<
150        T,
151        P: libsql::params::IntoParams,
152        F: FnMut(libsql::Row) -> libsql::Result<T>,
153    >(
154        &self,
155        statement: &str,
156        params: P,
157        mut select: F,
158    ) -> libsql::Result<Vec<T>> {
159        let statement = self.0.prepare(statement).await?;
160        let mut rows = statement.query(params).await?;
161        let mut result = Vec::new();
162        while let Some(row) = rows.next().await? {
163            result.push(select(row)?);
164        }
165        Ok(result)
166    }
167}
168
169impl SqliteSession {
170    /// Open a connection to the SQLite database at `path`,
171    /// creating one if it doesn't exist.
172    pub async fn open<P: AsRef<Path>>(path: P) -> libsql::Result<Self> {
173        let conn = libsql::Builder::new_local(path).build().await?.connect()?;
174        let db = Database(conn);
175        db.init().await?;
176
177        let home_dc = db
178            .fetch_one("SELECT * FROM dc_home LIMIT 1", named_params![], |row| {
179                Ok(row.get::<i32>(0)?)
180            })
181            .await?
182            .unwrap_or(DEFAULT_DC);
183
184        let dc_options = db
185            .fetch_all("SELECT * FROM dc_option", named_params![], |row| {
186                Ok(DcOption {
187                    id: row.get::<i32>(0)?,
188                    ipv4: row.get::<String>(1)?.parse().unwrap(),
189                    ipv6: row.get::<String>(2)?.parse().unwrap(),
190                    auth_key: row
191                        .get::<Option<Vec<u8>>>(3)?
192                        .map(|auth_key| auth_key.try_into().unwrap()),
193                })
194            })
195            .await?
196            .into_iter()
197            .map(|dc_option| (dc_option.id, dc_option))
198            .collect();
199
200        Ok(SqliteSession {
201            database: AsyncMutex::new(db),
202            cache: Mutex::new(Cache {
203                home_dc,
204                dc_options,
205            }),
206        })
207    }
208}
209
210impl Session for SqliteSession {
211    fn home_dc_id(&self) -> i32 {
212        self.cache.lock().unwrap().home_dc
213    }
214
215    fn set_home_dc_id(&self, dc_id: i32) -> BoxFuture<'_, ()> {
216        self.cache.lock().unwrap().home_dc = dc_id;
217        Box::pin(async move {
218            let transaction = self
219                .database
220                .lock()
221                .await
222                .begin_transaction()
223                .await
224                .unwrap();
225            transaction
226                .execute("DELETE FROM dc_home", params![])
227                .await
228                .unwrap();
229            let stmt = transaction
230                .prepare("INSERT INTO dc_home VALUES (:dc_id)")
231                .await
232                .unwrap();
233            stmt.execute(named_params! {":dc_id": dc_id}).await.unwrap();
234            transaction.commit().await.unwrap();
235        })
236    }
237
238    fn dc_option(&self, dc_id: i32) -> Option<DcOption> {
239        self.cache
240            .lock()
241            .unwrap()
242            .dc_options
243            .get(&dc_id)
244            .cloned()
245            .or_else(|| {
246                KNOWN_DC_OPTIONS
247                    .iter()
248                    .find(|dc_option| dc_option.id == dc_id)
249                    .cloned()
250            })
251    }
252
253    fn set_dc_option(&self, dc_option: &DcOption) -> BoxFuture<'_, ()> {
254        self.cache
255            .lock()
256            .unwrap()
257            .dc_options
258            .insert(dc_option.id, dc_option.clone());
259
260        let dc_option = dc_option.clone();
261        Box::pin(async move {
262            let db = self.database.lock().await;
263            db.0.execute(
264                "INSERT OR REPLACE INTO dc_option VALUES (:dc_id, :ipv4, :ipv6, :auth_key)",
265                named_params! {
266                    ":dc_id": dc_option.id,
267                    ":ipv4": dc_option.ipv4.to_string(),
268                    ":ipv6": dc_option.ipv6.to_string(),
269                    ":auth_key": dc_option.auth_key.map(|k| k.to_vec()),
270                },
271            )
272            .await
273            .unwrap();
274        })
275    }
276
277    fn peer(&self, peer: PeerId) -> BoxFuture<'_, Option<PeerInfo>> {
278        Box::pin(async move {
279            let db = self.database.lock().await;
280            let map_row = |row: libsql::Row| {
281                let subtype = row.get::<Option<i64>>(2)?.map(|s| s as u8);
282                Ok(match peer.kind() {
283                    PeerKind::User | PeerKind::UserSelf => PeerInfo::User {
284                        id: PeerId::user_unchecked(row.get::<i64>(0)?).bare_id(),
285                        auth: row.get::<Option<i64>>(1)?.map(PeerAuth::from_hash),
286                        bot: subtype.map(|s| s & PeerSubtype::UserBot as u8 != 0),
287                        is_self: subtype.map(|s| s & PeerSubtype::UserSelf as u8 != 0),
288                    },
289                    PeerKind::Chat => PeerInfo::Chat { id: peer.bare_id() },
290                    PeerKind::Channel => PeerInfo::Channel {
291                        id: peer.bare_id(),
292                        auth: row.get::<Option<i64>>(1)?.map(PeerAuth::from_hash),
293                        kind: subtype.and_then(|s| {
294                            if (s & PeerSubtype::Gigagroup as u8) == PeerSubtype::Gigagroup as u8 {
295                                Some(ChannelKind::Gigagroup)
296                            } else if s & PeerSubtype::Broadcast as u8 != 0 {
297                                Some(ChannelKind::Broadcast)
298                            } else if s & PeerSubtype::Megagroup as u8 != 0 {
299                                Some(ChannelKind::Megagroup)
300                            } else {
301                                None
302                            }
303                        }),
304                    },
305                })
306            };
307
308            if peer.kind() == PeerKind::UserSelf {
309                db.fetch_one(
310                    "SELECT * FROM peer_info WHERE subtype & :type LIMIT 1",
311                    named_params! {":type": PeerSubtype::UserSelf as i64},
312                    map_row,
313                )
314                .await
315                .unwrap()
316            } else {
317                db.fetch_one(
318                    "SELECT * FROM peer_info WHERE peer_id = :peer_id LIMIT 1",
319                    named_params! {":peer_id": peer.bot_api_dialog_id()},
320                    map_row,
321                )
322                .await
323                .unwrap()
324            }
325        })
326    }
327
328    fn cache_peer(&self, peer: &PeerInfo) -> BoxFuture<'_, ()> {
329        let peer = peer.clone();
330        Box::pin(async move {
331            let db = self.database.lock().await;
332            let stmt =
333                db.0.prepare("INSERT OR REPLACE INTO peer_info VALUES (:peer_id, :hash, :subtype)")
334                    .await
335                    .unwrap();
336            let subtype = match peer {
337                PeerInfo::User { bot, is_self, .. } => {
338                    match (bot.unwrap_or_default(), is_self.unwrap_or_default()) {
339                        (true, true) => Some(PeerSubtype::UserSelfBot),
340                        (true, false) => Some(PeerSubtype::UserBot),
341                        (false, true) => Some(PeerSubtype::UserSelf),
342                        (false, false) => None,
343                    }
344                }
345                PeerInfo::Chat { .. } => None,
346                PeerInfo::Channel { kind, .. } => kind.map(|kind| match kind {
347                    ChannelKind::Megagroup => PeerSubtype::Megagroup,
348                    ChannelKind::Broadcast => PeerSubtype::Broadcast,
349                    ChannelKind::Gigagroup => PeerSubtype::Gigagroup,
350                }),
351            };
352            let mut params = vec![];
353            let peer_id = peer.id().bot_api_dialog_id();
354            params.push((":peer_id".to_owned(), peer_id));
355            let hash = peer.auth().unwrap_or_default().hash();
356            if peer.auth().is_some() {
357                params.push((":hash".to_owned(), hash));
358            }
359            let subtype = subtype.map(|s| s as i64);
360            if subtype.is_some() {
361                params.push((":subtype".to_owned(), subtype.unwrap()));
362            }
363            stmt.execute(params).await.unwrap();
364        })
365    }
366
367    fn updates_state(&self) -> BoxFuture<'_, UpdatesState> {
368        Box::pin(async move {
369            let db = self.database.lock().await;
370            let mut state = db
371                .fetch_one(
372                    "SELECT * FROM update_state LIMIT 1",
373                    named_params![],
374                    |row| {
375                        Ok(UpdatesState {
376                            pts: row.get(0)?,
377                            qts: row.get(1)?,
378                            date: row.get(2)?,
379                            seq: row.get(3)?,
380                            channels: Vec::new(),
381                        })
382                    },
383                )
384                .await
385                .unwrap()
386                .unwrap_or_default();
387            state.channels = db
388                .fetch_all("SELECT * FROM channel_state", named_params![], |row| {
389                    Ok(ChannelState {
390                        id: row.get(0)?,
391                        pts: row.get(1)?,
392                    })
393                })
394                .await
395                .unwrap();
396            state
397        })
398    }
399
400    fn set_update_state(&self, update: UpdateState) -> BoxFuture<'_, ()> {
401        Box::pin(async move {
402            let db = self.database.lock().await;
403            let transaction = db.begin_transaction().await.unwrap();
404
405            match update {
406                UpdateState::All(updates_state) => {
407                    transaction
408                        .execute("DELETE FROM update_state", params![])
409                        .await
410                        .unwrap();
411                    transaction
412                        .execute(
413                            "INSERT INTO update_state VALUES (:pts, :qts, :date, :seq)",
414                            named_params! {
415                                ":pts": updates_state.pts,
416                                ":qts": updates_state.qts,
417                                ":date": updates_state.date,
418                                ":seq": updates_state.seq,
419                            },
420                        )
421                        .await
422                        .unwrap();
423
424                    transaction
425                        .execute("DELETE FROM channel_state", params![])
426                        .await
427                        .unwrap();
428                    for channel in updates_state.channels {
429                        transaction
430                            .execute(
431                                "INSERT INTO channel_state VALUES (:peer_id, :pts)",
432                                named_params! {
433                                    ":peer_id": channel.id,
434                                    ":pts": channel.pts,
435                                },
436                            )
437                            .await
438                            .unwrap();
439                    }
440                }
441                UpdateState::Primary { pts, date, seq } => {
442                    let previous = db
443                        .fetch_one(
444                            "SELECT * FROM update_state LIMIT 1",
445                            named_params![],
446                            |_| Ok(()),
447                        )
448                        .await
449                        .unwrap();
450
451                    if previous.is_some() {
452                        transaction
453                            .execute(
454                                "UPDATE update_state SET pts = :pts, date = :date, seq = :seq",
455                                named_params! {
456                                    ":pts": pts,
457                                    ":date": date,
458                                    ":seq": seq,
459                                },
460                            )
461                            .await
462                            .unwrap();
463                    } else {
464                        transaction
465                            .execute(
466                                "INSERT INTO update_state VALUES (:pts, 0, :date, :seq)",
467                                named_params! {
468                                    ":pts": pts,
469                                    ":date": date,
470                                    ":seq": seq,
471                                },
472                            )
473                            .await
474                            .unwrap();
475                    }
476                }
477                UpdateState::Secondary { qts } => {
478                    let previous = db
479                        .fetch_one(
480                            "SELECT * FROM update_state LIMIT 1",
481                            named_params![],
482                            |_| Ok(()),
483                        )
484                        .await
485                        .unwrap();
486
487                    if previous.is_some() {
488                        transaction
489                            .execute(
490                                "UPDATE update_state SET qts = :qts",
491                                named_params! {":qts": qts},
492                            )
493                            .await
494                            .unwrap();
495                    } else {
496                        transaction
497                            .execute(
498                                "INSERT INTO update_state VALUES (0, :qts, 0, 0)",
499                                named_params! {":qts": qts},
500                            )
501                            .await
502                            .unwrap();
503                    }
504                }
505                UpdateState::Channel { id, pts } => {
506                    transaction
507                        .execute(
508                            "INSERT OR REPLACE INTO channel_state VALUES (:peer_id, :pts)",
509                            named_params! {
510                                ":peer_id": id,
511                                ":pts": pts,
512                            },
513                        )
514                        .await
515                        .unwrap();
516                }
517            }
518
519            transaction.commit().await.unwrap();
520        })
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
527
528    use {DcOption, KNOWN_DC_OPTIONS, PeerInfo, Session, UpdateState};
529
530    use super::*;
531
532    #[test]
533    fn exercise_sqlite_session() {
534        tokio::runtime::Builder::new_current_thread()
535            .enable_all()
536            .build()
537            .unwrap()
538            .block_on(do_exercise_sqlite_session());
539    }
540
541    async fn do_exercise_sqlite_session() {
542        let session = SqliteSession::open(":memory:").await.unwrap();
543
544        assert_eq!(session.home_dc_id(), DEFAULT_DC);
545        session.set_home_dc_id(DEFAULT_DC + 1).await;
546        assert_eq!(session.home_dc_id(), DEFAULT_DC + 1);
547
548        assert_eq!(
549            session.dc_option(KNOWN_DC_OPTIONS[0].id),
550            Some(KNOWN_DC_OPTIONS[0].clone())
551        );
552        let new_dc_option = DcOption {
553            id: KNOWN_DC_OPTIONS
554                .iter()
555                .map(|dc_option| dc_option.id)
556                .max()
557                .unwrap()
558                + 1,
559            ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 1),
560            ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 1, 0, 0),
561            auth_key: Some([1; 256]),
562        };
563        assert_eq!(session.dc_option(new_dc_option.id), None);
564        session.set_dc_option(&new_dc_option).await;
565        assert_eq!(session.dc_option(new_dc_option.id), Some(new_dc_option));
566
567        assert_eq!(session.peer(PeerId::self_user()).await, None);
568        assert_eq!(session.peer(PeerId::user_unchecked(1)).await, None);
569        let peer = PeerInfo::User {
570            id: 1,
571            auth: None,
572            bot: Some(true),
573            is_self: Some(true),
574        };
575        session.cache_peer(&peer).await;
576        assert_eq!(session.peer(PeerId::self_user()).await, Some(peer.clone()));
577        assert_eq!(session.peer(PeerId::user_unchecked(1)).await, Some(peer));
578
579        assert_eq!(session.peer(PeerId::channel_unchecked(1)).await, None);
580        let peer = PeerInfo::Channel {
581            id: 1,
582            auth: Some(PeerAuth::from_hash(-1)),
583            kind: Some(ChannelKind::Broadcast),
584        };
585        session.cache_peer(&peer).await;
586        assert_eq!(session.peer(PeerId::channel_unchecked(1)).await, Some(peer));
587
588        assert_eq!(session.updates_state().await, UpdatesState::default());
589        session
590            .set_update_state(UpdateState::All(UpdatesState {
591                pts: 1,
592                qts: 2,
593                date: 3,
594                seq: 4,
595                channels: vec![
596                    ChannelState { id: 5, pts: 6 },
597                    ChannelState { id: 7, pts: 8 },
598                ],
599            }))
600            .await;
601        session
602            .set_update_state(UpdateState::Primary {
603                pts: 2,
604                date: 4,
605                seq: 5,
606            })
607            .await;
608        session
609            .set_update_state(UpdateState::Secondary { qts: 3 })
610            .await;
611        session
612            .set_update_state(UpdateState::Channel { id: 7, pts: 9 })
613            .await;
614        assert_eq!(
615            session.updates_state().await,
616            UpdatesState {
617                pts: 2,
618                qts: 3,
619                date: 4,
620                seq: 5,
621                channels: vec![
622                    ChannelState { id: 5, pts: 6 },
623                    ChannelState { id: 7, pts: 9 },
624                ],
625            }
626        );
627    }
628}