1use std::collections::HashMap;
10use std::fmt;
11use std::net::AddrParseError;
12use std::path::Path;
13use std::sync::Mutex;
14
15use libsql::{named_params, params};
16use tokio::sync::Mutex as AsyncMutex;
17
18use crate::types::{
19 ChannelKind, ChannelState, DcOption, PeerAuth, PeerId, PeerInfo, PeerKind, UpdateState,
20 UpdatesState,
21};
22use crate::{BoxFuture, DEFAULT_DC, KNOWN_DC_OPTIONS, Session};
23
24const VERSION: i64 = 1;
25
26struct Database(libsql::Connection);
27
28struct Cache {
29 pub home_dc: i32,
30 pub dc_options: HashMap<i32, DcOption>,
31}
32
33pub struct SqliteSession {
35 database: AsyncMutex<Database>,
36 cache: Mutex<Cache>,
37}
38
39#[derive(Debug)]
40pub enum SqliteSessionError {
41 Poisoned,
42 AddrParse(std::net::AddrParseError),
43 Sql(libsql::Error),
44 InvalidAuthKeyLength(usize),
45}
46
47impl std::error::Error for SqliteSessionError {}
48
49impl fmt::Display for SqliteSessionError {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 match self {
52 SqliteSessionError::Poisoned => write!(f, "session lock is poisoned"),
53 SqliteSessionError::AddrParse(_) => write!(f, "invalid socket address syntax"),
54 SqliteSessionError::Sql(err) => write!(f, "{err}"),
55 SqliteSessionError::InvalidAuthKeyLength(actual) => {
56 write!(f, "invalid auth_key length: expected 256, got {actual}")
57 }
58 }
59 }
60}
61
62impl From<AddrParseError> for SqliteSessionError {
63 fn from(x: AddrParseError) -> Self {
64 Self::AddrParse(x)
65 }
66}
67
68impl From<libsql::Error> for SqliteSessionError {
69 fn from(x: libsql::Error) -> Self {
70 Self::Sql(x)
71 }
72}
73
74#[repr(u8)]
75enum PeerSubtype {
76 UserSelf = 1,
77 UserBot = 2,
78 UserSelfBot = 3,
79 Megagroup = 4,
80 Broadcast = 8,
81 Gigagroup = 12,
82}
83
84impl Database {
85 async fn init(&self) -> libsql::Result<()> {
86 let mut user_version: i64 = self
87 .fetch_one("PRAGMA user_version", params![], |row| row.get(0))
88 .await?
89 .unwrap_or(0);
90 if user_version == VERSION {
91 return Ok(());
92 }
93
94 if user_version == 0 {
95 self.migrate_v0_to_v1().await?;
96 user_version += 1;
97 }
98 if user_version == VERSION {
99 self.0
101 .execute(&format!("PRAGMA user_version = {VERSION}"), params![])
102 .await?;
103 }
104 Ok(())
105 }
106
107 async fn migrate_v0_to_v1(&self) -> libsql::Result<()> {
108 let transaction = self.begin_transaction().await?;
109 transaction
110 .execute(
111 "CREATE TABLE dc_home (
112 dc_id INTEGER NOT NULL,
113 PRIMARY KEY(dc_id))",
114 params![],
115 )
116 .await?;
117 transaction
118 .execute(
119 "CREATE TABLE dc_option (
120 dc_id INTEGER NOT NULL,
121 ipv4 TEXT NOT NULL,
122 ipv6 TEXT NOT NULL,
123 auth_key BLOB,
124 PRIMARY KEY (dc_id))",
125 params![],
126 )
127 .await?;
128 transaction
129 .execute(
130 "CREATE TABLE peer_info (
131 peer_id INTEGER NOT NULL,
132 hash INTEGER,
133 subtype INTEGER,
134 PRIMARY KEY (peer_id))",
135 params![],
136 )
137 .await?;
138 transaction
139 .execute(
140 "CREATE TABLE update_state (
141 pts INTEGER NOT NULL,
142 qts INTEGER NOT NULL,
143 date INTEGER NOT NULL,
144 seq INTEGER NOT NULL)",
145 params![],
146 )
147 .await?;
148 transaction
149 .execute(
150 "CREATE TABLE channel_state (
151 peer_id INTEGER NOT NULL,
152 pts INTEGER NOT NULL,
153 PRIMARY KEY (peer_id))",
154 params![],
155 )
156 .await?;
157
158 transaction.commit().await?;
159 Ok(())
160 }
161
162 async fn begin_transaction(&self) -> libsql::Result<libsql::Transaction> {
163 self.0.transaction().await
164 }
165
166 async fn fetch_one<
167 T,
168 P: libsql::params::IntoParams,
169 F: FnOnce(libsql::Row) -> libsql::Result<T>,
170 >(
171 &self,
172 statement: &str,
173 params: P,
174 select: F,
175 ) -> libsql::Result<Option<T>> {
176 let mut statement = self.0.prepare(statement).await?;
177 let result = statement.query_row(params).await;
178 match result {
179 Ok(value) => Ok(Some(select(value)?)),
180 Err(libsql::Error::QueryReturnedNoRows) => Ok(None),
181 Err(e) => Err(e),
182 }
183 }
184
185 async fn fetch_all<
186 T,
187 P: libsql::params::IntoParams,
188 F: FnMut(libsql::Row) -> Result<T, SqliteSessionError>,
189 >(
190 &self,
191 statement: &str,
192 params: P,
193 mut select: F,
194 ) -> Result<Vec<T>, SqliteSessionError> {
195 let statement = self.0.prepare(statement).await?;
196 let mut rows = statement.query(params).await?;
197 let mut result = Vec::new();
198 while let Some(row) = rows.next().await? {
199 result.push(select(row)?);
200 }
201 Ok(result)
202 }
203}
204
205impl SqliteSession {
206 pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self, SqliteSessionError> {
209 let conn = libsql::Builder::new_local(path).build().await?.connect()?;
210 let db = Database(conn);
211 db.init().await?;
212
213 let home_dc = db
214 .fetch_one("SELECT * FROM dc_home LIMIT 1", named_params![], |row| {
215 Ok(row.get::<i32>(0)?)
216 })
217 .await?
218 .unwrap_or(DEFAULT_DC);
219
220 let dc_options = db
221 .fetch_all("SELECT * FROM dc_option", named_params![], |row| {
222 Ok(DcOption {
223 id: row.get::<i32>(0)?,
224 ipv4: row.get::<String>(1)?.parse()?,
225 ipv6: row.get::<String>(2)?.parse()?,
226 auth_key: match row.get::<Option<Vec<u8>>>(3)? {
227 None => None,
228 Some(auth_key) => Some(auth_key.try_into().map_err(|v: Vec<u8>| {
229 SqliteSessionError::InvalidAuthKeyLength(v.len())
230 })?),
231 },
232 })
233 })
234 .await?
235 .into_iter()
236 .map(|dc_option| (dc_option.id, dc_option))
237 .collect();
238
239 Ok(SqliteSession {
240 database: AsyncMutex::new(db),
241 cache: Mutex::new(Cache {
242 home_dc,
243 dc_options,
244 }),
245 })
246 }
247}
248
249impl Session for SqliteSession {
250 type Error = SqliteSessionError;
251
252 fn home_dc_id(&self) -> Result<i32, SqliteSessionError> {
253 Ok(self
254 .cache
255 .lock()
256 .map_err(|_| SqliteSessionError::Poisoned)?
257 .home_dc)
258 }
259
260 fn set_home_dc_id(&self, dc_id: i32) -> BoxFuture<'_, Result<(), SqliteSessionError>> {
261 let ok = match self.cache.lock() {
262 Err(_) => Err(SqliteSessionError::Poisoned),
263 Ok(mut x) => {
264 x.home_dc = dc_id;
265 Ok(())
266 }
267 };
268 Box::pin(async move {
269 ok?;
270
271 let transaction = self.database.lock().await.begin_transaction().await?;
272 transaction
273 .execute("DELETE FROM dc_home", params![])
274 .await?;
275 let stmt = transaction
276 .prepare("INSERT INTO dc_home VALUES (:dc_id)")
277 .await?;
278 stmt.execute(named_params! {":dc_id": dc_id}).await?;
279 transaction.commit().await?;
280 Ok(())
281 })
282 }
283
284 fn dc_option(&self, dc_id: i32) -> Result<Option<DcOption>, SqliteSessionError> {
285 Ok(self
286 .cache
287 .lock()
288 .map_err(|_| SqliteSessionError::Poisoned)?
289 .dc_options
290 .get(&dc_id)
291 .cloned()
292 .or_else(|| {
293 KNOWN_DC_OPTIONS
294 .iter()
295 .find(|dc_option| dc_option.id == dc_id)
296 .cloned()
297 }))
298 }
299
300 fn set_dc_option(&self, dc_option: &DcOption) -> BoxFuture<'_, Result<(), SqliteSessionError>> {
301 let ok = match self.cache.lock() {
302 Err(_) => Err(SqliteSessionError::Poisoned),
303 Ok(mut x) => {
304 x.dc_options.insert(dc_option.id, dc_option.clone());
305 Ok(())
306 }
307 };
308
309 let dc_option = dc_option.clone();
310 Box::pin(async move {
311 ok?;
312
313 let db = self.database.lock().await;
314 db.0.execute(
315 "INSERT OR REPLACE INTO dc_option VALUES (:dc_id, :ipv4, :ipv6, :auth_key)",
316 named_params! {
317 ":dc_id": dc_option.id,
318 ":ipv4": dc_option.ipv4.to_string(),
319 ":ipv6": dc_option.ipv6.to_string(),
320 ":auth_key": dc_option.auth_key.map(|k| k.to_vec()),
321 },
322 )
323 .await?;
324 Ok(())
325 })
326 }
327
328 fn peer(&self, peer: PeerId) -> BoxFuture<'_, Result<Option<PeerInfo>, SqliteSessionError>> {
329 Box::pin(async move {
330 let db = self.database.lock().await;
331 let map_row = |row: libsql::Row| {
332 let subtype = row.get::<Option<i64>>(2)?.map(|s| s as u8);
333 Ok(match peer.kind() {
334 PeerKind::User => PeerInfo::User {
335 id: PeerId::user_unchecked(row.get::<i64>(0)?).bare_id_unchecked(),
336 auth: row.get::<Option<i64>>(1)?.map(PeerAuth::from_hash),
337 bot: subtype.map(|s| s & PeerSubtype::UserBot as u8 != 0),
338 is_self: subtype.map(|s| s & PeerSubtype::UserSelf as u8 != 0),
339 },
340 PeerKind::Chat => PeerInfo::Chat {
341 id: peer.bare_id_unchecked(),
342 },
343 PeerKind::Channel => PeerInfo::Channel {
344 id: peer.bare_id_unchecked(),
345 auth: row.get::<Option<i64>>(1)?.map(PeerAuth::from_hash),
346 kind: subtype.and_then(|s| {
347 if (s & PeerSubtype::Gigagroup as u8) == PeerSubtype::Gigagroup as u8 {
348 Some(ChannelKind::Gigagroup)
349 } else if s & PeerSubtype::Broadcast as u8 != 0 {
350 Some(ChannelKind::Broadcast)
351 } else if s & PeerSubtype::Megagroup as u8 != 0 {
352 Some(ChannelKind::Megagroup)
353 } else {
354 None
355 }
356 }),
357 },
358 })
359 };
360
361 Ok(if let Some(peer_id) = peer.bot_api_dialog_id() {
362 db.fetch_one(
363 "SELECT * FROM peer_info WHERE peer_id = :peer_id LIMIT 1",
364 named_params! {":peer_id": peer_id},
365 map_row,
366 )
367 .await?
368 } else {
369 db.fetch_one(
370 "SELECT * FROM peer_info WHERE subtype & :type LIMIT 1",
371 named_params! {":type": PeerSubtype::UserSelf as i64},
372 map_row,
373 )
374 .await?
375 })
376 })
377 }
378
379 fn cache_peer(&self, peer: &PeerInfo) -> BoxFuture<'_, Result<(), SqliteSessionError>> {
380 let peer = peer.clone();
381 Box::pin(async move {
382 let peer = if let Some(mut existing_peer) = self.peer(peer.id()).await? {
383 existing_peer.extend_info(&peer);
384 existing_peer
385 } else {
386 peer
387 };
388
389 let db = self.database.lock().await;
390 let stmt =
391 db.0.prepare("INSERT OR REPLACE INTO peer_info VALUES (:peer_id, :hash, :subtype)")
392 .await?;
393 let subtype = match peer {
394 PeerInfo::User { bot, is_self, .. } => {
395 match (bot.unwrap_or_default(), is_self.unwrap_or_default()) {
396 (true, true) => Some(PeerSubtype::UserSelfBot),
397 (true, false) => Some(PeerSubtype::UserBot),
398 (false, true) => Some(PeerSubtype::UserSelf),
399 (false, false) => None,
400 }
401 }
402 PeerInfo::Chat { .. } => None,
403 PeerInfo::Channel { kind, .. } => kind.map(|kind| match kind {
404 ChannelKind::Megagroup => PeerSubtype::Megagroup,
405 ChannelKind::Broadcast => PeerSubtype::Broadcast,
406 ChannelKind::Gigagroup => PeerSubtype::Gigagroup,
407 }),
408 };
409 let mut params = vec![];
410 let peer_id = peer.id().bot_api_dialog_id_unchecked();
411 params.push((":peer_id".to_owned(), peer_id));
412 let hash = peer.auth().unwrap_or_default().hash();
413 if peer.auth().is_some() {
414 params.push((":hash".to_owned(), hash));
415 }
416 let subtype = subtype.map(|s| s as i64);
417 if subtype.is_some() {
418 params.push((":subtype".to_owned(), subtype.unwrap()));
419 }
420 stmt.execute(params).await?;
421 Ok(())
422 })
423 }
424
425 fn updates_state(&self) -> BoxFuture<'_, Result<UpdatesState, SqliteSessionError>> {
426 Box::pin(async move {
427 let db = self.database.lock().await;
428 let mut state = db
429 .fetch_one(
430 "SELECT * FROM update_state LIMIT 1",
431 named_params![],
432 |row| {
433 Ok(UpdatesState {
434 pts: row.get(0)?,
435 qts: row.get(1)?,
436 date: row.get(2)?,
437 seq: row.get(3)?,
438 channels: Vec::new(),
439 })
440 },
441 )
442 .await?
443 .unwrap_or_default();
444 state.channels = db
445 .fetch_all("SELECT * FROM channel_state", named_params![], |row| {
446 Ok(ChannelState {
447 id: row.get(0)?,
448 pts: row.get(1)?,
449 })
450 })
451 .await?;
452 Ok(state)
453 })
454 }
455
456 fn set_update_state(
457 &self,
458 update: UpdateState,
459 ) -> BoxFuture<'_, Result<(), SqliteSessionError>> {
460 Box::pin(async move {
461 let db = self.database.lock().await;
462 let transaction = db.begin_transaction().await?;
463
464 match update {
465 UpdateState::All(updates_state) => {
466 transaction
467 .execute("DELETE FROM update_state", params![])
468 .await?;
469 transaction
470 .execute(
471 "INSERT INTO update_state VALUES (:pts, :qts, :date, :seq)",
472 named_params! {
473 ":pts": updates_state.pts,
474 ":qts": updates_state.qts,
475 ":date": updates_state.date,
476 ":seq": updates_state.seq,
477 },
478 )
479 .await?;
480
481 transaction
482 .execute("DELETE FROM channel_state", params![])
483 .await?;
484 for channel in updates_state.channels {
485 transaction
486 .execute(
487 "INSERT INTO channel_state VALUES (:peer_id, :pts)",
488 named_params! {
489 ":peer_id": channel.id,
490 ":pts": channel.pts,
491 },
492 )
493 .await?;
494 }
495 }
496 UpdateState::Primary { pts, date, seq } => {
497 let previous = db
498 .fetch_one(
499 "SELECT * FROM update_state LIMIT 1",
500 named_params![],
501 |_| Ok(()),
502 )
503 .await?;
504
505 if previous.is_some() {
506 transaction
507 .execute(
508 "UPDATE update_state SET pts = :pts, date = :date, seq = :seq",
509 named_params! {
510 ":pts": pts,
511 ":date": date,
512 ":seq": seq,
513 },
514 )
515 .await?;
516 } else {
517 transaction
518 .execute(
519 "INSERT INTO update_state VALUES (:pts, 0, :date, :seq)",
520 named_params! {
521 ":pts": pts,
522 ":date": date,
523 ":seq": seq,
524 },
525 )
526 .await?;
527 }
528 }
529 UpdateState::Secondary { qts } => {
530 let previous = db
531 .fetch_one(
532 "SELECT * FROM update_state LIMIT 1",
533 named_params![],
534 |_| Ok(()),
535 )
536 .await?;
537
538 if previous.is_some() {
539 transaction
540 .execute(
541 "UPDATE update_state SET qts = :qts",
542 named_params! {":qts": qts},
543 )
544 .await?;
545 } else {
546 transaction
547 .execute(
548 "INSERT INTO update_state VALUES (0, :qts, 0, 0)",
549 named_params! {":qts": qts},
550 )
551 .await?;
552 }
553 }
554 UpdateState::Channel { id, pts } => {
555 transaction
556 .execute(
557 "INSERT OR REPLACE INTO channel_state VALUES (:peer_id, :pts)",
558 named_params! {
559 ":peer_id": id,
560 ":pts": pts,
561 },
562 )
563 .await?;
564 }
565 }
566
567 transaction.commit().await?;
568 Ok(())
569 })
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
576
577 use {DcOption, KNOWN_DC_OPTIONS, PeerInfo, Session, UpdateState};
578
579 use super::*;
580
581 #[test]
582 fn exercise_sqlite_session() {
583 tokio::runtime::Builder::new_current_thread()
584 .enable_all()
585 .build()
586 .unwrap()
587 .block_on(do_exercise_sqlite_session());
588 }
589
590 async fn do_exercise_sqlite_session() {
591 let session = SqliteSession::open(":memory:").await.unwrap();
592
593 assert_eq!(session.home_dc_id().unwrap(), DEFAULT_DC);
594 session.set_home_dc_id(DEFAULT_DC + 1).await.unwrap();
595 assert_eq!(session.home_dc_id().unwrap(), DEFAULT_DC + 1);
596
597 assert_eq!(
598 session.dc_option(KNOWN_DC_OPTIONS[0].id).unwrap(),
599 Some(KNOWN_DC_OPTIONS[0].clone())
600 );
601 let new_dc_option = DcOption {
602 id: KNOWN_DC_OPTIONS
603 .iter()
604 .map(|dc_option| dc_option.id)
605 .max()
606 .unwrap()
607 + 1,
608 ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 1),
609 ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 1, 0, 0),
610 auth_key: Some([1; 256]),
611 };
612 assert_eq!(session.dc_option(new_dc_option.id).unwrap(), None);
613 session.set_dc_option(&new_dc_option).await.unwrap();
614 assert_eq!(
615 session.dc_option(new_dc_option.id).unwrap(),
616 Some(new_dc_option)
617 );
618
619 assert_eq!(session.peer(PeerId::self_user()).await.unwrap(), None);
620 assert_eq!(session.peer(PeerId::user_unchecked(1)).await.unwrap(), None);
621 let peer = PeerInfo::User {
622 id: 1,
623 auth: None,
624 bot: Some(true),
625 is_self: Some(true),
626 };
627 session.cache_peer(&peer).await.unwrap();
628 assert_eq!(
629 session.peer(PeerId::self_user()).await.unwrap(),
630 Some(peer.clone())
631 );
632 assert_eq!(
633 session.peer(PeerId::user_unchecked(1)).await.unwrap(),
634 Some(peer)
635 );
636
637 assert_eq!(
638 session.peer(PeerId::channel_unchecked(1)).await.unwrap(),
639 None
640 );
641 let peer = PeerInfo::Channel {
642 id: 1,
643 auth: Some(PeerAuth::from_hash(-1)),
644 kind: Some(ChannelKind::Broadcast),
645 };
646 session.cache_peer(&peer).await.unwrap();
647 assert_eq!(
648 session.peer(PeerId::channel_unchecked(1)).await.unwrap(),
649 Some(peer)
650 );
651
652 assert_eq!(
653 session.updates_state().await.unwrap(),
654 UpdatesState::default()
655 );
656 session
657 .set_update_state(UpdateState::All(UpdatesState {
658 pts: 1,
659 qts: 2,
660 date: 3,
661 seq: 4,
662 channels: vec![
663 ChannelState { id: 5, pts: 6 },
664 ChannelState { id: 7, pts: 8 },
665 ],
666 }))
667 .await
668 .unwrap();
669 session
670 .set_update_state(UpdateState::Primary {
671 pts: 2,
672 date: 4,
673 seq: 5,
674 })
675 .await
676 .unwrap();
677 session
678 .set_update_state(UpdateState::Secondary { qts: 3 })
679 .await
680 .unwrap();
681 session
682 .set_update_state(UpdateState::Channel { id: 7, pts: 9 })
683 .await
684 .unwrap();
685 assert_eq!(
686 session.updates_state().await.unwrap(),
687 UpdatesState {
688 pts: 2,
689 qts: 3,
690 date: 4,
691 seq: 5,
692 channels: vec![
693 ChannelState { id: 5, pts: 6 },
694 ChannelState { id: 7, pts: 9 },
695 ],
696 }
697 );
698 }
699}