Skip to main content

ferogram_session/
lib.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2//
3// ferogram: async Telegram MTProto client in Rust
4// https://github.com/ankit-chaubey/ferogram
5//
6// Licensed under either the MIT License or the Apache License 2.0.
7// See the LICENSE-MIT or LICENSE-APACHE file in this repository:
8// https://github.com/ankit-chaubey/ferogram
9//
10// Feel free to use, modify, and share this code.
11// Please keep this notice when redistributing.
12
13#![deny(unsafe_code)]
14#![cfg_attr(docsrs, feature(doc_cfg))]
15//! Session persistence types and storage backends for ferogram.
16//!
17//! This crate is part of [ferogram](https://crates.io/crates/ferogram), an async Rust
18//! MTProto client built by [Ankit Chaubey](https://github.com/ankit-chaubey).
19//!
20//! - Channel: [t.me/Ferogram](https://t.me/Ferogram)
21//! - Chat: [t.me/FerogramChat](https://t.me/FerogramChat)
22//!
23//! Most users do not use this crate directly. The `ferogram` crate wires it up
24//! automatically when you call `Client::builder().session(...)` or
25//! `.session_string(...)`.
26//!
27//! # What's in here
28//!
29//! - [`PersistedSession`]: the serializable session struct. Holds the DC table
30//!   (one `AuthKey` + salt + time offset per DC), update sequence counters
31//!   (PTS, QTS, date, seq), and the peer access-hash cache.
32//! - [`SessionBackend`]: the trait all backends implement. A single method:
33//!   `save(&PersistedSession)` and `load() -> Option<PersistedSession>`.
34//! - [`BinaryFileBackend`]: stores the session as a binary file on disk.
35//!   Good for bots and scripts. No extra dependencies.
36//! - [`InMemoryBackend`]: keeps everything in memory. Nothing survives process
37//!   exit. Used for tests and ephemeral tasks.
38//! - [`StringSessionBackend`]: serializes the session to a base64 string.
39//!   Useful for serverless environments where you store state in an env var or
40//!   database column. Load it with `Client::builder().session_string(s)`.
41//! - [`SqliteBackend`]: SQLite-backed storage via rusqlite. Behind the
42//!   `sqlite-session` feature flag. Good for local multi-account tooling.
43//! - [`LibSqlBackend`]: libSQL / Turso backend. Behind `libsql-session`.
44//!   For distributed or edge-hosted session storage.
45//!
46//! You can also implement `SessionBackend` yourself for Redis, PostgreSQL, or
47//! anything else.
48//!
49//! # Binary format
50//!
51//! The file backends start with a version byte:
52//! - `0x01`: legacy (DC table only, no update state or peer cache).
53//! - `0x02`: current (DC table + update state + peer cache).
54//!
55//! `load()` handles both. `save()` always writes v2.
56//!
57//! # Example: export and re-import a session
58//!
59//! ```rust,ignore
60//! # async fn example(client: ferogram::Client) -> anyhow::Result<()> {
61//! // Export
62//! let s = client.export_session_string().await?;
63//!
64//! // Later, in another process or after a restart:
65//! let (client, _) = ferogram::Client::builder()
66//!     .api_id(12345)
67//!     .api_hash("api_hash")
68//!     .session_string(s)
69//!     .connect()
70//!     .await?;
71//! # Ok(())
72//! # }
73//! ```
74
75use std::collections::HashMap;
76use std::io::{self, ErrorKind};
77use std::path::Path;
78
79#[cfg(feature = "serde")]
80mod auth_key_serde {
81    use serde::{Deserialize, Deserializer, Serializer};
82
83    pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
84    where
85        S: Serializer,
86    {
87        match value {
88            Some(k) => s.serialize_some(k.as_slice()),
89            None => s.serialize_none(),
90        }
91    }
92
93    pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
94    where
95        D: Deserializer<'de>,
96    {
97        let opt: Option<Vec<u8>> = Option::deserialize(d)?;
98        match opt {
99            None => Ok(None),
100            Some(v) => {
101                let arr: [u8; 256] = v
102                    .try_into()
103                    .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
104                Ok(Some(arr))
105            }
106        }
107    }
108}
109
110/// Per-DC option flags.
111///
112/// Stored in the session (v3+) so media DCs survive restarts.
113#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
114#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
115pub struct DcFlags(pub u8);
116
117impl DcFlags {
118    pub const NONE: DcFlags = DcFlags(0);
119    pub const IPV6: DcFlags = DcFlags(1 << 0);
120    pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
121    pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
122    pub const CDN: DcFlags = DcFlags(1 << 3);
123    pub const STATIC: DcFlags = DcFlags(1 << 4);
124
125    pub fn contains(self, other: DcFlags) -> bool {
126        self.0 & other.0 == other.0
127    }
128
129    pub fn set(&mut self, flag: DcFlags) {
130        self.0 |= flag.0;
131    }
132}
133
134impl std::ops::BitOr for DcFlags {
135    type Output = DcFlags;
136    fn bitor(self, rhs: DcFlags) -> DcFlags {
137        DcFlags(self.0 | rhs.0)
138    }
139}
140
141/// One entry in the DC address table.
142#[derive(Clone, Debug)]
143#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
144pub struct DcEntry {
145    pub dc_id: i32,
146    pub addr: String,
147    #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
148    pub auth_key: Option<[u8; 256]>,
149    pub first_salt: i64,
150    pub time_offset: i32,
151    /// DC capability flags (IPv6, media-only, CDN, ...).
152    pub flags: DcFlags,
153}
154
155impl DcEntry {
156    /// Returns `true` when this entry represents an IPv6 address.
157    #[inline]
158    pub fn is_ipv6(&self) -> bool {
159        self.flags.contains(DcFlags::IPV6)
160    }
161
162    /// Parse the stored `"ip:port"` / `"[ipv6]:port"` address into a
163    /// [`std::net::SocketAddr`].
164    ///
165    /// Both formats are valid:
166    /// - IPv4: `"149.154.175.53:443"`
167    /// - IPv6: `"[2001:b28:f23d:f001::a]:443"`
168    pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
169        self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
170            io::Error::new(
171                io::ErrorKind::InvalidData,
172                format!("invalid DC address: {:?}", self.addr),
173            )
174        })
175    }
176
177    /// Construct a `DcEntry` from separate IP string, port, and flags.
178    ///
179    /// IPv6 addresses are automatically wrapped in brackets so that
180    /// `socket_addr()` can round-trip them correctly:
181    ///
182    /// ```text
183    /// DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6)
184    /// // addr = "[2001:b28:f23d:f001::a]:443"
185    /// ```
186    ///
187    /// This is the preferred constructor when processing `help.getConfig`
188    /// `DcOption` objects from the Telegram API.
189    pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
190        // IPv6 addresses contain colons; wrap in brackets for SocketAddr compat.
191        let addr = if ip.contains(':') {
192            format!("[{ip}]:{port}")
193        } else {
194            format!("{ip}:{port}")
195        };
196        Self {
197            dc_id,
198            addr,
199            auth_key: None,
200            first_salt: 0,
201            time_offset: 0,
202            flags,
203        }
204    }
205}
206
207/// Snapshot of the MTProto update-sequence state that we persist so that
208/// `catch_up: true` can call `updates.getDifference` with the *pre-shutdown* pts.
209#[derive(Clone, Debug, Default)]
210#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
211pub struct UpdatesStateSnap {
212    /// Main persistence counter (messages, non-channel updates).
213    pub pts: i32,
214    /// Secondary counter for secret chats.
215    pub qts: i32,
216    /// Date of the last received update (Unix timestamp).
217    pub date: i32,
218    /// Combined-container sequence number.
219    pub seq: i32,
220    /// Per-channel persistence counters.  `(channel_id, pts)`.
221    pub channels: Vec<(i64, i32)>,
222}
223
224impl UpdatesStateSnap {
225    /// Returns `true` when we have a real state from the server (pts > 0).
226    #[inline]
227    pub fn is_initialised(&self) -> bool {
228        self.pts > 0
229    }
230
231    /// Advance (or insert) a per-channel pts value.
232    pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
233        if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
234            entry.1 = pts;
235        } else {
236            self.channels.push((channel_id, pts));
237        }
238    }
239
240    /// Look up the stored pts for a channel, returns 0 if unknown.
241    pub fn channel_pts(&self, channel_id: i64) -> i32 {
242        self.channels
243            .iter()
244            .find(|c| c.0 == channel_id)
245            .map(|c| c.1)
246            .unwrap_or(0)
247    }
248}
249
250/// A cached access-hash entry so that the peer can be addressed across restarts
251/// without re-resolving it from Telegram.
252#[derive(Clone, Debug)]
253#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
254pub struct CachedPeer {
255    /// Bare Telegram peer ID (always positive).
256    pub id: i64,
257    /// Access hash bound to the current session.
258    /// Always 0 for regular group chats (they need no access_hash).
259    pub access_hash: i64,
260    /// `true` → channel / supergroup.  `false` → user or regular group.
261    pub is_channel: bool,
262    /// `true` → regular group chat (Chat::Chat / ChatForbidden).
263    /// When true, access_hash is meaningless (groups need no hash).
264    pub is_chat: bool,
265}
266
267/// A min-user context entry: the user was seen with `min=true` (access_hash
268/// not usable directly) so we store the peer+message where they appeared so
269/// that `InputPeerUserFromMessage` can be constructed on restart.
270#[derive(Clone, Debug)]
271#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
272pub struct CachedMinPeer {
273    /// The min user's ID.
274    pub user_id: i64,
275    /// The channel/chat/user ID of the peer that contained the message.
276    pub peer_id: i64,
277    /// The message ID within that peer.
278    pub msg_id: i32,
279}
280
281/// Everything that needs to survive a process restart.
282#[derive(Clone, Debug, Default)]
283#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
284pub struct PersistedSession {
285    pub home_dc_id: i32,
286    pub dcs: Vec<DcEntry>,
287    /// Update counters to enable reliable catch-up after a disconnect.
288    pub updates_state: UpdatesStateSnap,
289    /// Peer access-hash cache so that the client can reach out to any previously
290    /// seen user or channel without re-resolving them.
291    pub peers: Vec<CachedPeer>,
292    /// Min-user message contexts: users seen with `min=true` that can only be
293    /// addressed via `InputPeerUserFromMessage`.
294    pub min_peers: Vec<CachedMinPeer>,
295}
296
297impl PersistedSession {
298    /// Encode the session to raw bytes (v2 binary format).
299    pub fn to_bytes(&self) -> Vec<u8> {
300        let mut b = Vec::with_capacity(512);
301
302        b.push(0x05u8); // version
303
304        b.extend_from_slice(&self.home_dc_id.to_le_bytes());
305
306        b.push(self.dcs.len() as u8);
307        for d in &self.dcs {
308            b.extend_from_slice(&d.dc_id.to_le_bytes());
309            match &d.auth_key {
310                Some(k) => {
311                    b.push(1);
312                    b.extend_from_slice(k);
313                }
314                None => {
315                    b.push(0);
316                }
317            }
318            b.extend_from_slice(&d.first_salt.to_le_bytes());
319            b.extend_from_slice(&d.time_offset.to_le_bytes());
320            let ab = d.addr.as_bytes();
321            b.push(ab.len() as u8);
322            b.extend_from_slice(ab);
323            b.push(d.flags.0);
324        }
325
326        b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
327        b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
328        b.extend_from_slice(&self.updates_state.date.to_le_bytes());
329        b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
330        let ch = &self.updates_state.channels;
331        b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
332        for &(cid, cpts) in ch {
333            b.extend_from_slice(&cid.to_le_bytes());
334            b.extend_from_slice(&cpts.to_le_bytes());
335        }
336
337        // v5 peer type: 0=user, 1=channel, 2=regular-group-chat
338        b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
339        for p in &self.peers {
340            b.extend_from_slice(&p.id.to_le_bytes());
341            b.extend_from_slice(&p.access_hash.to_le_bytes());
342            let peer_type: u8 = if p.is_chat {
343                2
344            } else if p.is_channel {
345                1
346            } else {
347                0
348            };
349            b.push(peer_type);
350        }
351
352        b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
353        for m in &self.min_peers {
354            b.extend_from_slice(&m.user_id.to_le_bytes());
355            b.extend_from_slice(&m.peer_id.to_le_bytes());
356            b.extend_from_slice(&m.msg_id.to_le_bytes());
357        }
358
359        b
360    }
361
362    /// Atomically save the session to `path`.
363    ///
364    /// Writes to `<path>.<seq>.tmp` (unique per call) then renames into place.
365    /// A fixed `.tmp` extension causes OS error 2 (ERROR_FILE_NOT_FOUND) on
366    /// Windows when two concurrent persist_state calls race: thread A renames
367    /// `.tmp` away while thread B's rename finds the source gone.
368    pub fn save(&self, path: &Path) -> io::Result<()> {
369        use std::sync::atomic::{AtomicU64, Ordering};
370        static SEQ: AtomicU64 = AtomicU64::new(0);
371        let n = SEQ.fetch_add(1, Ordering::Relaxed);
372        let tmp = path.with_extension(format!("{n}.tmp"));
373        std::fs::write(&tmp, self.to_bytes())?;
374        std::fs::rename(&tmp, path).inspect_err(|_e| {
375            let _ = std::fs::remove_file(&tmp);
376        })
377    }
378
379    /// Decode a session from raw bytes (v1 or v2 binary format).
380    pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
381        if buf.is_empty() {
382            return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
383        }
384
385        let mut p = 0usize;
386
387        macro_rules! r {
388            ($n:expr) => {{
389                if p + $n > buf.len() {
390                    return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
391                }
392                let s = &buf[p..p + $n];
393                p += $n;
394                s
395            }};
396        }
397        macro_rules! r_i32 {
398            () => {
399                i32::from_le_bytes(r!(4).try_into().unwrap())
400            };
401        }
402        macro_rules! r_i64 {
403            () => {
404                i64::from_le_bytes(r!(8).try_into().unwrap())
405            };
406        }
407        macro_rules! r_u8 {
408            () => {
409                r!(1)[0]
410            };
411        }
412        macro_rules! r_u16 {
413            () => {
414                u16::from_le_bytes(r!(2).try_into().unwrap())
415            };
416        }
417
418        let first_byte = r_u8!();
419
420        let (home_dc_id, version) = if first_byte == 0x05 {
421            (r_i32!(), 5u8)
422        } else if first_byte == 0x04 {
423            (r_i32!(), 4u8)
424        } else if first_byte == 0x03 {
425            (r_i32!(), 3u8)
426        } else if first_byte == 0x02 {
427            (r_i32!(), 2u8)
428        } else {
429            let rest = r!(3);
430            let mut bytes = [0u8; 4];
431            bytes[0] = first_byte;
432            bytes[1..4].copy_from_slice(rest);
433            (i32::from_le_bytes(bytes), 1u8)
434        };
435
436        let dc_count = r_u8!() as usize;
437        let mut dcs = Vec::with_capacity(dc_count);
438        for _ in 0..dc_count {
439            let dc_id = r_i32!();
440            let has_key = r_u8!();
441            let auth_key = if has_key == 1 {
442                let mut k = [0u8; 256];
443                k.copy_from_slice(r!(256));
444                Some(k)
445            } else {
446                None
447            };
448            let first_salt = r_i64!();
449            let time_offset = r_i32!();
450            let al = r_u8!() as usize;
451            let addr = String::from_utf8_lossy(r!(al)).into_owned();
452            let flags = if version >= 3 {
453                DcFlags(r_u8!())
454            } else {
455                DcFlags::NONE
456            };
457            dcs.push(DcEntry {
458                dc_id,
459                addr,
460                auth_key,
461                first_salt,
462                time_offset,
463                flags,
464            });
465        }
466
467        if version < 2 {
468            return Ok(Self {
469                home_dc_id,
470                dcs,
471                updates_state: UpdatesStateSnap::default(),
472                peers: Vec::new(),
473                min_peers: Vec::new(),
474            });
475        }
476
477        let pts = r_i32!();
478        let qts = r_i32!();
479        let date = r_i32!();
480        let seq = r_i32!();
481        let ch_count = r_u16!() as usize;
482        let mut channels = Vec::with_capacity(ch_count);
483        for _ in 0..ch_count {
484            let cid = r_i64!();
485            let cpts = r_i32!();
486            channels.push((cid, cpts));
487        }
488
489        let peer_count = r_u16!() as usize;
490        let mut peers = Vec::with_capacity(peer_count);
491        for _ in 0..peer_count {
492            let id = r_i64!();
493            let access_hash = r_i64!();
494            // v5: type byte 0=user, 1=channel, 2=chat; v2-v4: 0=user, 1=channel
495            let peer_type = r_u8!();
496            let is_channel = peer_type == 1;
497            let is_chat = peer_type == 2;
498            peers.push(CachedPeer {
499                id,
500                access_hash,
501                is_channel,
502                is_chat,
503            });
504        }
505
506        // v4+: min-user contexts
507        let min_peers = if version >= 4 {
508            let count = r_u16!() as usize;
509            let mut v = Vec::with_capacity(count);
510            for _ in 0..count {
511                let user_id = r_i64!();
512                let peer_id = r_i64!();
513                let msg_id = r_i32!();
514                v.push(CachedMinPeer {
515                    user_id,
516                    peer_id,
517                    msg_id,
518                });
519            }
520            v
521        } else {
522            Vec::new()
523        };
524
525        Ok(Self {
526            home_dc_id,
527            dcs,
528            updates_state: UpdatesStateSnap {
529                pts,
530                qts,
531                date,
532                seq,
533                channels,
534            },
535            peers,
536            min_peers,
537        })
538    }
539
540    /// Decode a session from a URL-safe base64 string produced by [`to_string`].
541    pub fn from_string(s: &str) -> io::Result<Self> {
542        use base64::Engine as _;
543        let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
544            .decode(s.trim())
545            .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
546        Self::from_bytes(&bytes)
547    }
548
549    pub fn load(path: &Path) -> io::Result<Self> {
550        let buf = std::fs::read(path)?;
551        Self::from_bytes(&buf)
552    }
553
554    // DC address helpers
555
556    /// Find the best DC entry for a given DC ID.
557    ///
558    /// When `prefer_ipv6` is `true`, returns the IPv6 entry if one is
559    /// stored, falling back to IPv4.  When `false`, returns IPv4,
560    /// falling back to IPv6.  Returns `None` only when the DC ID is
561    /// completely unknown.
562    ///
563    /// This correctly handles the case where both an IPv4 and an IPv6
564    /// `DcEntry` exist for the same `dc_id` (different `flags` bitmask).
565    pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
566        let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
567        candidates.peek()?;
568        // Collect so we can search twice
569        let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
570        // Preferred family first, fall back to whatever is available
571        cands
572            .iter()
573            .copied()
574            .find(|d| d.is_ipv6() == prefer_ipv6)
575            .or_else(|| cands.first().copied())
576    }
577
578    /// Iterate over every stored DC entry for a given DC ID.
579    ///
580    /// Typically yields one IPv4 and one IPv6 entry per DC ID once
581    /// `help.getConfig` has been applied.
582    pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
583        self.dcs.iter().filter(move |d| d.dc_id == dc_id)
584    }
585}
586
587impl std::fmt::Display for PersistedSession {
588    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589        use base64::Engine as _;
590        f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
591    }
592}
593
594/// Bootstrap DC address table (fallback if GetConfig fails).
595pub fn default_dc_addresses() -> HashMap<i32, String> {
596    [
597        (1, "149.154.175.53:443"),
598        (2, "149.154.167.51:443"),
599        (3, "149.154.175.100:443"),
600        (4, "149.154.167.91:443"),
601        (5, "91.108.56.130:443"),
602    ]
603    .into_iter()
604    .map(|(id, addr)| (id, addr.to_string()))
605    .collect()
606}
607
608// session_backend
609//
610// Pluggable session storage backend.
611
612use std::path::PathBuf;
613
614// Core trait (unchanged)
615
616/// Synchronous snapshot backend: saves and loads the full session at once.
617///
618/// All built-in backends implement this. Higher-level code should prefer the
619/// extension methods below (`update_dc`, `set_home_dc`, `update_state`) which
620/// avoid unnecessary full-snapshot writes.
621pub trait SessionBackend: Send + Sync {
622    fn save(&self, session: &PersistedSession) -> io::Result<()>;
623    fn load(&self) -> io::Result<Option<PersistedSession>>;
624    fn delete(&self) -> io::Result<()>;
625
626    /// Human-readable name for logging/debug output.
627    fn name(&self) -> &str;
628
629    // Granular helpers (default: load → mutate → save)
630    //
631    // These default implementations are correct but not optimal.
632    // Backends that store data in a database (SQLite, libsql, Redis) should
633    // override them to issue single-row UPDATE statements instead.
634
635    /// Update a single DC entry without rewriting the entire session.
636    ///
637    /// Typically called after:
638    /// - completing a DH handshake on a new DC (to persist its auth key)
639    /// - receiving updated DC addresses from `help.getConfig`
640    ///
641    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
642        let mut s = self.load()?.unwrap_or_default();
643        // Replace existing entry or append
644        if let Some(existing) = s
645            .dcs
646            .iter_mut()
647            .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
648        {
649            *existing = entry.clone();
650        } else {
651            s.dcs.push(entry.clone());
652        }
653        self.save(&s)
654    }
655
656    /// Change the home DC without touching any other session data.
657    ///
658    /// Called after a successful `*_MIGRATE` redirect: the user's account
659    /// now lives on a different DC.
660    ///
661    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
662        let mut s = self.load()?.unwrap_or_default();
663        s.home_dc_id = dc_id;
664        self.save(&s)
665    }
666
667    /// Apply a single update-sequence change without a full save/load.
668    ///
669    ///
670    /// `update` is the new partial or full state to merge in.
671    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
672        let mut s = self.load()?.unwrap_or_default();
673        update.apply_to(&mut s.updates_state);
674        self.save(&s)
675    }
676
677    /// Cache a peer access hash without a full session save.
678    ///
679    /// This is lossy-on-default (full round-trip) but correct.
680    /// Override in SQL backends to issue a single `INSERT OR REPLACE`.
681    ///
682    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
683        let mut s = self.load()?.unwrap_or_default();
684        if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
685            *existing = peer.clone();
686        } else {
687            s.peers.push(peer.clone());
688        }
689        self.save(&s)
690    }
691}
692
693// UpdateStateChange (mirrors  UpdateState enum)
694
695/// A single update-sequence change, applied via [`SessionBackend::apply_update_state`].
696///
697///uses:
698/// ```text
699/// UpdateState::All(updates_state)
700/// UpdateState::Primary { pts, date, seq }
701/// UpdateState::Secondary { qts }
702/// UpdateState::Channel { id, pts }
703/// ```
704///
705/// We map this 1-to-1 to layer's `UpdatesStateSnap`.
706#[derive(Debug, Clone)]
707pub enum UpdateStateChange {
708    /// Replace the entire state snapshot.
709    All(UpdatesStateSnap),
710    /// Update main sequence counters only (non-channel).
711    Primary { pts: i32, date: i32, seq: i32 },
712    /// Update the QTS counter (secret chats).
713    Secondary { qts: i32 },
714    /// Update the PTS for a specific channel.
715    Channel { id: i64, pts: i32 },
716}
717
718impl UpdateStateChange {
719    /// Apply `self` to `snap` in-place.
720    pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
721        match self {
722            Self::All(new_snap) => *snap = new_snap.clone(),
723            Self::Primary { pts, date, seq } => {
724                snap.pts = *pts;
725                snap.date = *date;
726                snap.seq = *seq;
727            }
728            Self::Secondary { qts } => {
729                snap.qts = *qts;
730            }
731            Self::Channel { id, pts } => {
732                // Replace or insert per-channel pts
733                if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
734                    existing.1 = *pts;
735                } else {
736                    snap.channels.push((*id, *pts));
737                }
738            }
739        }
740    }
741}
742
743// BinaryFileBackend
744
745/// Stores the session in a compact binary file (v2 format).
746pub struct BinaryFileBackend {
747    path: PathBuf,
748    /// Serialises concurrent save() calls within the same process so they
749    /// don't interleave on the tmp file even if PersistedSession::save uses
750    /// unique names (belt-and-suspenders; also prevents torn reads of the
751    /// session file from a concurrent load + save).
752    write_lock: std::sync::Mutex<()>,
753}
754
755impl BinaryFileBackend {
756    pub fn new(path: impl Into<PathBuf>) -> Self {
757        Self {
758            path: path.into(),
759            write_lock: std::sync::Mutex::new(()),
760        }
761    }
762
763    pub fn path(&self) -> &std::path::Path {
764        &self.path
765    }
766}
767
768impl SessionBackend for BinaryFileBackend {
769    fn save(&self, session: &PersistedSession) -> io::Result<()> {
770        let _guard = self.write_lock.lock().unwrap();
771        session.save(&self.path)
772    }
773
774    fn load(&self) -> io::Result<Option<PersistedSession>> {
775        if !self.path.exists() {
776            return Ok(None);
777        }
778        match PersistedSession::load(&self.path) {
779            Ok(s) => Ok(Some(s)),
780            Err(e) => {
781                let bak = self.path.with_extension("bak");
782                tracing::warn!(
783                    "[ferogram] Session file {:?} is corrupt ({e}); \
784                     renaming to {:?} and starting fresh",
785                    self.path,
786                    bak
787                );
788                let _ = std::fs::rename(&self.path, &bak);
789                Ok(None)
790            }
791        }
792    }
793
794    fn delete(&self) -> io::Result<()> {
795        if self.path.exists() {
796            std::fs::remove_file(&self.path)?;
797        }
798        Ok(())
799    }
800
801    fn name(&self) -> &str {
802        "binary-file"
803    }
804
805    // BinaryFileBackend: the default granular impls (load→mutate→save) are
806    // fine since the format is a single compact binary blob. No override needed.
807}
808
809// InMemoryBackend
810
811/// Ephemeral in-process session: nothing persisted to disk.
812///
813/// Override the granular methods to skip the clone overhead of the full
814/// snapshot path (we're already in memory, so direct field mutations are
815/// cheaper than clone→mutate→replace).
816#[derive(Default)]
817pub struct InMemoryBackend {
818    data: std::sync::Mutex<Option<PersistedSession>>,
819}
820
821impl InMemoryBackend {
822    pub fn new() -> Self {
823        Self::default()
824    }
825
826    /// Test helper: get a snapshot of the current in-memory state.
827    pub fn snapshot(&self) -> Option<PersistedSession> {
828        self.data.lock().unwrap().clone()
829    }
830}
831
832impl SessionBackend for InMemoryBackend {
833    fn save(&self, s: &PersistedSession) -> io::Result<()> {
834        *self.data.lock().unwrap() = Some(s.clone());
835        Ok(())
836    }
837
838    fn load(&self) -> io::Result<Option<PersistedSession>> {
839        Ok(self.data.lock().unwrap().clone())
840    }
841
842    fn delete(&self) -> io::Result<()> {
843        *self.data.lock().unwrap() = None;
844        Ok(())
845    }
846
847    fn name(&self) -> &str {
848        "in-memory"
849    }
850
851    // Granular overrides: cheaper than load→clone→save
852
853    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
854        let mut guard = self.data.lock().unwrap();
855        let s = guard.get_or_insert_with(PersistedSession::default);
856        if let Some(existing) = s
857            .dcs
858            .iter_mut()
859            .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
860        {
861            *existing = entry.clone();
862        } else {
863            s.dcs.push(entry.clone());
864        }
865        Ok(())
866    }
867
868    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
869        let mut guard = self.data.lock().unwrap();
870        let s = guard.get_or_insert_with(PersistedSession::default);
871        s.home_dc_id = dc_id;
872        Ok(())
873    }
874
875    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
876        let mut guard = self.data.lock().unwrap();
877        let s = guard.get_or_insert_with(PersistedSession::default);
878        update.apply_to(&mut s.updates_state);
879        Ok(())
880    }
881
882    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
883        let mut guard = self.data.lock().unwrap();
884        let s = guard.get_or_insert_with(PersistedSession::default);
885        if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
886            *existing = peer.clone();
887        } else {
888            s.peers.push(peer.clone());
889        }
890        Ok(())
891    }
892}
893
894// StringSessionBackend
895
896/// Portable base64 string session backend.
897pub struct StringSessionBackend {
898    data: std::sync::Mutex<String>,
899}
900
901impl StringSessionBackend {
902    pub fn new(s: impl Into<String>) -> Self {
903        Self {
904            data: std::sync::Mutex::new(s.into()),
905        }
906    }
907
908    pub fn current(&self) -> String {
909        self.data.lock().unwrap().clone()
910    }
911}
912
913impl SessionBackend for StringSessionBackend {
914    fn save(&self, session: &PersistedSession) -> io::Result<()> {
915        *self.data.lock().unwrap() = session.to_string();
916        Ok(())
917    }
918
919    fn load(&self) -> io::Result<Option<PersistedSession>> {
920        let s = self.data.lock().unwrap().clone();
921        if s.trim().is_empty() {
922            return Ok(None);
923        }
924        PersistedSession::from_string(&s).map(Some)
925    }
926
927    fn delete(&self) -> io::Result<()> {
928        *self.data.lock().unwrap() = String::new();
929        Ok(())
930    }
931
932    fn name(&self) -> &str {
933        "string-session"
934    }
935}
936
937// Tests
938
939/// Portable compact string-session encoding/decoding (V1/V2 binary base64).
940///
941/// Use via `Client::builder().session_string("ABC...")` which auto-detects
942/// the format. Use this module directly only when you need to inspect or
943/// construct a [`StringSession`] manually.
944pub mod string_session;
945pub use string_session::{FullSession, Session, StringSession, StringSessionError};
946
947#[cfg(test)]
948mod tests {
949    use super::*;
950
951    fn make_dc(id: i32) -> DcEntry {
952        DcEntry {
953            dc_id: id,
954            addr: format!("1.2.3.{id}:443"),
955            auth_key: None,
956            first_salt: 0,
957            time_offset: 0,
958            flags: DcFlags::NONE,
959        }
960    }
961
962    fn make_peer(id: i64, hash: i64) -> CachedPeer {
963        CachedPeer {
964            id,
965            access_hash: hash,
966            is_channel: false,
967            is_chat: false,
968        }
969    }
970
971    // InMemoryBackend: basic save/load
972
973    #[test]
974    fn inmemory_load_returns_none_when_empty() {
975        let b = InMemoryBackend::new();
976        assert!(b.load().unwrap().is_none());
977    }
978
979    #[test]
980    fn inmemory_save_then_load_round_trips() {
981        let b = InMemoryBackend::new();
982        let mut s = PersistedSession::default();
983        s.home_dc_id = 3;
984        s.dcs.push(make_dc(3));
985        b.save(&s).unwrap();
986
987        let loaded = b.load().unwrap().unwrap();
988        assert_eq!(loaded.home_dc_id, 3);
989        assert_eq!(loaded.dcs.len(), 1);
990    }
991
992    #[test]
993    fn inmemory_delete_clears_state() {
994        let b = InMemoryBackend::new();
995        let mut s = PersistedSession::default();
996        s.home_dc_id = 2;
997        b.save(&s).unwrap();
998        b.delete().unwrap();
999        assert!(b.load().unwrap().is_none());
1000    }
1001
1002    // InMemoryBackend: granular methods
1003
1004    #[test]
1005    fn inmemory_update_dc_inserts_new() {
1006        let b = InMemoryBackend::new();
1007        b.update_dc(&make_dc(4)).unwrap();
1008        let s = b.snapshot().unwrap();
1009        assert_eq!(s.dcs.len(), 1);
1010        assert_eq!(s.dcs[0].dc_id, 4);
1011    }
1012
1013    #[test]
1014    fn inmemory_update_dc_replaces_existing() {
1015        let b = InMemoryBackend::new();
1016        b.update_dc(&make_dc(2)).unwrap();
1017        let mut updated = make_dc(2);
1018        updated.addr = "9.9.9.9:443".to_string();
1019        b.update_dc(&updated).unwrap();
1020
1021        let s = b.snapshot().unwrap();
1022        assert_eq!(s.dcs.len(), 1);
1023        assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
1024    }
1025
1026    #[test]
1027    fn inmemory_set_home_dc() {
1028        let b = InMemoryBackend::new();
1029        b.set_home_dc(5).unwrap();
1030        assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
1031    }
1032
1033    #[test]
1034    fn inmemory_cache_peer_inserts() {
1035        let b = InMemoryBackend::new();
1036        b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
1037        let s = b.snapshot().unwrap();
1038        assert_eq!(s.peers.len(), 1);
1039        assert_eq!(s.peers[0].id, 100);
1040    }
1041
1042    #[test]
1043    fn inmemory_cache_peer_updates_existing() {
1044        let b = InMemoryBackend::new();
1045        b.cache_peer(&make_peer(100, 111)).unwrap();
1046        b.cache_peer(&make_peer(100, 222)).unwrap();
1047        let s = b.snapshot().unwrap();
1048        assert_eq!(s.peers.len(), 1);
1049        assert_eq!(s.peers[0].access_hash, 222);
1050    }
1051
1052    // UpdateStateChange
1053
1054    #[test]
1055    fn update_state_primary() {
1056        let mut snap = UpdatesStateSnap {
1057            pts: 0,
1058            qts: 0,
1059            date: 0,
1060            seq: 0,
1061            channels: vec![],
1062        };
1063        UpdateStateChange::Primary {
1064            pts: 10,
1065            date: 20,
1066            seq: 30,
1067        }
1068        .apply_to(&mut snap);
1069        assert_eq!(snap.pts, 10);
1070        assert_eq!(snap.date, 20);
1071        assert_eq!(snap.seq, 30);
1072        assert_eq!(snap.qts, 0); // untouched
1073    }
1074
1075    #[test]
1076    fn update_state_secondary() {
1077        let mut snap = UpdatesStateSnap {
1078            pts: 5,
1079            qts: 0,
1080            date: 0,
1081            seq: 0,
1082            channels: vec![],
1083        };
1084        UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1085        assert_eq!(snap.qts, 99);
1086        assert_eq!(snap.pts, 5); // untouched
1087    }
1088
1089    #[test]
1090    fn update_state_channel_inserts() {
1091        let mut snap = UpdatesStateSnap {
1092            pts: 0,
1093            qts: 0,
1094            date: 0,
1095            seq: 0,
1096            channels: vec![],
1097        };
1098        UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1099        assert_eq!(snap.channels, vec![(12345, 42)]);
1100    }
1101
1102    #[test]
1103    fn update_state_channel_updates_existing() {
1104        let mut snap = UpdatesStateSnap {
1105            pts: 0,
1106            qts: 0,
1107            date: 0,
1108            seq: 0,
1109            channels: vec![(12345, 10), (67890, 5)],
1110        };
1111        UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1112        // First channel updated, second untouched
1113        assert_eq!(snap.channels[0], (12345, 99));
1114        assert_eq!(snap.channels[1], (67890, 5));
1115    }
1116
1117    #[test]
1118    fn apply_update_state_via_backend() {
1119        let b = InMemoryBackend::new();
1120        b.apply_update_state(UpdateStateChange::Primary {
1121            pts: 7,
1122            date: 8,
1123            seq: 9,
1124        })
1125        .unwrap();
1126        let s = b.snapshot().unwrap();
1127        assert_eq!(s.updates_state.pts, 7);
1128    }
1129
1130    // Default impl (BinaryFileBackend trait shape via InMemory smoke)
1131
1132    #[test]
1133    fn default_update_dc_via_trait_object() {
1134        let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1135        b.update_dc(&make_dc(1)).unwrap();
1136        b.update_dc(&make_dc(2)).unwrap();
1137        // Can't call snapshot() on trait object, but save/load must be consistent
1138        let loaded = b.load().unwrap().unwrap();
1139        assert_eq!(loaded.dcs.len(), 2);
1140    }
1141
1142    // IPv6 tests
1143
1144    fn make_dc_v6(id: i32) -> DcEntry {
1145        DcEntry {
1146            dc_id: id,
1147            addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1148            auth_key: None,
1149            first_salt: 0,
1150            time_offset: 0,
1151            flags: DcFlags::IPV6,
1152        }
1153    }
1154
1155    #[test]
1156    fn dc_entry_from_parts_ipv4() {
1157        let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1158        assert_eq!(dc.addr, "149.154.175.53:443");
1159        assert!(!dc.is_ipv6());
1160        let sa = dc.socket_addr().unwrap();
1161        assert_eq!(sa.port(), 443);
1162    }
1163
1164    #[test]
1165    fn dc_entry_from_parts_ipv6() {
1166        let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1167        assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1168        assert!(dc.is_ipv6());
1169        let sa = dc.socket_addr().unwrap();
1170        assert_eq!(sa.port(), 443);
1171    }
1172
1173    #[test]
1174    fn persisted_session_dc_for_prefers_ipv6() {
1175        let mut s = PersistedSession::default();
1176        s.dcs.push(make_dc(2)); // IPv4
1177        s.dcs.push(make_dc_v6(2)); // IPv6
1178
1179        let v6 = s.dc_for(2, true).unwrap();
1180        assert!(v6.is_ipv6());
1181
1182        let v4 = s.dc_for(2, false).unwrap();
1183        assert!(!v4.is_ipv6());
1184    }
1185
1186    #[test]
1187    fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1188        let mut s = PersistedSession::default();
1189        s.dcs.push(make_dc(3)); // IPv4 only
1190
1191        // Asking for IPv6 should fall back to IPv4
1192        let dc = s.dc_for(3, true).unwrap();
1193        assert!(!dc.is_ipv6());
1194    }
1195
1196    #[test]
1197    fn persisted_session_all_dcs_for_returns_both() {
1198        let mut s = PersistedSession::default();
1199        s.dcs.push(make_dc(1));
1200        s.dcs.push(make_dc_v6(1));
1201        s.dcs.push(make_dc(2));
1202
1203        assert_eq!(s.all_dcs_for(1).count(), 2);
1204        assert_eq!(s.all_dcs_for(2).count(), 1);
1205        assert_eq!(s.all_dcs_for(5).count(), 0);
1206    }
1207
1208    #[test]
1209    fn inmemory_ipv4_and_ipv6_coexist() {
1210        let b = InMemoryBackend::new();
1211        b.update_dc(&make_dc(2)).unwrap(); // IPv4
1212        b.update_dc(&make_dc_v6(2)).unwrap(); // IPv6
1213
1214        let s = b.snapshot().unwrap();
1215        // Both entries must survive they have different flags
1216        assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1217    }
1218
1219    #[test]
1220    fn binary_roundtrip_ipv4_and_ipv6() {
1221        let mut s = PersistedSession::default();
1222        s.home_dc_id = 2;
1223        s.dcs.push(make_dc(2));
1224        s.dcs.push(make_dc_v6(2));
1225
1226        let bytes = s.to_bytes();
1227        let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1228        assert_eq!(loaded.dcs.len(), 2);
1229        assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1230        assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1231    }
1232}
1233
1234// SqliteBackend
1235
1236/// SQLite-backed session (via `rusqlite`).
1237///
1238/// Enabled with the `sqlite-session` Cargo feature.
1239///
1240/// # Schema
1241///
1242/// Six tables are created on first open (idempotent):
1243///
1244/// | Table          | Purpose                                          |
1245/// |----------------|--------------------------------------------------|
1246/// | `meta`         | `home_dc_id` and future scalar values            |
1247/// | `dcs`          | One row per DC (auth key, address, flags, ...)     |
1248/// | `update_state` | Single-row pts / qts / date / seq                |
1249/// | `channel_pts`  | Per-channel pts                                  |
1250/// | `peers`        | Access-hash cache (includes is_chat flag)        |
1251/// | `min_peers`    | Min-user message contexts                        |
1252///
1253/// # Granular writes
1254///
1255/// All [`SessionBackend`] extension methods (`update_dc`, `set_home_dc`,
1256/// `apply_update_state`, `cache_peer`) issue **single-row SQL statements**
1257/// instead of the default load-mutate-save round-trip, so they are safe to
1258/// call frequently (e.g. on every update batch) without performance concerns.
1259#[cfg(feature = "sqlite-session")]
1260pub struct SqliteBackend {
1261    conn: std::sync::Mutex<rusqlite::Connection>,
1262    label: String,
1263}
1264
1265#[cfg(feature = "sqlite-session")]
1266impl SqliteBackend {
1267    const SCHEMA: &'static str = "
1268        PRAGMA journal_mode = WAL;
1269        PRAGMA synchronous  = NORMAL;
1270
1271        CREATE TABLE IF NOT EXISTS meta (
1272            key   TEXT    PRIMARY KEY,
1273            value INTEGER NOT NULL DEFAULT 0
1274        );
1275
1276        CREATE TABLE IF NOT EXISTS dcs (
1277            dc_id       INTEGER NOT NULL,
1278            flags       INTEGER NOT NULL DEFAULT 0,
1279            addr        TEXT    NOT NULL,
1280            auth_key    BLOB,
1281            first_salt  INTEGER NOT NULL DEFAULT 0,
1282            time_offset INTEGER NOT NULL DEFAULT 0,
1283            PRIMARY KEY (dc_id, flags)
1284        );
1285
1286        CREATE TABLE IF NOT EXISTS update_state (
1287            id   INTEGER PRIMARY KEY CHECK (id = 1),
1288            pts  INTEGER NOT NULL DEFAULT 0,
1289            qts  INTEGER NOT NULL DEFAULT 0,
1290            date INTEGER NOT NULL DEFAULT 0,
1291            seq  INTEGER NOT NULL DEFAULT 0
1292        );
1293
1294        CREATE TABLE IF NOT EXISTS channel_pts (
1295            channel_id INTEGER PRIMARY KEY,
1296            pts        INTEGER NOT NULL
1297        );
1298
1299        CREATE TABLE IF NOT EXISTS peers (
1300            id           INTEGER PRIMARY KEY,
1301            access_hash  INTEGER NOT NULL,
1302            is_channel   INTEGER NOT NULL DEFAULT 0,
1303            is_chat      INTEGER NOT NULL DEFAULT 0
1304        );
1305
1306        CREATE TABLE IF NOT EXISTS min_peers (
1307            user_id INTEGER PRIMARY KEY,
1308            peer_id INTEGER NOT NULL,
1309            msg_id  INTEGER NOT NULL
1310        );
1311    ";
1312
1313    /// Open (or create) the SQLite database at `path`.
1314    pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1315        let path = path.into();
1316        let label = path.display().to_string();
1317        let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1318        conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1319        Self::migrate_legacy_sqlite_schema(&conn)?;
1320        Ok(Self {
1321            conn: std::sync::Mutex::new(conn),
1322            label,
1323        })
1324    }
1325
1326    /// Open an in-process SQLite database (useful for tests).
1327    pub fn in_memory() -> io::Result<Self> {
1328        let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1329        conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1330        Self::migrate_legacy_sqlite_schema(&conn)?;
1331        Ok(Self {
1332            conn: std::sync::Mutex::new(conn),
1333            label: ":memory:".into(),
1334        })
1335    }
1336
1337    fn map_err(e: rusqlite::Error) -> io::Error {
1338        io::Error::other(e)
1339    }
1340
1341    /// Migrate an older database that is missing the is_chat column or the
1342    /// min_peers table. Safe to call on a fresh database; both operations
1343    /// are no-ops if the schema is already current.
1344    fn migrate_legacy_sqlite_schema(conn: &rusqlite::Connection) -> io::Result<()> {
1345        let mut has_is_chat = false;
1346        let mut stmt = conn
1347            .prepare("PRAGMA table_info(peers)")
1348            .map_err(Self::map_err)?;
1349        let cols = stmt
1350            .query_map([], |row| row.get::<_, String>(1))
1351            .map_err(Self::map_err)?;
1352        for col in cols.filter_map(|r| r.ok()) {
1353            if col == "is_chat" {
1354                has_is_chat = true;
1355                break;
1356            }
1357        }
1358        if !has_is_chat {
1359            conn.execute_batch("ALTER TABLE peers ADD COLUMN is_chat INTEGER NOT NULL DEFAULT 0;")
1360                .map_err(Self::map_err)?;
1361        }
1362        conn.execute_batch(
1363            "CREATE TABLE IF NOT EXISTS min_peers (
1364                user_id INTEGER PRIMARY KEY,
1365                peer_id INTEGER NOT NULL,
1366                msg_id  INTEGER NOT NULL
1367            );",
1368        )
1369        .map_err(Self::map_err)?;
1370        Ok(())
1371    }
1372
1373    /// Read the full session out of the database.
1374    fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1375        // home_dc_id
1376        let home_dc_id: i32 = conn
1377            .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1378                r.get(0)
1379            })
1380            .unwrap_or(0);
1381
1382        // dcs
1383        let mut stmt = conn
1384            .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1385            .map_err(Self::map_err)?;
1386        let dcs = stmt
1387            .query_map([], |row| {
1388                let dc_id: i32 = row.get(0)?;
1389                let flags_raw: u8 = row.get(1)?;
1390                let addr: String = row.get(2)?;
1391                let key_blob: Option<Vec<u8>> = row.get(3)?;
1392                let first_salt: i64 = row.get(4)?;
1393                let time_offset: i32 = row.get(5)?;
1394                Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1395            })
1396            .map_err(Self::map_err)?
1397            .filter_map(|r| r.ok())
1398            .map(
1399                |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1400                    let auth_key = key_blob.and_then(|b| {
1401                        if b.len() == 256 {
1402                            let mut k = [0u8; 256];
1403                            k.copy_from_slice(&b);
1404                            Some(k)
1405                        } else {
1406                            None
1407                        }
1408                    });
1409                    DcEntry {
1410                        dc_id,
1411                        addr,
1412                        auth_key,
1413                        first_salt,
1414                        time_offset,
1415                        flags: DcFlags(flags_raw),
1416                    }
1417                },
1418            )
1419            .collect();
1420
1421        // update_state
1422        let updates_state = conn
1423            .query_row(
1424                "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1425                [],
1426                |r| {
1427                    Ok(UpdatesStateSnap {
1428                        pts: r.get(0)?,
1429                        qts: r.get(1)?,
1430                        date: r.get(2)?,
1431                        seq: r.get(3)?,
1432                        channels: vec![],
1433                    })
1434                },
1435            )
1436            .unwrap_or_default();
1437
1438        // channel_pts
1439        let mut ch_stmt = conn
1440            .prepare("SELECT channel_id, pts FROM channel_pts")
1441            .map_err(Self::map_err)?;
1442        let channels: Vec<(i64, i32)> = ch_stmt
1443            .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1444            .map_err(Self::map_err)?
1445            .filter_map(|r| r.ok())
1446            .collect();
1447
1448        // peers
1449        let mut peer_stmt = conn
1450            .prepare("SELECT id, access_hash, is_channel, is_chat FROM peers")
1451            .map_err(Self::map_err)?;
1452        let peers: Vec<CachedPeer> = peer_stmt
1453            .query_map([], |r| {
1454                Ok(CachedPeer {
1455                    id: r.get(0)?,
1456                    access_hash: r.get(1)?,
1457                    is_channel: r.get::<_, i32>(2)? != 0,
1458                    is_chat: r.get::<_, i32>(3)? != 0,
1459                })
1460            })
1461            .map_err(Self::map_err)?
1462            .filter_map(|r| r.ok())
1463            .collect();
1464
1465        // min_peers
1466        let mut min_stmt = conn
1467            .prepare("SELECT user_id, peer_id, msg_id FROM min_peers")
1468            .map_err(Self::map_err)?;
1469        let min_peers: Vec<CachedMinPeer> = min_stmt
1470            .query_map([], |r| {
1471                Ok(CachedMinPeer {
1472                    user_id: r.get(0)?,
1473                    peer_id: r.get(1)?,
1474                    msg_id: r.get(2)?,
1475                })
1476            })
1477            .map_err(Self::map_err)?
1478            .filter_map(|r| r.ok())
1479            .collect();
1480
1481        Ok(PersistedSession {
1482            home_dc_id,
1483            dcs,
1484            updates_state: UpdatesStateSnap {
1485                channels,
1486                ..updates_state
1487            },
1488            peers,
1489            min_peers,
1490        })
1491    }
1492
1493    /// Write the full session into the database inside a single transaction.
1494    fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1495        conn.execute_batch("BEGIN IMMEDIATE")
1496            .map_err(Self::map_err)?;
1497
1498        conn.execute(
1499            "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1500             ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1501            rusqlite::params![s.home_dc_id],
1502        )
1503        .map_err(Self::map_err)?;
1504
1505        // Replace all DCs
1506        conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1507        for d in &s.dcs {
1508            conn.execute(
1509                "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1510                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1511                rusqlite::params![
1512                    d.dc_id,
1513                    d.flags.0,
1514                    d.addr,
1515                    d.auth_key.as_ref().map(|k| k.as_ref()),
1516                    d.first_salt,
1517                    d.time_offset,
1518                ],
1519            )
1520            .map_err(Self::map_err)?;
1521        }
1522
1523        // update_state  pts and qts are monotonic: write_session() must never
1524        // move them backwards. MAX() ensures a stale snapshot cannot overwrite
1525        // a fresher value committed by apply_update_state().
1526        let us = &s.updates_state;
1527        conn.execute(
1528            "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1529             ON CONFLICT(id) DO UPDATE SET
1530               pts  = MAX(excluded.pts,  update_state.pts),
1531               qts  = MAX(excluded.qts,  update_state.qts),
1532               date = excluded.date,
1533               seq  = excluded.seq",
1534            rusqlite::params![us.pts, us.qts, us.date, us.seq],
1535        )
1536        .map_err(Self::map_err)?;
1537
1538        conn.execute("DELETE FROM channel_pts", [])
1539            .map_err(Self::map_err)?;
1540        for &(cid, cpts) in &us.channels {
1541            conn.execute(
1542                "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1543                rusqlite::params![cid, cpts],
1544            )
1545            .map_err(Self::map_err)?;
1546        }
1547
1548        // peers
1549        conn.execute("DELETE FROM peers", [])
1550            .map_err(Self::map_err)?;
1551        for p in &s.peers {
1552            conn.execute(
1553                "INSERT INTO peers (id, access_hash, is_channel, is_chat) VALUES (?1, ?2, ?3, ?4)",
1554                rusqlite::params![p.id, p.access_hash, p.is_channel as i32, p.is_chat as i32],
1555            )
1556            .map_err(Self::map_err)?;
1557        }
1558
1559        // min_peers
1560        conn.execute("DELETE FROM min_peers", [])
1561            .map_err(Self::map_err)?;
1562        for m in &s.min_peers {
1563            conn.execute(
1564                "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1, ?2, ?3)",
1565                rusqlite::params![m.user_id, m.peer_id, m.msg_id],
1566            )
1567            .map_err(Self::map_err)?;
1568        }
1569
1570        conn.execute_batch("COMMIT").map_err(Self::map_err)
1571    }
1572}
1573
1574#[cfg(feature = "sqlite-session")]
1575impl SessionBackend for SqliteBackend {
1576    fn save(&self, session: &PersistedSession) -> io::Result<()> {
1577        let conn = self.conn.lock().unwrap();
1578        Self::write_session(&conn, session)
1579    }
1580
1581    fn load(&self) -> io::Result<Option<PersistedSession>> {
1582        let conn = self.conn.lock().unwrap();
1583        // If meta table is empty, no session has been saved yet.
1584        let count: i64 = conn
1585            .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1586            .map_err(Self::map_err)?;
1587        if count == 0 {
1588            return Ok(None);
1589        }
1590        Self::read_session(&conn).map(Some)
1591    }
1592
1593    fn delete(&self) -> io::Result<()> {
1594        let conn = self.conn.lock().unwrap();
1595        conn.execute_batch(
1596            "BEGIN IMMEDIATE;
1597             DELETE FROM meta;
1598             DELETE FROM dcs;
1599             DELETE FROM update_state;
1600             DELETE FROM channel_pts;
1601             DELETE FROM peers;
1602             DELETE FROM min_peers;
1603             COMMIT;",
1604        )
1605        .map_err(Self::map_err)
1606    }
1607
1608    fn name(&self) -> &str {
1609        &self.label
1610    }
1611
1612    // Granular overrides (single-row SQL, no full round-trip)
1613
1614    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1615        let conn = self.conn.lock().unwrap();
1616        conn.execute(
1617            "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1618             VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1619             ON CONFLICT(dc_id, flags) DO UPDATE SET
1620               addr        = excluded.addr,
1621               auth_key    = excluded.auth_key,
1622               first_salt  = excluded.first_salt,
1623               time_offset = excluded.time_offset",
1624            rusqlite::params![
1625                entry.dc_id,
1626                entry.addr,
1627                entry.auth_key.as_ref().map(|k| k.as_ref()),
1628                entry.first_salt,
1629                entry.time_offset,
1630                entry.flags.0,
1631            ],
1632        )
1633        .map(|_| ())
1634        .map_err(Self::map_err)
1635    }
1636
1637    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1638        let conn = self.conn.lock().unwrap();
1639        conn.execute(
1640            "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1641             ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1642            rusqlite::params![dc_id],
1643        )
1644        .map(|_| ())
1645        .map_err(Self::map_err)
1646    }
1647
1648    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1649        let conn = self.conn.lock().unwrap();
1650        match update {
1651            UpdateStateChange::All(snap) => {
1652                conn.execute(
1653                    "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1654                     ON CONFLICT(id) DO UPDATE SET
1655                       pts=excluded.pts, qts=excluded.qts,
1656                       date=excluded.date, seq=excluded.seq",
1657                    rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1658                )
1659                .map_err(Self::map_err)?;
1660                conn.execute("DELETE FROM channel_pts", [])
1661                    .map_err(Self::map_err)?;
1662                for &(cid, cpts) in &snap.channels {
1663                    conn.execute(
1664                        "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1665                        rusqlite::params![cid, cpts],
1666                    )
1667                    .map_err(Self::map_err)?;
1668                }
1669                Ok(())
1670            }
1671            UpdateStateChange::Primary { pts, date, seq } => conn
1672                .execute(
1673                    "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1674                     ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1675                     seq=excluded.seq",
1676                    rusqlite::params![pts, date, seq],
1677                )
1678                .map(|_| ())
1679                .map_err(Self::map_err),
1680            UpdateStateChange::Secondary { qts } => conn
1681                .execute(
1682                    "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1683                     ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1684                    rusqlite::params![qts],
1685                )
1686                .map(|_| ())
1687                .map_err(Self::map_err),
1688            UpdateStateChange::Channel { id, pts } => conn
1689                .execute(
1690                    "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1691                     ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1692                    rusqlite::params![id, pts],
1693                )
1694                .map(|_| ())
1695                .map_err(Self::map_err),
1696        }
1697    }
1698
1699    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1700        let conn = self.conn.lock().unwrap();
1701        conn.execute(
1702            "INSERT INTO peers (id, access_hash, is_channel, is_chat) VALUES (?1, ?2, ?3, ?4)
1703             ON CONFLICT(id) DO UPDATE SET
1704               access_hash = excluded.access_hash,
1705               is_channel  = excluded.is_channel,
1706               is_chat     = excluded.is_chat",
1707            rusqlite::params![
1708                peer.id,
1709                peer.access_hash,
1710                peer.is_channel as i32,
1711                peer.is_chat as i32
1712            ],
1713        )
1714        .map(|_| ())
1715        .map_err(Self::map_err)
1716    }
1717}
1718
1719// LibSqlBackend
1720
1721/// libSQL-backed session (Turso / embedded replica / in-process).
1722///
1723/// Enabled with the `libsql-session` Cargo feature.
1724///
1725/// The libSQL API is async; since [`SessionBackend`] methods are sync we
1726/// block via `tokio::runtime::Handle::current().block_on(...)`.  Always
1727/// call from inside a Tokio runtime (i.e. the same runtime as the rest of
1728/// `ferogram`).
1729///
1730/// # Connecting
1731///
1732/// | Mode              | Constructor                        |
1733/// |-------------------|------------------------------------|
1734/// | Local file        | `LibSqlBackend::open_local(path)`  |
1735/// | In-memory         | `LibSqlBackend::in_memory()`       |
1736/// | Turso remote      | `LibSqlBackend::open_remote(url, token)` |
1737/// | Embedded replica  | `LibSqlBackend::open_replica(path, url, token)` |
1738#[cfg(feature = "libsql-session")]
1739pub struct LibSqlBackend {
1740    conn: libsql::Connection,
1741    label: String,
1742}
1743
1744#[cfg(feature = "libsql-session")]
1745impl LibSqlBackend {
1746    const SCHEMA: &'static str = "
1747        CREATE TABLE IF NOT EXISTS meta (
1748            key   TEXT    PRIMARY KEY,
1749            value INTEGER NOT NULL DEFAULT 0
1750        );
1751        CREATE TABLE IF NOT EXISTS dcs (
1752            dc_id       INTEGER NOT NULL,
1753            flags       INTEGER NOT NULL DEFAULT 0,
1754            addr        TEXT    NOT NULL,
1755            auth_key    BLOB,
1756            first_salt  INTEGER NOT NULL DEFAULT 0,
1757            time_offset INTEGER NOT NULL DEFAULT 0,
1758            PRIMARY KEY (dc_id, flags)
1759        );
1760        CREATE TABLE IF NOT EXISTS update_state (
1761            id   INTEGER PRIMARY KEY CHECK (id = 1),
1762            pts  INTEGER NOT NULL DEFAULT 0,
1763            qts  INTEGER NOT NULL DEFAULT 0,
1764            date INTEGER NOT NULL DEFAULT 0,
1765            seq  INTEGER NOT NULL DEFAULT 0
1766        );
1767        CREATE TABLE IF NOT EXISTS channel_pts (
1768            channel_id INTEGER PRIMARY KEY,
1769            pts        INTEGER NOT NULL
1770        );
1771        CREATE TABLE IF NOT EXISTS peers (
1772            id          INTEGER PRIMARY KEY,
1773            access_hash INTEGER NOT NULL,
1774            is_channel  INTEGER NOT NULL DEFAULT 0,
1775            is_chat     INTEGER NOT NULL DEFAULT 0
1776        );
1777        CREATE TABLE IF NOT EXISTS min_peers (
1778            user_id INTEGER PRIMARY KEY,
1779            peer_id INTEGER NOT NULL,
1780            msg_id  INTEGER NOT NULL
1781        );
1782    ";
1783
1784    fn block<F, T>(fut: F) -> io::Result<T>
1785    where
1786        F: std::future::Future<Output = Result<T, libsql::Error>>,
1787    {
1788        tokio::runtime::Handle::current()
1789            .block_on(fut)
1790            .map_err(io::Error::other)
1791    }
1792
1793    async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1794        conn.execute_batch(Self::SCHEMA).await
1795    }
1796
1797    /// Open a local file database.
1798    pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1799        let path = path.into();
1800        let label = path.display().to_string();
1801        let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1802        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1803        Self::block(Self::apply_schema(&conn))?;
1804        Ok(Self {
1805            conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1806            label,
1807        })
1808    }
1809
1810    /// Open an in-process in-memory database (useful for tests).
1811    pub fn in_memory() -> io::Result<Self> {
1812        let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1813        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1814        Self::block(Self::apply_schema(&conn))?;
1815        Ok(Self {
1816            conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1817            label: ":memory:".into(),
1818        })
1819    }
1820
1821    /// Connect to a remote Turso database.
1822    pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1823        let url = url.into();
1824        let label = url.clone();
1825        let db = Self::block(async {
1826            libsql::Builder::new_remote(url, auth_token.into())
1827                .build()
1828                .await
1829        })?;
1830        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1831        Self::block(Self::apply_schema(&conn))?;
1832        Ok(Self {
1833            conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1834            label,
1835        })
1836    }
1837
1838    /// Open an embedded replica (local file + Turso remote sync).
1839    pub fn open_replica(
1840        path: impl Into<PathBuf>,
1841        url: impl Into<String>,
1842        auth_token: impl Into<String>,
1843    ) -> io::Result<Self> {
1844        let path = path.into();
1845        let label = format!("{} (replica of {})", path.display(), url.into());
1846        let db = Self::block(async {
1847            libsql::Builder::new_remote_replica(path, url.into(), auth_token.into())
1848                .build()
1849                .await
1850        })?;
1851        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1852        Self::block(Self::apply_schema(&conn))?;
1853        Ok(Self {
1854            conn: std::sync::Arc::new(tokio::sync::Mutex::new(conn)),
1855            label,
1856        })
1857    }
1858
1859    async fn read_session_async(
1860        conn: &libsql::Connection,
1861    ) -> Result<PersistedSession, libsql::Error> {
1862        use libsql::de;
1863
1864        // home_dc_id
1865        let home_dc_id: i32 = conn
1866            .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
1867            .await?
1868            .next()
1869            .await?
1870            .map(|r| r.get::<i32>(0))
1871            .transpose()?
1872            .unwrap_or(0);
1873
1874        // dcs
1875        let mut rows = conn
1876            .query(
1877                "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
1878                (),
1879            )
1880            .await?;
1881        let mut dcs = Vec::new();
1882        while let Some(row) = rows.next().await? {
1883            let dc_id: i32 = row.get(0)?;
1884            let flags_raw: u8 = row.get::<i64>(1)? as u8;
1885            let addr: String = row.get(2)?;
1886            let key_blob: Option<Vec<u8>> = row.get(3)?;
1887            let first_salt: i64 = row.get(4)?;
1888            let time_offset: i32 = row.get(5)?;
1889            let auth_key = match key_blob {
1890                Some(b) if b.len() == 256 => {
1891                    let mut k = [0u8; 256];
1892                    k.copy_from_slice(&b);
1893                    Some(k)
1894                }
1895                Some(b) => {
1896                    return Err(libsql::Error::Misuse(format!(
1897                        "auth_key blob must be 256 bytes, got {}",
1898                        b.len()
1899                    )));
1900                }
1901                None => None,
1902            };
1903            dcs.push(DcEntry {
1904                dc_id,
1905                addr,
1906                auth_key,
1907                first_salt,
1908                time_offset,
1909                flags: DcFlags(flags_raw),
1910            });
1911        }
1912
1913        // update_state
1914        let mut us_row = conn
1915            .query(
1916                "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1917                (),
1918            )
1919            .await?;
1920        let updates_state = if let Some(r) = us_row.next().await? {
1921            UpdatesStateSnap {
1922                pts: r.get(0)?,
1923                qts: r.get(1)?,
1924                date: r.get(2)?,
1925                seq: r.get(3)?,
1926                channels: vec![],
1927            }
1928        } else {
1929            UpdatesStateSnap::default()
1930        };
1931
1932        // channel_pts
1933        let mut ch_rows = conn
1934            .query("SELECT channel_id, pts FROM channel_pts", ())
1935            .await?;
1936        let mut channels = Vec::new();
1937        while let Some(r) = ch_rows.next().await? {
1938            channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
1939        }
1940
1941        // peers
1942        let mut peer_rows = conn
1943            .query("SELECT id, access_hash, is_channel, is_chat FROM peers", ())
1944            .await?;
1945        let mut peers = Vec::new();
1946        while let Some(r) = peer_rows.next().await? {
1947            peers.push(CachedPeer {
1948                id: r.get(0)?,
1949                access_hash: r.get(1)?,
1950                is_channel: r.get::<i32>(2)? != 0,
1951                is_chat: r.get::<i32>(3)? != 0,
1952            });
1953        }
1954
1955        // min_peers
1956        let mut min_rows = conn
1957            .query("SELECT user_id, peer_id, msg_id FROM min_peers", ())
1958            .await?;
1959        let mut min_peers = Vec::new();
1960        while let Some(r) = min_rows.next().await? {
1961            min_peers.push(CachedMinPeer {
1962                user_id: r.get(0)?,
1963                peer_id: r.get(1)?,
1964                msg_id: r.get(2)?,
1965            });
1966        }
1967
1968        Ok(PersistedSession {
1969            home_dc_id,
1970            dcs,
1971            updates_state: UpdatesStateSnap {
1972                channels,
1973                ..updates_state
1974            },
1975            peers,
1976            min_peers,
1977        })
1978    }
1979
1980    async fn write_session_async(
1981        conn: &libsql::Connection,
1982        s: &PersistedSession,
1983    ) -> Result<(), libsql::Error> {
1984        conn.execute_batch("BEGIN IMMEDIATE").await?;
1985
1986        conn.execute(
1987            "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1988             ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1989            libsql::params![s.home_dc_id],
1990        )
1991        .await?;
1992
1993        conn.execute("DELETE FROM dcs", ()).await?;
1994        for d in &s.dcs {
1995            conn.execute(
1996                "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1997                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1998                libsql::params![
1999                    d.dc_id,
2000                    d.flags.0 as i64,
2001                    d.addr.clone(),
2002                    d.auth_key.map(|k| k.to_vec()),
2003                    d.first_salt,
2004                    d.time_offset,
2005                ],
2006            )
2007            .await?;
2008        }
2009
2010        let us = &s.updates_state;
2011        conn.execute(
2012            "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
2013             ON CONFLICT(id) DO UPDATE SET
2014               pts  = MAX(excluded.pts,  update_state.pts),
2015               qts  = MAX(excluded.qts,  update_state.qts),
2016               date = excluded.date,
2017               seq  = excluded.seq",
2018            libsql::params![us.pts, us.qts, us.date, us.seq],
2019        )
2020        .await?;
2021
2022        conn.execute("DELETE FROM channel_pts", ()).await?;
2023        for &(cid, cpts) in &us.channels {
2024            conn.execute(
2025                "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
2026                libsql::params![cid, cpts],
2027            )
2028            .await?;
2029        }
2030
2031        conn.execute("DELETE FROM peers", ()).await?;
2032        for p in &s.peers {
2033            conn.execute(
2034                "INSERT INTO peers (id, access_hash, is_channel, is_chat) VALUES (?1,?2,?3,?4)",
2035                libsql::params![p.id, p.access_hash, p.is_channel as i32, p.is_chat as i32],
2036            )
2037            .await?;
2038        }
2039
2040        conn.execute("DELETE FROM min_peers", ()).await?;
2041        for m in &s.min_peers {
2042            conn.execute(
2043                "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1,?2,?3)",
2044                libsql::params![m.user_id, m.peer_id, m.msg_id],
2045            )
2046            .await?;
2047        }
2048
2049        conn.execute_batch("COMMIT").await
2050    }
2051}
2052
2053#[cfg(feature = "libsql-session")]
2054impl SessionBackend for LibSqlBackend {
2055    fn save(&self, session: &PersistedSession) -> io::Result<()> {
2056        let conn = self.conn.clone();
2057        let session = session.clone();
2058        Self::block(async move {
2059            let conn = conn.lock().await;
2060            Self::write_session_async(&conn, session).await
2061        })
2062    }
2063
2064    fn load(&self) -> io::Result<Option<PersistedSession>> {
2065        let conn = self.conn.clone();
2066        let count: i64 = Self::block(async move {
2067            let conn = conn.lock().await;
2068            let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
2069            Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
2070        })?;
2071        if count == 0 {
2072            return Ok(None);
2073        }
2074        let conn = self.conn.clone();
2075        Self::block(async move {
2076            let conn = conn.lock().await;
2077            Self::read_session_async(&conn).await
2078        })
2079        .map(Some)
2080    }
2081
2082    fn delete(&self) -> io::Result<()> {
2083        let conn = self.conn.clone();
2084        Self::block(async move {
2085            let conn = conn.lock().await;
2086            conn.execute_batch(
2087                "BEGIN IMMEDIATE;
2088                 DELETE FROM meta;
2089                 DELETE FROM dcs;
2090                 DELETE FROM update_state;
2091                 DELETE FROM channel_pts;
2092                 DELETE FROM peers;
2093                 DELETE FROM min_peers;
2094                 COMMIT;",
2095            )
2096            .await
2097        })
2098    }
2099
2100    fn name(&self) -> &str {
2101        &self.label
2102    }
2103
2104    // Granular overrides
2105
2106    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
2107        let conn = self.conn.clone();
2108        let (dc_id, addr, key, salt, off, flags) = (
2109            entry.dc_id,
2110            entry.addr.clone(),
2111            entry.auth_key.map(|k| k.to_vec()),
2112            entry.first_salt,
2113            entry.time_offset,
2114            entry.flags.0 as i64,
2115        );
2116        Self::block(async move {
2117            let conn = conn.lock().await;
2118            conn.execute(
2119                "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2120                 VALUES (?1,?6,?2,?3,?4,?5)
2121                 ON CONFLICT(dc_id, flags) DO UPDATE SET
2122                   addr=excluded.addr, auth_key=excluded.auth_key,
2123                   first_salt=excluded.first_salt, time_offset=excluded.time_offset",
2124                libsql::params![dc_id, addr, key, salt, off, flags],
2125            )
2126            .await
2127            .map(|_| ())
2128        })
2129    }
2130
2131    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
2132        let conn = self.conn.clone();
2133        Self::block(async move {
2134            let conn = conn.lock().await;
2135            conn.execute(
2136                "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
2137                 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
2138                libsql::params![dc_id],
2139            )
2140            .await
2141            .map(|_| ())
2142        })
2143    }
2144
2145    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
2146        let conn = self.conn.clone();
2147        Self::block(async move {
2148            let conn = conn.lock().await;
2149            match update {
2150                UpdateStateChange::All(snap) => {
2151                    conn.execute(
2152                        "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
2153                         ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
2154                         date=excluded.date,seq=excluded.seq",
2155                        libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
2156                    )
2157                    .await?;
2158                    conn.execute("DELETE FROM channel_pts", ()).await?;
2159                    for &(cid, cpts) in &snap.channels {
2160                        conn.execute(
2161                            "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
2162                            libsql::params![cid, cpts],
2163                        )
2164                        .await?;
2165                    }
2166                    Ok(())
2167                }
2168                UpdateStateChange::Primary { pts, date, seq } => conn
2169                    .execute(
2170                        "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
2171                         ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
2172                         seq=excluded.seq",
2173                        libsql::params![pts, date, seq],
2174                    )
2175                    .await
2176                    .map(|_| ()),
2177                UpdateStateChange::Secondary { qts } => conn
2178                    .execute(
2179                        "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2180                         ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2181                        libsql::params![qts],
2182                    )
2183                    .await
2184                    .map(|_| ()),
2185                UpdateStateChange::Channel { id, pts } => conn
2186                    .execute(
2187                        "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2188                         ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2189                        libsql::params![id, pts],
2190                    )
2191                    .await
2192                    .map(|_| ()),
2193            }
2194        })
2195    }
2196
2197    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2198        let conn = self.conn.clone();
2199        let (id, hash, is_ch, is_ct) = (
2200            peer.id,
2201            peer.access_hash,
2202            peer.is_channel as i32,
2203            peer.is_chat as i32,
2204        );
2205        Self::block(async move {
2206            let conn = conn.lock().await;
2207            conn.execute(
2208                "INSERT INTO peers (id,access_hash,is_channel,is_chat) VALUES (?1,?2,?3,?4)
2209                 ON CONFLICT(id) DO UPDATE SET
2210                   access_hash=excluded.access_hash,
2211                   is_channel=excluded.is_channel,
2212                   is_chat=excluded.is_chat",
2213                libsql::params![id, hash, is_ch, is_ct],
2214            )
2215            .await
2216            .map(|_| ())
2217        })
2218    }
2219}