1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
14
15use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
16
17const VERSION_V1: u8 = 1;
18const VERSION_V2: u8 = 2;
19
20const AUTH_KEY_LEN: usize = 256;
21
22#[derive(Debug, Clone)]
23pub struct FullSession {
24 pub dc_id: u8,
25 pub ip: IpAddr,
26 pub port: u16,
27 pub auth_key: [u8; AUTH_KEY_LEN],
28 pub user_id: i64,
29 pub server_salt: i64,
30 pub seq_no: u32,
31 pub layer: u32,
32}
33
34#[derive(Debug, Clone)]
35pub struct Session {
36 pub dc_id: u8,
37 pub ip: IpAddr,
38 pub port: u16,
39 pub auth_key: [u8; AUTH_KEY_LEN],
40 pub user_id: i64,
41}
42
43#[derive(Debug, Clone)]
44pub enum StringSession {
45 V1(FullSession),
46 V2(Session),
47}
48
49#[derive(Debug, thiserror::Error)]
50pub enum StringSessionError {
51 #[error("base64 decode error: {0}")]
52 Base64(#[from] base64::DecodeError),
53 #[error("invalid or truncated session data")]
54 InvalidData,
55 #[error("unsupported version: {0}")]
56 UnsupportedVersion(u8),
57 #[error("unknown ip type byte: {0}")]
58 UnknownIpType(u8),
59}
60
61impl StringSession {
62 pub fn decode(s: &str) -> Result<Self, StringSessionError> {
64 let bytes = URL_SAFE_NO_PAD.decode(s.trim())?;
65
66 if bytes.is_empty() {
67 return Err(StringSessionError::InvalidData);
68 }
69
70 match bytes[0] {
71 VERSION_V1 => decode_v1(&bytes).map(StringSession::V1),
72 VERSION_V2 => decode_v2(&bytes).map(StringSession::V2),
73 v => Err(StringSessionError::UnsupportedVersion(v)),
74 }
75 }
76
77 pub fn encode(&self) -> String {
79 match self {
80 StringSession::V2(s) => encode_v2(s),
81 StringSession::V1(s) => encode_v2(&Session {
82 dc_id: s.dc_id,
83 ip: s.ip,
84 port: s.port,
85 auth_key: s.auth_key,
86 user_id: s.user_id,
87 }),
88 }
89 }
90
91 pub fn encode_v1(&self) -> String {
94 match self {
95 StringSession::V1(s) => encode_v1(s),
96 StringSession::V2(_) => {
97 panic!("cannot encode V2 session as V1: missing server_salt, seq_no, layer")
98 }
99 }
100 }
101
102 pub fn session(&self) -> Session {
106 match self {
107 StringSession::V2(s) => s.clone(),
108 StringSession::V1(s) => Session {
109 dc_id: s.dc_id,
110 ip: s.ip,
111 port: s.port,
112 auth_key: s.auth_key,
113 user_id: s.user_id,
114 },
115 }
116 }
117
118 pub fn full_session(&self) -> Option<&FullSession> {
121 match self {
122 StringSession::V1(s) => Some(s),
123 StringSession::V2(_) => None,
124 }
125 }
126
127 pub fn version(&self) -> u8 {
130 match self {
131 StringSession::V1(_) => VERSION_V1,
132 StringSession::V2(_) => VERSION_V2,
133 }
134 }
135}
136
137impl From<Session> for StringSession {
138 fn from(s: Session) -> Self {
139 StringSession::V2(s)
140 }
141}
142
143impl From<FullSession> for StringSession {
144 fn from(s: FullSession) -> Self {
145 StringSession::V1(s)
146 }
147}
148
149fn encode_v2(s: &Session) -> String {
150 let ip_bytes = ip_to_bytes(s.ip);
151 let ip_type = ip_type_byte(s.ip);
152
153 let mut buf = Vec::with_capacity(1 + 1 + 1 + ip_bytes.len() + 2 + 8 + AUTH_KEY_LEN);
154 buf.push(VERSION_V2);
155 buf.push(s.dc_id);
156 buf.push(ip_type);
157 buf.extend_from_slice(&ip_bytes);
158 buf.extend_from_slice(&s.port.to_be_bytes());
159 buf.extend_from_slice(&s.user_id.to_be_bytes());
160 buf.extend_from_slice(&s.auth_key);
161
162 URL_SAFE_NO_PAD.encode(&buf)
163}
164
165fn encode_v1(s: &FullSession) -> String {
166 let ip_bytes = ip_to_bytes(s.ip);
167 let ip_type = ip_type_byte(s.ip);
168
169 let mut buf = Vec::with_capacity(1 + 1 + 1 + ip_bytes.len() + 2 + 8 + 8 + 4 + 4 + AUTH_KEY_LEN);
170 buf.push(VERSION_V1);
171 buf.push(s.dc_id);
172 buf.push(ip_type);
173 buf.extend_from_slice(&ip_bytes);
174 buf.extend_from_slice(&s.port.to_be_bytes());
175 buf.extend_from_slice(&s.user_id.to_be_bytes());
176 buf.extend_from_slice(&s.server_salt.to_be_bytes());
177 buf.extend_from_slice(&s.seq_no.to_be_bytes());
178 buf.extend_from_slice(&s.layer.to_be_bytes());
179 buf.extend_from_slice(&s.auth_key);
180
181 URL_SAFE_NO_PAD.encode(&buf)
182}
183
184fn decode_v2(bytes: &[u8]) -> Result<Session, StringSessionError> {
185 let mut c = 1usize;
186
187 let dc_id = read_u8(bytes, &mut c)?;
188 let ip = read_ip(bytes, &mut c)?;
189
190 if bytes.len() < c + 2 + 8 + AUTH_KEY_LEN {
191 return Err(StringSessionError::InvalidData);
192 }
193
194 let port = read_u16_be(bytes, &mut c)?;
195 let user_id = read_i64_be(bytes, &mut c)?;
196 let auth_key = read_auth_key(bytes, &mut c)?;
197
198 Ok(Session {
199 dc_id,
200 ip,
201 port,
202 auth_key,
203 user_id,
204 })
205}
206
207fn decode_v1(bytes: &[u8]) -> Result<FullSession, StringSessionError> {
208 let mut c = 1usize;
209
210 let dc_id = read_u8(bytes, &mut c)?;
211 let ip = read_ip(bytes, &mut c)?;
212
213 if bytes.len() < c + 2 + 8 + 8 + 4 + 4 + AUTH_KEY_LEN {
214 return Err(StringSessionError::InvalidData);
215 }
216
217 let port = read_u16_be(bytes, &mut c)?;
218 let user_id = read_i64_be(bytes, &mut c)?;
219 let server_salt = read_i64_be(bytes, &mut c)?;
220 let seq_no = read_u32_be(bytes, &mut c)?;
221 let layer = read_u32_be(bytes, &mut c)?;
222 let auth_key = read_auth_key(bytes, &mut c)?;
223
224 Ok(FullSession {
225 dc_id,
226 ip,
227 port,
228 auth_key,
229 user_id,
230 server_salt,
231 seq_no,
232 layer,
233 })
234}
235
236fn read_u8(bytes: &[u8], c: &mut usize) -> Result<u8, StringSessionError> {
237 if bytes.len() < *c + 1 {
238 return Err(StringSessionError::InvalidData);
239 }
240 let v = bytes[*c];
241 *c += 1;
242 Ok(v)
243}
244
245fn read_u16_be(bytes: &[u8], c: &mut usize) -> Result<u16, StringSessionError> {
246 let v = u16::from_be_bytes(
247 bytes[*c..*c + 2]
248 .try_into()
249 .map_err(|_| StringSessionError::InvalidData)?,
250 );
251 *c += 2;
252 Ok(v)
253}
254
255fn read_u32_be(bytes: &[u8], c: &mut usize) -> Result<u32, StringSessionError> {
256 let v = u32::from_be_bytes(
257 bytes[*c..*c + 4]
258 .try_into()
259 .map_err(|_| StringSessionError::InvalidData)?,
260 );
261 *c += 4;
262 Ok(v)
263}
264
265fn read_i64_be(bytes: &[u8], c: &mut usize) -> Result<i64, StringSessionError> {
266 let v = i64::from_be_bytes(
267 bytes[*c..*c + 8]
268 .try_into()
269 .map_err(|_| StringSessionError::InvalidData)?,
270 );
271 *c += 8;
272 Ok(v)
273}
274
275fn read_auth_key(bytes: &[u8], c: &mut usize) -> Result<[u8; AUTH_KEY_LEN], StringSessionError> {
276 let key: [u8; AUTH_KEY_LEN] = bytes[*c..*c + AUTH_KEY_LEN]
277 .try_into()
278 .map_err(|_| StringSessionError::InvalidData)?;
279 *c += AUTH_KEY_LEN;
280 Ok(key)
281}
282
283fn read_ip(bytes: &[u8], c: &mut usize) -> Result<IpAddr, StringSessionError> {
284 let ip_type = read_u8(bytes, c)?;
285 match ip_type {
286 4 => {
287 if bytes.len() < *c + 4 {
288 return Err(StringSessionError::InvalidData);
289 }
290 let octets: [u8; 4] = bytes[*c..*c + 4]
291 .try_into()
292 .map_err(|_| StringSessionError::InvalidData)?;
293 *c += 4;
294 Ok(IpAddr::V4(Ipv4Addr::from(octets)))
295 }
296 6 => {
297 if bytes.len() < *c + 16 {
298 return Err(StringSessionError::InvalidData);
299 }
300 let octets: [u8; 16] = bytes[*c..*c + 16]
301 .try_into()
302 .map_err(|_| StringSessionError::InvalidData)?;
303 *c += 16;
304 Ok(IpAddr::V6(Ipv6Addr::from(octets)))
305 }
306 other => Err(StringSessionError::UnknownIpType(other)),
307 }
308}
309
310fn ip_to_bytes(ip: IpAddr) -> Vec<u8> {
311 match ip {
312 IpAddr::V4(v4) => v4.octets().to_vec(),
313 IpAddr::V6(v6) => v6.octets().to_vec(),
314 }
315}
316
317fn ip_type_byte(ip: IpAddr) -> u8 {
318 match ip {
319 IpAddr::V4(_) => 4,
320 IpAddr::V6(_) => 6,
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 fn dummy_key() -> [u8; AUTH_KEY_LEN] {
329 let mut k = [0u8; AUTH_KEY_LEN];
330 for (i, b) in k.iter_mut().enumerate() {
331 *b = i as u8;
332 }
333 k
334 }
335
336 fn ipv4() -> IpAddr {
337 IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51))
338 }
339
340 fn ipv6() -> IpAddr {
341 IpAddr::V6(Ipv6Addr::new(0x2001, 0xb28, 0xf23d, 0, 0, 0, 0, 0xa))
342 }
343
344 #[test]
345 fn v2_roundtrip_ipv4() {
346 let s = StringSession::V2(Session {
347 dc_id: 2,
348 ip: ipv4(),
349 port: 443,
350 auth_key: dummy_key(),
351 user_id: 123456789,
352 });
353
354 let encoded = s.encode();
355 let decoded = StringSession::decode(&encoded).unwrap();
356
357 assert_eq!(decoded.version(), 2);
358 let d = decoded.session();
359 assert_eq!(d.dc_id, 2);
360 assert_eq!(d.ip, ipv4());
361 assert_eq!(d.port, 443);
362 assert_eq!(d.user_id, 123456789);
363 assert_eq!(d.auth_key, dummy_key());
364 }
365
366 #[test]
367 fn v2_roundtrip_ipv6() {
368 let s = StringSession::V2(Session {
369 dc_id: 4,
370 ip: ipv6(),
371 port: 443,
372 auth_key: dummy_key(),
373 user_id: -987654321,
374 });
375
376 let encoded = s.encode();
377 let decoded = StringSession::decode(&encoded).unwrap();
378
379 assert_eq!(decoded.version(), 2);
380 let d = decoded.session();
381 assert_eq!(d.ip, ipv6());
382 assert_eq!(d.user_id, -987654321);
383 }
384
385 #[test]
386 fn v1_roundtrip_ipv4() {
387 let s = StringSession::V1(FullSession {
388 dc_id: 1,
389 ip: ipv4(),
390 port: 443,
391 auth_key: dummy_key(),
392 user_id: 111,
393 server_salt: -999,
394 seq_no: 42,
395 layer: 166,
396 });
397
398 let encoded = s.encode_v1();
399 let decoded = StringSession::decode(&encoded).unwrap();
400
401 assert_eq!(decoded.version(), 1);
402 let f = decoded.full_session().unwrap();
403 assert_eq!(f.dc_id, 1);
404 assert_eq!(f.ip, ipv4());
405 assert_eq!(f.port, 443);
406 assert_eq!(f.user_id, 111);
407 assert_eq!(f.server_salt, -999);
408 assert_eq!(f.seq_no, 42);
409 assert_eq!(f.layer, 166);
410 assert_eq!(f.auth_key, dummy_key());
411 }
412
413 #[test]
414 fn v1_roundtrip_ipv6() {
415 let s = StringSession::V1(FullSession {
416 dc_id: 5,
417 ip: ipv6(),
418 port: 443,
419 auth_key: dummy_key(),
420 user_id: 777,
421 server_salt: 12345,
422 seq_no: 10,
423 layer: 166,
424 });
425
426 let encoded = s.encode_v1();
427 let decoded = StringSession::decode(&encoded).unwrap();
428
429 assert_eq!(decoded.version(), 1);
430 let f = decoded.full_session().unwrap();
431 assert_eq!(f.ip, ipv6());
432 assert_eq!(f.layer, 166);
433 }
434
435 #[test]
436 fn v1_encode_produces_v2_when_called_via_encode() {
437 let s = StringSession::V1(FullSession {
438 dc_id: 2,
439 ip: ipv4(),
440 port: 443,
441 auth_key: dummy_key(),
442 user_id: 555,
443 server_salt: 0,
444 seq_no: 0,
445 layer: 166,
446 });
447
448 let encoded = s.encode();
449 let decoded = StringSession::decode(&encoded).unwrap();
450 assert_eq!(decoded.version(), 2);
451 }
452
453 #[test]
454 fn v2_encoded_length_ipv4() {
455 let s = StringSession::V2(Session {
456 dc_id: 1,
457 ip: ipv4(),
458 port: 443,
459 auth_key: dummy_key(),
460 user_id: 1,
461 });
462 assert_eq!(s.encode().len(), 364);
463 }
464
465 #[test]
466 fn rejects_truncated() {
467 assert!(StringSession::decode("Ag").is_err());
468 }
469
470 #[test]
471 fn rejects_unsupported_version() {
472 let bad = URL_SAFE_NO_PAD.encode(&[99u8]);
473 assert!(matches!(
474 StringSession::decode(&bad),
475 Err(StringSessionError::UnsupportedVersion(99))
476 ));
477 }
478
479 #[test]
480 fn full_session_returns_none_for_v2() {
481 let s = StringSession::V2(Session {
482 dc_id: 1,
483 ip: ipv4(),
484 port: 443,
485 auth_key: dummy_key(),
486 user_id: 1,
487 });
488 assert!(s.full_session().is_none());
489 }
490}