1#![deny(unsafe_code)]
14#![cfg_attr(docsrs, feature(doc_cfg))]
15use std::collections::HashMap;
76use std::io::{self, ErrorKind};
77use std::path::Path;
78
79#[cfg(feature = "serde")]
80mod auth_key_serde {
81 use serde::{Deserialize, Deserializer, Serializer};
82
83 pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
84 where
85 S: Serializer,
86 {
87 match value {
88 Some(k) => s.serialize_some(k.as_slice()),
89 None => s.serialize_none(),
90 }
91 }
92
93 pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
94 where
95 D: Deserializer<'de>,
96 {
97 let opt: Option<Vec<u8>> = Option::deserialize(d)?;
98 match opt {
99 None => Ok(None),
100 Some(v) => {
101 let arr: [u8; 256] = v
102 .try_into()
103 .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
104 Ok(Some(arr))
105 }
106 }
107 }
108}
109
110#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
114#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
115pub struct DcFlags(pub u8);
116
117impl DcFlags {
118 pub const NONE: DcFlags = DcFlags(0);
119 pub const IPV6: DcFlags = DcFlags(1 << 0);
120 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
121 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
122 pub const CDN: DcFlags = DcFlags(1 << 3);
123 pub const STATIC: DcFlags = DcFlags(1 << 4);
124
125 pub fn contains(self, other: DcFlags) -> bool {
126 self.0 & other.0 == other.0
127 }
128
129 pub fn set(&mut self, flag: DcFlags) {
130 self.0 |= flag.0;
131 }
132}
133
134impl std::ops::BitOr for DcFlags {
135 type Output = DcFlags;
136 fn bitor(self, rhs: DcFlags) -> DcFlags {
137 DcFlags(self.0 | rhs.0)
138 }
139}
140
141#[derive(Clone, Debug)]
143#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
144pub struct DcEntry {
145 pub dc_id: i32,
146 pub addr: String,
147 #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
148 pub auth_key: Option<[u8; 256]>,
149 pub first_salt: i64,
150 pub time_offset: i32,
151 pub flags: DcFlags,
153}
154
155impl DcEntry {
156 #[inline]
158 pub fn is_ipv6(&self) -> bool {
159 self.flags.contains(DcFlags::IPV6)
160 }
161
162 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
169 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
170 io::Error::new(
171 io::ErrorKind::InvalidData,
172 format!("invalid DC address: {:?}", self.addr),
173 )
174 })
175 }
176
177 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
190 let addr = if ip.contains(':') {
192 format!("[{ip}]:{port}")
193 } else {
194 format!("{ip}:{port}")
195 };
196 Self {
197 dc_id,
198 addr,
199 auth_key: None,
200 first_salt: 0,
201 time_offset: 0,
202 flags,
203 }
204 }
205}
206
207#[derive(Clone, Debug, Default)]
210#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
211pub struct UpdatesStateSnap {
212 pub pts: i32,
214 pub qts: i32,
216 pub date: i32,
218 pub seq: i32,
220 pub channels: Vec<(i64, i32)>,
222}
223
224impl UpdatesStateSnap {
225 #[inline]
227 pub fn is_initialised(&self) -> bool {
228 self.pts > 0
229 }
230
231 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
233 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
234 entry.1 = pts;
235 } else {
236 self.channels.push((channel_id, pts));
237 }
238 }
239
240 pub fn channel_pts(&self, channel_id: i64) -> i32 {
242 self.channels
243 .iter()
244 .find(|c| c.0 == channel_id)
245 .map(|c| c.1)
246 .unwrap_or(0)
247 }
248}
249
250#[derive(Clone, Debug)]
253#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
254pub struct CachedPeer {
255 pub id: i64,
257 pub access_hash: i64,
260 pub is_channel: bool,
262 pub is_chat: bool,
265}
266
267#[derive(Clone, Debug)]
271#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
272pub struct CachedMinPeer {
273 pub user_id: i64,
275 pub peer_id: i64,
277 pub msg_id: i32,
279}
280
281#[derive(Clone, Debug, Default)]
283#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
284pub struct PersistedSession {
285 pub home_dc_id: i32,
286 pub dcs: Vec<DcEntry>,
287 pub updates_state: UpdatesStateSnap,
289 pub peers: Vec<CachedPeer>,
292 pub min_peers: Vec<CachedMinPeer>,
295}
296
297impl PersistedSession {
298 pub fn to_bytes(&self) -> Vec<u8> {
300 let mut b = Vec::with_capacity(512);
301
302 b.push(0x05u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
305
306 b.push(self.dcs.len() as u8);
307 for d in &self.dcs {
308 b.extend_from_slice(&d.dc_id.to_le_bytes());
309 match &d.auth_key {
310 Some(k) => {
311 b.push(1);
312 b.extend_from_slice(k);
313 }
314 None => {
315 b.push(0);
316 }
317 }
318 b.extend_from_slice(&d.first_salt.to_le_bytes());
319 b.extend_from_slice(&d.time_offset.to_le_bytes());
320 let ab = d.addr.as_bytes();
321 b.push(ab.len() as u8);
322 b.extend_from_slice(ab);
323 b.push(d.flags.0);
324 }
325
326 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
327 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
328 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
329 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
330 let ch = &self.updates_state.channels;
331 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
332 for &(cid, cpts) in ch {
333 b.extend_from_slice(&cid.to_le_bytes());
334 b.extend_from_slice(&cpts.to_le_bytes());
335 }
336
337 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
339 for p in &self.peers {
340 b.extend_from_slice(&p.id.to_le_bytes());
341 b.extend_from_slice(&p.access_hash.to_le_bytes());
342 let peer_type: u8 = if p.is_chat {
343 2
344 } else if p.is_channel {
345 1
346 } else {
347 0
348 };
349 b.push(peer_type);
350 }
351
352 b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
353 for m in &self.min_peers {
354 b.extend_from_slice(&m.user_id.to_le_bytes());
355 b.extend_from_slice(&m.peer_id.to_le_bytes());
356 b.extend_from_slice(&m.msg_id.to_le_bytes());
357 }
358
359 b
360 }
361
362 pub fn save(&self, path: &Path) -> io::Result<()> {
369 use std::sync::atomic::{AtomicU64, Ordering};
370 static SEQ: AtomicU64 = AtomicU64::new(0);
371 let n = SEQ.fetch_add(1, Ordering::Relaxed);
372 let tmp = path.with_extension(format!("{n}.tmp"));
373 std::fs::write(&tmp, self.to_bytes())?;
374 std::fs::rename(&tmp, path).inspect_err(|_e| {
375 let _ = std::fs::remove_file(&tmp);
376 })
377 }
378
379 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
381 if buf.is_empty() {
382 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
383 }
384
385 let mut p = 0usize;
386
387 macro_rules! r {
388 ($n:expr) => {{
389 if p + $n > buf.len() {
390 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
391 }
392 let s = &buf[p..p + $n];
393 p += $n;
394 s
395 }};
396 }
397 macro_rules! r_i32 {
398 () => {
399 i32::from_le_bytes(r!(4).try_into().unwrap())
400 };
401 }
402 macro_rules! r_i64 {
403 () => {
404 i64::from_le_bytes(r!(8).try_into().unwrap())
405 };
406 }
407 macro_rules! r_u8 {
408 () => {
409 r!(1)[0]
410 };
411 }
412 macro_rules! r_u16 {
413 () => {
414 u16::from_le_bytes(r!(2).try_into().unwrap())
415 };
416 }
417
418 let first_byte = r_u8!();
419
420 let (home_dc_id, version) = if first_byte == 0x05 {
421 (r_i32!(), 5u8)
422 } else if first_byte == 0x04 {
423 (r_i32!(), 4u8)
424 } else if first_byte == 0x03 {
425 (r_i32!(), 3u8)
426 } else if first_byte == 0x02 {
427 (r_i32!(), 2u8)
428 } else {
429 let rest = r!(3);
430 let mut bytes = [0u8; 4];
431 bytes[0] = first_byte;
432 bytes[1..4].copy_from_slice(rest);
433 (i32::from_le_bytes(bytes), 1u8)
434 };
435
436 let dc_count = r_u8!() as usize;
437 let mut dcs = Vec::with_capacity(dc_count);
438 for _ in 0..dc_count {
439 let dc_id = r_i32!();
440 let has_key = r_u8!();
441 let auth_key = if has_key == 1 {
442 let mut k = [0u8; 256];
443 k.copy_from_slice(r!(256));
444 Some(k)
445 } else {
446 None
447 };
448 let first_salt = r_i64!();
449 let time_offset = r_i32!();
450 let al = r_u8!() as usize;
451 let addr = String::from_utf8_lossy(r!(al)).into_owned();
452 let flags = if version >= 3 {
453 DcFlags(r_u8!())
454 } else {
455 DcFlags::NONE
456 };
457 dcs.push(DcEntry {
458 dc_id,
459 addr,
460 auth_key,
461 first_salt,
462 time_offset,
463 flags,
464 });
465 }
466
467 if version < 2 {
468 return Ok(Self {
469 home_dc_id,
470 dcs,
471 updates_state: UpdatesStateSnap::default(),
472 peers: Vec::new(),
473 min_peers: Vec::new(),
474 });
475 }
476
477 let pts = r_i32!();
478 let qts = r_i32!();
479 let date = r_i32!();
480 let seq = r_i32!();
481 let ch_count = r_u16!() as usize;
482 let mut channels = Vec::with_capacity(ch_count);
483 for _ in 0..ch_count {
484 let cid = r_i64!();
485 let cpts = r_i32!();
486 channels.push((cid, cpts));
487 }
488
489 let peer_count = r_u16!() as usize;
490 let mut peers = Vec::with_capacity(peer_count);
491 for _ in 0..peer_count {
492 let id = r_i64!();
493 let access_hash = r_i64!();
494 let peer_type = r_u8!();
496 let is_channel = peer_type == 1;
497 let is_chat = peer_type == 2;
498 peers.push(CachedPeer {
499 id,
500 access_hash,
501 is_channel,
502 is_chat,
503 });
504 }
505
506 let min_peers = if version >= 4 {
508 let count = r_u16!() as usize;
509 let mut v = Vec::with_capacity(count);
510 for _ in 0..count {
511 let user_id = r_i64!();
512 let peer_id = r_i64!();
513 let msg_id = r_i32!();
514 v.push(CachedMinPeer {
515 user_id,
516 peer_id,
517 msg_id,
518 });
519 }
520 v
521 } else {
522 Vec::new()
523 };
524
525 Ok(Self {
526 home_dc_id,
527 dcs,
528 updates_state: UpdatesStateSnap {
529 pts,
530 qts,
531 date,
532 seq,
533 channels,
534 },
535 peers,
536 min_peers,
537 })
538 }
539
540 pub fn from_string(s: &str) -> io::Result<Self> {
542 use base64::Engine as _;
543 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
544 .decode(s.trim())
545 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
546 Self::from_bytes(&bytes)
547 }
548
549 pub fn load(path: &Path) -> io::Result<Self> {
550 let buf = std::fs::read(path)?;
551 Self::from_bytes(&buf)
552 }
553
554 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
566 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
567 candidates.peek()?;
568 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
570 cands
572 .iter()
573 .copied()
574 .find(|d| d.is_ipv6() == prefer_ipv6)
575 .or_else(|| cands.first().copied())
576 }
577
578 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
583 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
584 }
585}
586
587impl std::fmt::Display for PersistedSession {
588 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589 use base64::Engine as _;
590 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
591 }
592}
593
594pub fn default_dc_addresses() -> HashMap<i32, String> {
596 [
597 (1, "149.154.175.53:443"),
598 (2, "149.154.167.51:443"),
599 (3, "149.154.175.100:443"),
600 (4, "149.154.167.91:443"),
601 (5, "91.108.56.130:443"),
602 ]
603 .into_iter()
604 .map(|(id, addr)| (id, addr.to_string()))
605 .collect()
606}
607
608use std::path::PathBuf;
613
614pub trait SessionBackend: Send + Sync {
622 fn save(&self, session: &PersistedSession) -> io::Result<()>;
623 fn load(&self) -> io::Result<Option<PersistedSession>>;
624 fn delete(&self) -> io::Result<()>;
625
626 fn name(&self) -> &str;
628
629 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
642 let mut s = self.load()?.unwrap_or_default();
643 if let Some(existing) = s
645 .dcs
646 .iter_mut()
647 .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
648 {
649 *existing = entry.clone();
650 } else {
651 s.dcs.push(entry.clone());
652 }
653 self.save(&s)
654 }
655
656 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
662 let mut s = self.load()?.unwrap_or_default();
663 s.home_dc_id = dc_id;
664 self.save(&s)
665 }
666
667 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
672 let mut s = self.load()?.unwrap_or_default();
673 update.apply_to(&mut s.updates_state);
674 self.save(&s)
675 }
676
677 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
683 let mut s = self.load()?.unwrap_or_default();
684 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
685 *existing = peer.clone();
686 } else {
687 s.peers.push(peer.clone());
688 }
689 self.save(&s)
690 }
691}
692
693#[derive(Debug, Clone)]
707pub enum UpdateStateChange {
708 All(UpdatesStateSnap),
710 Primary { pts: i32, date: i32, seq: i32 },
712 Secondary { qts: i32 },
714 Channel { id: i64, pts: i32 },
716}
717
718impl UpdateStateChange {
719 pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
721 match self {
722 Self::All(new_snap) => *snap = new_snap.clone(),
723 Self::Primary { pts, date, seq } => {
724 snap.pts = *pts;
725 snap.date = *date;
726 snap.seq = *seq;
727 }
728 Self::Secondary { qts } => {
729 snap.qts = *qts;
730 }
731 Self::Channel { id, pts } => {
732 if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
734 existing.1 = *pts;
735 } else {
736 snap.channels.push((*id, *pts));
737 }
738 }
739 }
740 }
741}
742
743pub struct BinaryFileBackend {
747 path: PathBuf,
748 write_lock: std::sync::Mutex<()>,
753}
754
755impl BinaryFileBackend {
756 pub fn new(path: impl Into<PathBuf>) -> Self {
757 Self {
758 path: path.into(),
759 write_lock: std::sync::Mutex::new(()),
760 }
761 }
762
763 pub fn path(&self) -> &std::path::Path {
764 &self.path
765 }
766}
767
768impl SessionBackend for BinaryFileBackend {
769 fn save(&self, session: &PersistedSession) -> io::Result<()> {
770 let _guard = self.write_lock.lock().unwrap();
771 session.save(&self.path)
772 }
773
774 fn load(&self) -> io::Result<Option<PersistedSession>> {
775 if !self.path.exists() {
776 return Ok(None);
777 }
778 match PersistedSession::load(&self.path) {
779 Ok(s) => Ok(Some(s)),
780 Err(e) => {
781 let bak = self.path.with_extension("bak");
782 tracing::warn!(
783 "[ferogram] Session file {:?} is corrupt ({e}); \
784 renaming to {:?} and starting fresh",
785 self.path,
786 bak
787 );
788 let _ = std::fs::rename(&self.path, &bak);
789 Ok(None)
790 }
791 }
792 }
793
794 fn delete(&self) -> io::Result<()> {
795 if self.path.exists() {
796 std::fs::remove_file(&self.path)?;
797 }
798 Ok(())
799 }
800
801 fn name(&self) -> &str {
802 "binary-file"
803 }
804
805 }
808
809#[derive(Default)]
817pub struct InMemoryBackend {
818 data: std::sync::Mutex<Option<PersistedSession>>,
819}
820
821impl InMemoryBackend {
822 pub fn new() -> Self {
823 Self::default()
824 }
825
826 pub fn snapshot(&self) -> Option<PersistedSession> {
828 self.data.lock().unwrap().clone()
829 }
830}
831
832impl SessionBackend for InMemoryBackend {
833 fn save(&self, s: &PersistedSession) -> io::Result<()> {
834 *self.data.lock().unwrap() = Some(s.clone());
835 Ok(())
836 }
837
838 fn load(&self) -> io::Result<Option<PersistedSession>> {
839 Ok(self.data.lock().unwrap().clone())
840 }
841
842 fn delete(&self) -> io::Result<()> {
843 *self.data.lock().unwrap() = None;
844 Ok(())
845 }
846
847 fn name(&self) -> &str {
848 "in-memory"
849 }
850
851 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
854 let mut guard = self.data.lock().unwrap();
855 let s = guard.get_or_insert_with(PersistedSession::default);
856 if let Some(existing) = s
857 .dcs
858 .iter_mut()
859 .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
860 {
861 *existing = entry.clone();
862 } else {
863 s.dcs.push(entry.clone());
864 }
865 Ok(())
866 }
867
868 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
869 let mut guard = self.data.lock().unwrap();
870 let s = guard.get_or_insert_with(PersistedSession::default);
871 s.home_dc_id = dc_id;
872 Ok(())
873 }
874
875 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
876 let mut guard = self.data.lock().unwrap();
877 let s = guard.get_or_insert_with(PersistedSession::default);
878 update.apply_to(&mut s.updates_state);
879 Ok(())
880 }
881
882 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
883 let mut guard = self.data.lock().unwrap();
884 let s = guard.get_or_insert_with(PersistedSession::default);
885 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
886 *existing = peer.clone();
887 } else {
888 s.peers.push(peer.clone());
889 }
890 Ok(())
891 }
892}
893
894pub struct StringSessionBackend {
898 data: std::sync::Mutex<String>,
899}
900
901impl StringSessionBackend {
902 pub fn new(s: impl Into<String>) -> Self {
903 Self {
904 data: std::sync::Mutex::new(s.into()),
905 }
906 }
907
908 pub fn current(&self) -> String {
909 self.data.lock().unwrap().clone()
910 }
911}
912
913impl SessionBackend for StringSessionBackend {
914 fn save(&self, session: &PersistedSession) -> io::Result<()> {
915 *self.data.lock().unwrap() = session.to_string();
916 Ok(())
917 }
918
919 fn load(&self) -> io::Result<Option<PersistedSession>> {
920 let s = self.data.lock().unwrap().clone();
921 if s.trim().is_empty() {
922 return Ok(None);
923 }
924 PersistedSession::from_string(&s).map(Some)
925 }
926
927 fn delete(&self) -> io::Result<()> {
928 *self.data.lock().unwrap() = String::new();
929 Ok(())
930 }
931
932 fn name(&self) -> &str {
933 "string-session"
934 }
935}
936
937#[cfg(test)]
940mod tests {
941 use super::*;
942
943 fn make_dc(id: i32) -> DcEntry {
944 DcEntry {
945 dc_id: id,
946 addr: format!("1.2.3.{id}:443"),
947 auth_key: None,
948 first_salt: 0,
949 time_offset: 0,
950 flags: DcFlags::NONE,
951 }
952 }
953
954 fn make_peer(id: i64, hash: i64) -> CachedPeer {
955 CachedPeer {
956 id,
957 access_hash: hash,
958 is_channel: false,
959 is_chat: false,
960 }
961 }
962
963 #[test]
966 fn inmemory_load_returns_none_when_empty() {
967 let b = InMemoryBackend::new();
968 assert!(b.load().unwrap().is_none());
969 }
970
971 #[test]
972 fn inmemory_save_then_load_round_trips() {
973 let b = InMemoryBackend::new();
974 let mut s = PersistedSession::default();
975 s.home_dc_id = 3;
976 s.dcs.push(make_dc(3));
977 b.save(&s).unwrap();
978
979 let loaded = b.load().unwrap().unwrap();
980 assert_eq!(loaded.home_dc_id, 3);
981 assert_eq!(loaded.dcs.len(), 1);
982 }
983
984 #[test]
985 fn inmemory_delete_clears_state() {
986 let b = InMemoryBackend::new();
987 let mut s = PersistedSession::default();
988 s.home_dc_id = 2;
989 b.save(&s).unwrap();
990 b.delete().unwrap();
991 assert!(b.load().unwrap().is_none());
992 }
993
994 #[test]
997 fn inmemory_update_dc_inserts_new() {
998 let b = InMemoryBackend::new();
999 b.update_dc(&make_dc(4)).unwrap();
1000 let s = b.snapshot().unwrap();
1001 assert_eq!(s.dcs.len(), 1);
1002 assert_eq!(s.dcs[0].dc_id, 4);
1003 }
1004
1005 #[test]
1006 fn inmemory_update_dc_replaces_existing() {
1007 let b = InMemoryBackend::new();
1008 b.update_dc(&make_dc(2)).unwrap();
1009 let mut updated = make_dc(2);
1010 updated.addr = "9.9.9.9:443".to_string();
1011 b.update_dc(&updated).unwrap();
1012
1013 let s = b.snapshot().unwrap();
1014 assert_eq!(s.dcs.len(), 1);
1015 assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
1016 }
1017
1018 #[test]
1019 fn inmemory_set_home_dc() {
1020 let b = InMemoryBackend::new();
1021 b.set_home_dc(5).unwrap();
1022 assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
1023 }
1024
1025 #[test]
1026 fn inmemory_cache_peer_inserts() {
1027 let b = InMemoryBackend::new();
1028 b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
1029 let s = b.snapshot().unwrap();
1030 assert_eq!(s.peers.len(), 1);
1031 assert_eq!(s.peers[0].id, 100);
1032 }
1033
1034 #[test]
1035 fn inmemory_cache_peer_updates_existing() {
1036 let b = InMemoryBackend::new();
1037 b.cache_peer(&make_peer(100, 111)).unwrap();
1038 b.cache_peer(&make_peer(100, 222)).unwrap();
1039 let s = b.snapshot().unwrap();
1040 assert_eq!(s.peers.len(), 1);
1041 assert_eq!(s.peers[0].access_hash, 222);
1042 }
1043
1044 #[test]
1047 fn update_state_primary() {
1048 let mut snap = UpdatesStateSnap {
1049 pts: 0,
1050 qts: 0,
1051 date: 0,
1052 seq: 0,
1053 channels: vec![],
1054 };
1055 UpdateStateChange::Primary {
1056 pts: 10,
1057 date: 20,
1058 seq: 30,
1059 }
1060 .apply_to(&mut snap);
1061 assert_eq!(snap.pts, 10);
1062 assert_eq!(snap.date, 20);
1063 assert_eq!(snap.seq, 30);
1064 assert_eq!(snap.qts, 0); }
1066
1067 #[test]
1068 fn update_state_secondary() {
1069 let mut snap = UpdatesStateSnap {
1070 pts: 5,
1071 qts: 0,
1072 date: 0,
1073 seq: 0,
1074 channels: vec![],
1075 };
1076 UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1077 assert_eq!(snap.qts, 99);
1078 assert_eq!(snap.pts, 5); }
1080
1081 #[test]
1082 fn update_state_channel_inserts() {
1083 let mut snap = UpdatesStateSnap {
1084 pts: 0,
1085 qts: 0,
1086 date: 0,
1087 seq: 0,
1088 channels: vec![],
1089 };
1090 UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1091 assert_eq!(snap.channels, vec![(12345, 42)]);
1092 }
1093
1094 #[test]
1095 fn update_state_channel_updates_existing() {
1096 let mut snap = UpdatesStateSnap {
1097 pts: 0,
1098 qts: 0,
1099 date: 0,
1100 seq: 0,
1101 channels: vec![(12345, 10), (67890, 5)],
1102 };
1103 UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1104 assert_eq!(snap.channels[0], (12345, 99));
1106 assert_eq!(snap.channels[1], (67890, 5));
1107 }
1108
1109 #[test]
1110 fn apply_update_state_via_backend() {
1111 let b = InMemoryBackend::new();
1112 b.apply_update_state(UpdateStateChange::Primary {
1113 pts: 7,
1114 date: 8,
1115 seq: 9,
1116 })
1117 .unwrap();
1118 let s = b.snapshot().unwrap();
1119 assert_eq!(s.updates_state.pts, 7);
1120 }
1121
1122 #[test]
1125 fn default_update_dc_via_trait_object() {
1126 let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1127 b.update_dc(&make_dc(1)).unwrap();
1128 b.update_dc(&make_dc(2)).unwrap();
1129 let loaded = b.load().unwrap().unwrap();
1131 assert_eq!(loaded.dcs.len(), 2);
1132 }
1133
1134 fn make_dc_v6(id: i32) -> DcEntry {
1137 DcEntry {
1138 dc_id: id,
1139 addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1140 auth_key: None,
1141 first_salt: 0,
1142 time_offset: 0,
1143 flags: DcFlags::IPV6,
1144 }
1145 }
1146
1147 #[test]
1148 fn dc_entry_from_parts_ipv4() {
1149 let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1150 assert_eq!(dc.addr, "149.154.175.53:443");
1151 assert!(!dc.is_ipv6());
1152 let sa = dc.socket_addr().unwrap();
1153 assert_eq!(sa.port(), 443);
1154 }
1155
1156 #[test]
1157 fn dc_entry_from_parts_ipv6() {
1158 let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1159 assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1160 assert!(dc.is_ipv6());
1161 let sa = dc.socket_addr().unwrap();
1162 assert_eq!(sa.port(), 443);
1163 }
1164
1165 #[test]
1166 fn persisted_session_dc_for_prefers_ipv6() {
1167 let mut s = PersistedSession::default();
1168 s.dcs.push(make_dc(2)); s.dcs.push(make_dc_v6(2)); let v6 = s.dc_for(2, true).unwrap();
1172 assert!(v6.is_ipv6());
1173
1174 let v4 = s.dc_for(2, false).unwrap();
1175 assert!(!v4.is_ipv6());
1176 }
1177
1178 #[test]
1179 fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1180 let mut s = PersistedSession::default();
1181 s.dcs.push(make_dc(3)); let dc = s.dc_for(3, true).unwrap();
1185 assert!(!dc.is_ipv6());
1186 }
1187
1188 #[test]
1189 fn persisted_session_all_dcs_for_returns_both() {
1190 let mut s = PersistedSession::default();
1191 s.dcs.push(make_dc(1));
1192 s.dcs.push(make_dc_v6(1));
1193 s.dcs.push(make_dc(2));
1194
1195 assert_eq!(s.all_dcs_for(1).count(), 2);
1196 assert_eq!(s.all_dcs_for(2).count(), 1);
1197 assert_eq!(s.all_dcs_for(5).count(), 0);
1198 }
1199
1200 #[test]
1201 fn inmemory_ipv4_and_ipv6_coexist() {
1202 let b = InMemoryBackend::new();
1203 b.update_dc(&make_dc(2)).unwrap(); b.update_dc(&make_dc_v6(2)).unwrap(); let s = b.snapshot().unwrap();
1207 assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1209 }
1210
1211 #[test]
1212 fn binary_roundtrip_ipv4_and_ipv6() {
1213 let mut s = PersistedSession::default();
1214 s.home_dc_id = 2;
1215 s.dcs.push(make_dc(2));
1216 s.dcs.push(make_dc_v6(2));
1217
1218 let bytes = s.to_bytes();
1219 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1220 assert_eq!(loaded.dcs.len(), 2);
1221 assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1222 assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1223 }
1224}
1225
1226#[cfg(feature = "sqlite-session")]
1251pub struct SqliteBackend {
1252 conn: std::sync::Mutex<rusqlite::Connection>,
1253 label: String,
1254}
1255
1256#[cfg(feature = "sqlite-session")]
1257impl SqliteBackend {
1258 const SCHEMA: &'static str = "
1259 PRAGMA journal_mode = WAL;
1260 PRAGMA synchronous = NORMAL;
1261
1262 CREATE TABLE IF NOT EXISTS meta (
1263 key TEXT PRIMARY KEY,
1264 value INTEGER NOT NULL DEFAULT 0
1265 );
1266
1267 CREATE TABLE IF NOT EXISTS dcs (
1268 dc_id INTEGER NOT NULL,
1269 flags INTEGER NOT NULL DEFAULT 0,
1270 addr TEXT NOT NULL,
1271 auth_key BLOB,
1272 first_salt INTEGER NOT NULL DEFAULT 0,
1273 time_offset INTEGER NOT NULL DEFAULT 0,
1274 PRIMARY KEY (dc_id, flags)
1275 );
1276
1277 CREATE TABLE IF NOT EXISTS update_state (
1278 id INTEGER PRIMARY KEY CHECK (id = 1),
1279 pts INTEGER NOT NULL DEFAULT 0,
1280 qts INTEGER NOT NULL DEFAULT 0,
1281 date INTEGER NOT NULL DEFAULT 0,
1282 seq INTEGER NOT NULL DEFAULT 0
1283 );
1284
1285 CREATE TABLE IF NOT EXISTS channel_pts (
1286 channel_id INTEGER PRIMARY KEY,
1287 pts INTEGER NOT NULL
1288 );
1289
1290 CREATE TABLE IF NOT EXISTS peers (
1291 id INTEGER PRIMARY KEY,
1292 access_hash INTEGER NOT NULL,
1293 is_channel INTEGER NOT NULL DEFAULT 0
1294 );
1295 ";
1296
1297 pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1299 let path = path.into();
1300 let label = path.display().to_string();
1301 let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1302 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1303 Ok(Self {
1304 conn: std::sync::Mutex::new(conn),
1305 label,
1306 })
1307 }
1308
1309 pub fn in_memory() -> io::Result<Self> {
1311 let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1312 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1313 Ok(Self {
1314 conn: std::sync::Mutex::new(conn),
1315 label: ":memory:".into(),
1316 })
1317 }
1318
1319 fn map_err(e: rusqlite::Error) -> io::Error {
1320 io::Error::other(e)
1321 }
1322
1323 fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1325 let home_dc_id: i32 = conn
1327 .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1328 r.get(0)
1329 })
1330 .unwrap_or(0);
1331
1332 let mut stmt = conn
1334 .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1335 .map_err(Self::map_err)?;
1336 let dcs = stmt
1337 .query_map([], |row| {
1338 let dc_id: i32 = row.get(0)?;
1339 let flags_raw: u8 = row.get(1)?;
1340 let addr: String = row.get(2)?;
1341 let key_blob: Option<Vec<u8>> = row.get(3)?;
1342 let first_salt: i64 = row.get(4)?;
1343 let time_offset: i32 = row.get(5)?;
1344 Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1345 })
1346 .map_err(Self::map_err)?
1347 .filter_map(|r| r.ok())
1348 .map(
1349 |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1350 let auth_key = key_blob.and_then(|b| {
1351 if b.len() == 256 {
1352 let mut k = [0u8; 256];
1353 k.copy_from_slice(&b);
1354 Some(k)
1355 } else {
1356 None
1357 }
1358 });
1359 DcEntry {
1360 dc_id,
1361 addr,
1362 auth_key,
1363 first_salt,
1364 time_offset,
1365 flags: DcFlags(flags_raw),
1366 }
1367 },
1368 )
1369 .collect();
1370
1371 let updates_state = conn
1373 .query_row(
1374 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1375 [],
1376 |r| {
1377 Ok(UpdatesStateSnap {
1378 pts: r.get(0)?,
1379 qts: r.get(1)?,
1380 date: r.get(2)?,
1381 seq: r.get(3)?,
1382 channels: vec![],
1383 })
1384 },
1385 )
1386 .unwrap_or_default();
1387
1388 let mut ch_stmt = conn
1390 .prepare("SELECT channel_id, pts FROM channel_pts")
1391 .map_err(Self::map_err)?;
1392 let channels: Vec<(i64, i32)> = ch_stmt
1393 .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1394 .map_err(Self::map_err)?
1395 .filter_map(|r| r.ok())
1396 .collect();
1397
1398 let mut peer_stmt = conn
1400 .prepare("SELECT id, access_hash, is_channel FROM peers")
1401 .map_err(Self::map_err)?;
1402 let peers: Vec<CachedPeer> = peer_stmt
1403 .query_map([], |r| {
1404 Ok(CachedPeer {
1405 id: r.get(0)?,
1406 access_hash: r.get(1)?,
1407 is_channel: r.get::<_, i32>(2)? != 0,
1408 is_chat: false,
1409 })
1410 })
1411 .map_err(Self::map_err)?
1412 .filter_map(|r| r.ok())
1413 .collect();
1414
1415 Ok(PersistedSession {
1416 home_dc_id,
1417 dcs,
1418 updates_state: UpdatesStateSnap {
1419 channels,
1420 ..updates_state
1421 },
1422 peers,
1423 min_peers: Vec::new(),
1424 })
1425 }
1426
1427 fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1429 conn.execute_batch("BEGIN IMMEDIATE")
1430 .map_err(Self::map_err)?;
1431
1432 conn.execute(
1433 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1434 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1435 rusqlite::params![s.home_dc_id],
1436 )
1437 .map_err(Self::map_err)?;
1438
1439 conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1441 for d in &s.dcs {
1442 conn.execute(
1443 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1444 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1445 rusqlite::params![
1446 d.dc_id,
1447 d.flags.0,
1448 d.addr,
1449 d.auth_key.as_ref().map(|k| k.as_ref()),
1450 d.first_salt,
1451 d.time_offset,
1452 ],
1453 )
1454 .map_err(Self::map_err)?;
1455 }
1456
1457 let us = &s.updates_state;
1461 conn.execute(
1462 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1463 ON CONFLICT(id) DO UPDATE SET
1464 pts = MAX(excluded.pts, update_state.pts),
1465 qts = MAX(excluded.qts, update_state.qts),
1466 date = excluded.date,
1467 seq = excluded.seq",
1468 rusqlite::params![us.pts, us.qts, us.date, us.seq],
1469 )
1470 .map_err(Self::map_err)?;
1471
1472 conn.execute("DELETE FROM channel_pts", [])
1473 .map_err(Self::map_err)?;
1474 for &(cid, cpts) in &us.channels {
1475 conn.execute(
1476 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1477 rusqlite::params![cid, cpts],
1478 )
1479 .map_err(Self::map_err)?;
1480 }
1481
1482 conn.execute("DELETE FROM peers", [])
1484 .map_err(Self::map_err)?;
1485 for p in &s.peers {
1486 conn.execute(
1487 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)",
1488 rusqlite::params![p.id, p.access_hash, p.is_channel as i32],
1489 )
1490 .map_err(Self::map_err)?;
1491 }
1492
1493 conn.execute_batch("COMMIT").map_err(Self::map_err)
1494 }
1495}
1496
1497#[cfg(feature = "sqlite-session")]
1498impl SessionBackend for SqliteBackend {
1499 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1500 let conn = self.conn.lock().unwrap();
1501 Self::write_session(&conn, session)
1502 }
1503
1504 fn load(&self) -> io::Result<Option<PersistedSession>> {
1505 let conn = self.conn.lock().unwrap();
1506 let count: i64 = conn
1508 .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1509 .map_err(Self::map_err)?;
1510 if count == 0 {
1511 return Ok(None);
1512 }
1513 Self::read_session(&conn).map(Some)
1514 }
1515
1516 fn delete(&self) -> io::Result<()> {
1517 let conn = self.conn.lock().unwrap();
1518 conn.execute_batch(
1519 "BEGIN IMMEDIATE;
1520 DELETE FROM meta;
1521 DELETE FROM dcs;
1522 DELETE FROM update_state;
1523 DELETE FROM channel_pts;
1524 DELETE FROM peers;
1525 COMMIT;",
1526 )
1527 .map_err(Self::map_err)
1528 }
1529
1530 fn name(&self) -> &str {
1531 &self.label
1532 }
1533
1534 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1537 let conn = self.conn.lock().unwrap();
1538 conn.execute(
1539 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1540 VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1541 ON CONFLICT(dc_id, flags) DO UPDATE SET
1542 addr = excluded.addr,
1543 auth_key = excluded.auth_key,
1544 first_salt = excluded.first_salt,
1545 time_offset = excluded.time_offset",
1546 rusqlite::params![
1547 entry.dc_id,
1548 entry.addr,
1549 entry.auth_key.as_ref().map(|k| k.as_ref()),
1550 entry.first_salt,
1551 entry.time_offset,
1552 entry.flags.0,
1553 ],
1554 )
1555 .map(|_| ())
1556 .map_err(Self::map_err)
1557 }
1558
1559 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1560 let conn = self.conn.lock().unwrap();
1561 conn.execute(
1562 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1563 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1564 rusqlite::params![dc_id],
1565 )
1566 .map(|_| ())
1567 .map_err(Self::map_err)
1568 }
1569
1570 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1571 let conn = self.conn.lock().unwrap();
1572 match update {
1573 UpdateStateChange::All(snap) => {
1574 conn.execute(
1575 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1576 ON CONFLICT(id) DO UPDATE SET
1577 pts=excluded.pts, qts=excluded.qts,
1578 date=excluded.date, seq=excluded.seq",
1579 rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1580 )
1581 .map_err(Self::map_err)?;
1582 conn.execute("DELETE FROM channel_pts", [])
1583 .map_err(Self::map_err)?;
1584 for &(cid, cpts) in &snap.channels {
1585 conn.execute(
1586 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1587 rusqlite::params![cid, cpts],
1588 )
1589 .map_err(Self::map_err)?;
1590 }
1591 Ok(())
1592 }
1593 UpdateStateChange::Primary { pts, date, seq } => conn
1594 .execute(
1595 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1596 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1597 seq=excluded.seq",
1598 rusqlite::params![pts, date, seq],
1599 )
1600 .map(|_| ())
1601 .map_err(Self::map_err),
1602 UpdateStateChange::Secondary { qts } => conn
1603 .execute(
1604 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1605 ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1606 rusqlite::params![qts],
1607 )
1608 .map(|_| ())
1609 .map_err(Self::map_err),
1610 UpdateStateChange::Channel { id, pts } => conn
1611 .execute(
1612 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1613 ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1614 rusqlite::params![id, pts],
1615 )
1616 .map(|_| ())
1617 .map_err(Self::map_err),
1618 }
1619 }
1620
1621 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1622 let conn = self.conn.lock().unwrap();
1623 conn.execute(
1624 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)
1625 ON CONFLICT(id) DO UPDATE SET
1626 access_hash = excluded.access_hash,
1627 is_channel = excluded.is_channel",
1628 rusqlite::params![peer.id, peer.access_hash, peer.is_channel as i32],
1629 )
1630 .map(|_| ())
1631 .map_err(Self::map_err)
1632 }
1633}
1634
1635#[cfg(feature = "libsql-session")]
1655pub struct LibSqlBackend {
1656 conn: libsql::Connection,
1657 label: String,
1658}
1659
1660#[cfg(feature = "libsql-session")]
1661impl LibSqlBackend {
1662 const SCHEMA: &'static str = "
1663 CREATE TABLE IF NOT EXISTS meta (
1664 key TEXT PRIMARY KEY,
1665 value INTEGER NOT NULL DEFAULT 0
1666 );
1667 CREATE TABLE IF NOT EXISTS dcs (
1668 dc_id INTEGER NOT NULL,
1669 flags INTEGER NOT NULL DEFAULT 0,
1670 addr TEXT NOT NULL,
1671 auth_key BLOB,
1672 first_salt INTEGER NOT NULL DEFAULT 0,
1673 time_offset INTEGER NOT NULL DEFAULT 0,
1674 PRIMARY KEY (dc_id, flags)
1675 );
1676 CREATE TABLE IF NOT EXISTS update_state (
1677 id INTEGER PRIMARY KEY CHECK (id = 1),
1678 pts INTEGER NOT NULL DEFAULT 0,
1679 qts INTEGER NOT NULL DEFAULT 0,
1680 date INTEGER NOT NULL DEFAULT 0,
1681 seq INTEGER NOT NULL DEFAULT 0
1682 );
1683 CREATE TABLE IF NOT EXISTS channel_pts (
1684 channel_id INTEGER PRIMARY KEY,
1685 pts INTEGER NOT NULL
1686 );
1687 CREATE TABLE IF NOT EXISTS peers (
1688 id INTEGER PRIMARY KEY,
1689 access_hash INTEGER NOT NULL,
1690 is_channel INTEGER NOT NULL DEFAULT 0
1691 );
1692 ";
1693
1694 fn block<F, T>(fut: F) -> io::Result<T>
1695 where
1696 F: std::future::Future<Output = Result<T, libsql::Error>>,
1697 {
1698 tokio::runtime::Handle::current()
1699 .block_on(fut)
1700 .map_err(io::Error::other)
1701 }
1702
1703 async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1704 conn.execute_batch(Self::SCHEMA).await
1705 }
1706
1707 pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1709 let path = path.into();
1710 let label = path.display().to_string();
1711 let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1712 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1713 Self::block(Self::apply_schema(&conn))?;
1714 Ok(Self {
1715 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1716 label,
1717 })
1718 }
1719
1720 pub fn in_memory() -> io::Result<Self> {
1722 let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1723 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1724 Self::block(Self::apply_schema(&conn))?;
1725 Ok(Self {
1726 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1727 label: ":memory:".into(),
1728 })
1729 }
1730
1731 pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1733 let url = url.into();
1734 let label = url.clone();
1735 let db = Self::block(async {
1736 libsql::Builder::new_remote(url, auth_token.into())
1737 .build()
1738 .await
1739 })?;
1740 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1741 Self::block(Self::apply_schema(&conn))?;
1742 Ok(Self {
1743 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1744 label,
1745 })
1746 }
1747
1748 pub fn open_replica(
1750 path: impl Into<PathBuf>,
1751 url: impl Into<String>,
1752 auth_token: impl Into<String>,
1753 ) -> io::Result<Self> {
1754 let path = path.into();
1755 let label = format!("{} (replica of {})", path.display(), url.into());
1756 let db = Self::block(async {
1757 libsql::Builder::new_remote_replica(path, url.into(), auth_token.into())
1758 .build()
1759 .await
1760 })?;
1761 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1762 Self::block(Self::apply_schema(&conn))?;
1763 Ok(Self {
1764 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1765 label,
1766 })
1767 }
1768
1769 async fn read_session_async(
1770 conn: &libsql::Connection,
1771 ) -> Result<PersistedSession, libsql::Error> {
1772 use libsql::de;
1773
1774 let home_dc_id: i32 = conn
1776 .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
1777 .await?
1778 .next()
1779 .await?
1780 .map(|r| r.get::<i32>(0))
1781 .transpose()?
1782 .unwrap_or(0);
1783
1784 let mut rows = conn
1786 .query(
1787 "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
1788 (),
1789 )
1790 .await?;
1791 let mut dcs = Vec::new();
1792 while let Some(row) = rows.next().await? {
1793 let dc_id: i32 = row.get(0)?;
1794 let flags_raw: u8 = row.get::<i64>(1)? as u8;
1795 let addr: String = row.get(2)?;
1796 let key_blob: Option<Vec<u8>> = row.get(3)?;
1797 let first_salt: i64 = row.get(4)?;
1798 let time_offset: i32 = row.get(5)?;
1799 let auth_key = match key_blob {
1800 Some(b) if b.len() == 256 => {
1801 let mut k = [0u8; 256];
1802 k.copy_from_slice(&b);
1803 Some(k)
1804 }
1805 Some(b) => {
1806 return Err(libsql::Error::Misuse(format!(
1807 "auth_key blob must be 256 bytes, got {}",
1808 b.len()
1809 )));
1810 }
1811 None => None,
1812 };
1813 dcs.push(DcEntry {
1814 dc_id,
1815 addr,
1816 auth_key,
1817 first_salt,
1818 time_offset,
1819 flags: DcFlags(flags_raw),
1820 });
1821 }
1822
1823 let mut us_row = conn
1825 .query(
1826 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1827 (),
1828 )
1829 .await?;
1830 let updates_state = if let Some(r) = us_row.next().await? {
1831 UpdatesStateSnap {
1832 pts: r.get(0)?,
1833 qts: r.get(1)?,
1834 date: r.get(2)?,
1835 seq: r.get(3)?,
1836 channels: vec![],
1837 }
1838 } else {
1839 UpdatesStateSnap::default()
1840 };
1841
1842 let mut ch_rows = conn
1844 .query("SELECT channel_id, pts FROM channel_pts", ())
1845 .await?;
1846 let mut channels = Vec::new();
1847 while let Some(r) = ch_rows.next().await? {
1848 channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
1849 }
1850
1851 let mut peer_rows = conn
1853 .query("SELECT id, access_hash, is_channel FROM peers", ())
1854 .await?;
1855 let mut peers = Vec::new();
1856 while let Some(r) = peer_rows.next().await? {
1857 peers.push(CachedPeer {
1858 id: r.get(0)?,
1859 access_hash: r.get(1)?,
1860 is_channel: r.get::<i32>(2)? != 0,
1861 is_chat: false,
1862 });
1863 }
1864
1865 Ok(PersistedSession {
1866 home_dc_id,
1867 dcs,
1868 updates_state: UpdatesStateSnap {
1869 channels,
1870 ..updates_state
1871 },
1872 peers,
1873 min_peers: Vec::new(),
1874 })
1875 }
1876
1877 async fn write_session_async(
1878 conn: &libsql::Connection,
1879 s: &PersistedSession,
1880 ) -> Result<(), libsql::Error> {
1881 conn.execute_batch("BEGIN IMMEDIATE").await?;
1882
1883 conn.execute(
1884 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1885 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1886 libsql::params![s.home_dc_id],
1887 )
1888 .await?;
1889
1890 conn.execute("DELETE FROM dcs", ()).await?;
1891 for d in &s.dcs {
1892 conn.execute(
1893 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1894 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1895 libsql::params![
1896 d.dc_id,
1897 d.flags.0 as i64,
1898 d.addr.clone(),
1899 d.auth_key.map(|k| k.to_vec()),
1900 d.first_salt,
1901 d.time_offset,
1902 ],
1903 )
1904 .await?;
1905 }
1906
1907 let us = &s.updates_state;
1908 conn.execute(
1909 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1910 ON CONFLICT(id) DO UPDATE SET
1911 pts = MAX(excluded.pts, update_state.pts),
1912 qts = MAX(excluded.qts, update_state.qts),
1913 date = excluded.date,
1914 seq = excluded.seq",
1915 libsql::params![us.pts, us.qts, us.date, us.seq],
1916 )
1917 .await?;
1918
1919 conn.execute("DELETE FROM channel_pts", ()).await?;
1920 for &(cid, cpts) in &us.channels {
1921 conn.execute(
1922 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
1923 libsql::params![cid, cpts],
1924 )
1925 .await?;
1926 }
1927
1928 conn.execute("DELETE FROM peers", ()).await?;
1929 for p in &s.peers {
1930 conn.execute(
1931 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1,?2,?3)",
1932 libsql::params![p.id, p.access_hash, p.is_channel as i32],
1933 )
1934 .await?;
1935 }
1936
1937 conn.execute_batch("COMMIT").await
1938 }
1939}
1940
1941#[cfg(feature = "libsql-session")]
1942impl SessionBackend for LibSqlBackend {
1943 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1944 let conn = self.conn.clone();
1945 let session = session.clone();
1946 Self::block(async move {
1947 let conn = conn.lock().await;
1948 Self::write_session_async(&conn, session).await
1949 })
1950 }
1951
1952 fn load(&self) -> io::Result<Option<PersistedSession>> {
1953 let conn = self.conn.clone();
1954 let count: i64 = Self::block(async move {
1955 let conn = conn.lock().await;
1956 let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
1957 Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
1958 })?;
1959 if count == 0 {
1960 return Ok(None);
1961 }
1962 let conn = self.conn.clone();
1963 Self::block(async move {
1964 let conn = conn.lock().await;
1965 Self::read_session_async(&conn).await
1966 })
1967 .map(Some)
1968 }
1969
1970 fn delete(&self) -> io::Result<()> {
1971 let conn = self.conn.clone();
1972 Self::block(async move {
1973 let conn = conn.lock().await;
1974 conn.execute_batch(
1975 "BEGIN IMMEDIATE;
1976 DELETE FROM meta;
1977 DELETE FROM dcs;
1978 DELETE FROM update_state;
1979 DELETE FROM channel_pts;
1980 DELETE FROM peers;
1981 COMMIT;",
1982 )
1983 .await
1984 })
1985 }
1986
1987 fn name(&self) -> &str {
1988 &self.label
1989 }
1990
1991 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1994 let conn = self.conn.clone();
1995 let (dc_id, addr, key, salt, off, flags) = (
1996 entry.dc_id,
1997 entry.addr.clone(),
1998 entry.auth_key.map(|k| k.to_vec()),
1999 entry.first_salt,
2000 entry.time_offset,
2001 entry.flags.0 as i64,
2002 );
2003 Self::block(async move {
2004 let conn = conn.lock().await;
2005 conn.execute(
2006 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2007 VALUES (?1,?6,?2,?3,?4,?5)
2008 ON CONFLICT(dc_id, flags) DO UPDATE SET
2009 addr=excluded.addr, auth_key=excluded.auth_key,
2010 first_salt=excluded.first_salt, time_offset=excluded.time_offset",
2011 libsql::params![dc_id, addr, key, salt, off, flags],
2012 )
2013 .await
2014 .map(|_| ())
2015 })
2016 }
2017
2018 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
2019 let conn = self.conn.clone();
2020 Self::block(async move {
2021 let conn = conn.lock().await;
2022 conn.execute(
2023 "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
2024 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
2025 libsql::params![dc_id],
2026 )
2027 .await
2028 .map(|_| ())
2029 })
2030 }
2031
2032 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
2033 let conn = self.conn.clone();
2034 Self::block(async move {
2035 let conn = conn.lock().await;
2036 match update {
2037 UpdateStateChange::All(snap) => {
2038 conn.execute(
2039 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
2040 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
2041 date=excluded.date,seq=excluded.seq",
2042 libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
2043 )
2044 .await?;
2045 conn.execute("DELETE FROM channel_pts", ()).await?;
2046 for &(cid, cpts) in &snap.channels {
2047 conn.execute(
2048 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
2049 libsql::params![cid, cpts],
2050 )
2051 .await?;
2052 }
2053 Ok(())
2054 }
2055 UpdateStateChange::Primary { pts, date, seq } => conn
2056 .execute(
2057 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
2058 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
2059 seq=excluded.seq",
2060 libsql::params![pts, date, seq],
2061 )
2062 .await
2063 .map(|_| ()),
2064 UpdateStateChange::Secondary { qts } => conn
2065 .execute(
2066 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2067 ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2068 libsql::params![qts],
2069 )
2070 .await
2071 .map(|_| ()),
2072 UpdateStateChange::Channel { id, pts } => conn
2073 .execute(
2074 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2075 ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2076 libsql::params![id, pts],
2077 )
2078 .await
2079 .map(|_| ()),
2080 }
2081 })
2082 }
2083
2084 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2085 let conn = self.conn.clone();
2086 let (id, hash, is_ch) = (peer.id, peer.access_hash, peer.is_channel as i32);
2087 Self::block(async move {
2088 let conn = conn.lock().await;
2089 conn.execute(
2090 "INSERT INTO peers (id,access_hash,is_channel) VALUES (?1,?2,?3)
2091 ON CONFLICT(id) DO UPDATE SET
2092 access_hash=excluded.access_hash,
2093 is_channel=excluded.is_channel",
2094 libsql::params![id, hash, is_ch],
2095 )
2096 .await
2097 .map(|_| ())
2098 })
2099 }
2100}