1use std::collections::HashMap;
30use std::io::{self, ErrorKind};
31use std::path::Path;
32
33#[cfg(feature = "serde")]
34mod auth_key_serde {
35 use serde::{Deserialize, Deserializer, Serializer};
36
37 pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
38 where
39 S: Serializer,
40 {
41 match value {
42 Some(k) => s.serialize_some(k.as_slice()),
43 None => s.serialize_none(),
44 }
45 }
46
47 pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
48 where
49 D: Deserializer<'de>,
50 {
51 let opt: Option<Vec<u8>> = Option::deserialize(d)?;
52 match opt {
53 None => Ok(None),
54 Some(v) => {
55 let arr: [u8; 256] = v
56 .try_into()
57 .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
58 Ok(Some(arr))
59 }
60 }
61 }
62}
63
64#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
68#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
69pub struct DcFlags(pub u8);
70
71impl DcFlags {
72 pub const NONE: DcFlags = DcFlags(0);
73 pub const IPV6: DcFlags = DcFlags(1 << 0);
74 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
75 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
76 pub const CDN: DcFlags = DcFlags(1 << 3);
77 pub const STATIC: DcFlags = DcFlags(1 << 4);
78
79 pub fn contains(self, other: DcFlags) -> bool {
80 self.0 & other.0 == other.0
81 }
82
83 pub fn set(&mut self, flag: DcFlags) {
84 self.0 |= flag.0;
85 }
86}
87
88impl std::ops::BitOr for DcFlags {
89 type Output = DcFlags;
90 fn bitor(self, rhs: DcFlags) -> DcFlags {
91 DcFlags(self.0 | rhs.0)
92 }
93}
94
95#[derive(Clone, Debug)]
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98pub struct DcEntry {
99 pub dc_id: i32,
100 pub addr: String,
101 #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
102 pub auth_key: Option<[u8; 256]>,
103 pub first_salt: i64,
104 pub time_offset: i32,
105 pub flags: DcFlags,
107}
108
109impl DcEntry {
110 #[inline]
112 pub fn is_ipv6(&self) -> bool {
113 self.flags.contains(DcFlags::IPV6)
114 }
115
116 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
123 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
124 io::Error::new(
125 io::ErrorKind::InvalidData,
126 format!("invalid DC address: {:?}", self.addr),
127 )
128 })
129 }
130
131 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
144 let addr = if ip.contains(':') {
146 format!("[{ip}]:{port}")
147 } else {
148 format!("{ip}:{port}")
149 };
150 Self {
151 dc_id,
152 addr,
153 auth_key: None,
154 first_salt: 0,
155 time_offset: 0,
156 flags,
157 }
158 }
159}
160
161#[derive(Clone, Debug, Default)]
164#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
165pub struct UpdatesStateSnap {
166 pub pts: i32,
168 pub qts: i32,
170 pub date: i32,
172 pub seq: i32,
174 pub channels: Vec<(i64, i32)>,
176}
177
178impl UpdatesStateSnap {
179 #[inline]
181 pub fn is_initialised(&self) -> bool {
182 self.pts > 0
183 }
184
185 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
187 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
188 entry.1 = pts;
189 } else {
190 self.channels.push((channel_id, pts));
191 }
192 }
193
194 pub fn channel_pts(&self, channel_id: i64) -> i32 {
196 self.channels
197 .iter()
198 .find(|c| c.0 == channel_id)
199 .map(|c| c.1)
200 .unwrap_or(0)
201 }
202}
203
204#[derive(Clone, Debug)]
207#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
208pub struct CachedPeer {
209 pub id: i64,
211 pub access_hash: i64,
214 pub is_channel: bool,
216 pub is_chat: bool,
219}
220
221#[derive(Clone, Debug)]
225#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
226pub struct CachedMinPeer {
227 pub user_id: i64,
229 pub peer_id: i64,
231 pub msg_id: i32,
233}
234
235#[derive(Clone, Debug, Default)]
237#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
238pub struct PersistedSession {
239 pub home_dc_id: i32,
240 pub dcs: Vec<DcEntry>,
241 pub updates_state: UpdatesStateSnap,
243 pub peers: Vec<CachedPeer>,
246 pub min_peers: Vec<CachedMinPeer>,
249}
250
251impl PersistedSession {
252 pub fn to_bytes(&self) -> Vec<u8> {
254 let mut b = Vec::with_capacity(512);
255
256 b.push(0x05u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
259
260 b.push(self.dcs.len() as u8);
261 for d in &self.dcs {
262 b.extend_from_slice(&d.dc_id.to_le_bytes());
263 match &d.auth_key {
264 Some(k) => {
265 b.push(1);
266 b.extend_from_slice(k);
267 }
268 None => {
269 b.push(0);
270 }
271 }
272 b.extend_from_slice(&d.first_salt.to_le_bytes());
273 b.extend_from_slice(&d.time_offset.to_le_bytes());
274 let ab = d.addr.as_bytes();
275 b.push(ab.len() as u8);
276 b.extend_from_slice(ab);
277 b.push(d.flags.0);
278 }
279
280 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
281 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
282 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
283 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
284 let ch = &self.updates_state.channels;
285 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
286 for &(cid, cpts) in ch {
287 b.extend_from_slice(&cid.to_le_bytes());
288 b.extend_from_slice(&cpts.to_le_bytes());
289 }
290
291 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
293 for p in &self.peers {
294 b.extend_from_slice(&p.id.to_le_bytes());
295 b.extend_from_slice(&p.access_hash.to_le_bytes());
296 let peer_type: u8 = if p.is_chat {
297 2
298 } else if p.is_channel {
299 1
300 } else {
301 0
302 };
303 b.push(peer_type);
304 }
305
306 b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
307 for m in &self.min_peers {
308 b.extend_from_slice(&m.user_id.to_le_bytes());
309 b.extend_from_slice(&m.peer_id.to_le_bytes());
310 b.extend_from_slice(&m.msg_id.to_le_bytes());
311 }
312
313 b
314 }
315
316 pub fn save(&self, path: &Path) -> io::Result<()> {
323 use std::sync::atomic::{AtomicU64, Ordering};
324 static SEQ: AtomicU64 = AtomicU64::new(0);
325 let n = SEQ.fetch_add(1, Ordering::Relaxed);
326 let tmp = path.with_extension(format!("{n}.tmp"));
327 std::fs::write(&tmp, self.to_bytes())?;
328 std::fs::rename(&tmp, path).inspect_err(|_e| {
329 let _ = std::fs::remove_file(&tmp);
330 })
331 }
332
333 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
335 if buf.is_empty() {
336 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
337 }
338
339 let mut p = 0usize;
340
341 macro_rules! r {
342 ($n:expr) => {{
343 if p + $n > buf.len() {
344 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
345 }
346 let s = &buf[p..p + $n];
347 p += $n;
348 s
349 }};
350 }
351 macro_rules! r_i32 {
352 () => {
353 i32::from_le_bytes(r!(4).try_into().unwrap())
354 };
355 }
356 macro_rules! r_i64 {
357 () => {
358 i64::from_le_bytes(r!(8).try_into().unwrap())
359 };
360 }
361 macro_rules! r_u8 {
362 () => {
363 r!(1)[0]
364 };
365 }
366 macro_rules! r_u16 {
367 () => {
368 u16::from_le_bytes(r!(2).try_into().unwrap())
369 };
370 }
371
372 let first_byte = r_u8!();
373
374 let (home_dc_id, version) = if first_byte == 0x05 {
375 (r_i32!(), 5u8)
376 } else if first_byte == 0x04 {
377 (r_i32!(), 4u8)
378 } else if first_byte == 0x03 {
379 (r_i32!(), 3u8)
380 } else if first_byte == 0x02 {
381 (r_i32!(), 2u8)
382 } else {
383 let rest = r!(3);
384 let mut bytes = [0u8; 4];
385 bytes[0] = first_byte;
386 bytes[1..4].copy_from_slice(rest);
387 (i32::from_le_bytes(bytes), 1u8)
388 };
389
390 let dc_count = r_u8!() as usize;
391 let mut dcs = Vec::with_capacity(dc_count);
392 for _ in 0..dc_count {
393 let dc_id = r_i32!();
394 let has_key = r_u8!();
395 let auth_key = if has_key == 1 {
396 let mut k = [0u8; 256];
397 k.copy_from_slice(r!(256));
398 Some(k)
399 } else {
400 None
401 };
402 let first_salt = r_i64!();
403 let time_offset = r_i32!();
404 let al = r_u8!() as usize;
405 let addr = String::from_utf8_lossy(r!(al)).into_owned();
406 let flags = if version >= 3 {
407 DcFlags(r_u8!())
408 } else {
409 DcFlags::NONE
410 };
411 dcs.push(DcEntry {
412 dc_id,
413 addr,
414 auth_key,
415 first_salt,
416 time_offset,
417 flags,
418 });
419 }
420
421 if version < 2 {
422 return Ok(Self {
423 home_dc_id,
424 dcs,
425 updates_state: UpdatesStateSnap::default(),
426 peers: Vec::new(),
427 min_peers: Vec::new(),
428 });
429 }
430
431 let pts = r_i32!();
432 let qts = r_i32!();
433 let date = r_i32!();
434 let seq = r_i32!();
435 let ch_count = r_u16!() as usize;
436 let mut channels = Vec::with_capacity(ch_count);
437 for _ in 0..ch_count {
438 let cid = r_i64!();
439 let cpts = r_i32!();
440 channels.push((cid, cpts));
441 }
442
443 let peer_count = r_u16!() as usize;
444 let mut peers = Vec::with_capacity(peer_count);
445 for _ in 0..peer_count {
446 let id = r_i64!();
447 let access_hash = r_i64!();
448 let peer_type = r_u8!();
450 let is_channel = peer_type == 1;
451 let is_chat = peer_type == 2;
452 peers.push(CachedPeer {
453 id,
454 access_hash,
455 is_channel,
456 is_chat,
457 });
458 }
459
460 let min_peers = if version >= 4 {
462 let count = r_u16!() as usize;
463 let mut v = Vec::with_capacity(count);
464 for _ in 0..count {
465 let user_id = r_i64!();
466 let peer_id = r_i64!();
467 let msg_id = r_i32!();
468 v.push(CachedMinPeer {
469 user_id,
470 peer_id,
471 msg_id,
472 });
473 }
474 v
475 } else {
476 Vec::new()
477 };
478
479 Ok(Self {
480 home_dc_id,
481 dcs,
482 updates_state: UpdatesStateSnap {
483 pts,
484 qts,
485 date,
486 seq,
487 channels,
488 },
489 peers,
490 min_peers,
491 })
492 }
493
494 pub fn from_string(s: &str) -> io::Result<Self> {
496 use base64::Engine as _;
497 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
498 .decode(s.trim())
499 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
500 Self::from_bytes(&bytes)
501 }
502
503 pub fn load(path: &Path) -> io::Result<Self> {
504 let buf = std::fs::read(path)?;
505 Self::from_bytes(&buf)
506 }
507
508 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
520 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
521 candidates.peek()?;
522 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
524 cands
526 .iter()
527 .copied()
528 .find(|d| d.is_ipv6() == prefer_ipv6)
529 .or_else(|| cands.first().copied())
530 }
531
532 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
537 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
538 }
539}
540
541impl std::fmt::Display for PersistedSession {
542 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543 use base64::Engine as _;
544 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
545 }
546}
547
548pub fn default_dc_addresses() -> HashMap<i32, String> {
550 [
551 (1, "149.154.175.53:443"),
552 (2, "149.154.167.51:443"),
553 (3, "149.154.175.100:443"),
554 (4, "149.154.167.91:443"),
555 (5, "91.108.56.130:443"),
556 ]
557 .into_iter()
558 .map(|(id, addr)| (id, addr.to_string()))
559 .collect()
560}
561
562use std::path::PathBuf;
567
568pub trait SessionBackend: Send + Sync {
576 fn save(&self, session: &PersistedSession) -> io::Result<()>;
577 fn load(&self) -> io::Result<Option<PersistedSession>>;
578 fn delete(&self) -> io::Result<()>;
579
580 fn name(&self) -> &str;
582
583 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
597 let mut s = self.load()?.unwrap_or_default();
598 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
600 *existing = entry.clone();
601 } else {
602 s.dcs.push(entry.clone());
603 }
604 self.save(&s)
605 }
606
607 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
614 let mut s = self.load()?.unwrap_or_default();
615 s.home_dc_id = dc_id;
616 self.save(&s)
617 }
618
619 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
625 let mut s = self.load()?.unwrap_or_default();
626 update.apply_to(&mut s.updates_state);
627 self.save(&s)
628 }
629
630 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
637 let mut s = self.load()?.unwrap_or_default();
638 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
639 *existing = peer.clone();
640 } else {
641 s.peers.push(peer.clone());
642 }
643 self.save(&s)
644 }
645}
646
647#[derive(Debug, Clone)]
661pub enum UpdateStateChange {
662 All(UpdatesStateSnap),
664 Primary { pts: i32, date: i32, seq: i32 },
666 Secondary { qts: i32 },
668 Channel { id: i64, pts: i32 },
670}
671
672impl UpdateStateChange {
673 pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
675 match self {
676 Self::All(new_snap) => *snap = new_snap.clone(),
677 Self::Primary { pts, date, seq } => {
678 snap.pts = *pts;
679 snap.date = *date;
680 snap.seq = *seq;
681 }
682 Self::Secondary { qts } => {
683 snap.qts = *qts;
684 }
685 Self::Channel { id, pts } => {
686 if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
688 existing.1 = *pts;
689 } else {
690 snap.channels.push((*id, *pts));
691 }
692 }
693 }
694 }
695}
696
697pub struct BinaryFileBackend {
701 path: PathBuf,
702 write_lock: std::sync::Mutex<()>,
707}
708
709impl BinaryFileBackend {
710 pub fn new(path: impl Into<PathBuf>) -> Self {
711 Self {
712 path: path.into(),
713 write_lock: std::sync::Mutex::new(()),
714 }
715 }
716
717 pub fn path(&self) -> &std::path::Path {
718 &self.path
719 }
720}
721
722impl SessionBackend for BinaryFileBackend {
723 fn save(&self, session: &PersistedSession) -> io::Result<()> {
724 let _guard = self.write_lock.lock().unwrap();
725 session.save(&self.path)
726 }
727
728 fn load(&self) -> io::Result<Option<PersistedSession>> {
729 if !self.path.exists() {
730 return Ok(None);
731 }
732 match PersistedSession::load(&self.path) {
733 Ok(s) => Ok(Some(s)),
734 Err(e) => {
735 let bak = self.path.with_extension("bak");
736 tracing::warn!(
737 "[ferogram] Session file {:?} is corrupt ({e}); \
738 renaming to {:?} and starting fresh",
739 self.path,
740 bak
741 );
742 let _ = std::fs::rename(&self.path, &bak);
743 Ok(None)
744 }
745 }
746 }
747
748 fn delete(&self) -> io::Result<()> {
749 if self.path.exists() {
750 std::fs::remove_file(&self.path)?;
751 }
752 Ok(())
753 }
754
755 fn name(&self) -> &str {
756 "binary-file"
757 }
758
759 }
762
763#[derive(Default)]
771pub struct InMemoryBackend {
772 data: std::sync::Mutex<Option<PersistedSession>>,
773}
774
775impl InMemoryBackend {
776 pub fn new() -> Self {
777 Self::default()
778 }
779
780 pub fn snapshot(&self) -> Option<PersistedSession> {
782 self.data.lock().unwrap().clone()
783 }
784}
785
786impl SessionBackend for InMemoryBackend {
787 fn save(&self, s: &PersistedSession) -> io::Result<()> {
788 *self.data.lock().unwrap() = Some(s.clone());
789 Ok(())
790 }
791
792 fn load(&self) -> io::Result<Option<PersistedSession>> {
793 Ok(self.data.lock().unwrap().clone())
794 }
795
796 fn delete(&self) -> io::Result<()> {
797 *self.data.lock().unwrap() = None;
798 Ok(())
799 }
800
801 fn name(&self) -> &str {
802 "in-memory"
803 }
804
805 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
808 let mut guard = self.data.lock().unwrap();
809 let s = guard.get_or_insert_with(PersistedSession::default);
810 if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
811 *existing = entry.clone();
812 } else {
813 s.dcs.push(entry.clone());
814 }
815 Ok(())
816 }
817
818 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
819 let mut guard = self.data.lock().unwrap();
820 let s = guard.get_or_insert_with(PersistedSession::default);
821 s.home_dc_id = dc_id;
822 Ok(())
823 }
824
825 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
826 let mut guard = self.data.lock().unwrap();
827 let s = guard.get_or_insert_with(PersistedSession::default);
828 update.apply_to(&mut s.updates_state);
829 Ok(())
830 }
831
832 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
833 let mut guard = self.data.lock().unwrap();
834 let s = guard.get_or_insert_with(PersistedSession::default);
835 if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
836 *existing = peer.clone();
837 } else {
838 s.peers.push(peer.clone());
839 }
840 Ok(())
841 }
842}
843
844pub struct StringSessionBackend {
848 data: std::sync::Mutex<String>,
849}
850
851impl StringSessionBackend {
852 pub fn new(s: impl Into<String>) -> Self {
853 Self {
854 data: std::sync::Mutex::new(s.into()),
855 }
856 }
857
858 pub fn current(&self) -> String {
859 self.data.lock().unwrap().clone()
860 }
861}
862
863impl SessionBackend for StringSessionBackend {
864 fn save(&self, session: &PersistedSession) -> io::Result<()> {
865 *self.data.lock().unwrap() = session.to_string();
866 Ok(())
867 }
868
869 fn load(&self) -> io::Result<Option<PersistedSession>> {
870 let s = self.data.lock().unwrap().clone();
871 if s.trim().is_empty() {
872 return Ok(None);
873 }
874 PersistedSession::from_string(&s).map(Some)
875 }
876
877 fn delete(&self) -> io::Result<()> {
878 *self.data.lock().unwrap() = String::new();
879 Ok(())
880 }
881
882 fn name(&self) -> &str {
883 "string-session"
884 }
885}
886
887#[cfg(test)]
890mod tests {
891 use super::*;
892
893 fn make_dc(id: i32) -> DcEntry {
894 DcEntry {
895 dc_id: id,
896 addr: format!("1.2.3.{id}:443"),
897 auth_key: None,
898 first_salt: 0,
899 time_offset: 0,
900 flags: DcFlags::NONE,
901 }
902 }
903
904 fn make_peer(id: i64, hash: i64) -> CachedPeer {
905 CachedPeer {
906 id,
907 access_hash: hash,
908 is_channel: false,
909 is_chat: false,
910 }
911 }
912
913 #[test]
916 fn inmemory_load_returns_none_when_empty() {
917 let b = InMemoryBackend::new();
918 assert!(b.load().unwrap().is_none());
919 }
920
921 #[test]
922 fn inmemory_save_then_load_round_trips() {
923 let b = InMemoryBackend::new();
924 let mut s = PersistedSession::default();
925 s.home_dc_id = 3;
926 s.dcs.push(make_dc(3));
927 b.save(&s).unwrap();
928
929 let loaded = b.load().unwrap().unwrap();
930 assert_eq!(loaded.home_dc_id, 3);
931 assert_eq!(loaded.dcs.len(), 1);
932 }
933
934 #[test]
935 fn inmemory_delete_clears_state() {
936 let b = InMemoryBackend::new();
937 let mut s = PersistedSession::default();
938 s.home_dc_id = 2;
939 b.save(&s).unwrap();
940 b.delete().unwrap();
941 assert!(b.load().unwrap().is_none());
942 }
943
944 #[test]
947 fn inmemory_update_dc_inserts_new() {
948 let b = InMemoryBackend::new();
949 b.update_dc(&make_dc(4)).unwrap();
950 let s = b.snapshot().unwrap();
951 assert_eq!(s.dcs.len(), 1);
952 assert_eq!(s.dcs[0].dc_id, 4);
953 }
954
955 #[test]
956 fn inmemory_update_dc_replaces_existing() {
957 let b = InMemoryBackend::new();
958 b.update_dc(&make_dc(2)).unwrap();
959 let mut updated = make_dc(2);
960 updated.addr = "9.9.9.9:443".to_string();
961 b.update_dc(&updated).unwrap();
962
963 let s = b.snapshot().unwrap();
964 assert_eq!(s.dcs.len(), 1);
965 assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
966 }
967
968 #[test]
969 fn inmemory_set_home_dc() {
970 let b = InMemoryBackend::new();
971 b.set_home_dc(5).unwrap();
972 assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
973 }
974
975 #[test]
976 fn inmemory_cache_peer_inserts() {
977 let b = InMemoryBackend::new();
978 b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
979 let s = b.snapshot().unwrap();
980 assert_eq!(s.peers.len(), 1);
981 assert_eq!(s.peers[0].id, 100);
982 }
983
984 #[test]
985 fn inmemory_cache_peer_updates_existing() {
986 let b = InMemoryBackend::new();
987 b.cache_peer(&make_peer(100, 111)).unwrap();
988 b.cache_peer(&make_peer(100, 222)).unwrap();
989 let s = b.snapshot().unwrap();
990 assert_eq!(s.peers.len(), 1);
991 assert_eq!(s.peers[0].access_hash, 222);
992 }
993
994 #[test]
997 fn update_state_primary() {
998 let mut snap = UpdatesStateSnap {
999 pts: 0,
1000 qts: 0,
1001 date: 0,
1002 seq: 0,
1003 channels: vec![],
1004 };
1005 UpdateStateChange::Primary {
1006 pts: 10,
1007 date: 20,
1008 seq: 30,
1009 }
1010 .apply_to(&mut snap);
1011 assert_eq!(snap.pts, 10);
1012 assert_eq!(snap.date, 20);
1013 assert_eq!(snap.seq, 30);
1014 assert_eq!(snap.qts, 0); }
1016
1017 #[test]
1018 fn update_state_secondary() {
1019 let mut snap = UpdatesStateSnap {
1020 pts: 5,
1021 qts: 0,
1022 date: 0,
1023 seq: 0,
1024 channels: vec![],
1025 };
1026 UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1027 assert_eq!(snap.qts, 99);
1028 assert_eq!(snap.pts, 5); }
1030
1031 #[test]
1032 fn update_state_channel_inserts() {
1033 let mut snap = UpdatesStateSnap {
1034 pts: 0,
1035 qts: 0,
1036 date: 0,
1037 seq: 0,
1038 channels: vec![],
1039 };
1040 UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1041 assert_eq!(snap.channels, vec![(12345, 42)]);
1042 }
1043
1044 #[test]
1045 fn update_state_channel_updates_existing() {
1046 let mut snap = UpdatesStateSnap {
1047 pts: 0,
1048 qts: 0,
1049 date: 0,
1050 seq: 0,
1051 channels: vec![(12345, 10), (67890, 5)],
1052 };
1053 UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1054 assert_eq!(snap.channels[0], (12345, 99));
1056 assert_eq!(snap.channels[1], (67890, 5));
1057 }
1058
1059 #[test]
1060 fn apply_update_state_via_backend() {
1061 let b = InMemoryBackend::new();
1062 b.apply_update_state(UpdateStateChange::Primary {
1063 pts: 7,
1064 date: 8,
1065 seq: 9,
1066 })
1067 .unwrap();
1068 let s = b.snapshot().unwrap();
1069 assert_eq!(s.updates_state.pts, 7);
1070 }
1071
1072 #[test]
1075 fn default_update_dc_via_trait_object() {
1076 let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1077 b.update_dc(&make_dc(1)).unwrap();
1078 b.update_dc(&make_dc(2)).unwrap();
1079 let loaded = b.load().unwrap().unwrap();
1081 assert_eq!(loaded.dcs.len(), 2);
1082 }
1083
1084 fn make_dc_v6(id: i32) -> DcEntry {
1087 DcEntry {
1088 dc_id: id,
1089 addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1090 auth_key: None,
1091 first_salt: 0,
1092 time_offset: 0,
1093 flags: DcFlags::IPV6,
1094 }
1095 }
1096
1097 #[test]
1098 fn dc_entry_from_parts_ipv4() {
1099 let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1100 assert_eq!(dc.addr, "149.154.175.53:443");
1101 assert!(!dc.is_ipv6());
1102 let sa = dc.socket_addr().unwrap();
1103 assert_eq!(sa.port(), 443);
1104 }
1105
1106 #[test]
1107 fn dc_entry_from_parts_ipv6() {
1108 let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1109 assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1110 assert!(dc.is_ipv6());
1111 let sa = dc.socket_addr().unwrap();
1112 assert_eq!(sa.port(), 443);
1113 }
1114
1115 #[test]
1116 fn persisted_session_dc_for_prefers_ipv6() {
1117 let mut s = PersistedSession::default();
1118 s.dcs.push(make_dc(2)); s.dcs.push(make_dc_v6(2)); let v6 = s.dc_for(2, true).unwrap();
1122 assert!(v6.is_ipv6());
1123
1124 let v4 = s.dc_for(2, false).unwrap();
1125 assert!(!v4.is_ipv6());
1126 }
1127
1128 #[test]
1129 fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1130 let mut s = PersistedSession::default();
1131 s.dcs.push(make_dc(3)); let dc = s.dc_for(3, true).unwrap();
1135 assert!(!dc.is_ipv6());
1136 }
1137
1138 #[test]
1139 fn persisted_session_all_dcs_for_returns_both() {
1140 let mut s = PersistedSession::default();
1141 s.dcs.push(make_dc(1));
1142 s.dcs.push(make_dc_v6(1));
1143 s.dcs.push(make_dc(2));
1144
1145 assert_eq!(s.all_dcs_for(1).count(), 2);
1146 assert_eq!(s.all_dcs_for(2).count(), 1);
1147 assert_eq!(s.all_dcs_for(5).count(), 0);
1148 }
1149
1150 #[test]
1151 fn inmemory_ipv4_and_ipv6_coexist() {
1152 let b = InMemoryBackend::new();
1153 b.update_dc(&make_dc(2)).unwrap(); b.update_dc(&make_dc_v6(2)).unwrap(); let s = b.snapshot().unwrap();
1157 assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1159 }
1160
1161 #[test]
1162 fn binary_roundtrip_ipv4_and_ipv6() {
1163 let mut s = PersistedSession::default();
1164 s.home_dc_id = 2;
1165 s.dcs.push(make_dc(2));
1166 s.dcs.push(make_dc_v6(2));
1167
1168 let bytes = s.to_bytes();
1169 let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1170 assert_eq!(loaded.dcs.len(), 2);
1171 assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1172 assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1173 }
1174}
1175
1176#[cfg(feature = "sqlite-session")]
1201pub struct SqliteBackend {
1202 conn: std::sync::Mutex<rusqlite::Connection>,
1203 label: String,
1204}
1205
1206#[cfg(feature = "sqlite-session")]
1207impl SqliteBackend {
1208 const SCHEMA: &'static str = "
1209 PRAGMA journal_mode = WAL;
1210 PRAGMA synchronous = NORMAL;
1211
1212 CREATE TABLE IF NOT EXISTS meta (
1213 key TEXT PRIMARY KEY,
1214 value INTEGER NOT NULL DEFAULT 0
1215 );
1216
1217 CREATE TABLE IF NOT EXISTS dcs (
1218 dc_id INTEGER NOT NULL,
1219 flags INTEGER NOT NULL DEFAULT 0,
1220 addr TEXT NOT NULL,
1221 auth_key BLOB,
1222 first_salt INTEGER NOT NULL DEFAULT 0,
1223 time_offset INTEGER NOT NULL DEFAULT 0,
1224 PRIMARY KEY (dc_id, flags)
1225 );
1226
1227 CREATE TABLE IF NOT EXISTS update_state (
1228 id INTEGER PRIMARY KEY CHECK (id = 1),
1229 pts INTEGER NOT NULL DEFAULT 0,
1230 qts INTEGER NOT NULL DEFAULT 0,
1231 date INTEGER NOT NULL DEFAULT 0,
1232 seq INTEGER NOT NULL DEFAULT 0
1233 );
1234
1235 CREATE TABLE IF NOT EXISTS channel_pts (
1236 channel_id INTEGER PRIMARY KEY,
1237 pts INTEGER NOT NULL
1238 );
1239
1240 CREATE TABLE IF NOT EXISTS peers (
1241 id INTEGER PRIMARY KEY,
1242 access_hash INTEGER NOT NULL,
1243 is_channel INTEGER NOT NULL DEFAULT 0
1244 );
1245 ";
1246
1247 pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1249 let path = path.into();
1250 let label = path.display().to_string();
1251 let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1252 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1253 Ok(Self {
1254 conn: std::sync::Mutex::new(conn),
1255 label,
1256 })
1257 }
1258
1259 pub fn in_memory() -> io::Result<Self> {
1261 let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1262 conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1263 Ok(Self {
1264 conn: std::sync::Mutex::new(conn),
1265 label: ":memory:".into(),
1266 })
1267 }
1268
1269 fn map_err(e: rusqlite::Error) -> io::Error {
1270 io::Error::other(e)
1271 }
1272
1273 fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1275 let home_dc_id: i32 = conn
1277 .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1278 r.get(0)
1279 })
1280 .unwrap_or(0);
1281
1282 let mut stmt = conn
1284 .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1285 .map_err(Self::map_err)?;
1286 let dcs = stmt
1287 .query_map([], |row| {
1288 let dc_id: i32 = row.get(0)?;
1289 let flags_raw: u8 = row.get(1)?;
1290 let addr: String = row.get(2)?;
1291 let key_blob: Option<Vec<u8>> = row.get(3)?;
1292 let first_salt: i64 = row.get(4)?;
1293 let time_offset: i32 = row.get(5)?;
1294 Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1295 })
1296 .map_err(Self::map_err)?
1297 .filter_map(|r| r.ok())
1298 .map(
1299 |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1300 let auth_key = key_blob.and_then(|b| {
1301 if b.len() == 256 {
1302 let mut k = [0u8; 256];
1303 k.copy_from_slice(&b);
1304 Some(k)
1305 } else {
1306 None
1307 }
1308 });
1309 DcEntry {
1310 dc_id,
1311 addr,
1312 auth_key,
1313 first_salt,
1314 time_offset,
1315 flags: DcFlags(flags_raw),
1316 }
1317 },
1318 )
1319 .collect();
1320
1321 let updates_state = conn
1323 .query_row(
1324 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1325 [],
1326 |r| {
1327 Ok(UpdatesStateSnap {
1328 pts: r.get(0)?,
1329 qts: r.get(1)?,
1330 date: r.get(2)?,
1331 seq: r.get(3)?,
1332 channels: vec![],
1333 })
1334 },
1335 )
1336 .unwrap_or_default();
1337
1338 let mut ch_stmt = conn
1340 .prepare("SELECT channel_id, pts FROM channel_pts")
1341 .map_err(Self::map_err)?;
1342 let channels: Vec<(i64, i32)> = ch_stmt
1343 .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1344 .map_err(Self::map_err)?
1345 .filter_map(|r| r.ok())
1346 .collect();
1347
1348 let mut peer_stmt = conn
1350 .prepare("SELECT id, access_hash, is_channel FROM peers")
1351 .map_err(Self::map_err)?;
1352 let peers: Vec<CachedPeer> = peer_stmt
1353 .query_map([], |r| {
1354 Ok(CachedPeer {
1355 id: r.get(0)?,
1356 access_hash: r.get(1)?,
1357 is_channel: r.get::<_, i32>(2)? != 0,
1358 is_chat: false,
1359 })
1360 })
1361 .map_err(Self::map_err)?
1362 .filter_map(|r| r.ok())
1363 .collect();
1364
1365 Ok(PersistedSession {
1366 home_dc_id,
1367 dcs,
1368 updates_state: UpdatesStateSnap {
1369 channels,
1370 ..updates_state
1371 },
1372 peers,
1373 min_peers: Vec::new(),
1374 })
1375 }
1376
1377 fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1379 conn.execute_batch("BEGIN IMMEDIATE")
1380 .map_err(Self::map_err)?;
1381
1382 conn.execute(
1383 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1384 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1385 rusqlite::params![s.home_dc_id],
1386 )
1387 .map_err(Self::map_err)?;
1388
1389 conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1391 for d in &s.dcs {
1392 conn.execute(
1393 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1394 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1395 rusqlite::params![
1396 d.dc_id,
1397 d.flags.0,
1398 d.addr,
1399 d.auth_key.as_ref().map(|k| k.as_ref()),
1400 d.first_salt,
1401 d.time_offset,
1402 ],
1403 )
1404 .map_err(Self::map_err)?;
1405 }
1406
1407 let us = &s.updates_state;
1411 conn.execute(
1412 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1413 ON CONFLICT(id) DO UPDATE SET
1414 pts = MAX(excluded.pts, update_state.pts),
1415 qts = MAX(excluded.qts, update_state.qts),
1416 date = excluded.date,
1417 seq = excluded.seq",
1418 rusqlite::params![us.pts, us.qts, us.date, us.seq],
1419 )
1420 .map_err(Self::map_err)?;
1421
1422 conn.execute("DELETE FROM channel_pts", [])
1423 .map_err(Self::map_err)?;
1424 for &(cid, cpts) in &us.channels {
1425 conn.execute(
1426 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1427 rusqlite::params![cid, cpts],
1428 )
1429 .map_err(Self::map_err)?;
1430 }
1431
1432 conn.execute("DELETE FROM peers", [])
1434 .map_err(Self::map_err)?;
1435 for p in &s.peers {
1436 conn.execute(
1437 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)",
1438 rusqlite::params![p.id, p.access_hash, p.is_channel as i32],
1439 )
1440 .map_err(Self::map_err)?;
1441 }
1442
1443 conn.execute_batch("COMMIT").map_err(Self::map_err)
1444 }
1445}
1446
1447#[cfg(feature = "sqlite-session")]
1448impl SessionBackend for SqliteBackend {
1449 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1450 let conn = self.conn.lock().unwrap();
1451 Self::write_session(&conn, session)
1452 }
1453
1454 fn load(&self) -> io::Result<Option<PersistedSession>> {
1455 let conn = self.conn.lock().unwrap();
1456 let count: i64 = conn
1458 .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1459 .map_err(Self::map_err)?;
1460 if count == 0 {
1461 return Ok(None);
1462 }
1463 Self::read_session(&conn).map(Some)
1464 }
1465
1466 fn delete(&self) -> io::Result<()> {
1467 let conn = self.conn.lock().unwrap();
1468 conn.execute_batch(
1469 "BEGIN IMMEDIATE;
1470 DELETE FROM meta;
1471 DELETE FROM dcs;
1472 DELETE FROM update_state;
1473 DELETE FROM channel_pts;
1474 DELETE FROM peers;
1475 COMMIT;",
1476 )
1477 .map_err(Self::map_err)
1478 }
1479
1480 fn name(&self) -> &str {
1481 &self.label
1482 }
1483
1484 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1487 let conn = self.conn.lock().unwrap();
1488 conn.execute(
1489 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1490 VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1491 ON CONFLICT(dc_id, flags) DO UPDATE SET
1492 addr = excluded.addr,
1493 auth_key = excluded.auth_key,
1494 first_salt = excluded.first_salt,
1495 time_offset = excluded.time_offset",
1496 rusqlite::params![
1497 entry.dc_id,
1498 entry.addr,
1499 entry.auth_key.as_ref().map(|k| k.as_ref()),
1500 entry.first_salt,
1501 entry.time_offset,
1502 entry.flags.0,
1503 ],
1504 )
1505 .map(|_| ())
1506 .map_err(Self::map_err)
1507 }
1508
1509 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1510 let conn = self.conn.lock().unwrap();
1511 conn.execute(
1512 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1513 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1514 rusqlite::params![dc_id],
1515 )
1516 .map(|_| ())
1517 .map_err(Self::map_err)
1518 }
1519
1520 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1521 let conn = self.conn.lock().unwrap();
1522 match update {
1523 UpdateStateChange::All(snap) => {
1524 conn.execute(
1525 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1526 ON CONFLICT(id) DO UPDATE SET
1527 pts=excluded.pts, qts=excluded.qts,
1528 date=excluded.date, seq=excluded.seq",
1529 rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1530 )
1531 .map_err(Self::map_err)?;
1532 conn.execute("DELETE FROM channel_pts", [])
1533 .map_err(Self::map_err)?;
1534 for &(cid, cpts) in &snap.channels {
1535 conn.execute(
1536 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1537 rusqlite::params![cid, cpts],
1538 )
1539 .map_err(Self::map_err)?;
1540 }
1541 Ok(())
1542 }
1543 UpdateStateChange::Primary { pts, date, seq } => conn
1544 .execute(
1545 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1546 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1547 seq=excluded.seq",
1548 rusqlite::params![pts, date, seq],
1549 )
1550 .map(|_| ())
1551 .map_err(Self::map_err),
1552 UpdateStateChange::Secondary { qts } => conn
1553 .execute(
1554 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1555 ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1556 rusqlite::params![qts],
1557 )
1558 .map(|_| ())
1559 .map_err(Self::map_err),
1560 UpdateStateChange::Channel { id, pts } => conn
1561 .execute(
1562 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1563 ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1564 rusqlite::params![id, pts],
1565 )
1566 .map(|_| ())
1567 .map_err(Self::map_err),
1568 }
1569 }
1570
1571 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1572 let conn = self.conn.lock().unwrap();
1573 conn.execute(
1574 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)
1575 ON CONFLICT(id) DO UPDATE SET
1576 access_hash = excluded.access_hash,
1577 is_channel = excluded.is_channel",
1578 rusqlite::params![peer.id, peer.access_hash, peer.is_channel as i32],
1579 )
1580 .map(|_| ())
1581 .map_err(Self::map_err)
1582 }
1583}
1584
1585#[cfg(feature = "libsql-session")]
1605pub struct LibSqlBackend {
1606 conn: libsql::Connection,
1607 label: String,
1608}
1609
1610#[cfg(feature = "libsql-session")]
1611impl LibSqlBackend {
1612 const SCHEMA: &'static str = "
1613 CREATE TABLE IF NOT EXISTS meta (
1614 key TEXT PRIMARY KEY,
1615 value INTEGER NOT NULL DEFAULT 0
1616 );
1617 CREATE TABLE IF NOT EXISTS dcs (
1618 dc_id INTEGER NOT NULL,
1619 flags INTEGER NOT NULL DEFAULT 0,
1620 addr TEXT NOT NULL,
1621 auth_key BLOB,
1622 first_salt INTEGER NOT NULL DEFAULT 0,
1623 time_offset INTEGER NOT NULL DEFAULT 0,
1624 PRIMARY KEY (dc_id, flags)
1625 );
1626 CREATE TABLE IF NOT EXISTS update_state (
1627 id INTEGER PRIMARY KEY CHECK (id = 1),
1628 pts INTEGER NOT NULL DEFAULT 0,
1629 qts INTEGER NOT NULL DEFAULT 0,
1630 date INTEGER NOT NULL DEFAULT 0,
1631 seq INTEGER NOT NULL DEFAULT 0
1632 );
1633 CREATE TABLE IF NOT EXISTS channel_pts (
1634 channel_id INTEGER PRIMARY KEY,
1635 pts INTEGER NOT NULL
1636 );
1637 CREATE TABLE IF NOT EXISTS peers (
1638 id INTEGER PRIMARY KEY,
1639 access_hash INTEGER NOT NULL,
1640 is_channel INTEGER NOT NULL DEFAULT 0
1641 );
1642 ";
1643
1644 fn block<F, T>(fut: F) -> io::Result<T>
1645 where
1646 F: std::future::Future<Output = Result<T, libsql::Error>>,
1647 {
1648 tokio::runtime::Handle::current()
1649 .block_on(fut)
1650 .map_err(io::Error::other)
1651 }
1652
1653 async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1654 conn.execute_batch(Self::SCHEMA).await
1655 }
1656
1657 pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1659 let path = path.into();
1660 let label = path.display().to_string();
1661 let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1662 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1663 Self::block(Self::apply_schema(&conn))?;
1664 Ok(Self {
1665 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1666 label,
1667 })
1668 }
1669
1670 pub fn in_memory() -> io::Result<Self> {
1672 let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1673 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1674 Self::block(Self::apply_schema(&conn))?;
1675 Ok(Self {
1676 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1677 label: ":memory:".into(),
1678 })
1679 }
1680
1681 pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1683 let url = url.into();
1684 let label = url.clone();
1685 let db = Self::block(async {
1686 libsql::Builder::new_remote(url, auth_token.into())
1687 .build()
1688 .await
1689 })?;
1690 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1691 Self::block(Self::apply_schema(&conn))?;
1692 Ok(Self {
1693 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1694 label,
1695 })
1696 }
1697
1698 pub fn open_replica(
1700 path: impl Into<PathBuf>,
1701 url: impl Into<String>,
1702 auth_token: impl Into<String>,
1703 ) -> io::Result<Self> {
1704 let path = path.into();
1705 let label = format!("{} (replica of {})", path.display(), url.into());
1706 let db = Self::block(async {
1707 libsql::Builder::new_remote_replica(path, url.into(), auth_token.into())
1708 .build()
1709 .await
1710 })?;
1711 let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1712 Self::block(Self::apply_schema(&conn))?;
1713 Ok(Self {
1714 conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1715 label,
1716 })
1717 }
1718
1719 async fn read_session_async(
1720 conn: &libsql::Connection,
1721 ) -> Result<PersistedSession, libsql::Error> {
1722 use libsql::de;
1723
1724 let home_dc_id: i32 = conn
1726 .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
1727 .await?
1728 .next()
1729 .await?
1730 .map(|r| r.get::<i32>(0))
1731 .transpose()?
1732 .unwrap_or(0);
1733
1734 let mut rows = conn
1736 .query(
1737 "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
1738 (),
1739 )
1740 .await?;
1741 let mut dcs = Vec::new();
1742 while let Some(row) = rows.next().await? {
1743 let dc_id: i32 = row.get(0)?;
1744 let flags_raw: u8 = row.get::<i64>(1)? as u8;
1745 let addr: String = row.get(2)?;
1746 let key_blob: Option<Vec<u8>> = row.get(3)?;
1747 let first_salt: i64 = row.get(4)?;
1748 let time_offset: i32 = row.get(5)?;
1749 let auth_key = match key_blob {
1750 Some(b) if b.len() == 256 => {
1751 let mut k = [0u8; 256];
1752 k.copy_from_slice(&b);
1753 Some(k)
1754 }
1755 Some(b) => {
1756 return Err(libsql::Error::Misuse(format!(
1757 "auth_key blob must be 256 bytes, got {}",
1758 b.len()
1759 )));
1760 }
1761 None => None,
1762 };
1763 dcs.push(DcEntry {
1764 dc_id,
1765 addr,
1766 auth_key,
1767 first_salt,
1768 time_offset,
1769 flags: DcFlags(flags_raw),
1770 });
1771 }
1772
1773 let mut us_row = conn
1775 .query(
1776 "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1777 (),
1778 )
1779 .await?;
1780 let updates_state = if let Some(r) = us_row.next().await? {
1781 UpdatesStateSnap {
1782 pts: r.get(0)?,
1783 qts: r.get(1)?,
1784 date: r.get(2)?,
1785 seq: r.get(3)?,
1786 channels: vec![],
1787 }
1788 } else {
1789 UpdatesStateSnap::default()
1790 };
1791
1792 let mut ch_rows = conn
1794 .query("SELECT channel_id, pts FROM channel_pts", ())
1795 .await?;
1796 let mut channels = Vec::new();
1797 while let Some(r) = ch_rows.next().await? {
1798 channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
1799 }
1800
1801 let mut peer_rows = conn
1803 .query("SELECT id, access_hash, is_channel FROM peers", ())
1804 .await?;
1805 let mut peers = Vec::new();
1806 while let Some(r) = peer_rows.next().await? {
1807 peers.push(CachedPeer {
1808 id: r.get(0)?,
1809 access_hash: r.get(1)?,
1810 is_channel: r.get::<i32>(2)? != 0,
1811 is_chat: false,
1812 });
1813 }
1814
1815 Ok(PersistedSession {
1816 home_dc_id,
1817 dcs,
1818 updates_state: UpdatesStateSnap {
1819 channels,
1820 ..updates_state
1821 },
1822 peers,
1823 min_peers: Vec::new(),
1824 })
1825 }
1826
1827 async fn write_session_async(
1828 conn: &libsql::Connection,
1829 s: &PersistedSession,
1830 ) -> Result<(), libsql::Error> {
1831 conn.execute_batch("BEGIN IMMEDIATE").await?;
1832
1833 conn.execute(
1834 "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1835 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1836 libsql::params![s.home_dc_id],
1837 )
1838 .await?;
1839
1840 conn.execute("DELETE FROM dcs", ()).await?;
1841 for d in &s.dcs {
1842 conn.execute(
1843 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1844 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1845 libsql::params![
1846 d.dc_id,
1847 d.flags.0 as i64,
1848 d.addr.clone(),
1849 d.auth_key.map(|k| k.to_vec()),
1850 d.first_salt,
1851 d.time_offset,
1852 ],
1853 )
1854 .await?;
1855 }
1856
1857 let us = &s.updates_state;
1858 conn.execute(
1859 "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1860 ON CONFLICT(id) DO UPDATE SET
1861 pts = MAX(excluded.pts, update_state.pts),
1862 qts = MAX(excluded.qts, update_state.qts),
1863 date = excluded.date,
1864 seq = excluded.seq",
1865 libsql::params![us.pts, us.qts, us.date, us.seq],
1866 )
1867 .await?;
1868
1869 conn.execute("DELETE FROM channel_pts", ()).await?;
1870 for &(cid, cpts) in &us.channels {
1871 conn.execute(
1872 "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
1873 libsql::params![cid, cpts],
1874 )
1875 .await?;
1876 }
1877
1878 conn.execute("DELETE FROM peers", ()).await?;
1879 for p in &s.peers {
1880 conn.execute(
1881 "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1,?2,?3)",
1882 libsql::params![p.id, p.access_hash, p.is_channel as i32],
1883 )
1884 .await?;
1885 }
1886
1887 conn.execute_batch("COMMIT").await
1888 }
1889}
1890
1891#[cfg(feature = "libsql-session")]
1892impl SessionBackend for LibSqlBackend {
1893 fn save(&self, session: &PersistedSession) -> io::Result<()> {
1894 let conn = self.conn.clone();
1895 let session = session.clone();
1896 Self::block(async move {
1897 let conn = conn.lock().await;
1898 Self::write_session_async(&conn, session).await
1899 })
1900 }
1901
1902 fn load(&self) -> io::Result<Option<PersistedSession>> {
1903 let conn = self.conn.clone();
1904 let count: i64 = Self::block(async move {
1905 let conn = conn.lock().await;
1906 let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
1907 Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
1908 })?;
1909 if count == 0 {
1910 return Ok(None);
1911 }
1912 let conn = self.conn.clone();
1913 Self::block(async move {
1914 let conn = conn.lock().await;
1915 Self::read_session_async(&conn).await
1916 })
1917 .map(Some)
1918 }
1919
1920 fn delete(&self) -> io::Result<()> {
1921 let conn = self.conn.clone();
1922 Self::block(async move {
1923 let conn = conn.lock().await;
1924 conn.execute_batch(
1925 "BEGIN IMMEDIATE;
1926 DELETE FROM meta;
1927 DELETE FROM dcs;
1928 DELETE FROM update_state;
1929 DELETE FROM channel_pts;
1930 DELETE FROM peers;
1931 COMMIT;",
1932 )
1933 .await
1934 })
1935 }
1936
1937 fn name(&self) -> &str {
1938 &self.label
1939 }
1940
1941 fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1944 let conn = self.conn.clone();
1945 let (dc_id, addr, key, salt, off, flags) = (
1946 entry.dc_id,
1947 entry.addr.clone(),
1948 entry.auth_key.map(|k| k.to_vec()),
1949 entry.first_salt,
1950 entry.time_offset,
1951 entry.flags.0 as i64,
1952 );
1953 Self::block(async move {
1954 let conn = conn.lock().await;
1955 conn.execute(
1956 "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1957 VALUES (?1,?6,?2,?3,?4,?5)
1958 ON CONFLICT(dc_id, flags) DO UPDATE SET
1959 addr=excluded.addr, auth_key=excluded.auth_key,
1960 first_salt=excluded.first_salt, time_offset=excluded.time_offset",
1961 libsql::params![dc_id, addr, key, salt, off, flags],
1962 )
1963 .await
1964 .map(|_| ())
1965 })
1966 }
1967
1968 fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1969 let conn = self.conn.clone();
1970 Self::block(async move {
1971 let conn = conn.lock().await;
1972 conn.execute(
1973 "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
1974 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
1975 libsql::params![dc_id],
1976 )
1977 .await
1978 .map(|_| ())
1979 })
1980 }
1981
1982 fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1983 let conn = self.conn.clone();
1984 Self::block(async move {
1985 let conn = conn.lock().await;
1986 match update {
1987 UpdateStateChange::All(snap) => {
1988 conn.execute(
1989 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
1990 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
1991 date=excluded.date,seq=excluded.seq",
1992 libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
1993 )
1994 .await?;
1995 conn.execute("DELETE FROM channel_pts", ()).await?;
1996 for &(cid, cpts) in &snap.channels {
1997 conn.execute(
1998 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
1999 libsql::params![cid, cpts],
2000 )
2001 .await?;
2002 }
2003 Ok(())
2004 }
2005 UpdateStateChange::Primary { pts, date, seq } => conn
2006 .execute(
2007 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
2008 ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
2009 seq=excluded.seq",
2010 libsql::params![pts, date, seq],
2011 )
2012 .await
2013 .map(|_| ()),
2014 UpdateStateChange::Secondary { qts } => conn
2015 .execute(
2016 "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2017 ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2018 libsql::params![qts],
2019 )
2020 .await
2021 .map(|_| ()),
2022 UpdateStateChange::Channel { id, pts } => conn
2023 .execute(
2024 "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2025 ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2026 libsql::params![id, pts],
2027 )
2028 .await
2029 .map(|_| ()),
2030 }
2031 })
2032 }
2033
2034 fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2035 let conn = self.conn.clone();
2036 let (id, hash, is_ch) = (peer.id, peer.access_hash, peer.is_channel as i32);
2037 Self::block(async move {
2038 let conn = conn.lock().await;
2039 conn.execute(
2040 "INSERT INTO peers (id,access_hash,is_channel) VALUES (?1,?2,?3)
2041 ON CONFLICT(id) DO UPDATE SET
2042 access_hash=excluded.access_hash,
2043 is_channel=excluded.is_channel",
2044 libsql::params![id, hash, is_ch],
2045 )
2046 .await
2047 .map(|_| ())
2048 })
2049 }
2050}