1use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Mutex;
12
13use futures_core::future::BoxFuture;
14use libsql::{named_params, params};
15use tokio::sync::Mutex as AsyncMutex;
16
17use crate::types::{
18 ChannelKind, ChannelState, DcOption, PeerAuth, PeerId, PeerInfo, PeerKind, UpdateState,
19 UpdatesState,
20};
21use crate::{DEFAULT_DC, KNOWN_DC_OPTIONS, Session};
22
23const VERSION: i64 = 1;
24
25struct Database(libsql::Connection);
26
27struct Cache {
28 pub home_dc: i32,
29 pub dc_options: HashMap<i32, DcOption>,
30}
31
32pub struct SqliteSession {
34 database: AsyncMutex<Database>,
35 cache: Mutex<Cache>,
36}
37
38#[repr(u8)]
39enum PeerSubtype {
40 UserSelf = 1,
41 UserBot = 2,
42 UserSelfBot = 3,
43 Megagroup = 4,
44 Broadcast = 8,
45 Gigagroup = 12,
46}
47
48impl Database {
49 async fn init(&self) -> libsql::Result<()> {
50 let mut user_version: i64 = self
51 .fetch_one("PRAGMA user_version", params![], |row| row.get(0))
52 .await?
53 .unwrap_or(0);
54 if user_version == VERSION {
55 return Ok(());
56 }
57
58 if user_version == 0 {
59 self.migrate_v0_to_v1().await?;
60 user_version += 1;
61 }
62 if user_version == VERSION {
63 self.0
65 .execute(&format!("PRAGMA user_version = {VERSION}"), params![])
66 .await?;
67 }
68 Ok(())
69 }
70
71 async fn migrate_v0_to_v1(&self) -> libsql::Result<()> {
72 let transaction = self.begin_transaction().await?;
73 transaction
74 .execute(
75 "CREATE TABLE dc_home (
76 dc_id INTEGER NOT NULL,
77 PRIMARY KEY(dc_id))",
78 params![],
79 )
80 .await?;
81 transaction
82 .execute(
83 "CREATE TABLE dc_option (
84 dc_id INTEGER NOT NULL,
85 ipv4 TEXT NOT NULL,
86 ipv6 TEXT NOT NULL,
87 auth_key BLOB,
88 PRIMARY KEY (dc_id))",
89 params![],
90 )
91 .await?;
92 transaction
93 .execute(
94 "CREATE TABLE peer_info (
95 peer_id INTEGER NOT NULL,
96 hash INTEGER,
97 subtype INTEGER,
98 PRIMARY KEY (peer_id))",
99 params![],
100 )
101 .await?;
102 transaction
103 .execute(
104 "CREATE TABLE update_state (
105 pts INTEGER NOT NULL,
106 qts INTEGER NOT NULL,
107 date INTEGER NOT NULL,
108 seq INTEGER NOT NULL)",
109 params![],
110 )
111 .await?;
112 transaction
113 .execute(
114 "CREATE TABLE channel_state (
115 peer_id INTEGER NOT NULL,
116 pts INTEGER NOT NULL,
117 PRIMARY KEY (peer_id))",
118 params![],
119 )
120 .await?;
121
122 transaction.commit().await?;
123 Ok(())
124 }
125
126 async fn begin_transaction(&self) -> libsql::Result<libsql::Transaction> {
127 self.0.transaction().await
128 }
129
130 async fn fetch_one<
131 T,
132 P: libsql::params::IntoParams,
133 F: FnOnce(libsql::Row) -> libsql::Result<T>,
134 >(
135 &self,
136 statement: &str,
137 params: P,
138 select: F,
139 ) -> libsql::Result<Option<T>> {
140 let mut statement = self.0.prepare(statement).await?;
141 let result = statement.query_row(params).await;
142 match result {
143 Ok(value) => Ok(Some(select(value)?)),
144 Err(libsql::Error::QueryReturnedNoRows) => Ok(None),
145 Err(e) => Err(e),
146 }
147 }
148
149 async fn fetch_all<
150 T,
151 P: libsql::params::IntoParams,
152 F: FnMut(libsql::Row) -> libsql::Result<T>,
153 >(
154 &self,
155 statement: &str,
156 params: P,
157 mut select: F,
158 ) -> libsql::Result<Vec<T>> {
159 let statement = self.0.prepare(statement).await?;
160 let mut rows = statement.query(params).await?;
161 let mut result = Vec::new();
162 while let Some(row) = rows.next().await? {
163 result.push(select(row)?);
164 }
165 Ok(result)
166 }
167}
168
169impl SqliteSession {
170 pub async fn open<P: AsRef<Path>>(path: P) -> libsql::Result<Self> {
173 let conn = libsql::Builder::new_local(path).build().await?.connect()?;
174 let db = Database(conn);
175 db.init().await?;
176
177 let home_dc = db
178 .fetch_one("SELECT * FROM dc_home LIMIT 1", named_params![], |row| {
179 Ok(row.get::<i32>(0)?)
180 })
181 .await?
182 .unwrap_or(DEFAULT_DC);
183
184 let dc_options = db
185 .fetch_all("SELECT * FROM dc_option", named_params![], |row| {
186 Ok(DcOption {
187 id: row.get::<i32>(0)?,
188 ipv4: row.get::<String>(1)?.parse().unwrap(),
189 ipv6: row.get::<String>(2)?.parse().unwrap(),
190 auth_key: row
191 .get::<Option<Vec<u8>>>(3)?
192 .map(|auth_key| auth_key.try_into().unwrap()),
193 })
194 })
195 .await?
196 .into_iter()
197 .map(|dc_option| (dc_option.id, dc_option))
198 .collect();
199
200 Ok(SqliteSession {
201 database: AsyncMutex::new(db),
202 cache: Mutex::new(Cache {
203 home_dc,
204 dc_options,
205 }),
206 })
207 }
208}
209
210impl Session for SqliteSession {
211 fn home_dc_id(&self) -> i32 {
212 self.cache.lock().unwrap().home_dc
213 }
214
215 fn set_home_dc_id(&self, dc_id: i32) -> BoxFuture<'_, ()> {
216 self.cache.lock().unwrap().home_dc = dc_id;
217 Box::pin(async move {
218 let transaction = self
219 .database
220 .lock()
221 .await
222 .begin_transaction()
223 .await
224 .unwrap();
225 transaction
226 .execute("DELETE FROM dc_home", params![])
227 .await
228 .unwrap();
229 let stmt = transaction
230 .prepare("INSERT INTO dc_home VALUES (:dc_id)")
231 .await
232 .unwrap();
233 stmt.execute(named_params! {":dc_id": dc_id}).await.unwrap();
234 transaction.commit().await.unwrap();
235 })
236 }
237
238 fn dc_option(&self, dc_id: i32) -> Option<DcOption> {
239 self.cache
240 .lock()
241 .unwrap()
242 .dc_options
243 .get(&dc_id)
244 .cloned()
245 .or_else(|| {
246 KNOWN_DC_OPTIONS
247 .iter()
248 .find(|dc_option| dc_option.id == dc_id)
249 .cloned()
250 })
251 }
252
253 fn set_dc_option(&self, dc_option: &DcOption) -> BoxFuture<'_, ()> {
254 self.cache
255 .lock()
256 .unwrap()
257 .dc_options
258 .insert(dc_option.id, dc_option.clone());
259
260 let dc_option = dc_option.clone();
261 Box::pin(async move {
262 let db = self.database.lock().await;
263 db.0.execute(
264 "INSERT OR REPLACE INTO dc_option VALUES (:dc_id, :ipv4, :ipv6, :auth_key)",
265 named_params! {
266 ":dc_id": dc_option.id,
267 ":ipv4": dc_option.ipv4.to_string(),
268 ":ipv6": dc_option.ipv6.to_string(),
269 ":auth_key": dc_option.auth_key.map(|k| k.to_vec()),
270 },
271 )
272 .await
273 .unwrap();
274 })
275 }
276
277 fn peer(&self, peer: PeerId) -> BoxFuture<'_, Option<PeerInfo>> {
278 Box::pin(async move {
279 let db = self.database.lock().await;
280 let map_row = |row: libsql::Row| {
281 let subtype = row.get::<Option<i64>>(2)?.map(|s| s as u8);
282 Ok(match peer.kind() {
283 PeerKind::User | PeerKind::UserSelf => PeerInfo::User {
284 id: PeerId::user_unchecked(row.get::<i64>(0)?).bare_id(),
285 auth: row.get::<Option<i64>>(1)?.map(PeerAuth::from_hash),
286 bot: subtype.map(|s| s & PeerSubtype::UserBot as u8 != 0),
287 is_self: subtype.map(|s| s & PeerSubtype::UserSelf as u8 != 0),
288 },
289 PeerKind::Chat => PeerInfo::Chat { id: peer.bare_id() },
290 PeerKind::Channel => PeerInfo::Channel {
291 id: peer.bare_id(),
292 auth: row.get::<Option<i64>>(1)?.map(PeerAuth::from_hash),
293 kind: subtype.and_then(|s| {
294 if (s & PeerSubtype::Gigagroup as u8) == PeerSubtype::Gigagroup as u8 {
295 Some(ChannelKind::Gigagroup)
296 } else if s & PeerSubtype::Broadcast as u8 != 0 {
297 Some(ChannelKind::Broadcast)
298 } else if s & PeerSubtype::Megagroup as u8 != 0 {
299 Some(ChannelKind::Megagroup)
300 } else {
301 None
302 }
303 }),
304 },
305 })
306 };
307
308 if peer.kind() == PeerKind::UserSelf {
309 db.fetch_one(
310 "SELECT * FROM peer_info WHERE subtype & :type LIMIT 1",
311 named_params! {":type": PeerSubtype::UserSelf as i64},
312 map_row,
313 )
314 .await
315 .unwrap()
316 } else {
317 db.fetch_one(
318 "SELECT * FROM peer_info WHERE peer_id = :peer_id LIMIT 1",
319 named_params! {":peer_id": peer.bot_api_dialog_id()},
320 map_row,
321 )
322 .await
323 .unwrap()
324 }
325 })
326 }
327
328 fn cache_peer(&self, peer: &PeerInfo) -> BoxFuture<'_, ()> {
329 let peer = peer.clone();
330 Box::pin(async move {
331 let db = self.database.lock().await;
332 let stmt =
333 db.0.prepare("INSERT OR REPLACE INTO peer_info VALUES (:peer_id, :hash, :subtype)")
334 .await
335 .unwrap();
336 let subtype = match peer {
337 PeerInfo::User { bot, is_self, .. } => {
338 match (bot.unwrap_or_default(), is_self.unwrap_or_default()) {
339 (true, true) => Some(PeerSubtype::UserSelfBot),
340 (true, false) => Some(PeerSubtype::UserBot),
341 (false, true) => Some(PeerSubtype::UserSelf),
342 (false, false) => None,
343 }
344 }
345 PeerInfo::Chat { .. } => None,
346 PeerInfo::Channel { kind, .. } => kind.map(|kind| match kind {
347 ChannelKind::Megagroup => PeerSubtype::Megagroup,
348 ChannelKind::Broadcast => PeerSubtype::Broadcast,
349 ChannelKind::Gigagroup => PeerSubtype::Gigagroup,
350 }),
351 };
352 let mut params = vec![];
353 let peer_id = peer.id().bot_api_dialog_id();
354 params.push((":peer_id".to_owned(), peer_id));
355 let hash = peer.auth().unwrap_or_default().hash();
356 if peer.auth().is_some() {
357 params.push((":hash".to_owned(), hash));
358 }
359 let subtype = subtype.map(|s| s as i64);
360 if subtype.is_some() {
361 params.push((":subtype".to_owned(), subtype.unwrap()));
362 }
363 stmt.execute(params).await.unwrap();
364 })
365 }
366
367 fn updates_state(&self) -> BoxFuture<'_, UpdatesState> {
368 Box::pin(async move {
369 let db = self.database.lock().await;
370 let mut state = db
371 .fetch_one(
372 "SELECT * FROM update_state LIMIT 1",
373 named_params![],
374 |row| {
375 Ok(UpdatesState {
376 pts: row.get(0)?,
377 qts: row.get(1)?,
378 date: row.get(2)?,
379 seq: row.get(3)?,
380 channels: Vec::new(),
381 })
382 },
383 )
384 .await
385 .unwrap()
386 .unwrap_or_default();
387 state.channels = db
388 .fetch_all("SELECT * FROM channel_state", named_params![], |row| {
389 Ok(ChannelState {
390 id: row.get(0)?,
391 pts: row.get(1)?,
392 })
393 })
394 .await
395 .unwrap();
396 state
397 })
398 }
399
400 fn set_update_state(&self, update: UpdateState) -> BoxFuture<'_, ()> {
401 Box::pin(async move {
402 let db = self.database.lock().await;
403 let transaction = db.begin_transaction().await.unwrap();
404
405 match update {
406 UpdateState::All(updates_state) => {
407 transaction
408 .execute("DELETE FROM update_state", params![])
409 .await
410 .unwrap();
411 transaction
412 .execute(
413 "INSERT INTO update_state VALUES (:pts, :qts, :date, :seq)",
414 named_params! {
415 ":pts": updates_state.pts,
416 ":qts": updates_state.qts,
417 ":date": updates_state.date,
418 ":seq": updates_state.seq,
419 },
420 )
421 .await
422 .unwrap();
423
424 transaction
425 .execute("DELETE FROM channel_state", params![])
426 .await
427 .unwrap();
428 for channel in updates_state.channels {
429 transaction
430 .execute(
431 "INSERT INTO channel_state VALUES (:peer_id, :pts)",
432 named_params! {
433 ":peer_id": channel.id,
434 ":pts": channel.pts,
435 },
436 )
437 .await
438 .unwrap();
439 }
440 }
441 UpdateState::Primary { pts, date, seq } => {
442 let previous = db
443 .fetch_one(
444 "SELECT * FROM update_state LIMIT 1",
445 named_params![],
446 |_| Ok(()),
447 )
448 .await
449 .unwrap();
450
451 if previous.is_some() {
452 transaction
453 .execute(
454 "UPDATE update_state SET pts = :pts, date = :date, seq = :seq",
455 named_params! {
456 ":pts": pts,
457 ":date": date,
458 ":seq": seq,
459 },
460 )
461 .await
462 .unwrap();
463 } else {
464 transaction
465 .execute(
466 "INSERT INTO update_state VALUES (:pts, 0, :date, :seq)",
467 named_params! {
468 ":pts": pts,
469 ":date": date,
470 ":seq": seq,
471 },
472 )
473 .await
474 .unwrap();
475 }
476 }
477 UpdateState::Secondary { qts } => {
478 let previous = db
479 .fetch_one(
480 "SELECT * FROM update_state LIMIT 1",
481 named_params![],
482 |_| Ok(()),
483 )
484 .await
485 .unwrap();
486
487 if previous.is_some() {
488 transaction
489 .execute(
490 "UPDATE update_state SET qts = :qts",
491 named_params! {":qts": qts},
492 )
493 .await
494 .unwrap();
495 } else {
496 transaction
497 .execute(
498 "INSERT INTO update_state VALUES (0, :qts, 0, 0)",
499 named_params! {":qts": qts},
500 )
501 .await
502 .unwrap();
503 }
504 }
505 UpdateState::Channel { id, pts } => {
506 transaction
507 .execute(
508 "INSERT OR REPLACE INTO channel_state VALUES (:peer_id, :pts)",
509 named_params! {
510 ":peer_id": id,
511 ":pts": pts,
512 },
513 )
514 .await
515 .unwrap();
516 }
517 }
518
519 transaction.commit().await.unwrap();
520 })
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
527
528 use {DcOption, KNOWN_DC_OPTIONS, PeerInfo, Session, UpdateState};
529
530 use super::*;
531
532 #[test]
533 fn exercise_sqlite_session() {
534 tokio::runtime::Builder::new_current_thread()
535 .enable_all()
536 .build()
537 .unwrap()
538 .block_on(do_exercise_sqlite_session());
539 }
540
541 async fn do_exercise_sqlite_session() {
542 let session = SqliteSession::open(":memory:").await.unwrap();
543
544 assert_eq!(session.home_dc_id(), DEFAULT_DC);
545 session.set_home_dc_id(DEFAULT_DC + 1).await;
546 assert_eq!(session.home_dc_id(), DEFAULT_DC + 1);
547
548 assert_eq!(
549 session.dc_option(KNOWN_DC_OPTIONS[0].id),
550 Some(KNOWN_DC_OPTIONS[0].clone())
551 );
552 let new_dc_option = DcOption {
553 id: KNOWN_DC_OPTIONS
554 .iter()
555 .map(|dc_option| dc_option.id)
556 .max()
557 .unwrap()
558 + 1,
559 ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 1),
560 ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 1, 0, 0),
561 auth_key: Some([1; 256]),
562 };
563 assert_eq!(session.dc_option(new_dc_option.id), None);
564 session.set_dc_option(&new_dc_option).await;
565 assert_eq!(session.dc_option(new_dc_option.id), Some(new_dc_option));
566
567 assert_eq!(session.peer(PeerId::self_user()).await, None);
568 assert_eq!(session.peer(PeerId::user_unchecked(1)).await, None);
569 let peer = PeerInfo::User {
570 id: 1,
571 auth: None,
572 bot: Some(true),
573 is_self: Some(true),
574 };
575 session.cache_peer(&peer).await;
576 assert_eq!(session.peer(PeerId::self_user()).await, Some(peer.clone()));
577 assert_eq!(session.peer(PeerId::user_unchecked(1)).await, Some(peer));
578
579 assert_eq!(session.peer(PeerId::channel_unchecked(1)).await, None);
580 let peer = PeerInfo::Channel {
581 id: 1,
582 auth: Some(PeerAuth::from_hash(-1)),
583 kind: Some(ChannelKind::Broadcast),
584 };
585 session.cache_peer(&peer).await;
586 assert_eq!(session.peer(PeerId::channel_unchecked(1)).await, Some(peer));
587
588 assert_eq!(session.updates_state().await, UpdatesState::default());
589 session
590 .set_update_state(UpdateState::All(UpdatesState {
591 pts: 1,
592 qts: 2,
593 date: 3,
594 seq: 4,
595 channels: vec![
596 ChannelState { id: 5, pts: 6 },
597 ChannelState { id: 7, pts: 8 },
598 ],
599 }))
600 .await;
601 session
602 .set_update_state(UpdateState::Primary {
603 pts: 2,
604 date: 4,
605 seq: 5,
606 })
607 .await;
608 session
609 .set_update_state(UpdateState::Secondary { qts: 3 })
610 .await;
611 session
612 .set_update_state(UpdateState::Channel { id: 7, pts: 9 })
613 .await;
614 assert_eq!(
615 session.updates_state().await,
616 UpdatesState {
617 pts: 2,
618 qts: 3,
619 date: 4,
620 seq: 5,
621 channels: vec![
622 ChannelState { id: 5, pts: 6 },
623 ChannelState { id: 7, pts: 9 },
624 ],
625 }
626 );
627 }
628}