Skip to main content

ferogram_session/
string_session.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2//
3// ferogram: async Telegram MTProto client in Rust
4// https://github.com/ankit-chaubey/ferogram
5//
6// Licensed under either the MIT License or the Apache License 2.0.
7// See the LICENSE-MIT or LICENSE-APACHE file in this repository:
8// https://github.com/ankit-chaubey/ferogram
9//
10// Feel free to use, modify, and share this code.
11// Please keep this notice when redistributing.
12
13use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
14
15use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
16
17const VERSION_V1: u8 = 1;
18const VERSION_V2: u8 = 2;
19
20const AUTH_KEY_LEN: usize = 256;
21
22#[derive(Debug, Clone)]
23pub struct FullSession {
24    pub dc_id: u8,
25    pub ip: IpAddr,
26    pub port: u16,
27    pub auth_key: [u8; AUTH_KEY_LEN],
28    pub user_id: i64,
29    pub server_salt: i64,
30    pub seq_no: u32,
31    pub layer: u32,
32}
33
34#[derive(Debug, Clone)]
35pub struct Session {
36    pub dc_id: u8,
37    pub ip: IpAddr,
38    pub port: u16,
39    pub auth_key: [u8; AUTH_KEY_LEN],
40    pub user_id: i64,
41}
42
43#[derive(Debug, Clone)]
44pub enum StringSession {
45    V1(FullSession),
46    V2(Session),
47}
48
49#[derive(Debug, thiserror::Error)]
50pub enum StringSessionError {
51    #[error("base64 decode error: {0}")]
52    Base64(#[from] base64::DecodeError),
53    #[error("invalid or truncated session data")]
54    InvalidData,
55    #[error("unsupported version: {0}")]
56    UnsupportedVersion(u8),
57    #[error("unknown ip type byte: {0}")]
58    UnknownIpType(u8),
59}
60
61impl StringSession {
62    /// Decode a string session. Auto-detects V1 or V2 from the version byte.
63    pub fn decode(s: &str) -> Result<Self, StringSessionError> {
64        let bytes = URL_SAFE_NO_PAD.decode(s.trim())?;
65
66        if bytes.is_empty() {
67            return Err(StringSessionError::InvalidData);
68        }
69
70        match bytes[0] {
71            VERSION_V1 => decode_v1(&bytes).map(StringSession::V1),
72            VERSION_V2 => decode_v2(&bytes).map(StringSession::V2),
73            v => Err(StringSessionError::UnsupportedVersion(v)),
74        }
75    }
76
77    /// Encode as V2 (minimal). This is the default.
78    pub fn encode(&self) -> String {
79        match self {
80            StringSession::V2(s) => encode_v2(s),
81            StringSession::V1(s) => encode_v2(&Session {
82                dc_id: s.dc_id,
83                ip: s.ip,
84                port: s.port,
85                auth_key: s.auth_key,
86                user_id: s.user_id,
87            }),
88        }
89    }
90
91    /// Encode as V1 (full session with salt, seq_no, layer).
92    /// Use this for manual transfer or when full state is needed.
93    pub fn encode_v1(&self) -> String {
94        match self {
95            StringSession::V1(s) => encode_v1(s),
96            StringSession::V2(_) => {
97                panic!("cannot encode V2 session as V1: missing server_salt, seq_no, layer")
98            }
99        }
100    }
101
102    /// The minimal V2 fields, regardless of which version was decoded. For a
103    /// V1 session this drops `server_salt`/`seq_no`/`layer`; use
104    /// [`Self::full_session`] if you need those.
105    pub fn session(&self) -> Session {
106        match self {
107            StringSession::V2(s) => s.clone(),
108            StringSession::V1(s) => Session {
109                dc_id: s.dc_id,
110                ip: s.ip,
111                port: s.port,
112                auth_key: s.auth_key,
113                user_id: s.user_id,
114            },
115        }
116    }
117
118    /// The full V1 fields, or `None` if this is a V2 session (V2 never
119    /// carried `server_salt`/`seq_no`/`layer` to begin with).
120    pub fn full_session(&self) -> Option<&FullSession> {
121        match self {
122            StringSession::V1(s) => Some(s),
123            StringSession::V2(_) => None,
124        }
125    }
126
127    /// Which version this session was decoded as (1 or 2), independent of
128    /// which version [`Self::encode`] would produce.
129    pub fn version(&self) -> u8 {
130        match self {
131            StringSession::V1(_) => VERSION_V1,
132            StringSession::V2(_) => VERSION_V2,
133        }
134    }
135}
136
137impl From<Session> for StringSession {
138    fn from(s: Session) -> Self {
139        StringSession::V2(s)
140    }
141}
142
143impl From<FullSession> for StringSession {
144    fn from(s: FullSession) -> Self {
145        StringSession::V1(s)
146    }
147}
148
149fn encode_v2(s: &Session) -> String {
150    let ip_bytes = ip_to_bytes(s.ip);
151    let ip_type = ip_type_byte(s.ip);
152
153    let mut buf = Vec::with_capacity(1 + 1 + 1 + ip_bytes.len() + 2 + 8 + AUTH_KEY_LEN);
154    buf.push(VERSION_V2);
155    buf.push(s.dc_id);
156    buf.push(ip_type);
157    buf.extend_from_slice(&ip_bytes);
158    buf.extend_from_slice(&s.port.to_be_bytes());
159    buf.extend_from_slice(&s.user_id.to_be_bytes());
160    buf.extend_from_slice(&s.auth_key);
161
162    URL_SAFE_NO_PAD.encode(&buf)
163}
164
165fn encode_v1(s: &FullSession) -> String {
166    let ip_bytes = ip_to_bytes(s.ip);
167    let ip_type = ip_type_byte(s.ip);
168
169    let mut buf = Vec::with_capacity(1 + 1 + 1 + ip_bytes.len() + 2 + 8 + 8 + 4 + 4 + AUTH_KEY_LEN);
170    buf.push(VERSION_V1);
171    buf.push(s.dc_id);
172    buf.push(ip_type);
173    buf.extend_from_slice(&ip_bytes);
174    buf.extend_from_slice(&s.port.to_be_bytes());
175    buf.extend_from_slice(&s.user_id.to_be_bytes());
176    buf.extend_from_slice(&s.server_salt.to_be_bytes());
177    buf.extend_from_slice(&s.seq_no.to_be_bytes());
178    buf.extend_from_slice(&s.layer.to_be_bytes());
179    buf.extend_from_slice(&s.auth_key);
180
181    URL_SAFE_NO_PAD.encode(&buf)
182}
183
184fn decode_v2(bytes: &[u8]) -> Result<Session, StringSessionError> {
185    let mut c = 1usize;
186
187    let dc_id = read_u8(bytes, &mut c)?;
188    let ip = read_ip(bytes, &mut c)?;
189
190    if bytes.len() < c + 2 + 8 + AUTH_KEY_LEN {
191        return Err(StringSessionError::InvalidData);
192    }
193
194    let port = read_u16_be(bytes, &mut c)?;
195    let user_id = read_i64_be(bytes, &mut c)?;
196    let auth_key = read_auth_key(bytes, &mut c)?;
197
198    Ok(Session {
199        dc_id,
200        ip,
201        port,
202        auth_key,
203        user_id,
204    })
205}
206
207fn decode_v1(bytes: &[u8]) -> Result<FullSession, StringSessionError> {
208    let mut c = 1usize;
209
210    let dc_id = read_u8(bytes, &mut c)?;
211    let ip = read_ip(bytes, &mut c)?;
212
213    if bytes.len() < c + 2 + 8 + 8 + 4 + 4 + AUTH_KEY_LEN {
214        return Err(StringSessionError::InvalidData);
215    }
216
217    let port = read_u16_be(bytes, &mut c)?;
218    let user_id = read_i64_be(bytes, &mut c)?;
219    let server_salt = read_i64_be(bytes, &mut c)?;
220    let seq_no = read_u32_be(bytes, &mut c)?;
221    let layer = read_u32_be(bytes, &mut c)?;
222    let auth_key = read_auth_key(bytes, &mut c)?;
223
224    Ok(FullSession {
225        dc_id,
226        ip,
227        port,
228        auth_key,
229        user_id,
230        server_salt,
231        seq_no,
232        layer,
233    })
234}
235
236fn read_u8(bytes: &[u8], c: &mut usize) -> Result<u8, StringSessionError> {
237    if bytes.len() < *c + 1 {
238        return Err(StringSessionError::InvalidData);
239    }
240    let v = bytes[*c];
241    *c += 1;
242    Ok(v)
243}
244
245fn read_u16_be(bytes: &[u8], c: &mut usize) -> Result<u16, StringSessionError> {
246    let v = u16::from_be_bytes(
247        bytes[*c..*c + 2]
248            .try_into()
249            .map_err(|_| StringSessionError::InvalidData)?,
250    );
251    *c += 2;
252    Ok(v)
253}
254
255fn read_u32_be(bytes: &[u8], c: &mut usize) -> Result<u32, StringSessionError> {
256    let v = u32::from_be_bytes(
257        bytes[*c..*c + 4]
258            .try_into()
259            .map_err(|_| StringSessionError::InvalidData)?,
260    );
261    *c += 4;
262    Ok(v)
263}
264
265fn read_i64_be(bytes: &[u8], c: &mut usize) -> Result<i64, StringSessionError> {
266    let v = i64::from_be_bytes(
267        bytes[*c..*c + 8]
268            .try_into()
269            .map_err(|_| StringSessionError::InvalidData)?,
270    );
271    *c += 8;
272    Ok(v)
273}
274
275fn read_auth_key(bytes: &[u8], c: &mut usize) -> Result<[u8; AUTH_KEY_LEN], StringSessionError> {
276    let key: [u8; AUTH_KEY_LEN] = bytes[*c..*c + AUTH_KEY_LEN]
277        .try_into()
278        .map_err(|_| StringSessionError::InvalidData)?;
279    *c += AUTH_KEY_LEN;
280    Ok(key)
281}
282
283fn read_ip(bytes: &[u8], c: &mut usize) -> Result<IpAddr, StringSessionError> {
284    let ip_type = read_u8(bytes, c)?;
285    match ip_type {
286        4 => {
287            if bytes.len() < *c + 4 {
288                return Err(StringSessionError::InvalidData);
289            }
290            let octets: [u8; 4] = bytes[*c..*c + 4]
291                .try_into()
292                .map_err(|_| StringSessionError::InvalidData)?;
293            *c += 4;
294            Ok(IpAddr::V4(Ipv4Addr::from(octets)))
295        }
296        6 => {
297            if bytes.len() < *c + 16 {
298                return Err(StringSessionError::InvalidData);
299            }
300            let octets: [u8; 16] = bytes[*c..*c + 16]
301                .try_into()
302                .map_err(|_| StringSessionError::InvalidData)?;
303            *c += 16;
304            Ok(IpAddr::V6(Ipv6Addr::from(octets)))
305        }
306        other => Err(StringSessionError::UnknownIpType(other)),
307    }
308}
309
310fn ip_to_bytes(ip: IpAddr) -> Vec<u8> {
311    match ip {
312        IpAddr::V4(v4) => v4.octets().to_vec(),
313        IpAddr::V6(v6) => v6.octets().to_vec(),
314    }
315}
316
317fn ip_type_byte(ip: IpAddr) -> u8 {
318    match ip {
319        IpAddr::V4(_) => 4,
320        IpAddr::V6(_) => 6,
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    fn dummy_key() -> [u8; AUTH_KEY_LEN] {
329        let mut k = [0u8; AUTH_KEY_LEN];
330        for (i, b) in k.iter_mut().enumerate() {
331            *b = i as u8;
332        }
333        k
334    }
335
336    fn ipv4() -> IpAddr {
337        IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51))
338    }
339
340    fn ipv6() -> IpAddr {
341        IpAddr::V6(Ipv6Addr::new(0x2001, 0xb28, 0xf23d, 0, 0, 0, 0, 0xa))
342    }
343
344    #[test]
345    fn v2_roundtrip_ipv4() {
346        let s = StringSession::V2(Session {
347            dc_id: 2,
348            ip: ipv4(),
349            port: 443,
350            auth_key: dummy_key(),
351            user_id: 123456789,
352        });
353
354        let encoded = s.encode();
355        let decoded = StringSession::decode(&encoded).unwrap();
356
357        assert_eq!(decoded.version(), 2);
358        let d = decoded.session();
359        assert_eq!(d.dc_id, 2);
360        assert_eq!(d.ip, ipv4());
361        assert_eq!(d.port, 443);
362        assert_eq!(d.user_id, 123456789);
363        assert_eq!(d.auth_key, dummy_key());
364    }
365
366    #[test]
367    fn v2_roundtrip_ipv6() {
368        let s = StringSession::V2(Session {
369            dc_id: 4,
370            ip: ipv6(),
371            port: 443,
372            auth_key: dummy_key(),
373            user_id: -987654321,
374        });
375
376        let encoded = s.encode();
377        let decoded = StringSession::decode(&encoded).unwrap();
378
379        assert_eq!(decoded.version(), 2);
380        let d = decoded.session();
381        assert_eq!(d.ip, ipv6());
382        assert_eq!(d.user_id, -987654321);
383    }
384
385    #[test]
386    fn v1_roundtrip_ipv4() {
387        let s = StringSession::V1(FullSession {
388            dc_id: 1,
389            ip: ipv4(),
390            port: 443,
391            auth_key: dummy_key(),
392            user_id: 111,
393            server_salt: -999,
394            seq_no: 42,
395            layer: 166,
396        });
397
398        let encoded = s.encode_v1();
399        let decoded = StringSession::decode(&encoded).unwrap();
400
401        assert_eq!(decoded.version(), 1);
402        let f = decoded.full_session().unwrap();
403        assert_eq!(f.dc_id, 1);
404        assert_eq!(f.ip, ipv4());
405        assert_eq!(f.port, 443);
406        assert_eq!(f.user_id, 111);
407        assert_eq!(f.server_salt, -999);
408        assert_eq!(f.seq_no, 42);
409        assert_eq!(f.layer, 166);
410        assert_eq!(f.auth_key, dummy_key());
411    }
412
413    #[test]
414    fn v1_roundtrip_ipv6() {
415        let s = StringSession::V1(FullSession {
416            dc_id: 5,
417            ip: ipv6(),
418            port: 443,
419            auth_key: dummy_key(),
420            user_id: 777,
421            server_salt: 12345,
422            seq_no: 10,
423            layer: 166,
424        });
425
426        let encoded = s.encode_v1();
427        let decoded = StringSession::decode(&encoded).unwrap();
428
429        assert_eq!(decoded.version(), 1);
430        let f = decoded.full_session().unwrap();
431        assert_eq!(f.ip, ipv6());
432        assert_eq!(f.layer, 166);
433    }
434
435    #[test]
436    fn v1_encode_produces_v2_when_called_via_encode() {
437        let s = StringSession::V1(FullSession {
438            dc_id: 2,
439            ip: ipv4(),
440            port: 443,
441            auth_key: dummy_key(),
442            user_id: 555,
443            server_salt: 0,
444            seq_no: 0,
445            layer: 166,
446        });
447
448        let encoded = s.encode();
449        let decoded = StringSession::decode(&encoded).unwrap();
450        assert_eq!(decoded.version(), 2);
451    }
452
453    #[test]
454    fn v2_encoded_length_ipv4() {
455        let s = StringSession::V2(Session {
456            dc_id: 1,
457            ip: ipv4(),
458            port: 443,
459            auth_key: dummy_key(),
460            user_id: 1,
461        });
462        assert_eq!(s.encode().len(), 364);
463    }
464
465    #[test]
466    fn rejects_truncated() {
467        assert!(StringSession::decode("Ag").is_err());
468    }
469
470    #[test]
471    fn rejects_unsupported_version() {
472        let bad = URL_SAFE_NO_PAD.encode(&[99u8]);
473        assert!(matches!(
474            StringSession::decode(&bad),
475            Err(StringSessionError::UnsupportedVersion(99))
476        ));
477    }
478
479    #[test]
480    fn full_session_returns_none_for_v2() {
481        let s = StringSession::V2(Session {
482            dc_id: 1,
483            ip: ipv4(),
484            port: 443,
485            auth_key: dummy_key(),
486            user_id: 1,
487        });
488        assert!(s.full_session().is_none());
489    }
490}