Skip to main content

layer_client/
session.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4// NOTE:
5// The "Layer" project is no longer maintained or supported.
6// Its original purpose for personal SDK/APK experimentation and learning
7// has been fulfilled.
8//
9// Please use Ferogram instead:
10// https://github.com/ankit-chaubey/ferogram
11// Ferogram will receive future updates and development, although progress
12// may be slower.
13//
14// Ferogram is an async Telegram MTProto client library written in Rust.
15// Its implementation follows the behaviour of the official Telegram clients,
16// particularly Telegram Desktop and TDLib, and aims to provide a clean and
17// modern async interface for building Telegram clients and tools.
18
19//! Session persistence: saves auth key, salt, time offset, DC table,
20//! update sequence counters (pts/qts/seq/date/per-channel pts), and
21//! peer access-hash cache.
22//!
23//! ## Binary format versioning
24//!
25//! Every file starts with a single **version byte**:
26//! - `0x01`: legacy format (DC table only, no update state or peers).
27//! - `0x02`: current format (DC table + update state + peer cache).
28//!
29//! `load()` handles both.  `save()` always writes v2.
30
31use std::collections::HashMap;
32use std::io::{self, ErrorKind};
33use std::path::Path;
34
35// DcFlags
36
37/// Per-DC option flags mirroring the tDesktop `DcOption` bitmask.
38///
39/// Stored in the session (v3+) so media DCs survive restarts.
40#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub struct DcFlags(pub u8);
43
44impl DcFlags {
45    pub const NONE: DcFlags = DcFlags(0);
46    pub const IPV6: DcFlags = DcFlags(1 << 0);
47    pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
48    pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
49    pub const CDN: DcFlags = DcFlags(1 << 3);
50    pub const STATIC: DcFlags = DcFlags(1 << 4);
51
52    pub fn contains(self, other: DcFlags) -> bool {
53        self.0 & other.0 == other.0
54    }
55
56    pub fn set(&mut self, flag: DcFlags) {
57        self.0 |= flag.0;
58    }
59}
60
61impl std::ops::BitOr for DcFlags {
62    type Output = DcFlags;
63    fn bitor(self, rhs: DcFlags) -> DcFlags {
64        DcFlags(self.0 | rhs.0)
65    }
66}
67
68// DcEntry
69
70/// One entry in the DC address table.
71#[derive(Clone, Debug)]
72#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
73pub struct DcEntry {
74    pub dc_id: i32,
75    pub addr: String,
76    pub auth_key: Option<[u8; 256]>,
77    pub first_salt: i64,
78    pub time_offset: i32,
79    /// DC capability flags (IPv6, media-only, CDN, …).
80    pub flags: DcFlags,
81}
82
83impl DcEntry {
84    /// Returns `true` when this entry represents an IPv6 address.
85    #[inline]
86    pub fn is_ipv6(&self) -> bool {
87        self.flags.contains(DcFlags::IPV6)
88    }
89
90    /// Parse the stored `"ip:port"` / `"[ipv6]:port"` address into a
91    /// [`std::net::SocketAddr`].
92    ///
93    /// Both formats are valid:
94    /// - IPv4: `"149.154.175.53:443"`
95    /// - IPv6: `"[2001:b28:f23d:f001::a]:443"`
96    pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
97        self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
98            io::Error::new(
99                io::ErrorKind::InvalidData,
100                format!("invalid DC address: {:?}", self.addr),
101            )
102        })
103    }
104
105    /// Construct a `DcEntry` from separate IP string, port, and flags.
106    ///
107    /// IPv6 addresses are automatically wrapped in brackets so that
108    /// `socket_addr()` can round-trip them correctly:
109    ///
110    /// ```text
111    /// DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6)
112    /// // addr = "[2001:b28:f23d:f001::a]:443"
113    /// ```
114    ///
115    /// This is the preferred constructor when processing `help.getConfig`
116    /// `DcOption` objects from the Telegram API.
117    pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
118        // IPv6 addresses contain colons; wrap in brackets for SocketAddr compat.
119        let addr = if ip.contains(':') {
120            format!("[{ip}]:{port}")
121        } else {
122            format!("{ip}:{port}")
123        };
124        Self {
125            dc_id,
126            addr,
127            auth_key: None,
128            first_salt: 0,
129            time_offset: 0,
130            flags,
131        }
132    }
133}
134
135// UpdatesStateSnap
136
137/// Snapshot of the MTProto update-sequence state that we persist so that
138/// `catch_up: true` can call `updates.getDifference` with the *pre-shutdown* pts.
139#[derive(Clone, Debug, Default)]
140#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
141pub struct UpdatesStateSnap {
142    /// Main persistence counter (messages, non-channel updates).
143    pub pts: i32,
144    /// Secondary counter for secret chats.
145    pub qts: i32,
146    /// Date of the last received update (Unix timestamp).
147    pub date: i32,
148    /// Combined-container sequence number.
149    pub seq: i32,
150    /// Per-channel persistence counters.  `(channel_id, pts)`.
151    pub channels: Vec<(i64, i32)>,
152}
153
154impl UpdatesStateSnap {
155    /// Returns `true` when we have a real state from the server (pts > 0).
156    #[inline]
157    pub fn is_initialised(&self) -> bool {
158        self.pts > 0
159    }
160
161    /// Advance (or insert) a per-channel pts value.
162    pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
163        if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
164            entry.1 = pts;
165        } else {
166            self.channels.push((channel_id, pts));
167        }
168    }
169
170    /// Look up the stored pts for a channel, returns 0 if unknown.
171    pub fn channel_pts(&self, channel_id: i64) -> i32 {
172        self.channels
173            .iter()
174            .find(|c| c.0 == channel_id)
175            .map(|c| c.1)
176            .unwrap_or(0)
177    }
178}
179
180// CachedPeer
181
182/// A cached access-hash entry so that the peer can be addressed across restarts
183/// without re-resolving it from Telegram.
184#[derive(Clone, Debug)]
185#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
186pub struct CachedPeer {
187    /// Bare Telegram peer ID (always positive).
188    pub id: i64,
189    /// Access hash bound to the current session.
190    pub access_hash: i64,
191    /// `true` → channel / supergroup.  `false` → user.
192    pub is_channel: bool,
193}
194
195// PersistedSession
196
197/// Everything that needs to survive a process restart.
198#[derive(Clone, Debug, Default)]
199#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
200pub struct PersistedSession {
201    pub home_dc_id: i32,
202    pub dcs: Vec<DcEntry>,
203    /// Update counters to enable reliable catch-up after a disconnect.
204    pub updates_state: UpdatesStateSnap,
205    /// Peer access-hash cache so that the client can reach out to any previously
206    /// seen user or channel without re-resolving them.
207    pub peers: Vec<CachedPeer>,
208}
209
210impl PersistedSession {
211    // Serialise (v2)
212
213    /// Encode the session to raw bytes (v2 binary format).
214    pub fn to_bytes(&self) -> Vec<u8> {
215        let mut b = Vec::with_capacity(512);
216
217        b.push(0x03u8); // version
218
219        b.extend_from_slice(&self.home_dc_id.to_le_bytes());
220
221        b.push(self.dcs.len() as u8);
222        for d in &self.dcs {
223            b.extend_from_slice(&d.dc_id.to_le_bytes());
224            match &d.auth_key {
225                Some(k) => {
226                    b.push(1);
227                    b.extend_from_slice(k);
228                }
229                None => {
230                    b.push(0);
231                }
232            }
233            b.extend_from_slice(&d.first_salt.to_le_bytes());
234            b.extend_from_slice(&d.time_offset.to_le_bytes());
235            let ab = d.addr.as_bytes();
236            b.push(ab.len() as u8);
237            b.extend_from_slice(ab);
238            b.push(d.flags.0); // v3: DC flags byte
239        }
240
241        // update state
242        b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
243        b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
244        b.extend_from_slice(&self.updates_state.date.to_le_bytes());
245        b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
246        let ch = &self.updates_state.channels;
247        b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
248        for &(cid, cpts) in ch {
249            b.extend_from_slice(&cid.to_le_bytes());
250            b.extend_from_slice(&cpts.to_le_bytes());
251        }
252
253        // peer cache
254        b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
255        for p in &self.peers {
256            b.extend_from_slice(&p.id.to_le_bytes());
257            b.extend_from_slice(&p.access_hash.to_le_bytes());
258            b.push(p.is_channel as u8);
259        }
260
261        b
262    }
263
264    /// Atomically save the session to `path`.
265    ///
266    /// Writes to `<path>.tmp` first, then renames into place so a crash
267    /// mid-write never corrupts the existing session file.
268    pub fn save(&self, path: &Path) -> io::Result<()> {
269        let tmp = path.with_extension("tmp");
270        std::fs::write(&tmp, self.to_bytes())?;
271        std::fs::rename(&tmp, path)
272    }
273
274    // Deserialise (v1 + v2)
275
276    /// Decode a session from raw bytes (v1 or v2 binary format).
277    pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
278        if buf.is_empty() {
279            return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
280        }
281
282        let mut p = 0usize;
283
284        macro_rules! r {
285            ($n:expr) => {{
286                if p + $n > buf.len() {
287                    return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
288                }
289                let s = &buf[p..p + $n];
290                p += $n;
291                s
292            }};
293        }
294        macro_rules! r_i32 {
295            () => {
296                i32::from_le_bytes(r!(4).try_into().unwrap())
297            };
298        }
299        macro_rules! r_i64 {
300            () => {
301                i64::from_le_bytes(r!(8).try_into().unwrap())
302            };
303        }
304        macro_rules! r_u8 {
305            () => {
306                r!(1)[0]
307            };
308        }
309        macro_rules! r_u16 {
310            () => {
311                u16::from_le_bytes(r!(2).try_into().unwrap())
312            };
313        }
314
315        let first_byte = r_u8!();
316
317        let (home_dc_id, version) = if first_byte == 0x03 {
318            (r_i32!(), 3u8)
319        } else if first_byte == 0x02 {
320            (r_i32!(), 2u8)
321        } else {
322            let rest = r!(3);
323            let mut bytes = [0u8; 4];
324            bytes[0] = first_byte;
325            bytes[1..4].copy_from_slice(rest);
326            (i32::from_le_bytes(bytes), 1u8)
327        };
328
329        let dc_count = r_u8!() as usize;
330        let mut dcs = Vec::with_capacity(dc_count);
331        for _ in 0..dc_count {
332            let dc_id = r_i32!();
333            let has_key = r_u8!();
334            let auth_key = if has_key == 1 {
335                let mut k = [0u8; 256];
336                k.copy_from_slice(r!(256));
337                Some(k)
338            } else {
339                None
340            };
341            let first_salt = r_i64!();
342            let time_offset = r_i32!();
343            let al = r_u8!() as usize;
344            let addr = String::from_utf8_lossy(r!(al)).into_owned();
345            let flags = if version >= 3 {
346                DcFlags(r_u8!())
347            } else {
348                DcFlags::NONE
349            };
350            dcs.push(DcEntry {
351                dc_id,
352                addr,
353                auth_key,
354                first_salt,
355                time_offset,
356                flags,
357            });
358        }
359
360        if version < 2 {
361            return Ok(Self {
362                home_dc_id,
363                dcs,
364                updates_state: UpdatesStateSnap::default(),
365                peers: Vec::new(),
366            });
367        }
368
369        let pts = r_i32!();
370        let qts = r_i32!();
371        let date = r_i32!();
372        let seq = r_i32!();
373        let ch_count = r_u16!() as usize;
374        let mut channels = Vec::with_capacity(ch_count);
375        for _ in 0..ch_count {
376            let cid = r_i64!();
377            let cpts = r_i32!();
378            channels.push((cid, cpts));
379        }
380
381        let peer_count = r_u16!() as usize;
382        let mut peers = Vec::with_capacity(peer_count);
383        for _ in 0..peer_count {
384            let id = r_i64!();
385            let access_hash = r_i64!();
386            let is_channel = r_u8!() != 0;
387            peers.push(CachedPeer {
388                id,
389                access_hash,
390                is_channel,
391            });
392        }
393
394        Ok(Self {
395            home_dc_id,
396            dcs,
397            updates_state: UpdatesStateSnap {
398                pts,
399                qts,
400                date,
401                seq,
402                channels,
403            },
404            peers,
405        })
406    }
407
408    /// Decode a session from a URL-safe base64 string produced by [`to_string`].
409    pub fn from_string(s: &str) -> io::Result<Self> {
410        use base64::Engine as _;
411        let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
412            .decode(s.trim())
413            .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
414        Self::from_bytes(&bytes)
415    }
416
417    pub fn load(path: &Path) -> io::Result<Self> {
418        let buf = std::fs::read(path)?;
419        Self::from_bytes(&buf)
420    }
421
422    // DC address helpers
423
424    /// Find the best DC entry for a given DC ID.
425    ///
426    /// When `prefer_ipv6` is `true`, returns the IPv6 entry if one is
427    /// stored, falling back to IPv4.  When `false`, returns IPv4,
428    /// falling back to IPv6.  Returns `None` only when the DC ID is
429    /// completely unknown.
430    ///
431    /// This correctly handles the case where both an IPv4 and an IPv6
432    /// `DcEntry` exist for the same `dc_id` (different `flags` bitmask).
433    pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
434        let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
435        candidates.peek()?;
436        // Collect so we can search twice
437        let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
438        // Preferred family first, fall back to whatever is available
439        cands
440            .iter()
441            .copied()
442            .find(|d| d.is_ipv6() == prefer_ipv6)
443            .or_else(|| cands.first().copied())
444    }
445
446    /// Iterate over every stored DC entry for a given DC ID.
447    ///
448    /// Typically yields one IPv4 and one IPv6 entry per DC ID once
449    /// `help.getConfig` has been applied.
450    pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
451        self.dcs.iter().filter(move |d| d.dc_id == dc_id)
452    }
453}
454
455impl std::fmt::Display for PersistedSession {
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        use base64::Engine as _;
458        f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
459    }
460}
461
462// Bootstrap DC table
463
464/// Bootstrap DC address table (fallback if GetConfig fails).
465pub fn default_dc_addresses() -> HashMap<i32, String> {
466    [
467        (1, "149.154.175.53:443"),
468        (2, "149.154.167.51:443"),
469        (3, "149.154.175.100:443"),
470        (4, "149.154.167.91:443"),
471        (5, "91.108.56.130:443"),
472    ]
473    .into_iter()
474    .map(|(id, addr)| (id, addr.to_string()))
475    .collect()
476}