portalgun_lib/
lib.rs

1// SPDX-FileCopyrightText: 2023 perillamint <perillamint@silicon.moe>
2// SPDX-FileCopyrightText: 2020-2022 Alex Grinman <me@alexgr.in>
3//
4// SPDX-License-Identifier: MIT
5
6use base64::Engine;
7use rand::prelude::*;
8use serde::{Deserialize, Serialize};
9use sha2::Digest;
10
11#[derive(Serialize, Deserialize, Debug, Clone)]
12#[serde(transparent)]
13pub struct SecretKey(pub String);
14impl SecretKey {
15    pub fn generate() -> Self {
16        let mut rng = rand::thread_rng();
17        let binstr = std::iter::repeat(())
18            .map(|_| rng.sample(rand::distributions::Alphanumeric))
19            .take(22)
20            .collect::<Vec<u8>>();
21        Self(String::from_utf8(binstr).unwrap())
22    }
23
24    pub fn client_id(&self) -> ClientId {
25        ClientId(
26            base64::engine::general_purpose::STANDARD
27                .encode(sha2::Sha256::digest(self.0.as_bytes())),
28        )
29    }
30}
31
32#[derive(Serialize, Deserialize, Debug, Clone)]
33#[serde(transparent)]
34pub struct ReconnectToken(pub String);
35
36#[derive(Serialize, Deserialize, Debug, Clone)]
37#[serde(rename_all = "snake_case")]
38pub enum ServerHello {
39    Success {
40        sub_domain: String,
41        hostname: String,
42        client_id: ClientId,
43    },
44    SubDomainInUse,
45    InvalidSubDomain,
46    AuthFailed,
47    AuthInfo {
48        oidc_client_id: String,
49        oidc_discovery: String,
50        oidc_scopes: Vec<String>,
51    },
52    Error(String),
53}
54
55impl ServerHello {
56    #[allow(unused)]
57    pub fn random_domain() -> String {
58        let mut rng = rand::thread_rng();
59        let binstr = std::iter::repeat(())
60            .map(|_| rng.sample(rand::distributions::Alphanumeric))
61            .take(8)
62            .collect::<Vec<u8>>();
63
64        String::from_utf8(binstr).unwrap().to_lowercase()
65    }
66
67    #[allow(unused)]
68    pub fn prefixed_random_domain(prefix: &str) -> String {
69        format!("{}-{}", prefix, Self::random_domain())
70    }
71}
72
73#[derive(Serialize, Deserialize, Debug, Clone)]
74pub struct ClientHello {
75    /// deprecated: just send some garbage
76    id: ClientId,
77    pub sub_domain: Option<String>,
78    pub client_type: ClientType,
79    pub reconnect_token: Option<ReconnectToken>,
80}
81
82impl ClientHello {
83    pub fn generate(sub_domain: Option<String>, typ: ClientType) -> Self {
84        ClientHello {
85            id: ClientId::generate(),
86            client_type: typ,
87            sub_domain,
88            reconnect_token: None,
89        }
90    }
91
92    pub fn reconnect(reconnect_token: ReconnectToken) -> Self {
93        ClientHello {
94            id: ClientId::generate(),
95            sub_domain: None,
96            client_type: ClientType::Anonymous,
97            reconnect_token: Some(reconnect_token),
98        }
99    }
100}
101
102#[derive(Serialize, Deserialize, Debug, Clone)]
103pub enum ClientType {
104    Auth { key: SecretKey },
105    Anonymous,
106    AuthInfo,
107}
108
109#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
110#[serde(transparent)]
111pub struct ClientId(String);
112
113impl std::fmt::Display for ClientId {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        self.0.fmt(f)
116    }
117}
118impl ClientId {
119    pub fn generate() -> Self {
120        let mut id = [0u8; 32];
121        rand::thread_rng().fill_bytes(&mut id);
122        ClientId(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(id))
123    }
124
125    pub fn safe_id(self) -> ClientId {
126        ClientId(
127            base64::engine::general_purpose::STANDARD
128                .encode(sha2::Sha256::digest(self.0.as_bytes())),
129        )
130    }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Hash)]
134pub struct StreamId([u8; 8]);
135
136impl StreamId {
137    pub fn generate() -> StreamId {
138        let mut id = [0u8; 8];
139        rand::thread_rng().fill_bytes(&mut id);
140        StreamId(id)
141    }
142}
143
144impl ToString for StreamId {
145    fn to_string(&self) -> String {
146        format!(
147            "stream_{}",
148            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.0)
149        )
150    }
151}
152
153#[derive(Debug, Clone)]
154pub enum ControlPacket {
155    Init(StreamId),
156    Data(StreamId, Vec<u8>),
157    Refused(StreamId),
158    End(StreamId),
159    Ping(Option<ReconnectToken>),
160}
161
162pub const PING_INTERVAL: u64 = 30;
163
164const EMPTY_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
165const TOKEN_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]);
166
167impl ControlPacket {
168    pub fn serialize(self) -> Vec<u8> {
169        match self {
170            ControlPacket::Init(sid) => [vec![0x01], sid.0.to_vec()].concat(),
171            ControlPacket::Data(sid, data) => [vec![0x02], sid.0.to_vec(), data].concat(),
172            ControlPacket::Refused(sid) => [vec![0x03], sid.0.to_vec()].concat(),
173            ControlPacket::End(sid) => [vec![0x04], sid.0.to_vec()].concat(),
174            ControlPacket::Ping(tok) => {
175                let data = tok.map_or(EMPTY_STREAM.0.to_vec(), |t| {
176                    [TOKEN_STREAM.0.to_vec(), t.0.into_bytes()].concat()
177                });
178                [vec![0x05], data].concat()
179            }
180        }
181    }
182
183    pub fn packet_type(&self) -> &str {
184        match &self {
185            ControlPacket::Ping(_) => "PING",
186            ControlPacket::Init(_) => "INIT STREAM",
187            ControlPacket::Data(_, _) => "STREAM DATA",
188            ControlPacket::Refused(_) => "REFUSED",
189            ControlPacket::End(_) => "END STREAM",
190        }
191    }
192
193    pub fn deserialize(data: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
194        if data.len() < 9 {
195            return Err("invalid DataPacket, missing stream id".into());
196        }
197
198        let mut stream_id = [0u8; 8];
199        stream_id.clone_from_slice(&data[1..9]);
200        let stream_id = StreamId(stream_id);
201
202        let packet = match data[0] {
203            0x01 => ControlPacket::Init(stream_id),
204            0x02 => ControlPacket::Data(stream_id, data[9..].to_vec()),
205            0x03 => ControlPacket::Refused(stream_id),
206            0x04 => ControlPacket::End(stream_id),
207            0x05 => {
208                if stream_id == EMPTY_STREAM {
209                    ControlPacket::Ping(None)
210                } else {
211                    ControlPacket::Ping(Some(ReconnectToken(
212                        String::from_utf8_lossy(&data[9..]).to_string(),
213                    )))
214                }
215            }
216            _ => return Err("invalid control byte in DataPacket".into()),
217        };
218
219        Ok(packet)
220    }
221}