Skip to main content

grammers_session/storages/
sqlite.rs

1// Copyright 2020 - developers of the `grammers` project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::collections::HashMap;
10use std::fmt;
11use std::net::AddrParseError;
12use std::path::Path;
13use std::sync::Mutex;
14
15use libsql::{named_params, params};
16use tokio::sync::Mutex as AsyncMutex;
17
18use crate::types::{
19    ChannelKind, ChannelState, DcOption, PeerAuth, PeerId, PeerInfo, PeerKind, UpdateState,
20    UpdatesState,
21};
22use crate::{BoxFuture, DEFAULT_DC, KNOWN_DC_OPTIONS, Session};
23
24const VERSION: i64 = 1;
25
26struct Database(libsql::Connection);
27
28struct Cache {
29    pub home_dc: i32,
30    pub dc_options: HashMap<i32, DcOption>,
31}
32
33/// SQLite-based storage. This is the recommended option.
34pub struct SqliteSession {
35    database: AsyncMutex<Database>,
36    cache: Mutex<Cache>,
37}
38
39#[derive(Debug)]
40pub enum SqliteSessionError {
41    Poisoned,
42    AddrParse(std::net::AddrParseError),
43    Sql(libsql::Error),
44    InvalidAuthKeyLength(usize),
45}
46
47impl std::error::Error for SqliteSessionError {}
48
49impl fmt::Display for SqliteSessionError {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self {
52            SqliteSessionError::Poisoned => write!(f, "session lock is poisoned"),
53            SqliteSessionError::AddrParse(_) => write!(f, "invalid socket address syntax"),
54            SqliteSessionError::Sql(err) => write!(f, "{err}"),
55            SqliteSessionError::InvalidAuthKeyLength(actual) => {
56                write!(f, "invalid auth_key length: expected 256, got {actual}")
57            }
58        }
59    }
60}
61
62impl From<AddrParseError> for SqliteSessionError {
63    fn from(x: AddrParseError) -> Self {
64        Self::AddrParse(x)
65    }
66}
67
68impl From<libsql::Error> for SqliteSessionError {
69    fn from(x: libsql::Error) -> Self {
70        Self::Sql(x)
71    }
72}
73
74#[repr(u8)]
75enum PeerSubtype {
76    UserSelf = 1,
77    UserBot = 2,
78    UserSelfBot = 3,
79    Megagroup = 4,
80    Broadcast = 8,
81    Gigagroup = 12,
82}
83
84impl Database {
85    async fn init(&self) -> libsql::Result<()> {
86        let mut user_version: i64 = self
87            .fetch_one("PRAGMA user_version", params![], |row| row.get(0))
88            .await?
89            .unwrap_or(0);
90        if user_version == VERSION {
91            return Ok(());
92        }
93
94        if user_version == 0 {
95            self.migrate_v0_to_v1().await?;
96            user_version += 1;
97        }
98        if user_version == VERSION {
99            // Can't bind PRAGMA parameters, but `VERSION` is not user-controlled input.
100            self.0
101                .execute(&format!("PRAGMA user_version = {VERSION}"), params![])
102                .await?;
103        }
104        Ok(())
105    }
106
107    async fn migrate_v0_to_v1(&self) -> libsql::Result<()> {
108        let transaction = self.begin_transaction().await?;
109        transaction
110            .execute(
111                "CREATE TABLE dc_home (
112                dc_id INTEGER NOT NULL,
113                PRIMARY KEY(dc_id))",
114                params![],
115            )
116            .await?;
117        transaction
118            .execute(
119                "CREATE TABLE dc_option (
120                dc_id INTEGER NOT NULL,
121                ipv4 TEXT NOT NULL,
122                ipv6 TEXT NOT NULL,
123                auth_key BLOB,
124                PRIMARY KEY (dc_id))",
125                params![],
126            )
127            .await?;
128        transaction
129            .execute(
130                "CREATE TABLE peer_info (
131                peer_id INTEGER NOT NULL,
132                hash INTEGER,
133                subtype INTEGER,
134                PRIMARY KEY (peer_id))",
135                params![],
136            )
137            .await?;
138        transaction
139            .execute(
140                "CREATE TABLE update_state (
141                pts INTEGER NOT NULL,
142                qts INTEGER NOT NULL,
143                date INTEGER NOT NULL,
144                seq INTEGER NOT NULL)",
145                params![],
146            )
147            .await?;
148        transaction
149            .execute(
150                "CREATE TABLE channel_state (
151                peer_id INTEGER NOT NULL,
152                pts INTEGER NOT NULL,
153                PRIMARY KEY (peer_id))",
154                params![],
155            )
156            .await?;
157
158        transaction.commit().await?;
159        Ok(())
160    }
161
162    async fn begin_transaction(&self) -> libsql::Result<libsql::Transaction> {
163        self.0.transaction().await
164    }
165
166    async fn fetch_one<
167        T,
168        P: libsql::params::IntoParams,
169        F: FnOnce(libsql::Row) -> libsql::Result<T>,
170    >(
171        &self,
172        statement: &str,
173        params: P,
174        select: F,
175    ) -> libsql::Result<Option<T>> {
176        let mut statement = self.0.prepare(statement).await?;
177        let result = statement.query_row(params).await;
178        match result {
179            Ok(value) => Ok(Some(select(value)?)),
180            Err(libsql::Error::QueryReturnedNoRows) => Ok(None),
181            Err(e) => Err(e),
182        }
183    }
184
185    async fn fetch_all<
186        T,
187        P: libsql::params::IntoParams,
188        F: FnMut(libsql::Row) -> Result<T, SqliteSessionError>,
189    >(
190        &self,
191        statement: &str,
192        params: P,
193        mut select: F,
194    ) -> Result<Vec<T>, SqliteSessionError> {
195        let statement = self.0.prepare(statement).await?;
196        let mut rows = statement.query(params).await?;
197        let mut result = Vec::new();
198        while let Some(row) = rows.next().await? {
199            result.push(select(row)?);
200        }
201        Ok(result)
202    }
203}
204
205impl SqliteSession {
206    /// Open a connection to the SQLite database at `path`,
207    /// creating one if it doesn't exist.
208    pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self, SqliteSessionError> {
209        let conn = libsql::Builder::new_local(path).build().await?.connect()?;
210        let db = Database(conn);
211        db.init().await?;
212
213        let home_dc = db
214            .fetch_one("SELECT * FROM dc_home LIMIT 1", named_params![], |row| {
215                Ok(row.get::<i32>(0)?)
216            })
217            .await?
218            .unwrap_or(DEFAULT_DC);
219
220        let dc_options = db
221            .fetch_all("SELECT * FROM dc_option", named_params![], |row| {
222                Ok(DcOption {
223                    id: row.get::<i32>(0)?,
224                    ipv4: row.get::<String>(1)?.parse()?,
225                    ipv6: row.get::<String>(2)?.parse()?,
226                    auth_key: match row.get::<Option<Vec<u8>>>(3)? {
227                        None => None,
228                        Some(auth_key) => Some(auth_key.try_into().map_err(|v: Vec<u8>| {
229                            SqliteSessionError::InvalidAuthKeyLength(v.len())
230                        })?),
231                    },
232                })
233            })
234            .await?
235            .into_iter()
236            .map(|dc_option| (dc_option.id, dc_option))
237            .collect();
238
239        Ok(SqliteSession {
240            database: AsyncMutex::new(db),
241            cache: Mutex::new(Cache {
242                home_dc,
243                dc_options,
244            }),
245        })
246    }
247}
248
249impl Session for SqliteSession {
250    type Error = SqliteSessionError;
251
252    fn home_dc_id(&self) -> Result<i32, SqliteSessionError> {
253        Ok(self
254            .cache
255            .lock()
256            .map_err(|_| SqliteSessionError::Poisoned)?
257            .home_dc)
258    }
259
260    fn set_home_dc_id(&self, dc_id: i32) -> BoxFuture<'_, Result<(), SqliteSessionError>> {
261        let ok = match self.cache.lock() {
262            Err(_) => Err(SqliteSessionError::Poisoned),
263            Ok(mut x) => {
264                x.home_dc = dc_id;
265                Ok(())
266            }
267        };
268        Box::pin(async move {
269            ok?;
270
271            let transaction = self.database.lock().await.begin_transaction().await?;
272            transaction
273                .execute("DELETE FROM dc_home", params![])
274                .await?;
275            let stmt = transaction
276                .prepare("INSERT INTO dc_home VALUES (:dc_id)")
277                .await?;
278            stmt.execute(named_params! {":dc_id": dc_id}).await?;
279            transaction.commit().await?;
280            Ok(())
281        })
282    }
283
284    fn dc_option(&self, dc_id: i32) -> Result<Option<DcOption>, SqliteSessionError> {
285        Ok(self
286            .cache
287            .lock()
288            .map_err(|_| SqliteSessionError::Poisoned)?
289            .dc_options
290            .get(&dc_id)
291            .cloned()
292            .or_else(|| {
293                KNOWN_DC_OPTIONS
294                    .iter()
295                    .find(|dc_option| dc_option.id == dc_id)
296                    .cloned()
297            }))
298    }
299
300    fn set_dc_option(&self, dc_option: &DcOption) -> BoxFuture<'_, Result<(), SqliteSessionError>> {
301        let ok = match self.cache.lock() {
302            Err(_) => Err(SqliteSessionError::Poisoned),
303            Ok(mut x) => {
304                x.dc_options.insert(dc_option.id, dc_option.clone());
305                Ok(())
306            }
307        };
308
309        let dc_option = dc_option.clone();
310        Box::pin(async move {
311            ok?;
312
313            let db = self.database.lock().await;
314            db.0.execute(
315                "INSERT OR REPLACE INTO dc_option VALUES (:dc_id, :ipv4, :ipv6, :auth_key)",
316                named_params! {
317                    ":dc_id": dc_option.id,
318                    ":ipv4": dc_option.ipv4.to_string(),
319                    ":ipv6": dc_option.ipv6.to_string(),
320                    ":auth_key": dc_option.auth_key.map(|k| k.to_vec()),
321                },
322            )
323            .await?;
324            Ok(())
325        })
326    }
327
328    fn peer(&self, peer: PeerId) -> BoxFuture<'_, Result<Option<PeerInfo>, SqliteSessionError>> {
329        Box::pin(async move {
330            let db = self.database.lock().await;
331            let map_row = |row: libsql::Row| {
332                let subtype = row.get::<Option<i64>>(2)?.map(|s| s as u8);
333                Ok(match peer.kind() {
334                    PeerKind::User => PeerInfo::User {
335                        id: PeerId::user_unchecked(row.get::<i64>(0)?).bare_id_unchecked(),
336                        auth: row.get::<Option<i64>>(1)?.map(PeerAuth::from_hash),
337                        bot: subtype.map(|s| s & PeerSubtype::UserBot as u8 != 0),
338                        is_self: subtype.map(|s| s & PeerSubtype::UserSelf as u8 != 0),
339                    },
340                    PeerKind::Chat => PeerInfo::Chat {
341                        id: peer.bare_id_unchecked(),
342                    },
343                    PeerKind::Channel => PeerInfo::Channel {
344                        id: peer.bare_id_unchecked(),
345                        auth: row.get::<Option<i64>>(1)?.map(PeerAuth::from_hash),
346                        kind: subtype.and_then(|s| {
347                            if (s & PeerSubtype::Gigagroup as u8) == PeerSubtype::Gigagroup as u8 {
348                                Some(ChannelKind::Gigagroup)
349                            } else if s & PeerSubtype::Broadcast as u8 != 0 {
350                                Some(ChannelKind::Broadcast)
351                            } else if s & PeerSubtype::Megagroup as u8 != 0 {
352                                Some(ChannelKind::Megagroup)
353                            } else {
354                                None
355                            }
356                        }),
357                    },
358                })
359            };
360
361            Ok(if let Some(peer_id) = peer.bot_api_dialog_id() {
362                db.fetch_one(
363                    "SELECT * FROM peer_info WHERE peer_id = :peer_id LIMIT 1",
364                    named_params! {":peer_id": peer_id},
365                    map_row,
366                )
367                .await?
368            } else {
369                db.fetch_one(
370                    "SELECT * FROM peer_info WHERE subtype & :type LIMIT 1",
371                    named_params! {":type": PeerSubtype::UserSelf as i64},
372                    map_row,
373                )
374                .await?
375            })
376        })
377    }
378
379    fn cache_peer(&self, peer: &PeerInfo) -> BoxFuture<'_, Result<(), SqliteSessionError>> {
380        let peer = peer.clone();
381        Box::pin(async move {
382            let peer = if let Some(mut existing_peer) = self.peer(peer.id()).await? {
383                existing_peer.extend_info(&peer);
384                existing_peer
385            } else {
386                peer
387            };
388
389            let db = self.database.lock().await;
390            let stmt =
391                db.0.prepare("INSERT OR REPLACE INTO peer_info VALUES (:peer_id, :hash, :subtype)")
392                    .await?;
393            let subtype = match peer {
394                PeerInfo::User { bot, is_self, .. } => {
395                    match (bot.unwrap_or_default(), is_self.unwrap_or_default()) {
396                        (true, true) => Some(PeerSubtype::UserSelfBot),
397                        (true, false) => Some(PeerSubtype::UserBot),
398                        (false, true) => Some(PeerSubtype::UserSelf),
399                        (false, false) => None,
400                    }
401                }
402                PeerInfo::Chat { .. } => None,
403                PeerInfo::Channel { kind, .. } => kind.map(|kind| match kind {
404                    ChannelKind::Megagroup => PeerSubtype::Megagroup,
405                    ChannelKind::Broadcast => PeerSubtype::Broadcast,
406                    ChannelKind::Gigagroup => PeerSubtype::Gigagroup,
407                }),
408            };
409            let mut params = vec![];
410            let peer_id = peer.id().bot_api_dialog_id_unchecked();
411            params.push((":peer_id".to_owned(), peer_id));
412            let hash = peer.auth().unwrap_or_default().hash();
413            if peer.auth().is_some() {
414                params.push((":hash".to_owned(), hash));
415            }
416            let subtype = subtype.map(|s| s as i64);
417            if subtype.is_some() {
418                params.push((":subtype".to_owned(), subtype.unwrap()));
419            }
420            stmt.execute(params).await?;
421            Ok(())
422        })
423    }
424
425    fn updates_state(&self) -> BoxFuture<'_, Result<UpdatesState, SqliteSessionError>> {
426        Box::pin(async move {
427            let db = self.database.lock().await;
428            let mut state = db
429                .fetch_one(
430                    "SELECT * FROM update_state LIMIT 1",
431                    named_params![],
432                    |row| {
433                        Ok(UpdatesState {
434                            pts: row.get(0)?,
435                            qts: row.get(1)?,
436                            date: row.get(2)?,
437                            seq: row.get(3)?,
438                            channels: Vec::new(),
439                        })
440                    },
441                )
442                .await?
443                .unwrap_or_default();
444            state.channels = db
445                .fetch_all("SELECT * FROM channel_state", named_params![], |row| {
446                    Ok(ChannelState {
447                        id: row.get(0)?,
448                        pts: row.get(1)?,
449                    })
450                })
451                .await?;
452            Ok(state)
453        })
454    }
455
456    fn set_update_state(
457        &self,
458        update: UpdateState,
459    ) -> BoxFuture<'_, Result<(), SqliteSessionError>> {
460        Box::pin(async move {
461            let db = self.database.lock().await;
462            let transaction = db.begin_transaction().await?;
463
464            match update {
465                UpdateState::All(updates_state) => {
466                    transaction
467                        .execute("DELETE FROM update_state", params![])
468                        .await?;
469                    transaction
470                        .execute(
471                            "INSERT INTO update_state VALUES (:pts, :qts, :date, :seq)",
472                            named_params! {
473                                ":pts": updates_state.pts,
474                                ":qts": updates_state.qts,
475                                ":date": updates_state.date,
476                                ":seq": updates_state.seq,
477                            },
478                        )
479                        .await?;
480
481                    transaction
482                        .execute("DELETE FROM channel_state", params![])
483                        .await?;
484                    for channel in updates_state.channels {
485                        transaction
486                            .execute(
487                                "INSERT INTO channel_state VALUES (:peer_id, :pts)",
488                                named_params! {
489                                    ":peer_id": channel.id,
490                                    ":pts": channel.pts,
491                                },
492                            )
493                            .await?;
494                    }
495                }
496                UpdateState::Primary { pts, date, seq } => {
497                    let previous = db
498                        .fetch_one(
499                            "SELECT * FROM update_state LIMIT 1",
500                            named_params![],
501                            |_| Ok(()),
502                        )
503                        .await?;
504
505                    if previous.is_some() {
506                        transaction
507                            .execute(
508                                "UPDATE update_state SET pts = :pts, date = :date, seq = :seq",
509                                named_params! {
510                                    ":pts": pts,
511                                    ":date": date,
512                                    ":seq": seq,
513                                },
514                            )
515                            .await?;
516                    } else {
517                        transaction
518                            .execute(
519                                "INSERT INTO update_state VALUES (:pts, 0, :date, :seq)",
520                                named_params! {
521                                    ":pts": pts,
522                                    ":date": date,
523                                    ":seq": seq,
524                                },
525                            )
526                            .await?;
527                    }
528                }
529                UpdateState::Secondary { qts } => {
530                    let previous = db
531                        .fetch_one(
532                            "SELECT * FROM update_state LIMIT 1",
533                            named_params![],
534                            |_| Ok(()),
535                        )
536                        .await?;
537
538                    if previous.is_some() {
539                        transaction
540                            .execute(
541                                "UPDATE update_state SET qts = :qts",
542                                named_params! {":qts": qts},
543                            )
544                            .await?;
545                    } else {
546                        transaction
547                            .execute(
548                                "INSERT INTO update_state VALUES (0, :qts, 0, 0)",
549                                named_params! {":qts": qts},
550                            )
551                            .await?;
552                    }
553                }
554                UpdateState::Channel { id, pts } => {
555                    transaction
556                        .execute(
557                            "INSERT OR REPLACE INTO channel_state VALUES (:peer_id, :pts)",
558                            named_params! {
559                                ":peer_id": id,
560                                ":pts": pts,
561                            },
562                        )
563                        .await?;
564                }
565            }
566
567            transaction.commit().await?;
568            Ok(())
569        })
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
576
577    use {DcOption, KNOWN_DC_OPTIONS, PeerInfo, Session, UpdateState};
578
579    use super::*;
580
581    #[test]
582    fn exercise_sqlite_session() {
583        tokio::runtime::Builder::new_current_thread()
584            .enable_all()
585            .build()
586            .unwrap()
587            .block_on(do_exercise_sqlite_session());
588    }
589
590    async fn do_exercise_sqlite_session() {
591        let session = SqliteSession::open(":memory:").await.unwrap();
592
593        assert_eq!(session.home_dc_id().unwrap(), DEFAULT_DC);
594        session.set_home_dc_id(DEFAULT_DC + 1).await.unwrap();
595        assert_eq!(session.home_dc_id().unwrap(), DEFAULT_DC + 1);
596
597        assert_eq!(
598            session.dc_option(KNOWN_DC_OPTIONS[0].id).unwrap(),
599            Some(KNOWN_DC_OPTIONS[0].clone())
600        );
601        let new_dc_option = DcOption {
602            id: KNOWN_DC_OPTIONS
603                .iter()
604                .map(|dc_option| dc_option.id)
605                .max()
606                .unwrap()
607                + 1,
608            ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 1),
609            ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 1, 0, 0),
610            auth_key: Some([1; 256]),
611        };
612        assert_eq!(session.dc_option(new_dc_option.id).unwrap(), None);
613        session.set_dc_option(&new_dc_option).await.unwrap();
614        assert_eq!(
615            session.dc_option(new_dc_option.id).unwrap(),
616            Some(new_dc_option)
617        );
618
619        assert_eq!(session.peer(PeerId::self_user()).await.unwrap(), None);
620        assert_eq!(session.peer(PeerId::user_unchecked(1)).await.unwrap(), None);
621        let peer = PeerInfo::User {
622            id: 1,
623            auth: None,
624            bot: Some(true),
625            is_self: Some(true),
626        };
627        session.cache_peer(&peer).await.unwrap();
628        assert_eq!(
629            session.peer(PeerId::self_user()).await.unwrap(),
630            Some(peer.clone())
631        );
632        assert_eq!(
633            session.peer(PeerId::user_unchecked(1)).await.unwrap(),
634            Some(peer)
635        );
636
637        assert_eq!(
638            session.peer(PeerId::channel_unchecked(1)).await.unwrap(),
639            None
640        );
641        let peer = PeerInfo::Channel {
642            id: 1,
643            auth: Some(PeerAuth::from_hash(-1)),
644            kind: Some(ChannelKind::Broadcast),
645        };
646        session.cache_peer(&peer).await.unwrap();
647        assert_eq!(
648            session.peer(PeerId::channel_unchecked(1)).await.unwrap(),
649            Some(peer)
650        );
651
652        assert_eq!(
653            session.updates_state().await.unwrap(),
654            UpdatesState::default()
655        );
656        session
657            .set_update_state(UpdateState::All(UpdatesState {
658                pts: 1,
659                qts: 2,
660                date: 3,
661                seq: 4,
662                channels: vec![
663                    ChannelState { id: 5, pts: 6 },
664                    ChannelState { id: 7, pts: 8 },
665                ],
666            }))
667            .await
668            .unwrap();
669        session
670            .set_update_state(UpdateState::Primary {
671                pts: 2,
672                date: 4,
673                seq: 5,
674            })
675            .await
676            .unwrap();
677        session
678            .set_update_state(UpdateState::Secondary { qts: 3 })
679            .await
680            .unwrap();
681        session
682            .set_update_state(UpdateState::Channel { id: 7, pts: 9 })
683            .await
684            .unwrap();
685        assert_eq!(
686            session.updates_state().await.unwrap(),
687            UpdatesState {
688                pts: 2,
689                qts: 3,
690                date: 4,
691                seq: 5,
692                channels: vec![
693                    ChannelState { id: 5, pts: 6 },
694                    ChannelState { id: 7, pts: 9 },
695                ],
696            }
697        );
698    }
699}