1#![cfg_attr(docsrs, feature(doc_cfg))]
12use std::collections::HashMap;
73use std::io::{self, ErrorKind};
74use std::path::Path;
75
76#[cfg(feature = "serde")]
77mod auth_key_serde {
78 use serde::{Deserialize, Deserializer, Serializer};
79
80 pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
81 where
82 S: Serializer,
83 {
84 match value {
85 Some(k) => s.serialize_some(k.as_slice()),
86 None => s.serialize_none(),
87 }
88 }
89
90 pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
91 where
92 D: Deserializer<'de>,
93 {
94 let opt: Option<Vec<u8>> = Option::deserialize(d)?;
95 match opt {
96 None => Ok(None),
97 Some(v) => {
98 let arr: [u8; 256] = v
99 .try_into()
100 .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
101 Ok(Some(arr))
102 }
103 }
104 }
105}
106
107#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
111#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
112pub struct DcFlags(pub u8);
113
114impl DcFlags {
115 pub const NONE: DcFlags = DcFlags(0);
116 pub const IPV6: DcFlags = DcFlags(1 << 0);
117 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
118 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
119 pub const CDN: DcFlags = DcFlags(1 << 3);
120 pub const STATIC: DcFlags = DcFlags(1 << 4);
121
122 pub fn contains(self, other: DcFlags) -> bool {
123 self.0 & other.0 == other.0
124 }
125
126 pub fn set(&mut self, flag: DcFlags) {
127 self.0 |= flag.0;
128 }
129}
130
131impl std::ops::BitOr for DcFlags {
132 type Output = DcFlags;
133 fn bitor(self, rhs: DcFlags) -> DcFlags {
134 DcFlags(self.0 | rhs.0)
135 }
136}
137
138#[derive(Clone, Debug)]
140#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
141pub struct DcEntry {
142 pub dc_id: i32,
143 pub addr: String,
144 #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
145 pub auth_key: Option<[u8; 256]>,
146 pub first_salt: i64,
147 pub time_offset: i32,
148 pub flags: DcFlags,
150}
151
152impl DcEntry {
153 #[inline]
155 pub fn is_ipv6(&self) -> bool {
156 self.flags.contains(DcFlags::IPV6)
157 }
158
159 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
166 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
167 io::Error::new(
168 io::ErrorKind::InvalidData,
169 format!("invalid DC address: {:?}", self.addr),
170 )
171 })
172 }
173
174 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
187 let addr = if ip.contains(':') {
189 format!("[{ip}]:{port}")
190 } else {
191 format!("{ip}:{port}")
192 };
193 Self {
194 dc_id,
195 addr,
196 auth_key: None,
197 first_salt: 0,
198 time_offset: 0,
199 flags,
200 }
201 }
202}
203
204#[derive(Clone, Debug, Default)]
207#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
208pub struct UpdatesStateSnap {
209 pub pts: i32,
211 pub qts: i32,
213 pub date: i32,
215 pub seq: i32,
217 pub channels: Vec<(i64, i32)>,
219}
220
221impl UpdatesStateSnap {
222 #[inline]
224 pub fn is_initialised(&self) -> bool {
225 self.pts > 0
226 }
227
228 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
230 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
231 entry.1 = pts;
232 } else {
233 self.channels.push((channel_id, pts));
234 }
235 }
236
237 pub fn channel_pts(&self, channel_id: i64) -> i32 {
239 self.channels
240 .iter()
241 .find(|c| c.0 == channel_id)
242 .map(|c| c.1)
243 .unwrap_or(0)
244 }
245}
246
247#[derive(Clone, Debug)]
250#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
251pub struct CachedPeer {
252 pub id: i64,
254 pub access_hash: i64,
257 pub is_channel: bool,
259 pub is_chat: bool,
262}
263
264#[derive(Clone, Debug)]
268#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
269pub struct CachedMinPeer {
270 pub user_id: i64,
272 pub peer_id: i64,
274 pub msg_id: i32,
276}
277
278#[derive(Clone, Debug, Default)]
280#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
281pub struct PersistedSession {
282 pub home_dc_id: i32,
283 pub dcs: Vec<DcEntry>,
284 pub updates_state: UpdatesStateSnap,
286 pub peers: Vec<CachedPeer>,
289 pub min_peers: Vec<CachedMinPeer>,
292}
293
294impl PersistedSession {
295 pub fn to_bytes(&self) -> Vec<u8> {
297 let mut b = Vec::with_capacity(512);
298
299 b.push(0x05u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
302
303 b.push(self.dcs.len() as u8);
304 for d in &self.dcs {
305 b.extend_from_slice(&d.dc_id.to_le_bytes());
306 match &d.auth_key {
307 Some(k) => {
308 b.push(1);
309 b.extend_from_slice(k);
310 }
311 None => {
312 b.push(0);
313 }
314 }
315 b.extend_from_slice(&d.first_salt.to_le_bytes());
316 b.extend_from_slice(&d.time_offset.to_le_bytes());
317 let ab = d.addr.as_bytes();
318 b.push(ab.len() as u8);
319 b.extend_from_slice(ab);
320 b.push(d.flags.0);
321 }
322
323 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
324 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
325 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
326 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
327 let ch = &self.updates_state.channels;
328 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
329 for &(cid, cpts) in ch {
330 b.extend_from_slice(&cid.to_le_bytes());
331 b.extend_from_slice(&cpts.to_le_bytes());
332 }
333
334 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
336 for p in &self.peers {
337 b.extend_from_slice(&p.id.to_le_bytes());
338 b.extend_from_slice(&p.access_hash.to_le_bytes());
339 let peer_type: u8 = if p.is_chat {
340 2
341 } else if p.is_channel {
342 1
343 } else {
344 0
345 };
346 b.push(peer_type);
347 }
348
349 b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
350 for m in &self.min_peers {
351 b.extend_from_slice(&m.user_id.to_le_bytes());
352 b.extend_from_slice(&m.peer_id.to_le_bytes());
353 b.extend_from_slice(&m.msg_id.to_le_bytes());
354 }
355
356 b
357 }
358
359 pub fn save(&self, path: &Path) -> io::Result<()> {
366 use std::sync::atomic::{AtomicU64, Ordering};
367 static SEQ: AtomicU64 = AtomicU64::new(0);
368 let n = SEQ.fetch_add(1, Ordering::Relaxed);
369 let tmp = path.with_extension(format!("{n}.tmp"));
370 std::fs::write(&tmp, self.to_bytes())?;
371 std::fs::rename(&tmp, path).inspect_err(|_e| {
372 let _ = std::fs::remove_file(&tmp);
373 })
374 }
375
376 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
378 if buf.is_empty() {
379 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
380 }
381
382 let mut p = 0usize;
383
384 macro_rules! r {
385 ($n:expr) => {{
386 if p + $n > buf.len() {
387 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
388 }
389 let s = &buf[p..p + $n];
390 p += $n;
391 s
392 }};
393 }
394 macro_rules! r_i32 {
395 () => {
396 i32::from_le_bytes(r!(4).try_into().unwrap())
397 };
398 }
399 macro_rules! r_i64 {
400 () => {
401 i64::from_le_bytes(r!(8).try_into().unwrap())
402 };
403 }
404 macro_rules! r_u8 {
405 () => {
406 r!(1)[0]
407 };
408 }
409 macro_rules! r_u16 {
410 () => {
411 u16::from_le_bytes(r!(2).try_into().unwrap())
412 };
413 }
414
415 let first_byte = r_u8!();
416
417 let (home_dc_id, version) = if first_byte == 0x05 {
418 (r_i32!(), 5u8)
419 } else if first_byte == 0x04 {
420 (r_i32!(), 4u8)
421 } else if first_byte == 0x03 {
422 (r_i32!(), 3u8)
423 } else if first_byte == 0x02 {
424 (r_i32!(), 2u8)
425 } else {
426 let rest = r!(3);
427 let mut bytes = [0u8; 4];
428 bytes[0] = first_byte;
429 bytes[1..4].copy_from_slice(rest);
430 (i32::from_le_bytes(bytes), 1u8)
431 };
432
433 let dc_count = r_u8!() as usize;
434 let mut dcs = Vec::with_capacity(dc_count);
435 for _ in 0..dc_count {
436 let dc_id = r_i32!();
437 let has_key = r_u8!();
438 let auth_key = if has_key == 1 {
439 let mut k = [0u8; 256];
440 k.copy_from_slice(r!(256));
441 Some(k)
442 } else {
443 None
444 };
445 let first_salt = r_i64!();
446 let time_offset = r_i32!();
447 let al = r_u8!() as usize;
448 let addr = String::from_utf8_lossy(r!(al)).into_owned();
449 let flags = if version >= 3 {
450 DcFlags(r_u8!())
451 } else {
452 DcFlags::NONE
453 };
454 dcs.push(DcEntry {
455 dc_id,
456 addr,
457 auth_key,
458 first_salt,
459 time_offset,
460 flags,
461 });
462 }
463
464 if version < 2 {
465 return Ok(Self {
466 home_dc_id,
467 dcs,
468 updates_state: UpdatesStateSnap::default(),
469 peers: Vec::new(),
470 min_peers: Vec::new(),
471 });
472 }
473
474 let pts = r_i32!();
475 let qts = r_i32!();
476 let date = r_i32!();
477 let seq = r_i32!();
478 let ch_count = r_u16!() as usize;
479 let mut channels = Vec::with_capacity(ch_count);
480 for _ in 0..ch_count {
481 let cid = r_i64!();
482 let cpts = r_i32!();
483 channels.push((cid, cpts));
484 }
485
486 let peer_count = r_u16!() as usize;
487 let mut peers = Vec::with_capacity(peer_count);
488 for _ in 0..peer_count {
489 let id = r_i64!();
490 let access_hash = r_i64!();
491 let peer_type = r_u8!();
493 let is_channel = peer_type == 1;
494 let is_chat = peer_type == 2;
495 peers.push(CachedPeer {
496 id,
497 access_hash,
498 is_channel,
499 is_chat,
500 });
501 }
502
503 let min_peers = if version >= 4 {
505 let count = r_u16!() as usize;
506 let mut v = Vec::with_capacity(count);
507 for _ in 0..count {
508 let user_id = r_i64!();
509 let peer_id = r_i64!();
510 let msg_id = r_i32!();
511 v.push(CachedMinPeer {
512 user_id,
513 peer_id,
514 msg_id,
515 });
516 }
517 v
518 } else {
519 Vec::new()
520 };
521
522 Ok(Self {
523 home_dc_id,
524 dcs,
525 updates_state: UpdatesStateSnap {
526 pts,
527 qts,
528 date,
529 seq,
530 channels,
531 },
532 peers,
533 min_peers,
534 })
535 }
536
537 pub fn from_string(s: &str) -> io::Result<Self> {
539 use base64::Engine as _;
540 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
541 .decode(s.trim())
542 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
543 Self::from_bytes(&bytes)
544 }
545
546 pub fn load(path: &Path) -> io::Result<Self> {
547 let buf = std::fs::read(path)?;
548 Self::from_bytes(&buf)
549 }
550
551 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
563 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
564 candidates.peek()?;
565 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
567 cands
569 .iter()
570 .copied()
571 .find(|d| d.is_ipv6() == prefer_ipv6)
572 .or_else(|| cands.first().copied())
573 }
574
575 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
580 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
581 }
582}
583
584impl std::fmt::Display for PersistedSession {
585 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
586 use base64::Engine as _;
587 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
588 }
589}
590
591pub fn default_dc_addresses() -> HashMap<i32, String> {
593 [
594 (1, "149.154.175.53:443"),
595 (2, "149.154.167.51:443"),
596 (3, "149.154.175.100:443"),
597 (4, "149.154.167.91:443"),
598 (5, "91.108.56.130:443"),
599 ]
600 .into_iter()
601 .map(|(id, addr)| (id, addr.to_string()))
602 .collect()
603}
604
605use std::path::PathBuf;
610
611pub trait SessionBackend: Send + Sync {
619 fn save(&self, session: &PersistedSession) -> io::Result<()>;
620 fn load(&self) -> io::Result<Option<PersistedSession>>;
621 fn delete(&self) -> io::Result<()>;
622
623 fn name(&self) -> &str;
625
626 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
639 let mut s = self.load()?.unwrap_or_default();
640 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
642 *existing = entry.clone();
643 } else {
644 s.dcs.push(entry.clone());
645 }
646 self.save(&s)
647 }
648
649 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
655 let mut s = self.load()?.unwrap_or_default();
656 s.home_dc_id = dc_id;
657 self.save(&s)
658 }
659
660 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
665 let mut s = self.load()?.unwrap_or_default();
666 update.apply_to(&mut s.updates_state);
667 self.save(&s)
668 }
669
670 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
676 let mut s = self.load()?.unwrap_or_default();
677 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
678 *existing = peer.clone();
679 } else {
680 s.peers.push(peer.clone());
681 }
682 self.save(&s)
683 }
684}
685
686#[derive(Debug, Clone)]
700pub enum UpdateStateChange {
701 All(UpdatesStateSnap),
703 Primary { pts: i32, date: i32, seq: i32 },
705 Secondary { qts: i32 },
707 Channel { id: i64, pts: i32 },
709}
710
711impl UpdateStateChange {
712 pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
714 match self {
715 Self::All(new_snap) => *snap = new_snap.clone(),
716 Self::Primary { pts, date, seq } => {
717 snap.pts = *pts;
718 snap.date = *date;
719 snap.seq = *seq;
720 }
721 Self::Secondary { qts } => {
722 snap.qts = *qts;
723 }
724 Self::Channel { id, pts } => {
725 if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
727 existing.1 = *pts;
728 } else {
729 snap.channels.push((*id, *pts));
730 }
731 }
732 }
733 }
734}
735
736pub struct BinaryFileBackend {
740 path: PathBuf,
741 write_lock: std::sync::Mutex<()>,
746}
747
748impl BinaryFileBackend {
749 pub fn new(path: impl Into<PathBuf>) -> Self {
750 Self {
751 path: path.into(),
752 write_lock: std::sync::Mutex::new(()),
753 }
754 }
755
756 pub fn path(&self) -> &std::path::Path {
757 &self.path
758 }
759}
760
761impl SessionBackend for BinaryFileBackend {
762 fn save(&self, session: &PersistedSession) -> io::Result<()> {
763 let _guard = self.write_lock.lock().unwrap();
764 session.save(&self.path)
765 }
766
767 fn load(&self) -> io::Result<Option<PersistedSession>> {
768 if !self.path.exists() {
769 return Ok(None);
770 }
771 match PersistedSession::load(&self.path) {
772 Ok(s) => Ok(Some(s)),
773 Err(e) => {
774 let bak = self.path.with_extension("bak");
775 tracing::warn!(
776 "[ferogram] Session file {:?} is corrupt ({e}); \
777 renaming to {:?} and starting fresh",
778 self.path,
779 bak
780 );
781 let _ = std::fs::rename(&self.path, &bak);
782 Ok(None)
783 }
784 }
785 }
786
787 fn delete(&self) -> io::Result<()> {
788 if self.path.exists() {
789 std::fs::remove_file(&self.path)?;
790 }
791 Ok(())
792 }
793
794 fn name(&self) -> &str {
795 "binary-file"
796 }
797
798 }
801
802#[derive(Default)]
810pub struct InMemoryBackend {
811 data: std::sync::Mutex<Option<PersistedSession>>,
812}
813
814impl InMemoryBackend {
815 pub fn new() -> Self {
816 Self::default()
817 }
818
819 pub fn snapshot(&self) -> Option<PersistedSession> {
821 self.data.lock().unwrap().clone()
822 }
823}
824
825impl SessionBackend for InMemoryBackend {
826 fn save(&self, s: &PersistedSession) -> io::Result<()> {
827 *self.data.lock().unwrap() = Some(s.clone());
828 Ok(())
829 }
830
831 fn load(&self) -> io::Result<Option<PersistedSession>> {
832 Ok(self.data.lock().unwrap().clone())
833 }
834
835 fn delete(&self) -> io::Result<()> {
836 *self.data.lock().unwrap() = None;
837 Ok(())
838 }
839
840 fn name(&self) -> &str {
841 "in-memory"
842 }
843
844 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
847 let mut guard = self.data.lock().unwrap();
848 let s = guard.get_or_insert_with(PersistedSession::default);
849 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
850 *existing = entry.clone();
851 } else {
852 s.dcs.push(entry.clone());
853 }
854 Ok(())
855 }
856
857 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
858 let mut guard = self.data.lock().unwrap();
859 let s = guard.get_or_insert_with(PersistedSession::default);
860 s.home_dc_id = dc_id;
861 Ok(())
862 }
863
864 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
865 let mut guard = self.data.lock().unwrap();
866 let s = guard.get_or_insert_with(PersistedSession::default);
867 update.apply_to(&mut s.updates_state);
868 Ok(())
869 }
870
871 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
872 let mut guard = self.data.lock().unwrap();
873 let s = guard.get_or_insert_with(PersistedSession::default);
874 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
875 *existing = peer.clone();
876 } else {
877 s.peers.push(peer.clone());
878 }
879 Ok(())
880 }
881}
882
883pub struct StringSessionBackend {
887 data: std::sync::Mutex<String>,
888}
889
890impl StringSessionBackend {
891 pub fn new(s: impl Into<String>) -> Self {
892 Self {
893 data: std::sync::Mutex::new(s.into()),
894 }
895 }
896
897 pub fn current(&self) -> String {
898 self.data.lock().unwrap().clone()
899 }
900}
901
902impl SessionBackend for StringSessionBackend {
903 fn save(&self, session: &PersistedSession) -> io::Result<()> {
904 *self.data.lock().unwrap() = session.to_string();
905 Ok(())
906 }
907
908 fn load(&self) -> io::Result<Option<PersistedSession>> {
909 let s = self.data.lock().unwrap().clone();
910 if s.trim().is_empty() {
911 return Ok(None);
912 }
913 PersistedSession::from_string(&s).map(Some)
914 }
915
916 fn delete(&self) -> io::Result<()> {
917 *self.data.lock().unwrap() = String::new();
918 Ok(())
919 }
920
921 fn name(&self) -> &str {
922 "string-session"
923 }
924}
925
926#[cfg(test)]
929mod tests {
930 use super::*;
931
932 fn make_dc(id: i32) -> DcEntry {
933 DcEntry {
934 dc_id: id,
935 addr: format!("1.2.3.{id}:443"),
936 auth_key: None,
937 first_salt: 0,
938 time_offset: 0,
939 flags: DcFlags::NONE,
940 }
941 }
942
943 fn make_peer(id: i64, hash: i64) -> CachedPeer {
944 CachedPeer {
945 id,
946 access_hash: hash,
947 is_channel: false,
948 is_chat: false,
949 }
950 }
951
952 #[test]
955 fn inmemory_load_returns_none_when_empty() {
956 let b = InMemoryBackend::new();
957 assert!(b.load().unwrap().is_none());
958 }
959
960 #[test]
961 fn inmemory_save_then_load_round_trips() {
962 let b = InMemoryBackend::new();
963 let mut s = PersistedSession::default();
964 s.home_dc_id = 3;
965 s.dcs.push(make_dc(3));
966 b.save(&s).unwrap();
967
968 let loaded = b.load().unwrap().unwrap();
969 assert_eq!(loaded.home_dc_id, 3);
970 assert_eq!(loaded.dcs.len(), 1);
971 }
972
973 #[test]
974 fn inmemory_delete_clears_state() {
975 let b = InMemoryBackend::new();
976 let mut s = PersistedSession::default();
977 s.home_dc_id = 2;
978 b.save(&s).unwrap();
979 b.delete().unwrap();
980 assert!(b.load().unwrap().is_none());
981 }
982
983 #[test]
986 fn inmemory_update_dc_inserts_new() {
987 let b = InMemoryBackend::new();
988 b.update_dc(&make_dc(4)).unwrap();
989 let s = b.snapshot().unwrap();
990 assert_eq!(s.dcs.len(), 1);
991 assert_eq!(s.dcs[0].dc_id, 4);
992 }
993
994 #[test]
995 fn inmemory_update_dc_replaces_existing() {
996 let b = InMemoryBackend::new();
997 b.update_dc(&make_dc(2)).unwrap();
998 let mut updated = make_dc(2);
999 updated.addr = "9.9.9.9:443".to_string();
1000 b.update_dc(&updated).unwrap();
1001
1002 let s = b.snapshot().unwrap();
1003 assert_eq!(s.dcs.len(), 1);
1004 assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
1005 }
1006
1007 #[test]
1008 fn inmemory_set_home_dc() {
1009 let b = InMemoryBackend::new();
1010 b.set_home_dc(5).unwrap();
1011 assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
1012 }
1013
1014 #[test]
1015 fn inmemory_cache_peer_inserts() {
1016 let b = InMemoryBackend::new();
1017 b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
1018 let s = b.snapshot().unwrap();
1019 assert_eq!(s.peers.len(), 1);
1020 assert_eq!(s.peers[0].id, 100);
1021 }
1022
1023 #[test]
1024 fn inmemory_cache_peer_updates_existing() {
1025 let b = InMemoryBackend::new();
1026 b.cache_peer(&make_peer(100, 111)).unwrap();
1027 b.cache_peer(&make_peer(100, 222)).unwrap();
1028 let s = b.snapshot().unwrap();
1029 assert_eq!(s.peers.len(), 1);
1030 assert_eq!(s.peers[0].access_hash, 222);
1031 }
1032
1033 #[test]
1036 fn update_state_primary() {
1037 let mut snap = UpdatesStateSnap {
1038 pts: 0,
1039 qts: 0,
1040 date: 0,
1041 seq: 0,
1042 channels: vec![],
1043 };
1044 UpdateStateChange::Primary {
1045 pts: 10,
1046 date: 20,
1047 seq: 30,
1048 }
1049 .apply_to(&mut snap);
1050 assert_eq!(snap.pts, 10);
1051 assert_eq!(snap.date, 20);
1052 assert_eq!(snap.seq, 30);
1053 assert_eq!(snap.qts, 0); }
1055
1056 #[test]
1057 fn update_state_secondary() {
1058 let mut snap = UpdatesStateSnap {
1059 pts: 5,
1060 qts: 0,
1061 date: 0,
1062 seq: 0,
1063 channels: vec![],
1064 };
1065 UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1066 assert_eq!(snap.qts, 99);
1067 assert_eq!(snap.pts, 5); }
1069
1070 #[test]
1071 fn update_state_channel_inserts() {
1072 let mut snap = UpdatesStateSnap {
1073 pts: 0,
1074 qts: 0,
1075 date: 0,
1076 seq: 0,
1077 channels: vec![],
1078 };
1079 UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1080 assert_eq!(snap.channels, vec![(12345, 42)]);
1081 }
1082
1083 #[test]
1084 fn update_state_channel_updates_existing() {
1085 let mut snap = UpdatesStateSnap {
1086 pts: 0,
1087 qts: 0,
1088 date: 0,
1089 seq: 0,
1090 channels: vec![(12345, 10), (67890, 5)],
1091 };
1092 UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1093 assert_eq!(snap.channels[0], (12345, 99));
1095 assert_eq!(snap.channels[1], (67890, 5));
1096 }
1097
1098 #[test]
1099 fn apply_update_state_via_backend() {
1100 let b = InMemoryBackend::new();
1101 b.apply_update_state(UpdateStateChange::Primary {
1102 pts: 7,
1103 date: 8,
1104 seq: 9,
1105 })
1106 .unwrap();
1107 let s = b.snapshot().unwrap();
1108 assert_eq!(s.updates_state.pts, 7);
1109 }
1110
1111 #[test]
1114 fn default_update_dc_via_trait_object() {
1115 let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1116 b.update_dc(&make_dc(1)).unwrap();
1117 b.update_dc(&make_dc(2)).unwrap();
1118 let loaded = b.load().unwrap().unwrap();
1120 assert_eq!(loaded.dcs.len(), 2);
1121 }
1122
1123 fn make_dc_v6(id: i32) -> DcEntry {
1126 DcEntry {
1127 dc_id: id,
1128 addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1129 auth_key: None,
1130 first_salt: 0,
1131 time_offset: 0,
1132 flags: DcFlags::IPV6,
1133 }
1134 }
1135
1136 #[test]
1137 fn dc_entry_from_parts_ipv4() {
1138 let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1139 assert_eq!(dc.addr, "149.154.175.53:443");
1140 assert!(!dc.is_ipv6());
1141 let sa = dc.socket_addr().unwrap();
1142 assert_eq!(sa.port(), 443);
1143 }
1144
1145 #[test]
1146 fn dc_entry_from_parts_ipv6() {
1147 let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1148 assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1149 assert!(dc.is_ipv6());
1150 let sa = dc.socket_addr().unwrap();
1151 assert_eq!(sa.port(), 443);
1152 }
1153
1154 #[test]
1155 fn persisted_session_dc_for_prefers_ipv6() {
1156 let mut s = PersistedSession::default();
1157 s.dcs.push(make_dc(2)); s.dcs.push(make_dc_v6(2)); let v6 = s.dc_for(2, true).unwrap();
1161 assert!(v6.is_ipv6());
1162
1163 let v4 = s.dc_for(2, false).unwrap();
1164 assert!(!v4.is_ipv6());
1165 }
1166
1167 #[test]
1168 fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1169 let mut s = PersistedSession::default();
1170 s.dcs.push(make_dc(3)); let dc = s.dc_for(3, true).unwrap();
1174 assert!(!dc.is_ipv6());
1175 }
1176
1177 #[test]
1178 fn persisted_session_all_dcs_for_returns_both() {
1179 let mut s = PersistedSession::default();
1180 s.dcs.push(make_dc(1));
1181 s.dcs.push(make_dc_v6(1));
1182 s.dcs.push(make_dc(2));
1183
1184 assert_eq!(s.all_dcs_for(1).count(), 2);
1185 assert_eq!(s.all_dcs_for(2).count(), 1);
1186 assert_eq!(s.all_dcs_for(5).count(), 0);
1187 }
1188
1189 #[test]
1190 fn inmemory_ipv4_and_ipv6_coexist() {
1191 let b = InMemoryBackend::new();
1192 b.update_dc(&make_dc(2)).unwrap(); b.update_dc(&make_dc_v6(2)).unwrap(); let s = b.snapshot().unwrap();
1196 assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1198 }
1199
1200 #[test]
1201 fn binary_roundtrip_ipv4_and_ipv6() {
1202 let mut s = PersistedSession::default();
1203 s.home_dc_id = 2;
1204 s.dcs.push(make_dc(2));
1205 s.dcs.push(make_dc_v6(2));
1206
1207 let bytes = s.to_bytes();
1208 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1209 assert_eq!(loaded.dcs.len(), 2);
1210 assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1211 assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1212 }
1213}
1214
1215#[cfg(feature = "sqlite-session")]
1240pub struct SqliteBackend {
1241 conn: std::sync::Mutex<rusqlite::Connection>,
1242 label: String,
1243}
1244
1245#[cfg(feature = "sqlite-session")]
1246impl SqliteBackend {
1247 const SCHEMA: &'static str = "
1248 PRAGMA journal_mode = WAL;
1249 PRAGMA synchronous = NORMAL;
1250
1251 CREATE TABLE IF NOT EXISTS meta (
1252 key TEXT PRIMARY KEY,
1253 value INTEGER NOT NULL DEFAULT 0
1254 );
1255
1256 CREATE TABLE IF NOT EXISTS dcs (
1257 dc_id INTEGER NOT NULL,
1258 flags INTEGER NOT NULL DEFAULT 0,
1259 addr TEXT NOT NULL,
1260 auth_key BLOB,
1261 first_salt INTEGER NOT NULL DEFAULT 0,
1262 time_offset INTEGER NOT NULL DEFAULT 0,
1263 PRIMARY KEY (dc_id, flags)
1264 );
1265
1266 CREATE TABLE IF NOT EXISTS update_state (
1267 id INTEGER PRIMARY KEY CHECK (id = 1),
1268 pts INTEGER NOT NULL DEFAULT 0,
1269 qts INTEGER NOT NULL DEFAULT 0,
1270 date INTEGER NOT NULL DEFAULT 0,
1271 seq INTEGER NOT NULL DEFAULT 0
1272 );
1273
1274 CREATE TABLE IF NOT EXISTS channel_pts (
1275 channel_id INTEGER PRIMARY KEY,
1276 pts INTEGER NOT NULL
1277 );
1278
1279 CREATE TABLE IF NOT EXISTS peers (
1280 id INTEGER PRIMARY KEY,
1281 access_hash INTEGER NOT NULL,
1282 is_channel INTEGER NOT NULL DEFAULT 0
1283 );
1284 ";
1285
1286 pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1288 let path = path.into();
1289 let label = path.display().to_string();
1290 let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1291 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1292 Ok(Self {
1293 conn: std::sync::Mutex::new(conn),
1294 label,
1295 })
1296 }
1297
1298 pub fn in_memory() -> io::Result<Self> {
1300 let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1301 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1302 Ok(Self {
1303 conn: std::sync::Mutex::new(conn),
1304 label: ":memory:".into(),
1305 })
1306 }
1307
1308 fn map_err(e: rusqlite::Error) -> io::Error {
1309 io::Error::other(e)
1310 }
1311
1312 fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1314 let home_dc_id: i32 = conn
1316 .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1317 r.get(0)
1318 })
1319 .unwrap_or(0);
1320
1321 let mut stmt = conn
1323 .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1324 .map_err(Self::map_err)?;
1325 let dcs = stmt
1326 .query_map([], |row| {
1327 let dc_id: i32 = row.get(0)?;
1328 let flags_raw: u8 = row.get(1)?;
1329 let addr: String = row.get(2)?;
1330 let key_blob: Option<Vec<u8>> = row.get(3)?;
1331 let first_salt: i64 = row.get(4)?;
1332 let time_offset: i32 = row.get(5)?;
1333 Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1334 })
1335 .map_err(Self::map_err)?
1336 .filter_map(|r| r.ok())
1337 .map(
1338 |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1339 let auth_key = key_blob.and_then(|b| {
1340 if b.len() == 256 {
1341 let mut k = [0u8; 256];
1342 k.copy_from_slice(&b);
1343 Some(k)
1344 } else {
1345 None
1346 }
1347 });
1348 DcEntry {
1349 dc_id,
1350 addr,
1351 auth_key,
1352 first_salt,
1353 time_offset,
1354 flags: DcFlags(flags_raw),
1355 }
1356 },
1357 )
1358 .collect();
1359
1360 let updates_state = conn
1362 .query_row(
1363 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1364 [],
1365 |r| {
1366 Ok(UpdatesStateSnap {
1367 pts: r.get(0)?,
1368 qts: r.get(1)?,
1369 date: r.get(2)?,
1370 seq: r.get(3)?,
1371 channels: vec![],
1372 })
1373 },
1374 )
1375 .unwrap_or_default();
1376
1377 let mut ch_stmt = conn
1379 .prepare("SELECT channel_id, pts FROM channel_pts")
1380 .map_err(Self::map_err)?;
1381 let channels: Vec<(i64, i32)> = ch_stmt
1382 .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1383 .map_err(Self::map_err)?
1384 .filter_map(|r| r.ok())
1385 .collect();
1386
1387 let mut peer_stmt = conn
1389 .prepare("SELECT id, access_hash, is_channel FROM peers")
1390 .map_err(Self::map_err)?;
1391 let peers: Vec<CachedPeer> = peer_stmt
1392 .query_map([], |r| {
1393 Ok(CachedPeer {
1394 id: r.get(0)?,
1395 access_hash: r.get(1)?,
1396 is_channel: r.get::<_, i32>(2)? != 0,
1397 is_chat: false,
1398 })
1399 })
1400 .map_err(Self::map_err)?
1401 .filter_map(|r| r.ok())
1402 .collect();
1403
1404 Ok(PersistedSession {
1405 home_dc_id,
1406 dcs,
1407 updates_state: UpdatesStateSnap {
1408 channels,
1409 ..updates_state
1410 },
1411 peers,
1412 min_peers: Vec::new(),
1413 })
1414 }
1415
1416 fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1418 conn.execute_batch("BEGIN IMMEDIATE")
1419 .map_err(Self::map_err)?;
1420
1421 conn.execute(
1422 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1423 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1424 rusqlite::params![s.home_dc_id],
1425 )
1426 .map_err(Self::map_err)?;
1427
1428 conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1430 for d in &s.dcs {
1431 conn.execute(
1432 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1433 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1434 rusqlite::params![
1435 d.dc_id,
1436 d.flags.0,
1437 d.addr,
1438 d.auth_key.as_ref().map(|k| k.as_ref()),
1439 d.first_salt,
1440 d.time_offset,
1441 ],
1442 )
1443 .map_err(Self::map_err)?;
1444 }
1445
1446 let us = &s.updates_state;
1450 conn.execute(
1451 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1452 ON CONFLICT(id) DO UPDATE SET
1453 pts = MAX(excluded.pts, update_state.pts),
1454 qts = MAX(excluded.qts, update_state.qts),
1455 date = excluded.date,
1456 seq = excluded.seq",
1457 rusqlite::params![us.pts, us.qts, us.date, us.seq],
1458 )
1459 .map_err(Self::map_err)?;
1460
1461 conn.execute("DELETE FROM channel_pts", [])
1462 .map_err(Self::map_err)?;
1463 for &(cid, cpts) in &us.channels {
1464 conn.execute(
1465 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1466 rusqlite::params![cid, cpts],
1467 )
1468 .map_err(Self::map_err)?;
1469 }
1470
1471 conn.execute("DELETE FROM peers", [])
1473 .map_err(Self::map_err)?;
1474 for p in &s.peers {
1475 conn.execute(
1476 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)",
1477 rusqlite::params![p.id, p.access_hash, p.is_channel as i32],
1478 )
1479 .map_err(Self::map_err)?;
1480 }
1481
1482 conn.execute_batch("COMMIT").map_err(Self::map_err)
1483 }
1484}
1485
1486#[cfg(feature = "sqlite-session")]
1487impl SessionBackend for SqliteBackend {
1488 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1489 let conn = self.conn.lock().unwrap();
1490 Self::write_session(&conn, session)
1491 }
1492
1493 fn load(&self) -> io::Result<Option<PersistedSession>> {
1494 let conn = self.conn.lock().unwrap();
1495 let count: i64 = conn
1497 .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1498 .map_err(Self::map_err)?;
1499 if count == 0 {
1500 return Ok(None);
1501 }
1502 Self::read_session(&conn).map(Some)
1503 }
1504
1505 fn delete(&self) -> io::Result<()> {
1506 let conn = self.conn.lock().unwrap();
1507 conn.execute_batch(
1508 "BEGIN IMMEDIATE;
1509 DELETE FROM meta;
1510 DELETE FROM dcs;
1511 DELETE FROM update_state;
1512 DELETE FROM channel_pts;
1513 DELETE FROM peers;
1514 COMMIT;",
1515 )
1516 .map_err(Self::map_err)
1517 }
1518
1519 fn name(&self) -> &str {
1520 &self.label
1521 }
1522
1523 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1526 let conn = self.conn.lock().unwrap();
1527 conn.execute(
1528 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1529 VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1530 ON CONFLICT(dc_id, flags) DO UPDATE SET
1531 addr = excluded.addr,
1532 auth_key = excluded.auth_key,
1533 first_salt = excluded.first_salt,
1534 time_offset = excluded.time_offset",
1535 rusqlite::params![
1536 entry.dc_id,
1537 entry.addr,
1538 entry.auth_key.as_ref().map(|k| k.as_ref()),
1539 entry.first_salt,
1540 entry.time_offset,
1541 entry.flags.0,
1542 ],
1543 )
1544 .map(|_| ())
1545 .map_err(Self::map_err)
1546 }
1547
1548 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1549 let conn = self.conn.lock().unwrap();
1550 conn.execute(
1551 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1552 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1553 rusqlite::params![dc_id],
1554 )
1555 .map(|_| ())
1556 .map_err(Self::map_err)
1557 }
1558
1559 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1560 let conn = self.conn.lock().unwrap();
1561 match update {
1562 UpdateStateChange::All(snap) => {
1563 conn.execute(
1564 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1565 ON CONFLICT(id) DO UPDATE SET
1566 pts=excluded.pts, qts=excluded.qts,
1567 date=excluded.date, seq=excluded.seq",
1568 rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1569 )
1570 .map_err(Self::map_err)?;
1571 conn.execute("DELETE FROM channel_pts", [])
1572 .map_err(Self::map_err)?;
1573 for &(cid, cpts) in &snap.channels {
1574 conn.execute(
1575 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1576 rusqlite::params![cid, cpts],
1577 )
1578 .map_err(Self::map_err)?;
1579 }
1580 Ok(())
1581 }
1582 UpdateStateChange::Primary { pts, date, seq } => conn
1583 .execute(
1584 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1585 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1586 seq=excluded.seq",
1587 rusqlite::params![pts, date, seq],
1588 )
1589 .map(|_| ())
1590 .map_err(Self::map_err),
1591 UpdateStateChange::Secondary { qts } => conn
1592 .execute(
1593 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1594 ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1595 rusqlite::params![qts],
1596 )
1597 .map(|_| ())
1598 .map_err(Self::map_err),
1599 UpdateStateChange::Channel { id, pts } => conn
1600 .execute(
1601 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1602 ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1603 rusqlite::params![id, pts],
1604 )
1605 .map(|_| ())
1606 .map_err(Self::map_err),
1607 }
1608 }
1609
1610 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1611 let conn = self.conn.lock().unwrap();
1612 conn.execute(
1613 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)
1614 ON CONFLICT(id) DO UPDATE SET
1615 access_hash = excluded.access_hash,
1616 is_channel = excluded.is_channel",
1617 rusqlite::params![peer.id, peer.access_hash, peer.is_channel as i32],
1618 )
1619 .map(|_| ())
1620 .map_err(Self::map_err)
1621 }
1622}
1623
1624#[cfg(feature = "libsql-session")]
1644pub struct LibSqlBackend {
1645 conn: libsql::Connection,
1646 label: String,
1647}
1648
1649#[cfg(feature = "libsql-session")]
1650impl LibSqlBackend {
1651 const SCHEMA: &'static str = "
1652 CREATE TABLE IF NOT EXISTS meta (
1653 key TEXT PRIMARY KEY,
1654 value INTEGER NOT NULL DEFAULT 0
1655 );
1656 CREATE TABLE IF NOT EXISTS dcs (
1657 dc_id INTEGER NOT NULL,
1658 flags INTEGER NOT NULL DEFAULT 0,
1659 addr TEXT NOT NULL,
1660 auth_key BLOB,
1661 first_salt INTEGER NOT NULL DEFAULT 0,
1662 time_offset INTEGER NOT NULL DEFAULT 0,
1663 PRIMARY KEY (dc_id, flags)
1664 );
1665 CREATE TABLE IF NOT EXISTS update_state (
1666 id INTEGER PRIMARY KEY CHECK (id = 1),
1667 pts INTEGER NOT NULL DEFAULT 0,
1668 qts INTEGER NOT NULL DEFAULT 0,
1669 date INTEGER NOT NULL DEFAULT 0,
1670 seq INTEGER NOT NULL DEFAULT 0
1671 );
1672 CREATE TABLE IF NOT EXISTS channel_pts (
1673 channel_id INTEGER PRIMARY KEY,
1674 pts INTEGER NOT NULL
1675 );
1676 CREATE TABLE IF NOT EXISTS peers (
1677 id INTEGER PRIMARY KEY,
1678 access_hash INTEGER NOT NULL,
1679 is_channel INTEGER NOT NULL DEFAULT 0
1680 );
1681 ";
1682
1683 fn block<F, T>(fut: F) -> io::Result<T>
1684 where
1685 F: std::future::Future<Output = Result<T, libsql::Error>>,
1686 {
1687 tokio::runtime::Handle::current()
1688 .block_on(fut)
1689 .map_err(io::Error::other)
1690 }
1691
1692 async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1693 conn.execute_batch(Self::SCHEMA).await
1694 }
1695
1696 pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1698 let path = path.into();
1699 let label = path.display().to_string();
1700 let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1701 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1702 Self::block(Self::apply_schema(&conn))?;
1703 Ok(Self {
1704 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1705 label,
1706 })
1707 }
1708
1709 pub fn in_memory() -> io::Result<Self> {
1711 let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1712 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1713 Self::block(Self::apply_schema(&conn))?;
1714 Ok(Self {
1715 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1716 label: ":memory:".into(),
1717 })
1718 }
1719
1720 pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1722 let url = url.into();
1723 let label = url.clone();
1724 let db = Self::block(async {
1725 libsql::Builder::new_remote(url, auth_token.into())
1726 .build()
1727 .await
1728 })?;
1729 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1730 Self::block(Self::apply_schema(&conn))?;
1731 Ok(Self {
1732 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1733 label,
1734 })
1735 }
1736
1737 pub fn open_replica(
1739 path: impl Into<PathBuf>,
1740 url: impl Into<String>,
1741 auth_token: impl Into<String>,
1742 ) -> io::Result<Self> {
1743 let path = path.into();
1744 let label = format!("{} (replica of {})", path.display(), url.into());
1745 let db = Self::block(async {
1746 libsql::Builder::new_remote_replica(path, url.into(), auth_token.into())
1747 .build()
1748 .await
1749 })?;
1750 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1751 Self::block(Self::apply_schema(&conn))?;
1752 Ok(Self {
1753 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1754 label,
1755 })
1756 }
1757
1758 async fn read_session_async(
1759 conn: &libsql::Connection,
1760 ) -> Result<PersistedSession, libsql::Error> {
1761 use libsql::de;
1762
1763 let home_dc_id: i32 = conn
1765 .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
1766 .await?
1767 .next()
1768 .await?
1769 .map(|r| r.get::<i32>(0))
1770 .transpose()?
1771 .unwrap_or(0);
1772
1773 let mut rows = conn
1775 .query(
1776 "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
1777 (),
1778 )
1779 .await?;
1780 let mut dcs = Vec::new();
1781 while let Some(row) = rows.next().await? {
1782 let dc_id: i32 = row.get(0)?;
1783 let flags_raw: u8 = row.get::<i64>(1)? as u8;
1784 let addr: String = row.get(2)?;
1785 let key_blob: Option<Vec<u8>> = row.get(3)?;
1786 let first_salt: i64 = row.get(4)?;
1787 let time_offset: i32 = row.get(5)?;
1788 let auth_key = match key_blob {
1789 Some(b) if b.len() == 256 => {
1790 let mut k = [0u8; 256];
1791 k.copy_from_slice(&b);
1792 Some(k)
1793 }
1794 Some(b) => {
1795 return Err(libsql::Error::Misuse(format!(
1796 "auth_key blob must be 256 bytes, got {}",
1797 b.len()
1798 )));
1799 }
1800 None => None,
1801 };
1802 dcs.push(DcEntry {
1803 dc_id,
1804 addr,
1805 auth_key,
1806 first_salt,
1807 time_offset,
1808 flags: DcFlags(flags_raw),
1809 });
1810 }
1811
1812 let mut us_row = conn
1814 .query(
1815 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1816 (),
1817 )
1818 .await?;
1819 let updates_state = if let Some(r) = us_row.next().await? {
1820 UpdatesStateSnap {
1821 pts: r.get(0)?,
1822 qts: r.get(1)?,
1823 date: r.get(2)?,
1824 seq: r.get(3)?,
1825 channels: vec![],
1826 }
1827 } else {
1828 UpdatesStateSnap::default()
1829 };
1830
1831 let mut ch_rows = conn
1833 .query("SELECT channel_id, pts FROM channel_pts", ())
1834 .await?;
1835 let mut channels = Vec::new();
1836 while let Some(r) = ch_rows.next().await? {
1837 channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
1838 }
1839
1840 let mut peer_rows = conn
1842 .query("SELECT id, access_hash, is_channel FROM peers", ())
1843 .await?;
1844 let mut peers = Vec::new();
1845 while let Some(r) = peer_rows.next().await? {
1846 peers.push(CachedPeer {
1847 id: r.get(0)?,
1848 access_hash: r.get(1)?,
1849 is_channel: r.get::<i32>(2)? != 0,
1850 is_chat: false,
1851 });
1852 }
1853
1854 Ok(PersistedSession {
1855 home_dc_id,
1856 dcs,
1857 updates_state: UpdatesStateSnap {
1858 channels,
1859 ..updates_state
1860 },
1861 peers,
1862 min_peers: Vec::new(),
1863 })
1864 }
1865
1866 async fn write_session_async(
1867 conn: &libsql::Connection,
1868 s: &PersistedSession,
1869 ) -> Result<(), libsql::Error> {
1870 conn.execute_batch("BEGIN IMMEDIATE").await?;
1871
1872 conn.execute(
1873 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1874 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1875 libsql::params![s.home_dc_id],
1876 )
1877 .await?;
1878
1879 conn.execute("DELETE FROM dcs", ()).await?;
1880 for d in &s.dcs {
1881 conn.execute(
1882 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1883 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1884 libsql::params![
1885 d.dc_id,
1886 d.flags.0 as i64,
1887 d.addr.clone(),
1888 d.auth_key.map(|k| k.to_vec()),
1889 d.first_salt,
1890 d.time_offset,
1891 ],
1892 )
1893 .await?;
1894 }
1895
1896 let us = &s.updates_state;
1897 conn.execute(
1898 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1899 ON CONFLICT(id) DO UPDATE SET
1900 pts = MAX(excluded.pts, update_state.pts),
1901 qts = MAX(excluded.qts, update_state.qts),
1902 date = excluded.date,
1903 seq = excluded.seq",
1904 libsql::params![us.pts, us.qts, us.date, us.seq],
1905 )
1906 .await?;
1907
1908 conn.execute("DELETE FROM channel_pts", ()).await?;
1909 for &(cid, cpts) in &us.channels {
1910 conn.execute(
1911 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
1912 libsql::params![cid, cpts],
1913 )
1914 .await?;
1915 }
1916
1917 conn.execute("DELETE FROM peers", ()).await?;
1918 for p in &s.peers {
1919 conn.execute(
1920 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1,?2,?3)",
1921 libsql::params![p.id, p.access_hash, p.is_channel as i32],
1922 )
1923 .await?;
1924 }
1925
1926 conn.execute_batch("COMMIT").await
1927 }
1928}
1929
1930#[cfg(feature = "libsql-session")]
1931impl SessionBackend for LibSqlBackend {
1932 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1933 let conn = self.conn.clone();
1934 let session = session.clone();
1935 Self::block(async move {
1936 let conn = conn.lock().await;
1937 Self::write_session_async(&conn, session).await
1938 })
1939 }
1940
1941 fn load(&self) -> io::Result<Option<PersistedSession>> {
1942 let conn = self.conn.clone();
1943 let count: i64 = Self::block(async move {
1944 let conn = conn.lock().await;
1945 let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
1946 Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
1947 })?;
1948 if count == 0 {
1949 return Ok(None);
1950 }
1951 let conn = self.conn.clone();
1952 Self::block(async move {
1953 let conn = conn.lock().await;
1954 Self::read_session_async(&conn).await
1955 })
1956 .map(Some)
1957 }
1958
1959 fn delete(&self) -> io::Result<()> {
1960 let conn = self.conn.clone();
1961 Self::block(async move {
1962 let conn = conn.lock().await;
1963 conn.execute_batch(
1964 "BEGIN IMMEDIATE;
1965 DELETE FROM meta;
1966 DELETE FROM dcs;
1967 DELETE FROM update_state;
1968 DELETE FROM channel_pts;
1969 DELETE FROM peers;
1970 COMMIT;",
1971 )
1972 .await
1973 })
1974 }
1975
1976 fn name(&self) -> &str {
1977 &self.label
1978 }
1979
1980 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1983 let conn = self.conn.clone();
1984 let (dc_id, addr, key, salt, off, flags) = (
1985 entry.dc_id,
1986 entry.addr.clone(),
1987 entry.auth_key.map(|k| k.to_vec()),
1988 entry.first_salt,
1989 entry.time_offset,
1990 entry.flags.0 as i64,
1991 );
1992 Self::block(async move {
1993 let conn = conn.lock().await;
1994 conn.execute(
1995 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1996 VALUES (?1,?6,?2,?3,?4,?5)
1997 ON CONFLICT(dc_id, flags) DO UPDATE SET
1998 addr=excluded.addr, auth_key=excluded.auth_key,
1999 first_salt=excluded.first_salt, time_offset=excluded.time_offset",
2000 libsql::params![dc_id, addr, key, salt, off, flags],
2001 )
2002 .await
2003 .map(|_| ())
2004 })
2005 }
2006
2007 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
2008 let conn = self.conn.clone();
2009 Self::block(async move {
2010 let conn = conn.lock().await;
2011 conn.execute(
2012 "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
2013 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
2014 libsql::params![dc_id],
2015 )
2016 .await
2017 .map(|_| ())
2018 })
2019 }
2020
2021 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
2022 let conn = self.conn.clone();
2023 Self::block(async move {
2024 let conn = conn.lock().await;
2025 match update {
2026 UpdateStateChange::All(snap) => {
2027 conn.execute(
2028 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
2029 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
2030 date=excluded.date,seq=excluded.seq",
2031 libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
2032 )
2033 .await?;
2034 conn.execute("DELETE FROM channel_pts", ()).await?;
2035 for &(cid, cpts) in &snap.channels {
2036 conn.execute(
2037 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
2038 libsql::params![cid, cpts],
2039 )
2040 .await?;
2041 }
2042 Ok(())
2043 }
2044 UpdateStateChange::Primary { pts, date, seq } => conn
2045 .execute(
2046 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
2047 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
2048 seq=excluded.seq",
2049 libsql::params![pts, date, seq],
2050 )
2051 .await
2052 .map(|_| ()),
2053 UpdateStateChange::Secondary { qts } => conn
2054 .execute(
2055 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2056 ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2057 libsql::params![qts],
2058 )
2059 .await
2060 .map(|_| ()),
2061 UpdateStateChange::Channel { id, pts } => conn
2062 .execute(
2063 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2064 ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2065 libsql::params![id, pts],
2066 )
2067 .await
2068 .map(|_| ()),
2069 }
2070 })
2071 }
2072
2073 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2074 let conn = self.conn.clone();
2075 let (id, hash, is_ch) = (peer.id, peer.access_hash, peer.is_channel as i32);
2076 Self::block(async move {
2077 let conn = conn.lock().await;
2078 conn.execute(
2079 "INSERT INTO peers (id,access_hash,is_channel) VALUES (?1,?2,?3)
2080 ON CONFLICT(id) DO UPDATE SET
2081 access_hash=excluded.access_hash,
2082 is_channel=excluded.is_channel",
2083 libsql::params![id, hash, is_ch],
2084 )
2085 .await
2086 .map(|_| ())
2087 })
2088 }
2089}