Skip to main content

layer_client/
session.rs

1//! Session persistence: saves auth key, salt, time offset, DC table,
2//! update sequence counters (pts/qts/seq/date/per-channel pts), and
3//! peer access-hash cache.
4//!
5//! ## Binary format versioning
6//!
7//! Every file starts with a single **version byte**:
8//! - `0x01`: legacy format (DC table only, no update state or peers).
9//! - `0x02`: current format (DC table + update state + peer cache).
10//!
11//! `load()` handles both.  `save()` always writes v2.
12
13use std::collections::HashMap;
14use std::io::{self, ErrorKind};
15use std::path::Path;
16
17// DcFlags
18
19/// Per-DC option flags mirroring the tDesktop `DcOption` bitmask.
20///
21/// Stored in the session (v3+) so media DCs survive restarts.
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct DcFlags(pub u8);
25
26impl DcFlags {
27    pub const NONE: DcFlags = DcFlags(0);
28    pub const IPV6: DcFlags = DcFlags(1 << 0);
29    pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
30    pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
31    pub const CDN: DcFlags = DcFlags(1 << 3);
32    pub const STATIC: DcFlags = DcFlags(1 << 4);
33
34    pub fn contains(self, other: DcFlags) -> bool {
35        self.0 & other.0 == other.0
36    }
37
38    pub fn set(&mut self, flag: DcFlags) {
39        self.0 |= flag.0;
40    }
41}
42
43impl std::ops::BitOr for DcFlags {
44    type Output = DcFlags;
45    fn bitor(self, rhs: DcFlags) -> DcFlags {
46        DcFlags(self.0 | rhs.0)
47    }
48}
49
50// DcEntry
51
52/// One entry in the DC address table.
53#[derive(Clone, Debug)]
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55pub struct DcEntry {
56    pub dc_id: i32,
57    pub addr: String,
58    pub auth_key: Option<[u8; 256]>,
59    pub first_salt: i64,
60    pub time_offset: i32,
61    /// DC capability flags (IPv6, media-only, CDN, …).
62    pub flags: DcFlags,
63}
64
65impl DcEntry {
66    /// Returns `true` when this entry represents an IPv6 address.
67    #[inline]
68    pub fn is_ipv6(&self) -> bool {
69        self.flags.contains(DcFlags::IPV6)
70    }
71
72    /// Parse the stored `"ip:port"` / `"[ipv6]:port"` address into a
73    /// [`std::net::SocketAddr`].
74    ///
75    /// Both formats are valid:
76    /// - IPv4: `"149.154.175.53:443"`
77    /// - IPv6: `"[2001:b28:f23d:f001::a]:443"`
78    pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
79        self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
80            io::Error::new(
81                io::ErrorKind::InvalidData,
82                format!("invalid DC address: {:?}", self.addr),
83            )
84        })
85    }
86
87    /// Construct a `DcEntry` from separate IP string, port, and flags.
88    ///
89    /// IPv6 addresses are automatically wrapped in brackets so that
90    /// `socket_addr()` can round-trip them correctly:
91    ///
92    /// ```text
93    /// DcEntry::from_parts(2, "2001:b28:f23d:f001::a", 443, DcFlags::IPV6)
94    /// // addr = "[2001:b28:f23d:f001::a]:443"
95    /// ```
96    ///
97    /// This is the preferred constructor when processing `help.getConfig`
98    /// `DcOption` objects from the Telegram API.
99    pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
100        // IPv6 addresses contain colons; wrap in brackets for SocketAddr compat.
101        let addr = if ip.contains(':') {
102            format!("[{ip}]:{port}")
103        } else {
104            format!("{ip}:{port}")
105        };
106        Self {
107            dc_id,
108            addr,
109            auth_key: None,
110            first_salt: 0,
111            time_offset: 0,
112            flags,
113        }
114    }
115}
116
117// UpdatesStateSnap
118
119/// Snapshot of the MTProto update-sequence state that we persist so that
120/// `catch_up: true` can call `updates.getDifference` with the *pre-shutdown* pts.
121#[derive(Clone, Debug, Default)]
122#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
123pub struct UpdatesStateSnap {
124    /// Main persistence counter (messages, non-channel updates).
125    pub pts: i32,
126    /// Secondary counter for secret chats.
127    pub qts: i32,
128    /// Date of the last received update (Unix timestamp).
129    pub date: i32,
130    /// Combined-container sequence number.
131    pub seq: i32,
132    /// Per-channel persistence counters.  `(channel_id, pts)`.
133    pub channels: Vec<(i64, i32)>,
134}
135
136impl UpdatesStateSnap {
137    /// Returns `true` when we have a real state from the server (pts > 0).
138    #[inline]
139    pub fn is_initialised(&self) -> bool {
140        self.pts > 0
141    }
142
143    /// Advance (or insert) a per-channel pts value.
144    pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
145        if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
146            entry.1 = pts;
147        } else {
148            self.channels.push((channel_id, pts));
149        }
150    }
151
152    /// Look up the stored pts for a channel, returns 0 if unknown.
153    pub fn channel_pts(&self, channel_id: i64) -> i32 {
154        self.channels
155            .iter()
156            .find(|c| c.0 == channel_id)
157            .map(|c| c.1)
158            .unwrap_or(0)
159    }
160}
161
162// CachedPeer
163
164/// A cached access-hash entry so that the peer can be addressed across restarts
165/// without re-resolving it from Telegram.
166#[derive(Clone, Debug)]
167#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
168pub struct CachedPeer {
169    /// Bare Telegram peer ID (always positive).
170    pub id: i64,
171    /// Access hash bound to the current session.
172    pub access_hash: i64,
173    /// `true` → channel / supergroup.  `false` → user.
174    pub is_channel: bool,
175}
176
177// PersistedSession
178
179/// Everything that needs to survive a process restart.
180#[derive(Clone, Debug, Default)]
181#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
182pub struct PersistedSession {
183    pub home_dc_id: i32,
184    pub dcs: Vec<DcEntry>,
185    /// Update counters to enable reliable catch-up after a disconnect.
186    pub updates_state: UpdatesStateSnap,
187    /// Peer access-hash cache so that the client can reach out to any previously
188    /// seen user or channel without re-resolving them.
189    pub peers: Vec<CachedPeer>,
190}
191
192impl PersistedSession {
193    // Serialise (v2)
194
195    /// Encode the session to raw bytes (v2 binary format).
196    pub fn to_bytes(&self) -> Vec<u8> {
197        let mut b = Vec::with_capacity(512);
198
199        b.push(0x03u8); // version
200
201        b.extend_from_slice(&self.home_dc_id.to_le_bytes());
202
203        b.push(self.dcs.len() as u8);
204        for d in &self.dcs {
205            b.extend_from_slice(&d.dc_id.to_le_bytes());
206            match &d.auth_key {
207                Some(k) => {
208                    b.push(1);
209                    b.extend_from_slice(k);
210                }
211                None => {
212                    b.push(0);
213                }
214            }
215            b.extend_from_slice(&d.first_salt.to_le_bytes());
216            b.extend_from_slice(&d.time_offset.to_le_bytes());
217            let ab = d.addr.as_bytes();
218            b.push(ab.len() as u8);
219            b.extend_from_slice(ab);
220            b.push(d.flags.0); // v3: DC flags byte
221        }
222
223        // update state
224        b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
225        b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
226        b.extend_from_slice(&self.updates_state.date.to_le_bytes());
227        b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
228        let ch = &self.updates_state.channels;
229        b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
230        for &(cid, cpts) in ch {
231            b.extend_from_slice(&cid.to_le_bytes());
232            b.extend_from_slice(&cpts.to_le_bytes());
233        }
234
235        // peer cache
236        b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
237        for p in &self.peers {
238            b.extend_from_slice(&p.id.to_le_bytes());
239            b.extend_from_slice(&p.access_hash.to_le_bytes());
240            b.push(p.is_channel as u8);
241        }
242
243        b
244    }
245
246    /// Atomically save the session to `path`.
247    ///
248    /// Writes to `<path>.tmp` first, then renames into place so a crash
249    /// mid-write never corrupts the existing session file.
250    pub fn save(&self, path: &Path) -> io::Result<()> {
251        let tmp = path.with_extension("tmp");
252        std::fs::write(&tmp, self.to_bytes())?;
253        std::fs::rename(&tmp, path)
254    }
255
256    // Deserialise (v1 + v2)
257
258    /// Decode a session from raw bytes (v1 or v2 binary format).
259    pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
260        if buf.is_empty() {
261            return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
262        }
263
264        let mut p = 0usize;
265
266        macro_rules! r {
267            ($n:expr) => {{
268                if p + $n > buf.len() {
269                    return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
270                }
271                let s = &buf[p..p + $n];
272                p += $n;
273                s
274            }};
275        }
276        macro_rules! r_i32 {
277            () => {
278                i32::from_le_bytes(r!(4).try_into().unwrap())
279            };
280        }
281        macro_rules! r_i64 {
282            () => {
283                i64::from_le_bytes(r!(8).try_into().unwrap())
284            };
285        }
286        macro_rules! r_u8 {
287            () => {
288                r!(1)[0]
289            };
290        }
291        macro_rules! r_u16 {
292            () => {
293                u16::from_le_bytes(r!(2).try_into().unwrap())
294            };
295        }
296
297        let first_byte = r_u8!();
298
299        let (home_dc_id, version) = if first_byte == 0x03 {
300            (r_i32!(), 3u8)
301        } else if first_byte == 0x02 {
302            (r_i32!(), 2u8)
303        } else {
304            let rest = r!(3);
305            let mut bytes = [0u8; 4];
306            bytes[0] = first_byte;
307            bytes[1..4].copy_from_slice(rest);
308            (i32::from_le_bytes(bytes), 1u8)
309        };
310
311        let dc_count = r_u8!() as usize;
312        let mut dcs = Vec::with_capacity(dc_count);
313        for _ in 0..dc_count {
314            let dc_id = r_i32!();
315            let has_key = r_u8!();
316            let auth_key = if has_key == 1 {
317                let mut k = [0u8; 256];
318                k.copy_from_slice(r!(256));
319                Some(k)
320            } else {
321                None
322            };
323            let first_salt = r_i64!();
324            let time_offset = r_i32!();
325            let al = r_u8!() as usize;
326            let addr = String::from_utf8_lossy(r!(al)).into_owned();
327            let flags = if version >= 3 {
328                DcFlags(r_u8!())
329            } else {
330                DcFlags::NONE
331            };
332            dcs.push(DcEntry {
333                dc_id,
334                addr,
335                auth_key,
336                first_salt,
337                time_offset,
338                flags,
339            });
340        }
341
342        if version < 2 {
343            return Ok(Self {
344                home_dc_id,
345                dcs,
346                updates_state: UpdatesStateSnap::default(),
347                peers: Vec::new(),
348            });
349        }
350
351        let pts = r_i32!();
352        let qts = r_i32!();
353        let date = r_i32!();
354        let seq = r_i32!();
355        let ch_count = r_u16!() as usize;
356        let mut channels = Vec::with_capacity(ch_count);
357        for _ in 0..ch_count {
358            let cid = r_i64!();
359            let cpts = r_i32!();
360            channels.push((cid, cpts));
361        }
362
363        let peer_count = r_u16!() as usize;
364        let mut peers = Vec::with_capacity(peer_count);
365        for _ in 0..peer_count {
366            let id = r_i64!();
367            let access_hash = r_i64!();
368            let is_channel = r_u8!() != 0;
369            peers.push(CachedPeer {
370                id,
371                access_hash,
372                is_channel,
373            });
374        }
375
376        Ok(Self {
377            home_dc_id,
378            dcs,
379            updates_state: UpdatesStateSnap {
380                pts,
381                qts,
382                date,
383                seq,
384                channels,
385            },
386            peers,
387        })
388    }
389
390    /// Decode a session from a URL-safe base64 string produced by [`to_string`].
391    pub fn from_string(s: &str) -> io::Result<Self> {
392        use base64::Engine as _;
393        let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
394            .decode(s.trim())
395            .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
396        Self::from_bytes(&bytes)
397    }
398
399    pub fn load(path: &Path) -> io::Result<Self> {
400        let buf = std::fs::read(path)?;
401        Self::from_bytes(&buf)
402    }
403
404    // DC address helpers
405
406    /// Find the best DC entry for a given DC ID.
407    ///
408    /// When `prefer_ipv6` is `true`, returns the IPv6 entry if one is
409    /// stored, falling back to IPv4.  When `false`, returns IPv4,
410    /// falling back to IPv6.  Returns `None` only when the DC ID is
411    /// completely unknown.
412    ///
413    /// This correctly handles the case where both an IPv4 and an IPv6
414    /// `DcEntry` exist for the same `dc_id` (different `flags` bitmask).
415    pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
416        let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
417        if candidates.peek().is_none() {
418            return None;
419        }
420        // Collect so we can search twice
421        let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
422        // Preferred family first, fall back to whatever is available
423        cands
424            .iter()
425            .copied()
426            .find(|d| d.is_ipv6() == prefer_ipv6)
427            .or_else(|| cands.first().copied())
428    }
429
430    /// Iterate over every stored DC entry for a given DC ID.
431    ///
432    /// Typically yields one IPv4 and one IPv6 entry per DC ID once
433    /// `help.getConfig` has been applied.
434    pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
435        self.dcs.iter().filter(move |d| d.dc_id == dc_id)
436    }
437}
438
439impl std::fmt::Display for PersistedSession {
440    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441        use base64::Engine as _;
442        f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
443    }
444}
445
446// Bootstrap DC table
447
448/// Bootstrap DC address table (fallback if GetConfig fails).
449pub fn default_dc_addresses() -> HashMap<i32, String> {
450    [
451        (1, "149.154.175.53:443"),
452        (2, "149.154.167.51:443"),
453        (3, "149.154.175.100:443"),
454        (4, "149.154.167.91:443"),
455        (5, "91.108.56.130:443"),
456    ]
457    .into_iter()
458    .map(|(id, addr)| (id, addr.to_string()))
459    .collect()
460}