1use std::collections::HashMap;
32use std::io::{self, ErrorKind};
33use std::path::Path;
34
35#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub struct DcFlags(pub u8);
43
44impl DcFlags {
45 pub const NONE: DcFlags = DcFlags(0);
46 pub const IPV6: DcFlags = DcFlags(1 << 0);
47 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
48 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
49 pub const CDN: DcFlags = DcFlags(1 << 3);
50 pub const STATIC: DcFlags = DcFlags(1 << 4);
51
52 pub fn contains(self, other: DcFlags) -> bool {
53 self.0 & other.0 == other.0
54 }
55
56 pub fn set(&mut self, flag: DcFlags) {
57 self.0 |= flag.0;
58 }
59}
60
61impl std::ops::BitOr for DcFlags {
62 type Output = DcFlags;
63 fn bitor(self, rhs: DcFlags) -> DcFlags {
64 DcFlags(self.0 | rhs.0)
65 }
66}
67
68#[derive(Clone, Debug)]
72#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
73pub struct DcEntry {
74 pub dc_id: i32,
75 pub addr: String,
76 pub auth_key: Option<[u8; 256]>,
77 pub first_salt: i64,
78 pub time_offset: i32,
79 pub flags: DcFlags,
81}
82
83impl DcEntry {
84 #[inline]
86 pub fn is_ipv6(&self) -> bool {
87 self.flags.contains(DcFlags::IPV6)
88 }
89
90 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
97 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
98 io::Error::new(
99 io::ErrorKind::InvalidData,
100 format!("invalid DC address: {:?}", self.addr),
101 )
102 })
103 }
104
105 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
118 let addr = if ip.contains(':') {
120 format!("[{ip}]:{port}")
121 } else {
122 format!("{ip}:{port}")
123 };
124 Self {
125 dc_id,
126 addr,
127 auth_key: None,
128 first_salt: 0,
129 time_offset: 0,
130 flags,
131 }
132 }
133}
134
135#[derive(Clone, Debug, Default)]
140#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
141pub struct UpdatesStateSnap {
142 pub pts: i32,
144 pub qts: i32,
146 pub date: i32,
148 pub seq: i32,
150 pub channels: Vec<(i64, i32)>,
152}
153
154impl UpdatesStateSnap {
155 #[inline]
157 pub fn is_initialised(&self) -> bool {
158 self.pts > 0
159 }
160
161 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
163 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
164 entry.1 = pts;
165 } else {
166 self.channels.push((channel_id, pts));
167 }
168 }
169
170 pub fn channel_pts(&self, channel_id: i64) -> i32 {
172 self.channels
173 .iter()
174 .find(|c| c.0 == channel_id)
175 .map(|c| c.1)
176 .unwrap_or(0)
177 }
178}
179
180#[derive(Clone, Debug)]
185#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
186pub struct CachedPeer {
187 pub id: i64,
189 pub access_hash: i64,
191 pub is_channel: bool,
193}
194
195#[derive(Clone, Debug, Default)]
199#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
200pub struct PersistedSession {
201 pub home_dc_id: i32,
202 pub dcs: Vec<DcEntry>,
203 pub updates_state: UpdatesStateSnap,
205 pub peers: Vec<CachedPeer>,
208}
209
210impl PersistedSession {
211 pub fn to_bytes(&self) -> Vec<u8> {
215 let mut b = Vec::with_capacity(512);
216
217 b.push(0x03u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
220
221 b.push(self.dcs.len() as u8);
222 for d in &self.dcs {
223 b.extend_from_slice(&d.dc_id.to_le_bytes());
224 match &d.auth_key {
225 Some(k) => {
226 b.push(1);
227 b.extend_from_slice(k);
228 }
229 None => {
230 b.push(0);
231 }
232 }
233 b.extend_from_slice(&d.first_salt.to_le_bytes());
234 b.extend_from_slice(&d.time_offset.to_le_bytes());
235 let ab = d.addr.as_bytes();
236 b.push(ab.len() as u8);
237 b.extend_from_slice(ab);
238 b.push(d.flags.0); }
240
241 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
243 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
244 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
245 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
246 let ch = &self.updates_state.channels;
247 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
248 for &(cid, cpts) in ch {
249 b.extend_from_slice(&cid.to_le_bytes());
250 b.extend_from_slice(&cpts.to_le_bytes());
251 }
252
253 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
255 for p in &self.peers {
256 b.extend_from_slice(&p.id.to_le_bytes());
257 b.extend_from_slice(&p.access_hash.to_le_bytes());
258 b.push(p.is_channel as u8);
259 }
260
261 b
262 }
263
264 pub fn save(&self, path: &Path) -> io::Result<()> {
269 let tmp = path.with_extension("tmp");
270 std::fs::write(&tmp, self.to_bytes())?;
271 std::fs::rename(&tmp, path)
272 }
273
274 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
278 if buf.is_empty() {
279 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
280 }
281
282 let mut p = 0usize;
283
284 macro_rules! r {
285 ($n:expr) => {{
286 if p + $n > buf.len() {
287 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
288 }
289 let s = &buf[p..p + $n];
290 p += $n;
291 s
292 }};
293 }
294 macro_rules! r_i32 {
295 () => {
296 i32::from_le_bytes(r!(4).try_into().unwrap())
297 };
298 }
299 macro_rules! r_i64 {
300 () => {
301 i64::from_le_bytes(r!(8).try_into().unwrap())
302 };
303 }
304 macro_rules! r_u8 {
305 () => {
306 r!(1)[0]
307 };
308 }
309 macro_rules! r_u16 {
310 () => {
311 u16::from_le_bytes(r!(2).try_into().unwrap())
312 };
313 }
314
315 let first_byte = r_u8!();
316
317 let (home_dc_id, version) = if first_byte == 0x03 {
318 (r_i32!(), 3u8)
319 } else if first_byte == 0x02 {
320 (r_i32!(), 2u8)
321 } else {
322 let rest = r!(3);
323 let mut bytes = [0u8; 4];
324 bytes[0] = first_byte;
325 bytes[1..4].copy_from_slice(rest);
326 (i32::from_le_bytes(bytes), 1u8)
327 };
328
329 let dc_count = r_u8!() as usize;
330 let mut dcs = Vec::with_capacity(dc_count);
331 for _ in 0..dc_count {
332 let dc_id = r_i32!();
333 let has_key = r_u8!();
334 let auth_key = if has_key == 1 {
335 let mut k = [0u8; 256];
336 k.copy_from_slice(r!(256));
337 Some(k)
338 } else {
339 None
340 };
341 let first_salt = r_i64!();
342 let time_offset = r_i32!();
343 let al = r_u8!() as usize;
344 let addr = String::from_utf8_lossy(r!(al)).into_owned();
345 let flags = if version >= 3 {
346 DcFlags(r_u8!())
347 } else {
348 DcFlags::NONE
349 };
350 dcs.push(DcEntry {
351 dc_id,
352 addr,
353 auth_key,
354 first_salt,
355 time_offset,
356 flags,
357 });
358 }
359
360 if version < 2 {
361 return Ok(Self {
362 home_dc_id,
363 dcs,
364 updates_state: UpdatesStateSnap::default(),
365 peers: Vec::new(),
366 });
367 }
368
369 let pts = r_i32!();
370 let qts = r_i32!();
371 let date = r_i32!();
372 let seq = r_i32!();
373 let ch_count = r_u16!() as usize;
374 let mut channels = Vec::with_capacity(ch_count);
375 for _ in 0..ch_count {
376 let cid = r_i64!();
377 let cpts = r_i32!();
378 channels.push((cid, cpts));
379 }
380
381 let peer_count = r_u16!() as usize;
382 let mut peers = Vec::with_capacity(peer_count);
383 for _ in 0..peer_count {
384 let id = r_i64!();
385 let access_hash = r_i64!();
386 let is_channel = r_u8!() != 0;
387 peers.push(CachedPeer {
388 id,
389 access_hash,
390 is_channel,
391 });
392 }
393
394 Ok(Self {
395 home_dc_id,
396 dcs,
397 updates_state: UpdatesStateSnap {
398 pts,
399 qts,
400 date,
401 seq,
402 channels,
403 },
404 peers,
405 })
406 }
407
408 pub fn from_string(s: &str) -> io::Result<Self> {
410 use base64::Engine as _;
411 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
412 .decode(s.trim())
413 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
414 Self::from_bytes(&bytes)
415 }
416
417 pub fn load(path: &Path) -> io::Result<Self> {
418 let buf = std::fs::read(path)?;
419 Self::from_bytes(&buf)
420 }
421
422 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
434 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
435 candidates.peek()?;
436 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
438 cands
440 .iter()
441 .copied()
442 .find(|d| d.is_ipv6() == prefer_ipv6)
443 .or_else(|| cands.first().copied())
444 }
445
446 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
451 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
452 }
453}
454
455impl std::fmt::Display for PersistedSession {
456 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457 use base64::Engine as _;
458 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
459 }
460}
461
462pub fn default_dc_addresses() -> HashMap<i32, String> {
466 [
467 (1, "149.154.175.53:443"),
468 (2, "149.154.167.51:443"),
469 (3, "149.154.175.100:443"),
470 (4, "149.154.167.91:443"),
471 (5, "91.108.56.130:443"),
472 ]
473 .into_iter()
474 .map(|(id, addr)| (id, addr.to_string()))
475 .collect()
476}