1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
14
15use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
16
17const VERSION_V1: u8 = 1;
18const VERSION_V2: u8 = 2;
19
20const AUTH_KEY_LEN: usize = 256;
21
22#[derive(Debug, Clone)]
23pub struct FullSession {
24 pub dc_id: u8,
25 pub ip: IpAddr,
26 pub port: u16,
27 pub auth_key: [u8; AUTH_KEY_LEN],
28 pub user_id: i64,
29 pub server_salt: i64,
30 pub seq_no: u32,
31 pub layer: u32,
32}
33
34#[derive(Debug, Clone)]
35pub struct Session {
36 pub dc_id: u8,
37 pub ip: IpAddr,
38 pub port: u16,
39 pub auth_key: [u8; AUTH_KEY_LEN],
40 pub user_id: i64,
41}
42
43#[derive(Debug, Clone)]
44pub enum StringSession {
45 V1(FullSession),
46 V2(Session),
47}
48
49#[derive(Debug, thiserror::Error)]
50pub enum StringSessionError {
51 #[error("base64 decode error: {0}")]
52 Base64(#[from] base64::DecodeError),
53 #[error("invalid or truncated session data")]
54 InvalidData,
55 #[error("unsupported version: {0}")]
56 UnsupportedVersion(u8),
57 #[error("unknown ip type byte: {0}")]
58 UnknownIpType(u8),
59}
60
61impl StringSession {
62 pub fn decode(s: &str) -> Result<Self, StringSessionError> {
64 let bytes = URL_SAFE_NO_PAD.decode(s.trim())?;
65
66 if bytes.is_empty() {
67 return Err(StringSessionError::InvalidData);
68 }
69
70 match bytes[0] {
71 VERSION_V1 => decode_v1(&bytes).map(StringSession::V1),
72 VERSION_V2 => decode_v2(&bytes).map(StringSession::V2),
73 v => Err(StringSessionError::UnsupportedVersion(v)),
74 }
75 }
76
77 pub fn encode(&self) -> String {
79 match self {
80 StringSession::V2(s) => encode_v2(s),
81 StringSession::V1(s) => encode_v2(&Session {
82 dc_id: s.dc_id,
83 ip: s.ip,
84 port: s.port,
85 auth_key: s.auth_key,
86 user_id: s.user_id,
87 }),
88 }
89 }
90
91 pub fn encode_v1(&self) -> String {
94 match self {
95 StringSession::V1(s) => encode_v1(s),
96 StringSession::V2(_) => {
97 panic!("cannot encode V2 session as V1: missing server_salt, seq_no, layer")
98 }
99 }
100 }
101
102 pub fn session(&self) -> Session {
103 match self {
104 StringSession::V2(s) => s.clone(),
105 StringSession::V1(s) => Session {
106 dc_id: s.dc_id,
107 ip: s.ip,
108 port: s.port,
109 auth_key: s.auth_key,
110 user_id: s.user_id,
111 },
112 }
113 }
114
115 pub fn full_session(&self) -> Option<&FullSession> {
116 match self {
117 StringSession::V1(s) => Some(s),
118 StringSession::V2(_) => None,
119 }
120 }
121
122 pub fn version(&self) -> u8 {
123 match self {
124 StringSession::V1(_) => VERSION_V1,
125 StringSession::V2(_) => VERSION_V2,
126 }
127 }
128}
129
130impl From<Session> for StringSession {
131 fn from(s: Session) -> Self {
132 StringSession::V2(s)
133 }
134}
135
136impl From<FullSession> for StringSession {
137 fn from(s: FullSession) -> Self {
138 StringSession::V1(s)
139 }
140}
141
142fn encode_v2(s: &Session) -> String {
143 let ip_bytes = ip_to_bytes(s.ip);
144 let ip_type = ip_type_byte(s.ip);
145
146 let mut buf = Vec::with_capacity(1 + 1 + 1 + ip_bytes.len() + 2 + 8 + AUTH_KEY_LEN);
147 buf.push(VERSION_V2);
148 buf.push(s.dc_id);
149 buf.push(ip_type);
150 buf.extend_from_slice(&ip_bytes);
151 buf.extend_from_slice(&s.port.to_be_bytes());
152 buf.extend_from_slice(&s.user_id.to_be_bytes());
153 buf.extend_from_slice(&s.auth_key);
154
155 URL_SAFE_NO_PAD.encode(&buf)
156}
157
158fn encode_v1(s: &FullSession) -> String {
159 let ip_bytes = ip_to_bytes(s.ip);
160 let ip_type = ip_type_byte(s.ip);
161
162 let mut buf = Vec::with_capacity(1 + 1 + 1 + ip_bytes.len() + 2 + 8 + 8 + 4 + 4 + AUTH_KEY_LEN);
163 buf.push(VERSION_V1);
164 buf.push(s.dc_id);
165 buf.push(ip_type);
166 buf.extend_from_slice(&ip_bytes);
167 buf.extend_from_slice(&s.port.to_be_bytes());
168 buf.extend_from_slice(&s.user_id.to_be_bytes());
169 buf.extend_from_slice(&s.server_salt.to_be_bytes());
170 buf.extend_from_slice(&s.seq_no.to_be_bytes());
171 buf.extend_from_slice(&s.layer.to_be_bytes());
172 buf.extend_from_slice(&s.auth_key);
173
174 URL_SAFE_NO_PAD.encode(&buf)
175}
176
177fn decode_v2(bytes: &[u8]) -> Result<Session, StringSessionError> {
178 let mut c = 1usize;
179
180 let dc_id = read_u8(bytes, &mut c)?;
181 let ip = read_ip(bytes, &mut c)?;
182
183 if bytes.len() < c + 2 + 8 + AUTH_KEY_LEN {
184 return Err(StringSessionError::InvalidData);
185 }
186
187 let port = read_u16_be(bytes, &mut c)?;
188 let user_id = read_i64_be(bytes, &mut c)?;
189 let auth_key = read_auth_key(bytes, &mut c)?;
190
191 Ok(Session {
192 dc_id,
193 ip,
194 port,
195 auth_key,
196 user_id,
197 })
198}
199
200fn decode_v1(bytes: &[u8]) -> Result<FullSession, StringSessionError> {
201 let mut c = 1usize;
202
203 let dc_id = read_u8(bytes, &mut c)?;
204 let ip = read_ip(bytes, &mut c)?;
205
206 if bytes.len() < c + 2 + 8 + 8 + 4 + 4 + AUTH_KEY_LEN {
207 return Err(StringSessionError::InvalidData);
208 }
209
210 let port = read_u16_be(bytes, &mut c)?;
211 let user_id = read_i64_be(bytes, &mut c)?;
212 let server_salt = read_i64_be(bytes, &mut c)?;
213 let seq_no = read_u32_be(bytes, &mut c)?;
214 let layer = read_u32_be(bytes, &mut c)?;
215 let auth_key = read_auth_key(bytes, &mut c)?;
216
217 Ok(FullSession {
218 dc_id,
219 ip,
220 port,
221 auth_key,
222 user_id,
223 server_salt,
224 seq_no,
225 layer,
226 })
227}
228
229fn read_u8(bytes: &[u8], c: &mut usize) -> Result<u8, StringSessionError> {
230 if bytes.len() < *c + 1 {
231 return Err(StringSessionError::InvalidData);
232 }
233 let v = bytes[*c];
234 *c += 1;
235 Ok(v)
236}
237
238fn read_u16_be(bytes: &[u8], c: &mut usize) -> Result<u16, StringSessionError> {
239 let v = u16::from_be_bytes(
240 bytes[*c..*c + 2]
241 .try_into()
242 .map_err(|_| StringSessionError::InvalidData)?,
243 );
244 *c += 2;
245 Ok(v)
246}
247
248fn read_u32_be(bytes: &[u8], c: &mut usize) -> Result<u32, StringSessionError> {
249 let v = u32::from_be_bytes(
250 bytes[*c..*c + 4]
251 .try_into()
252 .map_err(|_| StringSessionError::InvalidData)?,
253 );
254 *c += 4;
255 Ok(v)
256}
257
258fn read_i64_be(bytes: &[u8], c: &mut usize) -> Result<i64, StringSessionError> {
259 let v = i64::from_be_bytes(
260 bytes[*c..*c + 8]
261 .try_into()
262 .map_err(|_| StringSessionError::InvalidData)?,
263 );
264 *c += 8;
265 Ok(v)
266}
267
268fn read_auth_key(bytes: &[u8], c: &mut usize) -> Result<[u8; AUTH_KEY_LEN], StringSessionError> {
269 let key: [u8; AUTH_KEY_LEN] = bytes[*c..*c + AUTH_KEY_LEN]
270 .try_into()
271 .map_err(|_| StringSessionError::InvalidData)?;
272 *c += AUTH_KEY_LEN;
273 Ok(key)
274}
275
276fn read_ip(bytes: &[u8], c: &mut usize) -> Result<IpAddr, StringSessionError> {
277 let ip_type = read_u8(bytes, c)?;
278 match ip_type {
279 4 => {
280 if bytes.len() < *c + 4 {
281 return Err(StringSessionError::InvalidData);
282 }
283 let octets: [u8; 4] = bytes[*c..*c + 4]
284 .try_into()
285 .map_err(|_| StringSessionError::InvalidData)?;
286 *c += 4;
287 Ok(IpAddr::V4(Ipv4Addr::from(octets)))
288 }
289 6 => {
290 if bytes.len() < *c + 16 {
291 return Err(StringSessionError::InvalidData);
292 }
293 let octets: [u8; 16] = bytes[*c..*c + 16]
294 .try_into()
295 .map_err(|_| StringSessionError::InvalidData)?;
296 *c += 16;
297 Ok(IpAddr::V6(Ipv6Addr::from(octets)))
298 }
299 other => Err(StringSessionError::UnknownIpType(other)),
300 }
301}
302
303fn ip_to_bytes(ip: IpAddr) -> Vec<u8> {
304 match ip {
305 IpAddr::V4(v4) => v4.octets().to_vec(),
306 IpAddr::V6(v6) => v6.octets().to_vec(),
307 }
308}
309
310fn ip_type_byte(ip: IpAddr) -> u8 {
311 match ip {
312 IpAddr::V4(_) => 4,
313 IpAddr::V6(_) => 6,
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 fn dummy_key() -> [u8; AUTH_KEY_LEN] {
322 let mut k = [0u8; AUTH_KEY_LEN];
323 for (i, b) in k.iter_mut().enumerate() {
324 *b = i as u8;
325 }
326 k
327 }
328
329 fn ipv4() -> IpAddr {
330 IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51))
331 }
332
333 fn ipv6() -> IpAddr {
334 IpAddr::V6(Ipv6Addr::new(0x2001, 0xb28, 0xf23d, 0, 0, 0, 0, 0xa))
335 }
336
337 #[test]
338 fn v2_roundtrip_ipv4() {
339 let s = StringSession::V2(Session {
340 dc_id: 2,
341 ip: ipv4(),
342 port: 443,
343 auth_key: dummy_key(),
344 user_id: 123456789,
345 });
346
347 let encoded = s.encode();
348 let decoded = StringSession::decode(&encoded).unwrap();
349
350 assert_eq!(decoded.version(), 2);
351 let d = decoded.session();
352 assert_eq!(d.dc_id, 2);
353 assert_eq!(d.ip, ipv4());
354 assert_eq!(d.port, 443);
355 assert_eq!(d.user_id, 123456789);
356 assert_eq!(d.auth_key, dummy_key());
357 }
358
359 #[test]
360 fn v2_roundtrip_ipv6() {
361 let s = StringSession::V2(Session {
362 dc_id: 4,
363 ip: ipv6(),
364 port: 443,
365 auth_key: dummy_key(),
366 user_id: -987654321,
367 });
368
369 let encoded = s.encode();
370 let decoded = StringSession::decode(&encoded).unwrap();
371
372 assert_eq!(decoded.version(), 2);
373 let d = decoded.session();
374 assert_eq!(d.ip, ipv6());
375 assert_eq!(d.user_id, -987654321);
376 }
377
378 #[test]
379 fn v1_roundtrip_ipv4() {
380 let s = StringSession::V1(FullSession {
381 dc_id: 1,
382 ip: ipv4(),
383 port: 443,
384 auth_key: dummy_key(),
385 user_id: 111,
386 server_salt: -999,
387 seq_no: 42,
388 layer: 166,
389 });
390
391 let encoded = s.encode_v1();
392 let decoded = StringSession::decode(&encoded).unwrap();
393
394 assert_eq!(decoded.version(), 1);
395 let f = decoded.full_session().unwrap();
396 assert_eq!(f.dc_id, 1);
397 assert_eq!(f.ip, ipv4());
398 assert_eq!(f.port, 443);
399 assert_eq!(f.user_id, 111);
400 assert_eq!(f.server_salt, -999);
401 assert_eq!(f.seq_no, 42);
402 assert_eq!(f.layer, 166);
403 assert_eq!(f.auth_key, dummy_key());
404 }
405
406 #[test]
407 fn v1_roundtrip_ipv6() {
408 let s = StringSession::V1(FullSession {
409 dc_id: 5,
410 ip: ipv6(),
411 port: 443,
412 auth_key: dummy_key(),
413 user_id: 777,
414 server_salt: 12345,
415 seq_no: 10,
416 layer: 166,
417 });
418
419 let encoded = s.encode_v1();
420 let decoded = StringSession::decode(&encoded).unwrap();
421
422 assert_eq!(decoded.version(), 1);
423 let f = decoded.full_session().unwrap();
424 assert_eq!(f.ip, ipv6());
425 assert_eq!(f.layer, 166);
426 }
427
428 #[test]
429 fn v1_encode_produces_v2_when_called_via_encode() {
430 let s = StringSession::V1(FullSession {
431 dc_id: 2,
432 ip: ipv4(),
433 port: 443,
434 auth_key: dummy_key(),
435 user_id: 555,
436 server_salt: 0,
437 seq_no: 0,
438 layer: 166,
439 });
440
441 let encoded = s.encode();
442 let decoded = StringSession::decode(&encoded).unwrap();
443 assert_eq!(decoded.version(), 2);
444 }
445
446 #[test]
447 fn v2_encoded_length_ipv4() {
448 let s = StringSession::V2(Session {
449 dc_id: 1,
450 ip: ipv4(),
451 port: 443,
452 auth_key: dummy_key(),
453 user_id: 1,
454 });
455 assert_eq!(s.encode().len(), 364);
456 }
457
458 #[test]
459 fn rejects_truncated() {
460 assert!(StringSession::decode("Ag").is_err());
461 }
462
463 #[test]
464 fn rejects_unsupported_version() {
465 let bad = URL_SAFE_NO_PAD.encode(&[99u8]);
466 assert!(matches!(
467 StringSession::decode(&bad),
468 Err(StringSessionError::UnsupportedVersion(99))
469 ));
470 }
471
472 #[test]
473 fn full_session_returns_none_for_v2() {
474 let s = StringSession::V2(Session {
475 dc_id: 1,
476 ip: ipv4(),
477 port: 443,
478 auth_key: dummy_key(),
479 user_id: 1,
480 });
481 assert!(s.full_session().is_none());
482 }
483}