1use std::collections::HashMap;
30use std::io::{self, ErrorKind};
31use std::path::Path;
32
33#[cfg(feature = "serde")]
34mod auth_key_serde {
35 use serde::{Deserialize, Deserializer, Serializer};
36
37 pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
38 where
39 S: Serializer,
40 {
41 match value {
42 Some(k) => s.serialize_some(k.as_slice()),
43 None => s.serialize_none(),
44 }
45 }
46
47 pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
48 where
49 D: Deserializer<'de>,
50 {
51 let opt: Option<Vec<u8>> = Option::deserialize(d)?;
52 match opt {
53 None => Ok(None),
54 Some(v) => {
55 let arr: [u8; 256] = v
56 .try_into()
57 .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
58 Ok(Some(arr))
59 }
60 }
61 }
62}
63
64#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
68#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
69pub struct DcFlags(pub u8);
70
71impl DcFlags {
72 pub const NONE: DcFlags = DcFlags(0);
73 pub const IPV6: DcFlags = DcFlags(1 << 0);
74 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
75 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
76 pub const CDN: DcFlags = DcFlags(1 << 3);
77 pub const STATIC: DcFlags = DcFlags(1 << 4);
78
79 pub fn contains(self, other: DcFlags) -> bool {
80 self.0 & other.0 == other.0
81 }
82
83 pub fn set(&mut self, flag: DcFlags) {
84 self.0 |= flag.0;
85 }
86}
87
88impl std::ops::BitOr for DcFlags {
89 type Output = DcFlags;
90 fn bitor(self, rhs: DcFlags) -> DcFlags {
91 DcFlags(self.0 | rhs.0)
92 }
93}
94
95#[derive(Clone, Debug)]
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98pub struct DcEntry {
99 pub dc_id: i32,
100 pub addr: String,
101 #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
102 pub auth_key: Option<[u8; 256]>,
103 pub first_salt: i64,
104 pub time_offset: i32,
105 pub flags: DcFlags,
107}
108
109impl DcEntry {
110 #[inline]
112 pub fn is_ipv6(&self) -> bool {
113 self.flags.contains(DcFlags::IPV6)
114 }
115
116 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
123 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
124 io::Error::new(
125 io::ErrorKind::InvalidData,
126 format!("invalid DC address: {:?}", self.addr),
127 )
128 })
129 }
130
131 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
144 let addr = if ip.contains(':') {
146 format!("[{ip}]:{port}")
147 } else {
148 format!("{ip}:{port}")
149 };
150 Self {
151 dc_id,
152 addr,
153 auth_key: None,
154 first_salt: 0,
155 time_offset: 0,
156 flags,
157 }
158 }
159}
160
161#[derive(Clone, Debug, Default)]
164#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
165pub struct UpdatesStateSnap {
166 pub pts: i32,
168 pub qts: i32,
170 pub date: i32,
172 pub seq: i32,
174 pub channels: Vec<(i64, i32)>,
176}
177
178impl UpdatesStateSnap {
179 #[inline]
181 pub fn is_initialised(&self) -> bool {
182 self.pts > 0
183 }
184
185 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
187 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
188 entry.1 = pts;
189 } else {
190 self.channels.push((channel_id, pts));
191 }
192 }
193
194 pub fn channel_pts(&self, channel_id: i64) -> i32 {
196 self.channels
197 .iter()
198 .find(|c| c.0 == channel_id)
199 .map(|c| c.1)
200 .unwrap_or(0)
201 }
202}
203
204#[derive(Clone, Debug)]
207#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
208pub struct CachedPeer {
209 pub id: i64,
211 pub access_hash: i64,
214 pub is_channel: bool,
216 pub is_chat: bool,
219}
220
221#[derive(Clone, Debug)]
225#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
226pub struct CachedMinPeer {
227 pub user_id: i64,
229 pub peer_id: i64,
231 pub msg_id: i32,
233}
234
235#[derive(Clone, Debug, Default)]
237#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
238pub struct PersistedSession {
239 pub home_dc_id: i32,
240 pub dcs: Vec<DcEntry>,
241 pub updates_state: UpdatesStateSnap,
243 pub peers: Vec<CachedPeer>,
246 pub min_peers: Vec<CachedMinPeer>,
249}
250
251impl PersistedSession {
252 pub fn to_bytes(&self) -> Vec<u8> {
254 let mut b = Vec::with_capacity(512);
255
256 b.push(0x05u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
259
260 b.push(self.dcs.len() as u8);
261 for d in &self.dcs {
262 b.extend_from_slice(&d.dc_id.to_le_bytes());
263 match &d.auth_key {
264 Some(k) => {
265 b.push(1);
266 b.extend_from_slice(k);
267 }
268 None => {
269 b.push(0);
270 }
271 }
272 b.extend_from_slice(&d.first_salt.to_le_bytes());
273 b.extend_from_slice(&d.time_offset.to_le_bytes());
274 let ab = d.addr.as_bytes();
275 b.push(ab.len() as u8);
276 b.extend_from_slice(ab);
277 b.push(d.flags.0);
278 }
279
280 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
281 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
282 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
283 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
284 let ch = &self.updates_state.channels;
285 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
286 for &(cid, cpts) in ch {
287 b.extend_from_slice(&cid.to_le_bytes());
288 b.extend_from_slice(&cpts.to_le_bytes());
289 }
290
291 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
293 for p in &self.peers {
294 b.extend_from_slice(&p.id.to_le_bytes());
295 b.extend_from_slice(&p.access_hash.to_le_bytes());
296 let peer_type: u8 = if p.is_chat {
297 2
298 } else if p.is_channel {
299 1
300 } else {
301 0
302 };
303 b.push(peer_type);
304 }
305
306 b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
307 for m in &self.min_peers {
308 b.extend_from_slice(&m.user_id.to_le_bytes());
309 b.extend_from_slice(&m.peer_id.to_le_bytes());
310 b.extend_from_slice(&m.msg_id.to_le_bytes());
311 }
312
313 b
314 }
315
316 pub fn save(&self, path: &Path) -> io::Result<()> {
321 let tmp = path.with_extension("tmp");
322 std::fs::write(&tmp, self.to_bytes())?;
323 std::fs::rename(&tmp, path)
324 }
325
326 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
328 if buf.is_empty() {
329 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
330 }
331
332 let mut p = 0usize;
333
334 macro_rules! r {
335 ($n:expr) => {{
336 if p + $n > buf.len() {
337 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
338 }
339 let s = &buf[p..p + $n];
340 p += $n;
341 s
342 }};
343 }
344 macro_rules! r_i32 {
345 () => {
346 i32::from_le_bytes(r!(4).try_into().unwrap())
347 };
348 }
349 macro_rules! r_i64 {
350 () => {
351 i64::from_le_bytes(r!(8).try_into().unwrap())
352 };
353 }
354 macro_rules! r_u8 {
355 () => {
356 r!(1)[0]
357 };
358 }
359 macro_rules! r_u16 {
360 () => {
361 u16::from_le_bytes(r!(2).try_into().unwrap())
362 };
363 }
364
365 let first_byte = r_u8!();
366
367 let (home_dc_id, version) = if first_byte == 0x05 {
368 (r_i32!(), 5u8)
369 } else if first_byte == 0x04 {
370 (r_i32!(), 4u8)
371 } else if first_byte == 0x03 {
372 (r_i32!(), 3u8)
373 } else if first_byte == 0x02 {
374 (r_i32!(), 2u8)
375 } else {
376 let rest = r!(3);
377 let mut bytes = [0u8; 4];
378 bytes[0] = first_byte;
379 bytes[1..4].copy_from_slice(rest);
380 (i32::from_le_bytes(bytes), 1u8)
381 };
382
383 let dc_count = r_u8!() as usize;
384 let mut dcs = Vec::with_capacity(dc_count);
385 for _ in 0..dc_count {
386 let dc_id = r_i32!();
387 let has_key = r_u8!();
388 let auth_key = if has_key == 1 {
389 let mut k = [0u8; 256];
390 k.copy_from_slice(r!(256));
391 Some(k)
392 } else {
393 None
394 };
395 let first_salt = r_i64!();
396 let time_offset = r_i32!();
397 let al = r_u8!() as usize;
398 let addr = String::from_utf8_lossy(r!(al)).into_owned();
399 let flags = if version >= 3 {
400 DcFlags(r_u8!())
401 } else {
402 DcFlags::NONE
403 };
404 dcs.push(DcEntry {
405 dc_id,
406 addr,
407 auth_key,
408 first_salt,
409 time_offset,
410 flags,
411 });
412 }
413
414 if version < 2 {
415 return Ok(Self {
416 home_dc_id,
417 dcs,
418 updates_state: UpdatesStateSnap::default(),
419 peers: Vec::new(),
420 min_peers: Vec::new(),
421 });
422 }
423
424 let pts = r_i32!();
425 let qts = r_i32!();
426 let date = r_i32!();
427 let seq = r_i32!();
428 let ch_count = r_u16!() as usize;
429 let mut channels = Vec::with_capacity(ch_count);
430 for _ in 0..ch_count {
431 let cid = r_i64!();
432 let cpts = r_i32!();
433 channels.push((cid, cpts));
434 }
435
436 let peer_count = r_u16!() as usize;
437 let mut peers = Vec::with_capacity(peer_count);
438 for _ in 0..peer_count {
439 let id = r_i64!();
440 let access_hash = r_i64!();
441 let peer_type = r_u8!();
443 let is_channel = peer_type == 1;
444 let is_chat = peer_type == 2;
445 peers.push(CachedPeer {
446 id,
447 access_hash,
448 is_channel,
449 is_chat,
450 });
451 }
452
453 let min_peers = if version >= 4 {
455 let count = r_u16!() as usize;
456 let mut v = Vec::with_capacity(count);
457 for _ in 0..count {
458 let user_id = r_i64!();
459 let peer_id = r_i64!();
460 let msg_id = r_i32!();
461 v.push(CachedMinPeer {
462 user_id,
463 peer_id,
464 msg_id,
465 });
466 }
467 v
468 } else {
469 Vec::new()
470 };
471
472 Ok(Self {
473 home_dc_id,
474 dcs,
475 updates_state: UpdatesStateSnap {
476 pts,
477 qts,
478 date,
479 seq,
480 channels,
481 },
482 peers,
483 min_peers,
484 })
485 }
486
487 pub fn from_string(s: &str) -> io::Result<Self> {
489 use base64::Engine as _;
490 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
491 .decode(s.trim())
492 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
493 Self::from_bytes(&bytes)
494 }
495
496 pub fn load(path: &Path) -> io::Result<Self> {
497 let buf = std::fs::read(path)?;
498 Self::from_bytes(&buf)
499 }
500
501 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
513 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
514 candidates.peek()?;
515 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
517 cands
519 .iter()
520 .copied()
521 .find(|d| d.is_ipv6() == prefer_ipv6)
522 .or_else(|| cands.first().copied())
523 }
524
525 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
530 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
531 }
532}
533
534impl std::fmt::Display for PersistedSession {
535 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 use base64::Engine as _;
537 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
538 }
539}
540
541pub fn default_dc_addresses() -> HashMap<i32, String> {
543 [
544 (1, "149.154.175.53:443"),
545 (2, "149.154.167.51:443"),
546 (3, "149.154.175.100:443"),
547 (4, "149.154.167.91:443"),
548 (5, "91.108.56.130:443"),
549 ]
550 .into_iter()
551 .map(|(id, addr)| (id, addr.to_string()))
552 .collect()
553}
554
555use std::path::PathBuf;
560
561pub trait SessionBackend: Send + Sync {
569 fn save(&self, session: &PersistedSession) -> io::Result<()>;
570 fn load(&self) -> io::Result<Option<PersistedSession>>;
571 fn delete(&self) -> io::Result<()>;
572
573 fn name(&self) -> &str;
575
576 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
590 let mut s = self.load()?.unwrap_or_default();
591 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
593 *existing = entry.clone();
594 } else {
595 s.dcs.push(entry.clone());
596 }
597 self.save(&s)
598 }
599
600 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
607 let mut s = self.load()?.unwrap_or_default();
608 s.home_dc_id = dc_id;
609 self.save(&s)
610 }
611
612 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
618 let mut s = self.load()?.unwrap_or_default();
619 update.apply_to(&mut s.updates_state);
620 self.save(&s)
621 }
622
623 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
630 let mut s = self.load()?.unwrap_or_default();
631 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
632 *existing = peer.clone();
633 } else {
634 s.peers.push(peer.clone());
635 }
636 self.save(&s)
637 }
638}
639
640#[derive(Debug, Clone)]
654pub enum UpdateStateChange {
655 All(UpdatesStateSnap),
657 Primary { pts: i32, date: i32, seq: i32 },
659 Secondary { qts: i32 },
661 Channel { id: i64, pts: i32 },
663}
664
665impl UpdateStateChange {
666 pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
668 match self {
669 Self::All(new_snap) => *snap = new_snap.clone(),
670 Self::Primary { pts, date, seq } => {
671 snap.pts = *pts;
672 snap.date = *date;
673 snap.seq = *seq;
674 }
675 Self::Secondary { qts } => {
676 snap.qts = *qts;
677 }
678 Self::Channel { id, pts } => {
679 if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
681 existing.1 = *pts;
682 } else {
683 snap.channels.push((*id, *pts));
684 }
685 }
686 }
687 }
688}
689
690pub struct BinaryFileBackend {
694 path: PathBuf,
695}
696
697impl BinaryFileBackend {
698 pub fn new(path: impl Into<PathBuf>) -> Self {
699 Self { path: path.into() }
700 }
701
702 pub fn path(&self) -> &std::path::Path {
703 &self.path
704 }
705}
706
707impl SessionBackend for BinaryFileBackend {
708 fn save(&self, session: &PersistedSession) -> io::Result<()> {
709 session.save(&self.path)
710 }
711
712 fn load(&self) -> io::Result<Option<PersistedSession>> {
713 if !self.path.exists() {
714 return Ok(None);
715 }
716 match PersistedSession::load(&self.path) {
717 Ok(s) => Ok(Some(s)),
718 Err(e) => {
719 let bak = self.path.with_extension("bak");
720 tracing::warn!(
721 "[ferogram] Session file {:?} is corrupt ({e}); \
722 renaming to {:?} and starting fresh",
723 self.path,
724 bak
725 );
726 let _ = std::fs::rename(&self.path, &bak);
727 Ok(None)
728 }
729 }
730 }
731
732 fn delete(&self) -> io::Result<()> {
733 if self.path.exists() {
734 std::fs::remove_file(&self.path)?;
735 }
736 Ok(())
737 }
738
739 fn name(&self) -> &str {
740 "binary-file"
741 }
742
743 }
746
747#[derive(Default)]
755pub struct InMemoryBackend {
756 data: std::sync::Mutex<Option<PersistedSession>>,
757}
758
759impl InMemoryBackend {
760 pub fn new() -> Self {
761 Self::default()
762 }
763
764 pub fn snapshot(&self) -> Option<PersistedSession> {
766 self.data.lock().unwrap().clone()
767 }
768}
769
770impl SessionBackend for InMemoryBackend {
771 fn save(&self, s: &PersistedSession) -> io::Result<()> {
772 *self.data.lock().unwrap() = Some(s.clone());
773 Ok(())
774 }
775
776 fn load(&self) -> io::Result<Option<PersistedSession>> {
777 Ok(self.data.lock().unwrap().clone())
778 }
779
780 fn delete(&self) -> io::Result<()> {
781 *self.data.lock().unwrap() = None;
782 Ok(())
783 }
784
785 fn name(&self) -> &str {
786 "in-memory"
787 }
788
789 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
792 let mut guard = self.data.lock().unwrap();
793 let s = guard.get_or_insert_with(PersistedSession::default);
794 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
795 *existing = entry.clone();
796 } else {
797 s.dcs.push(entry.clone());
798 }
799 Ok(())
800 }
801
802 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
803 let mut guard = self.data.lock().unwrap();
804 let s = guard.get_or_insert_with(PersistedSession::default);
805 s.home_dc_id = dc_id;
806 Ok(())
807 }
808
809 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
810 let mut guard = self.data.lock().unwrap();
811 let s = guard.get_or_insert_with(PersistedSession::default);
812 update.apply_to(&mut s.updates_state);
813 Ok(())
814 }
815
816 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
817 let mut guard = self.data.lock().unwrap();
818 let s = guard.get_or_insert_with(PersistedSession::default);
819 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
820 *existing = peer.clone();
821 } else {
822 s.peers.push(peer.clone());
823 }
824 Ok(())
825 }
826}
827
828pub struct StringSessionBackend {
832 data: std::sync::Mutex<String>,
833}
834
835impl StringSessionBackend {
836 pub fn new(s: impl Into<String>) -> Self {
837 Self {
838 data: std::sync::Mutex::new(s.into()),
839 }
840 }
841
842 pub fn current(&self) -> String {
843 self.data.lock().unwrap().clone()
844 }
845}
846
847impl SessionBackend for StringSessionBackend {
848 fn save(&self, session: &PersistedSession) -> io::Result<()> {
849 *self.data.lock().unwrap() = session.to_string();
850 Ok(())
851 }
852
853 fn load(&self) -> io::Result<Option<PersistedSession>> {
854 let s = self.data.lock().unwrap().clone();
855 if s.trim().is_empty() {
856 return Ok(None);
857 }
858 PersistedSession::from_string(&s).map(Some)
859 }
860
861 fn delete(&self) -> io::Result<()> {
862 *self.data.lock().unwrap() = String::new();
863 Ok(())
864 }
865
866 fn name(&self) -> &str {
867 "string-session"
868 }
869}
870
871#[cfg(test)]
874mod tests {
875 use super::*;
876
877 fn make_dc(id: i32) -> DcEntry {
878 DcEntry {
879 dc_id: id,
880 addr: format!("1.2.3.{id}:443"),
881 auth_key: None,
882 first_salt: 0,
883 time_offset: 0,
884 flags: DcFlags::NONE,
885 }
886 }
887
888 fn make_peer(id: i64, hash: i64) -> CachedPeer {
889 CachedPeer {
890 id,
891 access_hash: hash,
892 is_channel: false,
893 is_chat: false,
894 }
895 }
896
897 #[test]
900 fn inmemory_load_returns_none_when_empty() {
901 let b = InMemoryBackend::new();
902 assert!(b.load().unwrap().is_none());
903 }
904
905 #[test]
906 fn inmemory_save_then_load_round_trips() {
907 let b = InMemoryBackend::new();
908 let mut s = PersistedSession::default();
909 s.home_dc_id = 3;
910 s.dcs.push(make_dc(3));
911 b.save(&s).unwrap();
912
913 let loaded = b.load().unwrap().unwrap();
914 assert_eq!(loaded.home_dc_id, 3);
915 assert_eq!(loaded.dcs.len(), 1);
916 }
917
918 #[test]
919 fn inmemory_delete_clears_state() {
920 let b = InMemoryBackend::new();
921 let mut s = PersistedSession::default();
922 s.home_dc_id = 2;
923 b.save(&s).unwrap();
924 b.delete().unwrap();
925 assert!(b.load().unwrap().is_none());
926 }
927
928 #[test]
931 fn inmemory_update_dc_inserts_new() {
932 let b = InMemoryBackend::new();
933 b.update_dc(&make_dc(4)).unwrap();
934 let s = b.snapshot().unwrap();
935 assert_eq!(s.dcs.len(), 1);
936 assert_eq!(s.dcs[0].dc_id, 4);
937 }
938
939 #[test]
940 fn inmemory_update_dc_replaces_existing() {
941 let b = InMemoryBackend::new();
942 b.update_dc(&make_dc(2)).unwrap();
943 let mut updated = make_dc(2);
944 updated.addr = "9.9.9.9:443".to_string();
945 b.update_dc(&updated).unwrap();
946
947 let s = b.snapshot().unwrap();
948 assert_eq!(s.dcs.len(), 1);
949 assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
950 }
951
952 #[test]
953 fn inmemory_set_home_dc() {
954 let b = InMemoryBackend::new();
955 b.set_home_dc(5).unwrap();
956 assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
957 }
958
959 #[test]
960 fn inmemory_cache_peer_inserts() {
961 let b = InMemoryBackend::new();
962 b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
963 let s = b.snapshot().unwrap();
964 assert_eq!(s.peers.len(), 1);
965 assert_eq!(s.peers[0].id, 100);
966 }
967
968 #[test]
969 fn inmemory_cache_peer_updates_existing() {
970 let b = InMemoryBackend::new();
971 b.cache_peer(&make_peer(100, 111)).unwrap();
972 b.cache_peer(&make_peer(100, 222)).unwrap();
973 let s = b.snapshot().unwrap();
974 assert_eq!(s.peers.len(), 1);
975 assert_eq!(s.peers[0].access_hash, 222);
976 }
977
978 #[test]
981 fn update_state_primary() {
982 let mut snap = UpdatesStateSnap {
983 pts: 0,
984 qts: 0,
985 date: 0,
986 seq: 0,
987 channels: vec![],
988 };
989 UpdateStateChange::Primary {
990 pts: 10,
991 date: 20,
992 seq: 30,
993 }
994 .apply_to(&mut snap);
995 assert_eq!(snap.pts, 10);
996 assert_eq!(snap.date, 20);
997 assert_eq!(snap.seq, 30);
998 assert_eq!(snap.qts, 0); }
1000
1001 #[test]
1002 fn update_state_secondary() {
1003 let mut snap = UpdatesStateSnap {
1004 pts: 5,
1005 qts: 0,
1006 date: 0,
1007 seq: 0,
1008 channels: vec![],
1009 };
1010 UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1011 assert_eq!(snap.qts, 99);
1012 assert_eq!(snap.pts, 5); }
1014
1015 #[test]
1016 fn update_state_channel_inserts() {
1017 let mut snap = UpdatesStateSnap {
1018 pts: 0,
1019 qts: 0,
1020 date: 0,
1021 seq: 0,
1022 channels: vec![],
1023 };
1024 UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1025 assert_eq!(snap.channels, vec![(12345, 42)]);
1026 }
1027
1028 #[test]
1029 fn update_state_channel_updates_existing() {
1030 let mut snap = UpdatesStateSnap {
1031 pts: 0,
1032 qts: 0,
1033 date: 0,
1034 seq: 0,
1035 channels: vec![(12345, 10), (67890, 5)],
1036 };
1037 UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1038 assert_eq!(snap.channels[0], (12345, 99));
1040 assert_eq!(snap.channels[1], (67890, 5));
1041 }
1042
1043 #[test]
1044 fn apply_update_state_via_backend() {
1045 let b = InMemoryBackend::new();
1046 b.apply_update_state(UpdateStateChange::Primary {
1047 pts: 7,
1048 date: 8,
1049 seq: 9,
1050 })
1051 .unwrap();
1052 let s = b.snapshot().unwrap();
1053 assert_eq!(s.updates_state.pts, 7);
1054 }
1055
1056 #[test]
1059 fn default_update_dc_via_trait_object() {
1060 let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1061 b.update_dc(&make_dc(1)).unwrap();
1062 b.update_dc(&make_dc(2)).unwrap();
1063 let loaded = b.load().unwrap().unwrap();
1065 assert_eq!(loaded.dcs.len(), 2);
1066 }
1067
1068 fn make_dc_v6(id: i32) -> DcEntry {
1071 DcEntry {
1072 dc_id: id,
1073 addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1074 auth_key: None,
1075 first_salt: 0,
1076 time_offset: 0,
1077 flags: DcFlags::IPV6,
1078 }
1079 }
1080
1081 #[test]
1082 fn dc_entry_from_parts_ipv4() {
1083 let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1084 assert_eq!(dc.addr, "149.154.175.53:443");
1085 assert!(!dc.is_ipv6());
1086 let sa = dc.socket_addr().unwrap();
1087 assert_eq!(sa.port(), 443);
1088 }
1089
1090 #[test]
1091 fn dc_entry_from_parts_ipv6() {
1092 let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1093 assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1094 assert!(dc.is_ipv6());
1095 let sa = dc.socket_addr().unwrap();
1096 assert_eq!(sa.port(), 443);
1097 }
1098
1099 #[test]
1100 fn persisted_session_dc_for_prefers_ipv6() {
1101 let mut s = PersistedSession::default();
1102 s.dcs.push(make_dc(2)); s.dcs.push(make_dc_v6(2)); let v6 = s.dc_for(2, true).unwrap();
1106 assert!(v6.is_ipv6());
1107
1108 let v4 = s.dc_for(2, false).unwrap();
1109 assert!(!v4.is_ipv6());
1110 }
1111
1112 #[test]
1113 fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1114 let mut s = PersistedSession::default();
1115 s.dcs.push(make_dc(3)); let dc = s.dc_for(3, true).unwrap();
1119 assert!(!dc.is_ipv6());
1120 }
1121
1122 #[test]
1123 fn persisted_session_all_dcs_for_returns_both() {
1124 let mut s = PersistedSession::default();
1125 s.dcs.push(make_dc(1));
1126 s.dcs.push(make_dc_v6(1));
1127 s.dcs.push(make_dc(2));
1128
1129 assert_eq!(s.all_dcs_for(1).count(), 2);
1130 assert_eq!(s.all_dcs_for(2).count(), 1);
1131 assert_eq!(s.all_dcs_for(5).count(), 0);
1132 }
1133
1134 #[test]
1135 fn inmemory_ipv4_and_ipv6_coexist() {
1136 let b = InMemoryBackend::new();
1137 b.update_dc(&make_dc(2)).unwrap(); b.update_dc(&make_dc_v6(2)).unwrap(); let s = b.snapshot().unwrap();
1141 assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1143 }
1144
1145 #[test]
1146 fn binary_roundtrip_ipv4_and_ipv6() {
1147 let mut s = PersistedSession::default();
1148 s.home_dc_id = 2;
1149 s.dcs.push(make_dc(2));
1150 s.dcs.push(make_dc_v6(2));
1151
1152 let bytes = s.to_bytes();
1153 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1154 assert_eq!(loaded.dcs.len(), 2);
1155 assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1156 assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1157 }
1158}
1159
1160#[cfg(feature = "sqlite-session")]
1185pub struct SqliteBackend {
1186 conn: std::sync::Mutex<rusqlite::Connection>,
1187 label: String,
1188}
1189
1190#[cfg(feature = "sqlite-session")]
1191impl SqliteBackend {
1192 const SCHEMA: &'static str = "
1193 PRAGMA journal_mode = WAL;
1194 PRAGMA synchronous = NORMAL;
1195
1196 CREATE TABLE IF NOT EXISTS meta (
1197 key TEXT PRIMARY KEY,
1198 value INTEGER NOT NULL DEFAULT 0
1199 );
1200
1201 CREATE TABLE IF NOT EXISTS dcs (
1202 dc_id INTEGER NOT NULL,
1203 flags INTEGER NOT NULL DEFAULT 0,
1204 addr TEXT NOT NULL,
1205 auth_key BLOB,
1206 first_salt INTEGER NOT NULL DEFAULT 0,
1207 time_offset INTEGER NOT NULL DEFAULT 0,
1208 PRIMARY KEY (dc_id, flags)
1209 );
1210
1211 CREATE TABLE IF NOT EXISTS update_state (
1212 id INTEGER PRIMARY KEY CHECK (id = 1),
1213 pts INTEGER NOT NULL DEFAULT 0,
1214 qts INTEGER NOT NULL DEFAULT 0,
1215 date INTEGER NOT NULL DEFAULT 0,
1216 seq INTEGER NOT NULL DEFAULT 0
1217 );
1218
1219 CREATE TABLE IF NOT EXISTS channel_pts (
1220 channel_id INTEGER PRIMARY KEY,
1221 pts INTEGER NOT NULL
1222 );
1223
1224 CREATE TABLE IF NOT EXISTS peers (
1225 id INTEGER PRIMARY KEY,
1226 access_hash INTEGER NOT NULL,
1227 is_channel INTEGER NOT NULL DEFAULT 0
1228 );
1229 ";
1230
1231 pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1233 let path = path.into();
1234 let label = path.display().to_string();
1235 let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1236 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1237 Ok(Self {
1238 conn: std::sync::Mutex::new(conn),
1239 label,
1240 })
1241 }
1242
1243 pub fn in_memory() -> io::Result<Self> {
1245 let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1246 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1247 Ok(Self {
1248 conn: std::sync::Mutex::new(conn),
1249 label: ":memory:".into(),
1250 })
1251 }
1252
1253 fn map_err(e: rusqlite::Error) -> io::Error {
1254 io::Error::other(e)
1255 }
1256
1257 fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1259 let home_dc_id: i32 = conn
1261 .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1262 r.get(0)
1263 })
1264 .unwrap_or(0);
1265
1266 let mut stmt = conn
1268 .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1269 .map_err(Self::map_err)?;
1270 let dcs = stmt
1271 .query_map([], |row| {
1272 let dc_id: i32 = row.get(0)?;
1273 let flags_raw: u8 = row.get(1)?;
1274 let addr: String = row.get(2)?;
1275 let key_blob: Option<Vec<u8>> = row.get(3)?;
1276 let first_salt: i64 = row.get(4)?;
1277 let time_offset: i32 = row.get(5)?;
1278 Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1279 })
1280 .map_err(Self::map_err)?
1281 .filter_map(|r| r.ok())
1282 .map(
1283 |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1284 let auth_key = key_blob.and_then(|b| {
1285 if b.len() == 256 {
1286 let mut k = [0u8; 256];
1287 k.copy_from_slice(&b);
1288 Some(k)
1289 } else {
1290 None
1291 }
1292 });
1293 DcEntry {
1294 dc_id,
1295 addr,
1296 auth_key,
1297 first_salt,
1298 time_offset,
1299 flags: DcFlags(flags_raw),
1300 }
1301 },
1302 )
1303 .collect();
1304
1305 let updates_state = conn
1307 .query_row(
1308 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1309 [],
1310 |r| {
1311 Ok(UpdatesStateSnap {
1312 pts: r.get(0)?,
1313 qts: r.get(1)?,
1314 date: r.get(2)?,
1315 seq: r.get(3)?,
1316 channels: vec![],
1317 })
1318 },
1319 )
1320 .unwrap_or_default();
1321
1322 let mut ch_stmt = conn
1324 .prepare("SELECT channel_id, pts FROM channel_pts")
1325 .map_err(Self::map_err)?;
1326 let channels: Vec<(i64, i32)> = ch_stmt
1327 .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1328 .map_err(Self::map_err)?
1329 .filter_map(|r| r.ok())
1330 .collect();
1331
1332 let mut peer_stmt = conn
1334 .prepare("SELECT id, access_hash, is_channel FROM peers")
1335 .map_err(Self::map_err)?;
1336 let peers: Vec<CachedPeer> = peer_stmt
1337 .query_map([], |r| {
1338 Ok(CachedPeer {
1339 id: r.get(0)?,
1340 access_hash: r.get(1)?,
1341 is_channel: r.get::<_, i32>(2)? != 0,
1342 is_chat: false,
1343 })
1344 })
1345 .map_err(Self::map_err)?
1346 .filter_map(|r| r.ok())
1347 .collect();
1348
1349 Ok(PersistedSession {
1350 home_dc_id,
1351 dcs,
1352 updates_state: UpdatesStateSnap {
1353 channels,
1354 ..updates_state
1355 },
1356 peers,
1357 min_peers: Vec::new(),
1358 })
1359 }
1360
1361 fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1363 conn.execute_batch("BEGIN IMMEDIATE")
1364 .map_err(Self::map_err)?;
1365
1366 conn.execute(
1367 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1368 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1369 rusqlite::params![s.home_dc_id],
1370 )
1371 .map_err(Self::map_err)?;
1372
1373 conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1375 for d in &s.dcs {
1376 conn.execute(
1377 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1378 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1379 rusqlite::params![
1380 d.dc_id,
1381 d.flags.0,
1382 d.addr,
1383 d.auth_key.as_ref().map(|k| k.as_ref()),
1384 d.first_salt,
1385 d.time_offset,
1386 ],
1387 )
1388 .map_err(Self::map_err)?;
1389 }
1390
1391 let us = &s.updates_state;
1395 conn.execute(
1396 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1397 ON CONFLICT(id) DO UPDATE SET
1398 pts = MAX(excluded.pts, update_state.pts),
1399 qts = MAX(excluded.qts, update_state.qts),
1400 date = excluded.date,
1401 seq = excluded.seq",
1402 rusqlite::params![us.pts, us.qts, us.date, us.seq],
1403 )
1404 .map_err(Self::map_err)?;
1405
1406 conn.execute("DELETE FROM channel_pts", [])
1407 .map_err(Self::map_err)?;
1408 for &(cid, cpts) in &us.channels {
1409 conn.execute(
1410 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1411 rusqlite::params![cid, cpts],
1412 )
1413 .map_err(Self::map_err)?;
1414 }
1415
1416 conn.execute("DELETE FROM peers", [])
1418 .map_err(Self::map_err)?;
1419 for p in &s.peers {
1420 conn.execute(
1421 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)",
1422 rusqlite::params![p.id, p.access_hash, p.is_channel as i32],
1423 )
1424 .map_err(Self::map_err)?;
1425 }
1426
1427 conn.execute_batch("COMMIT").map_err(Self::map_err)
1428 }
1429}
1430
1431#[cfg(feature = "sqlite-session")]
1432impl SessionBackend for SqliteBackend {
1433 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1434 let conn = self.conn.lock().unwrap();
1435 Self::write_session(&conn, session)
1436 }
1437
1438 fn load(&self) -> io::Result<Option<PersistedSession>> {
1439 let conn = self.conn.lock().unwrap();
1440 let count: i64 = conn
1442 .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1443 .map_err(Self::map_err)?;
1444 if count == 0 {
1445 return Ok(None);
1446 }
1447 Self::read_session(&conn).map(Some)
1448 }
1449
1450 fn delete(&self) -> io::Result<()> {
1451 let conn = self.conn.lock().unwrap();
1452 conn.execute_batch(
1453 "BEGIN IMMEDIATE;
1454 DELETE FROM meta;
1455 DELETE FROM dcs;
1456 DELETE FROM update_state;
1457 DELETE FROM channel_pts;
1458 DELETE FROM peers;
1459 COMMIT;",
1460 )
1461 .map_err(Self::map_err)
1462 }
1463
1464 fn name(&self) -> &str {
1465 &self.label
1466 }
1467
1468 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1471 let conn = self.conn.lock().unwrap();
1472 conn.execute(
1473 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1474 VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1475 ON CONFLICT(dc_id, flags) DO UPDATE SET
1476 addr = excluded.addr,
1477 auth_key = excluded.auth_key,
1478 first_salt = excluded.first_salt,
1479 time_offset = excluded.time_offset",
1480 rusqlite::params![
1481 entry.dc_id,
1482 entry.addr,
1483 entry.auth_key.as_ref().map(|k| k.as_ref()),
1484 entry.first_salt,
1485 entry.time_offset,
1486 entry.flags.0,
1487 ],
1488 )
1489 .map(|_| ())
1490 .map_err(Self::map_err)
1491 }
1492
1493 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1494 let conn = self.conn.lock().unwrap();
1495 conn.execute(
1496 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1497 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1498 rusqlite::params![dc_id],
1499 )
1500 .map(|_| ())
1501 .map_err(Self::map_err)
1502 }
1503
1504 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1505 let conn = self.conn.lock().unwrap();
1506 match update {
1507 UpdateStateChange::All(snap) => {
1508 conn.execute(
1509 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1510 ON CONFLICT(id) DO UPDATE SET
1511 pts=excluded.pts, qts=excluded.qts,
1512 date=excluded.date, seq=excluded.seq",
1513 rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1514 )
1515 .map_err(Self::map_err)?;
1516 conn.execute("DELETE FROM channel_pts", [])
1517 .map_err(Self::map_err)?;
1518 for &(cid, cpts) in &snap.channels {
1519 conn.execute(
1520 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1521 rusqlite::params![cid, cpts],
1522 )
1523 .map_err(Self::map_err)?;
1524 }
1525 Ok(())
1526 }
1527 UpdateStateChange::Primary { pts, date, seq } => conn
1528 .execute(
1529 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1530 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1531 seq=excluded.seq",
1532 rusqlite::params![pts, date, seq],
1533 )
1534 .map(|_| ())
1535 .map_err(Self::map_err),
1536 UpdateStateChange::Secondary { qts } => conn
1537 .execute(
1538 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1539 ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1540 rusqlite::params![qts],
1541 )
1542 .map(|_| ())
1543 .map_err(Self::map_err),
1544 UpdateStateChange::Channel { id, pts } => conn
1545 .execute(
1546 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1547 ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1548 rusqlite::params![id, pts],
1549 )
1550 .map(|_| ())
1551 .map_err(Self::map_err),
1552 }
1553 }
1554
1555 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1556 let conn = self.conn.lock().unwrap();
1557 conn.execute(
1558 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)
1559 ON CONFLICT(id) DO UPDATE SET
1560 access_hash = excluded.access_hash,
1561 is_channel = excluded.is_channel",
1562 rusqlite::params![peer.id, peer.access_hash, peer.is_channel as i32],
1563 )
1564 .map(|_| ())
1565 .map_err(Self::map_err)
1566 }
1567}
1568
1569#[cfg(feature = "libsql-session")]
1589pub struct LibSqlBackend {
1590 conn: libsql::Connection,
1591 label: String,
1592}
1593
1594#[cfg(feature = "libsql-session")]
1595impl LibSqlBackend {
1596 const SCHEMA: &'static str = "
1597 CREATE TABLE IF NOT EXISTS meta (
1598 key TEXT PRIMARY KEY,
1599 value INTEGER NOT NULL DEFAULT 0
1600 );
1601 CREATE TABLE IF NOT EXISTS dcs (
1602 dc_id INTEGER NOT NULL,
1603 flags INTEGER NOT NULL DEFAULT 0,
1604 addr TEXT NOT NULL,
1605 auth_key BLOB,
1606 first_salt INTEGER NOT NULL DEFAULT 0,
1607 time_offset INTEGER NOT NULL DEFAULT 0,
1608 PRIMARY KEY (dc_id, flags)
1609 );
1610 CREATE TABLE IF NOT EXISTS update_state (
1611 id INTEGER PRIMARY KEY CHECK (id = 1),
1612 pts INTEGER NOT NULL DEFAULT 0,
1613 qts INTEGER NOT NULL DEFAULT 0,
1614 date INTEGER NOT NULL DEFAULT 0,
1615 seq INTEGER NOT NULL DEFAULT 0
1616 );
1617 CREATE TABLE IF NOT EXISTS channel_pts (
1618 channel_id INTEGER PRIMARY KEY,
1619 pts INTEGER NOT NULL
1620 );
1621 CREATE TABLE IF NOT EXISTS peers (
1622 id INTEGER PRIMARY KEY,
1623 access_hash INTEGER NOT NULL,
1624 is_channel INTEGER NOT NULL DEFAULT 0
1625 );
1626 ";
1627
1628 fn block<F, T>(fut: F) -> io::Result<T>
1629 where
1630 F: std::future::Future<Output = Result<T, libsql::Error>>,
1631 {
1632 tokio::runtime::Handle::current()
1633 .block_on(fut)
1634 .map_err(io::Error::other)
1635 }
1636
1637 async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1638 conn.execute_batch(Self::SCHEMA).await
1639 }
1640
1641 pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1643 let path = path.into();
1644 let label = path.display().to_string();
1645 let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1646 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1647 Self::block(Self::apply_schema(&conn))?;
1648 Ok(Self {
1649 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1650 label,
1651 })
1652 }
1653
1654 pub fn in_memory() -> io::Result<Self> {
1656 let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1657 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1658 Self::block(Self::apply_schema(&conn))?;
1659 Ok(Self {
1660 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1661 label: ":memory:".into(),
1662 })
1663 }
1664
1665 pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1667 let url = url.into();
1668 let label = url.clone();
1669 let db = Self::block(async {
1670 libsql::Builder::new_remote(url, auth_token.into())
1671 .build()
1672 .await
1673 })?;
1674 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1675 Self::block(Self::apply_schema(&conn))?;
1676 Ok(Self {
1677 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1678 label,
1679 })
1680 }
1681
1682 pub fn open_replica(
1684 path: impl Into<PathBuf>,
1685 url: impl Into<String>,
1686 auth_token: impl Into<String>,
1687 ) -> io::Result<Self> {
1688 let path = path.into();
1689 let label = format!("{} (replica of {})", path.display(), url.into());
1690 let db = Self::block(async {
1691 libsql::Builder::new_remote_replica(path, url.into(), auth_token.into())
1692 .build()
1693 .await
1694 })?;
1695 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1696 Self::block(Self::apply_schema(&conn))?;
1697 Ok(Self {
1698 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1699 label,
1700 })
1701 }
1702
1703 async fn read_session_async(
1704 conn: &libsql::Connection,
1705 ) -> Result<PersistedSession, libsql::Error> {
1706 use libsql::de;
1707
1708 let home_dc_id: i32 = conn
1710 .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
1711 .await?
1712 .next()
1713 .await?
1714 .map(|r| r.get::<i32>(0))
1715 .transpose()?
1716 .unwrap_or(0);
1717
1718 let mut rows = conn
1720 .query(
1721 "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
1722 (),
1723 )
1724 .await?;
1725 let mut dcs = Vec::new();
1726 while let Some(row) = rows.next().await? {
1727 let dc_id: i32 = row.get(0)?;
1728 let flags_raw: u8 = row.get::<i64>(1)? as u8;
1729 let addr: String = row.get(2)?;
1730 let key_blob: Option<Vec<u8>> = row.get(3)?;
1731 let first_salt: i64 = row.get(4)?;
1732 let time_offset: i32 = row.get(5)?;
1733 let auth_key = match key_blob {
1734 Some(b) if b.len() == 256 => {
1735 let mut k = [0u8; 256];
1736 k.copy_from_slice(&b);
1737 Some(k)
1738 }
1739 Some(b) => {
1740 return Err(libsql::Error::Misuse(format!(
1741 "auth_key blob must be 256 bytes, got {}",
1742 b.len()
1743 )));
1744 }
1745 None => None,
1746 };
1747 dcs.push(DcEntry {
1748 dc_id,
1749 addr,
1750 auth_key,
1751 first_salt,
1752 time_offset,
1753 flags: DcFlags(flags_raw),
1754 });
1755 }
1756
1757 let mut us_row = conn
1759 .query(
1760 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1761 (),
1762 )
1763 .await?;
1764 let updates_state = if let Some(r) = us_row.next().await? {
1765 UpdatesStateSnap {
1766 pts: r.get(0)?,
1767 qts: r.get(1)?,
1768 date: r.get(2)?,
1769 seq: r.get(3)?,
1770 channels: vec![],
1771 }
1772 } else {
1773 UpdatesStateSnap::default()
1774 };
1775
1776 let mut ch_rows = conn
1778 .query("SELECT channel_id, pts FROM channel_pts", ())
1779 .await?;
1780 let mut channels = Vec::new();
1781 while let Some(r) = ch_rows.next().await? {
1782 channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
1783 }
1784
1785 let mut peer_rows = conn
1787 .query("SELECT id, access_hash, is_channel FROM peers", ())
1788 .await?;
1789 let mut peers = Vec::new();
1790 while let Some(r) = peer_rows.next().await? {
1791 peers.push(CachedPeer {
1792 id: r.get(0)?,
1793 access_hash: r.get(1)?,
1794 is_channel: r.get::<i32>(2)? != 0,
1795 is_chat: false,
1796 });
1797 }
1798
1799 Ok(PersistedSession {
1800 home_dc_id,
1801 dcs,
1802 updates_state: UpdatesStateSnap {
1803 channels,
1804 ..updates_state
1805 },
1806 peers,
1807 min_peers: Vec::new(),
1808 })
1809 }
1810
1811 async fn write_session_async(
1812 conn: &libsql::Connection,
1813 s: &PersistedSession,
1814 ) -> Result<(), libsql::Error> {
1815 conn.execute_batch("BEGIN IMMEDIATE").await?;
1816
1817 conn.execute(
1818 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1819 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1820 libsql::params![s.home_dc_id],
1821 )
1822 .await?;
1823
1824 conn.execute("DELETE FROM dcs", ()).await?;
1825 for d in &s.dcs {
1826 conn.execute(
1827 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1828 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1829 libsql::params![
1830 d.dc_id,
1831 d.flags.0 as i64,
1832 d.addr.clone(),
1833 d.auth_key.map(|k| k.to_vec()),
1834 d.first_salt,
1835 d.time_offset,
1836 ],
1837 )
1838 .await?;
1839 }
1840
1841 let us = &s.updates_state;
1842 conn.execute(
1843 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1844 ON CONFLICT(id) DO UPDATE SET
1845 pts = MAX(excluded.pts, update_state.pts),
1846 qts = MAX(excluded.qts, update_state.qts),
1847 date = excluded.date,
1848 seq = excluded.seq",
1849 libsql::params![us.pts, us.qts, us.date, us.seq],
1850 )
1851 .await?;
1852
1853 conn.execute("DELETE FROM channel_pts", ()).await?;
1854 for &(cid, cpts) in &us.channels {
1855 conn.execute(
1856 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
1857 libsql::params![cid, cpts],
1858 )
1859 .await?;
1860 }
1861
1862 conn.execute("DELETE FROM peers", ()).await?;
1863 for p in &s.peers {
1864 conn.execute(
1865 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1,?2,?3)",
1866 libsql::params![p.id, p.access_hash, p.is_channel as i32],
1867 )
1868 .await?;
1869 }
1870
1871 conn.execute_batch("COMMIT").await
1872 }
1873}
1874
1875#[cfg(feature = "libsql-session")]
1876impl SessionBackend for LibSqlBackend {
1877 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1878 let conn = self.conn.clone();
1879 let session = session.clone();
1880 Self::block(async move {
1881 let conn = conn.lock().await;
1882 Self::write_session_async(&conn, session).await
1883 })
1884 }
1885
1886 fn load(&self) -> io::Result<Option<PersistedSession>> {
1887 let conn = self.conn.clone();
1888 let count: i64 = Self::block(async move {
1889 let conn = conn.lock().await;
1890 let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
1891 Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
1892 })?;
1893 if count == 0 {
1894 return Ok(None);
1895 }
1896 let conn = self.conn.clone();
1897 Self::block(async move {
1898 let conn = conn.lock().await;
1899 Self::read_session_async(&conn).await
1900 })
1901 .map(Some)
1902 }
1903
1904 fn delete(&self) -> io::Result<()> {
1905 let conn = self.conn.clone();
1906 Self::block(async move {
1907 let conn = conn.lock().await;
1908 conn.execute_batch(
1909 "BEGIN IMMEDIATE;
1910 DELETE FROM meta;
1911 DELETE FROM dcs;
1912 DELETE FROM update_state;
1913 DELETE FROM channel_pts;
1914 DELETE FROM peers;
1915 COMMIT;",
1916 )
1917 .await
1918 })
1919 }
1920
1921 fn name(&self) -> &str {
1922 &self.label
1923 }
1924
1925 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1928 let conn = self.conn.clone();
1929 let (dc_id, addr, key, salt, off, flags) = (
1930 entry.dc_id,
1931 entry.addr.clone(),
1932 entry.auth_key.map(|k| k.to_vec()),
1933 entry.first_salt,
1934 entry.time_offset,
1935 entry.flags.0 as i64,
1936 );
1937 Self::block(async move {
1938 let conn = conn.lock().await;
1939 conn.execute(
1940 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1941 VALUES (?1,?6,?2,?3,?4,?5)
1942 ON CONFLICT(dc_id, flags) DO UPDATE SET
1943 addr=excluded.addr, auth_key=excluded.auth_key,
1944 first_salt=excluded.first_salt, time_offset=excluded.time_offset",
1945 libsql::params![dc_id, addr, key, salt, off, flags],
1946 )
1947 .await
1948 .map(|_| ())
1949 })
1950 }
1951
1952 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1953 let conn = self.conn.clone();
1954 Self::block(async move {
1955 let conn = conn.lock().await;
1956 conn.execute(
1957 "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
1958 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
1959 libsql::params![dc_id],
1960 )
1961 .await
1962 .map(|_| ())
1963 })
1964 }
1965
1966 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1967 let conn = self.conn.clone();
1968 Self::block(async move {
1969 let conn = conn.lock().await;
1970 match update {
1971 UpdateStateChange::All(snap) => {
1972 conn.execute(
1973 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
1974 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
1975 date=excluded.date,seq=excluded.seq",
1976 libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
1977 )
1978 .await?;
1979 conn.execute("DELETE FROM channel_pts", ()).await?;
1980 for &(cid, cpts) in &snap.channels {
1981 conn.execute(
1982 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
1983 libsql::params![cid, cpts],
1984 )
1985 .await?;
1986 }
1987 Ok(())
1988 }
1989 UpdateStateChange::Primary { pts, date, seq } => conn
1990 .execute(
1991 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
1992 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
1993 seq=excluded.seq",
1994 libsql::params![pts, date, seq],
1995 )
1996 .await
1997 .map(|_| ()),
1998 UpdateStateChange::Secondary { qts } => conn
1999 .execute(
2000 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2001 ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2002 libsql::params![qts],
2003 )
2004 .await
2005 .map(|_| ()),
2006 UpdateStateChange::Channel { id, pts } => conn
2007 .execute(
2008 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2009 ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2010 libsql::params![id, pts],
2011 )
2012 .await
2013 .map(|_| ()),
2014 }
2015 })
2016 }
2017
2018 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2019 let conn = self.conn.clone();
2020 let (id, hash, is_ch) = (peer.id, peer.access_hash, peer.is_channel as i32);
2021 Self::block(async move {
2022 let conn = conn.lock().await;
2023 conn.execute(
2024 "INSERT INTO peers (id,access_hash,is_channel) VALUES (?1,?2,?3)
2025 ON CONFLICT(id) DO UPDATE SET
2026 access_hash=excluded.access_hash,
2027 is_channel=excluded.is_channel",
2028 libsql::params![id, hash, is_ch],
2029 )
2030 .await
2031 .map(|_| ())
2032 })
2033 }
2034}