Skip to main content

ferogram_session/
lib.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3//
4// ferogram-session: session persistence types and backends for ferogram
5// https://github.com/ankit-chaubey/ferogram
6
7//! Session persistence types and pluggable storage backends.
8//!
9//! Saves auth key, salt, time offset, DC table, update sequence counters,
10//! and peer access-hash cache.
11//!
12//! # Crate contents
13//!
14//! - [`PersistedSession`]: the serializable session struct (DC table,
15//!   auth keys, update counters, peer access-hash cache).
16//! - [`SessionBackend`]: the sync snapshot trait that all backends implement.
17//! - Built-in backends: [`BinaryFileBackend`], [`InMemoryBackend`],
18//!   [`StringSessionBackend`], and optionally [`SqliteBackend`] and
19//!   [`LibSqlBackend`] behind feature flags.
20//!
21//! ## Binary format versioning
22//!
23//! Every file starts with a single **version byte**:
24//! - `0x01`: legacy format (DC table only, no update state or peers).
25//! - `0x02`: current format (DC table + update state + peer cache).
26//!
27//! `load()` handles both. `save()` always writes v2.
28
29use std::collections::HashMap;
30use std::io::{self, ErrorKind};
31use std::path::Path;
32
33/// Per-DC option flags.
34///
35/// Stored in the session (v3+) so media DCs survive restarts.
36#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub struct DcFlags(pub u8);
39
40impl DcFlags {
41    pub const NONE: DcFlags = DcFlags(0);
42    pub const IPV6: DcFlags = DcFlags(1 << 0);
43    pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
44    pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
45    pub const CDN: DcFlags = DcFlags(1 << 3);
46    pub const STATIC: DcFlags = DcFlags(1 << 4);
47
48    pub fn contains(self, other: DcFlags) -> bool {
49        self.0 & other.0 == other.0
50    }
51
52    pub fn set(&mut self, flag: DcFlags) {
53        self.0 |= flag.0;
54    }
55}
56
57impl std::ops::BitOr for DcFlags {
58    type Output = DcFlags;
59    fn bitor(self, rhs: DcFlags) -> DcFlags {
60        DcFlags(self.0 | rhs.0)
61    }
62}
63
64/// One entry in the DC address table.
65#[derive(Clone, Debug)]
66#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
67pub struct DcEntry {
68    pub dc_id: i32,
69    pub addr: String,
70    pub auth_key: Option<[u8; 256]>,
71    pub first_salt: i64,
72    pub time_offset: i32,
73    /// DC capability flags (IPv6, media-only, CDN, ...).
74    pub flags: DcFlags,
75}
76
77impl DcEntry {
78    /// Returns `true` when this entry represents an IPv6 address.
79    #[inline]
80    pub fn is_ipv6(&self) -> bool {
81        self.flags.contains(DcFlags::IPV6)
82    }
83
84    /// Parse the stored `"ip:port"` / `"[ipv6]:port"` address into a
85    /// [`std::net::SocketAddr`].
86    ///
87    /// Both formats are valid:
88    /// - IPv4: `"149.154.175.53:443"`
89    /// - IPv6: `"[2001:b28:f23d:f001::a]:443"`
90    pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
91        self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
92            io::Error::new(
93                io::ErrorKind::InvalidData,
94                format!("invalid DC address: {:?}", self.addr),
95            )
96        })
97    }
98
99    /// Construct a `DcEntry` from separate IP string, port, and flags.
100    ///
101    /// IPv6 addresses are automatically wrapped in brackets so that
102    /// `socket_addr()` can round-trip them correctly:
103    ///
104    /// ```text
105    /// DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6)
106    /// // addr = "[2001:b28:f23d:f001::a]:443"
107    /// ```
108    ///
109    /// This is the preferred constructor when processing `help.getConfig`
110    /// `DcOption` objects from the Telegram API.
111    pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
112        // IPv6 addresses contain colons; wrap in brackets for SocketAddr compat.
113        let addr = if ip.contains(':') {
114            format!("[{ip}]:{port}")
115        } else {
116            format!("{ip}:{port}")
117        };
118        Self {
119            dc_id,
120            addr,
121            auth_key: None,
122            first_salt: 0,
123            time_offset: 0,
124            flags,
125        }
126    }
127}
128
129/// Snapshot of the MTProto update-sequence state that we persist so that
130/// `catch_up: true` can call `updates.getDifference` with the *pre-shutdown* pts.
131#[derive(Clone, Debug, Default)]
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133pub struct UpdatesStateSnap {
134    /// Main persistence counter (messages, non-channel updates).
135    pub pts: i32,
136    /// Secondary counter for secret chats.
137    pub qts: i32,
138    /// Date of the last received update (Unix timestamp).
139    pub date: i32,
140    /// Combined-container sequence number.
141    pub seq: i32,
142    /// Per-channel persistence counters.  `(channel_id, pts)`.
143    pub channels: Vec<(i64, i32)>,
144}
145
146impl UpdatesStateSnap {
147    /// Returns `true` when we have a real state from the server (pts > 0).
148    #[inline]
149    pub fn is_initialised(&self) -> bool {
150        self.pts > 0
151    }
152
153    /// Advance (or insert) a per-channel pts value.
154    pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
155        if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
156            entry.1 = pts;
157        } else {
158            self.channels.push((channel_id, pts));
159        }
160    }
161
162    /// Look up the stored pts for a channel, returns 0 if unknown.
163    pub fn channel_pts(&self, channel_id: i64) -> i32 {
164        self.channels
165            .iter()
166            .find(|c| c.0 == channel_id)
167            .map(|c| c.1)
168            .unwrap_or(0)
169    }
170}
171
172/// A cached access-hash entry so that the peer can be addressed across restarts
173/// without re-resolving it from Telegram.
174#[derive(Clone, Debug)]
175#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
176pub struct CachedPeer {
177    /// Bare Telegram peer ID (always positive).
178    pub id: i64,
179    /// Access hash bound to the current session.
180    /// Always 0 for regular group chats (they need no access_hash).
181    pub access_hash: i64,
182    /// `true` → channel / supergroup.  `false` → user or regular group.
183    pub is_channel: bool,
184    /// `true` → regular group chat (Chat::Chat / ChatForbidden).
185    /// When true, access_hash is meaningless (groups need no hash).
186    pub is_chat: bool,
187}
188
189/// A min-user context entry: the user was seen with `min=true` (access_hash
190/// not usable directly) so we store the peer+message where they appeared so
191/// that `InputPeerUserFromMessage` can be constructed on restart.
192#[derive(Clone, Debug)]
193#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
194pub struct CachedMinPeer {
195    /// The min user's ID.
196    pub user_id: i64,
197    /// The channel/chat/user ID of the peer that contained the message.
198    pub peer_id: i64,
199    /// The message ID within that peer.
200    pub msg_id: i32,
201}
202
203/// Everything that needs to survive a process restart.
204#[derive(Clone, Debug, Default)]
205#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
206pub struct PersistedSession {
207    pub home_dc_id: i32,
208    pub dcs: Vec<DcEntry>,
209    /// Update counters to enable reliable catch-up after a disconnect.
210    pub updates_state: UpdatesStateSnap,
211    /// Peer access-hash cache so that the client can reach out to any previously
212    /// seen user or channel without re-resolving them.
213    pub peers: Vec<CachedPeer>,
214    /// Min-user message contexts: users seen with `min=true` that can only be
215    /// addressed via `InputPeerUserFromMessage`.
216    pub min_peers: Vec<CachedMinPeer>,
217}
218
219impl PersistedSession {
220    /// Encode the session to raw bytes (v2 binary format).
221    pub fn to_bytes(&self) -> Vec<u8> {
222        let mut b = Vec::with_capacity(512);
223
224        b.push(0x05u8); // version
225
226        b.extend_from_slice(&self.home_dc_id.to_le_bytes());
227
228        b.push(self.dcs.len() as u8);
229        for d in &self.dcs {
230            b.extend_from_slice(&d.dc_id.to_le_bytes());
231            match &d.auth_key {
232                Some(k) => {
233                    b.push(1);
234                    b.extend_from_slice(k);
235                }
236                None => {
237                    b.push(0);
238                }
239            }
240            b.extend_from_slice(&d.first_salt.to_le_bytes());
241            b.extend_from_slice(&d.time_offset.to_le_bytes());
242            let ab = d.addr.as_bytes();
243            b.push(ab.len() as u8);
244            b.extend_from_slice(ab);
245            b.push(d.flags.0);
246        }
247
248        b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
249        b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
250        b.extend_from_slice(&self.updates_state.date.to_le_bytes());
251        b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
252        let ch = &self.updates_state.channels;
253        b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
254        for &(cid, cpts) in ch {
255            b.extend_from_slice(&cid.to_le_bytes());
256            b.extend_from_slice(&cpts.to_le_bytes());
257        }
258
259        // v5 peer type: 0=user, 1=channel, 2=regular-group-chat
260        b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
261        for p in &self.peers {
262            b.extend_from_slice(&p.id.to_le_bytes());
263            b.extend_from_slice(&p.access_hash.to_le_bytes());
264            let peer_type: u8 = if p.is_chat {
265                2
266            } else if p.is_channel {
267                1
268            } else {
269                0
270            };
271            b.push(peer_type);
272        }
273
274        b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
275        for m in &self.min_peers {
276            b.extend_from_slice(&m.user_id.to_le_bytes());
277            b.extend_from_slice(&m.peer_id.to_le_bytes());
278            b.extend_from_slice(&m.msg_id.to_le_bytes());
279        }
280
281        b
282    }
283
284    /// Atomically save the session to `path`.
285    ///
286    /// Writes to `<path>.tmp` first, then renames into place so a crash
287    /// mid-write never corrupts the existing session file.
288    pub fn save(&self, path: &Path) -> io::Result<()> {
289        let tmp = path.with_extension("tmp");
290        std::fs::write(&tmp, self.to_bytes())?;
291        std::fs::rename(&tmp, path)
292    }
293
294    /// Decode a session from raw bytes (v1 or v2 binary format).
295    pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
296        if buf.is_empty() {
297            return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
298        }
299
300        let mut p = 0usize;
301
302        macro_rules! r {
303            ($n:expr) => {{
304                if p + $n > buf.len() {
305                    return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
306                }
307                let s = &buf[p..p + $n];
308                p += $n;
309                s
310            }};
311        }
312        macro_rules! r_i32 {
313            () => {
314                i32::from_le_bytes(r!(4).try_into().unwrap())
315            };
316        }
317        macro_rules! r_i64 {
318            () => {
319                i64::from_le_bytes(r!(8).try_into().unwrap())
320            };
321        }
322        macro_rules! r_u8 {
323            () => {
324                r!(1)[0]
325            };
326        }
327        macro_rules! r_u16 {
328            () => {
329                u16::from_le_bytes(r!(2).try_into().unwrap())
330            };
331        }
332
333        let first_byte = r_u8!();
334
335        let (home_dc_id, version) = if first_byte == 0x05 {
336            (r_i32!(), 5u8)
337        } else if first_byte == 0x04 {
338            (r_i32!(), 4u8)
339        } else if first_byte == 0x03 {
340            (r_i32!(), 3u8)
341        } else if first_byte == 0x02 {
342            (r_i32!(), 2u8)
343        } else {
344            let rest = r!(3);
345            let mut bytes = [0u8; 4];
346            bytes[0] = first_byte;
347            bytes[1..4].copy_from_slice(rest);
348            (i32::from_le_bytes(bytes), 1u8)
349        };
350
351        let dc_count = r_u8!() as usize;
352        let mut dcs = Vec::with_capacity(dc_count);
353        for _ in 0..dc_count {
354            let dc_id = r_i32!();
355            let has_key = r_u8!();
356            let auth_key = if has_key == 1 {
357                let mut k = [0u8; 256];
358                k.copy_from_slice(r!(256));
359                Some(k)
360            } else {
361                None
362            };
363            let first_salt = r_i64!();
364            let time_offset = r_i32!();
365            let al = r_u8!() as usize;
366            let addr = String::from_utf8_lossy(r!(al)).into_owned();
367            let flags = if version >= 3 {
368                DcFlags(r_u8!())
369            } else {
370                DcFlags::NONE
371            };
372            dcs.push(DcEntry {
373                dc_id,
374                addr,
375                auth_key,
376                first_salt,
377                time_offset,
378                flags,
379            });
380        }
381
382        if version < 2 {
383            return Ok(Self {
384                home_dc_id,
385                dcs,
386                updates_state: UpdatesStateSnap::default(),
387                peers: Vec::new(),
388                min_peers: Vec::new(),
389            });
390        }
391
392        let pts = r_i32!();
393        let qts = r_i32!();
394        let date = r_i32!();
395        let seq = r_i32!();
396        let ch_count = r_u16!() as usize;
397        let mut channels = Vec::with_capacity(ch_count);
398        for _ in 0..ch_count {
399            let cid = r_i64!();
400            let cpts = r_i32!();
401            channels.push((cid, cpts));
402        }
403
404        let peer_count = r_u16!() as usize;
405        let mut peers = Vec::with_capacity(peer_count);
406        for _ in 0..peer_count {
407            let id = r_i64!();
408            let access_hash = r_i64!();
409            // v5: type byte 0=user, 1=channel, 2=chat; v2-v4: 0=user, 1=channel
410            let peer_type = r_u8!();
411            let is_channel = peer_type == 1;
412            let is_chat = peer_type == 2;
413            peers.push(CachedPeer {
414                id,
415                access_hash,
416                is_channel,
417                is_chat,
418            });
419        }
420
421        // v4+: min-user contexts
422        let min_peers = if version >= 4 {
423            let count = r_u16!() as usize;
424            let mut v = Vec::with_capacity(count);
425            for _ in 0..count {
426                let user_id = r_i64!();
427                let peer_id = r_i64!();
428                let msg_id = r_i32!();
429                v.push(CachedMinPeer {
430                    user_id,
431                    peer_id,
432                    msg_id,
433                });
434            }
435            v
436        } else {
437            Vec::new()
438        };
439
440        Ok(Self {
441            home_dc_id,
442            dcs,
443            updates_state: UpdatesStateSnap {
444                pts,
445                qts,
446                date,
447                seq,
448                channels,
449            },
450            peers,
451            min_peers,
452        })
453    }
454
455    /// Decode a session from a URL-safe base64 string produced by [`to_string`].
456    pub fn from_string(s: &str) -> io::Result<Self> {
457        use base64::Engine as _;
458        let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
459            .decode(s.trim())
460            .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
461        Self::from_bytes(&bytes)
462    }
463
464    pub fn load(path: &Path) -> io::Result<Self> {
465        let buf = std::fs::read(path)?;
466        Self::from_bytes(&buf)
467    }
468
469    // DC address helpers
470
471    /// Find the best DC entry for a given DC ID.
472    ///
473    /// When `prefer_ipv6` is `true`, returns the IPv6 entry if one is
474    /// stored, falling back to IPv4.  When `false`, returns IPv4,
475    /// falling back to IPv6.  Returns `None` only when the DC ID is
476    /// completely unknown.
477    ///
478    /// This correctly handles the case where both an IPv4 and an IPv6
479    /// `DcEntry` exist for the same `dc_id` (different `flags` bitmask).
480    pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
481        let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
482        candidates.peek()?;
483        // Collect so we can search twice
484        let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
485        // Preferred family first, fall back to whatever is available
486        cands
487            .iter()
488            .copied()
489            .find(|d| d.is_ipv6() == prefer_ipv6)
490            .or_else(|| cands.first().copied())
491    }
492
493    /// Iterate over every stored DC entry for a given DC ID.
494    ///
495    /// Typically yields one IPv4 and one IPv6 entry per DC ID once
496    /// `help.getConfig` has been applied.
497    pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
498        self.dcs.iter().filter(move |d| d.dc_id == dc_id)
499    }
500}
501
502impl std::fmt::Display for PersistedSession {
503    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504        use base64::Engine as _;
505        f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
506    }
507}
508
509/// Bootstrap DC address table (fallback if GetConfig fails).
510pub fn default_dc_addresses() -> HashMap<i32, String> {
511    [
512        (1, "149.154.175.53:443"),
513        (2, "149.154.167.51:443"),
514        (3, "149.154.175.100:443"),
515        (4, "149.154.167.91:443"),
516        (5, "91.108.56.130:443"),
517    ]
518    .into_iter()
519    .map(|(id, addr)| (id, addr.to_string()))
520    .collect()
521}
522
523// session_backend
524//
525// Pluggable session storage backend.
526
527use std::path::PathBuf;
528
529// Core trait (unchanged)
530
531/// Synchronous snapshot backend: saves and loads the full session at once.
532///
533/// All built-in backends implement this. Higher-level code should prefer the
534/// extension methods below (`update_dc`, `set_home_dc`, `update_state`) which
535/// avoid unnecessary full-snapshot writes.
536pub trait SessionBackend: Send + Sync {
537    fn save(&self, session: &PersistedSession) -> io::Result<()>;
538    fn load(&self) -> io::Result<Option<PersistedSession>>;
539    fn delete(&self) -> io::Result<()>;
540
541    /// Human-readable name for logging/debug output.
542    fn name(&self) -> &str;
543
544    // Granular helpers (default: load → mutate → save)
545    //
546    // These default implementations are correct but not optimal.
547    // Backends that store data in a database (SQLite, libsql, Redis) should
548    // override them to issue single-row UPDATE statements instead.
549
550    /// Update a single DC entry without rewriting the entire session.
551    ///
552    /// Typically called after:
553    /// - completing a DH handshake on a new DC (to persist its auth key)
554    /// - receiving updated DC addresses from `help.getConfig`
555    ///
556    /// Ported from  `Session::set_dc_option`.
557    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
558        let mut s = self.load()?.unwrap_or_default();
559        // Replace existing entry or append
560        if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
561            *existing = entry.clone();
562        } else {
563            s.dcs.push(entry.clone());
564        }
565        self.save(&s)
566    }
567
568    /// Change the home DC without touching any other session data.
569    ///
570    /// Called after a successful `*_MIGRATE` redirect: the user's account
571    /// now lives on a different DC.
572    ///
573    /// Ported from  `Session::set_home_dc_id`.
574    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
575        let mut s = self.load()?.unwrap_or_default();
576        s.home_dc_id = dc_id;
577        self.save(&s)
578    }
579
580    /// Apply a single update-sequence change without a full save/load.
581    ///
582    /// Ported from  `Session::set_update_state(UpdateState)`.
583    ///
584    /// `update` is the new partial or full state to merge in.
585    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
586        let mut s = self.load()?.unwrap_or_default();
587        update.apply_to(&mut s.updates_state);
588        self.save(&s)
589    }
590
591    /// Cache a peer access hash without a full session save.
592    ///
593    /// This is lossy-on-default (full round-trip) but correct.
594    /// Override in SQL backends to issue a single `INSERT OR REPLACE`.
595    ///
596    /// Ported from  `Session::cache_peer`.
597    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
598        let mut s = self.load()?.unwrap_or_default();
599        if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
600            *existing = peer.clone();
601        } else {
602            s.peers.push(peer.clone());
603        }
604        self.save(&s)
605    }
606}
607
608// UpdateStateChange (mirrors  UpdateState enum)
609
610/// A single update-sequence change, applied via [`SessionBackend::apply_update_state`].
611///
612///uses:
613/// ```text
614/// UpdateState::All(updates_state)
615/// UpdateState::Primary { pts, date, seq }
616/// UpdateState::Secondary { qts }
617/// UpdateState::Channel { id, pts }
618/// ```
619///
620/// We map this 1-to-1 to layer's `UpdatesStateSnap`.
621#[derive(Debug, Clone)]
622pub enum UpdateStateChange {
623    /// Replace the entire state snapshot.
624    All(UpdatesStateSnap),
625    /// Update main sequence counters only (non-channel).
626    Primary { pts: i32, date: i32, seq: i32 },
627    /// Update the QTS counter (secret chats).
628    Secondary { qts: i32 },
629    /// Update the PTS for a specific channel.
630    Channel { id: i64, pts: i32 },
631}
632
633impl UpdateStateChange {
634    /// Apply `self` to `snap` in-place.
635    pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
636        match self {
637            Self::All(new_snap) => *snap = new_snap.clone(),
638            Self::Primary { pts, date, seq } => {
639                snap.pts = *pts;
640                snap.date = *date;
641                snap.seq = *seq;
642            }
643            Self::Secondary { qts } => {
644                snap.qts = *qts;
645            }
646            Self::Channel { id, pts } => {
647                // Replace or insert per-channel pts
648                if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
649                    existing.1 = *pts;
650                } else {
651                    snap.channels.push((*id, *pts));
652                }
653            }
654        }
655    }
656}
657
658// BinaryFileBackend
659
660/// Stores the session in a compact binary file (v2 format).
661pub struct BinaryFileBackend {
662    path: PathBuf,
663}
664
665impl BinaryFileBackend {
666    pub fn new(path: impl Into<PathBuf>) -> Self {
667        Self { path: path.into() }
668    }
669
670    pub fn path(&self) -> &std::path::Path {
671        &self.path
672    }
673}
674
675impl SessionBackend for BinaryFileBackend {
676    fn save(&self, session: &PersistedSession) -> io::Result<()> {
677        session.save(&self.path)
678    }
679
680    fn load(&self) -> io::Result<Option<PersistedSession>> {
681        if !self.path.exists() {
682            return Ok(None);
683        }
684        match PersistedSession::load(&self.path) {
685            Ok(s) => Ok(Some(s)),
686            Err(e) => {
687                let bak = self.path.with_extension("bak");
688                tracing::warn!(
689                    "[ferogram] Session file {:?} is corrupt ({e}); \
690                     renaming to {:?} and starting fresh",
691                    self.path,
692                    bak
693                );
694                let _ = std::fs::rename(&self.path, &bak);
695                Ok(None)
696            }
697        }
698    }
699
700    fn delete(&self) -> io::Result<()> {
701        if self.path.exists() {
702            std::fs::remove_file(&self.path)?;
703        }
704        Ok(())
705    }
706
707    fn name(&self) -> &str {
708        "binary-file"
709    }
710
711    // BinaryFileBackend: the default granular impls (load→mutate→save) are
712    // fine since the format is a single compact binary blob. No override needed.
713}
714
715// InMemoryBackend
716
717/// Ephemeral in-process session: nothing persisted to disk.
718///
719/// Override the granular methods to skip the clone overhead of the full
720/// snapshot path (we're already in memory, so direct field mutations are
721/// cheaper than clone→mutate→replace).
722#[derive(Default)]
723pub struct InMemoryBackend {
724    data: std::sync::Mutex<Option<PersistedSession>>,
725}
726
727impl InMemoryBackend {
728    pub fn new() -> Self {
729        Self::default()
730    }
731
732    /// Test helper: get a snapshot of the current in-memory state.
733    pub fn snapshot(&self) -> Option<PersistedSession> {
734        self.data.lock().unwrap().clone()
735    }
736}
737
738impl SessionBackend for InMemoryBackend {
739    fn save(&self, s: &PersistedSession) -> io::Result<()> {
740        *self.data.lock().unwrap() = Some(s.clone());
741        Ok(())
742    }
743
744    fn load(&self) -> io::Result<Option<PersistedSession>> {
745        Ok(self.data.lock().unwrap().clone())
746    }
747
748    fn delete(&self) -> io::Result<()> {
749        *self.data.lock().unwrap() = None;
750        Ok(())
751    }
752
753    fn name(&self) -> &str {
754        "in-memory"
755    }
756
757    // Granular overrides: cheaper than load→clone→save
758
759    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
760        let mut guard = self.data.lock().unwrap();
761        let s = guard.get_or_insert_with(PersistedSession::default);
762        if let Some(existing) = s.dcs.iter_mut().find(|d| d.dc_id == entry.dc_id) {
763            *existing = entry.clone();
764        } else {
765            s.dcs.push(entry.clone());
766        }
767        Ok(())
768    }
769
770    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
771        let mut guard = self.data.lock().unwrap();
772        let s = guard.get_or_insert_with(PersistedSession::default);
773        s.home_dc_id = dc_id;
774        Ok(())
775    }
776
777    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
778        let mut guard = self.data.lock().unwrap();
779        let s = guard.get_or_insert_with(PersistedSession::default);
780        update.apply_to(&mut s.updates_state);
781        Ok(())
782    }
783
784    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
785        let mut guard = self.data.lock().unwrap();
786        let s = guard.get_or_insert_with(PersistedSession::default);
787        if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
788            *existing = peer.clone();
789        } else {
790            s.peers.push(peer.clone());
791        }
792        Ok(())
793    }
794}
795
796// StringSessionBackend
797
798/// Portable base64 string session backend.
799pub struct StringSessionBackend {
800    data: std::sync::Mutex<String>,
801}
802
803impl StringSessionBackend {
804    pub fn new(s: impl Into<String>) -> Self {
805        Self {
806            data: std::sync::Mutex::new(s.into()),
807        }
808    }
809
810    pub fn current(&self) -> String {
811        self.data.lock().unwrap().clone()
812    }
813}
814
815impl SessionBackend for StringSessionBackend {
816    fn save(&self, session: &PersistedSession) -> io::Result<()> {
817        *self.data.lock().unwrap() = session.to_string();
818        Ok(())
819    }
820
821    fn load(&self) -> io::Result<Option<PersistedSession>> {
822        let s = self.data.lock().unwrap().clone();
823        if s.trim().is_empty() {
824            return Ok(None);
825        }
826        PersistedSession::from_string(&s).map(Some)
827    }
828
829    fn delete(&self) -> io::Result<()> {
830        *self.data.lock().unwrap() = String::new();
831        Ok(())
832    }
833
834    fn name(&self) -> &str {
835        "string-session"
836    }
837}
838
839// Tests
840
841#[cfg(test)]
842mod tests {
843    use super::*;
844
845    fn make_dc(id: i32) -> DcEntry {
846        DcEntry {
847            dc_id: id,
848            addr: format!("1.2.3.{id}:443"),
849            auth_key: None,
850            first_salt: 0,
851            time_offset: 0,
852            flags: DcFlags::NONE,
853        }
854    }
855
856    fn make_peer(id: i64, hash: i64) -> CachedPeer {
857        CachedPeer {
858            id,
859            access_hash: hash,
860            is_channel: false,
861            is_chat: false,
862        }
863    }
864
865    // InMemoryBackend: basic save/load
866
867    #[test]
868    fn inmemory_load_returns_none_when_empty() {
869        let b = InMemoryBackend::new();
870        assert!(b.load().unwrap().is_none());
871    }
872
873    #[test]
874    fn inmemory_save_then_load_round_trips() {
875        let b = InMemoryBackend::new();
876        let mut s = PersistedSession::default();
877        s.home_dc_id = 3;
878        s.dcs.push(make_dc(3));
879        b.save(&s).unwrap();
880
881        let loaded = b.load().unwrap().unwrap();
882        assert_eq!(loaded.home_dc_id, 3);
883        assert_eq!(loaded.dcs.len(), 1);
884    }
885
886    #[test]
887    fn inmemory_delete_clears_state() {
888        let b = InMemoryBackend::new();
889        let mut s = PersistedSession::default();
890        s.home_dc_id = 2;
891        b.save(&s).unwrap();
892        b.delete().unwrap();
893        assert!(b.load().unwrap().is_none());
894    }
895
896    // InMemoryBackend: granular methods
897
898    #[test]
899    fn inmemory_update_dc_inserts_new() {
900        let b = InMemoryBackend::new();
901        b.update_dc(&make_dc(4)).unwrap();
902        let s = b.snapshot().unwrap();
903        assert_eq!(s.dcs.len(), 1);
904        assert_eq!(s.dcs[0].dc_id, 4);
905    }
906
907    #[test]
908    fn inmemory_update_dc_replaces_existing() {
909        let b = InMemoryBackend::new();
910        b.update_dc(&make_dc(2)).unwrap();
911        let mut updated = make_dc(2);
912        updated.addr = "9.9.9.9:443".to_string();
913        b.update_dc(&updated).unwrap();
914
915        let s = b.snapshot().unwrap();
916        assert_eq!(s.dcs.len(), 1);
917        assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
918    }
919
920    #[test]
921    fn inmemory_set_home_dc() {
922        let b = InMemoryBackend::new();
923        b.set_home_dc(5).unwrap();
924        assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
925    }
926
927    #[test]
928    fn inmemory_cache_peer_inserts() {
929        let b = InMemoryBackend::new();
930        b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
931        let s = b.snapshot().unwrap();
932        assert_eq!(s.peers.len(), 1);
933        assert_eq!(s.peers[0].id, 100);
934    }
935
936    #[test]
937    fn inmemory_cache_peer_updates_existing() {
938        let b = InMemoryBackend::new();
939        b.cache_peer(&make_peer(100, 111)).unwrap();
940        b.cache_peer(&make_peer(100, 222)).unwrap();
941        let s = b.snapshot().unwrap();
942        assert_eq!(s.peers.len(), 1);
943        assert_eq!(s.peers[0].access_hash, 222);
944    }
945
946    // UpdateStateChange
947
948    #[test]
949    fn update_state_primary() {
950        let mut snap = UpdatesStateSnap {
951            pts: 0,
952            qts: 0,
953            date: 0,
954            seq: 0,
955            channels: vec![],
956        };
957        UpdateStateChange::Primary {
958            pts: 10,
959            date: 20,
960            seq: 30,
961        }
962        .apply_to(&mut snap);
963        assert_eq!(snap.pts, 10);
964        assert_eq!(snap.date, 20);
965        assert_eq!(snap.seq, 30);
966        assert_eq!(snap.qts, 0); // untouched
967    }
968
969    #[test]
970    fn update_state_secondary() {
971        let mut snap = UpdatesStateSnap {
972            pts: 5,
973            qts: 0,
974            date: 0,
975            seq: 0,
976            channels: vec![],
977        };
978        UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
979        assert_eq!(snap.qts, 99);
980        assert_eq!(snap.pts, 5); // untouched
981    }
982
983    #[test]
984    fn update_state_channel_inserts() {
985        let mut snap = UpdatesStateSnap {
986            pts: 0,
987            qts: 0,
988            date: 0,
989            seq: 0,
990            channels: vec![],
991        };
992        UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
993        assert_eq!(snap.channels, vec![(12345, 42)]);
994    }
995
996    #[test]
997    fn update_state_channel_updates_existing() {
998        let mut snap = UpdatesStateSnap {
999            pts: 0,
1000            qts: 0,
1001            date: 0,
1002            seq: 0,
1003            channels: vec![(12345, 10), (67890, 5)],
1004        };
1005        UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1006        // First channel updated, second untouched
1007        assert_eq!(snap.channels[0], (12345, 99));
1008        assert_eq!(snap.channels[1], (67890, 5));
1009    }
1010
1011    #[test]
1012    fn apply_update_state_via_backend() {
1013        let b = InMemoryBackend::new();
1014        b.apply_update_state(UpdateStateChange::Primary {
1015            pts: 7,
1016            date: 8,
1017            seq: 9,
1018        })
1019        .unwrap();
1020        let s = b.snapshot().unwrap();
1021        assert_eq!(s.updates_state.pts, 7);
1022    }
1023
1024    // Default impl (BinaryFileBackend trait shape via InMemory smoke)
1025
1026    #[test]
1027    fn default_update_dc_via_trait_object() {
1028        let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1029        b.update_dc(&make_dc(1)).unwrap();
1030        b.update_dc(&make_dc(2)).unwrap();
1031        // Can't call snapshot() on trait object, but save/load must be consistent
1032        let loaded = b.load().unwrap().unwrap();
1033        assert_eq!(loaded.dcs.len(), 2);
1034    }
1035
1036    // IPv6 tests
1037
1038    fn make_dc_v6(id: i32) -> DcEntry {
1039        DcEntry {
1040            dc_id: id,
1041            addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1042            auth_key: None,
1043            first_salt: 0,
1044            time_offset: 0,
1045            flags: DcFlags::IPV6,
1046        }
1047    }
1048
1049    #[test]
1050    fn dc_entry_from_parts_ipv4() {
1051        let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1052        assert_eq!(dc.addr, "149.154.175.53:443");
1053        assert!(!dc.is_ipv6());
1054        let sa = dc.socket_addr().unwrap();
1055        assert_eq!(sa.port(), 443);
1056    }
1057
1058    #[test]
1059    fn dc_entry_from_parts_ipv6() {
1060        let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1061        assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1062        assert!(dc.is_ipv6());
1063        let sa = dc.socket_addr().unwrap();
1064        assert_eq!(sa.port(), 443);
1065    }
1066
1067    #[test]
1068    fn persisted_session_dc_for_prefers_ipv6() {
1069        let mut s = PersistedSession::default();
1070        s.dcs.push(make_dc(2)); // IPv4
1071        s.dcs.push(make_dc_v6(2)); // IPv6
1072
1073        let v6 = s.dc_for(2, true).unwrap();
1074        assert!(v6.is_ipv6());
1075
1076        let v4 = s.dc_for(2, false).unwrap();
1077        assert!(!v4.is_ipv6());
1078    }
1079
1080    #[test]
1081    fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1082        let mut s = PersistedSession::default();
1083        s.dcs.push(make_dc(3)); // IPv4 only
1084
1085        // Asking for IPv6 should fall back to IPv4
1086        let dc = s.dc_for(3, true).unwrap();
1087        assert!(!dc.is_ipv6());
1088    }
1089
1090    #[test]
1091    fn persisted_session_all_dcs_for_returns_both() {
1092        let mut s = PersistedSession::default();
1093        s.dcs.push(make_dc(1));
1094        s.dcs.push(make_dc_v6(1));
1095        s.dcs.push(make_dc(2));
1096
1097        assert_eq!(s.all_dcs_for(1).count(), 2);
1098        assert_eq!(s.all_dcs_for(2).count(), 1);
1099        assert_eq!(s.all_dcs_for(5).count(), 0);
1100    }
1101
1102    #[test]
1103    fn inmemory_ipv4_and_ipv6_coexist() {
1104        let b = InMemoryBackend::new();
1105        b.update_dc(&make_dc(2)).unwrap(); // IPv4
1106        b.update_dc(&make_dc_v6(2)).unwrap(); // IPv6
1107
1108        let s = b.snapshot().unwrap();
1109        // Both entries must survive they have different flags
1110        assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1111    }
1112
1113    #[test]
1114    fn binary_roundtrip_ipv4_and_ipv6() {
1115        let mut s = PersistedSession::default();
1116        s.home_dc_id = 2;
1117        s.dcs.push(make_dc(2));
1118        s.dcs.push(make_dc_v6(2));
1119
1120        let bytes = s.to_bytes();
1121        let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1122        assert_eq!(loaded.dcs.len(), 2);
1123        assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1124        assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1125    }
1126}
1127
1128// SqliteBackend
1129
1130/// SQLite-backed session (via `rusqlite`).
1131///
1132/// Enabled with the `sqlite-session` Cargo feature.
1133///
1134/// # Schema
1135///
1136/// Five tables are created on first open (idempotent):
1137///
1138/// | Table          | Purpose                                          |
1139/// |----------------|--------------------------------------------------|
1140/// | `meta`         | `home_dc_id` and future scalar values            |
1141/// | `dcs`          | One row per DC (auth key, address, flags, ...)     |
1142/// | `update_state` | Single-row pts / qts / date / seq                |
1143/// | `channel_pts`  | Per-channel pts                                  |
1144/// | `peers`        | Access-hash cache                                |
1145///
1146/// # Granular writes
1147///
1148/// All [`SessionBackend`] extension methods (`update_dc`, `set_home_dc`,
1149/// `apply_update_state`, `cache_peer`) issue **single-row SQL statements**
1150/// instead of the default load-mutate-save round-trip, so they are safe to
1151/// call frequently (e.g. on every update batch) without performance concerns.
1152#[cfg(feature = "sqlite-session")]
1153pub struct SqliteBackend {
1154    conn: std::sync::Mutex<rusqlite::Connection>,
1155    label: String,
1156}
1157
1158#[cfg(feature = "sqlite-session")]
1159impl SqliteBackend {
1160    const SCHEMA: &'static str = "
1161        PRAGMA journal_mode = WAL;
1162        PRAGMA synchronous  = NORMAL;
1163
1164        CREATE TABLE IF NOT EXISTS meta (
1165            key   TEXT    PRIMARY KEY,
1166            value INTEGER NOT NULL DEFAULT 0
1167        );
1168
1169        CREATE TABLE IF NOT EXISTS dcs (
1170            dc_id       INTEGER NOT NULL,
1171            flags       INTEGER NOT NULL DEFAULT 0,
1172            addr        TEXT    NOT NULL,
1173            auth_key    BLOB,
1174            first_salt  INTEGER NOT NULL DEFAULT 0,
1175            time_offset INTEGER NOT NULL DEFAULT 0,
1176            PRIMARY KEY (dc_id, flags)
1177        );
1178
1179        CREATE TABLE IF NOT EXISTS update_state (
1180            id   INTEGER PRIMARY KEY CHECK (id = 1),
1181            pts  INTEGER NOT NULL DEFAULT 0,
1182            qts  INTEGER NOT NULL DEFAULT 0,
1183            date INTEGER NOT NULL DEFAULT 0,
1184            seq  INTEGER NOT NULL DEFAULT 0
1185        );
1186
1187        CREATE TABLE IF NOT EXISTS channel_pts (
1188            channel_id INTEGER PRIMARY KEY,
1189            pts        INTEGER NOT NULL
1190        );
1191
1192        CREATE TABLE IF NOT EXISTS peers (
1193            id           INTEGER PRIMARY KEY,
1194            access_hash  INTEGER NOT NULL,
1195            is_channel   INTEGER NOT NULL DEFAULT 0
1196        );
1197    ";
1198
1199    /// Open (or create) the SQLite database at `path`.
1200    pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1201        let path = path.into();
1202        let label = path.display().to_string();
1203        let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1204        conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1205        Ok(Self {
1206            conn: std::sync::Mutex::new(conn),
1207            label,
1208        })
1209    }
1210
1211    /// Open an in-process SQLite database (useful for tests).
1212    pub fn in_memory() -> io::Result<Self> {
1213        let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1214        conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1215        Ok(Self {
1216            conn: std::sync::Mutex::new(conn),
1217            label: ":memory:".into(),
1218        })
1219    }
1220
1221    fn map_err(e: rusqlite::Error) -> io::Error {
1222        io::Error::other(e)
1223    }
1224
1225    /// Read the full session out of the database.
1226    fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1227        // home_dc_id
1228        let home_dc_id: i32 = conn
1229            .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1230                r.get(0)
1231            })
1232            .unwrap_or(0);
1233
1234        // dcs
1235        let mut stmt = conn
1236            .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1237            .map_err(Self::map_err)?;
1238        let dcs = stmt
1239            .query_map([], |row| {
1240                let dc_id: i32 = row.get(0)?;
1241                let flags_raw: u8 = row.get(1)?;
1242                let addr: String = row.get(2)?;
1243                let key_blob: Option<Vec<u8>> = row.get(3)?;
1244                let first_salt: i64 = row.get(4)?;
1245                let time_offset: i32 = row.get(5)?;
1246                Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1247            })
1248            .map_err(Self::map_err)?
1249            .filter_map(|r| r.ok())
1250            .map(
1251                |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1252                    let auth_key = key_blob.and_then(|b| {
1253                        if b.len() == 256 {
1254                            let mut k = [0u8; 256];
1255                            k.copy_from_slice(&b);
1256                            Some(k)
1257                        } else {
1258                            None
1259                        }
1260                    });
1261                    DcEntry {
1262                        dc_id,
1263                        addr,
1264                        auth_key,
1265                        first_salt,
1266                        time_offset,
1267                        flags: DcFlags(flags_raw),
1268                    }
1269                },
1270            )
1271            .collect();
1272
1273        // update_state
1274        let updates_state = conn
1275            .query_row(
1276                "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1277                [],
1278                |r| {
1279                    Ok(UpdatesStateSnap {
1280                        pts: r.get(0)?,
1281                        qts: r.get(1)?,
1282                        date: r.get(2)?,
1283                        seq: r.get(3)?,
1284                        channels: vec![],
1285                    })
1286                },
1287            )
1288            .unwrap_or_default();
1289
1290        // channel_pts
1291        let mut ch_stmt = conn
1292            .prepare("SELECT channel_id, pts FROM channel_pts")
1293            .map_err(Self::map_err)?;
1294        let channels: Vec<(i64, i32)> = ch_stmt
1295            .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1296            .map_err(Self::map_err)?
1297            .filter_map(|r| r.ok())
1298            .collect();
1299
1300        // peers
1301        let mut peer_stmt = conn
1302            .prepare("SELECT id, access_hash, is_channel FROM peers")
1303            .map_err(Self::map_err)?;
1304        let peers: Vec<CachedPeer> = peer_stmt
1305            .query_map([], |r| {
1306                Ok(CachedPeer {
1307                    id: r.get(0)?,
1308                    access_hash: r.get(1)?,
1309                    is_channel: r.get::<_, i32>(2)? != 0,
1310                    is_chat: false,
1311                })
1312            })
1313            .map_err(Self::map_err)?
1314            .filter_map(|r| r.ok())
1315            .collect();
1316
1317        Ok(PersistedSession {
1318            home_dc_id,
1319            dcs,
1320            updates_state: UpdatesStateSnap {
1321                channels,
1322                ..updates_state
1323            },
1324            peers,
1325            min_peers: Vec::new(),
1326        })
1327    }
1328
1329    /// Write the full session into the database inside a single transaction.
1330    fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1331        conn.execute_batch("BEGIN IMMEDIATE")
1332            .map_err(Self::map_err)?;
1333
1334        conn.execute(
1335            "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1336             ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1337            rusqlite::params![s.home_dc_id],
1338        )
1339        .map_err(Self::map_err)?;
1340
1341        // Replace all DCs
1342        conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1343        for d in &s.dcs {
1344            conn.execute(
1345                "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1346                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1347                rusqlite::params![
1348                    d.dc_id,
1349                    d.flags.0,
1350                    d.addr,
1351                    d.auth_key.as_ref().map(|k| k.as_ref()),
1352                    d.first_salt,
1353                    d.time_offset,
1354                ],
1355            )
1356            .map_err(Self::map_err)?;
1357        }
1358
1359        // update_state  pts and qts are monotonic: write_session() must never
1360        // move them backwards. MAX() ensures a stale snapshot cannot overwrite
1361        // a fresher value committed by apply_update_state().
1362        let us = &s.updates_state;
1363        conn.execute(
1364            "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1365             ON CONFLICT(id) DO UPDATE SET
1366               pts  = MAX(excluded.pts,  update_state.pts),
1367               qts  = MAX(excluded.qts,  update_state.qts),
1368               date = excluded.date,
1369               seq  = excluded.seq",
1370            rusqlite::params![us.pts, us.qts, us.date, us.seq],
1371        )
1372        .map_err(Self::map_err)?;
1373
1374        conn.execute("DELETE FROM channel_pts", [])
1375            .map_err(Self::map_err)?;
1376        for &(cid, cpts) in &us.channels {
1377            conn.execute(
1378                "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1379                rusqlite::params![cid, cpts],
1380            )
1381            .map_err(Self::map_err)?;
1382        }
1383
1384        // peers
1385        conn.execute("DELETE FROM peers", [])
1386            .map_err(Self::map_err)?;
1387        for p in &s.peers {
1388            conn.execute(
1389                "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)",
1390                rusqlite::params![p.id, p.access_hash, p.is_channel as i32],
1391            )
1392            .map_err(Self::map_err)?;
1393        }
1394
1395        conn.execute_batch("COMMIT").map_err(Self::map_err)
1396    }
1397}
1398
1399#[cfg(feature = "sqlite-session")]
1400impl SessionBackend for SqliteBackend {
1401    fn save(&self, session: &PersistedSession) -> io::Result<()> {
1402        let conn = self.conn.lock().unwrap();
1403        Self::write_session(&conn, session)
1404    }
1405
1406    fn load(&self) -> io::Result<Option<PersistedSession>> {
1407        let conn = self.conn.lock().unwrap();
1408        // If meta table is empty, no session has been saved yet.
1409        let count: i64 = conn
1410            .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1411            .map_err(Self::map_err)?;
1412        if count == 0 {
1413            return Ok(None);
1414        }
1415        Self::read_session(&conn).map(Some)
1416    }
1417
1418    fn delete(&self) -> io::Result<()> {
1419        let conn = self.conn.lock().unwrap();
1420        conn.execute_batch(
1421            "BEGIN IMMEDIATE;
1422             DELETE FROM meta;
1423             DELETE FROM dcs;
1424             DELETE FROM update_state;
1425             DELETE FROM channel_pts;
1426             DELETE FROM peers;
1427             COMMIT;",
1428        )
1429        .map_err(Self::map_err)
1430    }
1431
1432    fn name(&self) -> &str {
1433        &self.label
1434    }
1435
1436    // Granular overrides (single-row SQL, no full round-trip)
1437
1438    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1439        let conn = self.conn.lock().unwrap();
1440        conn.execute(
1441            "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1442             VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1443             ON CONFLICT(dc_id, flags) DO UPDATE SET
1444               addr        = excluded.addr,
1445               auth_key    = excluded.auth_key,
1446               first_salt  = excluded.first_salt,
1447               time_offset = excluded.time_offset",
1448            rusqlite::params![
1449                entry.dc_id,
1450                entry.addr,
1451                entry.auth_key.as_ref().map(|k| k.as_ref()),
1452                entry.first_salt,
1453                entry.time_offset,
1454                entry.flags.0,
1455            ],
1456        )
1457        .map(|_| ())
1458        .map_err(Self::map_err)
1459    }
1460
1461    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1462        let conn = self.conn.lock().unwrap();
1463        conn.execute(
1464            "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1465             ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1466            rusqlite::params![dc_id],
1467        )
1468        .map(|_| ())
1469        .map_err(Self::map_err)
1470    }
1471
1472    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1473        let conn = self.conn.lock().unwrap();
1474        match update {
1475            UpdateStateChange::All(snap) => {
1476                conn.execute(
1477                    "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1478                     ON CONFLICT(id) DO UPDATE SET
1479                       pts=excluded.pts, qts=excluded.qts,
1480                       date=excluded.date, seq=excluded.seq",
1481                    rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1482                )
1483                .map_err(Self::map_err)?;
1484                conn.execute("DELETE FROM channel_pts", [])
1485                    .map_err(Self::map_err)?;
1486                for &(cid, cpts) in &snap.channels {
1487                    conn.execute(
1488                        "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1489                        rusqlite::params![cid, cpts],
1490                    )
1491                    .map_err(Self::map_err)?;
1492                }
1493                Ok(())
1494            }
1495            UpdateStateChange::Primary { pts, date, seq } => conn
1496                .execute(
1497                    "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1498                     ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1499                     seq=excluded.seq",
1500                    rusqlite::params![pts, date, seq],
1501                )
1502                .map(|_| ())
1503                .map_err(Self::map_err),
1504            UpdateStateChange::Secondary { qts } => conn
1505                .execute(
1506                    "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1507                     ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1508                    rusqlite::params![qts],
1509                )
1510                .map(|_| ())
1511                .map_err(Self::map_err),
1512            UpdateStateChange::Channel { id, pts } => conn
1513                .execute(
1514                    "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1515                     ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1516                    rusqlite::params![id, pts],
1517                )
1518                .map(|_| ())
1519                .map_err(Self::map_err),
1520        }
1521    }
1522
1523    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1524        let conn = self.conn.lock().unwrap();
1525        conn.execute(
1526            "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1, ?2, ?3)
1527             ON CONFLICT(id) DO UPDATE SET
1528               access_hash = excluded.access_hash,
1529               is_channel  = excluded.is_channel",
1530            rusqlite::params![peer.id, peer.access_hash, peer.is_channel as i32],
1531        )
1532        .map(|_| ())
1533        .map_err(Self::map_err)
1534    }
1535}
1536
1537// LibSqlBackend
1538
1539/// libSQL-backed session (Turso / embedded replica / in-process).
1540///
1541/// Enabled with the `libsql-session` Cargo feature.
1542///
1543/// The libSQL API is async; since [`SessionBackend`] methods are sync we
1544/// block via `tokio::runtime::Handle::current().block_on(...)`.  Always
1545/// call from inside a Tokio runtime (i.e. the same runtime as the rest of
1546/// `ferogram`).
1547///
1548/// # Connecting
1549///
1550/// | Mode              | Constructor                        |
1551/// |-------------------|------------------------------------|
1552/// | Local file        | `LibSqlBackend::open_local(path)`  |
1553/// | In-memory         | `LibSqlBackend::in_memory()`       |
1554/// | Turso remote      | `LibSqlBackend::open_remote(url, token)` |
1555/// | Embedded replica  | `LibSqlBackend::open_replica(path, url, token)` |
1556#[cfg(feature = "libsql-session")]
1557pub struct LibSqlBackend {
1558    conn: libsql::Connection,
1559    label: String,
1560}
1561
1562#[cfg(feature = "libsql-session")]
1563impl LibSqlBackend {
1564    const SCHEMA: &'static str = "
1565        CREATE TABLE IF NOT EXISTS meta (
1566            key   TEXT    PRIMARY KEY,
1567            value INTEGER NOT NULL DEFAULT 0
1568        );
1569        CREATE TABLE IF NOT EXISTS dcs (
1570            dc_id       INTEGER NOT NULL,
1571            flags       INTEGER NOT NULL DEFAULT 0,
1572            addr        TEXT    NOT NULL,
1573            auth_key    BLOB,
1574            first_salt  INTEGER NOT NULL DEFAULT 0,
1575            time_offset INTEGER NOT NULL DEFAULT 0,
1576            PRIMARY KEY (dc_id, flags)
1577        );
1578        CREATE TABLE IF NOT EXISTS update_state (
1579            id   INTEGER PRIMARY KEY CHECK (id = 1),
1580            pts  INTEGER NOT NULL DEFAULT 0,
1581            qts  INTEGER NOT NULL DEFAULT 0,
1582            date INTEGER NOT NULL DEFAULT 0,
1583            seq  INTEGER NOT NULL DEFAULT 0
1584        );
1585        CREATE TABLE IF NOT EXISTS channel_pts (
1586            channel_id INTEGER PRIMARY KEY,
1587            pts        INTEGER NOT NULL
1588        );
1589        CREATE TABLE IF NOT EXISTS peers (
1590            id          INTEGER PRIMARY KEY,
1591            access_hash INTEGER NOT NULL,
1592            is_channel  INTEGER NOT NULL DEFAULT 0
1593        );
1594    ";
1595
1596    fn block<F, T>(fut: F) -> io::Result<T>
1597    where
1598        F: std::future::Future<Output = Result<T, libsql::Error>>,
1599    {
1600        tokio::runtime::Handle::current()
1601            .block_on(fut)
1602            .map_err(io::Error::other)
1603    }
1604
1605    async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1606        conn.execute_batch(Self::SCHEMA).await
1607    }
1608
1609    /// Open a local file database.
1610    pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1611        let path = path.into();
1612        let label = path.display().to_string();
1613        let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1614        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1615        Self::block(Self::apply_schema(&conn))?;
1616        Ok(Self {
1617            conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1618            label,
1619        })
1620    }
1621
1622    /// Open an in-process in-memory database (useful for tests).
1623    pub fn in_memory() -> io::Result<Self> {
1624        let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1625        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1626        Self::block(Self::apply_schema(&conn))?;
1627        Ok(Self {
1628            conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1629            label: ":memory:".into(),
1630        })
1631    }
1632
1633    /// Connect to a remote Turso database.
1634    pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1635        let url = url.into();
1636        let label = url.clone();
1637        let db = Self::block(async {
1638            libsql::Builder::new_remote(url, auth_token.into())
1639                .build()
1640                .await
1641        })?;
1642        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1643        Self::block(Self::apply_schema(&conn))?;
1644        Ok(Self {
1645            conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1646            label,
1647        })
1648    }
1649
1650    /// Open an embedded replica (local file + Turso remote sync).
1651    pub fn open_replica(
1652        path: impl Into<PathBuf>,
1653        url: impl Into<String>,
1654        auth_token: impl Into<String>,
1655    ) -> io::Result<Self> {
1656        let path = path.into();
1657        let label = format!("{} (replica of {})", path.display(), url.into());
1658        let db = Self::block(async {
1659            libsql::Builder::new_remote_replica(path, url.into(), auth_token.into())
1660                .build()
1661                .await
1662        })?;
1663        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1664        Self::block(Self::apply_schema(&conn))?;
1665        Ok(Self {
1666            conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1667            label,
1668        })
1669    }
1670
1671    async fn read_session_async(
1672        conn: &libsql::Connection,
1673    ) -> Result<PersistedSession, libsql::Error> {
1674        use libsql::de;
1675
1676        // home_dc_id
1677        let home_dc_id: i32 = conn
1678            .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
1679            .await?
1680            .next()
1681            .await?
1682            .map(|r| r.get::<i32>(0))
1683            .transpose()?
1684            .unwrap_or(0);
1685
1686        // dcs
1687        let mut rows = conn
1688            .query(
1689                "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
1690                (),
1691            )
1692            .await?;
1693        let mut dcs = Vec::new();
1694        while let Some(row) = rows.next().await? {
1695            let dc_id: i32 = row.get(0)?;
1696            let flags_raw: u8 = row.get::<i64>(1)? as u8;
1697            let addr: String = row.get(2)?;
1698            let key_blob: Option<Vec<u8>> = row.get(3)?;
1699            let first_salt: i64 = row.get(4)?;
1700            let time_offset: i32 = row.get(5)?;
1701            let auth_key = match key_blob {
1702                Some(b) if b.len() == 256 => {
1703                    let mut k = [0u8; 256];
1704                    k.copy_from_slice(&b);
1705                    Some(k)
1706                }
1707                Some(b) => {
1708                    return Err(libsql::Error::Misuse(format!(
1709                        "auth_key blob must be 256 bytes, got {}",
1710                        b.len()
1711                    )));
1712                }
1713                None => None,
1714            };
1715            dcs.push(DcEntry {
1716                dc_id,
1717                addr,
1718                auth_key,
1719                first_salt,
1720                time_offset,
1721                flags: DcFlags(flags_raw),
1722            });
1723        }
1724
1725        // update_state
1726        let mut us_row = conn
1727            .query(
1728                "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1729                (),
1730            )
1731            .await?;
1732        let updates_state = if let Some(r) = us_row.next().await? {
1733            UpdatesStateSnap {
1734                pts: r.get(0)?,
1735                qts: r.get(1)?,
1736                date: r.get(2)?,
1737                seq: r.get(3)?,
1738                channels: vec![],
1739            }
1740        } else {
1741            UpdatesStateSnap::default()
1742        };
1743
1744        // channel_pts
1745        let mut ch_rows = conn
1746            .query("SELECT channel_id, pts FROM channel_pts", ())
1747            .await?;
1748        let mut channels = Vec::new();
1749        while let Some(r) = ch_rows.next().await? {
1750            channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
1751        }
1752
1753        // peers
1754        let mut peer_rows = conn
1755            .query("SELECT id, access_hash, is_channel FROM peers", ())
1756            .await?;
1757        let mut peers = Vec::new();
1758        while let Some(r) = peer_rows.next().await? {
1759            peers.push(CachedPeer {
1760                id: r.get(0)?,
1761                access_hash: r.get(1)?,
1762                is_channel: r.get::<i32>(2)? != 0,
1763                is_chat: false,
1764            });
1765        }
1766
1767        Ok(PersistedSession {
1768            home_dc_id,
1769            dcs,
1770            updates_state: UpdatesStateSnap {
1771                channels,
1772                ..updates_state
1773            },
1774            peers,
1775            min_peers: Vec::new(),
1776        })
1777    }
1778
1779    async fn write_session_async(
1780        conn: &libsql::Connection,
1781        s: &PersistedSession,
1782    ) -> Result<(), libsql::Error> {
1783        conn.execute_batch("BEGIN IMMEDIATE").await?;
1784
1785        conn.execute(
1786            "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1787             ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1788            libsql::params![s.home_dc_id],
1789        )
1790        .await?;
1791
1792        conn.execute("DELETE FROM dcs", ()).await?;
1793        for d in &s.dcs {
1794            conn.execute(
1795                "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1796                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1797                libsql::params![
1798                    d.dc_id,
1799                    d.flags.0 as i64,
1800                    d.addr.clone(),
1801                    d.auth_key.map(|k| k.to_vec()),
1802                    d.first_salt,
1803                    d.time_offset,
1804                ],
1805            )
1806            .await?;
1807        }
1808
1809        let us = &s.updates_state;
1810        conn.execute(
1811            "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1812             ON CONFLICT(id) DO UPDATE SET
1813               pts  = MAX(excluded.pts,  update_state.pts),
1814               qts  = MAX(excluded.qts,  update_state.qts),
1815               date = excluded.date,
1816               seq  = excluded.seq",
1817            libsql::params![us.pts, us.qts, us.date, us.seq],
1818        )
1819        .await?;
1820
1821        conn.execute("DELETE FROM channel_pts", ()).await?;
1822        for &(cid, cpts) in &us.channels {
1823            conn.execute(
1824                "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
1825                libsql::params![cid, cpts],
1826            )
1827            .await?;
1828        }
1829
1830        conn.execute("DELETE FROM peers", ()).await?;
1831        for p in &s.peers {
1832            conn.execute(
1833                "INSERT INTO peers (id, access_hash, is_channel) VALUES (?1,?2,?3)",
1834                libsql::params![p.id, p.access_hash, p.is_channel as i32],
1835            )
1836            .await?;
1837        }
1838
1839        conn.execute_batch("COMMIT").await
1840    }
1841}
1842
1843#[cfg(feature = "libsql-session")]
1844impl SessionBackend for LibSqlBackend {
1845    fn save(&self, session: &PersistedSession) -> io::Result<()> {
1846        let conn = self.conn.clone();
1847        let session = session.clone();
1848        Self::block(async move {
1849            let conn = conn.lock().await;
1850            Self::write_session_async(&conn, session).await
1851        })
1852    }
1853
1854    fn load(&self) -> io::Result<Option<PersistedSession>> {
1855        let conn = self.conn.clone();
1856        let count: i64 = Self::block(async move {
1857            let conn = conn.lock().await;
1858            let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
1859            Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
1860        })?;
1861        if count == 0 {
1862            return Ok(None);
1863        }
1864        let conn = self.conn.clone();
1865        Self::block(async move {
1866            let conn = conn.lock().await;
1867            Self::read_session_async(&conn).await
1868        })
1869        .map(Some)
1870    }
1871
1872    fn delete(&self) -> io::Result<()> {
1873        let conn = self.conn.clone();
1874        Self::block(async move {
1875            let conn = conn.lock().await;
1876            conn.execute_batch(
1877                "BEGIN IMMEDIATE;
1878                 DELETE FROM meta;
1879                 DELETE FROM dcs;
1880                 DELETE FROM update_state;
1881                 DELETE FROM channel_pts;
1882                 DELETE FROM peers;
1883                 COMMIT;",
1884            )
1885            .await
1886        })
1887    }
1888
1889    fn name(&self) -> &str {
1890        &self.label
1891    }
1892
1893    // Granular overrides
1894
1895    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1896        let conn = self.conn.clone();
1897        let (dc_id, addr, key, salt, off, flags) = (
1898            entry.dc_id,
1899            entry.addr.clone(),
1900            entry.auth_key.map(|k| k.to_vec()),
1901            entry.first_salt,
1902            entry.time_offset,
1903            entry.flags.0 as i64,
1904        );
1905        Self::block(async move {
1906            let conn = conn.lock().await;
1907            conn.execute(
1908                "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1909                 VALUES (?1,?6,?2,?3,?4,?5)
1910                 ON CONFLICT(dc_id, flags) DO UPDATE SET
1911                   addr=excluded.addr, auth_key=excluded.auth_key,
1912                   first_salt=excluded.first_salt, time_offset=excluded.time_offset",
1913                libsql::params![dc_id, addr, key, salt, off, flags],
1914            )
1915            .await
1916            .map(|_| ())
1917        })
1918    }
1919
1920    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1921        let conn = self.conn.clone();
1922        Self::block(async move {
1923            let conn = conn.lock().await;
1924            conn.execute(
1925                "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
1926                 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
1927                libsql::params![dc_id],
1928            )
1929            .await
1930            .map(|_| ())
1931        })
1932    }
1933
1934    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1935        let conn = self.conn.clone();
1936        Self::block(async move {
1937            let conn = conn.lock().await;
1938            match update {
1939                UpdateStateChange::All(snap) => {
1940                    conn.execute(
1941                        "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
1942                         ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
1943                         date=excluded.date,seq=excluded.seq",
1944                        libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
1945                    )
1946                    .await?;
1947                    conn.execute("DELETE FROM channel_pts", ()).await?;
1948                    for &(cid, cpts) in &snap.channels {
1949                        conn.execute(
1950                            "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
1951                            libsql::params![cid, cpts],
1952                        )
1953                        .await?;
1954                    }
1955                    Ok(())
1956                }
1957                UpdateStateChange::Primary { pts, date, seq } => conn
1958                    .execute(
1959                        "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
1960                         ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
1961                         seq=excluded.seq",
1962                        libsql::params![pts, date, seq],
1963                    )
1964                    .await
1965                    .map(|_| ()),
1966                UpdateStateChange::Secondary { qts } => conn
1967                    .execute(
1968                        "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
1969                         ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
1970                        libsql::params![qts],
1971                    )
1972                    .await
1973                    .map(|_| ()),
1974                UpdateStateChange::Channel { id, pts } => conn
1975                    .execute(
1976                        "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
1977                         ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
1978                        libsql::params![id, pts],
1979                    )
1980                    .await
1981                    .map(|_| ()),
1982            }
1983        })
1984    }
1985
1986    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1987        let conn = self.conn.clone();
1988        let (id, hash, is_ch) = (peer.id, peer.access_hash, peer.is_channel as i32);
1989        Self::block(async move {
1990            let conn = conn.lock().await;
1991            conn.execute(
1992                "INSERT INTO peers (id,access_hash,is_channel) VALUES (?1,?2,?3)
1993                 ON CONFLICT(id) DO UPDATE SET
1994                   access_hash=excluded.access_hash,
1995                   is_channel=excluded.is_channel",
1996                libsql::params![id, hash, is_ch],
1997            )
1998            .await
1999            .map(|_| ())
2000        })
2001    }
2002}