1#![cfg_attr(docsrs, feature(doc_cfg))]
12use std::collections::HashMap;
73use std::io::{self, ErrorKind};
74use std::path::Path;
75
76#[cfg(feature = "serde")]
77mod auth_key_serde {
78 use serde::{Deserialize, Deserializer, Serializer};
79
80 pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
81 where
82 S: Serializer,
83 {
84 match value {
85 Some(k) => s.serialize_some(k.as_slice()),
86 None => s.serialize_none(),
87 }
88 }
89
90 pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
91 where
92 D: Deserializer<'de>,
93 {
94 let opt: Option<Vec<u8>> = Option::deserialize(d)?;
95 match opt {
96 None => Ok(None),
97 Some(v) => {
98 let arr: [u8; 256] = v
99 .try_into()
100 .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
101 Ok(Some(arr))
102 }
103 }
104 }
105}
106
107#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
111#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
112pub struct DcFlags(pub u8);
113
114impl DcFlags {
115 pub const NONE: DcFlags = DcFlags(0);
116 pub const IPV6: DcFlags = DcFlags(1 << 0);
117 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
118 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
119 pub const CDN: DcFlags = DcFlags(1 << 3);
120 pub const STATIC: DcFlags = DcFlags(1 << 4);
121
122 pub fn contains(self, other: DcFlags) -> bool {
123 self.0 & other.0 == other.0
124 }
125
126 pub fn set(&mut self, flag: DcFlags) {
127 self.0 |= flag.0;
128 }
129}
130
131impl std::ops::BitOr for DcFlags {
132 type Output = DcFlags;
133 fn bitor(self, rhs: DcFlags) -> DcFlags {
134 DcFlags(self.0 | rhs.0)
135 }
136}
137
138#[derive(Clone, Debug)]
140#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
141pub struct DcEntry {
142 pub dc_id: i32,
143 pub addr: String,
144 #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
145 pub auth_key: Option<[u8; 256]>,
146 pub first_salt: i64,
147 pub time_offset: i32,
148 pub flags: DcFlags,
150}
151
152impl DcEntry {
153 #[inline]
155 pub fn is_ipv6(&self) -> bool {
156 self.flags.contains(DcFlags::IPV6)
157 }
158
159 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
166 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
167 io::Error::new(
168 io::ErrorKind::InvalidData,
169 format!("invalid DC address: {:?}", self.addr),
170 )
171 })
172 }
173
174 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
187 let addr = if ip.contains(':') {
189 format!("[{ip}]:{port}")
190 } else {
191 format!("{ip}:{port}")
192 };
193 Self {
194 dc_id,
195 addr,
196 auth_key: None,
197 first_salt: 0,
198 time_offset: 0,
199 flags,
200 }
201 }
202}
203
204#[derive(Clone, Debug, Default)]
207#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
208pub struct UpdatesStateSnap {
209 pub pts: i32,
211 pub qts: i32,
213 pub date: i32,
215 pub seq: i32,
217 pub channels: Vec<(i64, i32)>,
219}
220
221impl UpdatesStateSnap {
222 #[inline]
224 pub fn is_initialised(&self) -> bool {
225 self.pts > 0
226 }
227
228 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
230 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
231 entry.1 = pts;
232 } else {
233 self.channels.push((channel_id, pts));
234 }
235 }
236
237 pub fn channel_pts(&self, channel_id: i64) -> i32 {
239 self.channels
240 .iter()
241 .find(|c| c.0 == channel_id)
242 .map(|c| c.1)
243 .unwrap_or(0)
244 }
245}
246
247#[derive(Clone, Debug)]
250#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
251pub struct CachedPeer {
252 pub id: i64,
254 pub access_hash: i64,
257 pub is_channel: bool,
259 pub is_chat: bool,
262}
263
264#[derive(Clone, Debug)]
268#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
269pub struct CachedMinPeer {
270 pub user_id: i64,
272 pub peer_id: i64,
274 pub msg_id: i32,
276}
277
278#[derive(Clone, Debug, Default)]
280#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
281pub struct PersistedSession {
282 pub home_dc_id: i32,
283 pub dcs: Vec<DcEntry>,
284 pub updates_state: UpdatesStateSnap,
286 pub peers: Vec<CachedPeer>,
289 pub min_peers: Vec<CachedMinPeer>,
292}
293
294impl PersistedSession {
295 pub fn to_bytes(&self) -> Vec<u8> {
297 let mut b = Vec::with_capacity(512);
298
299 b.push(0x05u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
302
303 b.push(self.dcs.len() as u8);
304 for d in &self.dcs {
305 b.extend_from_slice(&d.dc_id.to_le_bytes());
306 match &d.auth_key {
307 Some(k) => {
308 b.push(1);
309 b.extend_from_slice(k);
310 }
311 None => {
312 b.push(0);
313 }
314 }
315 b.extend_from_slice(&d.first_salt.to_le_bytes());
316 b.extend_from_slice(&d.time_offset.to_le_bytes());
317 let ab = d.addr.as_bytes();
318 b.push(ab.len() as u8);
319 b.extend_from_slice(ab);
320 b.push(d.flags.0);
321 }
322
323 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
324 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
325 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
326 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
327 let ch = &self.updates_state.channels;
328 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
329 for &(cid, cpts) in ch {
330 b.extend_from_slice(&cid.to_le_bytes());
331 b.extend_from_slice(&cpts.to_le_bytes());
332 }
333
334 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
336 for p in &self.peers {
337 b.extend_from_slice(&p.id.to_le_bytes());
338 b.extend_from_slice(&p.access_hash.to_le_bytes());
339 let peer_type: u8 = if p.is_chat {
340 2
341 } else if p.is_channel {
342 1
343 } else {
344 0
345 };
346 b.push(peer_type);
347 }
348
349 b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
350 for m in &self.min_peers {
351 b.extend_from_slice(&m.user_id.to_le_bytes());
352 b.extend_from_slice(&m.peer_id.to_le_bytes());
353 b.extend_from_slice(&m.msg_id.to_le_bytes());
354 }
355
356 b
357 }
358
359 pub fn save(&self, path: &Path) -> io::Result<()> {
366 use std::sync::atomic::{AtomicU64, Ordering};
367 static SEQ: AtomicU64 = AtomicU64::new(0);
368 let n = SEQ.fetch_add(1, Ordering::Relaxed);
369 let tmp = path.with_extension(format!("{n}.tmp"));
370 std::fs::write(&tmp, self.to_bytes())?;
371 std::fs::rename(&tmp, path).inspect_err(|_e| {
372 let _ = std::fs::remove_file(&tmp);
373 })
374 }
375
376 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
378 if buf.is_empty() {
379 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
380 }
381
382 let mut p = 0usize;
383
384 macro_rules! r {
385 ($n:expr) => {{
386 if p + $n > buf.len() {
387 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
388 }
389 let s = &buf[p..p + $n];
390 p += $n;
391 s
392 }};
393 }
394 macro_rules! r_i32 {
395 () => {
396 i32::from_le_bytes(r!(4).try_into().unwrap())
397 };
398 }
399 macro_rules! r_i64 {
400 () => {
401 i64::from_le_bytes(r!(8).try_into().unwrap())
402 };
403 }
404 macro_rules! r_u8 {
405 () => {
406 r!(1)[0]
407 };
408 }
409 macro_rules! r_u16 {
410 () => {
411 u16::from_le_bytes(r!(2).try_into().unwrap())
412 };
413 }
414
415 let first_byte = r_u8!();
416
417 let (home_dc_id, version) = if first_byte == 0x05 {
418 (r_i32!(), 5u8)
419 } else if first_byte == 0x04 {
420 (r_i32!(), 4u8)
421 } else if first_byte == 0x03 {
422 (r_i32!(), 3u8)
423 } else if first_byte == 0x02 {
424 (r_i32!(), 2u8)
425 } else {
426 let rest = r!(3);
427 let mut bytes = [0u8; 4];
428 bytes[0] = first_byte;
429 bytes[1..4].copy_from_slice(rest);
430 (i32::from_le_bytes(bytes), 1u8)
431 };
432
433 let dc_count = r_u8!() as usize;
434 let mut dcs = Vec::with_capacity(dc_count);
435 for _ in 0..dc_count {
436 let dc_id = r_i32!();
437 let has_key = r_u8!();
438 let auth_key = if has_key == 1 {
439 let mut k = [0u8; 256];
440 k.copy_from_slice(r!(256));
441 Some(k)
442 } else {
443 None
444 };
445 let first_salt = r_i64!();
446 let time_offset = r_i32!();
447 let al = r_u8!() as usize;
448 let addr = String::from_utf8_lossy(r!(al)).into_owned();
449 let flags = if version >= 3 {
450 DcFlags(r_u8!())
451 } else {
452 DcFlags::NONE
453 };
454 dcs.push(DcEntry {
455 dc_id,
456 addr,
457 auth_key,
458 first_salt,
459 time_offset,
460 flags,
461 });
462 }
463
464 if version < 2 {
465 return Ok(Self {
466 home_dc_id,
467 dcs,
468 updates_state: UpdatesStateSnap::default(),
469 peers: Vec::new(),
470 min_peers: Vec::new(),
471 });
472 }
473
474 let pts = r_i32!();
475 let qts = r_i32!();
476 let date = r_i32!();
477 let seq = r_i32!();
478 let ch_count = r_u16!() as usize;
479 let mut channels = Vec::with_capacity(ch_count);
480 for _ in 0..ch_count {
481 let cid = r_i64!();
482 let cpts = r_i32!();
483 channels.push((cid, cpts));
484 }
485
486 let peer_count = r_u16!() as usize;
487 let mut peers = Vec::with_capacity(peer_count);
488 for _ in 0..peer_count {
489 let id = r_i64!();
490 let access_hash = r_i64!();
491 let peer_type = r_u8!();
493 let is_channel = peer_type == 1;
494 let is_chat = peer_type == 2;
495 peers.push(CachedPeer {
496 id,
497 access_hash,
498 is_channel,
499 is_chat,
500 });
501 }
502
503 let min_peers = if version >= 4 {
505 let count = r_u16!() as usize;
506 let mut v = Vec::with_capacity(count);
507 for _ in 0..count {
508 let user_id = r_i64!();
509 let peer_id = r_i64!();
510 let msg_id = r_i32!();
511 v.push(CachedMinPeer {
512 user_id,
513 peer_id,
514 msg_id,
515 });
516 }
517 v
518 } else {
519 Vec::new()
520 };
521
522 Ok(Self {
523 home_dc_id,
524 dcs,
525 updates_state: UpdatesStateSnap {
526 pts,
527 qts,
528 date,
529 seq,
530 channels,
531 },
532 peers,
533 min_peers,
534 })
535 }
536
537 pub fn from_string(s: &str) -> io::Result<Self> {
539 use base64::Engine as _;
540 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
541 .decode(s.trim())
542 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
543 Self::from_bytes(&bytes)
544 }
545
546 pub fn load(path: &Path) -> io::Result<Self> {
547 let buf = std::fs::read(path)?;
548 Self::from_bytes(&buf)
549 }
550
551 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
563 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
564 candidates.peek()?;
565 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
567 cands
569 .iter()
570 .copied()
571 .find(|d| d.is_ipv6() == prefer_ipv6)
572 .or_else(|| cands.first().copied())
573 }
574
575 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
580 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
581 }
582}
583
584impl std::fmt::Display for PersistedSession {
585 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
586 use base64::Engine as _;
587 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
588 }
589}
590
591pub fn default_dc_addresses() -> HashMap<i32, String> {
593 [
594 (1, "149.154.175.53:443"),
595 (2, "149.154.167.51:443"),
596 (3, "149.154.175.100:443"),
597 (4, "149.154.167.91:443"),
598 (5, "91.108.56.130:443"),
599 ]
600 .into_iter()
601 .map(|(id, addr)| (id, addr.to_string()))
602 .collect()
603}
604
605use std::path::PathBuf;
610
611pub trait SessionBackend: Send + Sync {
619 fn save(&self, session: &PersistedSession) -> io::Result<()>;
620 fn load(&self) -> io::Result<Option<PersistedSession>>;
621 fn delete(&self) -> io::Result<()>;
622
623 fn name(&self) -> &str;
625
626 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
640 let mut s = self.load()?.unwrap_or_default();
641 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
643 *existing = entry.clone();
644 } else {
645 s.dcs.push(entry.clone());
646 }
647 self.save(&s)
648 }
649
650 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
657 let mut s = self.load()?.unwrap_or_default();
658 s.home_dc_id = dc_id;
659 self.save(&s)
660 }
661
662 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
668 let mut s = self.load()?.unwrap_or_default();
669 update.apply_to(&mut s.updates_state);
670 self.save(&s)
671 }
672
673 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
680 let mut s = self.load()?.unwrap_or_default();
681 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
682 *existing = peer.clone();
683 } else {
684 s.peers.push(peer.clone());
685 }
686 self.save(&s)
687 }
688}
689
690#[derive(Debug, Clone)]
704pub enum UpdateStateChange {
705 All(UpdatesStateSnap),
707 Primary { pts: i32, date: i32, seq: i32 },
709 Secondary { qts: i32 },
711 Channel { id: i64, pts: i32 },
713}
714
715impl UpdateStateChange {
716 pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
718 match self {
719 Self::All(new_snap) => *snap = new_snap.clone(),
720 Self::Primary { pts, date, seq } => {
721 snap.pts = *pts;
722 snap.date = *date;
723 snap.seq = *seq;
724 }
725 Self::Secondary { qts } => {
726 snap.qts = *qts;
727 }
728 Self::Channel { id, pts } => {
729 if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
731 existing.1 = *pts;
732 } else {
733 snap.channels.push((*id, *pts));
734 }
735 }
736 }
737 }
738}
739
740pub struct BinaryFileBackend {
744 path: PathBuf,
745 write_lock: std::sync::Mutex<()>,
750}
751
752impl BinaryFileBackend {
753 pub fn new(path: impl Into<PathBuf>) -> Self {
754 Self {
755 path: path.into(),
756 write_lock: std::sync::Mutex::new(()),
757 }
758 }
759
760 pub fn path(&self) -> &std::path::Path {
761 &self.path
762 }
763}
764
765impl SessionBackend for BinaryFileBackend {
766 fn save(&self, session: &PersistedSession) -> io::Result<()> {
767 let _guard = self.write_lock.lock().unwrap();
768 session.save(&self.path)
769 }
770
771 fn load(&self) -> io::Result<Option<PersistedSession>> {
772 if !self.path.exists() {
773 return Ok(None);
774 }
775 match PersistedSession::load(&self.path) {
776 Ok(s) => Ok(Some(s)),
777 Err(e) => {
778 let bak = self.path.with_extension("bak");
779 tracing::warn!(
780 "[ferogram] Session file {:?} is corrupt ({e}); \
781 renaming to {:?} and starting fresh",
782 self.path,
783 bak
784 );
785 let _ = std::fs::rename(&self.path, &bak);
786 Ok(None)
787 }
788 }
789 }
790
791 fn delete(&self) -> io::Result<()> {
792 if self.path.exists() {
793 std::fs::remove_file(&self.path)?;
794 }
795 Ok(())
796 }
797
798 fn name(&self) -> &str {
799 "binary-file"
800 }
801
802 }
805
806#[derive(Default)]
814pub struct InMemoryBackend {
815 data: std::sync::Mutex<Option<PersistedSession>>,
816}
817
818impl InMemoryBackend {
819 pub fn new() -> Self {
820 Self::default()
821 }
822
823 pub fn snapshot(&self) -> Option<PersistedSession> {
825 self.data.lock().unwrap().clone()
826 }
827}
828
829impl SessionBackend for InMemoryBackend {
830 fn save(&self, s: &PersistedSession) -> io::Result<()> {
831 *self.data.lock().unwrap() = Some(s.clone());
832 Ok(())
833 }
834
835 fn load(&self) -> io::Result<Option<PersistedSession>> {
836 Ok(self.data.lock().unwrap().clone())
837 }
838
839 fn delete(&self) -> io::Result<()> {
840 *self.data.lock().unwrap() = None;
841 Ok(())
842 }
843
844 fn name(&self) -> &str {
845 "in-memory"
846 }
847
848 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
851 let mut guard = self.data.lock().unwrap();
852 let s = guard.get_or_insert_with(PersistedSession::default);
853 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
854 *existing = entry.clone();
855 } else {
856 s.dcs.push(entry.clone());
857 }
858 Ok(())
859 }
860
861 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
862 let mut guard = self.data.lock().unwrap();
863 let s = guard.get_or_insert_with(PersistedSession::default);
864 s.home_dc_id = dc_id;
865 Ok(())
866 }
867
868 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
869 let mut guard = self.data.lock().unwrap();
870 let s = guard.get_or_insert_with(PersistedSession::default);
871 update.apply_to(&mut s.updates_state);
872 Ok(())
873 }
874
875 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
876 let mut guard = self.data.lock().unwrap();
877 let s = guard.get_or_insert_with(PersistedSession::default);
878 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
879 *existing = peer.clone();
880 } else {
881 s.peers.push(peer.clone());
882 }
883 Ok(())
884 }
885}
886
887pub struct StringSessionBackend {
891 data: std::sync::Mutex<String>,
892}
893
894impl StringSessionBackend {
895 pub fn new(s: impl Into<String>) -> Self {
896 Self {
897 data: std::sync::Mutex::new(s.into()),
898 }
899 }
900
901 pub fn current(&self) -> String {
902 self.data.lock().unwrap().clone()
903 }
904}
905
906impl SessionBackend for StringSessionBackend {
907 fn save(&self, session: &PersistedSession) -> io::Result<()> {
908 *self.data.lock().unwrap() = session.to_string();
909 Ok(())
910 }
911
912 fn load(&self) -> io::Result<Option<PersistedSession>> {
913 let s = self.data.lock().unwrap().clone();
914 if s.trim().is_empty() {
915 return Ok(None);
916 }
917 PersistedSession::from_string(&s).map(Some)
918 }
919
920 fn delete(&self) -> io::Result<()> {
921 *self.data.lock().unwrap() = String::new();
922 Ok(())
923 }
924
925 fn name(&self) -> &str {
926 "string-session"
927 }
928}
929
930#[cfg(test)]
933mod tests {
934 use super::*;
935
936 fn make_dc(id: i32) -> DcEntry {
937 DcEntry {
938 dc_id: id,
939 addr: format!("1.2.3.{id}:443"),
940 auth_key: None,
941 first_salt: 0,
942 time_offset: 0,
943 flags: DcFlags::NONE,
944 }
945 }
946
947 fn make_peer(id: i64, hash: i64) -> CachedPeer {
948 CachedPeer {
949 id,
950 access_hash: hash,
951 is_channel: false,
952 is_chat: false,
953 }
954 }
955
956 #[test]
959 fn inmemory_load_returns_none_when_empty() {
960 let b = InMemoryBackend::new();
961 assert!(b.load().unwrap().is_none());
962 }
963
964 #[test]
965 fn inmemory_save_then_load_round_trips() {
966 let b = InMemoryBackend::new();
967 let mut s = PersistedSession::default();
968 s.home_dc_id = 3;
969 s.dcs.push(make_dc(3));
970 b.save(&s).unwrap();
971
972 let loaded = b.load().unwrap().unwrap();
973 assert_eq!(loaded.home_dc_id, 3);
974 assert_eq!(loaded.dcs.len(), 1);
975 }
976
977 #[test]
978 fn inmemory_delete_clears_state() {
979 let b = InMemoryBackend::new();
980 let mut s = PersistedSession::default();
981 s.home_dc_id = 2;
982 b.save(&s).unwrap();
983 b.delete().unwrap();
984 assert!(b.load().unwrap().is_none());
985 }
986
987 #[test]
990 fn inmemory_update_dc_inserts_new() {
991 let b = InMemoryBackend::new();
992 b.update_dc(&make_dc(4)).unwrap();
993 let s = b.snapshot().unwrap();
994 assert_eq!(s.dcs.len(), 1);
995 assert_eq!(s.dcs[0].dc_id, 4);
996 }
997
998 #[test]
999 fn inmemory_update_dc_replaces_existing() {
1000 let b = InMemoryBackend::new();
1001 b.update_dc(&make_dc(2)).unwrap();
1002 let mut updated = make_dc(2);
1003 updated.addr = "9.9.9.9:443".to_string();
1004 b.update_dc(&updated).unwrap();
1005
1006 let s = b.snapshot().unwrap();
1007 assert_eq!(s.dcs.len(), 1);
1008 assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
1009 }
1010
1011 #[test]
1012 fn inmemory_set_home_dc() {
1013 let b = InMemoryBackend::new();
1014 b.set_home_dc(5).unwrap();
1015 assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
1016 }
1017
1018 #[test]
1019 fn inmemory_cache_peer_inserts() {
1020 let b = InMemoryBackend::new();
1021 b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
1022 let s = b.snapshot().unwrap();
1023 assert_eq!(s.peers.len(), 1);
1024 assert_eq!(s.peers[0].id, 100);
1025 }
1026
1027 #[test]
1028 fn inmemory_cache_peer_updates_existing() {
1029 let b = InMemoryBackend::new();
1030 b.cache_peer(&make_peer(100, 111)).unwrap();
1031 b.cache_peer(&make_peer(100, 222)).unwrap();
1032 let s = b.snapshot().unwrap();
1033 assert_eq!(s.peers.len(), 1);
1034 assert_eq!(s.peers[0].access_hash, 222);
1035 }
1036
1037 #[test]
1040 fn update_state_primary() {
1041 let mut snap = UpdatesStateSnap {
1042 pts: 0,
1043 qts: 0,
1044 date: 0,
1045 seq: 0,
1046 channels: vec![],
1047 };
1048 UpdateStateChange::Primary {
1049 pts: 10,
1050 date: 20,
1051 seq: 30,
1052 }
1053 .apply_to(&mut snap);
1054 assert_eq!(snap.pts, 10);
1055 assert_eq!(snap.date, 20);
1056 assert_eq!(snap.seq, 30);
1057 assert_eq!(snap.qts, 0); }
1059
1060 #[test]
1061 fn update_state_secondary() {
1062 let mut snap = UpdatesStateSnap {
1063 pts: 5,
1064 qts: 0,
1065 date: 0,
1066 seq: 0,
1067 channels: vec![],
1068 };
1069 UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1070 assert_eq!(snap.qts, 99);
1071 assert_eq!(snap.pts, 5); }
1073
1074 #[test]
1075 fn update_state_channel_inserts() {
1076 let mut snap = UpdatesStateSnap {
1077 pts: 0,
1078 qts: 0,
1079 date: 0,
1080 seq: 0,
1081 channels: vec![],
1082 };
1083 UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1084 assert_eq!(snap.channels, vec![(12345, 42)]);
1085 }
1086
1087 #[test]
1088 fn update_state_channel_updates_existing() {
1089 let mut snap = UpdatesStateSnap {
1090 pts: 0,
1091 qts: 0,
1092 date: 0,
1093 seq: 0,
1094 channels: vec![(12345, 10), (67890, 5)],
1095 };
1096 UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1097 assert_eq!(snap.channels[0], (12345, 99));
1099 assert_eq!(snap.channels[1], (67890, 5));
1100 }
1101
1102 #[test]
1103 fn apply_update_state_via_backend() {
1104 let b = InMemoryBackend::new();
1105 b.apply_update_state(UpdateStateChange::Primary {
1106 pts: 7,
1107 date: 8,
1108 seq: 9,
1109 })
1110 .unwrap();
1111 let s = b.snapshot().unwrap();
1112 assert_eq!(s.updates_state.pts, 7);
1113 }
1114
1115 #[test]
1118 fn default_update_dc_via_trait_object() {
1119 let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1120 b.update_dc(&make_dc(1)).unwrap();
1121 b.update_dc(&make_dc(2)).unwrap();
1122 let loaded = b.load().unwrap().unwrap();
1124 assert_eq!(loaded.dcs.len(), 2);
1125 }
1126
1127 fn make_dc_v6(id: i32) -> DcEntry {
1130 DcEntry {
1131 dc_id: id,
1132 addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1133 auth_key: None,
1134 first_salt: 0,
1135 time_offset: 0,
1136 flags: DcFlags::IPV6,
1137 }
1138 }
1139
1140 #[test]
1141 fn dc_entry_from_parts_ipv4() {
1142 let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1143 assert_eq!(dc.addr, "149.154.175.53:443");
1144 assert!(!dc.is_ipv6());
1145 let sa = dc.socket_addr().unwrap();
1146 assert_eq!(sa.port(), 443);
1147 }
1148
1149 #[test]
1150 fn dc_entry_from_parts_ipv6() {
1151 let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1152 assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1153 assert!(dc.is_ipv6());
1154 let sa = dc.socket_addr().unwrap();
1155 assert_eq!(sa.port(), 443);
1156 }
1157
1158 #[test]
1159 fn persisted_session_dc_for_prefers_ipv6() {
1160 let mut s = PersistedSession::default();
1161 s.dcs.push(make_dc(2)); s.dcs.push(make_dc_v6(2)); let v6 = s.dc_for(2, true).unwrap();
1165 assert!(v6.is_ipv6());
1166
1167 let v4 = s.dc_for(2, false).unwrap();
1168 assert!(!v4.is_ipv6());
1169 }
1170
1171 #[test]
1172 fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1173 let mut s = PersistedSession::default();
1174 s.dcs.push(make_dc(3)); let dc = s.dc_for(3, true).unwrap();
1178 assert!(!dc.is_ipv6());
1179 }
1180
1181 #[test]
1182 fn persisted_session_all_dcs_for_returns_both() {
1183 let mut s = PersistedSession::default();
1184 s.dcs.push(make_dc(1));
1185 s.dcs.push(make_dc_v6(1));
1186 s.dcs.push(make_dc(2));
1187
1188 assert_eq!(s.all_dcs_for(1).count(), 2);
1189 assert_eq!(s.all_dcs_for(2).count(), 1);
1190 assert_eq!(s.all_dcs_for(5).count(), 0);
1191 }
1192
1193 #[test]
1194 fn inmemory_ipv4_and_ipv6_coexist() {
1195 let b = InMemoryBackend::new();
1196 b.update_dc(&make_dc(2)).unwrap(); b.update_dc(&make_dc_v6(2)).unwrap(); let s = b.snapshot().unwrap();
1200 assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1202 }
1203
1204 #[test]
1205 fn binary_roundtrip_ipv4_and_ipv6() {
1206 let mut s = PersistedSession::default();
1207 s.home_dc_id = 2;
1208 s.dcs.push(make_dc(2));
1209 s.dcs.push(make_dc_v6(2));
1210
1211 let bytes = s.to_bytes();
1212 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1213 assert_eq!(loaded.dcs.len(), 2);
1214 assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1215 assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1216 }
1217}
1218
1219#[cfg(feature = "sqlite-session")]
1244pub struct SqliteBackend {
1245 conn: std::sync::Mutex<rusqlite::Connection>,
1246 label: String,
1247}
1248
1249#[cfg(feature = "sqlite-session")]
1250impl SqliteBackend {
1251 const SCHEMA: &'static str = "
1252 PRAGMA journal_mode = WAL;
1253 PRAGMA synchronous = NORMAL;
1254
1255 CREATE TABLE IF NOT EXISTS meta (
1256 key TEXT PRIMARY KEY,
1257 value INTEGER NOT NULL DEFAULT 0
1258 );
1259
1260 CREATE TABLE IF NOT EXISTS dcs (
1261 dc_id INTEGER NOT NULL,
1262 flags INTEGER NOT NULL DEFAULT 0,
1263 addr TEXT NOT NULL,
1264 auth_key BLOB,
1265 first_salt INTEGER NOT NULL DEFAULT 0,
1266 time_offset INTEGER NOT NULL DEFAULT 0,
1267 PRIMARY KEY (dc_id, flags)
1268 );
1269
1270 CREATE TABLE IF NOT EXISTS update_state (
1271 id INTEGER PRIMARY KEY CHECK (id = 1),
1272 pts INTEGER NOT NULL DEFAULT 0,
1273 qts INTEGER NOT NULL DEFAULT 0,
1274 date INTEGER NOT NULL DEFAULT 0,
1275 seq INTEGER NOT NULL DEFAULT 0
1276 );
1277
1278 CREATE TABLE IF NOT EXISTS channel_pts (
1279 channel_id INTEGER PRIMARY KEY,
1280 pts INTEGER NOT NULL
1281 );
1282
1283 CREATE TABLE IF NOT EXISTS peers (
1284 id INTEGER PRIMARY KEY,
1285 access_hash INTEGER NOT NULL,
1286 is_channel INTEGER NOT NULL DEFAULT 0
1287 );
1288 ";
1289
1290 pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1292 let path = path.into();
1293 let label = path.display().to_string();
1294 let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1295 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1296 Ok(Self {
1297 conn: std::sync::Mutex::new(conn),
1298 label,
1299 })
1300 }
1301
1302 pub fn in_memory() -> io::Result<Self> {
1304 let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1305 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1306 Ok(Self {
1307 conn: std::sync::Mutex::new(conn),
1308 label: ":memory:".into(),
1309 })
1310 }
1311
1312 fn map_err(e: rusqlite::Error) -> io::Error {
1313 io::Error::other(e)
1314 }
1315
1316 fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1318 let home_dc_id: i32 = conn
1320 .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1321 r.get(0)
1322 })
1323 .unwrap_or(0);
1324
1325 let mut stmt = conn
1327 .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1328 .map_err(Self::map_err)?;
1329 let dcs = stmt
1330 .query_map([], |row| {
1331 let dc_id: i32 = row.get(0)?;
1332 let flags_raw: u8 = row.get(1)?;
1333 let addr: String = row.get(2)?;
1334 let key_blob: Option<Vec<u8>> = row.get(3)?;
1335 let first_salt: i64 = row.get(4)?;
1336 let time_offset: i32 = row.get(5)?;
1337 Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1338 })
1339 .map_err(Self::map_err)?
1340 .filter_map(|r| r.ok())
1341 .map(
1342 |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1343 let auth_key = key_blob.and_then(|b| {
1344 if b.len() == 256 {
1345 let mut k = [0u8; 256];
1346 k.copy_from_slice(&b);
1347 Some(k)
1348 } else {
1349 None
1350 }
1351 });
1352 DcEntry {
1353 dc_id,
1354 addr,
1355 auth_key,
1356 first_salt,
1357 time_offset,
1358 flags: DcFlags(flags_raw),
1359 }
1360 },
1361 )
1362 .collect();
1363
1364 let updates_state = conn
1366 .query_row(
1367 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1368 [],
1369 |r| {
1370 Ok(UpdatesStateSnap {
1371 pts: r.get(0)?,
1372 qts: r.get(1)?,
1373 date: r.get(2)?,
1374 seq: r.get(3)?,
1375 channels: vec![],
1376 })
1377 },
1378 )
1379 .unwrap_or_default();
1380
1381 let mut ch_stmt = conn
1383 .prepare("SELECT channel_id, pts FROM channel_pts")
1384 .map_err(Self::map_err)?;
1385 let channels: Vec<(i64, i32)> = ch_stmt
1386 .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1387 .map_err(Self::map_err)?
1388 .filter_map(|r| r.ok())
1389 .collect();
1390
1391 let mut peer_stmt = conn
1393 .prepare("SELECT id, access_hash, is_channel FROM peers")
1394 .map_err(Self::map_err)?;
1395 let peers: Vec<CachedPeer> = peer_stmt
1396 .query_map([], |r| {
1397 Ok(CachedPeer {
1398 id: r.get(0)?,
1399 access_hash: r.get(1)?,
1400 is_channel: r.get::<_, i32>(2)? != 0,
1401 is_chat: false,
1402 })
1403 })
1404 .map_err(Self::map_err)?
1405 .filter_map(|r| r.ok())
1406 .collect();
1407
1408 Ok(PersistedSession {
1409 home_dc_id,
1410 dcs,
1411 updates_state: UpdatesStateSnap {
1412 channels,
1413 ..updates_state
1414 },
1415 peers,
1416 min_peers: Vec::new(),
1417 })
1418 }
1419
1420 fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1422 conn.execute_batch("BEGIN IMMEDIATE")
1423 .map_err(Self::map_err)?;
1424
1425 conn.execute(
1426 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1427 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1428 rusqlite::params![s.home_dc_id],
1429 )
1430 .map_err(Self::map_err)?;
1431
1432 conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1434 for d in &s.dcs {
1435 conn.execute(
1436 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1437 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1438 rusqlite::params![
1439 d.dc_id,
1440 d.flags.0,
1441 d.addr,
1442 d.auth_key.as_ref().map(|k| k.as_ref()),
1443 d.first_salt,
1444 d.time_offset,
1445 ],
1446 )
1447 .map_err(Self::map_err)?;
1448 }
1449
1450 let us = &s.updates_state;
1454 conn.execute(
1455 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1456 ON CONFLICT(id) DO UPDATE SET
1457 pts = MAX(excluded.pts, update_state.pts),
1458 qts = MAX(excluded.qts, update_state.qts),
1459 date = excluded.date,
1460 seq = excluded.seq",
1461 rusqlite::params![us.pts, us.qts, us.date, us.seq],
1462 )
1463 .map_err(Self::map_err)?;
1464
1465 conn.execute("DELETE FROM channel_pts", [])
1466 .map_err(Self::map_err)?;
1467 for &(cid, cpts) in &us.channels {
1468 conn.execute(
1469 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1470 rusqlite::params![cid, cpts],
1471 )
1472 .map_err(Self::map_err)?;
1473 }
1474
1475 conn.execute("DELETE FROM peers", [])
1477 .map_err(Self::map_err)?;
1478 for p in &s.peers {
1479 conn.execute(
1480 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)",
1481 rusqlite::params![p.id, p.access_hash, p.is_channel as i32],
1482 )
1483 .map_err(Self::map_err)?;
1484 }
1485
1486 conn.execute_batch("COMMIT").map_err(Self::map_err)
1487 }
1488}
1489
1490#[cfg(feature = "sqlite-session")]
1491impl SessionBackend for SqliteBackend {
1492 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1493 let conn = self.conn.lock().unwrap();
1494 Self::write_session(&conn, session)
1495 }
1496
1497 fn load(&self) -> io::Result<Option<PersistedSession>> {
1498 let conn = self.conn.lock().unwrap();
1499 let count: i64 = conn
1501 .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1502 .map_err(Self::map_err)?;
1503 if count == 0 {
1504 return Ok(None);
1505 }
1506 Self::read_session(&conn).map(Some)
1507 }
1508
1509 fn delete(&self) -> io::Result<()> {
1510 let conn = self.conn.lock().unwrap();
1511 conn.execute_batch(
1512 "BEGIN IMMEDIATE;
1513 DELETE FROM meta;
1514 DELETE FROM dcs;
1515 DELETE FROM update_state;
1516 DELETE FROM channel_pts;
1517 DELETE FROM peers;
1518 COMMIT;",
1519 )
1520 .map_err(Self::map_err)
1521 }
1522
1523 fn name(&self) -> &str {
1524 &self.label
1525 }
1526
1527 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1530 let conn = self.conn.lock().unwrap();
1531 conn.execute(
1532 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1533 VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1534 ON CONFLICT(dc_id, flags) DO UPDATE SET
1535 addr = excluded.addr,
1536 auth_key = excluded.auth_key,
1537 first_salt = excluded.first_salt,
1538 time_offset = excluded.time_offset",
1539 rusqlite::params![
1540 entry.dc_id,
1541 entry.addr,
1542 entry.auth_key.as_ref().map(|k| k.as_ref()),
1543 entry.first_salt,
1544 entry.time_offset,
1545 entry.flags.0,
1546 ],
1547 )
1548 .map(|_| ())
1549 .map_err(Self::map_err)
1550 }
1551
1552 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1553 let conn = self.conn.lock().unwrap();
1554 conn.execute(
1555 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1556 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1557 rusqlite::params![dc_id],
1558 )
1559 .map(|_| ())
1560 .map_err(Self::map_err)
1561 }
1562
1563 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1564 let conn = self.conn.lock().unwrap();
1565 match update {
1566 UpdateStateChange::All(snap) => {
1567 conn.execute(
1568 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1569 ON CONFLICT(id) DO UPDATE SET
1570 pts=excluded.pts, qts=excluded.qts,
1571 date=excluded.date, seq=excluded.seq",
1572 rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1573 )
1574 .map_err(Self::map_err)?;
1575 conn.execute("DELETE FROM channel_pts", [])
1576 .map_err(Self::map_err)?;
1577 for &(cid, cpts) in &snap.channels {
1578 conn.execute(
1579 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1580 rusqlite::params![cid, cpts],
1581 )
1582 .map_err(Self::map_err)?;
1583 }
1584 Ok(())
1585 }
1586 UpdateStateChange::Primary { pts, date, seq } => conn
1587 .execute(
1588 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1589 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1590 seq=excluded.seq",
1591 rusqlite::params![pts, date, seq],
1592 )
1593 .map(|_| ())
1594 .map_err(Self::map_err),
1595 UpdateStateChange::Secondary { qts } => conn
1596 .execute(
1597 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1598 ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1599 rusqlite::params![qts],
1600 )
1601 .map(|_| ())
1602 .map_err(Self::map_err),
1603 UpdateStateChange::Channel { id, pts } => conn
1604 .execute(
1605 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1606 ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1607 rusqlite::params![id, pts],
1608 )
1609 .map(|_| ())
1610 .map_err(Self::map_err),
1611 }
1612 }
1613
1614 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1615 let conn = self.conn.lock().unwrap();
1616 conn.execute(
1617 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)
1618 ON CONFLICT(id) DO UPDATE SET
1619 access_hash = excluded.access_hash,
1620 is_channel = excluded.is_channel",
1621 rusqlite::params![peer.id, peer.access_hash, peer.is_channel as i32],
1622 )
1623 .map(|_| ())
1624 .map_err(Self::map_err)
1625 }
1626}
1627
1628#[cfg(feature = "libsql-session")]
1648pub struct LibSqlBackend {
1649 conn: libsql::Connection,
1650 label: String,
1651}
1652
1653#[cfg(feature = "libsql-session")]
1654impl LibSqlBackend {
1655 const SCHEMA: &'static str = "
1656 CREATE TABLE IF NOT EXISTS meta (
1657 key TEXT PRIMARY KEY,
1658 value INTEGER NOT NULL DEFAULT 0
1659 );
1660 CREATE TABLE IF NOT EXISTS dcs (
1661 dc_id INTEGER NOT NULL,
1662 flags INTEGER NOT NULL DEFAULT 0,
1663 addr TEXT NOT NULL,
1664 auth_key BLOB,
1665 first_salt INTEGER NOT NULL DEFAULT 0,
1666 time_offset INTEGER NOT NULL DEFAULT 0,
1667 PRIMARY KEY (dc_id, flags)
1668 );
1669 CREATE TABLE IF NOT EXISTS update_state (
1670 id INTEGER PRIMARY KEY CHECK (id = 1),
1671 pts INTEGER NOT NULL DEFAULT 0,
1672 qts INTEGER NOT NULL DEFAULT 0,
1673 date INTEGER NOT NULL DEFAULT 0,
1674 seq INTEGER NOT NULL DEFAULT 0
1675 );
1676 CREATE TABLE IF NOT EXISTS channel_pts (
1677 channel_id INTEGER PRIMARY KEY,
1678 pts INTEGER NOT NULL
1679 );
1680 CREATE TABLE IF NOT EXISTS peers (
1681 id INTEGER PRIMARY KEY,
1682 access_hash INTEGER NOT NULL,
1683 is_channel INTEGER NOT NULL DEFAULT 0
1684 );
1685 ";
1686
1687 fn block<F, T>(fut: F) -> io::Result<T>
1688 where
1689 F: std::future::Future<Output = Result<T, libsql::Error>>,
1690 {
1691 tokio::runtime::Handle::current()
1692 .block_on(fut)
1693 .map_err(io::Error::other)
1694 }
1695
1696 async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1697 conn.execute_batch(Self::SCHEMA).await
1698 }
1699
1700 pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1702 let path = path.into();
1703 let label = path.display().to_string();
1704 let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1705 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1706 Self::block(Self::apply_schema(&conn))?;
1707 Ok(Self {
1708 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1709 label,
1710 })
1711 }
1712
1713 pub fn in_memory() -> io::Result<Self> {
1715 let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1716 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1717 Self::block(Self::apply_schema(&conn))?;
1718 Ok(Self {
1719 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1720 label: ":memory:".into(),
1721 })
1722 }
1723
1724 pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1726 let url = url.into();
1727 let label = url.clone();
1728 let db = Self::block(async {
1729 libsql::Builder::new_remote(url, auth_token.into())
1730 .build()
1731 .await
1732 })?;
1733 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1734 Self::block(Self::apply_schema(&conn))?;
1735 Ok(Self {
1736 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1737 label,
1738 })
1739 }
1740
1741 pub fn open_replica(
1743 path: impl Into<PathBuf>,
1744 url: impl Into<String>,
1745 auth_token: impl Into<String>,
1746 ) -> io::Result<Self> {
1747 let path = path.into();
1748 let label = format!("{} (replica of {})", path.display(), url.into());
1749 let db = Self::block(async {
1750 libsql::Builder::new_remote_replica(path, url.into(), auth_token.into())
1751 .build()
1752 .await
1753 })?;
1754 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1755 Self::block(Self::apply_schema(&conn))?;
1756 Ok(Self {
1757 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1758 label,
1759 })
1760 }
1761
1762 async fn read_session_async(
1763 conn: &libsql::Connection,
1764 ) -> Result<PersistedSession, libsql::Error> {
1765 use libsql::de;
1766
1767 let home_dc_id: i32 = conn
1769 .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
1770 .await?
1771 .next()
1772 .await?
1773 .map(|r| r.get::<i32>(0))
1774 .transpose()?
1775 .unwrap_or(0);
1776
1777 let mut rows = conn
1779 .query(
1780 "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
1781 (),
1782 )
1783 .await?;
1784 let mut dcs = Vec::new();
1785 while let Some(row) = rows.next().await? {
1786 let dc_id: i32 = row.get(0)?;
1787 let flags_raw: u8 = row.get::<i64>(1)? as u8;
1788 let addr: String = row.get(2)?;
1789 let key_blob: Option<Vec<u8>> = row.get(3)?;
1790 let first_salt: i64 = row.get(4)?;
1791 let time_offset: i32 = row.get(5)?;
1792 let auth_key = match key_blob {
1793 Some(b) if b.len() == 256 => {
1794 let mut k = [0u8; 256];
1795 k.copy_from_slice(&b);
1796 Some(k)
1797 }
1798 Some(b) => {
1799 return Err(libsql::Error::Misuse(format!(
1800 "auth_key blob must be 256 bytes, got {}",
1801 b.len()
1802 )));
1803 }
1804 None => None,
1805 };
1806 dcs.push(DcEntry {
1807 dc_id,
1808 addr,
1809 auth_key,
1810 first_salt,
1811 time_offset,
1812 flags: DcFlags(flags_raw),
1813 });
1814 }
1815
1816 let mut us_row = conn
1818 .query(
1819 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1820 (),
1821 )
1822 .await?;
1823 let updates_state = if let Some(r) = us_row.next().await? {
1824 UpdatesStateSnap {
1825 pts: r.get(0)?,
1826 qts: r.get(1)?,
1827 date: r.get(2)?,
1828 seq: r.get(3)?,
1829 channels: vec![],
1830 }
1831 } else {
1832 UpdatesStateSnap::default()
1833 };
1834
1835 let mut ch_rows = conn
1837 .query("SELECT channel_id, pts FROM channel_pts", ())
1838 .await?;
1839 let mut channels = Vec::new();
1840 while let Some(r) = ch_rows.next().await? {
1841 channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
1842 }
1843
1844 let mut peer_rows = conn
1846 .query("SELECT id, access_hash, is_channel FROM peers", ())
1847 .await?;
1848 let mut peers = Vec::new();
1849 while let Some(r) = peer_rows.next().await? {
1850 peers.push(CachedPeer {
1851 id: r.get(0)?,
1852 access_hash: r.get(1)?,
1853 is_channel: r.get::<i32>(2)? != 0,
1854 is_chat: false,
1855 });
1856 }
1857
1858 Ok(PersistedSession {
1859 home_dc_id,
1860 dcs,
1861 updates_state: UpdatesStateSnap {
1862 channels,
1863 ..updates_state
1864 },
1865 peers,
1866 min_peers: Vec::new(),
1867 })
1868 }
1869
1870 async fn write_session_async(
1871 conn: &libsql::Connection,
1872 s: &PersistedSession,
1873 ) -> Result<(), libsql::Error> {
1874 conn.execute_batch("BEGIN IMMEDIATE").await?;
1875
1876 conn.execute(
1877 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1878 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1879 libsql::params![s.home_dc_id],
1880 )
1881 .await?;
1882
1883 conn.execute("DELETE FROM dcs", ()).await?;
1884 for d in &s.dcs {
1885 conn.execute(
1886 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1887 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1888 libsql::params![
1889 d.dc_id,
1890 d.flags.0 as i64,
1891 d.addr.clone(),
1892 d.auth_key.map(|k| k.to_vec()),
1893 d.first_salt,
1894 d.time_offset,
1895 ],
1896 )
1897 .await?;
1898 }
1899
1900 let us = &s.updates_state;
1901 conn.execute(
1902 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1903 ON CONFLICT(id) DO UPDATE SET
1904 pts = MAX(excluded.pts, update_state.pts),
1905 qts = MAX(excluded.qts, update_state.qts),
1906 date = excluded.date,
1907 seq = excluded.seq",
1908 libsql::params![us.pts, us.qts, us.date, us.seq],
1909 )
1910 .await?;
1911
1912 conn.execute("DELETE FROM channel_pts", ()).await?;
1913 for &(cid, cpts) in &us.channels {
1914 conn.execute(
1915 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
1916 libsql::params![cid, cpts],
1917 )
1918 .await?;
1919 }
1920
1921 conn.execute("DELETE FROM peers", ()).await?;
1922 for p in &s.peers {
1923 conn.execute(
1924 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1,?2,?3)",
1925 libsql::params![p.id, p.access_hash, p.is_channel as i32],
1926 )
1927 .await?;
1928 }
1929
1930 conn.execute_batch("COMMIT").await
1931 }
1932}
1933
1934#[cfg(feature = "libsql-session")]
1935impl SessionBackend for LibSqlBackend {
1936 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1937 let conn = self.conn.clone();
1938 let session = session.clone();
1939 Self::block(async move {
1940 let conn = conn.lock().await;
1941 Self::write_session_async(&conn, session).await
1942 })
1943 }
1944
1945 fn load(&self) -> io::Result<Option<PersistedSession>> {
1946 let conn = self.conn.clone();
1947 let count: i64 = Self::block(async move {
1948 let conn = conn.lock().await;
1949 let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
1950 Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
1951 })?;
1952 if count == 0 {
1953 return Ok(None);
1954 }
1955 let conn = self.conn.clone();
1956 Self::block(async move {
1957 let conn = conn.lock().await;
1958 Self::read_session_async(&conn).await
1959 })
1960 .map(Some)
1961 }
1962
1963 fn delete(&self) -> io::Result<()> {
1964 let conn = self.conn.clone();
1965 Self::block(async move {
1966 let conn = conn.lock().await;
1967 conn.execute_batch(
1968 "BEGIN IMMEDIATE;
1969 DELETE FROM meta;
1970 DELETE FROM dcs;
1971 DELETE FROM update_state;
1972 DELETE FROM channel_pts;
1973 DELETE FROM peers;
1974 COMMIT;",
1975 )
1976 .await
1977 })
1978 }
1979
1980 fn name(&self) -> &str {
1981 &self.label
1982 }
1983
1984 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1987 let conn = self.conn.clone();
1988 let (dc_id, addr, key, salt, off, flags) = (
1989 entry.dc_id,
1990 entry.addr.clone(),
1991 entry.auth_key.map(|k| k.to_vec()),
1992 entry.first_salt,
1993 entry.time_offset,
1994 entry.flags.0 as i64,
1995 );
1996 Self::block(async move {
1997 let conn = conn.lock().await;
1998 conn.execute(
1999 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2000 VALUES (?1,?6,?2,?3,?4,?5)
2001 ON CONFLICT(dc_id, flags) DO UPDATE SET
2002 addr=excluded.addr, auth_key=excluded.auth_key,
2003 first_salt=excluded.first_salt, time_offset=excluded.time_offset",
2004 libsql::params![dc_id, addr, key, salt, off, flags],
2005 )
2006 .await
2007 .map(|_| ())
2008 })
2009 }
2010
2011 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
2012 let conn = self.conn.clone();
2013 Self::block(async move {
2014 let conn = conn.lock().await;
2015 conn.execute(
2016 "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
2017 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
2018 libsql::params![dc_id],
2019 )
2020 .await
2021 .map(|_| ())
2022 })
2023 }
2024
2025 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
2026 let conn = self.conn.clone();
2027 Self::block(async move {
2028 let conn = conn.lock().await;
2029 match update {
2030 UpdateStateChange::All(snap) => {
2031 conn.execute(
2032 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
2033 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
2034 date=excluded.date,seq=excluded.seq",
2035 libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
2036 )
2037 .await?;
2038 conn.execute("DELETE FROM channel_pts", ()).await?;
2039 for &(cid, cpts) in &snap.channels {
2040 conn.execute(
2041 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
2042 libsql::params![cid, cpts],
2043 )
2044 .await?;
2045 }
2046 Ok(())
2047 }
2048 UpdateStateChange::Primary { pts, date, seq } => conn
2049 .execute(
2050 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
2051 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
2052 seq=excluded.seq",
2053 libsql::params![pts, date, seq],
2054 )
2055 .await
2056 .map(|_| ()),
2057 UpdateStateChange::Secondary { qts } => conn
2058 .execute(
2059 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2060 ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2061 libsql::params![qts],
2062 )
2063 .await
2064 .map(|_| ()),
2065 UpdateStateChange::Channel { id, pts } => conn
2066 .execute(
2067 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2068 ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2069 libsql::params![id, pts],
2070 )
2071 .await
2072 .map(|_| ()),
2073 }
2074 })
2075 }
2076
2077 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2078 let conn = self.conn.clone();
2079 let (id, hash, is_ch) = (peer.id, peer.access_hash, peer.is_channel as i32);
2080 Self::block(async move {
2081 let conn = conn.lock().await;
2082 conn.execute(
2083 "INSERT INTO peers (id,access_hash,is_channel) VALUES (?1,?2,?3)
2084 ON CONFLICT(id) DO UPDATE SET
2085 access_hash=excluded.access_hash,
2086 is_channel=excluded.is_channel",
2087 libsql::params![id, hash, is_ch],
2088 )
2089 .await
2090 .map(|_| ())
2091 })
2092 }
2093}