1use std::collections::HashMap;
30use std::io::{self, ErrorKind};
31use std::path::Path;
32
33#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub struct DcFlags(pub u8);
39
40impl DcFlags {
41 pub const NONE: DcFlags = DcFlags(0);
42 pub const IPV6: DcFlags = DcFlags(1 << 0);
43 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
44 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
45 pub const CDN: DcFlags = DcFlags(1 << 3);
46 pub const STATIC: DcFlags = DcFlags(1 << 4);
47
48 pub fn contains(self, other: DcFlags) -> bool {
49 self.0 & other.0 == other.0
50 }
51
52 pub fn set(&mut self, flag: DcFlags) {
53 self.0 |= flag.0;
54 }
55}
56
57impl std::ops::BitOr for DcFlags {
58 type Output = DcFlags;
59 fn bitor(self, rhs: DcFlags) -> DcFlags {
60 DcFlags(self.0 | rhs.0)
61 }
62}
63
64#[derive(Clone, Debug)]
66#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
67pub struct DcEntry {
68 pub dc_id: i32,
69 pub addr: String,
70 pub auth_key: Option<[u8; 256]>,
71 pub first_salt: i64,
72 pub time_offset: i32,
73 pub flags: DcFlags,
75}
76
77impl DcEntry {
78 #[inline]
80 pub fn is_ipv6(&self) -> bool {
81 self.flags.contains(DcFlags::IPV6)
82 }
83
84 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
91 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
92 io::Error::new(
93 io::ErrorKind::InvalidData,
94 format!("invalid DC address: {:?}", self.addr),
95 )
96 })
97 }
98
99 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
112 let addr = if ip.contains(':') {
114 format!("[{ip}]:{port}")
115 } else {
116 format!("{ip}:{port}")
117 };
118 Self {
119 dc_id,
120 addr,
121 auth_key: None,
122 first_salt: 0,
123 time_offset: 0,
124 flags,
125 }
126 }
127}
128
129#[derive(Clone, Debug, Default)]
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133pub struct UpdatesStateSnap {
134 pub pts: i32,
136 pub qts: i32,
138 pub date: i32,
140 pub seq: i32,
142 pub channels: Vec<(i64, i32)>,
144}
145
146impl UpdatesStateSnap {
147 #[inline]
149 pub fn is_initialised(&self) -> bool {
150 self.pts > 0
151 }
152
153 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
155 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
156 entry.1 = pts;
157 } else {
158 self.channels.push((channel_id, pts));
159 }
160 }
161
162 pub fn channel_pts(&self, channel_id: i64) -> i32 {
164 self.channels
165 .iter()
166 .find(|c| c.0 == channel_id)
167 .map(|c| c.1)
168 .unwrap_or(0)
169 }
170}
171
172#[derive(Clone, Debug)]
175#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
176pub struct CachedPeer {
177 pub id: i64,
179 pub access_hash: i64,
182 pub is_channel: bool,
184 pub is_chat: bool,
187}
188
189#[derive(Clone, Debug)]
193#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
194pub struct CachedMinPeer {
195 pub user_id: i64,
197 pub peer_id: i64,
199 pub msg_id: i32,
201}
202
203#[derive(Clone, Debug, Default)]
205#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
206pub struct PersistedSession {
207 pub home_dc_id: i32,
208 pub dcs: Vec<DcEntry>,
209 pub updates_state: UpdatesStateSnap,
211 pub peers: Vec<CachedPeer>,
214 pub min_peers: Vec<CachedMinPeer>,
217}
218
219impl PersistedSession {
220 pub fn to_bytes(&self) -> Vec<u8> {
222 let mut b = Vec::with_capacity(512);
223
224 b.push(0x05u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
227
228 b.push(self.dcs.len() as u8);
229 for d in &self.dcs {
230 b.extend_from_slice(&d.dc_id.to_le_bytes());
231 match &d.auth_key {
232 Some(k) => {
233 b.push(1);
234 b.extend_from_slice(k);
235 }
236 None => {
237 b.push(0);
238 }
239 }
240 b.extend_from_slice(&d.first_salt.to_le_bytes());
241 b.extend_from_slice(&d.time_offset.to_le_bytes());
242 let ab = d.addr.as_bytes();
243 b.push(ab.len() as u8);
244 b.extend_from_slice(ab);
245 b.push(d.flags.0);
246 }
247
248 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
249 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
250 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
251 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
252 let ch = &self.updates_state.channels;
253 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
254 for &(cid, cpts) in ch {
255 b.extend_from_slice(&cid.to_le_bytes());
256 b.extend_from_slice(&cpts.to_le_bytes());
257 }
258
259 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
261 for p in &self.peers {
262 b.extend_from_slice(&p.id.to_le_bytes());
263 b.extend_from_slice(&p.access_hash.to_le_bytes());
264 let peer_type: u8 = if p.is_chat {
265 2
266 } else if p.is_channel {
267 1
268 } else {
269 0
270 };
271 b.push(peer_type);
272 }
273
274 b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
275 for m in &self.min_peers {
276 b.extend_from_slice(&m.user_id.to_le_bytes());
277 b.extend_from_slice(&m.peer_id.to_le_bytes());
278 b.extend_from_slice(&m.msg_id.to_le_bytes());
279 }
280
281 b
282 }
283
284 pub fn save(&self, path: &Path) -> io::Result<()> {
289 let tmp = path.with_extension("tmp");
290 std::fs::write(&tmp, self.to_bytes())?;
291 std::fs::rename(&tmp, path)
292 }
293
294 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
296 if buf.is_empty() {
297 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
298 }
299
300 let mut p = 0usize;
301
302 macro_rules! r {
303 ($n:expr) => {{
304 if p + $n > buf.len() {
305 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
306 }
307 let s = &buf[p..p + $n];
308 p += $n;
309 s
310 }};
311 }
312 macro_rules! r_i32 {
313 () => {
314 i32::from_le_bytes(r!(4).try_into().unwrap())
315 };
316 }
317 macro_rules! r_i64 {
318 () => {
319 i64::from_le_bytes(r!(8).try_into().unwrap())
320 };
321 }
322 macro_rules! r_u8 {
323 () => {
324 r!(1)[0]
325 };
326 }
327 macro_rules! r_u16 {
328 () => {
329 u16::from_le_bytes(r!(2).try_into().unwrap())
330 };
331 }
332
333 let first_byte = r_u8!();
334
335 let (home_dc_id, version) = if first_byte == 0x05 {
336 (r_i32!(), 5u8)
337 } else if first_byte == 0x04 {
338 (r_i32!(), 4u8)
339 } else if first_byte == 0x03 {
340 (r_i32!(), 3u8)
341 } else if first_byte == 0x02 {
342 (r_i32!(), 2u8)
343 } else {
344 let rest = r!(3);
345 let mut bytes = [0u8; 4];
346 bytes[0] = first_byte;
347 bytes[1..4].copy_from_slice(rest);
348 (i32::from_le_bytes(bytes), 1u8)
349 };
350
351 let dc_count = r_u8!() as usize;
352 let mut dcs = Vec::with_capacity(dc_count);
353 for _ in 0..dc_count {
354 let dc_id = r_i32!();
355 let has_key = r_u8!();
356 let auth_key = if has_key == 1 {
357 let mut k = [0u8; 256];
358 k.copy_from_slice(r!(256));
359 Some(k)
360 } else {
361 None
362 };
363 let first_salt = r_i64!();
364 let time_offset = r_i32!();
365 let al = r_u8!() as usize;
366 let addr = String::from_utf8_lossy(r!(al)).into_owned();
367 let flags = if version >= 3 {
368 DcFlags(r_u8!())
369 } else {
370 DcFlags::NONE
371 };
372 dcs.push(DcEntry {
373 dc_id,
374 addr,
375 auth_key,
376 first_salt,
377 time_offset,
378 flags,
379 });
380 }
381
382 if version < 2 {
383 return Ok(Self {
384 home_dc_id,
385 dcs,
386 updates_state: UpdatesStateSnap::default(),
387 peers: Vec::new(),
388 min_peers: Vec::new(),
389 });
390 }
391
392 let pts = r_i32!();
393 let qts = r_i32!();
394 let date = r_i32!();
395 let seq = r_i32!();
396 let ch_count = r_u16!() as usize;
397 let mut channels = Vec::with_capacity(ch_count);
398 for _ in 0..ch_count {
399 let cid = r_i64!();
400 let cpts = r_i32!();
401 channels.push((cid, cpts));
402 }
403
404 let peer_count = r_u16!() as usize;
405 let mut peers = Vec::with_capacity(peer_count);
406 for _ in 0..peer_count {
407 let id = r_i64!();
408 let access_hash = r_i64!();
409 let peer_type = r_u8!();
411 let is_channel = peer_type == 1;
412 let is_chat = peer_type == 2;
413 peers.push(CachedPeer {
414 id,
415 access_hash,
416 is_channel,
417 is_chat,
418 });
419 }
420
421 let min_peers = if version >= 4 {
423 let count = r_u16!() as usize;
424 let mut v = Vec::with_capacity(count);
425 for _ in 0..count {
426 let user_id = r_i64!();
427 let peer_id = r_i64!();
428 let msg_id = r_i32!();
429 v.push(CachedMinPeer {
430 user_id,
431 peer_id,
432 msg_id,
433 });
434 }
435 v
436 } else {
437 Vec::new()
438 };
439
440 Ok(Self {
441 home_dc_id,
442 dcs,
443 updates_state: UpdatesStateSnap {
444 pts,
445 qts,
446 date,
447 seq,
448 channels,
449 },
450 peers,
451 min_peers,
452 })
453 }
454
455 pub fn from_string(s: &str) -> io::Result<Self> {
457 use base64::Engine as _;
458 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
459 .decode(s.trim())
460 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
461 Self::from_bytes(&bytes)
462 }
463
464 pub fn load(path: &Path) -> io::Result<Self> {
465 let buf = std::fs::read(path)?;
466 Self::from_bytes(&buf)
467 }
468
469 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
481 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
482 candidates.peek()?;
483 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
485 cands
487 .iter()
488 .copied()
489 .find(|d| d.is_ipv6() == prefer_ipv6)
490 .or_else(|| cands.first().copied())
491 }
492
493 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
498 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
499 }
500}
501
502impl std::fmt::Display for PersistedSession {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 use base64::Engine as _;
505 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
506 }
507}
508
509pub fn default_dc_addresses() -> HashMap<i32, String> {
511 [
512 (1, "149.154.175.53:443"),
513 (2, "149.154.167.51:443"),
514 (3, "149.154.175.100:443"),
515 (4, "149.154.167.91:443"),
516 (5, "91.108.56.130:443"),
517 ]
518 .into_iter()
519 .map(|(id, addr)| (id, addr.to_string()))
520 .collect()
521}
522
523use std::path::PathBuf;
528
529pub trait SessionBackend: Send + Sync {
537 fn save(&self, session: &PersistedSession) -> io::Result<()>;
538 fn load(&self) -> io::Result<Option<PersistedSession>>;
539 fn delete(&self) -> io::Result<()>;
540
541 fn name(&self) -> &str;
543
544 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
558 let mut s = self.load()?.unwrap_or_default();
559 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
561 *existing = entry.clone();
562 } else {
563 s.dcs.push(entry.clone());
564 }
565 self.save(&s)
566 }
567
568 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
575 let mut s = self.load()?.unwrap_or_default();
576 s.home_dc_id = dc_id;
577 self.save(&s)
578 }
579
580 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
586 let mut s = self.load()?.unwrap_or_default();
587 update.apply_to(&mut s.updates_state);
588 self.save(&s)
589 }
590
591 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
598 let mut s = self.load()?.unwrap_or_default();
599 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
600 *existing = peer.clone();
601 } else {
602 s.peers.push(peer.clone());
603 }
604 self.save(&s)
605 }
606}
607
608#[derive(Debug, Clone)]
622pub enum UpdateStateChange {
623 All(UpdatesStateSnap),
625 Primary { pts: i32, date: i32, seq: i32 },
627 Secondary { qts: i32 },
629 Channel { id: i64, pts: i32 },
631}
632
633impl UpdateStateChange {
634 pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
636 match self {
637 Self::All(new_snap) => *snap = new_snap.clone(),
638 Self::Primary { pts, date, seq } => {
639 snap.pts = *pts;
640 snap.date = *date;
641 snap.seq = *seq;
642 }
643 Self::Secondary { qts } => {
644 snap.qts = *qts;
645 }
646 Self::Channel { id, pts } => {
647 if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
649 existing.1 = *pts;
650 } else {
651 snap.channels.push((*id, *pts));
652 }
653 }
654 }
655 }
656}
657
658pub struct BinaryFileBackend {
662 path: PathBuf,
663}
664
665impl BinaryFileBackend {
666 pub fn new(path: impl Into<PathBuf>) -> Self {
667 Self { path: path.into() }
668 }
669
670 pub fn path(&self) -> &std::path::Path {
671 &self.path
672 }
673}
674
675impl SessionBackend for BinaryFileBackend {
676 fn save(&self, session: &PersistedSession) -> io::Result<()> {
677 session.save(&self.path)
678 }
679
680 fn load(&self) -> io::Result<Option<PersistedSession>> {
681 if !self.path.exists() {
682 return Ok(None);
683 }
684 match PersistedSession::load(&self.path) {
685 Ok(s) => Ok(Some(s)),
686 Err(e) => {
687 let bak = self.path.with_extension("bak");
688 tracing::warn!(
689 "[ferogram] Session file {:?} is corrupt ({e}); \
690 renaming to {:?} and starting fresh",
691 self.path,
692 bak
693 );
694 let _ = std::fs::rename(&self.path, &bak);
695 Ok(None)
696 }
697 }
698 }
699
700 fn delete(&self) -> io::Result<()> {
701 if self.path.exists() {
702 std::fs::remove_file(&self.path)?;
703 }
704 Ok(())
705 }
706
707 fn name(&self) -> &str {
708 "binary-file"
709 }
710
711 }
714
715#[derive(Default)]
723pub struct InMemoryBackend {
724 data: std::sync::Mutex<Option<PersistedSession>>,
725}
726
727impl InMemoryBackend {
728 pub fn new() -> Self {
729 Self::default()
730 }
731
732 pub fn snapshot(&self) -> Option<PersistedSession> {
734 self.data.lock().unwrap().clone()
735 }
736}
737
738impl SessionBackend for InMemoryBackend {
739 fn save(&self, s: &PersistedSession) -> io::Result<()> {
740 *self.data.lock().unwrap() = Some(s.clone());
741 Ok(())
742 }
743
744 fn load(&self) -> io::Result<Option<PersistedSession>> {
745 Ok(self.data.lock().unwrap().clone())
746 }
747
748 fn delete(&self) -> io::Result<()> {
749 *self.data.lock().unwrap() = None;
750 Ok(())
751 }
752
753 fn name(&self) -> &str {
754 "in-memory"
755 }
756
757 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
760 let mut guard = self.data.lock().unwrap();
761 let s = guard.get_or_insert_with(PersistedSession::default);
762 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
763 *existing = entry.clone();
764 } else {
765 s.dcs.push(entry.clone());
766 }
767 Ok(())
768 }
769
770 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
771 let mut guard = self.data.lock().unwrap();
772 let s = guard.get_or_insert_with(PersistedSession::default);
773 s.home_dc_id = dc_id;
774 Ok(())
775 }
776
777 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
778 let mut guard = self.data.lock().unwrap();
779 let s = guard.get_or_insert_with(PersistedSession::default);
780 update.apply_to(&mut s.updates_state);
781 Ok(())
782 }
783
784 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
785 let mut guard = self.data.lock().unwrap();
786 let s = guard.get_or_insert_with(PersistedSession::default);
787 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
788 *existing = peer.clone();
789 } else {
790 s.peers.push(peer.clone());
791 }
792 Ok(())
793 }
794}
795
796pub struct StringSessionBackend {
800 data: std::sync::Mutex<String>,
801}
802
803impl StringSessionBackend {
804 pub fn new(s: impl Into<String>) -> Self {
805 Self {
806 data: std::sync::Mutex::new(s.into()),
807 }
808 }
809
810 pub fn current(&self) -> String {
811 self.data.lock().unwrap().clone()
812 }
813}
814
815impl SessionBackend for StringSessionBackend {
816 fn save(&self, session: &PersistedSession) -> io::Result<()> {
817 *self.data.lock().unwrap() = session.to_string();
818 Ok(())
819 }
820
821 fn load(&self) -> io::Result<Option<PersistedSession>> {
822 let s = self.data.lock().unwrap().clone();
823 if s.trim().is_empty() {
824 return Ok(None);
825 }
826 PersistedSession::from_string(&s).map(Some)
827 }
828
829 fn delete(&self) -> io::Result<()> {
830 *self.data.lock().unwrap() = String::new();
831 Ok(())
832 }
833
834 fn name(&self) -> &str {
835 "string-session"
836 }
837}
838
839#[cfg(test)]
842mod tests {
843 use super::*;
844
845 fn make_dc(id: i32) -> DcEntry {
846 DcEntry {
847 dc_id: id,
848 addr: format!("1.2.3.{id}:443"),
849 auth_key: None,
850 first_salt: 0,
851 time_offset: 0,
852 flags: DcFlags::NONE,
853 }
854 }
855
856 fn make_peer(id: i64, hash: i64) -> CachedPeer {
857 CachedPeer {
858 id,
859 access_hash: hash,
860 is_channel: false,
861 is_chat: false,
862 }
863 }
864
865 #[test]
868 fn inmemory_load_returns_none_when_empty() {
869 let b = InMemoryBackend::new();
870 assert!(b.load().unwrap().is_none());
871 }
872
873 #[test]
874 fn inmemory_save_then_load_round_trips() {
875 let b = InMemoryBackend::new();
876 let mut s = PersistedSession::default();
877 s.home_dc_id = 3;
878 s.dcs.push(make_dc(3));
879 b.save(&s).unwrap();
880
881 let loaded = b.load().unwrap().unwrap();
882 assert_eq!(loaded.home_dc_id, 3);
883 assert_eq!(loaded.dcs.len(), 1);
884 }
885
886 #[test]
887 fn inmemory_delete_clears_state() {
888 let b = InMemoryBackend::new();
889 let mut s = PersistedSession::default();
890 s.home_dc_id = 2;
891 b.save(&s).unwrap();
892 b.delete().unwrap();
893 assert!(b.load().unwrap().is_none());
894 }
895
896 #[test]
899 fn inmemory_update_dc_inserts_new() {
900 let b = InMemoryBackend::new();
901 b.update_dc(&make_dc(4)).unwrap();
902 let s = b.snapshot().unwrap();
903 assert_eq!(s.dcs.len(), 1);
904 assert_eq!(s.dcs[0].dc_id, 4);
905 }
906
907 #[test]
908 fn inmemory_update_dc_replaces_existing() {
909 let b = InMemoryBackend::new();
910 b.update_dc(&make_dc(2)).unwrap();
911 let mut updated = make_dc(2);
912 updated.addr = "9.9.9.9:443".to_string();
913 b.update_dc(&updated).unwrap();
914
915 let s = b.snapshot().unwrap();
916 assert_eq!(s.dcs.len(), 1);
917 assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
918 }
919
920 #[test]
921 fn inmemory_set_home_dc() {
922 let b = InMemoryBackend::new();
923 b.set_home_dc(5).unwrap();
924 assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
925 }
926
927 #[test]
928 fn inmemory_cache_peer_inserts() {
929 let b = InMemoryBackend::new();
930 b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
931 let s = b.snapshot().unwrap();
932 assert_eq!(s.peers.len(), 1);
933 assert_eq!(s.peers[0].id, 100);
934 }
935
936 #[test]
937 fn inmemory_cache_peer_updates_existing() {
938 let b = InMemoryBackend::new();
939 b.cache_peer(&make_peer(100, 111)).unwrap();
940 b.cache_peer(&make_peer(100, 222)).unwrap();
941 let s = b.snapshot().unwrap();
942 assert_eq!(s.peers.len(), 1);
943 assert_eq!(s.peers[0].access_hash, 222);
944 }
945
946 #[test]
949 fn update_state_primary() {
950 let mut snap = UpdatesStateSnap {
951 pts: 0,
952 qts: 0,
953 date: 0,
954 seq: 0,
955 channels: vec![],
956 };
957 UpdateStateChange::Primary {
958 pts: 10,
959 date: 20,
960 seq: 30,
961 }
962 .apply_to(&mut snap);
963 assert_eq!(snap.pts, 10);
964 assert_eq!(snap.date, 20);
965 assert_eq!(snap.seq, 30);
966 assert_eq!(snap.qts, 0); }
968
969 #[test]
970 fn update_state_secondary() {
971 let mut snap = UpdatesStateSnap {
972 pts: 5,
973 qts: 0,
974 date: 0,
975 seq: 0,
976 channels: vec![],
977 };
978 UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
979 assert_eq!(snap.qts, 99);
980 assert_eq!(snap.pts, 5); }
982
983 #[test]
984 fn update_state_channel_inserts() {
985 let mut snap = UpdatesStateSnap {
986 pts: 0,
987 qts: 0,
988 date: 0,
989 seq: 0,
990 channels: vec![],
991 };
992 UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
993 assert_eq!(snap.channels, vec![(12345, 42)]);
994 }
995
996 #[test]
997 fn update_state_channel_updates_existing() {
998 let mut snap = UpdatesStateSnap {
999 pts: 0,
1000 qts: 0,
1001 date: 0,
1002 seq: 0,
1003 channels: vec![(12345, 10), (67890, 5)],
1004 };
1005 UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1006 assert_eq!(snap.channels[0], (12345, 99));
1008 assert_eq!(snap.channels[1], (67890, 5));
1009 }
1010
1011 #[test]
1012 fn apply_update_state_via_backend() {
1013 let b = InMemoryBackend::new();
1014 b.apply_update_state(UpdateStateChange::Primary {
1015 pts: 7,
1016 date: 8,
1017 seq: 9,
1018 })
1019 .unwrap();
1020 let s = b.snapshot().unwrap();
1021 assert_eq!(s.updates_state.pts, 7);
1022 }
1023
1024 #[test]
1027 fn default_update_dc_via_trait_object() {
1028 let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1029 b.update_dc(&make_dc(1)).unwrap();
1030 b.update_dc(&make_dc(2)).unwrap();
1031 let loaded = b.load().unwrap().unwrap();
1033 assert_eq!(loaded.dcs.len(), 2);
1034 }
1035
1036 fn make_dc_v6(id: i32) -> DcEntry {
1039 DcEntry {
1040 dc_id: id,
1041 addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1042 auth_key: None,
1043 first_salt: 0,
1044 time_offset: 0,
1045 flags: DcFlags::IPV6,
1046 }
1047 }
1048
1049 #[test]
1050 fn dc_entry_from_parts_ipv4() {
1051 let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1052 assert_eq!(dc.addr, "149.154.175.53:443");
1053 assert!(!dc.is_ipv6());
1054 let sa = dc.socket_addr().unwrap();
1055 assert_eq!(sa.port(), 443);
1056 }
1057
1058 #[test]
1059 fn dc_entry_from_parts_ipv6() {
1060 let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1061 assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1062 assert!(dc.is_ipv6());
1063 let sa = dc.socket_addr().unwrap();
1064 assert_eq!(sa.port(), 443);
1065 }
1066
1067 #[test]
1068 fn persisted_session_dc_for_prefers_ipv6() {
1069 let mut s = PersistedSession::default();
1070 s.dcs.push(make_dc(2)); s.dcs.push(make_dc_v6(2)); let v6 = s.dc_for(2, true).unwrap();
1074 assert!(v6.is_ipv6());
1075
1076 let v4 = s.dc_for(2, false).unwrap();
1077 assert!(!v4.is_ipv6());
1078 }
1079
1080 #[test]
1081 fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1082 let mut s = PersistedSession::default();
1083 s.dcs.push(make_dc(3)); let dc = s.dc_for(3, true).unwrap();
1087 assert!(!dc.is_ipv6());
1088 }
1089
1090 #[test]
1091 fn persisted_session_all_dcs_for_returns_both() {
1092 let mut s = PersistedSession::default();
1093 s.dcs.push(make_dc(1));
1094 s.dcs.push(make_dc_v6(1));
1095 s.dcs.push(make_dc(2));
1096
1097 assert_eq!(s.all_dcs_for(1).count(), 2);
1098 assert_eq!(s.all_dcs_for(2).count(), 1);
1099 assert_eq!(s.all_dcs_for(5).count(), 0);
1100 }
1101
1102 #[test]
1103 fn inmemory_ipv4_and_ipv6_coexist() {
1104 let b = InMemoryBackend::new();
1105 b.update_dc(&make_dc(2)).unwrap(); b.update_dc(&make_dc_v6(2)).unwrap(); let s = b.snapshot().unwrap();
1109 assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1111 }
1112
1113 #[test]
1114 fn binary_roundtrip_ipv4_and_ipv6() {
1115 let mut s = PersistedSession::default();
1116 s.home_dc_id = 2;
1117 s.dcs.push(make_dc(2));
1118 s.dcs.push(make_dc_v6(2));
1119
1120 let bytes = s.to_bytes();
1121 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1122 assert_eq!(loaded.dcs.len(), 2);
1123 assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1124 assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1125 }
1126}
1127
1128#[cfg(feature = "sqlite-session")]
1153pub struct SqliteBackend {
1154 conn: std::sync::Mutex<rusqlite::Connection>,
1155 label: String,
1156}
1157
1158#[cfg(feature = "sqlite-session")]
1159impl SqliteBackend {
1160 const SCHEMA: &'static str = "
1161 PRAGMA journal_mode = WAL;
1162 PRAGMA synchronous = NORMAL;
1163
1164 CREATE TABLE IF NOT EXISTS meta (
1165 key TEXT PRIMARY KEY,
1166 value INTEGER NOT NULL DEFAULT 0
1167 );
1168
1169 CREATE TABLE IF NOT EXISTS dcs (
1170 dc_id INTEGER NOT NULL,
1171 flags INTEGER NOT NULL DEFAULT 0,
1172 addr TEXT NOT NULL,
1173 auth_key BLOB,
1174 first_salt INTEGER NOT NULL DEFAULT 0,
1175 time_offset INTEGER NOT NULL DEFAULT 0,
1176 PRIMARY KEY (dc_id, flags)
1177 );
1178
1179 CREATE TABLE IF NOT EXISTS update_state (
1180 id INTEGER PRIMARY KEY CHECK (id = 1),
1181 pts INTEGER NOT NULL DEFAULT 0,
1182 qts INTEGER NOT NULL DEFAULT 0,
1183 date INTEGER NOT NULL DEFAULT 0,
1184 seq INTEGER NOT NULL DEFAULT 0
1185 );
1186
1187 CREATE TABLE IF NOT EXISTS channel_pts (
1188 channel_id INTEGER PRIMARY KEY,
1189 pts INTEGER NOT NULL
1190 );
1191
1192 CREATE TABLE IF NOT EXISTS peers (
1193 id INTEGER PRIMARY KEY,
1194 access_hash INTEGER NOT NULL,
1195 is_channel INTEGER NOT NULL DEFAULT 0
1196 );
1197 ";
1198
1199 pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1201 let path = path.into();
1202 let label = path.display().to_string();
1203 let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1204 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1205 Ok(Self {
1206 conn: std::sync::Mutex::new(conn),
1207 label,
1208 })
1209 }
1210
1211 pub fn in_memory() -> io::Result<Self> {
1213 let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1214 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1215 Ok(Self {
1216 conn: std::sync::Mutex::new(conn),
1217 label: ":memory:".into(),
1218 })
1219 }
1220
1221 fn map_err(e: rusqlite::Error) -> io::Error {
1222 io::Error::other(e)
1223 }
1224
1225 fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1227 let home_dc_id: i32 = conn
1229 .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1230 r.get(0)
1231 })
1232 .unwrap_or(0);
1233
1234 let mut stmt = conn
1236 .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1237 .map_err(Self::map_err)?;
1238 let dcs = stmt
1239 .query_map([], |row| {
1240 let dc_id: i32 = row.get(0)?;
1241 let flags_raw: u8 = row.get(1)?;
1242 let addr: String = row.get(2)?;
1243 let key_blob: Option<Vec<u8>> = row.get(3)?;
1244 let first_salt: i64 = row.get(4)?;
1245 let time_offset: i32 = row.get(5)?;
1246 Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1247 })
1248 .map_err(Self::map_err)?
1249 .filter_map(|r| r.ok())
1250 .map(
1251 |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1252 let auth_key = key_blob.and_then(|b| {
1253 if b.len() == 256 {
1254 let mut k = [0u8; 256];
1255 k.copy_from_slice(&b);
1256 Some(k)
1257 } else {
1258 None
1259 }
1260 });
1261 DcEntry {
1262 dc_id,
1263 addr,
1264 auth_key,
1265 first_salt,
1266 time_offset,
1267 flags: DcFlags(flags_raw),
1268 }
1269 },
1270 )
1271 .collect();
1272
1273 let updates_state = conn
1275 .query_row(
1276 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1277 [],
1278 |r| {
1279 Ok(UpdatesStateSnap {
1280 pts: r.get(0)?,
1281 qts: r.get(1)?,
1282 date: r.get(2)?,
1283 seq: r.get(3)?,
1284 channels: vec![],
1285 })
1286 },
1287 )
1288 .unwrap_or_default();
1289
1290 let mut ch_stmt = conn
1292 .prepare("SELECT channel_id, pts FROM channel_pts")
1293 .map_err(Self::map_err)?;
1294 let channels: Vec<(i64, i32)> = ch_stmt
1295 .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1296 .map_err(Self::map_err)?
1297 .filter_map(|r| r.ok())
1298 .collect();
1299
1300 let mut peer_stmt = conn
1302 .prepare("SELECT id, access_hash, is_channel FROM peers")
1303 .map_err(Self::map_err)?;
1304 let peers: Vec<CachedPeer> = peer_stmt
1305 .query_map([], |r| {
1306 Ok(CachedPeer {
1307 id: r.get(0)?,
1308 access_hash: r.get(1)?,
1309 is_channel: r.get::<_, i32>(2)? != 0,
1310 is_chat: false,
1311 })
1312 })
1313 .map_err(Self::map_err)?
1314 .filter_map(|r| r.ok())
1315 .collect();
1316
1317 Ok(PersistedSession {
1318 home_dc_id,
1319 dcs,
1320 updates_state: UpdatesStateSnap {
1321 channels,
1322 ..updates_state
1323 },
1324 peers,
1325 min_peers: Vec::new(),
1326 })
1327 }
1328
1329 fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1331 conn.execute_batch("BEGIN IMMEDIATE")
1332 .map_err(Self::map_err)?;
1333
1334 conn.execute(
1335 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1336 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1337 rusqlite::params![s.home_dc_id],
1338 )
1339 .map_err(Self::map_err)?;
1340
1341 conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1343 for d in &s.dcs {
1344 conn.execute(
1345 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1346 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1347 rusqlite::params![
1348 d.dc_id,
1349 d.flags.0,
1350 d.addr,
1351 d.auth_key.as_ref().map(|k| k.as_ref()),
1352 d.first_salt,
1353 d.time_offset,
1354 ],
1355 )
1356 .map_err(Self::map_err)?;
1357 }
1358
1359 let us = &s.updates_state;
1363 conn.execute(
1364 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1365 ON CONFLICT(id) DO UPDATE SET
1366 pts = MAX(excluded.pts, update_state.pts),
1367 qts = MAX(excluded.qts, update_state.qts),
1368 date = excluded.date,
1369 seq = excluded.seq",
1370 rusqlite::params![us.pts, us.qts, us.date, us.seq],
1371 )
1372 .map_err(Self::map_err)?;
1373
1374 conn.execute("DELETE FROM channel_pts", [])
1375 .map_err(Self::map_err)?;
1376 for &(cid, cpts) in &us.channels {
1377 conn.execute(
1378 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1379 rusqlite::params![cid, cpts],
1380 )
1381 .map_err(Self::map_err)?;
1382 }
1383
1384 conn.execute("DELETE FROM peers", [])
1386 .map_err(Self::map_err)?;
1387 for p in &s.peers {
1388 conn.execute(
1389 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)",
1390 rusqlite::params![p.id, p.access_hash, p.is_channel as i32],
1391 )
1392 .map_err(Self::map_err)?;
1393 }
1394
1395 conn.execute_batch("COMMIT").map_err(Self::map_err)
1396 }
1397}
1398
1399#[cfg(feature = "sqlite-session")]
1400impl SessionBackend for SqliteBackend {
1401 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1402 let conn = self.conn.lock().unwrap();
1403 Self::write_session(&conn, session)
1404 }
1405
1406 fn load(&self) -> io::Result<Option<PersistedSession>> {
1407 let conn = self.conn.lock().unwrap();
1408 let count: i64 = conn
1410 .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1411 .map_err(Self::map_err)?;
1412 if count == 0 {
1413 return Ok(None);
1414 }
1415 Self::read_session(&conn).map(Some)
1416 }
1417
1418 fn delete(&self) -> io::Result<()> {
1419 let conn = self.conn.lock().unwrap();
1420 conn.execute_batch(
1421 "BEGIN IMMEDIATE;
1422 DELETE FROM meta;
1423 DELETE FROM dcs;
1424 DELETE FROM update_state;
1425 DELETE FROM channel_pts;
1426 DELETE FROM peers;
1427 COMMIT;",
1428 )
1429 .map_err(Self::map_err)
1430 }
1431
1432 fn name(&self) -> &str {
1433 &self.label
1434 }
1435
1436 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1439 let conn = self.conn.lock().unwrap();
1440 conn.execute(
1441 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1442 VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1443 ON CONFLICT(dc_id, flags) DO UPDATE SET
1444 addr = excluded.addr,
1445 auth_key = excluded.auth_key,
1446 first_salt = excluded.first_salt,
1447 time_offset = excluded.time_offset",
1448 rusqlite::params![
1449 entry.dc_id,
1450 entry.addr,
1451 entry.auth_key.as_ref().map(|k| k.as_ref()),
1452 entry.first_salt,
1453 entry.time_offset,
1454 entry.flags.0,
1455 ],
1456 )
1457 .map(|_| ())
1458 .map_err(Self::map_err)
1459 }
1460
1461 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1462 let conn = self.conn.lock().unwrap();
1463 conn.execute(
1464 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1465 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1466 rusqlite::params![dc_id],
1467 )
1468 .map(|_| ())
1469 .map_err(Self::map_err)
1470 }
1471
1472 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1473 let conn = self.conn.lock().unwrap();
1474 match update {
1475 UpdateStateChange::All(snap) => {
1476 conn.execute(
1477 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1478 ON CONFLICT(id) DO UPDATE SET
1479 pts=excluded.pts, qts=excluded.qts,
1480 date=excluded.date, seq=excluded.seq",
1481 rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1482 )
1483 .map_err(Self::map_err)?;
1484 conn.execute("DELETE FROM channel_pts", [])
1485 .map_err(Self::map_err)?;
1486 for &(cid, cpts) in &snap.channels {
1487 conn.execute(
1488 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1489 rusqlite::params![cid, cpts],
1490 )
1491 .map_err(Self::map_err)?;
1492 }
1493 Ok(())
1494 }
1495 UpdateStateChange::Primary { pts, date, seq } => conn
1496 .execute(
1497 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1498 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1499 seq=excluded.seq",
1500 rusqlite::params![pts, date, seq],
1501 )
1502 .map(|_| ())
1503 .map_err(Self::map_err),
1504 UpdateStateChange::Secondary { qts } => conn
1505 .execute(
1506 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1507 ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1508 rusqlite::params![qts],
1509 )
1510 .map(|_| ())
1511 .map_err(Self::map_err),
1512 UpdateStateChange::Channel { id, pts } => conn
1513 .execute(
1514 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1515 ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1516 rusqlite::params![id, pts],
1517 )
1518 .map(|_| ())
1519 .map_err(Self::map_err),
1520 }
1521 }
1522
1523 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1524 let conn = self.conn.lock().unwrap();
1525 conn.execute(
1526 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)
1527 ON CONFLICT(id) DO UPDATE SET
1528 access_hash = excluded.access_hash,
1529 is_channel = excluded.is_channel",
1530 rusqlite::params![peer.id, peer.access_hash, peer.is_channel as i32],
1531 )
1532 .map(|_| ())
1533 .map_err(Self::map_err)
1534 }
1535}
1536
1537#[cfg(feature = "libsql-session")]
1557pub struct LibSqlBackend {
1558 conn: libsql::Connection,
1559 label: String,
1560}
1561
1562#[cfg(feature = "libsql-session")]
1563impl LibSqlBackend {
1564 const SCHEMA: &'static str = "
1565 CREATE TABLE IF NOT EXISTS meta (
1566 key TEXT PRIMARY KEY,
1567 value INTEGER NOT NULL DEFAULT 0
1568 );
1569 CREATE TABLE IF NOT EXISTS dcs (
1570 dc_id INTEGER NOT NULL,
1571 flags INTEGER NOT NULL DEFAULT 0,
1572 addr TEXT NOT NULL,
1573 auth_key BLOB,
1574 first_salt INTEGER NOT NULL DEFAULT 0,
1575 time_offset INTEGER NOT NULL DEFAULT 0,
1576 PRIMARY KEY (dc_id, flags)
1577 );
1578 CREATE TABLE IF NOT EXISTS update_state (
1579 id INTEGER PRIMARY KEY CHECK (id = 1),
1580 pts INTEGER NOT NULL DEFAULT 0,
1581 qts INTEGER NOT NULL DEFAULT 0,
1582 date INTEGER NOT NULL DEFAULT 0,
1583 seq INTEGER NOT NULL DEFAULT 0
1584 );
1585 CREATE TABLE IF NOT EXISTS channel_pts (
1586 channel_id INTEGER PRIMARY KEY,
1587 pts INTEGER NOT NULL
1588 );
1589 CREATE TABLE IF NOT EXISTS peers (
1590 id INTEGER PRIMARY KEY,
1591 access_hash INTEGER NOT NULL,
1592 is_channel INTEGER NOT NULL DEFAULT 0
1593 );
1594 ";
1595
1596 fn block<F, T>(fut: F) -> io::Result<T>
1597 where
1598 F: std::future::Future<Output = Result<T, libsql::Error>>,
1599 {
1600 tokio::runtime::Handle::current()
1601 .block_on(fut)
1602 .map_err(io::Error::other)
1603 }
1604
1605 async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1606 conn.execute_batch(Self::SCHEMA).await
1607 }
1608
1609 pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1611 let path = path.into();
1612 let label = path.display().to_string();
1613 let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1614 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1615 Self::block(Self::apply_schema(&conn))?;
1616 Ok(Self {
1617 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1618 label,
1619 })
1620 }
1621
1622 pub fn in_memory() -> io::Result<Self> {
1624 let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1625 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1626 Self::block(Self::apply_schema(&conn))?;
1627 Ok(Self {
1628 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1629 label: ":memory:".into(),
1630 })
1631 }
1632
1633 pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1635 let url = url.into();
1636 let label = url.clone();
1637 let db = Self::block(async {
1638 libsql::Builder::new_remote(url, auth_token.into())
1639 .build()
1640 .await
1641 })?;
1642 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1643 Self::block(Self::apply_schema(&conn))?;
1644 Ok(Self {
1645 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1646 label,
1647 })
1648 }
1649
1650 pub fn open_replica(
1652 path: impl Into<PathBuf>,
1653 url: impl Into<String>,
1654 auth_token: impl Into<String>,
1655 ) -> io::Result<Self> {
1656 let path = path.into();
1657 let label = format!("{} (replica of {})", path.display(), url.into());
1658 let db = Self::block(async {
1659 libsql::Builder::new_remote_replica(path, url.into(), auth_token.into())
1660 .build()
1661 .await
1662 })?;
1663 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1664 Self::block(Self::apply_schema(&conn))?;
1665 Ok(Self {
1666 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1667 label,
1668 })
1669 }
1670
1671 async fn read_session_async(
1672 conn: &libsql::Connection,
1673 ) -> Result<PersistedSession, libsql::Error> {
1674 use libsql::de;
1675
1676 let home_dc_id: i32 = conn
1678 .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
1679 .await?
1680 .next()
1681 .await?
1682 .map(|r| r.get::<i32>(0))
1683 .transpose()?
1684 .unwrap_or(0);
1685
1686 let mut rows = conn
1688 .query(
1689 "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
1690 (),
1691 )
1692 .await?;
1693 let mut dcs = Vec::new();
1694 while let Some(row) = rows.next().await? {
1695 let dc_id: i32 = row.get(0)?;
1696 let flags_raw: u8 = row.get::<i64>(1)? as u8;
1697 let addr: String = row.get(2)?;
1698 let key_blob: Option<Vec<u8>> = row.get(3)?;
1699 let first_salt: i64 = row.get(4)?;
1700 let time_offset: i32 = row.get(5)?;
1701 let auth_key = match key_blob {
1702 Some(b) if b.len() == 256 => {
1703 let mut k = [0u8; 256];
1704 k.copy_from_slice(&b);
1705 Some(k)
1706 }
1707 Some(b) => {
1708 return Err(libsql::Error::Misuse(format!(
1709 "auth_key blob must be 256 bytes, got {}",
1710 b.len()
1711 )));
1712 }
1713 None => None,
1714 };
1715 dcs.push(DcEntry {
1716 dc_id,
1717 addr,
1718 auth_key,
1719 first_salt,
1720 time_offset,
1721 flags: DcFlags(flags_raw),
1722 });
1723 }
1724
1725 let mut us_row = conn
1727 .query(
1728 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1729 (),
1730 )
1731 .await?;
1732 let updates_state = if let Some(r) = us_row.next().await? {
1733 UpdatesStateSnap {
1734 pts: r.get(0)?,
1735 qts: r.get(1)?,
1736 date: r.get(2)?,
1737 seq: r.get(3)?,
1738 channels: vec![],
1739 }
1740 } else {
1741 UpdatesStateSnap::default()
1742 };
1743
1744 let mut ch_rows = conn
1746 .query("SELECT channel_id, pts FROM channel_pts", ())
1747 .await?;
1748 let mut channels = Vec::new();
1749 while let Some(r) = ch_rows.next().await? {
1750 channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
1751 }
1752
1753 let mut peer_rows = conn
1755 .query("SELECT id, access_hash, is_channel FROM peers", ())
1756 .await?;
1757 let mut peers = Vec::new();
1758 while let Some(r) = peer_rows.next().await? {
1759 peers.push(CachedPeer {
1760 id: r.get(0)?,
1761 access_hash: r.get(1)?,
1762 is_channel: r.get::<i32>(2)? != 0,
1763 is_chat: false,
1764 });
1765 }
1766
1767 Ok(PersistedSession {
1768 home_dc_id,
1769 dcs,
1770 updates_state: UpdatesStateSnap {
1771 channels,
1772 ..updates_state
1773 },
1774 peers,
1775 min_peers: Vec::new(),
1776 })
1777 }
1778
1779 async fn write_session_async(
1780 conn: &libsql::Connection,
1781 s: &PersistedSession,
1782 ) -> Result<(), libsql::Error> {
1783 conn.execute_batch("BEGIN IMMEDIATE").await?;
1784
1785 conn.execute(
1786 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1787 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1788 libsql::params![s.home_dc_id],
1789 )
1790 .await?;
1791
1792 conn.execute("DELETE FROM dcs", ()).await?;
1793 for d in &s.dcs {
1794 conn.execute(
1795 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1796 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1797 libsql::params![
1798 d.dc_id,
1799 d.flags.0 as i64,
1800 d.addr.clone(),
1801 d.auth_key.map(|k| k.to_vec()),
1802 d.first_salt,
1803 d.time_offset,
1804 ],
1805 )
1806 .await?;
1807 }
1808
1809 let us = &s.updates_state;
1810 conn.execute(
1811 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1812 ON CONFLICT(id) DO UPDATE SET
1813 pts = MAX(excluded.pts, update_state.pts),
1814 qts = MAX(excluded.qts, update_state.qts),
1815 date = excluded.date,
1816 seq = excluded.seq",
1817 libsql::params![us.pts, us.qts, us.date, us.seq],
1818 )
1819 .await?;
1820
1821 conn.execute("DELETE FROM channel_pts", ()).await?;
1822 for &(cid, cpts) in &us.channels {
1823 conn.execute(
1824 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
1825 libsql::params![cid, cpts],
1826 )
1827 .await?;
1828 }
1829
1830 conn.execute("DELETE FROM peers", ()).await?;
1831 for p in &s.peers {
1832 conn.execute(
1833 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1,?2,?3)",
1834 libsql::params![p.id, p.access_hash, p.is_channel as i32],
1835 )
1836 .await?;
1837 }
1838
1839 conn.execute_batch("COMMIT").await
1840 }
1841}
1842
1843#[cfg(feature = "libsql-session")]
1844impl SessionBackend for LibSqlBackend {
1845 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1846 let conn = self.conn.clone();
1847 let session = session.clone();
1848 Self::block(async move {
1849 let conn = conn.lock().await;
1850 Self::write_session_async(&conn, session).await
1851 })
1852 }
1853
1854 fn load(&self) -> io::Result<Option<PersistedSession>> {
1855 let conn = self.conn.clone();
1856 let count: i64 = Self::block(async move {
1857 let conn = conn.lock().await;
1858 let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
1859 Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
1860 })?;
1861 if count == 0 {
1862 return Ok(None);
1863 }
1864 let conn = self.conn.clone();
1865 Self::block(async move {
1866 let conn = conn.lock().await;
1867 Self::read_session_async(&conn).await
1868 })
1869 .map(Some)
1870 }
1871
1872 fn delete(&self) -> io::Result<()> {
1873 let conn = self.conn.clone();
1874 Self::block(async move {
1875 let conn = conn.lock().await;
1876 conn.execute_batch(
1877 "BEGIN IMMEDIATE;
1878 DELETE FROM meta;
1879 DELETE FROM dcs;
1880 DELETE FROM update_state;
1881 DELETE FROM channel_pts;
1882 DELETE FROM peers;
1883 COMMIT;",
1884 )
1885 .await
1886 })
1887 }
1888
1889 fn name(&self) -> &str {
1890 &self.label
1891 }
1892
1893 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1896 let conn = self.conn.clone();
1897 let (dc_id, addr, key, salt, off, flags) = (
1898 entry.dc_id,
1899 entry.addr.clone(),
1900 entry.auth_key.map(|k| k.to_vec()),
1901 entry.first_salt,
1902 entry.time_offset,
1903 entry.flags.0 as i64,
1904 );
1905 Self::block(async move {
1906 let conn = conn.lock().await;
1907 conn.execute(
1908 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1909 VALUES (?1,?6,?2,?3,?4,?5)
1910 ON CONFLICT(dc_id, flags) DO UPDATE SET
1911 addr=excluded.addr, auth_key=excluded.auth_key,
1912 first_salt=excluded.first_salt, time_offset=excluded.time_offset",
1913 libsql::params![dc_id, addr, key, salt, off, flags],
1914 )
1915 .await
1916 .map(|_| ())
1917 })
1918 }
1919
1920 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1921 let conn = self.conn.clone();
1922 Self::block(async move {
1923 let conn = conn.lock().await;
1924 conn.execute(
1925 "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
1926 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
1927 libsql::params![dc_id],
1928 )
1929 .await
1930 .map(|_| ())
1931 })
1932 }
1933
1934 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1935 let conn = self.conn.clone();
1936 Self::block(async move {
1937 let conn = conn.lock().await;
1938 match update {
1939 UpdateStateChange::All(snap) => {
1940 conn.execute(
1941 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
1942 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
1943 date=excluded.date,seq=excluded.seq",
1944 libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
1945 )
1946 .await?;
1947 conn.execute("DELETE FROM channel_pts", ()).await?;
1948 for &(cid, cpts) in &snap.channels {
1949 conn.execute(
1950 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
1951 libsql::params![cid, cpts],
1952 )
1953 .await?;
1954 }
1955 Ok(())
1956 }
1957 UpdateStateChange::Primary { pts, date, seq } => conn
1958 .execute(
1959 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
1960 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
1961 seq=excluded.seq",
1962 libsql::params![pts, date, seq],
1963 )
1964 .await
1965 .map(|_| ()),
1966 UpdateStateChange::Secondary { qts } => conn
1967 .execute(
1968 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
1969 ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
1970 libsql::params![qts],
1971 )
1972 .await
1973 .map(|_| ()),
1974 UpdateStateChange::Channel { id, pts } => conn
1975 .execute(
1976 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
1977 ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
1978 libsql::params![id, pts],
1979 )
1980 .await
1981 .map(|_| ()),
1982 }
1983 })
1984 }
1985
1986 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1987 let conn = self.conn.clone();
1988 let (id, hash, is_ch) = (peer.id, peer.access_hash, peer.is_channel as i32);
1989 Self::block(async move {
1990 let conn = conn.lock().await;
1991 conn.execute(
1992 "INSERT INTO peers (id,access_hash,is_channel) VALUES (?1,?2,?3)
1993 ON CONFLICT(id) DO UPDATE SET
1994 access_hash=excluded.access_hash,
1995 is_channel=excluded.is_channel",
1996 libsql::params![id, hash, is_ch],
1997 )
1998 .await
1999 .map(|_| ())
2000 })
2001 }
2002}