Skip to main content

ferogram_session/
lib.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2//
3// ferogram: async Telegram MTProto client in Rust
4// https://github.com/ankit-chaubey/ferogram
5//
6// Licensed under either the MIT License or the Apache License 2.0.
7// See the LICENSE-MIT or LICENSE-APACHE file in this repository:
8// https://github.com/ankit-chaubey/ferogram
9//
10// Feel free to use, modify, and share this code.
11// Please keep this notice when redistributing.
12
13#![deny(unsafe_code)]
14#![cfg_attr(docsrs, feature(doc_cfg))]
15#![doc(html_root_url = "https://docs.rs/ferogram-session/0.6.3")]
16//! Session persistence types and storage backends for ferogram.
17//!
18//! This crate is part of [ferogram](https://crates.io/crates/ferogram), an async Rust
19//! MTProto client built by [Ankit Chaubey](https://github.com/ankit-chaubey).
20//!
21//! - Channel: [t.me/Ferogram](https://t.me/Ferogram)
22//! - Chat: [t.me/FerogramChat](https://t.me/FerogramChat)
23//!
24//! Most users do not use this crate directly. The `ferogram` crate wires it up
25//! automatically when you call `Client::builder().session(...)` or
26//! `.session_string(...)`.
27//!
28//! # What's in here
29//!
30//! - [`PersistedSession`]: the serializable session struct. Holds the DC table
31//!   (one `AuthKey` + salt + time offset per DC), update sequence counters
32//!   (PTS, QTS, date, seq), and the peer access-hash cache.
33//! - [`SessionBackend`]: the trait all backends implement. A single method:
34//!   `save(&PersistedSession)` and `load() -> Option<PersistedSession>`.
35//! - [`BinaryFileBackend`]: stores the session as a binary file on disk.
36//!   Good for bots and scripts. No extra dependencies.
37//! - [`InMemoryBackend`]: keeps everything in memory. Nothing survives process
38//!   exit. Used for tests and ephemeral tasks.
39//! - [`StringSessionBackend`]: serializes the session to a base64 string.
40//!   Useful for serverless environments where you store state in an env var or
41//!   database column. Load it with `Client::builder().session_string(s)`.
42//! - `SqliteBackend`: SQLite-backed storage via rusqlite. Behind the
43//!   `sqlite-session` feature flag. Good for local multi-account tooling.
44//! - `LibSqlBackend`: libSQL / Turso backend. Behind `libsql-session`.
45//!   For distributed or edge-hosted session storage.
46//!
47//! You can also implement `SessionBackend` yourself for Redis, PostgreSQL, or
48//! anything else.
49//!
50//! # Binary format
51//!
52//! The file backends start with a version byte:
53//! - `0x01`: legacy (DC table only, no update state or peer cache).
54//! - `0x02`: current (DC table + update state + peer cache).
55//! - `0x06`: adds per-channel `ChannelKind` byte to each peer entry.
56//!
57//! `load()` handles all versions. `save()` always writes v6.
58//!
59//! # Example: export and re-import a session
60//!
61//! ```rust,ignore
62//! # async fn example(client: ferogram::Client) -> anyhow::Result<()> {
63//! // Export
64//! let s = client.export_session_string().await?;
65//!
66//! // Later, in another process or after a restart:
67//! let (client, _) = ferogram::Client::builder()
68//!     .api_id(12345)
69//!     .api_hash("api_hash")
70//!     .session_string(s)
71//!     .connect()
72//!     .await?;
73//! # Ok(())
74//! # }
75//! ```
76
77use std::collections::HashMap;
78use std::io::{self, ErrorKind};
79use std::path::Path;
80
81#[cfg(feature = "serde")]
82mod auth_key_serde {
83    use serde::{Deserialize, Deserializer, Serializer};
84
85    /// `serde(with = "auth_key_serde")` shim: `[u8; 256]` has no built-in
86    /// `Serialize`, so this writes it as a plain byte sequence.
87    pub fn serialize<S>(value: &Option<[u8; 256]>, s: S) -> Result<S::Ok, S::Error>
88    where
89        S: Serializer,
90    {
91        match value {
92            Some(k) => s.serialize_some(k.as_slice()),
93            None => s.serialize_none(),
94        }
95    }
96
97    /// Read side of [`serialize`]: decodes a byte sequence back into the
98    /// fixed-size array, rejecting anything that isn't exactly 256 bytes.
99    pub fn deserialize<'de, D>(d: D) -> Result<Option<[u8; 256]>, D::Error>
100    where
101        D: Deserializer<'de>,
102    {
103        let opt: Option<Vec<u8>> = Option::deserialize(d)?;
104        match opt {
105            None => Ok(None),
106            Some(v) => {
107                let arr: [u8; 256] = v
108                    .try_into()
109                    .map_err(|_| serde::de::Error::custom("auth_key must be exactly 256 bytes"))?;
110                Ok(Some(arr))
111            }
112        }
113    }
114}
115
116/// Per-DC option flags.
117///
118/// Stored in the session (v3+) so media DCs survive restarts.
119#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
120#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
121pub struct DcFlags(pub u8);
122
123impl DcFlags {
124    pub const NONE: DcFlags = DcFlags(0);
125    pub const IPV6: DcFlags = DcFlags(1 << 0);
126    pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
127    pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
128    pub const CDN: DcFlags = DcFlags(1 << 3);
129    pub const STATIC: DcFlags = DcFlags(1 << 4);
130
131    /// Whether every bit set in `other` is also set in `self`.
132    pub fn contains(self, other: DcFlags) -> bool {
133        self.0 & other.0 == other.0
134    }
135
136    /// OR `flag`'s bit(s) into `self` in place.
137    pub fn set(&mut self, flag: DcFlags) {
138        self.0 |= flag.0;
139    }
140}
141
142impl std::ops::BitOr for DcFlags {
143    type Output = DcFlags;
144    fn bitor(self, rhs: DcFlags) -> DcFlags {
145        DcFlags(self.0 | rhs.0)
146    }
147}
148
149/// One entry in the DC address table.
150#[derive(Clone, Debug)]
151#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
152pub struct DcEntry {
153    pub dc_id: i32,
154    pub addr: String,
155    #[cfg_attr(feature = "serde", serde(with = "auth_key_serde"))]
156    pub auth_key: Option<[u8; 256]>,
157    pub first_salt: i64,
158    pub time_offset: i32,
159    /// DC capability flags (IPv6, media-only, CDN, ...).
160    pub flags: DcFlags,
161}
162
163impl DcEntry {
164    /// Returns `true` when this entry represents an IPv6 address.
165    #[inline]
166    pub fn is_ipv6(&self) -> bool {
167        self.flags.contains(DcFlags::IPV6)
168    }
169
170    /// Parse the stored `"ip:port"` / `"[ipv6]:port"` address into a
171    /// [`std::net::SocketAddr`].
172    ///
173    /// Both formats are valid:
174    /// - IPv4: `"149.154.175.53:443"`
175    /// - IPv6: `"[2001:b28:f23d:f001::a]:443"`
176    pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
177        self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
178            io::Error::new(
179                io::ErrorKind::InvalidData,
180                format!("invalid DC address: {:?}", self.addr),
181            )
182        })
183    }
184
185    /// Construct a `DcEntry` from separate IP string, port, and flags.
186    ///
187    /// IPv6 addresses are automatically wrapped in brackets so that
188    /// `socket_addr()` can round-trip them correctly:
189    ///
190    /// ```text
191    /// DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6)
192    /// // addr = "[2001:b28:f23d:f001::a]:443"
193    /// ```
194    ///
195    /// This is the preferred constructor when processing `help.getConfig`
196    /// `DcOption` objects from the Telegram API.
197    pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
198        // IPv6 addresses contain colons; wrap in brackets for SocketAddr compat.
199        let addr = if ip.contains(':') {
200            format!("[{ip}]:{port}")
201        } else {
202            format!("{ip}:{port}")
203        };
204        Self {
205            dc_id,
206            addr,
207            auth_key: None,
208            first_salt: 0,
209            time_offset: 0,
210            flags,
211        }
212    }
213}
214
215/// Snapshot of the MTProto update-sequence state that we persist so that
216/// `catch_up: true` can call `updates.getDifference` with the *pre-shutdown* pts.
217#[derive(Clone, Debug, Default)]
218#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
219pub struct UpdatesStateSnap {
220    /// Main persistence counter (messages, non-channel updates).
221    pub pts: i32,
222    /// Secondary counter for secret chats.
223    pub qts: i32,
224    /// Date of the last received update (Unix timestamp).
225    pub date: i32,
226    /// Combined-container sequence number.
227    pub seq: i32,
228    /// Per-channel persistence counters.  `(channel_id, pts)`.
229    pub channels: Vec<(i64, i32)>,
230}
231
232impl UpdatesStateSnap {
233    /// Returns `true` when we have a real state from the server (pts > 0).
234    #[inline]
235    pub fn is_initialised(&self) -> bool {
236        self.pts > 0
237    }
238
239    /// Advance (or insert) a per-channel pts value.
240    pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
241        if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
242            entry.1 = pts;
243        } else {
244            self.channels.push((channel_id, pts));
245        }
246    }
247
248    /// Look up the stored pts for a channel, returns 0 if unknown.
249    pub fn channel_pts(&self, channel_id: i64) -> i32 {
250        self.channels
251            .iter()
252            .find(|c| c.0 == channel_id)
253            .map(|c| c.1)
254            .unwrap_or(0)
255    }
256}
257
258/// The kind of a Telegram channel, persisted in the session so that
259/// `is_megagroup()` / `is_broadcast()` work correctly after a restart.
260#[derive(Clone, Copy, Debug, PartialEq, Eq)]
261#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
262#[repr(u8)]
263pub enum ChannelKind {
264    /// A broadcast channel (posts only).
265    Broadcast = 0,
266    /// A supergroup / megagroup (all members can post).
267    Megagroup = 1,
268    /// A gigagroup / broadcast group (large public broadcast supergroup).
269    Gigagroup = 2,
270}
271
272impl ChannelKind {
273    /// Decode from the session byte. Returns `None` for any byte that isn't
274    /// a known variant, rather than guessing.
275    pub fn from_byte(b: u8) -> Option<Self> {
276        match b {
277            0 => Some(Self::Broadcast),
278            1 => Some(Self::Megagroup),
279            2 => Some(Self::Gigagroup),
280            _ => None,
281        }
282    }
283
284    /// Encode for storing in the session: the read side is [`Self::from_byte`].
285    pub fn to_byte(self) -> u8 {
286        self as u8
287    }
288}
289
290/// A cached access-hash entry so that the peer can be addressed across restarts
291/// without re-resolving it from Telegram.
292#[derive(Clone, Debug)]
293#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
294pub struct CachedPeer {
295    /// Bare Telegram peer ID (always positive).
296    pub id: i64,
297    /// Access hash bound to the current session.
298    /// Always 0 for regular group chats (they need no access_hash).
299    pub access_hash: i64,
300    /// `true` → channel / supergroup.  `false` → user or regular group.
301    pub is_channel: bool,
302    /// `true` → regular group chat (Chat::Chat / ChatForbidden).
303    /// When true, access_hash is meaningless (groups need no hash).
304    pub is_chat: bool,
305    /// For channel peers: the kind (broadcast / megagroup / gigagroup).
306    /// `None` for users, regular groups, and channels loaded from a pre-v6 session.
307    pub channel_kind: Option<ChannelKind>,
308}
309
310/// A min-user context entry: the user was seen with `min=true` (access_hash
311/// not usable directly) so we store the peer+message where they appeared so
312/// that `InputPeerUserFromMessage` can be constructed on restart.
313#[derive(Clone, Debug)]
314#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
315pub struct CachedMinPeer {
316    /// The min user's ID.
317    pub user_id: i64,
318    /// The channel/chat/user ID of the peer that contained the message.
319    pub peer_id: i64,
320    /// The message ID within that peer.
321    pub msg_id: i32,
322}
323
324/// Everything that needs to survive a process restart.
325#[derive(Clone, Debug, Default)]
326#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
327pub struct PersistedSession {
328    pub home_dc_id: i32,
329    pub dcs: Vec<DcEntry>,
330    /// Update counters to enable reliable catch-up after a disconnect.
331    pub updates_state: UpdatesStateSnap,
332    /// Peer access-hash cache so that the client can reach out to any previously
333    /// seen user or channel without re-resolving them.
334    pub peers: Vec<CachedPeer>,
335    /// Min-user message contexts: users seen with `min=true` that can only be
336    /// addressed via `InputPeerUserFromMessage`.
337    pub min_peers: Vec<CachedMinPeer>,
338}
339
340impl PersistedSession {
341    /// Encode the session to raw bytes (v2 binary format).
342    pub fn to_bytes(&self) -> Vec<u8> {
343        let mut b = Vec::with_capacity(512);
344
345        b.push(0x06u8); // version
346
347        b.extend_from_slice(&self.home_dc_id.to_le_bytes());
348
349        b.push(self.dcs.len() as u8);
350        for d in &self.dcs {
351            b.extend_from_slice(&d.dc_id.to_le_bytes());
352            match &d.auth_key {
353                Some(k) => {
354                    b.push(1);
355                    b.extend_from_slice(k);
356                }
357                None => {
358                    b.push(0);
359                }
360            }
361            b.extend_from_slice(&d.first_salt.to_le_bytes());
362            b.extend_from_slice(&d.time_offset.to_le_bytes());
363            let ab = d.addr.as_bytes();
364            b.push(ab.len() as u8);
365            b.extend_from_slice(ab);
366            b.push(d.flags.0);
367        }
368
369        b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
370        b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
371        b.extend_from_slice(&self.updates_state.date.to_le_bytes());
372        b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
373        let ch = &self.updates_state.channels;
374        b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
375        for &(cid, cpts) in ch {
376            b.extend_from_slice(&cid.to_le_bytes());
377            b.extend_from_slice(&cpts.to_le_bytes());
378        }
379
380        // v6 peer encoding: peer_type byte (0=user,1=channel,2=chat) + channel_kind byte
381        b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
382        for p in &self.peers {
383            b.extend_from_slice(&p.id.to_le_bytes());
384            b.extend_from_slice(&p.access_hash.to_le_bytes());
385            let peer_type: u8 = if p.is_chat {
386                2
387            } else if p.is_channel {
388                1
389            } else {
390                0
391            };
392            b.push(peer_type);
393            // channel_kind byte: 0xFF = absent, otherwise ChannelKind::to_byte()
394            b.push(p.channel_kind.map(|k| k.to_byte()).unwrap_or(0xFF));
395        }
396
397        b.extend_from_slice(&(self.min_peers.len() as u16).to_le_bytes());
398        for m in &self.min_peers {
399            b.extend_from_slice(&m.user_id.to_le_bytes());
400            b.extend_from_slice(&m.peer_id.to_le_bytes());
401            b.extend_from_slice(&m.msg_id.to_le_bytes());
402        }
403
404        b
405    }
406
407    /// Atomically save the session to `path`.
408    ///
409    /// Writes to `<path>.<seq>.tmp` (unique per call) then renames into place.
410    /// A fixed `.tmp` extension causes OS error 2 (ERROR_FILE_NOT_FOUND) on
411    /// Windows when two concurrent persist_state calls race: thread A renames
412    /// `.tmp` away while thread B's rename finds the source gone.
413    pub fn save(&self, path: &Path) -> io::Result<()> {
414        use std::sync::atomic::{AtomicU64, Ordering};
415        static SEQ: AtomicU64 = AtomicU64::new(0);
416        let n = SEQ.fetch_add(1, Ordering::Relaxed);
417        let tmp = path.with_extension(format!("{n}.tmp"));
418        std::fs::write(&tmp, self.to_bytes())?;
419        std::fs::rename(&tmp, path).inspect_err(|_e| {
420            let _ = std::fs::remove_file(&tmp);
421        })
422    }
423
424    /// Decode a session from raw bytes (v1 or v2 binary format).
425    pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
426        if buf.is_empty() {
427            return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
428        }
429
430        let mut p = 0usize;
431
432        macro_rules! r {
433            ($n:expr) => {{
434                if p + $n > buf.len() {
435                    return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
436                }
437                let s = &buf[p..p + $n];
438                p += $n;
439                s
440            }};
441        }
442        macro_rules! r_i32 {
443            () => {
444                i32::from_le_bytes(r!(4).try_into().unwrap())
445            };
446        }
447        macro_rules! r_i64 {
448            () => {
449                i64::from_le_bytes(r!(8).try_into().unwrap())
450            };
451        }
452        macro_rules! r_u8 {
453            () => {
454                r!(1)[0]
455            };
456        }
457        macro_rules! r_u16 {
458            () => {
459                u16::from_le_bytes(r!(2).try_into().unwrap())
460            };
461        }
462
463        let first_byte = r_u8!();
464
465        let (home_dc_id, version) = if first_byte == 0x06 {
466            (r_i32!(), 6u8)
467        } else if first_byte == 0x05 {
468            (r_i32!(), 5u8)
469        } else if first_byte == 0x04 {
470            (r_i32!(), 4u8)
471        } else if first_byte == 0x03 {
472            (r_i32!(), 3u8)
473        } else if first_byte == 0x02 {
474            (r_i32!(), 2u8)
475        } else {
476            let rest = r!(3);
477            let mut bytes = [0u8; 4];
478            bytes[0] = first_byte;
479            bytes[1..4].copy_from_slice(rest);
480            (i32::from_le_bytes(bytes), 1u8)
481        };
482
483        let dc_count = r_u8!() as usize;
484        let mut dcs = Vec::with_capacity(dc_count);
485        for _ in 0..dc_count {
486            let dc_id = r_i32!();
487            let has_key = r_u8!();
488            let auth_key = if has_key == 1 {
489                let mut k = [0u8; 256];
490                k.copy_from_slice(r!(256));
491                Some(k)
492            } else {
493                None
494            };
495            let first_salt = r_i64!();
496            let time_offset = r_i32!();
497            let al = r_u8!() as usize;
498            let addr = String::from_utf8_lossy(r!(al)).into_owned();
499            let flags = if version >= 3 {
500                DcFlags(r_u8!())
501            } else {
502                DcFlags::NONE
503            };
504            dcs.push(DcEntry {
505                dc_id,
506                addr,
507                auth_key,
508                first_salt,
509                time_offset,
510                flags,
511            });
512        }
513
514        if version < 2 {
515            return Ok(Self {
516                home_dc_id,
517                dcs,
518                updates_state: UpdatesStateSnap::default(),
519                peers: Vec::new(),
520                min_peers: Vec::new(),
521            });
522        }
523
524        let pts = r_i32!();
525        let qts = r_i32!();
526        let date = r_i32!();
527        let seq = r_i32!();
528        let ch_count = r_u16!() as usize;
529        let mut channels = Vec::with_capacity(ch_count);
530        for _ in 0..ch_count {
531            let cid = r_i64!();
532            let cpts = r_i32!();
533            channels.push((cid, cpts));
534        }
535
536        let peer_count = r_u16!() as usize;
537        let mut peers = Vec::with_capacity(peer_count);
538        for _ in 0..peer_count {
539            let id = r_i64!();
540            let access_hash = r_i64!();
541            // v5+: type byte 0=user, 1=channel, 2=chat; v2-v4: 0=user, 1=channel
542            let peer_type = r_u8!();
543            let is_channel = peer_type == 1;
544            let is_chat = peer_type == 2;
545            // v6+: channel_kind byte (0xFF = absent)
546            let channel_kind = if version >= 6 {
547                let kb = r_u8!();
548                if kb == 0xFF {
549                    None
550                } else {
551                    ChannelKind::from_byte(kb)
552                }
553            } else {
554                None
555            };
556            peers.push(CachedPeer {
557                id,
558                access_hash,
559                is_channel,
560                is_chat,
561                channel_kind,
562            });
563        }
564
565        // v4+: min-user contexts
566        let min_peers = if version >= 4 {
567            let count = r_u16!() as usize;
568            let mut v = Vec::with_capacity(count);
569            for _ in 0..count {
570                let user_id = r_i64!();
571                let peer_id = r_i64!();
572                let msg_id = r_i32!();
573                v.push(CachedMinPeer {
574                    user_id,
575                    peer_id,
576                    msg_id,
577                });
578            }
579            v
580        } else {
581            Vec::new()
582        };
583
584        Ok(Self {
585            home_dc_id,
586            dcs,
587            updates_state: UpdatesStateSnap {
588                pts,
589                qts,
590                date,
591                seq,
592                channels,
593            },
594            peers,
595            min_peers,
596        })
597    }
598
599    /// Decode a session from a URL-safe base64 string produced by `to_string`.
600    pub fn from_string(s: &str) -> io::Result<Self> {
601        use base64::Engine as _;
602        let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
603            .decode(s.trim())
604            .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
605        Self::from_bytes(&bytes)
606    }
607
608    /// Read and decode a session from `path` (the binary v2 format written
609    /// by [`Self::to_bytes`]).
610    pub fn load(path: &Path) -> io::Result<Self> {
611        let buf = std::fs::read(path)?;
612        Self::from_bytes(&buf)
613    }
614
615    // DC address helpers
616
617    /// Find the best DC entry for a given DC ID.
618    ///
619    /// When `prefer_ipv6` is `true`, returns the IPv6 entry if one is
620    /// stored, falling back to IPv4.  When `false`, returns IPv4,
621    /// falling back to IPv6.  Returns `None` only when the DC ID is
622    /// completely unknown.
623    ///
624    /// This correctly handles the case where both an IPv4 and an IPv6
625    /// `DcEntry` exist for the same `dc_id` (different `flags` bitmask).
626    pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
627        let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
628        candidates.peek()?;
629        // Collect so we can search twice
630        let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
631        // Preferred family first, fall back to whatever is available
632        cands
633            .iter()
634            .copied()
635            .find(|d| d.is_ipv6() == prefer_ipv6)
636            .or_else(|| cands.first().copied())
637    }
638
639    /// Iterate over every stored DC entry for a given DC ID.
640    ///
641    /// Typically yields one IPv4 and one IPv6 entry per DC ID once
642    /// `help.getConfig` has been applied.
643    pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
644        self.dcs.iter().filter(move |d| d.dc_id == dc_id)
645    }
646}
647
648impl std::fmt::Display for PersistedSession {
649    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
650        use base64::Engine as _;
651        f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
652    }
653}
654
655/// Bootstrap DC address table (fallback if GetConfig fails).
656pub fn default_dc_addresses() -> HashMap<i32, String> {
657    [
658        (1, "149.154.175.53:443"),
659        (2, "149.154.167.51:443"),
660        (3, "149.154.175.100:443"),
661        (4, "149.154.167.91:443"),
662        (5, "91.108.56.130:443"),
663    ]
664    .into_iter()
665    .map(|(id, addr)| (id, addr.to_string()))
666    .collect()
667}
668
669// session_backend
670//
671// Pluggable session storage backend.
672
673use std::path::PathBuf;
674
675// Core trait (unchanged)
676
677/// Synchronous snapshot backend: saves and loads the full session at once.
678///
679/// All built-in backends implement this. Higher-level code should prefer the
680/// extension methods below (`update_dc`, `set_home_dc`, `update_state`) which
681/// avoid unnecessary full-snapshot writes.
682pub trait SessionBackend: Send + Sync {
683    fn save(&self, session: &PersistedSession) -> io::Result<()>;
684    fn load(&self) -> io::Result<Option<PersistedSession>>;
685    fn delete(&self) -> io::Result<()>;
686
687    /// Human-readable name for logging/debug output.
688    fn name(&self) -> &str;
689
690    // Granular helpers (default: load → mutate → save)
691    //
692    // These default implementations are correct but not optimal.
693    // Backends that store data in a database (SQLite, libsql, Redis) should
694    // override them to issue single-row UPDATE statements instead.
695
696    /// Update a single DC entry without rewriting the entire session.
697    ///
698    /// Typically called after:
699    /// - completing a DH handshake on a new DC (to persist its auth key)
700    /// - receiving updated DC addresses from `help.getConfig`
701    ///
702    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
703        let mut s = self.load()?.unwrap_or_default();
704        // Replace existing entry or append
705        if let Some(existing) = s
706            .dcs
707            .iter_mut()
708            .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
709        {
710            *existing = entry.clone();
711        } else {
712            s.dcs.push(entry.clone());
713        }
714        self.save(&s)
715    }
716
717    /// Change the home DC without touching any other session data.
718    ///
719    /// Called after a successful `*_MIGRATE` redirect: the user's account
720    /// now lives on a different DC.
721    ///
722    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
723        let mut s = self.load()?.unwrap_or_default();
724        s.home_dc_id = dc_id;
725        self.save(&s)
726    }
727
728    /// Apply a single update-sequence change without a full save/load.
729    ///
730    ///
731    /// `update` is the new partial or full state to merge in.
732    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
733        let mut s = self.load()?.unwrap_or_default();
734        update.apply_to(&mut s.updates_state);
735        self.save(&s)
736    }
737
738    /// Cache a peer access hash without a full session save.
739    ///
740    /// This is lossy-on-default (full round-trip) but correct.
741    /// Override in SQL backends to issue a single `INSERT OR REPLACE`.
742    ///
743    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
744        let mut s = self.load()?.unwrap_or_default();
745        if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
746            *existing = peer.clone();
747        } else {
748            s.peers.push(peer.clone());
749        }
750        self.save(&s)
751    }
752}
753
754// UpdateStateChange (mirrors  UpdateState enum)
755
756/// A single update-sequence change, applied via [`SessionBackend::apply_update_state`].
757///
758/// Single variant constructed:
759/// ```text
760/// UpdateState::All(updates_state)
761/// UpdateState::Primary { pts, date, seq }
762/// UpdateState::Secondary { qts }
763/// UpdateState::Channel { id, pts }
764/// ```
765///
766/// `All` carries a full [`UpdatesStateSnap`]; the other variants update one
767/// field of the existing snapshot in place.
768#[derive(Debug, Clone)]
769pub enum UpdateStateChange {
770    /// Replace the entire state snapshot.
771    All(UpdatesStateSnap),
772    /// Update main sequence counters only (non-channel).
773    Primary { pts: i32, date: i32, seq: i32 },
774    /// Update the QTS counter (secret chats).
775    Secondary { qts: i32 },
776    /// Update the PTS for a specific channel.
777    Channel { id: i64, pts: i32 },
778}
779
780impl UpdateStateChange {
781    /// Apply `self` to `snap` in-place.
782    pub fn apply_to(&self, snap: &mut UpdatesStateSnap) {
783        match self {
784            Self::All(new_snap) => *snap = new_snap.clone(),
785            Self::Primary { pts, date, seq } => {
786                snap.pts = *pts;
787                snap.date = *date;
788                snap.seq = *seq;
789            }
790            Self::Secondary { qts } => {
791                snap.qts = *qts;
792            }
793            Self::Channel { id, pts } => {
794                // Replace or insert per-channel pts
795                if let Some(existing) = snap.channels.iter_mut().find(|c| c.0 == *id) {
796                    existing.1 = *pts;
797                } else {
798                    snap.channels.push((*id, *pts));
799                }
800            }
801        }
802    }
803}
804
805// BinaryFileBackend
806
807/// Stores the session in a compact binary file (v2 format).
808pub struct BinaryFileBackend {
809    path: PathBuf,
810    /// Serialises concurrent save() calls within the same process so they
811    /// don't interleave on the tmp file even if PersistedSession::save uses
812    /// unique names (belt-and-suspenders; also prevents torn reads of the
813    /// session file from a concurrent load + save).
814    write_lock: std::sync::Mutex<()>,
815}
816
817impl BinaryFileBackend {
818    /// Point a backend at `path`. The file doesn't need to exist yet; it's
819    /// created on the first save.
820    pub fn new(path: impl Into<PathBuf>) -> Self {
821        Self {
822            path: path.into(),
823            write_lock: std::sync::Mutex::new(()),
824        }
825    }
826
827    /// The file path this backend reads from and writes to.
828    pub fn path(&self) -> &std::path::Path {
829        &self.path
830    }
831}
832
833impl SessionBackend for BinaryFileBackend {
834    fn save(&self, session: &PersistedSession) -> io::Result<()> {
835        let _guard = self.write_lock.lock().unwrap();
836        session.save(&self.path)
837    }
838
839    fn load(&self) -> io::Result<Option<PersistedSession>> {
840        if !self.path.exists() {
841            return Ok(None);
842        }
843        match PersistedSession::load(&self.path) {
844            Ok(s) => Ok(Some(s)),
845            Err(e) => {
846                let bak = self.path.with_extension("bak");
847                tracing::warn!(
848                    "[ferogram::session] session file {:?} could not be read ({e}); \
849                     backing it up to {:?} and starting a new session",
850                    self.path,
851                    bak
852                );
853                let _ = std::fs::rename(&self.path, &bak);
854                Ok(None)
855            }
856        }
857    }
858
859    fn delete(&self) -> io::Result<()> {
860        if self.path.exists() {
861            std::fs::remove_file(&self.path)?;
862        }
863        Ok(())
864    }
865
866    fn name(&self) -> &str {
867        "binary-file"
868    }
869
870    // BinaryFileBackend: the default granular impls (load→mutate→save) are
871    // fine since the format is a single compact binary blob. No override needed.
872}
873
874// InMemoryBackend
875
876/// Ephemeral in-process session: nothing persisted to disk.
877///
878/// Override the granular methods to skip the clone overhead of the full
879/// snapshot path (we're already in memory, so direct field mutations are
880/// cheaper than clone→mutate→replace).
881#[derive(Default)]
882pub struct InMemoryBackend {
883    data: std::sync::Mutex<Option<PersistedSession>>,
884}
885
886impl InMemoryBackend {
887    /// An empty in-memory backend; equivalent to [`Default::default`].
888    pub fn new() -> Self {
889        Self::default()
890    }
891
892    /// Test helper: get a snapshot of the current in-memory state.
893    pub fn snapshot(&self) -> Option<PersistedSession> {
894        self.data.lock().unwrap().clone()
895    }
896}
897
898impl SessionBackend for InMemoryBackend {
899    fn save(&self, s: &PersistedSession) -> io::Result<()> {
900        *self.data.lock().unwrap() = Some(s.clone());
901        Ok(())
902    }
903
904    fn load(&self) -> io::Result<Option<PersistedSession>> {
905        Ok(self.data.lock().unwrap().clone())
906    }
907
908    fn delete(&self) -> io::Result<()> {
909        *self.data.lock().unwrap() = None;
910        Ok(())
911    }
912
913    fn name(&self) -> &str {
914        "in-memory"
915    }
916
917    // Granular overrides: cheaper than load→clone→save
918
919    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
920        let mut guard = self.data.lock().unwrap();
921        let s = guard.get_or_insert_with(PersistedSession::default);
922        if let Some(existing) = s
923            .dcs
924            .iter_mut()
925            .find(|d| d.dc_id == entry.dc_id && d.is_ipv6() == entry.is_ipv6())
926        {
927            *existing = entry.clone();
928        } else {
929            s.dcs.push(entry.clone());
930        }
931        Ok(())
932    }
933
934    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
935        let mut guard = self.data.lock().unwrap();
936        let s = guard.get_or_insert_with(PersistedSession::default);
937        s.home_dc_id = dc_id;
938        Ok(())
939    }
940
941    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
942        let mut guard = self.data.lock().unwrap();
943        let s = guard.get_or_insert_with(PersistedSession::default);
944        update.apply_to(&mut s.updates_state);
945        Ok(())
946    }
947
948    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
949        let mut guard = self.data.lock().unwrap();
950        let s = guard.get_or_insert_with(PersistedSession::default);
951        if let Some(existing) = s.peers.iter_mut().find(|p| p.id == peer.id) {
952            *existing = peer.clone();
953        } else {
954            s.peers.push(peer.clone());
955        }
956        Ok(())
957    }
958}
959
960// StringSessionBackend
961
962/// Portable base64 string session backend.
963pub struct StringSessionBackend {
964    data: std::sync::Mutex<String>,
965}
966
967impl StringSessionBackend {
968    /// Seed the backend with an existing session string (as produced by
969    /// [`Self::current`]), or an empty string to start fresh.
970    pub fn new(s: impl Into<String>) -> Self {
971        Self {
972            data: std::sync::Mutex::new(s.into()),
973        }
974    }
975
976    /// The current session, base64-encoded. Updates after every [`save`](SessionBackend::save).
977    pub fn current(&self) -> String {
978        self.data.lock().unwrap().clone()
979    }
980}
981
982impl SessionBackend for StringSessionBackend {
983    fn save(&self, session: &PersistedSession) -> io::Result<()> {
984        *self.data.lock().unwrap() = session.to_string();
985        Ok(())
986    }
987
988    fn load(&self) -> io::Result<Option<PersistedSession>> {
989        let s = self.data.lock().unwrap().clone();
990        if s.trim().is_empty() {
991            return Ok(None);
992        }
993        PersistedSession::from_string(&s).map(Some)
994    }
995
996    fn delete(&self) -> io::Result<()> {
997        *self.data.lock().unwrap() = String::new();
998        Ok(())
999    }
1000
1001    fn name(&self) -> &str {
1002        "string-session"
1003    }
1004}
1005
1006// Tests
1007
1008/// Portable compact string-session encoding/decoding (V1/V2 binary base64).
1009///
1010/// Use via `Client::builder().session_string("ABC...")` which auto-detects
1011/// the format. Use this module directly only when you need to inspect or
1012/// construct a [`StringSession`] manually.
1013pub mod string_session;
1014pub use string_session::{FullSession, Session, StringSession, StringSessionError};
1015
1016#[cfg(test)]
1017mod tests {
1018    use super::*;
1019
1020    fn make_dc(id: i32) -> DcEntry {
1021        DcEntry {
1022            dc_id: id,
1023            addr: format!("1.2.3.{id}:443"),
1024            auth_key: None,
1025            first_salt: 0,
1026            time_offset: 0,
1027            flags: DcFlags::NONE,
1028        }
1029    }
1030
1031    fn make_peer(id: i64, hash: i64) -> CachedPeer {
1032        CachedPeer {
1033            id,
1034            access_hash: hash,
1035            is_channel: false,
1036            is_chat: false,
1037            channel_kind: None,
1038        }
1039    }
1040
1041    // InMemoryBackend: basic save/load
1042
1043    #[test]
1044    fn inmemory_load_returns_none_when_empty() {
1045        let b = InMemoryBackend::new();
1046        assert!(b.load().unwrap().is_none());
1047    }
1048
1049    #[test]
1050    fn inmemory_save_then_load_round_trips() {
1051        let b = InMemoryBackend::new();
1052        let mut s = PersistedSession::default();
1053        s.home_dc_id = 3;
1054        s.dcs.push(make_dc(3));
1055        b.save(&s).unwrap();
1056
1057        let loaded = b.load().unwrap().unwrap();
1058        assert_eq!(loaded.home_dc_id, 3);
1059        assert_eq!(loaded.dcs.len(), 1);
1060    }
1061
1062    #[test]
1063    fn inmemory_delete_clears_state() {
1064        let b = InMemoryBackend::new();
1065        let mut s = PersistedSession::default();
1066        s.home_dc_id = 2;
1067        b.save(&s).unwrap();
1068        b.delete().unwrap();
1069        assert!(b.load().unwrap().is_none());
1070    }
1071
1072    // InMemoryBackend: granular methods
1073
1074    #[test]
1075    fn inmemory_update_dc_inserts_new() {
1076        let b = InMemoryBackend::new();
1077        b.update_dc(&make_dc(4)).unwrap();
1078        let s = b.snapshot().unwrap();
1079        assert_eq!(s.dcs.len(), 1);
1080        assert_eq!(s.dcs[0].dc_id, 4);
1081    }
1082
1083    #[test]
1084    fn inmemory_update_dc_replaces_existing() {
1085        let b = InMemoryBackend::new();
1086        b.update_dc(&make_dc(2)).unwrap();
1087        let mut updated = make_dc(2);
1088        updated.addr = "9.9.9.9:443".to_string();
1089        b.update_dc(&updated).unwrap();
1090
1091        let s = b.snapshot().unwrap();
1092        assert_eq!(s.dcs.len(), 1);
1093        assert_eq!(s.dcs[0].addr, "9.9.9.9:443");
1094    }
1095
1096    #[test]
1097    fn inmemory_set_home_dc() {
1098        let b = InMemoryBackend::new();
1099        b.set_home_dc(5).unwrap();
1100        assert_eq!(b.snapshot().unwrap().home_dc_id, 5);
1101    }
1102
1103    #[test]
1104    fn inmemory_cache_peer_inserts() {
1105        let b = InMemoryBackend::new();
1106        b.cache_peer(&make_peer(100, 0xdeadbeef)).unwrap();
1107        let s = b.snapshot().unwrap();
1108        assert_eq!(s.peers.len(), 1);
1109        assert_eq!(s.peers[0].id, 100);
1110    }
1111
1112    #[test]
1113    fn inmemory_cache_peer_updates_existing() {
1114        let b = InMemoryBackend::new();
1115        b.cache_peer(&make_peer(100, 111)).unwrap();
1116        b.cache_peer(&make_peer(100, 222)).unwrap();
1117        let s = b.snapshot().unwrap();
1118        assert_eq!(s.peers.len(), 1);
1119        assert_eq!(s.peers[0].access_hash, 222);
1120    }
1121
1122    // UpdateStateChange
1123
1124    #[test]
1125    fn update_state_primary() {
1126        let mut snap = UpdatesStateSnap {
1127            pts: 0,
1128            qts: 0,
1129            date: 0,
1130            seq: 0,
1131            channels: vec![],
1132        };
1133        UpdateStateChange::Primary {
1134            pts: 10,
1135            date: 20,
1136            seq: 30,
1137        }
1138        .apply_to(&mut snap);
1139        assert_eq!(snap.pts, 10);
1140        assert_eq!(snap.date, 20);
1141        assert_eq!(snap.seq, 30);
1142        assert_eq!(snap.qts, 0); // untouched
1143    }
1144
1145    #[test]
1146    fn update_state_secondary() {
1147        let mut snap = UpdatesStateSnap {
1148            pts: 5,
1149            qts: 0,
1150            date: 0,
1151            seq: 0,
1152            channels: vec![],
1153        };
1154        UpdateStateChange::Secondary { qts: 99 }.apply_to(&mut snap);
1155        assert_eq!(snap.qts, 99);
1156        assert_eq!(snap.pts, 5); // untouched
1157    }
1158
1159    #[test]
1160    fn update_state_channel_inserts() {
1161        let mut snap = UpdatesStateSnap {
1162            pts: 0,
1163            qts: 0,
1164            date: 0,
1165            seq: 0,
1166            channels: vec![],
1167        };
1168        UpdateStateChange::Channel { id: 12345, pts: 42 }.apply_to(&mut snap);
1169        assert_eq!(snap.channels, vec![(12345, 42)]);
1170    }
1171
1172    #[test]
1173    fn update_state_channel_updates_existing() {
1174        let mut snap = UpdatesStateSnap {
1175            pts: 0,
1176            qts: 0,
1177            date: 0,
1178            seq: 0,
1179            channels: vec![(12345, 10), (67890, 5)],
1180        };
1181        UpdateStateChange::Channel { id: 12345, pts: 99 }.apply_to(&mut snap);
1182        // First channel updated, second untouched
1183        assert_eq!(snap.channels[0], (12345, 99));
1184        assert_eq!(snap.channels[1], (67890, 5));
1185    }
1186
1187    #[test]
1188    fn apply_update_state_via_backend() {
1189        let b = InMemoryBackend::new();
1190        b.apply_update_state(UpdateStateChange::Primary {
1191            pts: 7,
1192            date: 8,
1193            seq: 9,
1194        })
1195        .unwrap();
1196        let s = b.snapshot().unwrap();
1197        assert_eq!(s.updates_state.pts, 7);
1198    }
1199
1200    // Default impl (BinaryFileBackend trait shape via InMemory smoke)
1201
1202    #[test]
1203    fn default_update_dc_via_trait_object() {
1204        let b: Box<dyn SessionBackend> = Box::new(InMemoryBackend::new());
1205        b.update_dc(&make_dc(1)).unwrap();
1206        b.update_dc(&make_dc(2)).unwrap();
1207        // Can't call snapshot() on trait object, but save/load must be consistent
1208        let loaded = b.load().unwrap().unwrap();
1209        assert_eq!(loaded.dcs.len(), 2);
1210    }
1211
1212    // IPv6 tests
1213
1214    fn make_dc_v6(id: i32) -> DcEntry {
1215        DcEntry {
1216            dc_id: id,
1217            addr: format!("[2001:b28:f23d:f00{}::a]:443", id),
1218            auth_key: None,
1219            first_salt: 0,
1220            time_offset: 0,
1221            flags: DcFlags::IPV6,
1222        }
1223    }
1224
1225    #[test]
1226    fn dc_entry_from_parts_ipv4() {
1227        let dc = DcEntry::from_parts(1, "149.154.175.53", 443, DcFlags::NONE);
1228        assert_eq!(dc.addr, "149.154.175.53:443");
1229        assert!(!dc.is_ipv6());
1230        let sa = dc.socket_addr().unwrap();
1231        assert_eq!(sa.port(), 443);
1232    }
1233
1234    #[test]
1235    fn dc_entry_from_parts_ipv6() {
1236        let dc = DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6);
1237        assert_eq!(dc.addr, "[2001:b28:f23d:f001::a]:443");
1238        assert!(dc.is_ipv6());
1239        let sa = dc.socket_addr().unwrap();
1240        assert_eq!(sa.port(), 443);
1241    }
1242
1243    #[test]
1244    fn persisted_session_dc_for_prefers_ipv6() {
1245        let mut s = PersistedSession::default();
1246        s.dcs.push(make_dc(2)); // IPv4
1247        s.dcs.push(make_dc_v6(2)); // IPv6
1248
1249        let v6 = s.dc_for(2, true).unwrap();
1250        assert!(v6.is_ipv6());
1251
1252        let v4 = s.dc_for(2, false).unwrap();
1253        assert!(!v4.is_ipv6());
1254    }
1255
1256    #[test]
1257    fn persisted_session_dc_for_falls_back_when_only_ipv4() {
1258        let mut s = PersistedSession::default();
1259        s.dcs.push(make_dc(3)); // IPv4 only
1260
1261        // Asking for IPv6 should fall back to IPv4
1262        let dc = s.dc_for(3, true).unwrap();
1263        assert!(!dc.is_ipv6());
1264    }
1265
1266    #[test]
1267    fn persisted_session_all_dcs_for_returns_both() {
1268        let mut s = PersistedSession::default();
1269        s.dcs.push(make_dc(1));
1270        s.dcs.push(make_dc_v6(1));
1271        s.dcs.push(make_dc(2));
1272
1273        assert_eq!(s.all_dcs_for(1).count(), 2);
1274        assert_eq!(s.all_dcs_for(2).count(), 1);
1275        assert_eq!(s.all_dcs_for(5).count(), 0);
1276    }
1277
1278    #[test]
1279    fn inmemory_ipv4_and_ipv6_coexist() {
1280        let b = InMemoryBackend::new();
1281        b.update_dc(&make_dc(2)).unwrap(); // IPv4
1282        b.update_dc(&make_dc_v6(2)).unwrap(); // IPv6
1283
1284        let s = b.snapshot().unwrap();
1285        // Both entries must survive they have different flags
1286        assert_eq!(s.dcs.iter().filter(|d| d.dc_id == 2).count(), 2);
1287    }
1288
1289    #[test]
1290    fn binary_roundtrip_ipv4_and_ipv6() {
1291        let mut s = PersistedSession::default();
1292        s.home_dc_id = 2;
1293        s.dcs.push(make_dc(2));
1294        s.dcs.push(make_dc_v6(2));
1295
1296        let bytes = s.to_bytes();
1297        let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1298        assert_eq!(loaded.dcs.len(), 2);
1299        assert_eq!(loaded.dcs.iter().filter(|d| d.is_ipv6()).count(), 1);
1300        assert_eq!(loaded.dcs.iter().filter(|d| !d.is_ipv6()).count(), 1);
1301    }
1302
1303    #[test]
1304    fn v6_channel_kind_roundtrip_all_variants() {
1305        let mut s = PersistedSession::default();
1306        s.home_dc_id = 1;
1307        s.peers.push(CachedPeer {
1308            id: 1001,
1309            access_hash: 0xaaaa,
1310            is_channel: true,
1311            is_chat: false,
1312            channel_kind: Some(ChannelKind::Broadcast),
1313        });
1314        s.peers.push(CachedPeer {
1315            id: 1002,
1316            access_hash: 0xbbbb,
1317            is_channel: true,
1318            is_chat: false,
1319            channel_kind: Some(ChannelKind::Megagroup),
1320        });
1321        s.peers.push(CachedPeer {
1322            id: 1003,
1323            access_hash: 0xcccc,
1324            is_channel: true,
1325            is_chat: false,
1326            channel_kind: Some(ChannelKind::Gigagroup),
1327        });
1328        s.peers.push(CachedPeer {
1329            id: 1004,
1330            access_hash: 0xdddd,
1331            is_channel: false,
1332            is_chat: false,
1333            channel_kind: None,
1334        });
1335
1336        let bytes = s.to_bytes();
1337        assert_eq!(bytes[0], 0x06, "version byte must be 6");
1338
1339        let loaded = PersistedSession::from_bytes(&bytes).unwrap();
1340        assert_eq!(loaded.peers.len(), 4);
1341
1342        let p = &loaded.peers[0];
1343        assert_eq!(p.id, 1001);
1344        assert_eq!(p.channel_kind, Some(ChannelKind::Broadcast));
1345
1346        let p = &loaded.peers[1];
1347        assert_eq!(p.id, 1002);
1348        assert_eq!(p.channel_kind, Some(ChannelKind::Megagroup));
1349
1350        let p = &loaded.peers[2];
1351        assert_eq!(p.id, 1003);
1352        assert_eq!(p.channel_kind, Some(ChannelKind::Gigagroup));
1353
1354        let p = &loaded.peers[3];
1355        assert_eq!(p.id, 1004);
1356        assert_eq!(p.channel_kind, None);
1357    }
1358
1359    #[test]
1360    fn v6_channel_kind_absent_sentinel_roundtrip() {
1361        // A user peer (not a channel) must survive a v6 encode/decode with kind=None.
1362        let mut s = PersistedSession::default();
1363        s.home_dc_id = 1;
1364        s.peers.push(CachedPeer {
1365            id: 555,
1366            access_hash: 0x1234,
1367            is_channel: false,
1368            is_chat: false,
1369            channel_kind: None,
1370        });
1371        let loaded = PersistedSession::from_bytes(&s.to_bytes()).unwrap();
1372        assert_eq!(loaded.peers[0].channel_kind, None);
1373    }
1374}
1375
1376// SqliteBackend
1377
1378/// SQLite-backed session (via `rusqlite`).
1379///
1380/// Enabled with the `sqlite-session` Cargo feature.
1381///
1382/// # Schema
1383///
1384/// Six tables are created on first open (idempotent):
1385///
1386/// | Table          | Purpose                                          |
1387/// |----------------|--------------------------------------------------|
1388/// | `meta`         | `home_dc_id` and future scalar values            |
1389/// | `dcs`          | One row per DC (auth key, address, flags, ...)     |
1390/// | `update_state` | Single-row pts / qts / date / seq                |
1391/// | `channel_pts`  | Per-channel pts                                  |
1392/// | `peers`        | Access-hash cache (includes is_chat flag)        |
1393/// | `min_peers`    | Min-user message contexts                        |
1394///
1395/// # Granular writes
1396///
1397/// All [`SessionBackend`] extension methods (`update_dc`, `set_home_dc`,
1398/// `apply_update_state`, `cache_peer`) issue **single-row SQL statements**
1399/// instead of the default load-mutate-save round-trip, so they are safe to
1400/// call frequently (e.g. on every update batch) without performance concerns.
1401#[cfg(feature = "sqlite-session")]
1402pub struct SqliteBackend {
1403    conn: std::sync::Mutex<rusqlite::Connection>,
1404    label: String,
1405}
1406
1407#[cfg(feature = "sqlite-session")]
1408impl SqliteBackend {
1409    const SCHEMA: &'static str = "
1410        PRAGMA journal_mode = WAL;
1411        PRAGMA synchronous  = NORMAL;
1412
1413        CREATE TABLE IF NOT EXISTS meta (
1414            key   TEXT    PRIMARY KEY,
1415            value INTEGER NOT NULL DEFAULT 0
1416        );
1417
1418        CREATE TABLE IF NOT EXISTS dcs (
1419            dc_id       INTEGER NOT NULL,
1420            flags       INTEGER NOT NULL DEFAULT 0,
1421            addr        TEXT    NOT NULL,
1422            auth_key    BLOB,
1423            first_salt  INTEGER NOT NULL DEFAULT 0,
1424            time_offset INTEGER NOT NULL DEFAULT 0,
1425            PRIMARY KEY (dc_id, flags)
1426        );
1427
1428        CREATE TABLE IF NOT EXISTS update_state (
1429            id   INTEGER PRIMARY KEY CHECK (id = 1),
1430            pts  INTEGER NOT NULL DEFAULT 0,
1431            qts  INTEGER NOT NULL DEFAULT 0,
1432            date INTEGER NOT NULL DEFAULT 0,
1433            seq  INTEGER NOT NULL DEFAULT 0
1434        );
1435
1436        CREATE TABLE IF NOT EXISTS channel_pts (
1437            channel_id INTEGER PRIMARY KEY,
1438            pts        INTEGER NOT NULL
1439        );
1440
1441        CREATE TABLE IF NOT EXISTS peers (
1442            id           INTEGER PRIMARY KEY,
1443            access_hash  INTEGER NOT NULL,
1444            is_channel   INTEGER NOT NULL DEFAULT 0,
1445            is_chat      INTEGER NOT NULL DEFAULT 0,
1446            channel_kind INTEGER
1447        );
1448
1449        CREATE TABLE IF NOT EXISTS min_peers (
1450            user_id INTEGER PRIMARY KEY,
1451            peer_id INTEGER NOT NULL,
1452            msg_id  INTEGER NOT NULL
1453        );
1454    ";
1455
1456    /// Open (or create) the SQLite database at `path`.
1457    pub fn open(path: impl Into<PathBuf>) -> io::Result<Self> {
1458        let path = path.into();
1459        let label = path.display().to_string();
1460        let conn = rusqlite::Connection::open(&path).map_err(io::Error::other)?;
1461        conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1462        Self::migrate_legacy_sqlite_schema(&conn)?;
1463        Ok(Self {
1464            conn: std::sync::Mutex::new(conn),
1465            label,
1466        })
1467    }
1468
1469    /// Open an in-process SQLite database (useful for tests).
1470    pub fn in_memory() -> io::Result<Self> {
1471        let conn = rusqlite::Connection::open_in_memory().map_err(io::Error::other)?;
1472        conn.execute_batch(Self::SCHEMA).map_err(io::Error::other)?;
1473        Self::migrate_legacy_sqlite_schema(&conn)?;
1474        Ok(Self {
1475            conn: std::sync::Mutex::new(conn),
1476            label: ":memory:".into(),
1477        })
1478    }
1479
1480    fn map_err(e: rusqlite::Error) -> io::Error {
1481        io::Error::other(e)
1482    }
1483
1484    /// Migrate an older database that is missing the is_chat column or the
1485    /// min_peers table. Safe to call on a fresh database; both operations
1486    /// are no-ops if the schema is already current.
1487    fn migrate_legacy_sqlite_schema(conn: &rusqlite::Connection) -> io::Result<()> {
1488        let mut has_is_chat = false;
1489        let mut stmt = conn
1490            .prepare("PRAGMA table_info(peers)")
1491            .map_err(Self::map_err)?;
1492        let cols = stmt
1493            .query_map([], |row| row.get::<_, String>(1))
1494            .map_err(Self::map_err)?;
1495        for col in cols.filter_map(|r| r.ok()) {
1496            if col == "is_chat" {
1497                has_is_chat = true;
1498                break;
1499            }
1500        }
1501        if !has_is_chat {
1502            conn.execute_batch("ALTER TABLE peers ADD COLUMN is_chat INTEGER NOT NULL DEFAULT 0;")
1503                .map_err(Self::map_err)?;
1504        }
1505        // v6 migration: channel_kind column (NULL = absent/unknown).
1506        let mut has_channel_kind = false;
1507        let mut stmt2 = conn
1508            .prepare("PRAGMA table_info(peers)")
1509            .map_err(Self::map_err)?;
1510        let cols2 = stmt2
1511            .query_map([], |row| row.get::<_, String>(1))
1512            .map_err(Self::map_err)?;
1513        for col in cols2.filter_map(|r| r.ok()) {
1514            if col == "channel_kind" {
1515                has_channel_kind = true;
1516                break;
1517            }
1518        }
1519        if !has_channel_kind {
1520            conn.execute_batch("ALTER TABLE peers ADD COLUMN channel_kind INTEGER;")
1521                .map_err(Self::map_err)?;
1522        }
1523        conn.execute_batch(
1524            "CREATE TABLE IF NOT EXISTS min_peers (
1525                user_id INTEGER PRIMARY KEY,
1526                peer_id INTEGER NOT NULL,
1527                msg_id  INTEGER NOT NULL
1528            );",
1529        )
1530        .map_err(Self::map_err)?;
1531        Ok(())
1532    }
1533
1534    /// Read the full session out of the database.
1535    fn read_session(conn: &rusqlite::Connection) -> io::Result<PersistedSession> {
1536        // home_dc_id
1537        let home_dc_id: i32 = conn
1538            .query_row("SELECT value FROM meta WHERE key = 'home_dc_id'", [], |r| {
1539                r.get(0)
1540            })
1541            .unwrap_or(0);
1542
1543        // dcs
1544        let mut stmt = conn
1545            .prepare("SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs")
1546            .map_err(Self::map_err)?;
1547        let dcs = stmt
1548            .query_map([], |row| {
1549                let dc_id: i32 = row.get(0)?;
1550                let flags_raw: u8 = row.get(1)?;
1551                let addr: String = row.get(2)?;
1552                let key_blob: Option<Vec<u8>> = row.get(3)?;
1553                let first_salt: i64 = row.get(4)?;
1554                let time_offset: i32 = row.get(5)?;
1555                Ok((dc_id, addr, key_blob, first_salt, time_offset, flags_raw))
1556            })
1557            .map_err(Self::map_err)?
1558            .filter_map(|r| r.ok())
1559            .map(
1560                |(dc_id, addr, key_blob, first_salt, time_offset, flags_raw)| {
1561                    let auth_key = key_blob.and_then(|b| {
1562                        if b.len() == 256 {
1563                            let mut k = [0u8; 256];
1564                            k.copy_from_slice(&b);
1565                            Some(k)
1566                        } else {
1567                            None
1568                        }
1569                    });
1570                    DcEntry {
1571                        dc_id,
1572                        addr,
1573                        auth_key,
1574                        first_salt,
1575                        time_offset,
1576                        flags: DcFlags(flags_raw),
1577                    }
1578                },
1579            )
1580            .collect();
1581
1582        // update_state
1583        let updates_state = conn
1584            .query_row(
1585                "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
1586                [],
1587                |r| {
1588                    Ok(UpdatesStateSnap {
1589                        pts: r.get(0)?,
1590                        qts: r.get(1)?,
1591                        date: r.get(2)?,
1592                        seq: r.get(3)?,
1593                        channels: vec![],
1594                    })
1595                },
1596            )
1597            .unwrap_or_default();
1598
1599        // channel_pts
1600        let mut ch_stmt = conn
1601            .prepare("SELECT channel_id, pts FROM channel_pts")
1602            .map_err(Self::map_err)?;
1603        let channels: Vec<(i64, i32)> = ch_stmt
1604            .query_map([], |r| Ok((r.get::<_, i64>(0)?, r.get::<_, i32>(1)?)))
1605            .map_err(Self::map_err)?
1606            .filter_map(|r| r.ok())
1607            .collect();
1608
1609        // peers
1610        let mut peer_stmt = conn
1611            .prepare("SELECT id, access_hash, is_channel, is_chat, channel_kind FROM peers")
1612            .map_err(Self::map_err)?;
1613        let peers: Vec<CachedPeer> = peer_stmt
1614            .query_map([], |r| {
1615                let kind_raw: Option<i32> = r.get(4)?;
1616                Ok(CachedPeer {
1617                    id: r.get(0)?,
1618                    access_hash: r.get(1)?,
1619                    is_channel: r.get::<_, i32>(2)? != 0,
1620                    is_chat: r.get::<_, i32>(3)? != 0,
1621                    channel_kind: kind_raw.and_then(|k| ChannelKind::from_byte(k as u8)),
1622                })
1623            })
1624            .map_err(Self::map_err)?
1625            .filter_map(|r| r.ok())
1626            .collect();
1627
1628        // min_peers
1629        let mut min_stmt = conn
1630            .prepare("SELECT user_id, peer_id, msg_id FROM min_peers")
1631            .map_err(Self::map_err)?;
1632        let min_peers: Vec<CachedMinPeer> = min_stmt
1633            .query_map([], |r| {
1634                Ok(CachedMinPeer {
1635                    user_id: r.get(0)?,
1636                    peer_id: r.get(1)?,
1637                    msg_id: r.get(2)?,
1638                })
1639            })
1640            .map_err(Self::map_err)?
1641            .filter_map(|r| r.ok())
1642            .collect();
1643
1644        Ok(PersistedSession {
1645            home_dc_id,
1646            dcs,
1647            updates_state: UpdatesStateSnap {
1648                channels,
1649                ..updates_state
1650            },
1651            peers,
1652            min_peers,
1653        })
1654    }
1655
1656    /// Write the full session into the database inside a single transaction.
1657    fn write_session(conn: &rusqlite::Connection, s: &PersistedSession) -> io::Result<()> {
1658        conn.execute_batch("BEGIN IMMEDIATE")
1659            .map_err(Self::map_err)?;
1660
1661        conn.execute(
1662            "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1663             ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1664            rusqlite::params![s.home_dc_id],
1665        )
1666        .map_err(Self::map_err)?;
1667
1668        // Replace all DCs
1669        conn.execute("DELETE FROM dcs", []).map_err(Self::map_err)?;
1670        for d in &s.dcs {
1671            conn.execute(
1672                "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1673                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1674                rusqlite::params![
1675                    d.dc_id,
1676                    d.flags.0,
1677                    d.addr,
1678                    d.auth_key.as_ref().map(|k| k.as_ref()),
1679                    d.first_salt,
1680                    d.time_offset,
1681                ],
1682            )
1683            .map_err(Self::map_err)?;
1684        }
1685
1686        // update_state  pts and qts are monotonic: write_session() must never
1687        // move them backwards. MAX() ensures a stale snapshot cannot overwrite
1688        // a fresher value committed by apply_update_state().
1689        let us = &s.updates_state;
1690        conn.execute(
1691            "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1, ?1, ?2, ?3, ?4)
1692             ON CONFLICT(id) DO UPDATE SET
1693               pts  = MAX(excluded.pts,  update_state.pts),
1694               qts  = MAX(excluded.qts,  update_state.qts),
1695               date = excluded.date,
1696               seq  = excluded.seq",
1697            rusqlite::params![us.pts, us.qts, us.date, us.seq],
1698        )
1699        .map_err(Self::map_err)?;
1700
1701        conn.execute("DELETE FROM channel_pts", [])
1702            .map_err(Self::map_err)?;
1703        for &(cid, cpts) in &us.channels {
1704            conn.execute(
1705                "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1706                rusqlite::params![cid, cpts],
1707            )
1708            .map_err(Self::map_err)?;
1709        }
1710
1711        // peers
1712        conn.execute("DELETE FROM peers", [])
1713            .map_err(Self::map_err)?;
1714        for p in &s.peers {
1715            conn.execute(
1716                "INSERT INTO peers (id, access_hash, is_channel, is_chat, channel_kind) VALUES (?1, ?2, ?3, ?4, ?5)",
1717                rusqlite::params![
1718                    p.id,
1719                    p.access_hash,
1720                    p.is_channel as i32,
1721                    p.is_chat as i32,
1722                    p.channel_kind.map(|k| k.to_byte() as i32),
1723                ],
1724            )
1725            .map_err(Self::map_err)?;
1726        }
1727
1728        // min_peers
1729        conn.execute("DELETE FROM min_peers", [])
1730            .map_err(Self::map_err)?;
1731        for m in &s.min_peers {
1732            conn.execute(
1733                "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1, ?2, ?3)",
1734                rusqlite::params![m.user_id, m.peer_id, m.msg_id],
1735            )
1736            .map_err(Self::map_err)?;
1737        }
1738
1739        conn.execute_batch("COMMIT").map_err(Self::map_err)
1740    }
1741}
1742
1743#[cfg(feature = "sqlite-session")]
1744impl SessionBackend for SqliteBackend {
1745    fn save(&self, session: &PersistedSession) -> io::Result<()> {
1746        let conn = self.conn.lock().unwrap();
1747        Self::write_session(&conn, session)
1748    }
1749
1750    fn load(&self) -> io::Result<Option<PersistedSession>> {
1751        let conn = self.conn.lock().unwrap();
1752        // If meta table is empty, no session has been saved yet.
1753        let count: i64 = conn
1754            .query_row("SELECT COUNT(*) FROM meta", [], |r| r.get(0))
1755            .map_err(Self::map_err)?;
1756        if count == 0 {
1757            return Ok(None);
1758        }
1759        Self::read_session(&conn).map(Some)
1760    }
1761
1762    fn delete(&self) -> io::Result<()> {
1763        let conn = self.conn.lock().unwrap();
1764        conn.execute_batch(
1765            "BEGIN IMMEDIATE;
1766             DELETE FROM meta;
1767             DELETE FROM dcs;
1768             DELETE FROM update_state;
1769             DELETE FROM channel_pts;
1770             DELETE FROM peers;
1771             DELETE FROM min_peers;
1772             COMMIT;",
1773        )
1774        .map_err(Self::map_err)
1775    }
1776
1777    fn name(&self) -> &str {
1778        &self.label
1779    }
1780
1781    // Granular overrides (single-row SQL, no full round-trip)
1782
1783    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
1784        let conn = self.conn.lock().unwrap();
1785        conn.execute(
1786            "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
1787             VALUES (?1, ?6, ?2, ?3, ?4, ?5)
1788             ON CONFLICT(dc_id, flags) DO UPDATE SET
1789               addr        = excluded.addr,
1790               auth_key    = excluded.auth_key,
1791               first_salt  = excluded.first_salt,
1792               time_offset = excluded.time_offset",
1793            rusqlite::params![
1794                entry.dc_id,
1795                entry.addr,
1796                entry.auth_key.as_ref().map(|k| k.as_ref()),
1797                entry.first_salt,
1798                entry.time_offset,
1799                entry.flags.0,
1800            ],
1801        )
1802        .map(|_| ())
1803        .map_err(Self::map_err)
1804    }
1805
1806    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
1807        let conn = self.conn.lock().unwrap();
1808        conn.execute(
1809            "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
1810             ON CONFLICT(key) DO UPDATE SET value = excluded.value",
1811            rusqlite::params![dc_id],
1812        )
1813        .map(|_| ())
1814        .map_err(Self::map_err)
1815    }
1816
1817    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
1818        let conn = self.conn.lock().unwrap();
1819        match update {
1820            UpdateStateChange::All(snap) => {
1821                conn.execute(
1822                    "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
1823                     ON CONFLICT(id) DO UPDATE SET
1824                       pts=excluded.pts, qts=excluded.qts,
1825                       date=excluded.date, seq=excluded.seq",
1826                    rusqlite::params![snap.pts, snap.qts, snap.date, snap.seq],
1827                )
1828                .map_err(Self::map_err)?;
1829                conn.execute("DELETE FROM channel_pts", [])
1830                    .map_err(Self::map_err)?;
1831                for &(cid, cpts) in &snap.channels {
1832                    conn.execute(
1833                        "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)",
1834                        rusqlite::params![cid, cpts],
1835                    )
1836                    .map_err(Self::map_err)?;
1837                }
1838                Ok(())
1839            }
1840            UpdateStateChange::Primary { pts, date, seq } => conn
1841                .execute(
1842                    "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,0,?2,?3)
1843                     ON CONFLICT(id) DO UPDATE SET pts=excluded.pts, date=excluded.date,
1844                     seq=excluded.seq",
1845                    rusqlite::params![pts, date, seq],
1846                )
1847                .map(|_| ())
1848                .map_err(Self::map_err),
1849            UpdateStateChange::Secondary { qts } => conn
1850                .execute(
1851                    "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,0,?1,0,0)
1852                     ON CONFLICT(id) DO UPDATE SET qts = excluded.qts",
1853                    rusqlite::params![qts],
1854                )
1855                .map(|_| ())
1856                .map_err(Self::map_err),
1857            UpdateStateChange::Channel { id, pts } => conn
1858                .execute(
1859                    "INSERT INTO channel_pts (channel_id, pts) VALUES (?1, ?2)
1860                     ON CONFLICT(channel_id) DO UPDATE SET pts = excluded.pts",
1861                    rusqlite::params![id, pts],
1862                )
1863                .map(|_| ())
1864                .map_err(Self::map_err),
1865        }
1866    }
1867
1868    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
1869        let conn = self.conn.lock().unwrap();
1870        conn.execute(
1871            "INSERT INTO peers (id, access_hash, is_channel, is_chat, channel_kind) VALUES (?1, ?2, ?3, ?4, ?5)
1872             ON CONFLICT(id) DO UPDATE SET
1873               access_hash  = excluded.access_hash,
1874               is_channel   = excluded.is_channel,
1875               is_chat      = excluded.is_chat,
1876               channel_kind = excluded.channel_kind",
1877            rusqlite::params![
1878                peer.id,
1879                peer.access_hash,
1880                peer.is_channel as i32,
1881                peer.is_chat as i32,
1882                peer.channel_kind.map(|k| k.to_byte() as i32),
1883            ],
1884        )
1885        .map(|_| ())
1886        .map_err(Self::map_err)
1887    }
1888}
1889
1890// LibSqlBackend
1891
1892/// libSQL-backed session (Turso / embedded replica / in-process).
1893///
1894/// Enabled with the `libsql-session` Cargo feature.
1895///
1896/// The libSQL API is async; since [`SessionBackend`] methods are sync we
1897/// block via `tokio::runtime::Handle::current().block_on(...)`.  Always
1898/// call from inside a Tokio runtime (i.e. the same runtime as the rest of
1899/// `ferogram`).
1900///
1901/// # Connecting
1902///
1903/// | Mode              | Constructor                        |
1904/// |-------------------|------------------------------------|
1905/// | Local file        | `LibSqlBackend::open_local(path)`  |
1906/// | In-memory         | `LibSqlBackend::in_memory()`       |
1907/// | Turso remote      | `LibSqlBackend::open_remote(url, token)` |
1908/// | Embedded replica  | `LibSqlBackend::open_replica(path, url, token)` |
1909#[cfg(feature = "libsql-session")]
1910pub struct LibSqlBackend {
1911    conn: libsql::Connection,
1912    label: String,
1913}
1914
1915#[cfg(feature = "libsql-session")]
1916impl LibSqlBackend {
1917    const SCHEMA: &'static str = "
1918        CREATE TABLE IF NOT EXISTS meta (
1919            key   TEXT    PRIMARY KEY,
1920            value INTEGER NOT NULL DEFAULT 0
1921        );
1922        CREATE TABLE IF NOT EXISTS dcs (
1923            dc_id       INTEGER NOT NULL,
1924            flags       INTEGER NOT NULL DEFAULT 0,
1925            addr        TEXT    NOT NULL,
1926            auth_key    BLOB,
1927            first_salt  INTEGER NOT NULL DEFAULT 0,
1928            time_offset INTEGER NOT NULL DEFAULT 0,
1929            PRIMARY KEY (dc_id, flags)
1930        );
1931        CREATE TABLE IF NOT EXISTS update_state (
1932            id   INTEGER PRIMARY KEY CHECK (id = 1),
1933            pts  INTEGER NOT NULL DEFAULT 0,
1934            qts  INTEGER NOT NULL DEFAULT 0,
1935            date INTEGER NOT NULL DEFAULT 0,
1936            seq  INTEGER NOT NULL DEFAULT 0
1937        );
1938        CREATE TABLE IF NOT EXISTS channel_pts (
1939            channel_id INTEGER PRIMARY KEY,
1940            pts        INTEGER NOT NULL
1941        );
1942        CREATE TABLE IF NOT EXISTS peers (
1943            id          INTEGER PRIMARY KEY,
1944            access_hash INTEGER NOT NULL,
1945            is_channel  INTEGER NOT NULL DEFAULT 0,
1946            is_chat     INTEGER NOT NULL DEFAULT 0,
1947            channel_kind INTEGER
1948        );
1949        CREATE TABLE IF NOT EXISTS min_peers (
1950            user_id INTEGER PRIMARY KEY,
1951            peer_id INTEGER NOT NULL,
1952            msg_id  INTEGER NOT NULL
1953        );
1954    ";
1955
1956    fn block<F, T>(fut: F) -> io::Result<T>
1957    where
1958        F: std::future::Future<Output = Result<T, libsql::Error>>,
1959    {
1960        tokio::runtime::Handle::current()
1961            .block_on(fut)
1962            .map_err(io::Error::other)
1963    }
1964
1965    async fn apply_schema(conn: &libsql::Connection) -> Result<(), libsql::Error> {
1966        conn.execute_batch(Self::SCHEMA).await?;
1967        // v6 migration: add channel_kind column to existing databases.
1968        // libSQL does not support PRAGMA table_info the same way, so we use a
1969        // best-effort ALTER TABLE that is silently ignored if the column exists.
1970        let _ = conn
1971            .execute_batch("ALTER TABLE peers ADD COLUMN channel_kind INTEGER;")
1972            .await;
1973        Ok(())
1974    }
1975
1976    /// Open a local file database.
1977    pub fn open_local(path: impl Into<PathBuf>) -> io::Result<Self> {
1978        let path = path.into();
1979        let label = path.display().to_string();
1980        let db = Self::block(async { libsql::Builder::new_local(path).build().await })?;
1981        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1982        Self::block(Self::apply_schema(&conn))?;
1983        Ok(Self { conn, label })
1984    }
1985
1986    /// Open an in-process in-memory database (useful for tests).
1987    pub fn in_memory() -> io::Result<Self> {
1988        let db = Self::block(async { libsql::Builder::new_local(":memory:").build().await })?;
1989        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
1990        Self::block(Self::apply_schema(&conn))?;
1991        Ok(Self {
1992            conn,
1993            label: ":memory:".into(),
1994        })
1995    }
1996
1997    /// Connect to a remote Turso database.
1998    pub fn open_remote(url: impl Into<String>, auth_token: impl Into<String>) -> io::Result<Self> {
1999        let url = url.into();
2000        let label = url.clone();
2001        let db = Self::block(async {
2002            libsql::Builder::new_remote(url, auth_token.into())
2003                .build()
2004                .await
2005        })?;
2006        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
2007        Self::block(Self::apply_schema(&conn))?;
2008        Ok(Self { conn, label })
2009    }
2010
2011    /// Open an embedded replica (local file + Turso remote sync).
2012    pub fn open_replica(
2013        path: impl Into<PathBuf>,
2014        url: impl Into<String>,
2015        auth_token: impl Into<String>,
2016    ) -> io::Result<Self> {
2017        let path = path.into();
2018        let url = url.into();
2019        let auth_token = auth_token.into();
2020        let label = format!("{} (replica of {})", path.display(), url);
2021        let db = Self::block(async {
2022            libsql::Builder::new_remote_replica(path, url, auth_token)
2023                .build()
2024                .await
2025        })?;
2026        let conn = Self::block(async { db.connect() }).map_err(io::Error::other)?;
2027        Self::block(Self::apply_schema(&conn))?;
2028        Ok(Self { conn, label })
2029    }
2030
2031    async fn read_session_async(
2032        conn: &libsql::Connection,
2033    ) -> Result<PersistedSession, libsql::Error> {
2034        // home_dc_id
2035        let home_dc_id: i32 = conn
2036            .query("SELECT value FROM meta WHERE key = 'home_dc_id'", ())
2037            .await?
2038            .next()
2039            .await?
2040            .map(|r| r.get::<i32>(0))
2041            .transpose()?
2042            .unwrap_or(0);
2043
2044        // dcs
2045        let mut rows = conn
2046            .query(
2047                "SELECT dc_id, flags, addr, auth_key, first_salt, time_offset FROM dcs",
2048                (),
2049            )
2050            .await?;
2051        let mut dcs = Vec::new();
2052        while let Some(row) = rows.next().await? {
2053            let dc_id: i32 = row.get(0)?;
2054            let flags_raw: u8 = row.get::<i64>(1)? as u8;
2055            let addr: String = row.get(2)?;
2056            let key_blob: Option<Vec<u8>> = row.get(3)?;
2057            let first_salt: i64 = row.get(4)?;
2058            let time_offset: i32 = row.get(5)?;
2059            let auth_key = match key_blob {
2060                Some(b) if b.len() == 256 => {
2061                    let mut k = [0u8; 256];
2062                    k.copy_from_slice(&b);
2063                    Some(k)
2064                }
2065                Some(b) => {
2066                    return Err(libsql::Error::Misuse(format!(
2067                        "auth_key blob must be 256 bytes, got {}",
2068                        b.len()
2069                    )));
2070                }
2071                None => None,
2072            };
2073            dcs.push(DcEntry {
2074                dc_id,
2075                addr,
2076                auth_key,
2077                first_salt,
2078                time_offset,
2079                flags: DcFlags(flags_raw),
2080            });
2081        }
2082
2083        // update_state
2084        let mut us_row = conn
2085            .query(
2086                "SELECT pts, qts, date, seq FROM update_state WHERE id = 1",
2087                (),
2088            )
2089            .await?;
2090        let updates_state = if let Some(r) = us_row.next().await? {
2091            UpdatesStateSnap {
2092                pts: r.get(0)?,
2093                qts: r.get(1)?,
2094                date: r.get(2)?,
2095                seq: r.get(3)?,
2096                channels: vec![],
2097            }
2098        } else {
2099            UpdatesStateSnap::default()
2100        };
2101
2102        // channel_pts
2103        let mut ch_rows = conn
2104            .query("SELECT channel_id, pts FROM channel_pts", ())
2105            .await?;
2106        let mut channels = Vec::new();
2107        while let Some(r) = ch_rows.next().await? {
2108            channels.push((r.get::<i64>(0)?, r.get::<i32>(1)?));
2109        }
2110
2111        // peers
2112        let mut peer_rows = conn
2113            .query(
2114                "SELECT id, access_hash, is_channel, is_chat, channel_kind FROM peers",
2115                (),
2116            )
2117            .await?;
2118        let mut peers = Vec::new();
2119        while let Some(r) = peer_rows.next().await? {
2120            let kind_raw: Option<i32> = r.get(4).ok();
2121            peers.push(CachedPeer {
2122                id: r.get(0)?,
2123                access_hash: r.get(1)?,
2124                is_channel: r.get::<i32>(2)? != 0,
2125                is_chat: r.get::<i32>(3)? != 0,
2126                channel_kind: kind_raw.and_then(|k| ChannelKind::from_byte(k as u8)),
2127            });
2128        }
2129
2130        // min_peers
2131        let mut min_rows = conn
2132            .query("SELECT user_id, peer_id, msg_id FROM min_peers", ())
2133            .await?;
2134        let mut min_peers = Vec::new();
2135        while let Some(r) = min_rows.next().await? {
2136            min_peers.push(CachedMinPeer {
2137                user_id: r.get(0)?,
2138                peer_id: r.get(1)?,
2139                msg_id: r.get(2)?,
2140            });
2141        }
2142
2143        Ok(PersistedSession {
2144            home_dc_id,
2145            dcs,
2146            updates_state: UpdatesStateSnap {
2147                channels,
2148                ..updates_state
2149            },
2150            peers,
2151            min_peers,
2152        })
2153    }
2154
2155    async fn write_session_async(
2156        conn: &libsql::Connection,
2157        s: &PersistedSession,
2158    ) -> Result<(), libsql::Error> {
2159        conn.execute_batch("BEGIN IMMEDIATE").await.map(|_| ())?;
2160
2161        conn.execute(
2162            "INSERT INTO meta (key, value) VALUES ('home_dc_id', ?1)
2163             ON CONFLICT(key) DO UPDATE SET value = excluded.value",
2164            libsql::params![s.home_dc_id],
2165        )
2166        .await?;
2167
2168        conn.execute("DELETE FROM dcs", ()).await?;
2169        for d in &s.dcs {
2170            conn.execute(
2171                "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2172                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
2173                libsql::params![
2174                    d.dc_id,
2175                    d.flags.0 as i64,
2176                    d.addr.clone(),
2177                    d.auth_key.map(|k| k.to_vec()),
2178                    d.first_salt,
2179                    d.time_offset,
2180                ],
2181            )
2182            .await?;
2183        }
2184
2185        let us = &s.updates_state;
2186        conn.execute(
2187            "INSERT INTO update_state (id, pts, qts, date, seq) VALUES (1,?1,?2,?3,?4)
2188             ON CONFLICT(id) DO UPDATE SET
2189               pts  = MAX(excluded.pts,  update_state.pts),
2190               qts  = MAX(excluded.qts,  update_state.qts),
2191               date = excluded.date,
2192               seq  = excluded.seq",
2193            libsql::params![us.pts, us.qts, us.date, us.seq],
2194        )
2195        .await?;
2196
2197        conn.execute("DELETE FROM channel_pts", ()).await?;
2198        for &(cid, cpts) in &us.channels {
2199            conn.execute(
2200                "INSERT INTO channel_pts (channel_id, pts) VALUES (?1,?2)",
2201                libsql::params![cid, cpts],
2202            )
2203            .await?;
2204        }
2205
2206        conn.execute("DELETE FROM peers", ()).await?;
2207        for p in &s.peers {
2208            conn.execute(
2209                "INSERT INTO peers (id, access_hash, is_channel, is_chat, channel_kind) VALUES (?1,?2,?3,?4,?5)",
2210                libsql::params![
2211                    p.id,
2212                    p.access_hash,
2213                    p.is_channel as i32,
2214                    p.is_chat as i32,
2215                    p.channel_kind.map(|k| k.to_byte() as i32),
2216                ],
2217            )
2218            .await?;
2219        }
2220
2221        conn.execute("DELETE FROM min_peers", ()).await?;
2222        for m in &s.min_peers {
2223            conn.execute(
2224                "INSERT INTO min_peers (user_id, peer_id, msg_id) VALUES (?1,?2,?3)",
2225                libsql::params![m.user_id, m.peer_id, m.msg_id],
2226            )
2227            .await?;
2228        }
2229
2230        conn.execute_batch("COMMIT").await.map(|_| ())
2231    }
2232}
2233
2234#[cfg(feature = "libsql-session")]
2235impl SessionBackend for LibSqlBackend {
2236    fn save(&self, session: &PersistedSession) -> io::Result<()> {
2237        let conn = self.conn.clone();
2238        let session = session.clone();
2239        Self::block(async move { Self::write_session_async(&conn, &session).await })
2240    }
2241
2242    fn load(&self) -> io::Result<Option<PersistedSession>> {
2243        let conn = self.conn.clone();
2244        let count: i64 = Self::block(async move {
2245            let mut rows = conn.query("SELECT COUNT(*) FROM meta", ()).await?;
2246            Ok::<i64, libsql::Error>(rows.next().await?.and_then(|r| r.get(0).ok()).unwrap_or(0))
2247        })?;
2248        if count == 0 {
2249            return Ok(None);
2250        }
2251        let conn = self.conn.clone();
2252        Self::block(async move { Self::read_session_async(&conn).await }).map(Some)
2253    }
2254
2255    fn delete(&self) -> io::Result<()> {
2256        let conn = self.conn.clone();
2257        Self::block(async move {
2258            conn.execute_batch(
2259                "BEGIN IMMEDIATE;
2260                 DELETE FROM meta;
2261                 DELETE FROM dcs;
2262                 DELETE FROM update_state;
2263                 DELETE FROM channel_pts;
2264                 DELETE FROM peers;
2265                 DELETE FROM min_peers;
2266                 COMMIT;",
2267            )
2268            .await
2269            .map(|_| ())
2270        })
2271    }
2272
2273    fn name(&self) -> &str {
2274        &self.label
2275    }
2276
2277    // Granular overrides
2278
2279    fn update_dc(&self, entry: &DcEntry) -> io::Result<()> {
2280        let conn = self.conn.clone();
2281        let (dc_id, addr, key, salt, off, flags) = (
2282            entry.dc_id,
2283            entry.addr.clone(),
2284            entry.auth_key.map(|k| k.to_vec()),
2285            entry.first_salt,
2286            entry.time_offset,
2287            entry.flags.0 as i64,
2288        );
2289        Self::block(async move {
2290            conn.execute(
2291                "INSERT INTO dcs (dc_id, flags, addr, auth_key, first_salt, time_offset)
2292                 VALUES (?1,?6,?2,?3,?4,?5)
2293                 ON CONFLICT(dc_id, flags) DO UPDATE SET
2294                   addr=excluded.addr, auth_key=excluded.auth_key,
2295                   first_salt=excluded.first_salt, time_offset=excluded.time_offset",
2296                libsql::params![dc_id, addr, key, salt, off, flags],
2297            )
2298            .await
2299            .map(|_| ())
2300        })
2301    }
2302
2303    fn set_home_dc(&self, dc_id: i32) -> io::Result<()> {
2304        let conn = self.conn.clone();
2305        Self::block(async move {
2306            conn.execute(
2307                "INSERT INTO meta (key, value) VALUES ('home_dc_id',?1)
2308                 ON CONFLICT(key) DO UPDATE SET value=excluded.value",
2309                libsql::params![dc_id],
2310            )
2311            .await
2312            .map(|_| ())
2313        })
2314    }
2315
2316    fn apply_update_state(&self, update: UpdateStateChange) -> io::Result<()> {
2317        let conn = self.conn.clone();
2318        Self::block(async move {
2319            match update {
2320                UpdateStateChange::All(snap) => {
2321                    conn.execute(
2322                        "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,?2,?3,?4)
2323                         ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,qts=excluded.qts,
2324                         date=excluded.date,seq=excluded.seq",
2325                        libsql::params![snap.pts, snap.qts, snap.date, snap.seq],
2326                    )
2327                    .await?;
2328                    conn.execute("DELETE FROM channel_pts", ()).await?;
2329                    for &(cid, cpts) in &snap.channels {
2330                        conn.execute(
2331                            "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)",
2332                            libsql::params![cid, cpts],
2333                        )
2334                        .await?;
2335                    }
2336                    Ok(())
2337                }
2338                UpdateStateChange::Primary { pts, date, seq } => conn
2339                    .execute(
2340                        "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,?1,0,?2,?3)
2341                         ON CONFLICT(id) DO UPDATE SET pts=excluded.pts,date=excluded.date,
2342                         seq=excluded.seq",
2343                        libsql::params![pts, date, seq],
2344                    )
2345                    .await
2346                    .map(|_| ()),
2347                UpdateStateChange::Secondary { qts } => conn
2348                    .execute(
2349                        "INSERT INTO update_state (id,pts,qts,date,seq) VALUES (1,0,?1,0,0)
2350                         ON CONFLICT(id) DO UPDATE SET qts=excluded.qts",
2351                        libsql::params![qts],
2352                    )
2353                    .await
2354                    .map(|_| ()),
2355                UpdateStateChange::Channel { id, pts } => conn
2356                    .execute(
2357                        "INSERT INTO channel_pts (channel_id,pts) VALUES (?1,?2)
2358                         ON CONFLICT(channel_id) DO UPDATE SET pts=excluded.pts",
2359                        libsql::params![id, pts],
2360                    )
2361                    .await
2362                    .map(|_| ()),
2363            }
2364        })
2365    }
2366
2367    fn cache_peer(&self, peer: &CachedPeer) -> io::Result<()> {
2368        let conn = self.conn.clone();
2369        let (id, hash, is_ch, is_ct, kind) = (
2370            peer.id,
2371            peer.access_hash,
2372            peer.is_channel as i32,
2373            peer.is_chat as i32,
2374            peer.channel_kind.map(|k| k.to_byte() as i32),
2375        );
2376        Self::block(async move {
2377            conn.execute(
2378                "INSERT INTO peers (id,access_hash,is_channel,is_chat,channel_kind) VALUES (?1,?2,?3,?4,?5)
2379                 ON CONFLICT(id) DO UPDATE SET
2380                   access_hash=excluded.access_hash,
2381                   is_channel=excluded.is_channel,
2382                   is_chat=excluded.is_chat,
2383                   channel_kind=excluded.channel_kind",
2384                libsql::params![id, hash, is_ch, is_ct, kind],
2385            )
2386            .await
2387            .map(|_| ())
2388        })
2389    }
2390}