1use std::collections::HashMap;
27use std::io::{self, ErrorKind};
28use std::path::Path;
29
30#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct DcFlags(pub u8);
38
39impl DcFlags {
40 pub const NONE: DcFlags = DcFlags(0);
41 pub const IPV6: DcFlags = DcFlags(1 << 0);
42 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
43 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
44 pub const CDN: DcFlags = DcFlags(1 << 3);
45 pub const STATIC: DcFlags = DcFlags(1 << 4);
46
47 pub fn contains(self, other: DcFlags) -> bool {
48 self.0 & other.0 == other.0
49 }
50
51 pub fn set(&mut self, flag: DcFlags) {
52 self.0 |= flag.0;
53 }
54}
55
56impl std::ops::BitOr for DcFlags {
57 type Output = DcFlags;
58 fn bitor(self, rhs: DcFlags) -> DcFlags {
59 DcFlags(self.0 | rhs.0)
60 }
61}
62
63#[derive(Clone, Debug)]
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68pub struct DcEntry {
69 pub dc_id: i32,
70 pub addr: String,
71 pub auth_key: Option<[u8; 256]>,
72 pub first_salt: i64,
73 pub time_offset: i32,
74 pub flags: DcFlags,
76}
77
78impl DcEntry {
79 #[inline]
81 pub fn is_ipv6(&self) -> bool {
82 self.flags.contains(DcFlags::IPV6)
83 }
84
85 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
92 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
93 io::Error::new(
94 io::ErrorKind::InvalidData,
95 format!("invalid DC address: {:?}", self.addr),
96 )
97 })
98 }
99
100 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
113 let addr = if ip.contains(':') {
115 format!("[{ip}]:{port}")
116 } else {
117 format!("{ip}:{port}")
118 };
119 Self {
120 dc_id,
121 addr,
122 auth_key: None,
123 first_salt: 0,
124 time_offset: 0,
125 flags,
126 }
127 }
128}
129
130#[derive(Clone, Debug, Default)]
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
136pub struct UpdatesStateSnap {
137 pub pts: i32,
139 pub qts: i32,
141 pub date: i32,
143 pub seq: i32,
145 pub channels: Vec<(i64, i32)>,
147}
148
149impl UpdatesStateSnap {
150 #[inline]
152 pub fn is_initialised(&self) -> bool {
153 self.pts > 0
154 }
155
156 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
158 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
159 entry.1 = pts;
160 } else {
161 self.channels.push((channel_id, pts));
162 }
163 }
164
165 pub fn channel_pts(&self, channel_id: i64) -> i32 {
167 self.channels
168 .iter()
169 .find(|c| c.0 == channel_id)
170 .map(|c| c.1)
171 .unwrap_or(0)
172 }
173}
174
175#[derive(Clone, Debug)]
180#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
181pub struct CachedPeer {
182 pub id: i64,
184 pub access_hash: i64,
186 pub is_channel: bool,
188}
189
190#[derive(Clone, Debug, Default)]
194#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
195pub struct PersistedSession {
196 pub home_dc_id: i32,
197 pub dcs: Vec<DcEntry>,
198 pub updates_state: UpdatesStateSnap,
200 pub peers: Vec<CachedPeer>,
203}
204
205impl PersistedSession {
206 pub fn to_bytes(&self) -> Vec<u8> {
210 let mut b = Vec::with_capacity(512);
211
212 b.push(0x03u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
215
216 b.push(self.dcs.len() as u8);
217 for d in &self.dcs {
218 b.extend_from_slice(&d.dc_id.to_le_bytes());
219 match &d.auth_key {
220 Some(k) => {
221 b.push(1);
222 b.extend_from_slice(k);
223 }
224 None => {
225 b.push(0);
226 }
227 }
228 b.extend_from_slice(&d.first_salt.to_le_bytes());
229 b.extend_from_slice(&d.time_offset.to_le_bytes());
230 let ab = d.addr.as_bytes();
231 b.push(ab.len() as u8);
232 b.extend_from_slice(ab);
233 b.push(d.flags.0); }
235
236 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
238 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
239 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
240 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
241 let ch = &self.updates_state.channels;
242 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
243 for &(cid, cpts) in ch {
244 b.extend_from_slice(&cid.to_le_bytes());
245 b.extend_from_slice(&cpts.to_le_bytes());
246 }
247
248 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
250 for p in &self.peers {
251 b.extend_from_slice(&p.id.to_le_bytes());
252 b.extend_from_slice(&p.access_hash.to_le_bytes());
253 b.push(p.is_channel as u8);
254 }
255
256 b
257 }
258
259 pub fn save(&self, path: &Path) -> io::Result<()> {
264 let tmp = path.with_extension("tmp");
265 std::fs::write(&tmp, self.to_bytes())?;
266 std::fs::rename(&tmp, path)
267 }
268
269 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
273 if buf.is_empty() {
274 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
275 }
276
277 let mut p = 0usize;
278
279 macro_rules! r {
280 ($n:expr) => {{
281 if p + $n > buf.len() {
282 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
283 }
284 let s = &buf[p..p + $n];
285 p += $n;
286 s
287 }};
288 }
289 macro_rules! r_i32 {
290 () => {
291 i32::from_le_bytes(r!(4).try_into().unwrap())
292 };
293 }
294 macro_rules! r_i64 {
295 () => {
296 i64::from_le_bytes(r!(8).try_into().unwrap())
297 };
298 }
299 macro_rules! r_u8 {
300 () => {
301 r!(1)[0]
302 };
303 }
304 macro_rules! r_u16 {
305 () => {
306 u16::from_le_bytes(r!(2).try_into().unwrap())
307 };
308 }
309
310 let first_byte = r_u8!();
311
312 let (home_dc_id, version) = if first_byte == 0x03 {
313 (r_i32!(), 3u8)
314 } else if first_byte == 0x02 {
315 (r_i32!(), 2u8)
316 } else {
317 let rest = r!(3);
318 let mut bytes = [0u8; 4];
319 bytes[0] = first_byte;
320 bytes[1..4].copy_from_slice(rest);
321 (i32::from_le_bytes(bytes), 1u8)
322 };
323
324 let dc_count = r_u8!() as usize;
325 let mut dcs = Vec::with_capacity(dc_count);
326 for _ in 0..dc_count {
327 let dc_id = r_i32!();
328 let has_key = r_u8!();
329 let auth_key = if has_key == 1 {
330 let mut k = [0u8; 256];
331 k.copy_from_slice(r!(256));
332 Some(k)
333 } else {
334 None
335 };
336 let first_salt = r_i64!();
337 let time_offset = r_i32!();
338 let al = r_u8!() as usize;
339 let addr = String::from_utf8_lossy(r!(al)).into_owned();
340 let flags = if version >= 3 {
341 DcFlags(r_u8!())
342 } else {
343 DcFlags::NONE
344 };
345 dcs.push(DcEntry {
346 dc_id,
347 addr,
348 auth_key,
349 first_salt,
350 time_offset,
351 flags,
352 });
353 }
354
355 if version < 2 {
356 return Ok(Self {
357 home_dc_id,
358 dcs,
359 updates_state: UpdatesStateSnap::default(),
360 peers: Vec::new(),
361 });
362 }
363
364 let pts = r_i32!();
365 let qts = r_i32!();
366 let date = r_i32!();
367 let seq = r_i32!();
368 let ch_count = r_u16!() as usize;
369 let mut channels = Vec::with_capacity(ch_count);
370 for _ in 0..ch_count {
371 let cid = r_i64!();
372 let cpts = r_i32!();
373 channels.push((cid, cpts));
374 }
375
376 let peer_count = r_u16!() as usize;
377 let mut peers = Vec::with_capacity(peer_count);
378 for _ in 0..peer_count {
379 let id = r_i64!();
380 let access_hash = r_i64!();
381 let is_channel = r_u8!() != 0;
382 peers.push(CachedPeer {
383 id,
384 access_hash,
385 is_channel,
386 });
387 }
388
389 Ok(Self {
390 home_dc_id,
391 dcs,
392 updates_state: UpdatesStateSnap {
393 pts,
394 qts,
395 date,
396 seq,
397 channels,
398 },
399 peers,
400 })
401 }
402
403 pub fn from_string(s: &str) -> io::Result<Self> {
405 use base64::Engine as _;
406 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
407 .decode(s.trim())
408 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
409 Self::from_bytes(&bytes)
410 }
411
412 pub fn load(path: &Path) -> io::Result<Self> {
413 let buf = std::fs::read(path)?;
414 Self::from_bytes(&buf)
415 }
416
417 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
429 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
430 candidates.peek()?;
431 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
433 cands
435 .iter()
436 .copied()
437 .find(|d| d.is_ipv6() == prefer_ipv6)
438 .or_else(|| cands.first().copied())
439 }
440
441 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
446 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
447 }
448}
449
450impl std::fmt::Display for PersistedSession {
451 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452 use base64::Engine as _;
453 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
454 }
455}
456
457pub fn default_dc_addresses() -> HashMap<i32, String> {
461 [
462 (1, "149.154.175.53:443"),
463 (2, "149.154.167.51:443"),
464 (3, "149.154.175.100:443"),
465 (4, "149.154.167.91:443"),
466 (5, "91.108.56.130:443"),
467 ]
468 .into_iter()
469 .map(|(id, addr)| (id, addr.to_string()))
470 .collect()
471}