1#![deny(unsafe_code)]
14#![cfg_attr(docsrs, feature(doc_cfg))]
15#![doc(html_root_url = "https://docs.rs/ferogram-session/0.6.3")]
16use std::collections::HashMap;
78use std::io::{self, ErrorKind};
79use std::path::Path;
80
81#[cfg(feature = "serde")]
82mod auth_key_serde {
83 use serde::{Deserialize, Deserializer, Serializer};
84
85 pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
88 where
89 S: Serializer,
90 {
91 match value {
92 Some(k) => s.serialize_some(k.as_slice()),
93 None => s.serialize_none(),
94 }
95 }
96
97 pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
100 where
101 D: Deserializer<'de>,
102 {
103 let opt: Option<Vec<u8>> = Option::deserialize(d)?;
104 match opt {
105 None => Ok(None),
106 Some(v) => {
107 let arr: [u8; 256] = v
108 .try_into()
109 .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
110 Ok(Some(arr))
111 }
112 }
113 }
114}
115
116#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
120#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
121pub struct DcFlags(pub u8);
122
123impl DcFlags {
124 pub const NONE: DcFlags = DcFlags(0);
125 pub const IPV6: DcFlags = DcFlags(1 << 0);
126 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
127 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
128 pub const CDN: DcFlags = DcFlags(1 << 3);
129 pub const STATIC: DcFlags = DcFlags(1 << 4);
130
131 pub fn contains(self, other: DcFlags) -> bool {
133 self.0 & other.0 == other.0
134 }
135
136 pub fn set(&mut self, flag: DcFlags) {
138 self.0 |= flag.0;
139 }
140}
141
142impl std::ops::BitOr for DcFlags {
143 type Output = DcFlags;
144 fn bitor(self, rhs: DcFlags) -> DcFlags {
145 DcFlags(self.0 | rhs.0)
146 }
147}
148
149#[derive(Clone, Debug)]
151#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
152pub struct DcEntry {
153 pub dc_id: i32,
154 pub addr: String,
155 #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
156 pub auth_key: Option<[u8; 256]>,
157 pub first_salt: i64,
158 pub time_offset: i32,
159 pub flags: DcFlags,
161}
162
163impl DcEntry {
164 #[inline]
166 pub fn is_ipv6(&self) -> bool {
167 self.flags.contains(DcFlags::IPV6)
168 }
169
170 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
177 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
178 io::Error::new(
179 io::ErrorKind::InvalidData,
180 format!("invalid DC address: {:?}", self.addr),
181 )
182 })
183 }
184
185 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
198 let addr = if ip.contains(':') {
200 format!("[{ip}]:{port}")
201 } else {
202 format!("{ip}:{port}")
203 };
204 Self {
205 dc_id,
206 addr,
207 auth_key: None,
208 first_salt: 0,
209 time_offset: 0,
210 flags,
211 }
212 }
213}
214
215#[derive(Clone, Debug, Default)]
218#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
219pub struct UpdatesStateSnap {
220 pub pts: i32,
222 pub qts: i32,
224 pub date: i32,
226 pub seq: i32,
228 pub channels: Vec<(i64, i32)>,
230}
231
232impl UpdatesStateSnap {
233 #[inline]
235 pub fn is_initialised(&self) -> bool {
236 self.pts > 0
237 }
238
239 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
241 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
242 entry.1 = pts;
243 } else {
244 self.channels.push((channel_id, pts));
245 }
246 }
247
248 pub fn channel_pts(&self, channel_id: i64) -> i32 {
250 self.channels
251 .iter()
252 .find(|c| c.0 == channel_id)
253 .map(|c| c.1)
254 .unwrap_or(0)
255 }
256}
257
258#[derive(Clone, Copy, Debug, PartialEq, Eq)]
261#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
262#[repr(u8)]
263pub enum ChannelKind {
264 Broadcast = 0,
266 Megagroup = 1,
268 Gigagroup = 2,
270}
271
272impl ChannelKind {
273 pub fn from_byte(b: u8) -> Option<Self> {
276 match b {
277 0 => Some(Self::Broadcast),
278 1 => Some(Self::Megagroup),
279 2 => Some(Self::Gigagroup),
280 _ => None,
281 }
282 }
283
284 pub fn to_byte(self) -> u8 {
286 self as u8
287 }
288}
289
290#[derive(Clone, Debug)]
293#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
294pub struct CachedPeer {
295 pub id: i64,
297 pub access_hash: i64,
300 pub is_channel: bool,
302 pub is_chat: bool,
305 pub channel_kind: Option<ChannelKind>,
308}
309
310#[derive(Clone, Debug)]
314#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
315pub struct CachedMinPeer {
316 pub user_id: i64,
318 pub peer_id: i64,
320 pub msg_id: i32,
322}
323
324#[derive(Clone, Debug, Default)]
326#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
327pub struct PersistedSession {
328 pub home_dc_id: i32,
329 pub dcs: Vec<DcEntry>,
330 pub updates_state: UpdatesStateSnap,
332 pub peers: Vec<CachedPeer>,
335 pub min_peers: Vec<CachedMinPeer>,
338}
339
340impl PersistedSession {
341 pub fn to_bytes(&self) -> Vec<u8> {
343 let mut b = Vec::with_capacity(512);
344
345 b.push(0x06u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
348
349 b.push(self.dcs.len() as u8);
350 for d in &self.dcs {
351 b.extend_from_slice(&d.dc_id.to_le_bytes());
352 match &d.auth_key {
353 Some(k) => {
354 b.push(1);
355 b.extend_from_slice(k);
356 }
357 None => {
358 b.push(0);
359 }
360 }
361 b.extend_from_slice(&d.first_salt.to_le_bytes());
362 b.extend_from_slice(&d.time_offset.to_le_bytes());
363 let ab = d.addr.as_bytes();
364 b.push(ab.len() as u8);
365 b.extend_from_slice(ab);
366 b.push(d.flags.0);
367 }
368
369 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
370 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
371 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
372 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
373 let ch = &self.updates_state.channels;
374 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
375 for &(cid, cpts) in ch {
376 b.extend_from_slice(&cid.to_le_bytes());
377 b.extend_from_slice(&cpts.to_le_bytes());
378 }
379
380 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
382 for p in &self.peers {
383 b.extend_from_slice(&p.id.to_le_bytes());
384 b.extend_from_slice(&p.access_hash.to_le_bytes());
385 let peer_type: u8 = if p.is_chat {
386 2
387 } else if p.is_channel {
388 1
389 } else {
390 0
391 };
392 b.push(peer_type);
393 b.push(p.channel_kind.map(|k| k.to_byte()).unwrap_or(0xFF));
395 }
396
397 b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
398 for m in &self.min_peers {
399 b.extend_from_slice(&m.user_id.to_le_bytes());
400 b.extend_from_slice(&m.peer_id.to_le_bytes());
401 b.extend_from_slice(&m.msg_id.to_le_bytes());
402 }
403
404 b
405 }
406
407 pub fn save(&self, path: &Path) -> io::Result<()> {
414 use std::sync::atomic::{AtomicU64, Ordering};
415 static SEQ: AtomicU64 = AtomicU64::new(0);
416 let n = SEQ.fetch_add(1, Ordering::Relaxed);
417 let tmp = path.with_extension(format!("{n}.tmp"));
418 std::fs::write(&tmp, self.to_bytes())?;
419 std::fs::rename(&tmp, path).inspect_err(|_e| {
420 let _ = std::fs::remove_file(&tmp);
421 })
422 }
423
424 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
426 if buf.is_empty() {
427 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
428 }
429
430 let mut p = 0usize;
431
432 macro_rules! r {
433 ($n:expr) => {{
434 if p + $n > buf.len() {
435 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
436 }
437 let s = &buf[p..p + $n];
438 p += $n;
439 s
440 }};
441 }
442 macro_rules! r_i32 {
443 () => {
444 i32::from_le_bytes(r!(4).try_into().unwrap())
445 };
446 }
447 macro_rules! r_i64 {
448 () => {
449 i64::from_le_bytes(r!(8).try_into().unwrap())
450 };
451 }
452 macro_rules! r_u8 {
453 () => {
454 r!(1)[0]
455 };
456 }
457 macro_rules! r_u16 {
458 () => {
459 u16::from_le_bytes(r!(2).try_into().unwrap())
460 };
461 }
462
463 let first_byte = r_u8!();
464
465 let (home_dc_id, version) = if first_byte == 0x06 {
466 (r_i32!(), 6u8)
467 } else if first_byte == 0x05 {
468 (r_i32!(), 5u8)
469 } else if first_byte == 0x04 {
470 (r_i32!(), 4u8)
471 } else if first_byte == 0x03 {
472 (r_i32!(), 3u8)
473 } else if first_byte == 0x02 {
474 (r_i32!(), 2u8)
475 } else {
476 let rest = r!(3);
477 let mut bytes = [0u8; 4];
478 bytes[0] = first_byte;
479 bytes[1..4].copy_from_slice(rest);
480 (i32::from_le_bytes(bytes), 1u8)
481 };
482
483 let dc_count = r_u8!() as usize;
484 let mut dcs = Vec::with_capacity(dc_count);
485 for _ in 0..dc_count {
486 let dc_id = r_i32!();
487 let has_key = r_u8!();
488 let auth_key = if has_key == 1 {
489 let mut k = [0u8; 256];
490 k.copy_from_slice(r!(256));
491 Some(k)
492 } else {
493 None
494 };
495 let first_salt = r_i64!();
496 let time_offset = r_i32!();
497 let al = r_u8!() as usize;
498 let addr = String::from_utf8_lossy(r!(al)).into_owned();
499 let flags = if version >= 3 {
500 DcFlags(r_u8!())
501 } else {
502 DcFlags::NONE
503 };
504 dcs.push(DcEntry {
505 dc_id,
506 addr,
507 auth_key,
508 first_salt,
509 time_offset,
510 flags,
511 });
512 }
513
514 if version < 2 {
515 return Ok(Self {
516 home_dc_id,
517 dcs,
518 updates_state: UpdatesStateSnap::default(),
519 peers: Vec::new(),
520 min_peers: Vec::new(),
521 });
522 }
523
524 let pts = r_i32!();
525 let qts = r_i32!();
526 let date = r_i32!();
527 let seq = r_i32!();
528 let ch_count = r_u16!() as usize;
529 let mut channels = Vec::with_capacity(ch_count);
530 for _ in 0..ch_count {
531 let cid = r_i64!();
532 let cpts = r_i32!();
533 channels.push((cid, cpts));
534 }
535
536 let peer_count = r_u16!() as usize;
537 let mut peers = Vec::with_capacity(peer_count);
538 for _ in 0..peer_count {
539 let id = r_i64!();
540 let access_hash = r_i64!();
541 let peer_type = r_u8!();
543 let is_channel = peer_type == 1;
544 let is_chat = peer_type == 2;
545 let channel_kind = if version >= 6 {
547 let kb = r_u8!();
548 if kb == 0xFF {
549 None
550 } else {
551 ChannelKind::from_byte(kb)
552 }
553 } else {
554 None
555 };
556 peers.push(CachedPeer {
557 id,
558 access_hash,
559 is_channel,
560 is_chat,
561 channel_kind,
562 });
563 }
564
565 let min_peers = if version >= 4 {
567 let count = r_u16!() as usize;
568 let mut v = Vec::with_capacity(count);
569 for _ in 0..count {
570 let user_id = r_i64!();
571 let peer_id = r_i64!();
572 let msg_id = r_i32!();
573 v.push(CachedMinPeer {
574 user_id,
575 peer_id,
576 msg_id,
577 });
578 }
579 v
580 } else {
581 Vec::new()
582 };
583
584 Ok(Self {
585 home_dc_id,
586 dcs,
587 updates_state: UpdatesStateSnap {
588 pts,
589 qts,
590 date,
591 seq,
592 channels,
593 },
594 peers,
595 min_peers,
596 })
597 }
598
599 pub fn from_string(s: &str) -> io::Result<Self> {
601 use base64::Engine as _;
602 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
603 .decode(s.trim())
604 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
605 Self::from_bytes(&bytes)
606 }
607
608 pub fn load(path: &Path) -> io::Result<Self> {
611 let buf = std::fs::read(path)?;
612 Self::from_bytes(&buf)
613 }
614
615 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
627 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
628 candidates.peek()?;
629 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
631 cands
633 .iter()
634 .copied()
635 .find(|d| d.is_ipv6() == prefer_ipv6)
636 .or_else(|| cands.first().copied())
637 }
638
639 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
644 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
645 }
646}
647
648impl std::fmt::Display for PersistedSession {
649 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
650 use base64::Engine as _;
651 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
652 }
653}
654
655pub fn default_dc_addresses() -> HashMap<i32, String> {
657 [
658 (1, "149.154.175.53:443"),
659 (2, "149.154.167.51:443"),
660 (3, "149.154.175.100:443"),
661 (4, "149.154.167.91:443"),
662 (5, "91.108.56.130:443"),
663 ]
664 .into_iter()
665 .map(|(id, addr)| (id, addr.to_string()))
666 .collect()
667}
668
669use std::path::PathBuf;
674
675pub trait SessionBackend: Send + Sync {
683 fn save(&self, session: &PersistedSession) -> io::Result<()>;
684 fn load(&self) -> io::Result<Option<PersistedSession>>;
685 fn delete(&self) -> io::Result<()>;
686
687 fn name(&self) -> &str;
689
690 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
703 let mut s = self.load()?.unwrap_or_default();
704 if let Some(existing) = s
706 .dcs
707 .iter_mut()
708 .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
709 {
710 *existing = entry.clone();
711 } else {
712 s.dcs.push(entry.clone());
713 }
714 self.save(&s)
715 }
716
717 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
723 let mut s = self.load()?.unwrap_or_default();
724 s.home_dc_id = dc_id;
725 self.save(&s)
726 }
727
728 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
733 let mut s = self.load()?.unwrap_or_default();
734 update.apply_to(&mut s.updates_state);
735 self.save(&s)
736 }
737
738 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
744 let mut s = self.load()?.unwrap_or_default();
745 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
746 *existing = peer.clone();
747 } else {
748 s.peers.push(peer.clone());
749 }
750 self.save(&s)
751 }
752}
753
754#[derive(Debug, Clone)]
769pub enum UpdateStateChange {
770 All(UpdatesStateSnap),
772 Primary { pts: i32, date: i32, seq: i32 },
774 Secondary { qts: i32 },
776 Channel { id: i64, pts: i32 },
778}
779
780impl UpdateStateChange {
781 pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
783 match self {
784 Self::All(new_snap) => *snap = new_snap.clone(),
785 Self::Primary { pts, date, seq } => {
786 snap.pts = *pts;
787 snap.date = *date;
788 snap.seq = *seq;
789 }
790 Self::Secondary { qts } => {
791 snap.qts = *qts;
792 }
793 Self::Channel { id, pts } => {
794 if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
796 existing.1 = *pts;
797 } else {
798 snap.channels.push((*id, *pts));
799 }
800 }
801 }
802 }
803}
804
805pub struct BinaryFileBackend {
809 path: PathBuf,
810 write_lock: std::sync::Mutex<()>,
815}
816
817impl BinaryFileBackend {
818 pub fn new(path: impl Into<PathBuf>) -> Self {
821 Self {
822 path: path.into(),
823 write_lock: std::sync::Mutex::new(()),
824 }
825 }
826
827 pub fn path(&self) -> &std::path::Path {
829 &self.path
830 }
831}
832
833impl SessionBackend for BinaryFileBackend {
834 fn save(&self, session: &PersistedSession) -> io::Result<()> {
835 let _guard = self.write_lock.lock().unwrap();
836 session.save(&self.path)
837 }
838
839 fn load(&self) -> io::Result<Option<PersistedSession>> {
840 if !self.path.exists() {
841 return Ok(None);
842 }
843 match PersistedSession::load(&self.path) {
844 Ok(s) => Ok(Some(s)),
845 Err(e) => {
846 let bak = self.path.with_extension("bak");
847 tracing::warn!(
848 "[ferogram::session] session file {:?} could not be read ({e}); \
849 backing it up to {:?} and starting a new session",
850 self.path,
851 bak
852 );
853 let _ = std::fs::rename(&self.path, &bak);
854 Ok(None)
855 }
856 }
857 }
858
859 fn delete(&self) -> io::Result<()> {
860 if self.path.exists() {
861 std::fs::remove_file(&self.path)?;
862 }
863 Ok(())
864 }
865
866 fn name(&self) -> &str {
867 "binary-file"
868 }
869
870 }
873
874#[derive(Default)]
882pub struct InMemoryBackend {
883 data: std::sync::Mutex<Option<PersistedSession>>,
884}
885
886impl InMemoryBackend {
887 pub fn new() -> Self {
889 Self::default()
890 }
891
892 pub fn snapshot(&self) -> Option<PersistedSession> {
894 self.data.lock().unwrap().clone()
895 }
896}
897
898impl SessionBackend for InMemoryBackend {
899 fn save(&self, s: &PersistedSession) -> io::Result<()> {
900 *self.data.lock().unwrap() = Some(s.clone());
901 Ok(())
902 }
903
904 fn load(&self) -> io::Result<Option<PersistedSession>> {
905 Ok(self.data.lock().unwrap().clone())
906 }
907
908 fn delete(&self) -> io::Result<()> {
909 *self.data.lock().unwrap() = None;
910 Ok(())
911 }
912
913 fn name(&self) -> &str {
914 "in-memory"
915 }
916
917 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
920 let mut guard = self.data.lock().unwrap();
921 let s = guard.get_or_insert_with(PersistedSession::default);
922 if let Some(existing) = s
923 .dcs
924 .iter_mut()
925 .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
926 {
927 *existing = entry.clone();
928 } else {
929 s.dcs.push(entry.clone());
930 }
931 Ok(())
932 }
933
934 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
935 let mut guard = self.data.lock().unwrap();
936 let s = guard.get_or_insert_with(PersistedSession::default);
937 s.home_dc_id = dc_id;
938 Ok(())
939 }
940
941 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
942 let mut guard = self.data.lock().unwrap();
943 let s = guard.get_or_insert_with(PersistedSession::default);
944 update.apply_to(&mut s.updates_state);
945 Ok(())
946 }
947
948 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
949 let mut guard = self.data.lock().unwrap();
950 let s = guard.get_or_insert_with(PersistedSession::default);
951 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
952 *existing = peer.clone();
953 } else {
954 s.peers.push(peer.clone());
955 }
956 Ok(())
957 }
958}
959
960pub struct StringSessionBackend {
964 data: std::sync::Mutex<String>,
965}
966
967impl StringSessionBackend {
968 pub fn new(s: impl Into<String>) -> Self {
971 Self {
972 data: std::sync::Mutex::new(s.into()),
973 }
974 }
975
976 pub fn current(&self) -> String {
978 self.data.lock().unwrap().clone()
979 }
980}
981
982impl SessionBackend for StringSessionBackend {
983 fn save(&self, session: &PersistedSession) -> io::Result<()> {
984 *self.data.lock().unwrap() = session.to_string();
985 Ok(())
986 }
987
988 fn load(&self) -> io::Result<Option<PersistedSession>> {
989 let s = self.data.lock().unwrap().clone();
990 if s.trim().is_empty() {
991 return Ok(None);
992 }
993 PersistedSession::from_string(&s).map(Some)
994 }
995
996 fn delete(&self) -> io::Result<()> {
997 *self.data.lock().unwrap() = String::new();
998 Ok(())
999 }
1000
1001 fn name(&self) -> &str {
1002 "string-session"
1003 }
1004}
1005
1006pub mod string_session;
1014pub use string_session::{FullSession, Session, StringSession, StringSessionError};
1015
1016#[cfg(test)]
1017mod tests {
1018 use super::*;
1019
1020 fn make_dc(id: i32) -> DcEntry {
1021 DcEntry {
1022 dc_id: id,
1023 addr: format!("1.2.3.{id}:443"),
1024 auth_key: None,
1025 first_salt: 0,
1026 time_offset: 0,
1027 flags: DcFlags::NONE,
1028 }
1029 }
1030
1031 fn make_peer(id: i64, hash: i64) -> CachedPeer {
1032 CachedPeer {
1033 id,
1034 access_hash: hash,
1035 is_channel: false,
1036 is_chat: false,
1037 channel_kind: None,
1038 }
1039 }
1040
1041 #[test]
1044 fn inmemory_load_returns_none_when_empty() {
1045 let b = InMemoryBackend::new();
1046 assert!(b.load().unwrap().is_none());
1047 }
1048
1049 #[test]
1050 fn inmemory_save_then_load_round_trips() {
1051 let b = InMemoryBackend::new();
1052 let mut s = PersistedSession::default();
1053 s.home_dc_id = 3;
1054 s.dcs.push(make_dc(3));
1055 b.save(&s).unwrap();
1056
1057 let loaded = b.load().unwrap().unwrap();
1058 assert_eq!(loaded.home_dc_id, 3);
1059 assert_eq!(loaded.dcs.len(), 1);
1060 }
1061
1062 #[test]
1063 fn inmemory_delete_clears_state() {
1064 let b = InMemoryBackend::new();
1065 let mut s = PersistedSession::default();
1066 s.home_dc_id = 2;
1067 b.save(&s).unwrap();
1068 b.delete().unwrap();
1069 assert!(b.load().unwrap().is_none());
1070 }
1071
1072 #[test]
1075 fn inmemory_update_dc_inserts_new() {
1076 let b = InMemoryBackend::new();
1077 b.update_dc(&make_dc(4)).unwrap();
1078 let s = b.snapshot().unwrap();
1079 assert_eq!(s.dcs.len(), 1);
1080 assert_eq!(s.dcs[0].dc_id, 4);
1081 }
1082
1083 #[test]
1084 fn inmemory_update_dc_replaces_existing() {
1085 let b = InMemoryBackend::new();
1086 b.update_dc(&make_dc(2)).unwrap();
1087 let mut updated = make_dc(2);
1088 updated.addr = "9.9.9.9:443".to_string();
1089 b.update_dc(&updated).unwrap();
1090
1091 let s = b.snapshot().unwrap();
1092 assert_eq!(s.dcs.len(), 1);
1093 assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
1094 }
1095
1096 #[test]
1097 fn inmemory_set_home_dc() {
1098 let b = InMemoryBackend::new();
1099 b.set_home_dc(5).unwrap();
1100 assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
1101 }
1102
1103 #[test]
1104 fn inmemory_cache_peer_inserts() {
1105 let b = InMemoryBackend::new();
1106 b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
1107 let s = b.snapshot().unwrap();
1108 assert_eq!(s.peers.len(), 1);
1109 assert_eq!(s.peers[0].id, 100);
1110 }
1111
1112 #[test]
1113 fn inmemory_cache_peer_updates_existing() {
1114 let b = InMemoryBackend::new();
1115 b.cache_peer(&make_peer(100, 111)).unwrap();
1116 b.cache_peer(&make_peer(100, 222)).unwrap();
1117 let s = b.snapshot().unwrap();
1118 assert_eq!(s.peers.len(), 1);
1119 assert_eq!(s.peers[0].access_hash, 222);
1120 }
1121
1122 #[test]
1125 fn update_state_primary() {
1126 let mut snap = UpdatesStateSnap {
1127 pts: 0,
1128 qts: 0,
1129 date: 0,
1130 seq: 0,
1131 channels: vec![],
1132 };
1133 UpdateStateChange::Primary {
1134 pts: 10,
1135 date: 20,
1136 seq: 30,
1137 }
1138 .apply_to(&mut snap);
1139 assert_eq!(snap.pts, 10);
1140 assert_eq!(snap.date, 20);
1141 assert_eq!(snap.seq, 30);
1142 assert_eq!(snap.qts, 0); }
1144
1145 #[test]
1146 fn update_state_secondary() {
1147 let mut snap = UpdatesStateSnap {
1148 pts: 5,
1149 qts: 0,
1150 date: 0,
1151 seq: 0,
1152 channels: vec![],
1153 };
1154 UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1155 assert_eq!(snap.qts, 99);
1156 assert_eq!(snap.pts, 5); }
1158
1159 #[test]
1160 fn update_state_channel_inserts() {
1161 let mut snap = UpdatesStateSnap {
1162 pts: 0,
1163 qts: 0,
1164 date: 0,
1165 seq: 0,
1166 channels: vec![],
1167 };
1168 UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1169 assert_eq!(snap.channels, vec![(12345, 42)]);
1170 }
1171
1172 #[test]
1173 fn update_state_channel_updates_existing() {
1174 let mut snap = UpdatesStateSnap {
1175 pts: 0,
1176 qts: 0,
1177 date: 0,
1178 seq: 0,
1179 channels: vec![(12345, 10), (67890, 5)],
1180 };
1181 UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1182 assert_eq!(snap.channels[0], (12345, 99));
1184 assert_eq!(snap.channels[1], (67890, 5));
1185 }
1186
1187 #[test]
1188 fn apply_update_state_via_backend() {
1189 let b = InMemoryBackend::new();
1190 b.apply_update_state(UpdateStateChange::Primary {
1191 pts: 7,
1192 date: 8,
1193 seq: 9,
1194 })
1195 .unwrap();
1196 let s = b.snapshot().unwrap();
1197 assert_eq!(s.updates_state.pts, 7);
1198 }
1199
1200 #[test]
1203 fn default_update_dc_via_trait_object() {
1204 let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1205 b.update_dc(&make_dc(1)).unwrap();
1206 b.update_dc(&make_dc(2)).unwrap();
1207 let loaded = b.load().unwrap().unwrap();
1209 assert_eq!(loaded.dcs.len(), 2);
1210 }
1211
1212 fn make_dc_v6(id: i32) -> DcEntry {
1215 DcEntry {
1216 dc_id: id,
1217 addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1218 auth_key: None,
1219 first_salt: 0,
1220 time_offset: 0,
1221 flags: DcFlags::IPV6,
1222 }
1223 }
1224
1225 #[test]
1226 fn dc_entry_from_parts_ipv4() {
1227 let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1228 assert_eq!(dc.addr, "149.154.175.53:443");
1229 assert!(!dc.is_ipv6());
1230 let sa = dc.socket_addr().unwrap();
1231 assert_eq!(sa.port(), 443);
1232 }
1233
1234 #[test]
1235 fn dc_entry_from_parts_ipv6() {
1236 let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1237 assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1238 assert!(dc.is_ipv6());
1239 let sa = dc.socket_addr().unwrap();
1240 assert_eq!(sa.port(), 443);
1241 }
1242
1243 #[test]
1244 fn persisted_session_dc_for_prefers_ipv6() {
1245 let mut s = PersistedSession::default();
1246 s.dcs.push(make_dc(2)); s.dcs.push(make_dc_v6(2)); let v6 = s.dc_for(2, true).unwrap();
1250 assert!(v6.is_ipv6());
1251
1252 let v4 = s.dc_for(2, false).unwrap();
1253 assert!(!v4.is_ipv6());
1254 }
1255
1256 #[test]
1257 fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1258 let mut s = PersistedSession::default();
1259 s.dcs.push(make_dc(3)); let dc = s.dc_for(3, true).unwrap();
1263 assert!(!dc.is_ipv6());
1264 }
1265
1266 #[test]
1267 fn persisted_session_all_dcs_for_returns_both() {
1268 let mut s = PersistedSession::default();
1269 s.dcs.push(make_dc(1));
1270 s.dcs.push(make_dc_v6(1));
1271 s.dcs.push(make_dc(2));
1272
1273 assert_eq!(s.all_dcs_for(1).count(), 2);
1274 assert_eq!(s.all_dcs_for(2).count(), 1);
1275 assert_eq!(s.all_dcs_for(5).count(), 0);
1276 }
1277
1278 #[test]
1279 fn inmemory_ipv4_and_ipv6_coexist() {
1280 let b = InMemoryBackend::new();
1281 b.update_dc(&make_dc(2)).unwrap(); b.update_dc(&make_dc_v6(2)).unwrap(); let s = b.snapshot().unwrap();
1285 assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1287 }
1288
1289 #[test]
1290 fn binary_roundtrip_ipv4_and_ipv6() {
1291 let mut s = PersistedSession::default();
1292 s.home_dc_id = 2;
1293 s.dcs.push(make_dc(2));
1294 s.dcs.push(make_dc_v6(2));
1295
1296 let bytes = s.to_bytes();
1297 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1298 assert_eq!(loaded.dcs.len(), 2);
1299 assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1300 assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1301 }
1302
1303 #[test]
1304 fn v6_channel_kind_roundtrip_all_variants() {
1305 let mut s = PersistedSession::default();
1306 s.home_dc_id = 1;
1307 s.peers.push(CachedPeer {
1308 id: 1001,
1309 access_hash: 0xaaaa,
1310 is_channel: true,
1311 is_chat: false,
1312 channel_kind: Some(ChannelKind::Broadcast),
1313 });
1314 s.peers.push(CachedPeer {
1315 id: 1002,
1316 access_hash: 0xbbbb,
1317 is_channel: true,
1318 is_chat: false,
1319 channel_kind: Some(ChannelKind::Megagroup),
1320 });
1321 s.peers.push(CachedPeer {
1322 id: 1003,
1323 access_hash: 0xcccc,
1324 is_channel: true,
1325 is_chat: false,
1326 channel_kind: Some(ChannelKind::Gigagroup),
1327 });
1328 s.peers.push(CachedPeer {
1329 id: 1004,
1330 access_hash: 0xdddd,
1331 is_channel: false,
1332 is_chat: false,
1333 channel_kind: None,
1334 });
1335
1336 let bytes = s.to_bytes();
1337 assert_eq!(bytes[0], 0x06, "version byte must be 6");
1338
1339 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1340 assert_eq!(loaded.peers.len(), 4);
1341
1342 let p = &loaded.peers[0];
1343 assert_eq!(p.id, 1001);
1344 assert_eq!(p.channel_kind, Some(ChannelKind::Broadcast));
1345
1346 let p = &loaded.peers[1];
1347 assert_eq!(p.id, 1002);
1348 assert_eq!(p.channel_kind, Some(ChannelKind::Megagroup));
1349
1350 let p = &loaded.peers[2];
1351 assert_eq!(p.id, 1003);
1352 assert_eq!(p.channel_kind, Some(ChannelKind::Gigagroup));
1353
1354 let p = &loaded.peers[3];
1355 assert_eq!(p.id, 1004);
1356 assert_eq!(p.channel_kind, None);
1357 }
1358
1359 #[test]
1360 fn v6_channel_kind_absent_sentinel_roundtrip() {
1361 let mut s = PersistedSession::default();
1363 s.home_dc_id = 1;
1364 s.peers.push(CachedPeer {
1365 id: 555,
1366 access_hash: 0x1234,
1367 is_channel: false,
1368 is_chat: false,
1369 channel_kind: None,
1370 });
1371 let loaded = PersistedSession::from_bytes(&s.to_bytes()).unwrap();
1372 assert_eq!(loaded.peers[0].channel_kind, None);
1373 }
1374}
1375
1376#[cfg(feature = "sqlite-session")]
1402pub struct SqliteBackend {
1403 conn: std::sync::Mutex<rusqlite::Connection>,
1404 label: String,
1405}
1406
1407#[cfg(feature = "sqlite-session")]
1408impl SqliteBackend {
1409 const SCHEMA: &'static str = "
1410 PRAGMA journal_mode = WAL;
1411 PRAGMA synchronous = NORMAL;
1412
1413 CREATE TABLE IF NOT EXISTS meta (
1414 key TEXT PRIMARY KEY,
1415 value INTEGER NOT NULL DEFAULT 0
1416 );
1417
1418 CREATE TABLE IF NOT EXISTS dcs (
1419 dc_id INTEGER NOT NULL,
1420 flags INTEGER NOT NULL DEFAULT 0,
1421 addr TEXT NOT NULL,
1422 auth_key BLOB,
1423 first_salt INTEGER NOT NULL DEFAULT 0,
1424 time_offset INTEGER NOT NULL DEFAULT 0,
1425 PRIMARY KEY (dc_id, flags)
1426 );
1427
1428 CREATE TABLE IF NOT EXISTS update_state (
1429 id INTEGER PRIMARY KEY CHECK (id = 1),
1430 pts INTEGER NOT NULL DEFAULT 0,
1431 qts INTEGER NOT NULL DEFAULT 0,
1432 date INTEGER NOT NULL DEFAULT 0,
1433 seq INTEGER NOT NULL DEFAULT 0
1434 );
1435
1436 CREATE TABLE IF NOT EXISTS channel_pts (
1437 channel_id INTEGER PRIMARY KEY,
1438 pts INTEGER NOT NULL
1439 );
1440
1441 CREATE TABLE IF NOT EXISTS peers (
1442 id INTEGER PRIMARY KEY,
1443 access_hash INTEGER NOT NULL,
1444 is_channel INTEGER NOT NULL DEFAULT 0,
1445 is_chat INTEGER NOT NULL DEFAULT 0,
1446 channel_kind INTEGER
1447 );
1448
1449 CREATE TABLE IF NOT EXISTS min_peers (
1450 user_id INTEGER PRIMARY KEY,
1451 peer_id INTEGER NOT NULL,
1452 msg_id INTEGER NOT NULL
1453 );
1454 ";
1455
1456 pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1458 let path = path.into();
1459 let label = path.display().to_string();
1460 let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1461 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1462 Self::migrate_legacy_sqlite_schema(&conn)?;
1463 Ok(Self {
1464 conn: std::sync::Mutex::new(conn),
1465 label,
1466 })
1467 }
1468
1469 pub fn in_memory() -> io::Result<Self> {
1471 let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1472 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1473 Self::migrate_legacy_sqlite_schema(&conn)?;
1474 Ok(Self {
1475 conn: std::sync::Mutex::new(conn),
1476 label: ":memory:".into(),
1477 })
1478 }
1479
1480 fn map_err(e: rusqlite::Error) -> io::Error {
1481 io::Error::other(e)
1482 }
1483
1484 fn migrate_legacy_sqlite_schema(conn: &rusqlite::Connection) -> io::Result<()> {
1488 let mut has_is_chat = false;
1489 let mut stmt = conn
1490 .prepare("PRAGMA table_info(peers)")
1491 .map_err(Self::map_err)?;
1492 let cols = stmt
1493 .query_map([], |row| row.get::<_, String>(1))
1494 .map_err(Self::map_err)?;
1495 for col in cols.filter_map(|r| r.ok()) {
1496 if col == "is_chat" {
1497 has_is_chat = true;
1498 break;
1499 }
1500 }
1501 if !has_is_chat {
1502 conn.execute_batch("ALTER TABLE peers ADD COLUMN is_chat INTEGER NOT NULL DEFAULT 0;")
1503 .map_err(Self::map_err)?;
1504 }
1505 let mut has_channel_kind = false;
1507 let mut stmt2 = conn
1508 .prepare("PRAGMA table_info(peers)")
1509 .map_err(Self::map_err)?;
1510 let cols2 = stmt2
1511 .query_map([], |row| row.get::<_, String>(1))
1512 .map_err(Self::map_err)?;
1513 for col in cols2.filter_map(|r| r.ok()) {
1514 if col == "channel_kind" {
1515 has_channel_kind = true;
1516 break;
1517 }
1518 }
1519 if !has_channel_kind {
1520 conn.execute_batch("ALTER TABLE peers ADD COLUMN channel_kind INTEGER;")
1521 .map_err(Self::map_err)?;
1522 }
1523 conn.execute_batch(
1524 "CREATE TABLE IF NOT EXISTS min_peers (
1525 user_id INTEGER PRIMARY KEY,
1526 peer_id INTEGER NOT NULL,
1527 msg_id INTEGER NOT NULL
1528 );",
1529 )
1530 .map_err(Self::map_err)?;
1531 Ok(())
1532 }
1533
1534 fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1536 let home_dc_id: i32 = conn
1538 .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1539 r.get(0)
1540 })
1541 .unwrap_or(0);
1542
1543 let mut stmt = conn
1545 .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1546 .map_err(Self::map_err)?;
1547 let dcs = stmt
1548 .query_map([], |row| {
1549 let dc_id: i32 = row.get(0)?;
1550 let flags_raw: u8 = row.get(1)?;
1551 let addr: String = row.get(2)?;
1552 let key_blob: Option<Vec<u8>> = row.get(3)?;
1553 let first_salt: i64 = row.get(4)?;
1554 let time_offset: i32 = row.get(5)?;
1555 Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1556 })
1557 .map_err(Self::map_err)?
1558 .filter_map(|r| r.ok())
1559 .map(
1560 |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1561 let auth_key = key_blob.and_then(|b| {
1562 if b.len() == 256 {
1563 let mut k = [0u8; 256];
1564 k.copy_from_slice(&b);
1565 Some(k)
1566 } else {
1567 None
1568 }
1569 });
1570 DcEntry {
1571 dc_id,
1572 addr,
1573 auth_key,
1574 first_salt,
1575 time_offset,
1576 flags: DcFlags(flags_raw),
1577 }
1578 },
1579 )
1580 .collect();
1581
1582 let updates_state = conn
1584 .query_row(
1585 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1586 [],
1587 |r| {
1588 Ok(UpdatesStateSnap {
1589 pts: r.get(0)?,
1590 qts: r.get(1)?,
1591 date: r.get(2)?,
1592 seq: r.get(3)?,
1593 channels: vec![],
1594 })
1595 },
1596 )
1597 .unwrap_or_default();
1598
1599 let mut ch_stmt = conn
1601 .prepare("SELECT channel_id, pts FROM channel_pts")
1602 .map_err(Self::map_err)?;
1603 let channels: Vec<(i64, i32)> = ch_stmt
1604 .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1605 .map_err(Self::map_err)?
1606 .filter_map(|r| r.ok())
1607 .collect();
1608
1609 let mut peer_stmt = conn
1611 .prepare("SELECT id, access_hash, is_channel, is_chat, channel_kind FROM peers")
1612 .map_err(Self::map_err)?;
1613 let peers: Vec<CachedPeer> = peer_stmt
1614 .query_map([], |r| {
1615 let kind_raw: Option<i32> = r.get(4)?;
1616 Ok(CachedPeer {
1617 id: r.get(0)?,
1618 access_hash: r.get(1)?,
1619 is_channel: r.get::<_, i32>(2)? != 0,
1620 is_chat: r.get::<_, i32>(3)? != 0,
1621 channel_kind: kind_raw.and_then(|k| ChannelKind::from_byte(k as u8)),
1622 })
1623 })
1624 .map_err(Self::map_err)?
1625 .filter_map(|r| r.ok())
1626 .collect();
1627
1628 let mut min_stmt = conn
1630 .prepare("SELECT user_id, peer_id, msg_id FROM min_peers")
1631 .map_err(Self::map_err)?;
1632 let min_peers: Vec<CachedMinPeer> = min_stmt
1633 .query_map([], |r| {
1634 Ok(CachedMinPeer {
1635 user_id: r.get(0)?,
1636 peer_id: r.get(1)?,
1637 msg_id: r.get(2)?,
1638 })
1639 })
1640 .map_err(Self::map_err)?
1641 .filter_map(|r| r.ok())
1642 .collect();
1643
1644 Ok(PersistedSession {
1645 home_dc_id,
1646 dcs,
1647 updates_state: UpdatesStateSnap {
1648 channels,
1649 ..updates_state
1650 },
1651 peers,
1652 min_peers,
1653 })
1654 }
1655
1656 fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1658 conn.execute_batch("BEGIN IMMEDIATE")
1659 .map_err(Self::map_err)?;
1660
1661 conn.execute(
1662 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1663 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1664 rusqlite::params![s.home_dc_id],
1665 )
1666 .map_err(Self::map_err)?;
1667
1668 conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1670 for d in &s.dcs {
1671 conn.execute(
1672 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1673 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1674 rusqlite::params![
1675 d.dc_id,
1676 d.flags.0,
1677 d.addr,
1678 d.auth_key.as_ref().map(|k| k.as_ref()),
1679 d.first_salt,
1680 d.time_offset,
1681 ],
1682 )
1683 .map_err(Self::map_err)?;
1684 }
1685
1686 let us = &s.updates_state;
1690 conn.execute(
1691 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1692 ON CONFLICT(id) DO UPDATE SET
1693 pts = MAX(excluded.pts, update_state.pts),
1694 qts = MAX(excluded.qts, update_state.qts),
1695 date = excluded.date,
1696 seq = excluded.seq",
1697 rusqlite::params![us.pts, us.qts, us.date, us.seq],
1698 )
1699 .map_err(Self::map_err)?;
1700
1701 conn.execute("DELETE FROM channel_pts", [])
1702 .map_err(Self::map_err)?;
1703 for &(cid, cpts) in &us.channels {
1704 conn.execute(
1705 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1706 rusqlite::params![cid, cpts],
1707 )
1708 .map_err(Self::map_err)?;
1709 }
1710
1711 conn.execute("DELETE FROM peers", [])
1713 .map_err(Self::map_err)?;
1714 for p in &s.peers {
1715 conn.execute(
1716 "INSERT INTO peers (id, access_hash, is_channel, is_chat, channel_kind) VALUES (?1, ?2, ?3, ?4, ?5)",
1717 rusqlite::params![
1718 p.id,
1719 p.access_hash,
1720 p.is_channel as i32,
1721 p.is_chat as i32,
1722 p.channel_kind.map(|k| k.to_byte() as i32),
1723 ],
1724 )
1725 .map_err(Self::map_err)?;
1726 }
1727
1728 conn.execute("DELETE FROM min_peers", [])
1730 .map_err(Self::map_err)?;
1731 for m in &s.min_peers {
1732 conn.execute(
1733 "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1, ?2, ?3)",
1734 rusqlite::params![m.user_id, m.peer_id, m.msg_id],
1735 )
1736 .map_err(Self::map_err)?;
1737 }
1738
1739 conn.execute_batch("COMMIT").map_err(Self::map_err)
1740 }
1741}
1742
1743#[cfg(feature = "sqlite-session")]
1744impl SessionBackend for SqliteBackend {
1745 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1746 let conn = self.conn.lock().unwrap();
1747 Self::write_session(&conn, session)
1748 }
1749
1750 fn load(&self) -> io::Result<Option<PersistedSession>> {
1751 let conn = self.conn.lock().unwrap();
1752 let count: i64 = conn
1754 .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1755 .map_err(Self::map_err)?;
1756 if count == 0 {
1757 return Ok(None);
1758 }
1759 Self::read_session(&conn).map(Some)
1760 }
1761
1762 fn delete(&self) -> io::Result<()> {
1763 let conn = self.conn.lock().unwrap();
1764 conn.execute_batch(
1765 "BEGIN IMMEDIATE;
1766 DELETE FROM meta;
1767 DELETE FROM dcs;
1768 DELETE FROM update_state;
1769 DELETE FROM channel_pts;
1770 DELETE FROM peers;
1771 DELETE FROM min_peers;
1772 COMMIT;",
1773 )
1774 .map_err(Self::map_err)
1775 }
1776
1777 fn name(&self) -> &str {
1778 &self.label
1779 }
1780
1781 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1784 let conn = self.conn.lock().unwrap();
1785 conn.execute(
1786 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1787 VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1788 ON CONFLICT(dc_id, flags) DO UPDATE SET
1789 addr = excluded.addr,
1790 auth_key = excluded.auth_key,
1791 first_salt = excluded.first_salt,
1792 time_offset = excluded.time_offset",
1793 rusqlite::params![
1794 entry.dc_id,
1795 entry.addr,
1796 entry.auth_key.as_ref().map(|k| k.as_ref()),
1797 entry.first_salt,
1798 entry.time_offset,
1799 entry.flags.0,
1800 ],
1801 )
1802 .map(|_| ())
1803 .map_err(Self::map_err)
1804 }
1805
1806 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1807 let conn = self.conn.lock().unwrap();
1808 conn.execute(
1809 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1810 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1811 rusqlite::params![dc_id],
1812 )
1813 .map(|_| ())
1814 .map_err(Self::map_err)
1815 }
1816
1817 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1818 let conn = self.conn.lock().unwrap();
1819 match update {
1820 UpdateStateChange::All(snap) => {
1821 conn.execute(
1822 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1823 ON CONFLICT(id) DO UPDATE SET
1824 pts=excluded.pts, qts=excluded.qts,
1825 date=excluded.date, seq=excluded.seq",
1826 rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1827 )
1828 .map_err(Self::map_err)?;
1829 conn.execute("DELETE FROM channel_pts", [])
1830 .map_err(Self::map_err)?;
1831 for &(cid, cpts) in &snap.channels {
1832 conn.execute(
1833 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1834 rusqlite::params![cid, cpts],
1835 )
1836 .map_err(Self::map_err)?;
1837 }
1838 Ok(())
1839 }
1840 UpdateStateChange::Primary { pts, date, seq } => conn
1841 .execute(
1842 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1843 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1844 seq=excluded.seq",
1845 rusqlite::params![pts, date, seq],
1846 )
1847 .map(|_| ())
1848 .map_err(Self::map_err),
1849 UpdateStateChange::Secondary { qts } => conn
1850 .execute(
1851 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1852 ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1853 rusqlite::params![qts],
1854 )
1855 .map(|_| ())
1856 .map_err(Self::map_err),
1857 UpdateStateChange::Channel { id, pts } => conn
1858 .execute(
1859 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1860 ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1861 rusqlite::params![id, pts],
1862 )
1863 .map(|_| ())
1864 .map_err(Self::map_err),
1865 }
1866 }
1867
1868 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1869 let conn = self.conn.lock().unwrap();
1870 conn.execute(
1871 "INSERT INTO peers (id, access_hash, is_channel, is_chat, channel_kind) VALUES (?1, ?2, ?3, ?4, ?5)
1872 ON CONFLICT(id) DO UPDATE SET
1873 access_hash = excluded.access_hash,
1874 is_channel = excluded.is_channel,
1875 is_chat = excluded.is_chat,
1876 channel_kind = excluded.channel_kind",
1877 rusqlite::params![
1878 peer.id,
1879 peer.access_hash,
1880 peer.is_channel as i32,
1881 peer.is_chat as i32,
1882 peer.channel_kind.map(|k| k.to_byte() as i32),
1883 ],
1884 )
1885 .map(|_| ())
1886 .map_err(Self::map_err)
1887 }
1888}
1889
1890#[cfg(feature = "libsql-session")]
1910pub struct LibSqlBackend {
1911 conn: libsql::Connection,
1912 label: String,
1913}
1914
1915#[cfg(feature = "libsql-session")]
1916impl LibSqlBackend {
1917 const SCHEMA: &'static str = "
1918 CREATE TABLE IF NOT EXISTS meta (
1919 key TEXT PRIMARY KEY,
1920 value INTEGER NOT NULL DEFAULT 0
1921 );
1922 CREATE TABLE IF NOT EXISTS dcs (
1923 dc_id INTEGER NOT NULL,
1924 flags INTEGER NOT NULL DEFAULT 0,
1925 addr TEXT NOT NULL,
1926 auth_key BLOB,
1927 first_salt INTEGER NOT NULL DEFAULT 0,
1928 time_offset INTEGER NOT NULL DEFAULT 0,
1929 PRIMARY KEY (dc_id, flags)
1930 );
1931 CREATE TABLE IF NOT EXISTS update_state (
1932 id INTEGER PRIMARY KEY CHECK (id = 1),
1933 pts INTEGER NOT NULL DEFAULT 0,
1934 qts INTEGER NOT NULL DEFAULT 0,
1935 date INTEGER NOT NULL DEFAULT 0,
1936 seq INTEGER NOT NULL DEFAULT 0
1937 );
1938 CREATE TABLE IF NOT EXISTS channel_pts (
1939 channel_id INTEGER PRIMARY KEY,
1940 pts INTEGER NOT NULL
1941 );
1942 CREATE TABLE IF NOT EXISTS peers (
1943 id INTEGER PRIMARY KEY,
1944 access_hash INTEGER NOT NULL,
1945 is_channel INTEGER NOT NULL DEFAULT 0,
1946 is_chat INTEGER NOT NULL DEFAULT 0,
1947 channel_kind INTEGER
1948 );
1949 CREATE TABLE IF NOT EXISTS min_peers (
1950 user_id INTEGER PRIMARY KEY,
1951 peer_id INTEGER NOT NULL,
1952 msg_id INTEGER NOT NULL
1953 );
1954 ";
1955
1956 fn block<F, T>(fut: F) -> io::Result<T>
1957 where
1958 F: std::future::Future<Output = Result<T, libsql::Error>>,
1959 {
1960 tokio::runtime::Handle::current()
1961 .block_on(fut)
1962 .map_err(io::Error::other)
1963 }
1964
1965 async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1966 conn.execute_batch(Self::SCHEMA).await?;
1967 let _ = conn
1971 .execute_batch("ALTER TABLE peers ADD COLUMN channel_kind INTEGER;")
1972 .await;
1973 Ok(())
1974 }
1975
1976 pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1978 let path = path.into();
1979 let label = path.display().to_string();
1980 let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1981 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1982 Self::block(Self::apply_schema(&conn))?;
1983 Ok(Self { conn, label })
1984 }
1985
1986 pub fn in_memory() -> io::Result<Self> {
1988 let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1989 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1990 Self::block(Self::apply_schema(&conn))?;
1991 Ok(Self {
1992 conn,
1993 label: ":memory:".into(),
1994 })
1995 }
1996
1997 pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1999 let url = url.into();
2000 let label = url.clone();
2001 let db = Self::block(async {
2002 libsql::Builder::new_remote(url, auth_token.into())
2003 .build()
2004 .await
2005 })?;
2006 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
2007 Self::block(Self::apply_schema(&conn))?;
2008 Ok(Self { conn, label })
2009 }
2010
2011 pub fn open_replica(
2013 path: impl Into<PathBuf>,
2014 url: impl Into<String>,
2015 auth_token: impl Into<String>,
2016 ) -> io::Result<Self> {
2017 let path = path.into();
2018 let url = url.into();
2019 let auth_token = auth_token.into();
2020 let label = format!("{} (replica of {})", path.display(), url);
2021 let db = Self::block(async {
2022 libsql::Builder::new_remote_replica(path, url, auth_token)
2023 .build()
2024 .await
2025 })?;
2026 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
2027 Self::block(Self::apply_schema(&conn))?;
2028 Ok(Self { conn, label })
2029 }
2030
2031 async fn read_session_async(
2032 conn: &libsql::Connection,
2033 ) -> Result<PersistedSession, libsql::Error> {
2034 let home_dc_id: i32 = conn
2036 .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
2037 .await?
2038 .next()
2039 .await?
2040 .map(|r| r.get::<i32>(0))
2041 .transpose()?
2042 .unwrap_or(0);
2043
2044 let mut rows = conn
2046 .query(
2047 "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
2048 (),
2049 )
2050 .await?;
2051 let mut dcs = Vec::new();
2052 while let Some(row) = rows.next().await? {
2053 let dc_id: i32 = row.get(0)?;
2054 let flags_raw: u8 = row.get::<i64>(1)? as u8;
2055 let addr: String = row.get(2)?;
2056 let key_blob: Option<Vec<u8>> = row.get(3)?;
2057 let first_salt: i64 = row.get(4)?;
2058 let time_offset: i32 = row.get(5)?;
2059 let auth_key = match key_blob {
2060 Some(b) if b.len() == 256 => {
2061 let mut k = [0u8; 256];
2062 k.copy_from_slice(&b);
2063 Some(k)
2064 }
2065 Some(b) => {
2066 return Err(libsql::Error::Misuse(format!(
2067 "auth_key blob must be 256 bytes, got {}",
2068 b.len()
2069 )));
2070 }
2071 None => None,
2072 };
2073 dcs.push(DcEntry {
2074 dc_id,
2075 addr,
2076 auth_key,
2077 first_salt,
2078 time_offset,
2079 flags: DcFlags(flags_raw),
2080 });
2081 }
2082
2083 let mut us_row = conn
2085 .query(
2086 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
2087 (),
2088 )
2089 .await?;
2090 let updates_state = if let Some(r) = us_row.next().await? {
2091 UpdatesStateSnap {
2092 pts: r.get(0)?,
2093 qts: r.get(1)?,
2094 date: r.get(2)?,
2095 seq: r.get(3)?,
2096 channels: vec![],
2097 }
2098 } else {
2099 UpdatesStateSnap::default()
2100 };
2101
2102 let mut ch_rows = conn
2104 .query("SELECT channel_id, pts FROM channel_pts", ())
2105 .await?;
2106 let mut channels = Vec::new();
2107 while let Some(r) = ch_rows.next().await? {
2108 channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
2109 }
2110
2111 let mut peer_rows = conn
2113 .query(
2114 "SELECT id, access_hash, is_channel, is_chat, channel_kind FROM peers",
2115 (),
2116 )
2117 .await?;
2118 let mut peers = Vec::new();
2119 while let Some(r) = peer_rows.next().await? {
2120 let kind_raw: Option<i32> = r.get(4).ok();
2121 peers.push(CachedPeer {
2122 id: r.get(0)?,
2123 access_hash: r.get(1)?,
2124 is_channel: r.get::<i32>(2)? != 0,
2125 is_chat: r.get::<i32>(3)? != 0,
2126 channel_kind: kind_raw.and_then(|k| ChannelKind::from_byte(k as u8)),
2127 });
2128 }
2129
2130 let mut min_rows = conn
2132 .query("SELECT user_id, peer_id, msg_id FROM min_peers", ())
2133 .await?;
2134 let mut min_peers = Vec::new();
2135 while let Some(r) = min_rows.next().await? {
2136 min_peers.push(CachedMinPeer {
2137 user_id: r.get(0)?,
2138 peer_id: r.get(1)?,
2139 msg_id: r.get(2)?,
2140 });
2141 }
2142
2143 Ok(PersistedSession {
2144 home_dc_id,
2145 dcs,
2146 updates_state: UpdatesStateSnap {
2147 channels,
2148 ..updates_state
2149 },
2150 peers,
2151 min_peers,
2152 })
2153 }
2154
2155 async fn write_session_async(
2156 conn: &libsql::Connection,
2157 s: &PersistedSession,
2158 ) -> Result<(), libsql::Error> {
2159 conn.execute_batch("BEGIN IMMEDIATE").await.map(|_| ())?;
2160
2161 conn.execute(
2162 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
2163 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
2164 libsql::params![s.home_dc_id],
2165 )
2166 .await?;
2167
2168 conn.execute("DELETE FROM dcs", ()).await?;
2169 for d in &s.dcs {
2170 conn.execute(
2171 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2172 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
2173 libsql::params![
2174 d.dc_id,
2175 d.flags.0 as i64,
2176 d.addr.clone(),
2177 d.auth_key.map(|k| k.to_vec()),
2178 d.first_salt,
2179 d.time_offset,
2180 ],
2181 )
2182 .await?;
2183 }
2184
2185 let us = &s.updates_state;
2186 conn.execute(
2187 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
2188 ON CONFLICT(id) DO UPDATE SET
2189 pts = MAX(excluded.pts, update_state.pts),
2190 qts = MAX(excluded.qts, update_state.qts),
2191 date = excluded.date,
2192 seq = excluded.seq",
2193 libsql::params![us.pts, us.qts, us.date, us.seq],
2194 )
2195 .await?;
2196
2197 conn.execute("DELETE FROM channel_pts", ()).await?;
2198 for &(cid, cpts) in &us.channels {
2199 conn.execute(
2200 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
2201 libsql::params![cid, cpts],
2202 )
2203 .await?;
2204 }
2205
2206 conn.execute("DELETE FROM peers", ()).await?;
2207 for p in &s.peers {
2208 conn.execute(
2209 "INSERT INTO peers (id, access_hash, is_channel, is_chat, channel_kind) VALUES (?1,?2,?3,?4,?5)",
2210 libsql::params![
2211 p.id,
2212 p.access_hash,
2213 p.is_channel as i32,
2214 p.is_chat as i32,
2215 p.channel_kind.map(|k| k.to_byte() as i32),
2216 ],
2217 )
2218 .await?;
2219 }
2220
2221 conn.execute("DELETE FROM min_peers", ()).await?;
2222 for m in &s.min_peers {
2223 conn.execute(
2224 "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1,?2,?3)",
2225 libsql::params![m.user_id, m.peer_id, m.msg_id],
2226 )
2227 .await?;
2228 }
2229
2230 conn.execute_batch("COMMIT").await.map(|_| ())
2231 }
2232}
2233
2234#[cfg(feature = "libsql-session")]
2235impl SessionBackend for LibSqlBackend {
2236 fn save(&self, session: &PersistedSession) -> io::Result<()> {
2237 let conn = self.conn.clone();
2238 let session = session.clone();
2239 Self::block(async move { Self::write_session_async(&conn, &session).await })
2240 }
2241
2242 fn load(&self) -> io::Result<Option<PersistedSession>> {
2243 let conn = self.conn.clone();
2244 let count: i64 = Self::block(async move {
2245 let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
2246 Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
2247 })?;
2248 if count == 0 {
2249 return Ok(None);
2250 }
2251 let conn = self.conn.clone();
2252 Self::block(async move { Self::read_session_async(&conn).await }).map(Some)
2253 }
2254
2255 fn delete(&self) -> io::Result<()> {
2256 let conn = self.conn.clone();
2257 Self::block(async move {
2258 conn.execute_batch(
2259 "BEGIN IMMEDIATE;
2260 DELETE FROM meta;
2261 DELETE FROM dcs;
2262 DELETE FROM update_state;
2263 DELETE FROM channel_pts;
2264 DELETE FROM peers;
2265 DELETE FROM min_peers;
2266 COMMIT;",
2267 )
2268 .await
2269 .map(|_| ())
2270 })
2271 }
2272
2273 fn name(&self) -> &str {
2274 &self.label
2275 }
2276
2277 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
2280 let conn = self.conn.clone();
2281 let (dc_id, addr, key, salt, off, flags) = (
2282 entry.dc_id,
2283 entry.addr.clone(),
2284 entry.auth_key.map(|k| k.to_vec()),
2285 entry.first_salt,
2286 entry.time_offset,
2287 entry.flags.0 as i64,
2288 );
2289 Self::block(async move {
2290 conn.execute(
2291 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2292 VALUES (?1,?6,?2,?3,?4,?5)
2293 ON CONFLICT(dc_id, flags) DO UPDATE SET
2294 addr=excluded.addr, auth_key=excluded.auth_key,
2295 first_salt=excluded.first_salt, time_offset=excluded.time_offset",
2296 libsql::params![dc_id, addr, key, salt, off, flags],
2297 )
2298 .await
2299 .map(|_| ())
2300 })
2301 }
2302
2303 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
2304 let conn = self.conn.clone();
2305 Self::block(async move {
2306 conn.execute(
2307 "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
2308 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
2309 libsql::params![dc_id],
2310 )
2311 .await
2312 .map(|_| ())
2313 })
2314 }
2315
2316 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
2317 let conn = self.conn.clone();
2318 Self::block(async move {
2319 match update {
2320 UpdateStateChange::All(snap) => {
2321 conn.execute(
2322 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
2323 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
2324 date=excluded.date,seq=excluded.seq",
2325 libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
2326 )
2327 .await?;
2328 conn.execute("DELETE FROM channel_pts", ()).await?;
2329 for &(cid, cpts) in &snap.channels {
2330 conn.execute(
2331 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
2332 libsql::params![cid, cpts],
2333 )
2334 .await?;
2335 }
2336 Ok(())
2337 }
2338 UpdateStateChange::Primary { pts, date, seq } => conn
2339 .execute(
2340 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
2341 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
2342 seq=excluded.seq",
2343 libsql::params![pts, date, seq],
2344 )
2345 .await
2346 .map(|_| ()),
2347 UpdateStateChange::Secondary { qts } => conn
2348 .execute(
2349 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2350 ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2351 libsql::params![qts],
2352 )
2353 .await
2354 .map(|_| ()),
2355 UpdateStateChange::Channel { id, pts } => conn
2356 .execute(
2357 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2358 ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2359 libsql::params![id, pts],
2360 )
2361 .await
2362 .map(|_| ()),
2363 }
2364 })
2365 }
2366
2367 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2368 let conn = self.conn.clone();
2369 let (id, hash, is_ch, is_ct, kind) = (
2370 peer.id,
2371 peer.access_hash,
2372 peer.is_channel as i32,
2373 peer.is_chat as i32,
2374 peer.channel_kind.map(|k| k.to_byte() as i32),
2375 );
2376 Self::block(async move {
2377 conn.execute(
2378 "INSERT INTO peers (id,access_hash,is_channel,is_chat,channel_kind) VALUES (?1,?2,?3,?4,?5)
2379 ON CONFLICT(id) DO UPDATE SET
2380 access_hash=excluded.access_hash,
2381 is_channel=excluded.is_channel,
2382 is_chat=excluded.is_chat,
2383 channel_kind=excluded.channel_kind",
2384 libsql::params![id, hash, is_ch, is_ct, kind],
2385 )
2386 .await
2387 .map(|_| ())
2388 })
2389 }
2390}