1use base64::Engine;
7use rand::prelude::*;
8use serde::{Deserialize, Serialize};
9use sha2::Digest;
10
11#[derive(Serialize, Deserialize, Debug, Clone)]
12#[serde(transparent)]
13pub struct SecretKey(pub String);
14impl SecretKey {
15 pub fn generate() -> Self {
16 let mut rng = rand::thread_rng();
17 let binstr = std::iter::repeat(())
18 .map(|_| rng.sample(rand::distributions::Alphanumeric))
19 .take(22)
20 .collect::<Vec<u8>>();
21 Self(String::from_utf8(binstr).unwrap())
22 }
23
24 pub fn client_id(&self) -> ClientId {
25 ClientId(
26 base64::engine::general_purpose::STANDARD
27 .encode(sha2::Sha256::digest(self.0.as_bytes())),
28 )
29 }
30}
31
32#[derive(Serialize, Deserialize, Debug, Clone)]
33#[serde(transparent)]
34pub struct ReconnectToken(pub String);
35
36#[derive(Serialize, Deserialize, Debug, Clone)]
37#[serde(rename_all = "snake_case")]
38pub enum ServerHello {
39 Success {
40 sub_domain: String,
41 hostname: String,
42 client_id: ClientId,
43 },
44 SubDomainInUse,
45 InvalidSubDomain,
46 AuthFailed,
47 AuthInfo {
48 oidc_client_id: String,
49 oidc_discovery: String,
50 oidc_scopes: Vec<String>,
51 },
52 Error(String),
53}
54
55impl ServerHello {
56 #[allow(unused)]
57 pub fn random_domain() -> String {
58 let mut rng = rand::thread_rng();
59 let binstr = std::iter::repeat(())
60 .map(|_| rng.sample(rand::distributions::Alphanumeric))
61 .take(8)
62 .collect::<Vec<u8>>();
63
64 String::from_utf8(binstr).unwrap().to_lowercase()
65 }
66
67 #[allow(unused)]
68 pub fn prefixed_random_domain(prefix: &str) -> String {
69 format!("{}-{}", prefix, Self::random_domain())
70 }
71}
72
73#[derive(Serialize, Deserialize, Debug, Clone)]
74pub struct ClientHello {
75 id: ClientId,
77 pub sub_domain: Option<String>,
78 pub client_type: ClientType,
79 pub reconnect_token: Option<ReconnectToken>,
80}
81
82impl ClientHello {
83 pub fn generate(sub_domain: Option<String>, typ: ClientType) -> Self {
84 ClientHello {
85 id: ClientId::generate(),
86 client_type: typ,
87 sub_domain,
88 reconnect_token: None,
89 }
90 }
91
92 pub fn reconnect(reconnect_token: ReconnectToken) -> Self {
93 ClientHello {
94 id: ClientId::generate(),
95 sub_domain: None,
96 client_type: ClientType::Anonymous,
97 reconnect_token: Some(reconnect_token),
98 }
99 }
100}
101
102#[derive(Serialize, Deserialize, Debug, Clone)]
103pub enum ClientType {
104 Auth { key: SecretKey },
105 Anonymous,
106 AuthInfo,
107}
108
109#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
110#[serde(transparent)]
111pub struct ClientId(String);
112
113impl std::fmt::Display for ClientId {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 self.0.fmt(f)
116 }
117}
118impl ClientId {
119 pub fn generate() -> Self {
120 let mut id = [0u8; 32];
121 rand::thread_rng().fill_bytes(&mut id);
122 ClientId(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(id))
123 }
124
125 pub fn safe_id(self) -> ClientId {
126 ClientId(
127 base64::engine::general_purpose::STANDARD
128 .encode(sha2::Sha256::digest(self.0.as_bytes())),
129 )
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Hash)]
134pub struct StreamId([u8; 8]);
135
136impl StreamId {
137 pub fn generate() -> StreamId {
138 let mut id = [0u8; 8];
139 rand::thread_rng().fill_bytes(&mut id);
140 StreamId(id)
141 }
142}
143
144impl ToString for StreamId {
145 fn to_string(&self) -> String {
146 format!(
147 "stream_{}",
148 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(self.0)
149 )
150 }
151}
152
153#[derive(Debug, Clone)]
154pub enum ControlPacket {
155 Init(StreamId),
156 Data(StreamId, Vec<u8>),
157 Refused(StreamId),
158 End(StreamId),
159 Ping(Option<ReconnectToken>),
160}
161
162pub const PING_INTERVAL: u64 = 30;
163
164const EMPTY_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
165const TOKEN_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]);
166
167impl ControlPacket {
168 pub fn serialize(self) -> Vec<u8> {
169 match self {
170 ControlPacket::Init(sid) => [vec![0x01], sid.0.to_vec()].concat(),
171 ControlPacket::Data(sid, data) => [vec![0x02], sid.0.to_vec(), data].concat(),
172 ControlPacket::Refused(sid) => [vec![0x03], sid.0.to_vec()].concat(),
173 ControlPacket::End(sid) => [vec![0x04], sid.0.to_vec()].concat(),
174 ControlPacket::Ping(tok) => {
175 let data = tok.map_or(EMPTY_STREAM.0.to_vec(), |t| {
176 [TOKEN_STREAM.0.to_vec(), t.0.into_bytes()].concat()
177 });
178 [vec![0x05], data].concat()
179 }
180 }
181 }
182
183 pub fn packet_type(&self) -> &str {
184 match &self {
185 ControlPacket::Ping(_) => "PING",
186 ControlPacket::Init(_) => "INIT STREAM",
187 ControlPacket::Data(_, _) => "STREAM DATA",
188 ControlPacket::Refused(_) => "REFUSED",
189 ControlPacket::End(_) => "END STREAM",
190 }
191 }
192
193 pub fn deserialize(data: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
194 if data.len() < 9 {
195 return Err("invalid DataPacket, missing stream id".into());
196 }
197
198 let mut stream_id = [0u8; 8];
199 stream_id.clone_from_slice(&data[1..9]);
200 let stream_id = StreamId(stream_id);
201
202 let packet = match data[0] {
203 0x01 => ControlPacket::Init(stream_id),
204 0x02 => ControlPacket::Data(stream_id, data[9..].to_vec()),
205 0x03 => ControlPacket::Refused(stream_id),
206 0x04 => ControlPacket::End(stream_id),
207 0x05 => {
208 if stream_id == EMPTY_STREAM {
209 ControlPacket::Ping(None)
210 } else {
211 ControlPacket::Ping(Some(ReconnectToken(
212 String::from_utf8_lossy(&data[9..]).to_string(),
213 )))
214 }
215 }
216 _ => return Err("invalid control byte in DataPacket".into()),
217 };
218
219 Ok(packet)
220 }
221}