1use std::collections::HashMap;
14use std::io::{self, ErrorKind};
15use std::path::Path;
16
17#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct DcFlags(pub u8);
25
26impl DcFlags {
27 pub const NONE: DcFlags = DcFlags(0);
28 pub const IPV6: DcFlags = DcFlags(1 << 0);
29 pub const MEDIA_ONLY: DcFlags = DcFlags(1 << 1);
30 pub const TCPO_ONLY: DcFlags = DcFlags(1 << 2);
31 pub const CDN: DcFlags = DcFlags(1 << 3);
32 pub const STATIC: DcFlags = DcFlags(1 << 4);
33
34 pub fn contains(self, other: DcFlags) -> bool {
35 self.0 & other.0 == other.0
36 }
37
38 pub fn set(&mut self, flag: DcFlags) {
39 self.0 |= flag.0;
40 }
41}
42
43impl std::ops::BitOr for DcFlags {
44 type Output = DcFlags;
45 fn bitor(self, rhs: DcFlags) -> DcFlags {
46 DcFlags(self.0 | rhs.0)
47 }
48}
49
50#[derive(Clone, Debug)]
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55pub struct DcEntry {
56 pub dc_id: i32,
57 pub addr: String,
58 pub auth_key: Option<[u8; 256]>,
59 pub first_salt: i64,
60 pub time_offset: i32,
61 pub flags: DcFlags,
63}
64
65impl DcEntry {
66 #[inline]
68 pub fn is_ipv6(&self) -> bool {
69 self.flags.contains(DcFlags::IPV6)
70 }
71
72 pub fn socket_addr(&self) -> io::Result<std::net::SocketAddr> {
79 self.addr.parse::<std::net::SocketAddr>().map_err(|_| {
80 io::Error::new(
81 io::ErrorKind::InvalidData,
82 format!("invalid DC address: {:?}", self.addr),
83 )
84 })
85 }
86
87 pub fn from_parts(dc_id: i32, ip: &str, port: u16, flags: DcFlags) -> Self {
100 let addr = if ip.contains(':') {
102 format!("[{ip}]:{port}")
103 } else {
104 format!("{ip}:{port}")
105 };
106 Self {
107 dc_id,
108 addr,
109 auth_key: None,
110 first_salt: 0,
111 time_offset: 0,
112 flags,
113 }
114 }
115}
116
117#[derive(Clone, Debug, Default)]
122#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
123pub struct UpdatesStateSnap {
124 pub pts: i32,
126 pub qts: i32,
128 pub date: i32,
130 pub seq: i32,
132 pub channels: Vec<(i64, i32)>,
134}
135
136impl UpdatesStateSnap {
137 #[inline]
139 pub fn is_initialised(&self) -> bool {
140 self.pts > 0
141 }
142
143 pub fn set_channel_pts(&mut self, channel_id: i64, pts: i32) {
145 if let Some(entry) = self.channels.iter_mut().find(|c| c.0 == channel_id) {
146 entry.1 = pts;
147 } else {
148 self.channels.push((channel_id, pts));
149 }
150 }
151
152 pub fn channel_pts(&self, channel_id: i64) -> i32 {
154 self.channels
155 .iter()
156 .find(|c| c.0 == channel_id)
157 .map(|c| c.1)
158 .unwrap_or(0)
159 }
160}
161
162#[derive(Clone, Debug)]
167#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
168pub struct CachedPeer {
169 pub id: i64,
171 pub access_hash: i64,
173 pub is_channel: bool,
175}
176
177#[derive(Clone, Debug, Default)]
181#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
182pub struct PersistedSession {
183 pub home_dc_id: i32,
184 pub dcs: Vec<DcEntry>,
185 pub updates_state: UpdatesStateSnap,
187 pub peers: Vec<CachedPeer>,
190}
191
192impl PersistedSession {
193 pub fn to_bytes(&self) -> Vec<u8> {
197 let mut b = Vec::with_capacity(512);
198
199 b.push(0x03u8); b.extend_from_slice(&self.home_dc_id.to_le_bytes());
202
203 b.push(self.dcs.len() as u8);
204 for d in &self.dcs {
205 b.extend_from_slice(&d.dc_id.to_le_bytes());
206 match &d.auth_key {
207 Some(k) => {
208 b.push(1);
209 b.extend_from_slice(k);
210 }
211 None => {
212 b.push(0);
213 }
214 }
215 b.extend_from_slice(&d.first_salt.to_le_bytes());
216 b.extend_from_slice(&d.time_offset.to_le_bytes());
217 let ab = d.addr.as_bytes();
218 b.push(ab.len() as u8);
219 b.extend_from_slice(ab);
220 b.push(d.flags.0); }
222
223 b.extend_from_slice(&self.updates_state.pts.to_le_bytes());
225 b.extend_from_slice(&self.updates_state.qts.to_le_bytes());
226 b.extend_from_slice(&self.updates_state.date.to_le_bytes());
227 b.extend_from_slice(&self.updates_state.seq.to_le_bytes());
228 let ch = &self.updates_state.channels;
229 b.extend_from_slice(&(ch.len() as u16).to_le_bytes());
230 for &(cid, cpts) in ch {
231 b.extend_from_slice(&cid.to_le_bytes());
232 b.extend_from_slice(&cpts.to_le_bytes());
233 }
234
235 b.extend_from_slice(&(self.peers.len() as u16).to_le_bytes());
237 for p in &self.peers {
238 b.extend_from_slice(&p.id.to_le_bytes());
239 b.extend_from_slice(&p.access_hash.to_le_bytes());
240 b.push(p.is_channel as u8);
241 }
242
243 b
244 }
245
246 pub fn save(&self, path: &Path) -> io::Result<()> {
251 let tmp = path.with_extension("tmp");
252 std::fs::write(&tmp, self.to_bytes())?;
253 std::fs::rename(&tmp, path)
254 }
255
256 pub fn from_bytes(buf: &[u8]) -> io::Result<Self> {
260 if buf.is_empty() {
261 return Err(io::Error::new(ErrorKind::InvalidData, "empty session data"));
262 }
263
264 let mut p = 0usize;
265
266 macro_rules! r {
267 ($n:expr) => {{
268 if p + $n > buf.len() {
269 return Err(io::Error::new(ErrorKind::InvalidData, "truncated session"));
270 }
271 let s = &buf[p..p + $n];
272 p += $n;
273 s
274 }};
275 }
276 macro_rules! r_i32 {
277 () => {
278 i32::from_le_bytes(r!(4).try_into().unwrap())
279 };
280 }
281 macro_rules! r_i64 {
282 () => {
283 i64::from_le_bytes(r!(8).try_into().unwrap())
284 };
285 }
286 macro_rules! r_u8 {
287 () => {
288 r!(1)[0]
289 };
290 }
291 macro_rules! r_u16 {
292 () => {
293 u16::from_le_bytes(r!(2).try_into().unwrap())
294 };
295 }
296
297 let first_byte = r_u8!();
298
299 let (home_dc_id, version) = if first_byte == 0x03 {
300 (r_i32!(), 3u8)
301 } else if first_byte == 0x02 {
302 (r_i32!(), 2u8)
303 } else {
304 let rest = r!(3);
305 let mut bytes = [0u8; 4];
306 bytes[0] = first_byte;
307 bytes[1..4].copy_from_slice(rest);
308 (i32::from_le_bytes(bytes), 1u8)
309 };
310
311 let dc_count = r_u8!() as usize;
312 let mut dcs = Vec::with_capacity(dc_count);
313 for _ in 0..dc_count {
314 let dc_id = r_i32!();
315 let has_key = r_u8!();
316 let auth_key = if has_key == 1 {
317 let mut k = [0u8; 256];
318 k.copy_from_slice(r!(256));
319 Some(k)
320 } else {
321 None
322 };
323 let first_salt = r_i64!();
324 let time_offset = r_i32!();
325 let al = r_u8!() as usize;
326 let addr = String::from_utf8_lossy(r!(al)).into_owned();
327 let flags = if version >= 3 {
328 DcFlags(r_u8!())
329 } else {
330 DcFlags::NONE
331 };
332 dcs.push(DcEntry {
333 dc_id,
334 addr,
335 auth_key,
336 first_salt,
337 time_offset,
338 flags,
339 });
340 }
341
342 if version < 2 {
343 return Ok(Self {
344 home_dc_id,
345 dcs,
346 updates_state: UpdatesStateSnap::default(),
347 peers: Vec::new(),
348 });
349 }
350
351 let pts = r_i32!();
352 let qts = r_i32!();
353 let date = r_i32!();
354 let seq = r_i32!();
355 let ch_count = r_u16!() as usize;
356 let mut channels = Vec::with_capacity(ch_count);
357 for _ in 0..ch_count {
358 let cid = r_i64!();
359 let cpts = r_i32!();
360 channels.push((cid, cpts));
361 }
362
363 let peer_count = r_u16!() as usize;
364 let mut peers = Vec::with_capacity(peer_count);
365 for _ in 0..peer_count {
366 let id = r_i64!();
367 let access_hash = r_i64!();
368 let is_channel = r_u8!() != 0;
369 peers.push(CachedPeer {
370 id,
371 access_hash,
372 is_channel,
373 });
374 }
375
376 Ok(Self {
377 home_dc_id,
378 dcs,
379 updates_state: UpdatesStateSnap {
380 pts,
381 qts,
382 date,
383 seq,
384 channels,
385 },
386 peers,
387 })
388 }
389
390 pub fn from_string(s: &str) -> io::Result<Self> {
392 use base64::Engine as _;
393 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
394 .decode(s.trim())
395 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
396 Self::from_bytes(&bytes)
397 }
398
399 pub fn load(path: &Path) -> io::Result<Self> {
400 let buf = std::fs::read(path)?;
401 Self::from_bytes(&buf)
402 }
403
404 pub fn dc_for(&self, dc_id: i32, prefer_ipv6: bool) -> Option<&DcEntry> {
416 let mut candidates = self.dcs.iter().filter(|d| d.dc_id == dc_id).peekable();
417 if candidates.peek().is_none() {
418 return None;
419 }
420 let cands: Vec<&DcEntry> = self.dcs.iter().filter(|d| d.dc_id == dc_id).collect();
422 cands
424 .iter()
425 .copied()
426 .find(|d| d.is_ipv6() == prefer_ipv6)
427 .or_else(|| cands.first().copied())
428 }
429
430 pub fn all_dcs_for(&self, dc_id: i32) -> impl Iterator<Item = &DcEntry> {
435 self.dcs.iter().filter(move |d| d.dc_id == dc_id)
436 }
437}
438
439impl std::fmt::Display for PersistedSession {
440 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441 use base64::Engine as _;
442 f.write_str(&base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.to_bytes()))
443 }
444}
445
446pub fn default_dc_addresses() -> HashMap<i32, String> {
450 [
451 (1, "149.154.175.53:443"),
452 (2, "149.154.167.51:443"),
453 (3, "149.154.175.100:443"),
454 (4, "149.154.167.91:443"),
455 (5, "91.108.56.130:443"),
456 ]
457 .into_iter()
458 .map(|(id, addr)| (id, addr.to_string()))
459 .collect()
460}