1#![deny(unsafe_code)]
14#![cfg_attr(docsrs, feature(doc_cfg))]
15use std::collections::HashMap;
77use std::io::{self, ErrorKind};
78use std::path::Path;
79
80#[cfg(feature = "serde")]
81mod auth_key_serde {
82 use serde::{Deserialize, Deserializer, Serializer};
83
84 pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
85 where
86 S: Serializer,
87 {
88 match value {
89 Some(k) => s.serialize_some(k.as_slice()),
90 None => s.serialize_none(),
91 }
92 }
93
94 pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
95 where
96 D: Deserializer<'de>,
97 {
98 let opt: Option<Vec<u8>> = Option::deserialize(d)?;
99 match opt {
100 None => Ok(None),
101 Some(v) => {
102 let arr: [u8; 256] = v
103 .try_into()
104 .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
105 Ok(Some(arr))
106 }
107 }
108 }
109}
110
111#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
115#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
116pub struct DcFlags(pub u8);
117
118impl DcFlags {
119 pub const NONE: DcFlags = DcFlags(0);
120 pub const IPV6: DcFlags = DcFlags(1 << 0);
121 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
122 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
123 pub const CDN: DcFlags = DcFlags(1 << 3);
124 pub const STATIC: DcFlags = DcFlags(1 << 4);
125
126 pub fn contains(self, other: DcFlags) -> bool {
127 self.0 & other.0 == other.0
128 }
129
130 pub fn set(&mut self, flag: DcFlags) {
131 self.0 |= flag.0;
132 }
133}
134
135impl std::ops::BitOr for DcFlags {
136 type Output = DcFlags;
137 fn bitor(self, rhs: DcFlags) -> DcFlags {
138 DcFlags(self.0 | rhs.0)
139 }
140}
141
142#[derive(Clone, Debug)]
144#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
145pub struct DcEntry {
146 pub dc_id: i32,
147 pub addr: String,
148 #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
149 pub auth_key: Option<[u8; 256]>,
150 pub first_salt: i64,
151 pub time_offset: i32,
152 pub flags: DcFlags,
154}
155
156impl DcEntry {
157 #[inline]
159 pub fn is_ipv6(&self) -> bool {
160 self.flags.contains(DcFlags::IPV6)
161 }
162
163 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
170 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
171 io::Error::new(
172 io::ErrorKind::InvalidData,
173 format!("invalid DC address: {:?}", self.addr),
174 )
175 })
176 }
177
178 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
191 let addr = if ip.contains(':') {
193 format!("[{ip}]:{port}")
194 } else {
195 format!("{ip}:{port}")
196 };
197 Self {
198 dc_id,
199 addr,
200 auth_key: None,
201 first_salt: 0,
202 time_offset: 0,
203 flags,
204 }
205 }
206}
207
208#[derive(Clone, Debug, Default)]
211#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
212pub struct UpdatesStateSnap {
213 pub pts: i32,
215 pub qts: i32,
217 pub date: i32,
219 pub seq: i32,
221 pub channels: Vec<(i64, i32)>,
223}
224
225impl UpdatesStateSnap {
226 #[inline]
228 pub fn is_initialised(&self) -> bool {
229 self.pts > 0
230 }
231
232 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
234 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
235 entry.1 = pts;
236 } else {
237 self.channels.push((channel_id, pts));
238 }
239 }
240
241 pub fn channel_pts(&self, channel_id: i64) -> i32 {
243 self.channels
244 .iter()
245 .find(|c| c.0 == channel_id)
246 .map(|c| c.1)
247 .unwrap_or(0)
248 }
249}
250
251#[derive(Clone, Copy, Debug, PartialEq, Eq)]
254#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
255#[repr(u8)]
256pub enum ChannelKind {
257 Broadcast = 0,
259 Megagroup = 1,
261 Gigagroup = 2,
263}
264
265impl ChannelKind {
266 pub fn from_byte(b: u8) -> Option<Self> {
268 match b {
269 0 => Some(Self::Broadcast),
270 1 => Some(Self::Megagroup),
271 2 => Some(Self::Gigagroup),
272 _ => None,
273 }
274 }
275
276 pub fn to_byte(self) -> u8 {
277 self as u8
278 }
279}
280
281#[derive(Clone, Debug)]
284#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
285pub struct CachedPeer {
286 pub id: i64,
288 pub access_hash: i64,
291 pub is_channel: bool,
293 pub is_chat: bool,
296 pub channel_kind: Option<ChannelKind>,
299}
300
301#[derive(Clone, Debug)]
305#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
306pub struct CachedMinPeer {
307 pub user_id: i64,
309 pub peer_id: i64,
311 pub msg_id: i32,
313}
314
315#[derive(Clone, Debug, Default)]
317#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
318pub struct PersistedSession {
319 pub home_dc_id: i32,
320 pub dcs: Vec<DcEntry>,
321 pub updates_state: UpdatesStateSnap,
323 pub peers: Vec<CachedPeer>,
326 pub min_peers: Vec<CachedMinPeer>,
329}
330
331impl PersistedSession {
332 pub fn to_bytes(&self) -> Vec<u8> {
334 let mut b = Vec::with_capacity(512);
335
336 b.push(0x06u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
339
340 b.push(self.dcs.len() as u8);
341 for d in &self.dcs {
342 b.extend_from_slice(&d.dc_id.to_le_bytes());
343 match &d.auth_key {
344 Some(k) => {
345 b.push(1);
346 b.extend_from_slice(k);
347 }
348 None => {
349 b.push(0);
350 }
351 }
352 b.extend_from_slice(&d.first_salt.to_le_bytes());
353 b.extend_from_slice(&d.time_offset.to_le_bytes());
354 let ab = d.addr.as_bytes();
355 b.push(ab.len() as u8);
356 b.extend_from_slice(ab);
357 b.push(d.flags.0);
358 }
359
360 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
361 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
362 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
363 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
364 let ch = &self.updates_state.channels;
365 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
366 for &(cid, cpts) in ch {
367 b.extend_from_slice(&cid.to_le_bytes());
368 b.extend_from_slice(&cpts.to_le_bytes());
369 }
370
371 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
373 for p in &self.peers {
374 b.extend_from_slice(&p.id.to_le_bytes());
375 b.extend_from_slice(&p.access_hash.to_le_bytes());
376 let peer_type: u8 = if p.is_chat {
377 2
378 } else if p.is_channel {
379 1
380 } else {
381 0
382 };
383 b.push(peer_type);
384 b.push(p.channel_kind.map(|k| k.to_byte()).unwrap_or(0xFF));
386 }
387
388 b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
389 for m in &self.min_peers {
390 b.extend_from_slice(&m.user_id.to_le_bytes());
391 b.extend_from_slice(&m.peer_id.to_le_bytes());
392 b.extend_from_slice(&m.msg_id.to_le_bytes());
393 }
394
395 b
396 }
397
398 pub fn save(&self, path: &Path) -> io::Result<()> {
405 use std::sync::atomic::{AtomicU64, Ordering};
406 static SEQ: AtomicU64 = AtomicU64::new(0);
407 let n = SEQ.fetch_add(1, Ordering::Relaxed);
408 let tmp = path.with_extension(format!("{n}.tmp"));
409 std::fs::write(&tmp, self.to_bytes())?;
410 std::fs::rename(&tmp, path).inspect_err(|_e| {
411 let _ = std::fs::remove_file(&tmp);
412 })
413 }
414
415 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
417 if buf.is_empty() {
418 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
419 }
420
421 let mut p = 0usize;
422
423 macro_rules! r {
424 ($n:expr) => {{
425 if p + $n > buf.len() {
426 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
427 }
428 let s = &buf[p..p + $n];
429 p += $n;
430 s
431 }};
432 }
433 macro_rules! r_i32 {
434 () => {
435 i32::from_le_bytes(r!(4).try_into().unwrap())
436 };
437 }
438 macro_rules! r_i64 {
439 () => {
440 i64::from_le_bytes(r!(8).try_into().unwrap())
441 };
442 }
443 macro_rules! r_u8 {
444 () => {
445 r!(1)[0]
446 };
447 }
448 macro_rules! r_u16 {
449 () => {
450 u16::from_le_bytes(r!(2).try_into().unwrap())
451 };
452 }
453
454 let first_byte = r_u8!();
455
456 let (home_dc_id, version) = if first_byte == 0x06 {
457 (r_i32!(), 6u8)
458 } else if first_byte == 0x05 {
459 (r_i32!(), 5u8)
460 } else if first_byte == 0x04 {
461 (r_i32!(), 4u8)
462 } else if first_byte == 0x03 {
463 (r_i32!(), 3u8)
464 } else if first_byte == 0x02 {
465 (r_i32!(), 2u8)
466 } else {
467 let rest = r!(3);
468 let mut bytes = [0u8; 4];
469 bytes[0] = first_byte;
470 bytes[1..4].copy_from_slice(rest);
471 (i32::from_le_bytes(bytes), 1u8)
472 };
473
474 let dc_count = r_u8!() as usize;
475 let mut dcs = Vec::with_capacity(dc_count);
476 for _ in 0..dc_count {
477 let dc_id = r_i32!();
478 let has_key = r_u8!();
479 let auth_key = if has_key == 1 {
480 let mut k = [0u8; 256];
481 k.copy_from_slice(r!(256));
482 Some(k)
483 } else {
484 None
485 };
486 let first_salt = r_i64!();
487 let time_offset = r_i32!();
488 let al = r_u8!() as usize;
489 let addr = String::from_utf8_lossy(r!(al)).into_owned();
490 let flags = if version >= 3 {
491 DcFlags(r_u8!())
492 } else {
493 DcFlags::NONE
494 };
495 dcs.push(DcEntry {
496 dc_id,
497 addr,
498 auth_key,
499 first_salt,
500 time_offset,
501 flags,
502 });
503 }
504
505 if version < 2 {
506 return Ok(Self {
507 home_dc_id,
508 dcs,
509 updates_state: UpdatesStateSnap::default(),
510 peers: Vec::new(),
511 min_peers: Vec::new(),
512 });
513 }
514
515 let pts = r_i32!();
516 let qts = r_i32!();
517 let date = r_i32!();
518 let seq = r_i32!();
519 let ch_count = r_u16!() as usize;
520 let mut channels = Vec::with_capacity(ch_count);
521 for _ in 0..ch_count {
522 let cid = r_i64!();
523 let cpts = r_i32!();
524 channels.push((cid, cpts));
525 }
526
527 let peer_count = r_u16!() as usize;
528 let mut peers = Vec::with_capacity(peer_count);
529 for _ in 0..peer_count {
530 let id = r_i64!();
531 let access_hash = r_i64!();
532 let peer_type = r_u8!();
534 let is_channel = peer_type == 1;
535 let is_chat = peer_type == 2;
536 let channel_kind = if version >= 6 {
538 let kb = r_u8!();
539 if kb == 0xFF {
540 None
541 } else {
542 ChannelKind::from_byte(kb)
543 }
544 } else {
545 None
546 };
547 peers.push(CachedPeer {
548 id,
549 access_hash,
550 is_channel,
551 is_chat,
552 channel_kind,
553 });
554 }
555
556 let min_peers = if version >= 4 {
558 let count = r_u16!() as usize;
559 let mut v = Vec::with_capacity(count);
560 for _ in 0..count {
561 let user_id = r_i64!();
562 let peer_id = r_i64!();
563 let msg_id = r_i32!();
564 v.push(CachedMinPeer {
565 user_id,
566 peer_id,
567 msg_id,
568 });
569 }
570 v
571 } else {
572 Vec::new()
573 };
574
575 Ok(Self {
576 home_dc_id,
577 dcs,
578 updates_state: UpdatesStateSnap {
579 pts,
580 qts,
581 date,
582 seq,
583 channels,
584 },
585 peers,
586 min_peers,
587 })
588 }
589
590 pub fn from_string(s: &str) -> io::Result<Self> {
592 use base64::Engine as _;
593 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
594 .decode(s.trim())
595 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
596 Self::from_bytes(&bytes)
597 }
598
599 pub fn load(path: &Path) -> io::Result<Self> {
600 let buf = std::fs::read(path)?;
601 Self::from_bytes(&buf)
602 }
603
604 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
616 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
617 candidates.peek()?;
618 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
620 cands
622 .iter()
623 .copied()
624 .find(|d| d.is_ipv6() == prefer_ipv6)
625 .or_else(|| cands.first().copied())
626 }
627
628 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
633 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
634 }
635}
636
637impl std::fmt::Display for PersistedSession {
638 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639 use base64::Engine as _;
640 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
641 }
642}
643
644pub fn default_dc_addresses() -> HashMap<i32, String> {
646 [
647 (1, "149.154.175.53:443"),
648 (2, "149.154.167.51:443"),
649 (3, "149.154.175.100:443"),
650 (4, "149.154.167.91:443"),
651 (5, "91.108.56.130:443"),
652 ]
653 .into_iter()
654 .map(|(id, addr)| (id, addr.to_string()))
655 .collect()
656}
657
658use std::path::PathBuf;
663
664pub trait SessionBackend: Send + Sync {
672 fn save(&self, session: &PersistedSession) -> io::Result<()>;
673 fn load(&self) -> io::Result<Option<PersistedSession>>;
674 fn delete(&self) -> io::Result<()>;
675
676 fn name(&self) -> &str;
678
679 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
692 let mut s = self.load()?.unwrap_or_default();
693 if let Some(existing) = s
695 .dcs
696 .iter_mut()
697 .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
698 {
699 *existing = entry.clone();
700 } else {
701 s.dcs.push(entry.clone());
702 }
703 self.save(&s)
704 }
705
706 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
712 let mut s = self.load()?.unwrap_or_default();
713 s.home_dc_id = dc_id;
714 self.save(&s)
715 }
716
717 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
722 let mut s = self.load()?.unwrap_or_default();
723 update.apply_to(&mut s.updates_state);
724 self.save(&s)
725 }
726
727 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
733 let mut s = self.load()?.unwrap_or_default();
734 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
735 *existing = peer.clone();
736 } else {
737 s.peers.push(peer.clone());
738 }
739 self.save(&s)
740 }
741}
742
743#[derive(Debug, Clone)]
757pub enum UpdateStateChange {
758 All(UpdatesStateSnap),
760 Primary { pts: i32, date: i32, seq: i32 },
762 Secondary { qts: i32 },
764 Channel { id: i64, pts: i32 },
766}
767
768impl UpdateStateChange {
769 pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
771 match self {
772 Self::All(new_snap) => *snap = new_snap.clone(),
773 Self::Primary { pts, date, seq } => {
774 snap.pts = *pts;
775 snap.date = *date;
776 snap.seq = *seq;
777 }
778 Self::Secondary { qts } => {
779 snap.qts = *qts;
780 }
781 Self::Channel { id, pts } => {
782 if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
784 existing.1 = *pts;
785 } else {
786 snap.channels.push((*id, *pts));
787 }
788 }
789 }
790 }
791}
792
793pub struct BinaryFileBackend {
797 path: PathBuf,
798 write_lock: std::sync::Mutex<()>,
803}
804
805impl BinaryFileBackend {
806 pub fn new(path: impl Into<PathBuf>) -> Self {
807 Self {
808 path: path.into(),
809 write_lock: std::sync::Mutex::new(()),
810 }
811 }
812
813 pub fn path(&self) -> &std::path::Path {
814 &self.path
815 }
816}
817
818impl SessionBackend for BinaryFileBackend {
819 fn save(&self, session: &PersistedSession) -> io::Result<()> {
820 let _guard = self.write_lock.lock().unwrap();
821 session.save(&self.path)
822 }
823
824 fn load(&self) -> io::Result<Option<PersistedSession>> {
825 if !self.path.exists() {
826 return Ok(None);
827 }
828 match PersistedSession::load(&self.path) {
829 Ok(s) => Ok(Some(s)),
830 Err(e) => {
831 let bak = self.path.with_extension("bak");
832 tracing::warn!(
833 "[ferogram::session] session file {:?} could not be read ({e}); \
834 backing it up to {:?} and starting a new session",
835 self.path,
836 bak
837 );
838 let _ = std::fs::rename(&self.path, &bak);
839 Ok(None)
840 }
841 }
842 }
843
844 fn delete(&self) -> io::Result<()> {
845 if self.path.exists() {
846 std::fs::remove_file(&self.path)?;
847 }
848 Ok(())
849 }
850
851 fn name(&self) -> &str {
852 "binary-file"
853 }
854
855 }
858
859#[derive(Default)]
867pub struct InMemoryBackend {
868 data: std::sync::Mutex<Option<PersistedSession>>,
869}
870
871impl InMemoryBackend {
872 pub fn new() -> Self {
873 Self::default()
874 }
875
876 pub fn snapshot(&self) -> Option<PersistedSession> {
878 self.data.lock().unwrap().clone()
879 }
880}
881
882impl SessionBackend for InMemoryBackend {
883 fn save(&self, s: &PersistedSession) -> io::Result<()> {
884 *self.data.lock().unwrap() = Some(s.clone());
885 Ok(())
886 }
887
888 fn load(&self) -> io::Result<Option<PersistedSession>> {
889 Ok(self.data.lock().unwrap().clone())
890 }
891
892 fn delete(&self) -> io::Result<()> {
893 *self.data.lock().unwrap() = None;
894 Ok(())
895 }
896
897 fn name(&self) -> &str {
898 "in-memory"
899 }
900
901 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
904 let mut guard = self.data.lock().unwrap();
905 let s = guard.get_or_insert_with(PersistedSession::default);
906 if let Some(existing) = s
907 .dcs
908 .iter_mut()
909 .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
910 {
911 *existing = entry.clone();
912 } else {
913 s.dcs.push(entry.clone());
914 }
915 Ok(())
916 }
917
918 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
919 let mut guard = self.data.lock().unwrap();
920 let s = guard.get_or_insert_with(PersistedSession::default);
921 s.home_dc_id = dc_id;
922 Ok(())
923 }
924
925 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
926 let mut guard = self.data.lock().unwrap();
927 let s = guard.get_or_insert_with(PersistedSession::default);
928 update.apply_to(&mut s.updates_state);
929 Ok(())
930 }
931
932 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
933 let mut guard = self.data.lock().unwrap();
934 let s = guard.get_or_insert_with(PersistedSession::default);
935 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
936 *existing = peer.clone();
937 } else {
938 s.peers.push(peer.clone());
939 }
940 Ok(())
941 }
942}
943
944pub struct StringSessionBackend {
948 data: std::sync::Mutex<String>,
949}
950
951impl StringSessionBackend {
952 pub fn new(s: impl Into<String>) -> Self {
953 Self {
954 data: std::sync::Mutex::new(s.into()),
955 }
956 }
957
958 pub fn current(&self) -> String {
959 self.data.lock().unwrap().clone()
960 }
961}
962
963impl SessionBackend for StringSessionBackend {
964 fn save(&self, session: &PersistedSession) -> io::Result<()> {
965 *self.data.lock().unwrap() = session.to_string();
966 Ok(())
967 }
968
969 fn load(&self) -> io::Result<Option<PersistedSession>> {
970 let s = self.data.lock().unwrap().clone();
971 if s.trim().is_empty() {
972 return Ok(None);
973 }
974 PersistedSession::from_string(&s).map(Some)
975 }
976
977 fn delete(&self) -> io::Result<()> {
978 *self.data.lock().unwrap() = String::new();
979 Ok(())
980 }
981
982 fn name(&self) -> &str {
983 "string-session"
984 }
985}
986
987pub mod string_session;
995pub use string_session::{FullSession, Session, StringSession, StringSessionError};
996
997#[cfg(test)]
998mod tests {
999 use super::*;
1000
1001 fn make_dc(id: i32) -> DcEntry {
1002 DcEntry {
1003 dc_id: id,
1004 addr: format!("1.2.3.{id}:443"),
1005 auth_key: None,
1006 first_salt: 0,
1007 time_offset: 0,
1008 flags: DcFlags::NONE,
1009 }
1010 }
1011
1012 fn make_peer(id: i64, hash: i64) -> CachedPeer {
1013 CachedPeer {
1014 id,
1015 access_hash: hash,
1016 is_channel: false,
1017 is_chat: false,
1018 channel_kind: None,
1019 }
1020 }
1021
1022 #[test]
1025 fn inmemory_load_returns_none_when_empty() {
1026 let b = InMemoryBackend::new();
1027 assert!(b.load().unwrap().is_none());
1028 }
1029
1030 #[test]
1031 fn inmemory_save_then_load_round_trips() {
1032 let b = InMemoryBackend::new();
1033 let mut s = PersistedSession::default();
1034 s.home_dc_id = 3;
1035 s.dcs.push(make_dc(3));
1036 b.save(&s).unwrap();
1037
1038 let loaded = b.load().unwrap().unwrap();
1039 assert_eq!(loaded.home_dc_id, 3);
1040 assert_eq!(loaded.dcs.len(), 1);
1041 }
1042
1043 #[test]
1044 fn inmemory_delete_clears_state() {
1045 let b = InMemoryBackend::new();
1046 let mut s = PersistedSession::default();
1047 s.home_dc_id = 2;
1048 b.save(&s).unwrap();
1049 b.delete().unwrap();
1050 assert!(b.load().unwrap().is_none());
1051 }
1052
1053 #[test]
1056 fn inmemory_update_dc_inserts_new() {
1057 let b = InMemoryBackend::new();
1058 b.update_dc(&make_dc(4)).unwrap();
1059 let s = b.snapshot().unwrap();
1060 assert_eq!(s.dcs.len(), 1);
1061 assert_eq!(s.dcs[0].dc_id, 4);
1062 }
1063
1064 #[test]
1065 fn inmemory_update_dc_replaces_existing() {
1066 let b = InMemoryBackend::new();
1067 b.update_dc(&make_dc(2)).unwrap();
1068 let mut updated = make_dc(2);
1069 updated.addr = "9.9.9.9:443".to_string();
1070 b.update_dc(&updated).unwrap();
1071
1072 let s = b.snapshot().unwrap();
1073 assert_eq!(s.dcs.len(), 1);
1074 assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
1075 }
1076
1077 #[test]
1078 fn inmemory_set_home_dc() {
1079 let b = InMemoryBackend::new();
1080 b.set_home_dc(5).unwrap();
1081 assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
1082 }
1083
1084 #[test]
1085 fn inmemory_cache_peer_inserts() {
1086 let b = InMemoryBackend::new();
1087 b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
1088 let s = b.snapshot().unwrap();
1089 assert_eq!(s.peers.len(), 1);
1090 assert_eq!(s.peers[0].id, 100);
1091 }
1092
1093 #[test]
1094 fn inmemory_cache_peer_updates_existing() {
1095 let b = InMemoryBackend::new();
1096 b.cache_peer(&make_peer(100, 111)).unwrap();
1097 b.cache_peer(&make_peer(100, 222)).unwrap();
1098 let s = b.snapshot().unwrap();
1099 assert_eq!(s.peers.len(), 1);
1100 assert_eq!(s.peers[0].access_hash, 222);
1101 }
1102
1103 #[test]
1106 fn update_state_primary() {
1107 let mut snap = UpdatesStateSnap {
1108 pts: 0,
1109 qts: 0,
1110 date: 0,
1111 seq: 0,
1112 channels: vec![],
1113 };
1114 UpdateStateChange::Primary {
1115 pts: 10,
1116 date: 20,
1117 seq: 30,
1118 }
1119 .apply_to(&mut snap);
1120 assert_eq!(snap.pts, 10);
1121 assert_eq!(snap.date, 20);
1122 assert_eq!(snap.seq, 30);
1123 assert_eq!(snap.qts, 0); }
1125
1126 #[test]
1127 fn update_state_secondary() {
1128 let mut snap = UpdatesStateSnap {
1129 pts: 5,
1130 qts: 0,
1131 date: 0,
1132 seq: 0,
1133 channels: vec![],
1134 };
1135 UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1136 assert_eq!(snap.qts, 99);
1137 assert_eq!(snap.pts, 5); }
1139
1140 #[test]
1141 fn update_state_channel_inserts() {
1142 let mut snap = UpdatesStateSnap {
1143 pts: 0,
1144 qts: 0,
1145 date: 0,
1146 seq: 0,
1147 channels: vec![],
1148 };
1149 UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1150 assert_eq!(snap.channels, vec![(12345, 42)]);
1151 }
1152
1153 #[test]
1154 fn update_state_channel_updates_existing() {
1155 let mut snap = UpdatesStateSnap {
1156 pts: 0,
1157 qts: 0,
1158 date: 0,
1159 seq: 0,
1160 channels: vec![(12345, 10), (67890, 5)],
1161 };
1162 UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1163 assert_eq!(snap.channels[0], (12345, 99));
1165 assert_eq!(snap.channels[1], (67890, 5));
1166 }
1167
1168 #[test]
1169 fn apply_update_state_via_backend() {
1170 let b = InMemoryBackend::new();
1171 b.apply_update_state(UpdateStateChange::Primary {
1172 pts: 7,
1173 date: 8,
1174 seq: 9,
1175 })
1176 .unwrap();
1177 let s = b.snapshot().unwrap();
1178 assert_eq!(s.updates_state.pts, 7);
1179 }
1180
1181 #[test]
1184 fn default_update_dc_via_trait_object() {
1185 let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1186 b.update_dc(&make_dc(1)).unwrap();
1187 b.update_dc(&make_dc(2)).unwrap();
1188 let loaded = b.load().unwrap().unwrap();
1190 assert_eq!(loaded.dcs.len(), 2);
1191 }
1192
1193 fn make_dc_v6(id: i32) -> DcEntry {
1196 DcEntry {
1197 dc_id: id,
1198 addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1199 auth_key: None,
1200 first_salt: 0,
1201 time_offset: 0,
1202 flags: DcFlags::IPV6,
1203 }
1204 }
1205
1206 #[test]
1207 fn dc_entry_from_parts_ipv4() {
1208 let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1209 assert_eq!(dc.addr, "149.154.175.53:443");
1210 assert!(!dc.is_ipv6());
1211 let sa = dc.socket_addr().unwrap();
1212 assert_eq!(sa.port(), 443);
1213 }
1214
1215 #[test]
1216 fn dc_entry_from_parts_ipv6() {
1217 let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1218 assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1219 assert!(dc.is_ipv6());
1220 let sa = dc.socket_addr().unwrap();
1221 assert_eq!(sa.port(), 443);
1222 }
1223
1224 #[test]
1225 fn persisted_session_dc_for_prefers_ipv6() {
1226 let mut s = PersistedSession::default();
1227 s.dcs.push(make_dc(2)); s.dcs.push(make_dc_v6(2)); let v6 = s.dc_for(2, true).unwrap();
1231 assert!(v6.is_ipv6());
1232
1233 let v4 = s.dc_for(2, false).unwrap();
1234 assert!(!v4.is_ipv6());
1235 }
1236
1237 #[test]
1238 fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1239 let mut s = PersistedSession::default();
1240 s.dcs.push(make_dc(3)); let dc = s.dc_for(3, true).unwrap();
1244 assert!(!dc.is_ipv6());
1245 }
1246
1247 #[test]
1248 fn persisted_session_all_dcs_for_returns_both() {
1249 let mut s = PersistedSession::default();
1250 s.dcs.push(make_dc(1));
1251 s.dcs.push(make_dc_v6(1));
1252 s.dcs.push(make_dc(2));
1253
1254 assert_eq!(s.all_dcs_for(1).count(), 2);
1255 assert_eq!(s.all_dcs_for(2).count(), 1);
1256 assert_eq!(s.all_dcs_for(5).count(), 0);
1257 }
1258
1259 #[test]
1260 fn inmemory_ipv4_and_ipv6_coexist() {
1261 let b = InMemoryBackend::new();
1262 b.update_dc(&make_dc(2)).unwrap(); b.update_dc(&make_dc_v6(2)).unwrap(); let s = b.snapshot().unwrap();
1266 assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1268 }
1269
1270 #[test]
1271 fn binary_roundtrip_ipv4_and_ipv6() {
1272 let mut s = PersistedSession::default();
1273 s.home_dc_id = 2;
1274 s.dcs.push(make_dc(2));
1275 s.dcs.push(make_dc_v6(2));
1276
1277 let bytes = s.to_bytes();
1278 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1279 assert_eq!(loaded.dcs.len(), 2);
1280 assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1281 assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1282 }
1283
1284 #[test]
1285 fn v6_channel_kind_roundtrip_all_variants() {
1286 let mut s = PersistedSession::default();
1287 s.home_dc_id = 1;
1288 s.peers.push(CachedPeer {
1289 id: 1001,
1290 access_hash: 0xaaaa,
1291 is_channel: true,
1292 is_chat: false,
1293 channel_kind: Some(ChannelKind::Broadcast),
1294 });
1295 s.peers.push(CachedPeer {
1296 id: 1002,
1297 access_hash: 0xbbbb,
1298 is_channel: true,
1299 is_chat: false,
1300 channel_kind: Some(ChannelKind::Megagroup),
1301 });
1302 s.peers.push(CachedPeer {
1303 id: 1003,
1304 access_hash: 0xcccc,
1305 is_channel: true,
1306 is_chat: false,
1307 channel_kind: Some(ChannelKind::Gigagroup),
1308 });
1309 s.peers.push(CachedPeer {
1310 id: 1004,
1311 access_hash: 0xdddd,
1312 is_channel: false,
1313 is_chat: false,
1314 channel_kind: None,
1315 });
1316
1317 let bytes = s.to_bytes();
1318 assert_eq!(bytes[0], 0x06, "version byte must be 6");
1319
1320 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1321 assert_eq!(loaded.peers.len(), 4);
1322
1323 let p = &loaded.peers[0];
1324 assert_eq!(p.id, 1001);
1325 assert_eq!(p.channel_kind, Some(ChannelKind::Broadcast));
1326
1327 let p = &loaded.peers[1];
1328 assert_eq!(p.id, 1002);
1329 assert_eq!(p.channel_kind, Some(ChannelKind::Megagroup));
1330
1331 let p = &loaded.peers[2];
1332 assert_eq!(p.id, 1003);
1333 assert_eq!(p.channel_kind, Some(ChannelKind::Gigagroup));
1334
1335 let p = &loaded.peers[3];
1336 assert_eq!(p.id, 1004);
1337 assert_eq!(p.channel_kind, None);
1338 }
1339
1340 #[test]
1341 fn v6_channel_kind_absent_sentinel_roundtrip() {
1342 let mut s = PersistedSession::default();
1344 s.home_dc_id = 1;
1345 s.peers.push(CachedPeer {
1346 id: 555,
1347 access_hash: 0x1234,
1348 is_channel: false,
1349 is_chat: false,
1350 channel_kind: None,
1351 });
1352 let loaded = PersistedSession::from_bytes(&s.to_bytes()).unwrap();
1353 assert_eq!(loaded.peers[0].channel_kind, None);
1354 }
1355}
1356
1357#[cfg(feature = "sqlite-session")]
1383pub struct SqliteBackend {
1384 conn: std::sync::Mutex<rusqlite::Connection>,
1385 label: String,
1386}
1387
1388#[cfg(feature = "sqlite-session")]
1389impl SqliteBackend {
1390 const SCHEMA: &'static str = "
1391 PRAGMA journal_mode = WAL;
1392 PRAGMA synchronous = NORMAL;
1393
1394 CREATE TABLE IF NOT EXISTS meta (
1395 key TEXT PRIMARY KEY,
1396 value INTEGER NOT NULL DEFAULT 0
1397 );
1398
1399 CREATE TABLE IF NOT EXISTS dcs (
1400 dc_id INTEGER NOT NULL,
1401 flags INTEGER NOT NULL DEFAULT 0,
1402 addr TEXT NOT NULL,
1403 auth_key BLOB,
1404 first_salt INTEGER NOT NULL DEFAULT 0,
1405 time_offset INTEGER NOT NULL DEFAULT 0,
1406 PRIMARY KEY (dc_id, flags)
1407 );
1408
1409 CREATE TABLE IF NOT EXISTS update_state (
1410 id INTEGER PRIMARY KEY CHECK (id = 1),
1411 pts INTEGER NOT NULL DEFAULT 0,
1412 qts INTEGER NOT NULL DEFAULT 0,
1413 date INTEGER NOT NULL DEFAULT 0,
1414 seq INTEGER NOT NULL DEFAULT 0
1415 );
1416
1417 CREATE TABLE IF NOT EXISTS channel_pts (
1418 channel_id INTEGER PRIMARY KEY,
1419 pts INTEGER NOT NULL
1420 );
1421
1422 CREATE TABLE IF NOT EXISTS peers (
1423 id INTEGER PRIMARY KEY,
1424 access_hash INTEGER NOT NULL,
1425 is_channel INTEGER NOT NULL DEFAULT 0,
1426 is_chat INTEGER NOT NULL DEFAULT 0,
1427 channel_kind INTEGER
1428 );
1429
1430 CREATE TABLE IF NOT EXISTS min_peers (
1431 user_id INTEGER PRIMARY KEY,
1432 peer_id INTEGER NOT NULL,
1433 msg_id INTEGER NOT NULL
1434 );
1435 ";
1436
1437 pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1439 let path = path.into();
1440 let label = path.display().to_string();
1441 let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1442 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1443 Self::migrate_legacy_sqlite_schema(&conn)?;
1444 Ok(Self {
1445 conn: std::sync::Mutex::new(conn),
1446 label,
1447 })
1448 }
1449
1450 pub fn in_memory() -> io::Result<Self> {
1452 let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1453 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1454 Self::migrate_legacy_sqlite_schema(&conn)?;
1455 Ok(Self {
1456 conn: std::sync::Mutex::new(conn),
1457 label: ":memory:".into(),
1458 })
1459 }
1460
1461 fn map_err(e: rusqlite::Error) -> io::Error {
1462 io::Error::other(e)
1463 }
1464
1465 fn migrate_legacy_sqlite_schema(conn: &rusqlite::Connection) -> io::Result<()> {
1469 let mut has_is_chat = false;
1470 let mut stmt = conn
1471 .prepare("PRAGMA table_info(peers)")
1472 .map_err(Self::map_err)?;
1473 let cols = stmt
1474 .query_map([], |row| row.get::<_, String>(1))
1475 .map_err(Self::map_err)?;
1476 for col in cols.filter_map(|r| r.ok()) {
1477 if col == "is_chat" {
1478 has_is_chat = true;
1479 break;
1480 }
1481 }
1482 if !has_is_chat {
1483 conn.execute_batch("ALTER TABLE peers ADD COLUMN is_chat INTEGER NOT NULL DEFAULT 0;")
1484 .map_err(Self::map_err)?;
1485 }
1486 let mut has_channel_kind = false;
1488 let mut stmt2 = conn
1489 .prepare("PRAGMA table_info(peers)")
1490 .map_err(Self::map_err)?;
1491 let cols2 = stmt2
1492 .query_map([], |row| row.get::<_, String>(1))
1493 .map_err(Self::map_err)?;
1494 for col in cols2.filter_map(|r| r.ok()) {
1495 if col == "channel_kind" {
1496 has_channel_kind = true;
1497 break;
1498 }
1499 }
1500 if !has_channel_kind {
1501 conn.execute_batch("ALTER TABLE peers ADD COLUMN channel_kind INTEGER;")
1502 .map_err(Self::map_err)?;
1503 }
1504 conn.execute_batch(
1505 "CREATE TABLE IF NOT EXISTS min_peers (
1506 user_id INTEGER PRIMARY KEY,
1507 peer_id INTEGER NOT NULL,
1508 msg_id INTEGER NOT NULL
1509 );",
1510 )
1511 .map_err(Self::map_err)?;
1512 Ok(())
1513 }
1514
1515 fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1517 let home_dc_id: i32 = conn
1519 .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1520 r.get(0)
1521 })
1522 .unwrap_or(0);
1523
1524 let mut stmt = conn
1526 .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1527 .map_err(Self::map_err)?;
1528 let dcs = stmt
1529 .query_map([], |row| {
1530 let dc_id: i32 = row.get(0)?;
1531 let flags_raw: u8 = row.get(1)?;
1532 let addr: String = row.get(2)?;
1533 let key_blob: Option<Vec<u8>> = row.get(3)?;
1534 let first_salt: i64 = row.get(4)?;
1535 let time_offset: i32 = row.get(5)?;
1536 Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1537 })
1538 .map_err(Self::map_err)?
1539 .filter_map(|r| r.ok())
1540 .map(
1541 |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1542 let auth_key = key_blob.and_then(|b| {
1543 if b.len() == 256 {
1544 let mut k = [0u8; 256];
1545 k.copy_from_slice(&b);
1546 Some(k)
1547 } else {
1548 None
1549 }
1550 });
1551 DcEntry {
1552 dc_id,
1553 addr,
1554 auth_key,
1555 first_salt,
1556 time_offset,
1557 flags: DcFlags(flags_raw),
1558 }
1559 },
1560 )
1561 .collect();
1562
1563 let updates_state = conn
1565 .query_row(
1566 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1567 [],
1568 |r| {
1569 Ok(UpdatesStateSnap {
1570 pts: r.get(0)?,
1571 qts: r.get(1)?,
1572 date: r.get(2)?,
1573 seq: r.get(3)?,
1574 channels: vec![],
1575 })
1576 },
1577 )
1578 .unwrap_or_default();
1579
1580 let mut ch_stmt = conn
1582 .prepare("SELECT channel_id, pts FROM channel_pts")
1583 .map_err(Self::map_err)?;
1584 let channels: Vec<(i64, i32)> = ch_stmt
1585 .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1586 .map_err(Self::map_err)?
1587 .filter_map(|r| r.ok())
1588 .collect();
1589
1590 let mut peer_stmt = conn
1592 .prepare("SELECT id, access_hash, is_channel, is_chat, channel_kind FROM peers")
1593 .map_err(Self::map_err)?;
1594 let peers: Vec<CachedPeer> = peer_stmt
1595 .query_map([], |r| {
1596 let kind_raw: Option<i32> = r.get(4)?;
1597 Ok(CachedPeer {
1598 id: r.get(0)?,
1599 access_hash: r.get(1)?,
1600 is_channel: r.get::<_, i32>(2)? != 0,
1601 is_chat: r.get::<_, i32>(3)? != 0,
1602 channel_kind: kind_raw.and_then(|k| ChannelKind::from_byte(k as u8)),
1603 })
1604 })
1605 .map_err(Self::map_err)?
1606 .filter_map(|r| r.ok())
1607 .collect();
1608
1609 let mut min_stmt = conn
1611 .prepare("SELECT user_id, peer_id, msg_id FROM min_peers")
1612 .map_err(Self::map_err)?;
1613 let min_peers: Vec<CachedMinPeer> = min_stmt
1614 .query_map([], |r| {
1615 Ok(CachedMinPeer {
1616 user_id: r.get(0)?,
1617 peer_id: r.get(1)?,
1618 msg_id: r.get(2)?,
1619 })
1620 })
1621 .map_err(Self::map_err)?
1622 .filter_map(|r| r.ok())
1623 .collect();
1624
1625 Ok(PersistedSession {
1626 home_dc_id,
1627 dcs,
1628 updates_state: UpdatesStateSnap {
1629 channels,
1630 ..updates_state
1631 },
1632 peers,
1633 min_peers,
1634 })
1635 }
1636
1637 fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1639 conn.execute_batch("BEGIN IMMEDIATE")
1640 .map_err(Self::map_err)?;
1641
1642 conn.execute(
1643 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1644 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1645 rusqlite::params![s.home_dc_id],
1646 )
1647 .map_err(Self::map_err)?;
1648
1649 conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1651 for d in &s.dcs {
1652 conn.execute(
1653 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1654 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1655 rusqlite::params![
1656 d.dc_id,
1657 d.flags.0,
1658 d.addr,
1659 d.auth_key.as_ref().map(|k| k.as_ref()),
1660 d.first_salt,
1661 d.time_offset,
1662 ],
1663 )
1664 .map_err(Self::map_err)?;
1665 }
1666
1667 let us = &s.updates_state;
1671 conn.execute(
1672 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1673 ON CONFLICT(id) DO UPDATE SET
1674 pts = MAX(excluded.pts, update_state.pts),
1675 qts = MAX(excluded.qts, update_state.qts),
1676 date = excluded.date,
1677 seq = excluded.seq",
1678 rusqlite::params![us.pts, us.qts, us.date, us.seq],
1679 )
1680 .map_err(Self::map_err)?;
1681
1682 conn.execute("DELETE FROM channel_pts", [])
1683 .map_err(Self::map_err)?;
1684 for &(cid, cpts) in &us.channels {
1685 conn.execute(
1686 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1687 rusqlite::params![cid, cpts],
1688 )
1689 .map_err(Self::map_err)?;
1690 }
1691
1692 conn.execute("DELETE FROM peers", [])
1694 .map_err(Self::map_err)?;
1695 for p in &s.peers {
1696 conn.execute(
1697 "INSERT INTO peers (id, access_hash, is_channel, is_chat, channel_kind) VALUES (?1, ?2, ?3, ?4, ?5)",
1698 rusqlite::params![
1699 p.id,
1700 p.access_hash,
1701 p.is_channel as i32,
1702 p.is_chat as i32,
1703 p.channel_kind.map(|k| k.to_byte() as i32),
1704 ],
1705 )
1706 .map_err(Self::map_err)?;
1707 }
1708
1709 conn.execute("DELETE FROM min_peers", [])
1711 .map_err(Self::map_err)?;
1712 for m in &s.min_peers {
1713 conn.execute(
1714 "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1, ?2, ?3)",
1715 rusqlite::params![m.user_id, m.peer_id, m.msg_id],
1716 )
1717 .map_err(Self::map_err)?;
1718 }
1719
1720 conn.execute_batch("COMMIT").map_err(Self::map_err)
1721 }
1722}
1723
1724#[cfg(feature = "sqlite-session")]
1725impl SessionBackend for SqliteBackend {
1726 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1727 let conn = self.conn.lock().unwrap();
1728 Self::write_session(&conn, session)
1729 }
1730
1731 fn load(&self) -> io::Result<Option<PersistedSession>> {
1732 let conn = self.conn.lock().unwrap();
1733 let count: i64 = conn
1735 .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1736 .map_err(Self::map_err)?;
1737 if count == 0 {
1738 return Ok(None);
1739 }
1740 Self::read_session(&conn).map(Some)
1741 }
1742
1743 fn delete(&self) -> io::Result<()> {
1744 let conn = self.conn.lock().unwrap();
1745 conn.execute_batch(
1746 "BEGIN IMMEDIATE;
1747 DELETE FROM meta;
1748 DELETE FROM dcs;
1749 DELETE FROM update_state;
1750 DELETE FROM channel_pts;
1751 DELETE FROM peers;
1752 DELETE FROM min_peers;
1753 COMMIT;",
1754 )
1755 .map_err(Self::map_err)
1756 }
1757
1758 fn name(&self) -> &str {
1759 &self.label
1760 }
1761
1762 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1765 let conn = self.conn.lock().unwrap();
1766 conn.execute(
1767 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1768 VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1769 ON CONFLICT(dc_id, flags) DO UPDATE SET
1770 addr = excluded.addr,
1771 auth_key = excluded.auth_key,
1772 first_salt = excluded.first_salt,
1773 time_offset = excluded.time_offset",
1774 rusqlite::params![
1775 entry.dc_id,
1776 entry.addr,
1777 entry.auth_key.as_ref().map(|k| k.as_ref()),
1778 entry.first_salt,
1779 entry.time_offset,
1780 entry.flags.0,
1781 ],
1782 )
1783 .map(|_| ())
1784 .map_err(Self::map_err)
1785 }
1786
1787 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1788 let conn = self.conn.lock().unwrap();
1789 conn.execute(
1790 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1791 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1792 rusqlite::params![dc_id],
1793 )
1794 .map(|_| ())
1795 .map_err(Self::map_err)
1796 }
1797
1798 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1799 let conn = self.conn.lock().unwrap();
1800 match update {
1801 UpdateStateChange::All(snap) => {
1802 conn.execute(
1803 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1804 ON CONFLICT(id) DO UPDATE SET
1805 pts=excluded.pts, qts=excluded.qts,
1806 date=excluded.date, seq=excluded.seq",
1807 rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1808 )
1809 .map_err(Self::map_err)?;
1810 conn.execute("DELETE FROM channel_pts", [])
1811 .map_err(Self::map_err)?;
1812 for &(cid, cpts) in &snap.channels {
1813 conn.execute(
1814 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1815 rusqlite::params![cid, cpts],
1816 )
1817 .map_err(Self::map_err)?;
1818 }
1819 Ok(())
1820 }
1821 UpdateStateChange::Primary { pts, date, seq } => conn
1822 .execute(
1823 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1824 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1825 seq=excluded.seq",
1826 rusqlite::params![pts, date, seq],
1827 )
1828 .map(|_| ())
1829 .map_err(Self::map_err),
1830 UpdateStateChange::Secondary { qts } => conn
1831 .execute(
1832 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1833 ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1834 rusqlite::params![qts],
1835 )
1836 .map(|_| ())
1837 .map_err(Self::map_err),
1838 UpdateStateChange::Channel { id, pts } => conn
1839 .execute(
1840 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1841 ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1842 rusqlite::params![id, pts],
1843 )
1844 .map(|_| ())
1845 .map_err(Self::map_err),
1846 }
1847 }
1848
1849 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1850 let conn = self.conn.lock().unwrap();
1851 conn.execute(
1852 "INSERT INTO peers (id, access_hash, is_channel, is_chat, channel_kind) VALUES (?1, ?2, ?3, ?4, ?5)
1853 ON CONFLICT(id) DO UPDATE SET
1854 access_hash = excluded.access_hash,
1855 is_channel = excluded.is_channel,
1856 is_chat = excluded.is_chat,
1857 channel_kind = excluded.channel_kind",
1858 rusqlite::params![
1859 peer.id,
1860 peer.access_hash,
1861 peer.is_channel as i32,
1862 peer.is_chat as i32,
1863 peer.channel_kind.map(|k| k.to_byte() as i32),
1864 ],
1865 )
1866 .map(|_| ())
1867 .map_err(Self::map_err)
1868 }
1869}
1870
1871#[cfg(feature = "libsql-session")]
1891pub struct LibSqlBackend {
1892 conn: libsql::Connection,
1893 label: String,
1894}
1895
1896#[cfg(feature = "libsql-session")]
1897impl LibSqlBackend {
1898 const SCHEMA: &'static str = "
1899 CREATE TABLE IF NOT EXISTS meta (
1900 key TEXT PRIMARY KEY,
1901 value INTEGER NOT NULL DEFAULT 0
1902 );
1903 CREATE TABLE IF NOT EXISTS dcs (
1904 dc_id INTEGER NOT NULL,
1905 flags INTEGER NOT NULL DEFAULT 0,
1906 addr TEXT NOT NULL,
1907 auth_key BLOB,
1908 first_salt INTEGER NOT NULL DEFAULT 0,
1909 time_offset INTEGER NOT NULL DEFAULT 0,
1910 PRIMARY KEY (dc_id, flags)
1911 );
1912 CREATE TABLE IF NOT EXISTS update_state (
1913 id INTEGER PRIMARY KEY CHECK (id = 1),
1914 pts INTEGER NOT NULL DEFAULT 0,
1915 qts INTEGER NOT NULL DEFAULT 0,
1916 date INTEGER NOT NULL DEFAULT 0,
1917 seq INTEGER NOT NULL DEFAULT 0
1918 );
1919 CREATE TABLE IF NOT EXISTS channel_pts (
1920 channel_id INTEGER PRIMARY KEY,
1921 pts INTEGER NOT NULL
1922 );
1923 CREATE TABLE IF NOT EXISTS peers (
1924 id INTEGER PRIMARY KEY,
1925 access_hash INTEGER NOT NULL,
1926 is_channel INTEGER NOT NULL DEFAULT 0,
1927 is_chat INTEGER NOT NULL DEFAULT 0,
1928 channel_kind INTEGER
1929 );
1930 CREATE TABLE IF NOT EXISTS min_peers (
1931 user_id INTEGER PRIMARY KEY,
1932 peer_id INTEGER NOT NULL,
1933 msg_id INTEGER NOT NULL
1934 );
1935 ";
1936
1937 fn block<F, T>(fut: F) -> io::Result<T>
1938 where
1939 F: std::future::Future<Output = Result<T, libsql::Error>>,
1940 {
1941 tokio::runtime::Handle::current()
1942 .block_on(fut)
1943 .map_err(io::Error::other)
1944 }
1945
1946 async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1947 conn.execute_batch(Self::SCHEMA).await?;
1948 let _ = conn
1952 .execute_batch("ALTER TABLE peers ADD COLUMN channel_kind INTEGER;")
1953 .await;
1954 Ok(())
1955 }
1956
1957 pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1959 let path = path.into();
1960 let label = path.display().to_string();
1961 let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1962 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1963 Self::block(Self::apply_schema(&conn))?;
1964 Ok(Self { conn, label })
1965 }
1966
1967 pub fn in_memory() -> io::Result<Self> {
1969 let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1970 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1971 Self::block(Self::apply_schema(&conn))?;
1972 Ok(Self {
1973 conn,
1974 label: ":memory:".into(),
1975 })
1976 }
1977
1978 pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1980 let url = url.into();
1981 let label = url.clone();
1982 let db = Self::block(async {
1983 libsql::Builder::new_remote(url, auth_token.into())
1984 .build()
1985 .await
1986 })?;
1987 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1988 Self::block(Self::apply_schema(&conn))?;
1989 Ok(Self { conn, label })
1990 }
1991
1992 pub fn open_replica(
1994 path: impl Into<PathBuf>,
1995 url: impl Into<String>,
1996 auth_token: impl Into<String>,
1997 ) -> io::Result<Self> {
1998 let path = path.into();
1999 let url = url.into();
2000 let auth_token = auth_token.into();
2001 let label = format!("{} (replica of {})", path.display(), url);
2002 let db = Self::block(async {
2003 libsql::Builder::new_remote_replica(path, url, auth_token)
2004 .build()
2005 .await
2006 })?;
2007 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
2008 Self::block(Self::apply_schema(&conn))?;
2009 Ok(Self { conn, label })
2010 }
2011
2012 async fn read_session_async(
2013 conn: &libsql::Connection,
2014 ) -> Result<PersistedSession, libsql::Error> {
2015 let home_dc_id: i32 = conn
2017 .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
2018 .await?
2019 .next()
2020 .await?
2021 .map(|r| r.get::<i32>(0))
2022 .transpose()?
2023 .unwrap_or(0);
2024
2025 let mut rows = conn
2027 .query(
2028 "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
2029 (),
2030 )
2031 .await?;
2032 let mut dcs = Vec::new();
2033 while let Some(row) = rows.next().await? {
2034 let dc_id: i32 = row.get(0)?;
2035 let flags_raw: u8 = row.get::<i64>(1)? as u8;
2036 let addr: String = row.get(2)?;
2037 let key_blob: Option<Vec<u8>> = row.get(3)?;
2038 let first_salt: i64 = row.get(4)?;
2039 let time_offset: i32 = row.get(5)?;
2040 let auth_key = match key_blob {
2041 Some(b) if b.len() == 256 => {
2042 let mut k = [0u8; 256];
2043 k.copy_from_slice(&b);
2044 Some(k)
2045 }
2046 Some(b) => {
2047 return Err(libsql::Error::Misuse(format!(
2048 "auth_key blob must be 256 bytes, got {}",
2049 b.len()
2050 )));
2051 }
2052 None => None,
2053 };
2054 dcs.push(DcEntry {
2055 dc_id,
2056 addr,
2057 auth_key,
2058 first_salt,
2059 time_offset,
2060 flags: DcFlags(flags_raw),
2061 });
2062 }
2063
2064 let mut us_row = conn
2066 .query(
2067 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
2068 (),
2069 )
2070 .await?;
2071 let updates_state = if let Some(r) = us_row.next().await? {
2072 UpdatesStateSnap {
2073 pts: r.get(0)?,
2074 qts: r.get(1)?,
2075 date: r.get(2)?,
2076 seq: r.get(3)?,
2077 channels: vec![],
2078 }
2079 } else {
2080 UpdatesStateSnap::default()
2081 };
2082
2083 let mut ch_rows = conn
2085 .query("SELECT channel_id, pts FROM channel_pts", ())
2086 .await?;
2087 let mut channels = Vec::new();
2088 while let Some(r) = ch_rows.next().await? {
2089 channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
2090 }
2091
2092 let mut peer_rows = conn
2094 .query(
2095 "SELECT id, access_hash, is_channel, is_chat, channel_kind FROM peers",
2096 (),
2097 )
2098 .await?;
2099 let mut peers = Vec::new();
2100 while let Some(r) = peer_rows.next().await? {
2101 let kind_raw: Option<i32> = r.get(4).ok();
2102 peers.push(CachedPeer {
2103 id: r.get(0)?,
2104 access_hash: r.get(1)?,
2105 is_channel: r.get::<i32>(2)? != 0,
2106 is_chat: r.get::<i32>(3)? != 0,
2107 channel_kind: kind_raw.and_then(|k| ChannelKind::from_byte(k as u8)),
2108 });
2109 }
2110
2111 let mut min_rows = conn
2113 .query("SELECT user_id, peer_id, msg_id FROM min_peers", ())
2114 .await?;
2115 let mut min_peers = Vec::new();
2116 while let Some(r) = min_rows.next().await? {
2117 min_peers.push(CachedMinPeer {
2118 user_id: r.get(0)?,
2119 peer_id: r.get(1)?,
2120 msg_id: r.get(2)?,
2121 });
2122 }
2123
2124 Ok(PersistedSession {
2125 home_dc_id,
2126 dcs,
2127 updates_state: UpdatesStateSnap {
2128 channels,
2129 ..updates_state
2130 },
2131 peers,
2132 min_peers,
2133 })
2134 }
2135
2136 async fn write_session_async(
2137 conn: &libsql::Connection,
2138 s: &PersistedSession,
2139 ) -> Result<(), libsql::Error> {
2140 conn.execute_batch("BEGIN IMMEDIATE").await.map(|_| ())?;
2141
2142 conn.execute(
2143 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
2144 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
2145 libsql::params![s.home_dc_id],
2146 )
2147 .await?;
2148
2149 conn.execute("DELETE FROM dcs", ()).await?;
2150 for d in &s.dcs {
2151 conn.execute(
2152 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2153 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
2154 libsql::params![
2155 d.dc_id,
2156 d.flags.0 as i64,
2157 d.addr.clone(),
2158 d.auth_key.map(|k| k.to_vec()),
2159 d.first_salt,
2160 d.time_offset,
2161 ],
2162 )
2163 .await?;
2164 }
2165
2166 let us = &s.updates_state;
2167 conn.execute(
2168 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
2169 ON CONFLICT(id) DO UPDATE SET
2170 pts = MAX(excluded.pts, update_state.pts),
2171 qts = MAX(excluded.qts, update_state.qts),
2172 date = excluded.date,
2173 seq = excluded.seq",
2174 libsql::params![us.pts, us.qts, us.date, us.seq],
2175 )
2176 .await?;
2177
2178 conn.execute("DELETE FROM channel_pts", ()).await?;
2179 for &(cid, cpts) in &us.channels {
2180 conn.execute(
2181 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
2182 libsql::params![cid, cpts],
2183 )
2184 .await?;
2185 }
2186
2187 conn.execute("DELETE FROM peers", ()).await?;
2188 for p in &s.peers {
2189 conn.execute(
2190 "INSERT INTO peers (id, access_hash, is_channel, is_chat, channel_kind) VALUES (?1,?2,?3,?4,?5)",
2191 libsql::params![
2192 p.id,
2193 p.access_hash,
2194 p.is_channel as i32,
2195 p.is_chat as i32,
2196 p.channel_kind.map(|k| k.to_byte() as i32),
2197 ],
2198 )
2199 .await?;
2200 }
2201
2202 conn.execute("DELETE FROM min_peers", ()).await?;
2203 for m in &s.min_peers {
2204 conn.execute(
2205 "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1,?2,?3)",
2206 libsql::params![m.user_id, m.peer_id, m.msg_id],
2207 )
2208 .await?;
2209 }
2210
2211 conn.execute_batch("COMMIT").await.map(|_| ())
2212 }
2213}
2214
2215#[cfg(feature = "libsql-session")]
2216impl SessionBackend for LibSqlBackend {
2217 fn save(&self, session: &PersistedSession) -> io::Result<()> {
2218 let conn = self.conn.clone();
2219 let session = session.clone();
2220 Self::block(async move { Self::write_session_async(&conn, &session).await })
2221 }
2222
2223 fn load(&self) -> io::Result<Option<PersistedSession>> {
2224 let conn = self.conn.clone();
2225 let count: i64 = Self::block(async move {
2226 let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
2227 Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
2228 })?;
2229 if count == 0 {
2230 return Ok(None);
2231 }
2232 let conn = self.conn.clone();
2233 Self::block(async move { Self::read_session_async(&conn).await }).map(Some)
2234 }
2235
2236 fn delete(&self) -> io::Result<()> {
2237 let conn = self.conn.clone();
2238 Self::block(async move {
2239 conn.execute_batch(
2240 "BEGIN IMMEDIATE;
2241 DELETE FROM meta;
2242 DELETE FROM dcs;
2243 DELETE FROM update_state;
2244 DELETE FROM channel_pts;
2245 DELETE FROM peers;
2246 DELETE FROM min_peers;
2247 COMMIT;",
2248 )
2249 .await
2250 .map(|_| ())
2251 })
2252 }
2253
2254 fn name(&self) -> &str {
2255 &self.label
2256 }
2257
2258 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
2261 let conn = self.conn.clone();
2262 let (dc_id, addr, key, salt, off, flags) = (
2263 entry.dc_id,
2264 entry.addr.clone(),
2265 entry.auth_key.map(|k| k.to_vec()),
2266 entry.first_salt,
2267 entry.time_offset,
2268 entry.flags.0 as i64,
2269 );
2270 Self::block(async move {
2271 conn.execute(
2272 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2273 VALUES (?1,?6,?2,?3,?4,?5)
2274 ON CONFLICT(dc_id, flags) DO UPDATE SET
2275 addr=excluded.addr, auth_key=excluded.auth_key,
2276 first_salt=excluded.first_salt, time_offset=excluded.time_offset",
2277 libsql::params![dc_id, addr, key, salt, off, flags],
2278 )
2279 .await
2280 .map(|_| ())
2281 })
2282 }
2283
2284 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
2285 let conn = self.conn.clone();
2286 Self::block(async move {
2287 conn.execute(
2288 "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
2289 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
2290 libsql::params![dc_id],
2291 )
2292 .await
2293 .map(|_| ())
2294 })
2295 }
2296
2297 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
2298 let conn = self.conn.clone();
2299 Self::block(async move {
2300 match update {
2301 UpdateStateChange::All(snap) => {
2302 conn.execute(
2303 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
2304 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
2305 date=excluded.date,seq=excluded.seq",
2306 libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
2307 )
2308 .await?;
2309 conn.execute("DELETE FROM channel_pts", ()).await?;
2310 for &(cid, cpts) in &snap.channels {
2311 conn.execute(
2312 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
2313 libsql::params![cid, cpts],
2314 )
2315 .await?;
2316 }
2317 Ok(())
2318 }
2319 UpdateStateChange::Primary { pts, date, seq } => conn
2320 .execute(
2321 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
2322 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
2323 seq=excluded.seq",
2324 libsql::params![pts, date, seq],
2325 )
2326 .await
2327 .map(|_| ()),
2328 UpdateStateChange::Secondary { qts } => conn
2329 .execute(
2330 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2331 ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2332 libsql::params![qts],
2333 )
2334 .await
2335 .map(|_| ()),
2336 UpdateStateChange::Channel { id, pts } => conn
2337 .execute(
2338 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2339 ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2340 libsql::params![id, pts],
2341 )
2342 .await
2343 .map(|_| ()),
2344 }
2345 })
2346 }
2347
2348 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2349 let conn = self.conn.clone();
2350 let (id, hash, is_ch, is_ct, kind) = (
2351 peer.id,
2352 peer.access_hash,
2353 peer.is_channel as i32,
2354 peer.is_chat as i32,
2355 peer.channel_kind.map(|k| k.to_byte() as i32),
2356 );
2357 Self::block(async move {
2358 conn.execute(
2359 "INSERT INTO peers (id,access_hash,is_channel,is_chat,channel_kind) VALUES (?1,?2,?3,?4,?5)
2360 ON CONFLICT(id) DO UPDATE SET
2361 access_hash=excluded.access_hash,
2362 is_channel=excluded.is_channel,
2363 is_chat=excluded.is_chat,
2364 channel_kind=excluded.channel_kind",
2365 libsql::params![id, hash, is_ch, is_ct, kind],
2366 )
2367 .await
2368 .map(|_| ())
2369 })
2370 }
2371}