1#![deny(unsafe_code)]
14#![cfg_attr(docsrs, feature(doc_cfg))]
15use std::collections::HashMap;
76use std::io::{self, ErrorKind};
77use std::path::Path;
78
79#[cfg(feature = "serde")]
80mod auth_key_serde {
81 use serde::{Deserialize, Deserializer, Serializer};
82
83 pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
84 where
85 S: Serializer,
86 {
87 match value {
88 Some(k) => s.serialize_some(k.as_slice()),
89 None => s.serialize_none(),
90 }
91 }
92
93 pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
94 where
95 D: Deserializer<'de>,
96 {
97 let opt: Option<Vec<u8>> = Option::deserialize(d)?;
98 match opt {
99 None => Ok(None),
100 Some(v) => {
101 let arr: [u8; 256] = v
102 .try_into()
103 .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
104 Ok(Some(arr))
105 }
106 }
107 }
108}
109
110#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
114#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
115pub struct DcFlags(pub u8);
116
117impl DcFlags {
118 pub const NONE: DcFlags = DcFlags(0);
119 pub const IPV6: DcFlags = DcFlags(1 << 0);
120 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
121 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
122 pub const CDN: DcFlags = DcFlags(1 << 3);
123 pub const STATIC: DcFlags = DcFlags(1 << 4);
124
125 pub fn contains(self, other: DcFlags) -> bool {
126 self.0 & other.0 == other.0
127 }
128
129 pub fn set(&mut self, flag: DcFlags) {
130 self.0 |= flag.0;
131 }
132}
133
134impl std::ops::BitOr for DcFlags {
135 type Output = DcFlags;
136 fn bitor(self, rhs: DcFlags) -> DcFlags {
137 DcFlags(self.0 | rhs.0)
138 }
139}
140
141#[derive(Clone, Debug)]
143#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
144pub struct DcEntry {
145 pub dc_id: i32,
146 pub addr: String,
147 #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
148 pub auth_key: Option<[u8; 256]>,
149 pub first_salt: i64,
150 pub time_offset: i32,
151 pub flags: DcFlags,
153}
154
155impl DcEntry {
156 #[inline]
158 pub fn is_ipv6(&self) -> bool {
159 self.flags.contains(DcFlags::IPV6)
160 }
161
162 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
169 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
170 io::Error::new(
171 io::ErrorKind::InvalidData,
172 format!("invalid DC address: {:?}", self.addr),
173 )
174 })
175 }
176
177 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
190 let addr = if ip.contains(':') {
192 format!("[{ip}]:{port}")
193 } else {
194 format!("{ip}:{port}")
195 };
196 Self {
197 dc_id,
198 addr,
199 auth_key: None,
200 first_salt: 0,
201 time_offset: 0,
202 flags,
203 }
204 }
205}
206
207#[derive(Clone, Debug, Default)]
210#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
211pub struct UpdatesStateSnap {
212 pub pts: i32,
214 pub qts: i32,
216 pub date: i32,
218 pub seq: i32,
220 pub channels: Vec<(i64, i32)>,
222}
223
224impl UpdatesStateSnap {
225 #[inline]
227 pub fn is_initialised(&self) -> bool {
228 self.pts > 0
229 }
230
231 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
233 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
234 entry.1 = pts;
235 } else {
236 self.channels.push((channel_id, pts));
237 }
238 }
239
240 pub fn channel_pts(&self, channel_id: i64) -> i32 {
242 self.channels
243 .iter()
244 .find(|c| c.0 == channel_id)
245 .map(|c| c.1)
246 .unwrap_or(0)
247 }
248}
249
250#[derive(Clone, Debug)]
253#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
254pub struct CachedPeer {
255 pub id: i64,
257 pub access_hash: i64,
260 pub is_channel: bool,
262 pub is_chat: bool,
265}
266
267#[derive(Clone, Debug)]
271#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
272pub struct CachedMinPeer {
273 pub user_id: i64,
275 pub peer_id: i64,
277 pub msg_id: i32,
279}
280
281#[derive(Clone, Debug, Default)]
283#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
284pub struct PersistedSession {
285 pub home_dc_id: i32,
286 pub dcs: Vec<DcEntry>,
287 pub updates_state: UpdatesStateSnap,
289 pub peers: Vec<CachedPeer>,
292 pub min_peers: Vec<CachedMinPeer>,
295}
296
297impl PersistedSession {
298 pub fn to_bytes(&self) -> Vec<u8> {
300 let mut b = Vec::with_capacity(512);
301
302 b.push(0x05u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
305
306 b.push(self.dcs.len() as u8);
307 for d in &self.dcs {
308 b.extend_from_slice(&d.dc_id.to_le_bytes());
309 match &d.auth_key {
310 Some(k) => {
311 b.push(1);
312 b.extend_from_slice(k);
313 }
314 None => {
315 b.push(0);
316 }
317 }
318 b.extend_from_slice(&d.first_salt.to_le_bytes());
319 b.extend_from_slice(&d.time_offset.to_le_bytes());
320 let ab = d.addr.as_bytes();
321 b.push(ab.len() as u8);
322 b.extend_from_slice(ab);
323 b.push(d.flags.0);
324 }
325
326 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
327 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
328 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
329 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
330 let ch = &self.updates_state.channels;
331 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
332 for &(cid, cpts) in ch {
333 b.extend_from_slice(&cid.to_le_bytes());
334 b.extend_from_slice(&cpts.to_le_bytes());
335 }
336
337 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
339 for p in &self.peers {
340 b.extend_from_slice(&p.id.to_le_bytes());
341 b.extend_from_slice(&p.access_hash.to_le_bytes());
342 let peer_type: u8 = if p.is_chat {
343 2
344 } else if p.is_channel {
345 1
346 } else {
347 0
348 };
349 b.push(peer_type);
350 }
351
352 b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
353 for m in &self.min_peers {
354 b.extend_from_slice(&m.user_id.to_le_bytes());
355 b.extend_from_slice(&m.peer_id.to_le_bytes());
356 b.extend_from_slice(&m.msg_id.to_le_bytes());
357 }
358
359 b
360 }
361
362 pub fn save(&self, path: &Path) -> io::Result<()> {
369 use std::sync::atomic::{AtomicU64, Ordering};
370 static SEQ: AtomicU64 = AtomicU64::new(0);
371 let n = SEQ.fetch_add(1, Ordering::Relaxed);
372 let tmp = path.with_extension(format!("{n}.tmp"));
373 std::fs::write(&tmp, self.to_bytes())?;
374 std::fs::rename(&tmp, path).inspect_err(|_e| {
375 let _ = std::fs::remove_file(&tmp);
376 })
377 }
378
379 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
381 if buf.is_empty() {
382 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
383 }
384
385 let mut p = 0usize;
386
387 macro_rules! r {
388 ($n:expr) => {{
389 if p + $n > buf.len() {
390 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
391 }
392 let s = &buf[p..p + $n];
393 p += $n;
394 s
395 }};
396 }
397 macro_rules! r_i32 {
398 () => {
399 i32::from_le_bytes(r!(4).try_into().unwrap())
400 };
401 }
402 macro_rules! r_i64 {
403 () => {
404 i64::from_le_bytes(r!(8).try_into().unwrap())
405 };
406 }
407 macro_rules! r_u8 {
408 () => {
409 r!(1)[0]
410 };
411 }
412 macro_rules! r_u16 {
413 () => {
414 u16::from_le_bytes(r!(2).try_into().unwrap())
415 };
416 }
417
418 let first_byte = r_u8!();
419
420 let (home_dc_id, version) = if first_byte == 0x05 {
421 (r_i32!(), 5u8)
422 } else if first_byte == 0x04 {
423 (r_i32!(), 4u8)
424 } else if first_byte == 0x03 {
425 (r_i32!(), 3u8)
426 } else if first_byte == 0x02 {
427 (r_i32!(), 2u8)
428 } else {
429 let rest = r!(3);
430 let mut bytes = [0u8; 4];
431 bytes[0] = first_byte;
432 bytes[1..4].copy_from_slice(rest);
433 (i32::from_le_bytes(bytes), 1u8)
434 };
435
436 let dc_count = r_u8!() as usize;
437 let mut dcs = Vec::with_capacity(dc_count);
438 for _ in 0..dc_count {
439 let dc_id = r_i32!();
440 let has_key = r_u8!();
441 let auth_key = if has_key == 1 {
442 let mut k = [0u8; 256];
443 k.copy_from_slice(r!(256));
444 Some(k)
445 } else {
446 None
447 };
448 let first_salt = r_i64!();
449 let time_offset = r_i32!();
450 let al = r_u8!() as usize;
451 let addr = String::from_utf8_lossy(r!(al)).into_owned();
452 let flags = if version >= 3 {
453 DcFlags(r_u8!())
454 } else {
455 DcFlags::NONE
456 };
457 dcs.push(DcEntry {
458 dc_id,
459 addr,
460 auth_key,
461 first_salt,
462 time_offset,
463 flags,
464 });
465 }
466
467 if version < 2 {
468 return Ok(Self {
469 home_dc_id,
470 dcs,
471 updates_state: UpdatesStateSnap::default(),
472 peers: Vec::new(),
473 min_peers: Vec::new(),
474 });
475 }
476
477 let pts = r_i32!();
478 let qts = r_i32!();
479 let date = r_i32!();
480 let seq = r_i32!();
481 let ch_count = r_u16!() as usize;
482 let mut channels = Vec::with_capacity(ch_count);
483 for _ in 0..ch_count {
484 let cid = r_i64!();
485 let cpts = r_i32!();
486 channels.push((cid, cpts));
487 }
488
489 let peer_count = r_u16!() as usize;
490 let mut peers = Vec::with_capacity(peer_count);
491 for _ in 0..peer_count {
492 let id = r_i64!();
493 let access_hash = r_i64!();
494 let peer_type = r_u8!();
496 let is_channel = peer_type == 1;
497 let is_chat = peer_type == 2;
498 peers.push(CachedPeer {
499 id,
500 access_hash,
501 is_channel,
502 is_chat,
503 });
504 }
505
506 let min_peers = if version >= 4 {
508 let count = r_u16!() as usize;
509 let mut v = Vec::with_capacity(count);
510 for _ in 0..count {
511 let user_id = r_i64!();
512 let peer_id = r_i64!();
513 let msg_id = r_i32!();
514 v.push(CachedMinPeer {
515 user_id,
516 peer_id,
517 msg_id,
518 });
519 }
520 v
521 } else {
522 Vec::new()
523 };
524
525 Ok(Self {
526 home_dc_id,
527 dcs,
528 updates_state: UpdatesStateSnap {
529 pts,
530 qts,
531 date,
532 seq,
533 channels,
534 },
535 peers,
536 min_peers,
537 })
538 }
539
540 pub fn from_string(s: &str) -> io::Result<Self> {
542 use base64::Engine as _;
543 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
544 .decode(s.trim())
545 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
546 Self::from_bytes(&bytes)
547 }
548
549 pub fn load(path: &Path) -> io::Result<Self> {
550 let buf = std::fs::read(path)?;
551 Self::from_bytes(&buf)
552 }
553
554 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
566 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
567 candidates.peek()?;
568 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
570 cands
572 .iter()
573 .copied()
574 .find(|d| d.is_ipv6() == prefer_ipv6)
575 .or_else(|| cands.first().copied())
576 }
577
578 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
583 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
584 }
585}
586
587impl std::fmt::Display for PersistedSession {
588 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589 use base64::Engine as _;
590 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
591 }
592}
593
594pub fn default_dc_addresses() -> HashMap<i32, String> {
596 [
597 (1, "149.154.175.53:443"),
598 (2, "149.154.167.51:443"),
599 (3, "149.154.175.100:443"),
600 (4, "149.154.167.91:443"),
601 (5, "91.108.56.130:443"),
602 ]
603 .into_iter()
604 .map(|(id, addr)| (id, addr.to_string()))
605 .collect()
606}
607
608use std::path::PathBuf;
613
614pub trait SessionBackend: Send + Sync {
622 fn save(&self, session: &PersistedSession) -> io::Result<()>;
623 fn load(&self) -> io::Result<Option<PersistedSession>>;
624 fn delete(&self) -> io::Result<()>;
625
626 fn name(&self) -> &str;
628
629 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
642 let mut s = self.load()?.unwrap_or_default();
643 if let Some(existing) = s
645 .dcs
646 .iter_mut()
647 .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
648 {
649 *existing = entry.clone();
650 } else {
651 s.dcs.push(entry.clone());
652 }
653 self.save(&s)
654 }
655
656 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
662 let mut s = self.load()?.unwrap_or_default();
663 s.home_dc_id = dc_id;
664 self.save(&s)
665 }
666
667 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
672 let mut s = self.load()?.unwrap_or_default();
673 update.apply_to(&mut s.updates_state);
674 self.save(&s)
675 }
676
677 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
683 let mut s = self.load()?.unwrap_or_default();
684 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
685 *existing = peer.clone();
686 } else {
687 s.peers.push(peer.clone());
688 }
689 self.save(&s)
690 }
691}
692
693#[derive(Debug, Clone)]
707pub enum UpdateStateChange {
708 All(UpdatesStateSnap),
710 Primary { pts: i32, date: i32, seq: i32 },
712 Secondary { qts: i32 },
714 Channel { id: i64, pts: i32 },
716}
717
718impl UpdateStateChange {
719 pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
721 match self {
722 Self::All(new_snap) => *snap = new_snap.clone(),
723 Self::Primary { pts, date, seq } => {
724 snap.pts = *pts;
725 snap.date = *date;
726 snap.seq = *seq;
727 }
728 Self::Secondary { qts } => {
729 snap.qts = *qts;
730 }
731 Self::Channel { id, pts } => {
732 if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
734 existing.1 = *pts;
735 } else {
736 snap.channels.push((*id, *pts));
737 }
738 }
739 }
740 }
741}
742
743pub struct BinaryFileBackend {
747 path: PathBuf,
748 write_lock: std::sync::Mutex<()>,
753}
754
755impl BinaryFileBackend {
756 pub fn new(path: impl Into<PathBuf>) -> Self {
757 Self {
758 path: path.into(),
759 write_lock: std::sync::Mutex::new(()),
760 }
761 }
762
763 pub fn path(&self) -> &std::path::Path {
764 &self.path
765 }
766}
767
768impl SessionBackend for BinaryFileBackend {
769 fn save(&self, session: &PersistedSession) -> io::Result<()> {
770 let _guard = self.write_lock.lock().unwrap();
771 session.save(&self.path)
772 }
773
774 fn load(&self) -> io::Result<Option<PersistedSession>> {
775 if !self.path.exists() {
776 return Ok(None);
777 }
778 match PersistedSession::load(&self.path) {
779 Ok(s) => Ok(Some(s)),
780 Err(e) => {
781 let bak = self.path.with_extension("bak");
782 tracing::warn!(
783 "[ferogram] Session file {:?} is corrupt ({e}); \
784 renaming to {:?} and starting fresh",
785 self.path,
786 bak
787 );
788 let _ = std::fs::rename(&self.path, &bak);
789 Ok(None)
790 }
791 }
792 }
793
794 fn delete(&self) -> io::Result<()> {
795 if self.path.exists() {
796 std::fs::remove_file(&self.path)?;
797 }
798 Ok(())
799 }
800
801 fn name(&self) -> &str {
802 "binary-file"
803 }
804
805 }
808
809#[derive(Default)]
817pub struct InMemoryBackend {
818 data: std::sync::Mutex<Option<PersistedSession>>,
819}
820
821impl InMemoryBackend {
822 pub fn new() -> Self {
823 Self::default()
824 }
825
826 pub fn snapshot(&self) -> Option<PersistedSession> {
828 self.data.lock().unwrap().clone()
829 }
830}
831
832impl SessionBackend for InMemoryBackend {
833 fn save(&self, s: &PersistedSession) -> io::Result<()> {
834 *self.data.lock().unwrap() = Some(s.clone());
835 Ok(())
836 }
837
838 fn load(&self) -> io::Result<Option<PersistedSession>> {
839 Ok(self.data.lock().unwrap().clone())
840 }
841
842 fn delete(&self) -> io::Result<()> {
843 *self.data.lock().unwrap() = None;
844 Ok(())
845 }
846
847 fn name(&self) -> &str {
848 "in-memory"
849 }
850
851 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
854 let mut guard = self.data.lock().unwrap();
855 let s = guard.get_or_insert_with(PersistedSession::default);
856 if let Some(existing) = s
857 .dcs
858 .iter_mut()
859 .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
860 {
861 *existing = entry.clone();
862 } else {
863 s.dcs.push(entry.clone());
864 }
865 Ok(())
866 }
867
868 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
869 let mut guard = self.data.lock().unwrap();
870 let s = guard.get_or_insert_with(PersistedSession::default);
871 s.home_dc_id = dc_id;
872 Ok(())
873 }
874
875 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
876 let mut guard = self.data.lock().unwrap();
877 let s = guard.get_or_insert_with(PersistedSession::default);
878 update.apply_to(&mut s.updates_state);
879 Ok(())
880 }
881
882 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
883 let mut guard = self.data.lock().unwrap();
884 let s = guard.get_or_insert_with(PersistedSession::default);
885 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
886 *existing = peer.clone();
887 } else {
888 s.peers.push(peer.clone());
889 }
890 Ok(())
891 }
892}
893
894pub struct StringSessionBackend {
898 data: std::sync::Mutex<String>,
899}
900
901impl StringSessionBackend {
902 pub fn new(s: impl Into<String>) -> Self {
903 Self {
904 data: std::sync::Mutex::new(s.into()),
905 }
906 }
907
908 pub fn current(&self) -> String {
909 self.data.lock().unwrap().clone()
910 }
911}
912
913impl SessionBackend for StringSessionBackend {
914 fn save(&self, session: &PersistedSession) -> io::Result<()> {
915 *self.data.lock().unwrap() = session.to_string();
916 Ok(())
917 }
918
919 fn load(&self) -> io::Result<Option<PersistedSession>> {
920 let s = self.data.lock().unwrap().clone();
921 if s.trim().is_empty() {
922 return Ok(None);
923 }
924 PersistedSession::from_string(&s).map(Some)
925 }
926
927 fn delete(&self) -> io::Result<()> {
928 *self.data.lock().unwrap() = String::new();
929 Ok(())
930 }
931
932 fn name(&self) -> &str {
933 "string-session"
934 }
935}
936
937pub mod string_session;
945pub use string_session::{FullSession, Session, StringSession, StringSessionError};
946
947#[cfg(test)]
948mod tests {
949 use super::*;
950
951 fn make_dc(id: i32) -> DcEntry {
952 DcEntry {
953 dc_id: id,
954 addr: format!("1.2.3.{id}:443"),
955 auth_key: None,
956 first_salt: 0,
957 time_offset: 0,
958 flags: DcFlags::NONE,
959 }
960 }
961
962 fn make_peer(id: i64, hash: i64) -> CachedPeer {
963 CachedPeer {
964 id,
965 access_hash: hash,
966 is_channel: false,
967 is_chat: false,
968 }
969 }
970
971 #[test]
974 fn inmemory_load_returns_none_when_empty() {
975 let b = InMemoryBackend::new();
976 assert!(b.load().unwrap().is_none());
977 }
978
979 #[test]
980 fn inmemory_save_then_load_round_trips() {
981 let b = InMemoryBackend::new();
982 let mut s = PersistedSession::default();
983 s.home_dc_id = 3;
984 s.dcs.push(make_dc(3));
985 b.save(&s).unwrap();
986
987 let loaded = b.load().unwrap().unwrap();
988 assert_eq!(loaded.home_dc_id, 3);
989 assert_eq!(loaded.dcs.len(), 1);
990 }
991
992 #[test]
993 fn inmemory_delete_clears_state() {
994 let b = InMemoryBackend::new();
995 let mut s = PersistedSession::default();
996 s.home_dc_id = 2;
997 b.save(&s).unwrap();
998 b.delete().unwrap();
999 assert!(b.load().unwrap().is_none());
1000 }
1001
1002 #[test]
1005 fn inmemory_update_dc_inserts_new() {
1006 let b = InMemoryBackend::new();
1007 b.update_dc(&make_dc(4)).unwrap();
1008 let s = b.snapshot().unwrap();
1009 assert_eq!(s.dcs.len(), 1);
1010 assert_eq!(s.dcs[0].dc_id, 4);
1011 }
1012
1013 #[test]
1014 fn inmemory_update_dc_replaces_existing() {
1015 let b = InMemoryBackend::new();
1016 b.update_dc(&make_dc(2)).unwrap();
1017 let mut updated = make_dc(2);
1018 updated.addr = "9.9.9.9:443".to_string();
1019 b.update_dc(&updated).unwrap();
1020
1021 let s = b.snapshot().unwrap();
1022 assert_eq!(s.dcs.len(), 1);
1023 assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
1024 }
1025
1026 #[test]
1027 fn inmemory_set_home_dc() {
1028 let b = InMemoryBackend::new();
1029 b.set_home_dc(5).unwrap();
1030 assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
1031 }
1032
1033 #[test]
1034 fn inmemory_cache_peer_inserts() {
1035 let b = InMemoryBackend::new();
1036 b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
1037 let s = b.snapshot().unwrap();
1038 assert_eq!(s.peers.len(), 1);
1039 assert_eq!(s.peers[0].id, 100);
1040 }
1041
1042 #[test]
1043 fn inmemory_cache_peer_updates_existing() {
1044 let b = InMemoryBackend::new();
1045 b.cache_peer(&make_peer(100, 111)).unwrap();
1046 b.cache_peer(&make_peer(100, 222)).unwrap();
1047 let s = b.snapshot().unwrap();
1048 assert_eq!(s.peers.len(), 1);
1049 assert_eq!(s.peers[0].access_hash, 222);
1050 }
1051
1052 #[test]
1055 fn update_state_primary() {
1056 let mut snap = UpdatesStateSnap {
1057 pts: 0,
1058 qts: 0,
1059 date: 0,
1060 seq: 0,
1061 channels: vec![],
1062 };
1063 UpdateStateChange::Primary {
1064 pts: 10,
1065 date: 20,
1066 seq: 30,
1067 }
1068 .apply_to(&mut snap);
1069 assert_eq!(snap.pts, 10);
1070 assert_eq!(snap.date, 20);
1071 assert_eq!(snap.seq, 30);
1072 assert_eq!(snap.qts, 0); }
1074
1075 #[test]
1076 fn update_state_secondary() {
1077 let mut snap = UpdatesStateSnap {
1078 pts: 5,
1079 qts: 0,
1080 date: 0,
1081 seq: 0,
1082 channels: vec![],
1083 };
1084 UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1085 assert_eq!(snap.qts, 99);
1086 assert_eq!(snap.pts, 5); }
1088
1089 #[test]
1090 fn update_state_channel_inserts() {
1091 let mut snap = UpdatesStateSnap {
1092 pts: 0,
1093 qts: 0,
1094 date: 0,
1095 seq: 0,
1096 channels: vec![],
1097 };
1098 UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1099 assert_eq!(snap.channels, vec![(12345, 42)]);
1100 }
1101
1102 #[test]
1103 fn update_state_channel_updates_existing() {
1104 let mut snap = UpdatesStateSnap {
1105 pts: 0,
1106 qts: 0,
1107 date: 0,
1108 seq: 0,
1109 channels: vec![(12345, 10), (67890, 5)],
1110 };
1111 UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1112 assert_eq!(snap.channels[0], (12345, 99));
1114 assert_eq!(snap.channels[1], (67890, 5));
1115 }
1116
1117 #[test]
1118 fn apply_update_state_via_backend() {
1119 let b = InMemoryBackend::new();
1120 b.apply_update_state(UpdateStateChange::Primary {
1121 pts: 7,
1122 date: 8,
1123 seq: 9,
1124 })
1125 .unwrap();
1126 let s = b.snapshot().unwrap();
1127 assert_eq!(s.updates_state.pts, 7);
1128 }
1129
1130 #[test]
1133 fn default_update_dc_via_trait_object() {
1134 let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1135 b.update_dc(&make_dc(1)).unwrap();
1136 b.update_dc(&make_dc(2)).unwrap();
1137 let loaded = b.load().unwrap().unwrap();
1139 assert_eq!(loaded.dcs.len(), 2);
1140 }
1141
1142 fn make_dc_v6(id: i32) -> DcEntry {
1145 DcEntry {
1146 dc_id: id,
1147 addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1148 auth_key: None,
1149 first_salt: 0,
1150 time_offset: 0,
1151 flags: DcFlags::IPV6,
1152 }
1153 }
1154
1155 #[test]
1156 fn dc_entry_from_parts_ipv4() {
1157 let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1158 assert_eq!(dc.addr, "149.154.175.53:443");
1159 assert!(!dc.is_ipv6());
1160 let sa = dc.socket_addr().unwrap();
1161 assert_eq!(sa.port(), 443);
1162 }
1163
1164 #[test]
1165 fn dc_entry_from_parts_ipv6() {
1166 let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1167 assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1168 assert!(dc.is_ipv6());
1169 let sa = dc.socket_addr().unwrap();
1170 assert_eq!(sa.port(), 443);
1171 }
1172
1173 #[test]
1174 fn persisted_session_dc_for_prefers_ipv6() {
1175 let mut s = PersistedSession::default();
1176 s.dcs.push(make_dc(2)); s.dcs.push(make_dc_v6(2)); let v6 = s.dc_for(2, true).unwrap();
1180 assert!(v6.is_ipv6());
1181
1182 let v4 = s.dc_for(2, false).unwrap();
1183 assert!(!v4.is_ipv6());
1184 }
1185
1186 #[test]
1187 fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1188 let mut s = PersistedSession::default();
1189 s.dcs.push(make_dc(3)); let dc = s.dc_for(3, true).unwrap();
1193 assert!(!dc.is_ipv6());
1194 }
1195
1196 #[test]
1197 fn persisted_session_all_dcs_for_returns_both() {
1198 let mut s = PersistedSession::default();
1199 s.dcs.push(make_dc(1));
1200 s.dcs.push(make_dc_v6(1));
1201 s.dcs.push(make_dc(2));
1202
1203 assert_eq!(s.all_dcs_for(1).count(), 2);
1204 assert_eq!(s.all_dcs_for(2).count(), 1);
1205 assert_eq!(s.all_dcs_for(5).count(), 0);
1206 }
1207
1208 #[test]
1209 fn inmemory_ipv4_and_ipv6_coexist() {
1210 let b = InMemoryBackend::new();
1211 b.update_dc(&make_dc(2)).unwrap(); b.update_dc(&make_dc_v6(2)).unwrap(); let s = b.snapshot().unwrap();
1215 assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1217 }
1218
1219 #[test]
1220 fn binary_roundtrip_ipv4_and_ipv6() {
1221 let mut s = PersistedSession::default();
1222 s.home_dc_id = 2;
1223 s.dcs.push(make_dc(2));
1224 s.dcs.push(make_dc_v6(2));
1225
1226 let bytes = s.to_bytes();
1227 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1228 assert_eq!(loaded.dcs.len(), 2);
1229 assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1230 assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1231 }
1232}
1233
1234#[cfg(feature = "sqlite-session")]
1260pub struct SqliteBackend {
1261 conn: std::sync::Mutex<rusqlite::Connection>,
1262 label: String,
1263}
1264
1265#[cfg(feature = "sqlite-session")]
1266impl SqliteBackend {
1267 const SCHEMA: &'static str = "
1268 PRAGMA journal_mode = WAL;
1269 PRAGMA synchronous = NORMAL;
1270
1271 CREATE TABLE IF NOT EXISTS meta (
1272 key TEXT PRIMARY KEY,
1273 value INTEGER NOT NULL DEFAULT 0
1274 );
1275
1276 CREATE TABLE IF NOT EXISTS dcs (
1277 dc_id INTEGER NOT NULL,
1278 flags INTEGER NOT NULL DEFAULT 0,
1279 addr TEXT NOT NULL,
1280 auth_key BLOB,
1281 first_salt INTEGER NOT NULL DEFAULT 0,
1282 time_offset INTEGER NOT NULL DEFAULT 0,
1283 PRIMARY KEY (dc_id, flags)
1284 );
1285
1286 CREATE TABLE IF NOT EXISTS update_state (
1287 id INTEGER PRIMARY KEY CHECK (id = 1),
1288 pts INTEGER NOT NULL DEFAULT 0,
1289 qts INTEGER NOT NULL DEFAULT 0,
1290 date INTEGER NOT NULL DEFAULT 0,
1291 seq INTEGER NOT NULL DEFAULT 0
1292 );
1293
1294 CREATE TABLE IF NOT EXISTS channel_pts (
1295 channel_id INTEGER PRIMARY KEY,
1296 pts INTEGER NOT NULL
1297 );
1298
1299 CREATE TABLE IF NOT EXISTS peers (
1300 id INTEGER PRIMARY KEY,
1301 access_hash INTEGER NOT NULL,
1302 is_channel INTEGER NOT NULL DEFAULT 0,
1303 is_chat INTEGER NOT NULL DEFAULT 0
1304 );
1305
1306 CREATE TABLE IF NOT EXISTS min_peers (
1307 user_id INTEGER PRIMARY KEY,
1308 peer_id INTEGER NOT NULL,
1309 msg_id INTEGER NOT NULL
1310 );
1311 ";
1312
1313 pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1315 let path = path.into();
1316 let label = path.display().to_string();
1317 let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1318 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1319 Self::migrate_legacy_sqlite_schema(&conn)?;
1320 Ok(Self {
1321 conn: std::sync::Mutex::new(conn),
1322 label,
1323 })
1324 }
1325
1326 pub fn in_memory() -> io::Result<Self> {
1328 let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1329 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1330 Self::migrate_legacy_sqlite_schema(&conn)?;
1331 Ok(Self {
1332 conn: std::sync::Mutex::new(conn),
1333 label: ":memory:".into(),
1334 })
1335 }
1336
1337 fn map_err(e: rusqlite::Error) -> io::Error {
1338 io::Error::other(e)
1339 }
1340
1341 fn migrate_legacy_sqlite_schema(conn: &rusqlite::Connection) -> io::Result<()> {
1345 let mut has_is_chat = false;
1346 let mut stmt = conn
1347 .prepare("PRAGMA table_info(peers)")
1348 .map_err(Self::map_err)?;
1349 let cols = stmt
1350 .query_map([], |row| row.get::<_, String>(1))
1351 .map_err(Self::map_err)?;
1352 for col in cols.filter_map(|r| r.ok()) {
1353 if col == "is_chat" {
1354 has_is_chat = true;
1355 break;
1356 }
1357 }
1358 if !has_is_chat {
1359 conn.execute_batch("ALTER TABLE peers ADD COLUMN is_chat INTEGER NOT NULL DEFAULT 0;")
1360 .map_err(Self::map_err)?;
1361 }
1362 conn.execute_batch(
1363 "CREATE TABLE IF NOT EXISTS min_peers (
1364 user_id INTEGER PRIMARY KEY,
1365 peer_id INTEGER NOT NULL,
1366 msg_id INTEGER NOT NULL
1367 );",
1368 )
1369 .map_err(Self::map_err)?;
1370 Ok(())
1371 }
1372
1373 fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1375 let home_dc_id: i32 = conn
1377 .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1378 r.get(0)
1379 })
1380 .unwrap_or(0);
1381
1382 let mut stmt = conn
1384 .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1385 .map_err(Self::map_err)?;
1386 let dcs = stmt
1387 .query_map([], |row| {
1388 let dc_id: i32 = row.get(0)?;
1389 let flags_raw: u8 = row.get(1)?;
1390 let addr: String = row.get(2)?;
1391 let key_blob: Option<Vec<u8>> = row.get(3)?;
1392 let first_salt: i64 = row.get(4)?;
1393 let time_offset: i32 = row.get(5)?;
1394 Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1395 })
1396 .map_err(Self::map_err)?
1397 .filter_map(|r| r.ok())
1398 .map(
1399 |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1400 let auth_key = key_blob.and_then(|b| {
1401 if b.len() == 256 {
1402 let mut k = [0u8; 256];
1403 k.copy_from_slice(&b);
1404 Some(k)
1405 } else {
1406 None
1407 }
1408 });
1409 DcEntry {
1410 dc_id,
1411 addr,
1412 auth_key,
1413 first_salt,
1414 time_offset,
1415 flags: DcFlags(flags_raw),
1416 }
1417 },
1418 )
1419 .collect();
1420
1421 let updates_state = conn
1423 .query_row(
1424 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1425 [],
1426 |r| {
1427 Ok(UpdatesStateSnap {
1428 pts: r.get(0)?,
1429 qts: r.get(1)?,
1430 date: r.get(2)?,
1431 seq: r.get(3)?,
1432 channels: vec![],
1433 })
1434 },
1435 )
1436 .unwrap_or_default();
1437
1438 let mut ch_stmt = conn
1440 .prepare("SELECT channel_id, pts FROM channel_pts")
1441 .map_err(Self::map_err)?;
1442 let channels: Vec<(i64, i32)> = ch_stmt
1443 .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1444 .map_err(Self::map_err)?
1445 .filter_map(|r| r.ok())
1446 .collect();
1447
1448 let mut peer_stmt = conn
1450 .prepare("SELECT id, access_hash, is_channel, is_chat FROM peers")
1451 .map_err(Self::map_err)?;
1452 let peers: Vec<CachedPeer> = peer_stmt
1453 .query_map([], |r| {
1454 Ok(CachedPeer {
1455 id: r.get(0)?,
1456 access_hash: r.get(1)?,
1457 is_channel: r.get::<_, i32>(2)? != 0,
1458 is_chat: r.get::<_, i32>(3)? != 0,
1459 })
1460 })
1461 .map_err(Self::map_err)?
1462 .filter_map(|r| r.ok())
1463 .collect();
1464
1465 let mut min_stmt = conn
1467 .prepare("SELECT user_id, peer_id, msg_id FROM min_peers")
1468 .map_err(Self::map_err)?;
1469 let min_peers: Vec<CachedMinPeer> = min_stmt
1470 .query_map([], |r| {
1471 Ok(CachedMinPeer {
1472 user_id: r.get(0)?,
1473 peer_id: r.get(1)?,
1474 msg_id: r.get(2)?,
1475 })
1476 })
1477 .map_err(Self::map_err)?
1478 .filter_map(|r| r.ok())
1479 .collect();
1480
1481 Ok(PersistedSession {
1482 home_dc_id,
1483 dcs,
1484 updates_state: UpdatesStateSnap {
1485 channels,
1486 ..updates_state
1487 },
1488 peers,
1489 min_peers,
1490 })
1491 }
1492
1493 fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1495 conn.execute_batch("BEGIN IMMEDIATE")
1496 .map_err(Self::map_err)?;
1497
1498 conn.execute(
1499 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1500 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1501 rusqlite::params![s.home_dc_id],
1502 )
1503 .map_err(Self::map_err)?;
1504
1505 conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1507 for d in &s.dcs {
1508 conn.execute(
1509 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1510 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1511 rusqlite::params![
1512 d.dc_id,
1513 d.flags.0,
1514 d.addr,
1515 d.auth_key.as_ref().map(|k| k.as_ref()),
1516 d.first_salt,
1517 d.time_offset,
1518 ],
1519 )
1520 .map_err(Self::map_err)?;
1521 }
1522
1523 let us = &s.updates_state;
1527 conn.execute(
1528 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1529 ON CONFLICT(id) DO UPDATE SET
1530 pts = MAX(excluded.pts, update_state.pts),
1531 qts = MAX(excluded.qts, update_state.qts),
1532 date = excluded.date,
1533 seq = excluded.seq",
1534 rusqlite::params![us.pts, us.qts, us.date, us.seq],
1535 )
1536 .map_err(Self::map_err)?;
1537
1538 conn.execute("DELETE FROM channel_pts", [])
1539 .map_err(Self::map_err)?;
1540 for &(cid, cpts) in &us.channels {
1541 conn.execute(
1542 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1543 rusqlite::params![cid, cpts],
1544 )
1545 .map_err(Self::map_err)?;
1546 }
1547
1548 conn.execute("DELETE FROM peers", [])
1550 .map_err(Self::map_err)?;
1551 for p in &s.peers {
1552 conn.execute(
1553 "INSERT INTO peers (id, access_hash, is_channel, is_chat) VALUES (?1, ?2, ?3, ?4)",
1554 rusqlite::params![p.id, p.access_hash, p.is_channel as i32, p.is_chat as i32],
1555 )
1556 .map_err(Self::map_err)?;
1557 }
1558
1559 conn.execute("DELETE FROM min_peers", [])
1561 .map_err(Self::map_err)?;
1562 for m in &s.min_peers {
1563 conn.execute(
1564 "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1, ?2, ?3)",
1565 rusqlite::params![m.user_id, m.peer_id, m.msg_id],
1566 )
1567 .map_err(Self::map_err)?;
1568 }
1569
1570 conn.execute_batch("COMMIT").map_err(Self::map_err)
1571 }
1572}
1573
1574#[cfg(feature = "sqlite-session")]
1575impl SessionBackend for SqliteBackend {
1576 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1577 let conn = self.conn.lock().unwrap();
1578 Self::write_session(&conn, session)
1579 }
1580
1581 fn load(&self) -> io::Result<Option<PersistedSession>> {
1582 let conn = self.conn.lock().unwrap();
1583 let count: i64 = conn
1585 .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1586 .map_err(Self::map_err)?;
1587 if count == 0 {
1588 return Ok(None);
1589 }
1590 Self::read_session(&conn).map(Some)
1591 }
1592
1593 fn delete(&self) -> io::Result<()> {
1594 let conn = self.conn.lock().unwrap();
1595 conn.execute_batch(
1596 "BEGIN IMMEDIATE;
1597 DELETE FROM meta;
1598 DELETE FROM dcs;
1599 DELETE FROM update_state;
1600 DELETE FROM channel_pts;
1601 DELETE FROM peers;
1602 DELETE FROM min_peers;
1603 COMMIT;",
1604 )
1605 .map_err(Self::map_err)
1606 }
1607
1608 fn name(&self) -> &str {
1609 &self.label
1610 }
1611
1612 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1615 let conn = self.conn.lock().unwrap();
1616 conn.execute(
1617 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1618 VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1619 ON CONFLICT(dc_id, flags) DO UPDATE SET
1620 addr = excluded.addr,
1621 auth_key = excluded.auth_key,
1622 first_salt = excluded.first_salt,
1623 time_offset = excluded.time_offset",
1624 rusqlite::params![
1625 entry.dc_id,
1626 entry.addr,
1627 entry.auth_key.as_ref().map(|k| k.as_ref()),
1628 entry.first_salt,
1629 entry.time_offset,
1630 entry.flags.0,
1631 ],
1632 )
1633 .map(|_| ())
1634 .map_err(Self::map_err)
1635 }
1636
1637 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1638 let conn = self.conn.lock().unwrap();
1639 conn.execute(
1640 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1641 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1642 rusqlite::params![dc_id],
1643 )
1644 .map(|_| ())
1645 .map_err(Self::map_err)
1646 }
1647
1648 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1649 let conn = self.conn.lock().unwrap();
1650 match update {
1651 UpdateStateChange::All(snap) => {
1652 conn.execute(
1653 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1654 ON CONFLICT(id) DO UPDATE SET
1655 pts=excluded.pts, qts=excluded.qts,
1656 date=excluded.date, seq=excluded.seq",
1657 rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1658 )
1659 .map_err(Self::map_err)?;
1660 conn.execute("DELETE FROM channel_pts", [])
1661 .map_err(Self::map_err)?;
1662 for &(cid, cpts) in &snap.channels {
1663 conn.execute(
1664 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1665 rusqlite::params![cid, cpts],
1666 )
1667 .map_err(Self::map_err)?;
1668 }
1669 Ok(())
1670 }
1671 UpdateStateChange::Primary { pts, date, seq } => conn
1672 .execute(
1673 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1674 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1675 seq=excluded.seq",
1676 rusqlite::params![pts, date, seq],
1677 )
1678 .map(|_| ())
1679 .map_err(Self::map_err),
1680 UpdateStateChange::Secondary { qts } => conn
1681 .execute(
1682 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1683 ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1684 rusqlite::params![qts],
1685 )
1686 .map(|_| ())
1687 .map_err(Self::map_err),
1688 UpdateStateChange::Channel { id, pts } => conn
1689 .execute(
1690 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1691 ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1692 rusqlite::params![id, pts],
1693 )
1694 .map(|_| ())
1695 .map_err(Self::map_err),
1696 }
1697 }
1698
1699 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1700 let conn = self.conn.lock().unwrap();
1701 conn.execute(
1702 "INSERT INTO peers (id, access_hash, is_channel, is_chat) VALUES (?1, ?2, ?3, ?4)
1703 ON CONFLICT(id) DO UPDATE SET
1704 access_hash = excluded.access_hash,
1705 is_channel = excluded.is_channel,
1706 is_chat = excluded.is_chat",
1707 rusqlite::params![
1708 peer.id,
1709 peer.access_hash,
1710 peer.is_channel as i32,
1711 peer.is_chat as i32
1712 ],
1713 )
1714 .map(|_| ())
1715 .map_err(Self::map_err)
1716 }
1717}
1718
1719#[cfg(feature = "libsql-session")]
1739pub struct LibSqlBackend {
1740 conn: libsql::Connection,
1741 label: String,
1742}
1743
1744#[cfg(feature = "libsql-session")]
1745impl LibSqlBackend {
1746 const SCHEMA: &'static str = "
1747 CREATE TABLE IF NOT EXISTS meta (
1748 key TEXT PRIMARY KEY,
1749 value INTEGER NOT NULL DEFAULT 0
1750 );
1751 CREATE TABLE IF NOT EXISTS dcs (
1752 dc_id INTEGER NOT NULL,
1753 flags INTEGER NOT NULL DEFAULT 0,
1754 addr TEXT NOT NULL,
1755 auth_key BLOB,
1756 first_salt INTEGER NOT NULL DEFAULT 0,
1757 time_offset INTEGER NOT NULL DEFAULT 0,
1758 PRIMARY KEY (dc_id, flags)
1759 );
1760 CREATE TABLE IF NOT EXISTS update_state (
1761 id INTEGER PRIMARY KEY CHECK (id = 1),
1762 pts INTEGER NOT NULL DEFAULT 0,
1763 qts INTEGER NOT NULL DEFAULT 0,
1764 date INTEGER NOT NULL DEFAULT 0,
1765 seq INTEGER NOT NULL DEFAULT 0
1766 );
1767 CREATE TABLE IF NOT EXISTS channel_pts (
1768 channel_id INTEGER PRIMARY KEY,
1769 pts INTEGER NOT NULL
1770 );
1771 CREATE TABLE IF NOT EXISTS peers (
1772 id INTEGER PRIMARY KEY,
1773 access_hash INTEGER NOT NULL,
1774 is_channel INTEGER NOT NULL DEFAULT 0,
1775 is_chat INTEGER NOT NULL DEFAULT 0
1776 );
1777 CREATE TABLE IF NOT EXISTS min_peers (
1778 user_id INTEGER PRIMARY KEY,
1779 peer_id INTEGER NOT NULL,
1780 msg_id INTEGER NOT NULL
1781 );
1782 ";
1783
1784 fn block<F, T>(fut: F) -> io::Result<T>
1785 where
1786 F: std::future::Future<Output = Result<T, libsql::Error>>,
1787 {
1788 tokio::runtime::Handle::current()
1789 .block_on(fut)
1790 .map_err(io::Error::other)
1791 }
1792
1793 async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1794 conn.execute_batch(Self::SCHEMA).await
1795 }
1796
1797 pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1799 let path = path.into();
1800 let label = path.display().to_string();
1801 let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1802 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1803 Self::block(Self::apply_schema(&conn))?;
1804 Ok(Self {
1805 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1806 label,
1807 })
1808 }
1809
1810 pub fn in_memory() -> io::Result<Self> {
1812 let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1813 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1814 Self::block(Self::apply_schema(&conn))?;
1815 Ok(Self {
1816 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1817 label: ":memory:".into(),
1818 })
1819 }
1820
1821 pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1823 let url = url.into();
1824 let label = url.clone();
1825 let db = Self::block(async {
1826 libsql::Builder::new_remote(url, auth_token.into())
1827 .build()
1828 .await
1829 })?;
1830 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1831 Self::block(Self::apply_schema(&conn))?;
1832 Ok(Self {
1833 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1834 label,
1835 })
1836 }
1837
1838 pub fn open_replica(
1840 path: impl Into<PathBuf>,
1841 url: impl Into<String>,
1842 auth_token: impl Into<String>,
1843 ) -> io::Result<Self> {
1844 let path = path.into();
1845 let label = format!("{} (replica of {})", path.display(), url.into());
1846 let db = Self::block(async {
1847 libsql::Builder::new_remote_replica(path, url.into(), auth_token.into())
1848 .build()
1849 .await
1850 })?;
1851 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1852 Self::block(Self::apply_schema(&conn))?;
1853 Ok(Self {
1854 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1855 label,
1856 })
1857 }
1858
1859 async fn read_session_async(
1860 conn: &libsql::Connection,
1861 ) -> Result<PersistedSession, libsql::Error> {
1862 use libsql::de;
1863
1864 let home_dc_id: i32 = conn
1866 .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
1867 .await?
1868 .next()
1869 .await?
1870 .map(|r| r.get::<i32>(0))
1871 .transpose()?
1872 .unwrap_or(0);
1873
1874 let mut rows = conn
1876 .query(
1877 "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
1878 (),
1879 )
1880 .await?;
1881 let mut dcs = Vec::new();
1882 while let Some(row) = rows.next().await? {
1883 let dc_id: i32 = row.get(0)?;
1884 let flags_raw: u8 = row.get::<i64>(1)? as u8;
1885 let addr: String = row.get(2)?;
1886 let key_blob: Option<Vec<u8>> = row.get(3)?;
1887 let first_salt: i64 = row.get(4)?;
1888 let time_offset: i32 = row.get(5)?;
1889 let auth_key = match key_blob {
1890 Some(b) if b.len() == 256 => {
1891 let mut k = [0u8; 256];
1892 k.copy_from_slice(&b);
1893 Some(k)
1894 }
1895 Some(b) => {
1896 return Err(libsql::Error::Misuse(format!(
1897 "auth_key blob must be 256 bytes, got {}",
1898 b.len()
1899 )));
1900 }
1901 None => None,
1902 };
1903 dcs.push(DcEntry {
1904 dc_id,
1905 addr,
1906 auth_key,
1907 first_salt,
1908 time_offset,
1909 flags: DcFlags(flags_raw),
1910 });
1911 }
1912
1913 let mut us_row = conn
1915 .query(
1916 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1917 (),
1918 )
1919 .await?;
1920 let updates_state = if let Some(r) = us_row.next().await? {
1921 UpdatesStateSnap {
1922 pts: r.get(0)?,
1923 qts: r.get(1)?,
1924 date: r.get(2)?,
1925 seq: r.get(3)?,
1926 channels: vec![],
1927 }
1928 } else {
1929 UpdatesStateSnap::default()
1930 };
1931
1932 let mut ch_rows = conn
1934 .query("SELECT channel_id, pts FROM channel_pts", ())
1935 .await?;
1936 let mut channels = Vec::new();
1937 while let Some(r) = ch_rows.next().await? {
1938 channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
1939 }
1940
1941 let mut peer_rows = conn
1943 .query("SELECT id, access_hash, is_channel, is_chat FROM peers", ())
1944 .await?;
1945 let mut peers = Vec::new();
1946 while let Some(r) = peer_rows.next().await? {
1947 peers.push(CachedPeer {
1948 id: r.get(0)?,
1949 access_hash: r.get(1)?,
1950 is_channel: r.get::<i32>(2)? != 0,
1951 is_chat: r.get::<i32>(3)? != 0,
1952 });
1953 }
1954
1955 let mut min_rows = conn
1957 .query("SELECT user_id, peer_id, msg_id FROM min_peers", ())
1958 .await?;
1959 let mut min_peers = Vec::new();
1960 while let Some(r) = min_rows.next().await? {
1961 min_peers.push(CachedMinPeer {
1962 user_id: r.get(0)?,
1963 peer_id: r.get(1)?,
1964 msg_id: r.get(2)?,
1965 });
1966 }
1967
1968 Ok(PersistedSession {
1969 home_dc_id,
1970 dcs,
1971 updates_state: UpdatesStateSnap {
1972 channels,
1973 ..updates_state
1974 },
1975 peers,
1976 min_peers,
1977 })
1978 }
1979
1980 async fn write_session_async(
1981 conn: &libsql::Connection,
1982 s: &PersistedSession,
1983 ) -> Result<(), libsql::Error> {
1984 conn.execute_batch("BEGIN IMMEDIATE").await?;
1985
1986 conn.execute(
1987 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1988 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1989 libsql::params![s.home_dc_id],
1990 )
1991 .await?;
1992
1993 conn.execute("DELETE FROM dcs", ()).await?;
1994 for d in &s.dcs {
1995 conn.execute(
1996 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1997 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1998 libsql::params![
1999 d.dc_id,
2000 d.flags.0 as i64,
2001 d.addr.clone(),
2002 d.auth_key.map(|k| k.to_vec()),
2003 d.first_salt,
2004 d.time_offset,
2005 ],
2006 )
2007 .await?;
2008 }
2009
2010 let us = &s.updates_state;
2011 conn.execute(
2012 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
2013 ON CONFLICT(id) DO UPDATE SET
2014 pts = MAX(excluded.pts, update_state.pts),
2015 qts = MAX(excluded.qts, update_state.qts),
2016 date = excluded.date,
2017 seq = excluded.seq",
2018 libsql::params![us.pts, us.qts, us.date, us.seq],
2019 )
2020 .await?;
2021
2022 conn.execute("DELETE FROM channel_pts", ()).await?;
2023 for &(cid, cpts) in &us.channels {
2024 conn.execute(
2025 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
2026 libsql::params![cid, cpts],
2027 )
2028 .await?;
2029 }
2030
2031 conn.execute("DELETE FROM peers", ()).await?;
2032 for p in &s.peers {
2033 conn.execute(
2034 "INSERT INTO peers (id, access_hash, is_channel, is_chat) VALUES (?1,?2,?3,?4)",
2035 libsql::params![p.id, p.access_hash, p.is_channel as i32, p.is_chat as i32],
2036 )
2037 .await?;
2038 }
2039
2040 conn.execute("DELETE FROM min_peers", ()).await?;
2041 for m in &s.min_peers {
2042 conn.execute(
2043 "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1,?2,?3)",
2044 libsql::params![m.user_id, m.peer_id, m.msg_id],
2045 )
2046 .await?;
2047 }
2048
2049 conn.execute_batch("COMMIT").await
2050 }
2051}
2052
2053#[cfg(feature = "libsql-session")]
2054impl SessionBackend for LibSqlBackend {
2055 fn save(&self, session: &PersistedSession) -> io::Result<()> {
2056 let conn = self.conn.clone();
2057 let session = session.clone();
2058 Self::block(async move {
2059 let conn = conn.lock().await;
2060 Self::write_session_async(&conn, session).await
2061 })
2062 }
2063
2064 fn load(&self) -> io::Result<Option<PersistedSession>> {
2065 let conn = self.conn.clone();
2066 let count: i64 = Self::block(async move {
2067 let conn = conn.lock().await;
2068 let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
2069 Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
2070 })?;
2071 if count == 0 {
2072 return Ok(None);
2073 }
2074 let conn = self.conn.clone();
2075 Self::block(async move {
2076 let conn = conn.lock().await;
2077 Self::read_session_async(&conn).await
2078 })
2079 .map(Some)
2080 }
2081
2082 fn delete(&self) -> io::Result<()> {
2083 let conn = self.conn.clone();
2084 Self::block(async move {
2085 let conn = conn.lock().await;
2086 conn.execute_batch(
2087 "BEGIN IMMEDIATE;
2088 DELETE FROM meta;
2089 DELETE FROM dcs;
2090 DELETE FROM update_state;
2091 DELETE FROM channel_pts;
2092 DELETE FROM peers;
2093 DELETE FROM min_peers;
2094 COMMIT;",
2095 )
2096 .await
2097 })
2098 }
2099
2100 fn name(&self) -> &str {
2101 &self.label
2102 }
2103
2104 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
2107 let conn = self.conn.clone();
2108 let (dc_id, addr, key, salt, off, flags) = (
2109 entry.dc_id,
2110 entry.addr.clone(),
2111 entry.auth_key.map(|k| k.to_vec()),
2112 entry.first_salt,
2113 entry.time_offset,
2114 entry.flags.0 as i64,
2115 );
2116 Self::block(async move {
2117 let conn = conn.lock().await;
2118 conn.execute(
2119 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2120 VALUES (?1,?6,?2,?3,?4,?5)
2121 ON CONFLICT(dc_id, flags) DO UPDATE SET
2122 addr=excluded.addr, auth_key=excluded.auth_key,
2123 first_salt=excluded.first_salt, time_offset=excluded.time_offset",
2124 libsql::params![dc_id, addr, key, salt, off, flags],
2125 )
2126 .await
2127 .map(|_| ())
2128 })
2129 }
2130
2131 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
2132 let conn = self.conn.clone();
2133 Self::block(async move {
2134 let conn = conn.lock().await;
2135 conn.execute(
2136 "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
2137 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
2138 libsql::params![dc_id],
2139 )
2140 .await
2141 .map(|_| ())
2142 })
2143 }
2144
2145 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
2146 let conn = self.conn.clone();
2147 Self::block(async move {
2148 let conn = conn.lock().await;
2149 match update {
2150 UpdateStateChange::All(snap) => {
2151 conn.execute(
2152 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
2153 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
2154 date=excluded.date,seq=excluded.seq",
2155 libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
2156 )
2157 .await?;
2158 conn.execute("DELETE FROM channel_pts", ()).await?;
2159 for &(cid, cpts) in &snap.channels {
2160 conn.execute(
2161 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
2162 libsql::params![cid, cpts],
2163 )
2164 .await?;
2165 }
2166 Ok(())
2167 }
2168 UpdateStateChange::Primary { pts, date, seq } => conn
2169 .execute(
2170 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
2171 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
2172 seq=excluded.seq",
2173 libsql::params![pts, date, seq],
2174 )
2175 .await
2176 .map(|_| ()),
2177 UpdateStateChange::Secondary { qts } => conn
2178 .execute(
2179 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2180 ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2181 libsql::params![qts],
2182 )
2183 .await
2184 .map(|_| ()),
2185 UpdateStateChange::Channel { id, pts } => conn
2186 .execute(
2187 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2188 ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2189 libsql::params![id, pts],
2190 )
2191 .await
2192 .map(|_| ()),
2193 }
2194 })
2195 }
2196
2197 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2198 let conn = self.conn.clone();
2199 let (id, hash, is_ch, is_ct) = (
2200 peer.id,
2201 peer.access_hash,
2202 peer.is_channel as i32,
2203 peer.is_chat as i32,
2204 );
2205 Self::block(async move {
2206 let conn = conn.lock().await;
2207 conn.execute(
2208 "INSERT INTO peers (id,access_hash,is_channel,is_chat) VALUES (?1,?2,?3,?4)
2209 ON CONFLICT(id) DO UPDATE SET
2210 access_hash=excluded.access_hash,
2211 is_channel=excluded.is_channel,
2212 is_chat=excluded.is_chat",
2213 libsql::params![id, hash, is_ch, is_ct],
2214 )
2215 .await
2216 .map(|_| ())
2217 })
2218 }
2219}