Skip to main content

ferogram/
session.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3//
4// ferogram: async Telegram MTProto client in Rust
5// https://github.com/ankit-chaubey/ferogram
6//
7// Based on layer: https://github.com/ankit-chaubey/layer
8// Follows official Telegram client behaviour (tdesktop, TDLib).
9//
10// If you use or modify this code, keep this notice at the top of your file
11// and include the LICENSE-MIT or LICENSE-APACHE file from this repository:
12// https://github.com/ankit-chaubey/ferogram
13
14//! Session persistence: saves auth key, salt, time offset, DC table,
15//! update sequence counters (pts/qts/seq/date/per-channel pts), and
16//! peer access-hash cache.
17//!
18//! ## Binary format versioning
19//!
20//! Every file starts with a single **version byte**:
21//! - `0x01`: legacy format (DC table only, no update state or peers).
22//! - `0x02`: current format (DC table + update state + peer cache).
23//!
24//! `load()` handles both.  `save()` always writes v2.
25
26use std::collections::HashMap;
27use std::io::{self, ErrorKind};
28use std::path::Path;
29
30// DcFlags
31
32/// Per-DC option flags mirroring the tDesktop `DcOption` bitmask.
33///
34/// Stored in the session (v3+) so media DCs survive restarts.
35#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct DcFlags(pub u8);
38
39impl DcFlags {
40    pub const NONE: DcFlags = DcFlags(0);
41    pub const IPV6: DcFlags = DcFlags(1 << 0);
42    pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
43    pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
44    pub const CDN: DcFlags = DcFlags(1 << 3);
45    pub const STATIC: DcFlags = DcFlags(1 << 4);
46
47    pub fn contains(self, other: DcFlags) -> bool {
48        self.0 & other.0 == other.0
49    }
50
51    pub fn set(&mut self, flag: DcFlags) {
52        self.0 |= flag.0;
53    }
54}
55
56impl std::ops::BitOr for DcFlags {
57    type Output = DcFlags;
58    fn bitor(self, rhs: DcFlags) -> DcFlags {
59        DcFlags(self.0 | rhs.0)
60    }
61}
62
63// DcEntry
64
65/// One entry in the DC address table.
66#[derive(Clone, Debug)]
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68pub struct DcEntry {
69    pub dc_id: i32,
70    pub addr: String,
71    pub auth_key: Option<[u8; 256]>,
72    pub first_salt: i64,
73    pub time_offset: i32,
74    /// DC capability flags (IPv6, media-only, CDN, …).
75    pub flags: DcFlags,
76}
77
78impl DcEntry {
79    /// Returns `true` when this entry represents an IPv6 address.
80    #[inline]
81    pub fn is_ipv6(&self) -> bool {
82        self.flags.contains(DcFlags::IPV6)
83    }
84
85    /// Parse the stored `"ip:port"` / `"[ipv6]:port"` address into a
86    /// [`std::net::SocketAddr`].
87    ///
88    /// Both formats are valid:
89    /// - IPv4: `"149.154.175.53:443"`
90    /// - IPv6: `"[2001:b28:f23d:f001::a]:443"`
91    pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
92        self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
93            io::Error::new(
94                io::ErrorKind::InvalidData,
95                format!("invalid DC address: {:?}", self.addr),
96            )
97        })
98    }
99
100    /// Construct a `DcEntry` from separate IP string, port, and flags.
101    ///
102    /// IPv6 addresses are automatically wrapped in brackets so that
103    /// `socket_addr()` can round-trip them correctly:
104    ///
105    /// ```text
106    /// DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6)
107    /// // addr = "[2001:b28:f23d:f001::a]:443"
108    /// ```
109    ///
110    /// This is the preferred constructor when processing `help.getConfig`
111    /// `DcOption` objects from the Telegram API.
112    pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
113        // IPv6 addresses contain colons; wrap in brackets for SocketAddr compat.
114        let addr = if ip.contains(':') {
115            format!("[{ip}]:{port}")
116        } else {
117            format!("{ip}:{port}")
118        };
119        Self {
120            dc_id,
121            addr,
122            auth_key: None,
123            first_salt: 0,
124            time_offset: 0,
125            flags,
126        }
127    }
128}
129
130// UpdatesStateSnap
131
132/// Snapshot of the MTProto update-sequence state that we persist so that
133/// `catch_up: true` can call `updates.getDifference` with the *pre-shutdown* pts.
134#[derive(Clone, Debug, Default)]
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
136pub struct UpdatesStateSnap {
137    /// Main persistence counter (messages, non-channel updates).
138    pub pts: i32,
139    /// Secondary counter for secret chats.
140    pub qts: i32,
141    /// Date of the last received update (Unix timestamp).
142    pub date: i32,
143    /// Combined-container sequence number.
144    pub seq: i32,
145    /// Per-channel persistence counters.  `(channel_id, pts)`.
146    pub channels: Vec<(i64, i32)>,
147}
148
149impl UpdatesStateSnap {
150    /// Returns `true` when we have a real state from the server (pts > 0).
151    #[inline]
152    pub fn is_initialised(&self) -> bool {
153        self.pts > 0
154    }
155
156    /// Advance (or insert) a per-channel pts value.
157    pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
158        if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
159            entry.1 = pts;
160        } else {
161            self.channels.push((channel_id, pts));
162        }
163    }
164
165    /// Look up the stored pts for a channel, returns 0 if unknown.
166    pub fn channel_pts(&self, channel_id: i64) -> i32 {
167        self.channels
168            .iter()
169            .find(|c| c.0 == channel_id)
170            .map(|c| c.1)
171            .unwrap_or(0)
172    }
173}
174
175// CachedPeer
176
177/// A cached access-hash entry so that the peer can be addressed across restarts
178/// without re-resolving it from Telegram.
179#[derive(Clone, Debug)]
180#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
181pub struct CachedPeer {
182    /// Bare Telegram peer ID (always positive).
183    pub id: i64,
184    /// Access hash bound to the current session.
185    pub access_hash: i64,
186    /// `true` → channel / supergroup.  `false` → user.
187    pub is_channel: bool,
188}
189
190// PersistedSession
191
192/// Everything that needs to survive a process restart.
193#[derive(Clone, Debug, Default)]
194#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
195pub struct PersistedSession {
196    pub home_dc_id: i32,
197    pub dcs: Vec<DcEntry>,
198    /// Update counters to enable reliable catch-up after a disconnect.
199    pub updates_state: UpdatesStateSnap,
200    /// Peer access-hash cache so that the client can reach out to any previously
201    /// seen user or channel without re-resolving them.
202    pub peers: Vec<CachedPeer>,
203}
204
205impl PersistedSession {
206    // Serialise (v2)
207
208    /// Encode the session to raw bytes (v2 binary format).
209    pub fn to_bytes(&self) -> Vec<u8> {
210        let mut b = Vec::with_capacity(512);
211
212        b.push(0x03u8); // version
213
214        b.extend_from_slice(&self.home_dc_id.to_le_bytes());
215
216        b.push(self.dcs.len() as u8);
217        for d in &self.dcs {
218            b.extend_from_slice(&d.dc_id.to_le_bytes());
219            match &d.auth_key {
220                Some(k) => {
221                    b.push(1);
222                    b.extend_from_slice(k);
223                }
224                None => {
225                    b.push(0);
226                }
227            }
228            b.extend_from_slice(&d.first_salt.to_le_bytes());
229            b.extend_from_slice(&d.time_offset.to_le_bytes());
230            let ab = d.addr.as_bytes();
231            b.push(ab.len() as u8);
232            b.extend_from_slice(ab);
233            b.push(d.flags.0); // v3: DC flags byte
234        }
235
236        // update state
237        b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
238        b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
239        b.extend_from_slice(&self.updates_state.date.to_le_bytes());
240        b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
241        let ch = &self.updates_state.channels;
242        b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
243        for &(cid, cpts) in ch {
244            b.extend_from_slice(&cid.to_le_bytes());
245            b.extend_from_slice(&cpts.to_le_bytes());
246        }
247
248        // peer cache
249        b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
250        for p in &self.peers {
251            b.extend_from_slice(&p.id.to_le_bytes());
252            b.extend_from_slice(&p.access_hash.to_le_bytes());
253            b.push(p.is_channel as u8);
254        }
255
256        b
257    }
258
259    /// Atomically save the session to `path`.
260    ///
261    /// Writes to `<path>.tmp` first, then renames into place so a crash
262    /// mid-write never corrupts the existing session file.
263    pub fn save(&self, path: &Path) -> io::Result<()> {
264        let tmp = path.with_extension("tmp");
265        std::fs::write(&tmp, self.to_bytes())?;
266        std::fs::rename(&tmp, path)
267    }
268
269    // Deserialise (v1 + v2)
270
271    /// Decode a session from raw bytes (v1 or v2 binary format).
272    pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
273        if buf.is_empty() {
274            return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
275        }
276
277        let mut p = 0usize;
278
279        macro_rules! r {
280            ($n:expr) => {{
281                if p + $n > buf.len() {
282                    return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
283                }
284                let s = &buf[p..p + $n];
285                p += $n;
286                s
287            }};
288        }
289        macro_rules! r_i32 {
290            () => {
291                i32::from_le_bytes(r!(4).try_into().unwrap())
292            };
293        }
294        macro_rules! r_i64 {
295            () => {
296                i64::from_le_bytes(r!(8).try_into().unwrap())
297            };
298        }
299        macro_rules! r_u8 {
300            () => {
301                r!(1)[0]
302            };
303        }
304        macro_rules! r_u16 {
305            () => {
306                u16::from_le_bytes(r!(2).try_into().unwrap())
307            };
308        }
309
310        let first_byte = r_u8!();
311
312        let (home_dc_id, version) = if first_byte == 0x03 {
313            (r_i32!(), 3u8)
314        } else if first_byte == 0x02 {
315            (r_i32!(), 2u8)
316        } else {
317            let rest = r!(3);
318            let mut bytes = [0u8; 4];
319            bytes[0] = first_byte;
320            bytes[1..4].copy_from_slice(rest);
321            (i32::from_le_bytes(bytes), 1u8)
322        };
323
324        let dc_count = r_u8!() as usize;
325        let mut dcs = Vec::with_capacity(dc_count);
326        for _ in 0..dc_count {
327            let dc_id = r_i32!();
328            let has_key = r_u8!();
329            let auth_key = if has_key == 1 {
330                let mut k = [0u8; 256];
331                k.copy_from_slice(r!(256));
332                Some(k)
333            } else {
334                None
335            };
336            let first_salt = r_i64!();
337            let time_offset = r_i32!();
338            let al = r_u8!() as usize;
339            let addr = String::from_utf8_lossy(r!(al)).into_owned();
340            let flags = if version >= 3 {
341                DcFlags(r_u8!())
342            } else {
343                DcFlags::NONE
344            };
345            dcs.push(DcEntry {
346                dc_id,
347                addr,
348                auth_key,
349                first_salt,
350                time_offset,
351                flags,
352            });
353        }
354
355        if version < 2 {
356            return Ok(Self {
357                home_dc_id,
358                dcs,
359                updates_state: UpdatesStateSnap::default(),
360                peers: Vec::new(),
361            });
362        }
363
364        let pts = r_i32!();
365        let qts = r_i32!();
366        let date = r_i32!();
367        let seq = r_i32!();
368        let ch_count = r_u16!() as usize;
369        let mut channels = Vec::with_capacity(ch_count);
370        for _ in 0..ch_count {
371            let cid = r_i64!();
372            let cpts = r_i32!();
373            channels.push((cid, cpts));
374        }
375
376        let peer_count = r_u16!() as usize;
377        let mut peers = Vec::with_capacity(peer_count);
378        for _ in 0..peer_count {
379            let id = r_i64!();
380            let access_hash = r_i64!();
381            let is_channel = r_u8!() != 0;
382            peers.push(CachedPeer {
383                id,
384                access_hash,
385                is_channel,
386            });
387        }
388
389        Ok(Self {
390            home_dc_id,
391            dcs,
392            updates_state: UpdatesStateSnap {
393                pts,
394                qts,
395                date,
396                seq,
397                channels,
398            },
399            peers,
400        })
401    }
402
403    /// Decode a session from a URL-safe base64 string produced by [`to_string`].
404    pub fn from_string(s: &str) -> io::Result<Self> {
405        use base64::Engine as _;
406        let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
407            .decode(s.trim())
408            .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
409        Self::from_bytes(&bytes)
410    }
411
412    pub fn load(path: &Path) -> io::Result<Self> {
413        let buf = std::fs::read(path)?;
414        Self::from_bytes(&buf)
415    }
416
417    // DC address helpers
418
419    /// Find the best DC entry for a given DC ID.
420    ///
421    /// When `prefer_ipv6` is `true`, returns the IPv6 entry if one is
422    /// stored, falling back to IPv4.  When `false`, returns IPv4,
423    /// falling back to IPv6.  Returns `None` only when the DC ID is
424    /// completely unknown.
425    ///
426    /// This correctly handles the case where both an IPv4 and an IPv6
427    /// `DcEntry` exist for the same `dc_id` (different `flags` bitmask).
428    pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
429        let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
430        candidates.peek()?;
431        // Collect so we can search twice
432        let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
433        // Preferred family first, fall back to whatever is available
434        cands
435            .iter()
436            .copied()
437            .find(|d| d.is_ipv6() == prefer_ipv6)
438            .or_else(|| cands.first().copied())
439    }
440
441    /// Iterate over every stored DC entry for a given DC ID.
442    ///
443    /// Typically yields one IPv4 and one IPv6 entry per DC ID once
444    /// `help.getConfig` has been applied.
445    pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
446        self.dcs.iter().filter(move |d| d.dc_id == dc_id)
447    }
448}
449
450impl std::fmt::Display for PersistedSession {
451    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452        use base64::Engine as _;
453        f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
454    }
455}
456
457// Bootstrap DC table
458
459/// Bootstrap DC address table (fallback if GetConfig fails).
460pub fn default_dc_addresses() -> HashMap<i32, String> {
461    [
462        (1, "149.154.175.53:443"),
463        (2, "149.154.167.51:443"),
464        (3, "149.154.175.100:443"),
465        (4, "149.154.167.91:443"),
466        (5, "91.108.56.130:443"),
467    ]
468    .into_iter()
469    .map(|(id, addr)| (id, addr.to_string()))
470    .collect()
471}