1use crate::stream::{
44 InStreamError, InputProtocol, OutStreamError, SilkroadStreamRead, SilkroadStreamWrite,
45};
46use bitflags::bitflags;
47use skrillax_packet::{
48 AsPacket, OutgoingPacket, Packet, PacketError, SecurityBytes, TryFromPacket,
49};
50use skrillax_security::handshake::{CheckBytesInitialization, PassiveEncryptionInitializationData};
51use skrillax_security::{
52 ActiveHandshake, PassiveHandshake, SecurityFeature, SilkroadSecurityError,
53};
54use skrillax_serde::{ByteSize, Deserialize, Serialize};
55use std::sync::Arc;
56use thiserror::Error;
57use tokio::io::{AsyncRead, AsyncWrite};
58
59#[derive(Error, Debug)]
60pub enum HandshakeError {
61 #[error("An error occurred while receiving data")]
62 InputError(#[from] InStreamError),
63 #[error("An error occurred while writing data")]
64 OutputError(#[from] OutStreamError),
65 #[error("A security level error occurred")]
66 SecurityError(#[from] SilkroadSecurityError),
67 #[error("An error occurred at the packet level")]
68 PacketError(#[from] PacketError),
69 #[error("Expected to receive a challenge, but received something else")]
70 NoChallengeReceived,
71 #[error("We didn't get an acknowledgment for the challenge response")]
72 FinalizationNotAccepted,
73 #[error("The flag inside the security packet did not match our expectations")]
74 InvalidContentFlag,
75}
76
77#[derive(Serialize, Deserialize, ByteSize, Copy, Clone, Eq, PartialEq, Debug)]
78struct HandshakeContent(u8);
79
80impl Default for HandshakeContent {
81 fn default() -> Self {
82 Self::empty()
83 }
84}
85
86bitflags! {
87 impl HandshakeContent: u8 {
88 const NONE = 1;
89 const INIT_BLOWFISH = 2;
90 const SETUP_CHECKS = 4;
91 const START_HANDSHAKE = 8;
92 const FINISH = 16;
93 }
94}
95
96#[derive(Serialize, ByteSize, Deserialize, Copy, Clone, Debug)]
97struct HandshakeInitialization {
98 handshake_seed: u64,
99 a: u32,
100 b: u32,
101 c: u32,
102}
103
104#[derive(Packet, ByteSize, Serialize, Deserialize, Default, Copy, Clone, Debug)]
105#[packet(opcode = 0x5000)]
106struct SecurityCapabilityCheck {
107 flag: HandshakeContent,
108 #[silkroad(when = "flag.contains(HandshakeContent::INIT_BLOWFISH)")]
109 blowfish_seed: Option<u64>,
110 #[silkroad(when = "flag.contains(HandshakeContent::SETUP_CHECKS)")]
111 seed_count: Option<u32>,
112 #[silkroad(when = "flag.contains(HandshakeContent::SETUP_CHECKS)")]
113 seed_crc: Option<u32>,
114 #[silkroad(when = "flag.contains(HandshakeContent::START_HANDSHAKE)")]
115 handshake_init: Option<HandshakeInitialization>,
116 #[silkroad(when = "flag.contains(HandshakeContent::FINISH)")]
117 challenge: Option<u64>,
118}
119
120impl SecurityCapabilityCheck {
121 fn check_bytes_init(&self) -> Option<CheckBytesInitialization> {
122 match (self.seed_crc, self.seed_count) {
123 (Some(crc_seed), Some(count_seed)) => Some(CheckBytesInitialization {
124 count_seed,
125 crc_seed,
126 }),
127 _ => None,
128 }
129 }
130
131 fn passive_encryption_init(&self) -> Option<PassiveEncryptionInitializationData> {
132 match (self.blowfish_seed, self.handshake_init) {
133 (Some(seed), Some(init)) => Some(PassiveEncryptionInitializationData {
134 seed,
135 handshake_seed: init.handshake_seed,
136 additional_seeds: [init.a, init.b, init.c],
137 }),
138 _ => None,
139 }
140 }
141}
142
143impl From<SecurityCapabilityCheck> for HandshakeActiveProtocol {
144 fn from(value: SecurityCapabilityCheck) -> Self {
145 HandshakeActiveProtocol::SecurityCapabilityCheck(value)
146 }
147}
148
149enum HandshakeActiveProtocol {
150 SecurityCapabilityCheck(SecurityCapabilityCheck),
151}
152
153impl From<&HandshakeActiveProtocol> for OutgoingPacket {
154 fn from(value: &HandshakeActiveProtocol) -> Self {
155 match value {
156 HandshakeActiveProtocol::SecurityCapabilityCheck(check) => check.as_packet(),
157 }
158 }
159}
160
161impl InputProtocol for HandshakeActiveProtocol {
162 type Proto = HandshakeActiveProtocol;
163
164 fn create_from(opcode: u16, data: &[u8]) -> Result<(usize, Self), InStreamError> {
165 match opcode {
166 SecurityCapabilityCheck::ID => {
167 let (consumed, check) = SecurityCapabilityCheck::try_deserialize(data)?;
168 Ok((
169 consumed,
170 HandshakeActiveProtocol::SecurityCapabilityCheck(check),
171 ))
172 },
173 _ => Err(InStreamError::UnmatchedOpcode(opcode)),
174 }
175 }
176}
177
178#[derive(Packet, ByteSize, Serialize, Deserialize, Debug)]
179#[packet(opcode = 0x5000)]
180struct HandshakeChallenge {
181 pub b: u32,
182 pub key: u64,
183}
184
185impl From<HandshakeChallenge> for HandshakePassiveProtocol {
186 fn from(value: HandshakeChallenge) -> Self {
187 HandshakePassiveProtocol::HandshakeChallenge(value)
188 }
189}
190
191#[derive(Packet, ByteSize, Serialize, Deserialize, Debug)]
192#[packet(opcode = 0x9000)]
193struct HandshakeAccepted;
194
195impl From<HandshakeAccepted> for HandshakePassiveProtocol {
196 fn from(value: HandshakeAccepted) -> Self {
197 HandshakePassiveProtocol::HandshakeAccepted(value)
198 }
199}
200
201enum HandshakePassiveProtocol {
202 HandshakeChallenge(HandshakeChallenge),
203 HandshakeAccepted(HandshakeAccepted),
204}
205
206impl InputProtocol for HandshakePassiveProtocol {
207 type Proto = HandshakePassiveProtocol;
208
209 fn create_from(opcode: u16, data: &[u8]) -> Result<(usize, Self), InStreamError> {
210 match opcode {
211 HandshakeAccepted::ID => {
212 let (consumed, accepted) = HandshakeAccepted::try_deserialize(data)?;
213 Ok((
214 consumed,
215 HandshakePassiveProtocol::HandshakeAccepted(accepted),
216 ))
217 },
218 HandshakeChallenge::ID => {
219 let (consumed, challenge) = HandshakeChallenge::try_deserialize(data)?;
220 Ok((
221 consumed,
222 HandshakePassiveProtocol::HandshakeChallenge(challenge),
223 ))
224 },
225 _ => Err(InStreamError::UnmatchedOpcode(opcode)),
226 }
227 }
228}
229
230impl From<&HandshakePassiveProtocol> for OutgoingPacket {
231 fn from(value: &HandshakePassiveProtocol) -> Self {
232 match value {
233 HandshakePassiveProtocol::HandshakeChallenge(challenge) => challenge.as_packet(),
234 HandshakePassiveProtocol::HandshakeAccepted(accept) => accept.as_packet(),
235 }
236 }
237}
238
239pub struct ActiveSecuritySetup<'a, T: AsyncRead + Unpin, S: AsyncWrite + Unpin> {
258 reader: &'a mut SilkroadStreamRead<T>,
259 writer: &'a mut SilkroadStreamWrite<S>,
260 enabled_features: SecurityFeature,
261}
262
263impl<T: AsyncRead + Unpin, S: AsyncWrite + Unpin> ActiveSecuritySetup<'_, T, S> {
264 pub async fn handle(
267 reader: &mut SilkroadStreamRead<T>,
268 writer: &mut SilkroadStreamWrite<S>,
269 ) -> Result<(), HandshakeError> {
270 ActiveSecuritySetup {
271 reader,
272 writer,
273 enabled_features: SecurityFeature::all(),
274 }
275 .initialize()
276 .await
277 }
278
279 pub async fn handle_with_features(
282 reader: &mut SilkroadStreamRead<T>,
283 writer: &mut SilkroadStreamWrite<S>,
284 enabled_features: SecurityFeature,
285 ) -> Result<(), HandshakeError> {
286 ActiveSecuritySetup {
287 reader,
288 writer,
289 enabled_features,
290 }
291 .initialize()
292 .await
293 }
294
295 async fn initialize(self) -> Result<(), HandshakeError> {
296 let (reader, writer) = (self.reader, self.writer);
297 let mut setup = ActiveHandshake::default();
298 let init = setup.initialize(self.enabled_features)?;
299
300 if let Some(checks) = init.checks.as_ref() {
301 let security_bytes = Arc::new(SecurityBytes::from_seeds(
302 checks.crc_seed,
303 checks.count_seed,
304 ));
305 reader.enable_security_checks(security_bytes);
306 }
307
308 let mut flag = HandshakeContent::START_HANDSHAKE;
309 let (blowfish_seed, encryption) = if let Some(encryption) = &init.encryption_seed {
310 flag |= HandshakeContent::INIT_BLOWFISH;
311 (
312 Some(encryption.seed),
313 Some(HandshakeInitialization {
314 handshake_seed: encryption.handshake_seed,
315 a: encryption.additional_seeds[0],
316 b: encryption.additional_seeds[1],
317 c: encryption.additional_seeds[2],
318 }),
319 )
320 } else {
321 (None, None)
322 };
323
324 let (crc, count) = if let Some(checks) = &init.checks {
325 flag |= HandshakeContent::SETUP_CHECKS;
326 (Some(checks.crc_seed), Some(checks.count_seed))
327 } else {
328 (None, None)
329 };
330
331 let init_packet = SecurityCapabilityCheck {
332 flag: HandshakeContent::INIT_BLOWFISH
333 | HandshakeContent::SETUP_CHECKS
334 | HandshakeContent::START_HANDSHAKE,
335 blowfish_seed,
336 seed_count: count,
337 seed_crc: crc,
338 handshake_init: encryption,
339 ..Default::default()
340 };
341 writer.write_packet(init_packet).await?;
342
343 let response = reader.next_packet::<HandshakePassiveProtocol>().await?;
344
345 let HandshakePassiveProtocol::HandshakeChallenge(challenge) = response else {
346 return Err(HandshakeError::NoChallengeReceived);
347 };
348
349 let challenge = setup.start_challenge(challenge.b, challenge.key)?;
350 writer
351 .write_packet(SecurityCapabilityCheck {
352 flag: HandshakeContent::FINISH,
353 challenge: Some(challenge),
354 ..Default::default()
355 })
356 .await?;
357
358 let response = reader.next_packet::<HandshakePassiveProtocol>().await?;
359 if !matches!(response, HandshakePassiveProtocol::HandshakeAccepted(_)) {
360 return Err(HandshakeError::FinalizationNotAccepted);
361 }
362
363 if let Some(encryption) = setup.finish()? {
364 let security = Arc::new(encryption);
365 reader.enable_encryption(Arc::clone(&security));
366 writer.enable_encryption(security);
367 }
368
369 Ok(())
370 }
371}
372
373pub struct PassiveSecuritySetup<'a, T: AsyncRead + Unpin, S: AsyncWrite + Unpin> {
391 reader: &'a mut SilkroadStreamRead<T>,
392 writer: &'a mut SilkroadStreamWrite<S>,
393}
394
395impl<T: AsyncRead + Unpin, S: AsyncWrite + Unpin> PassiveSecuritySetup<'_, T, S> {
396 pub async fn handle(
398 reader: &mut SilkroadStreamRead<T>,
399 writer: &mut SilkroadStreamWrite<S>,
400 ) -> Result<(), HandshakeError> {
401 PassiveSecuritySetup { reader, writer }.initialize().await
402 }
403
404 async fn initialize(self) -> Result<(), HandshakeError> {
405 let (reader, writer) = (self.reader, self.writer);
406 let mut handshake = PassiveHandshake::default();
407
408 let init = reader.next_packet::<HandshakeActiveProtocol>().await?;
409 let HandshakeActiveProtocol::SecurityCapabilityCheck(capability) = init;
410
411 if capability.flag == HandshakeContent::NONE {
412 return Ok(());
413 }
414
415 if let Some(checks) = capability.check_bytes_init() {
416 let security_bytes = Arc::new(SecurityBytes::from_seeds(
417 checks.crc_seed,
418 checks.count_seed,
419 ));
420 writer.enable_security_checks(security_bytes);
421 }
422
423 let encryption_seed = capability.passive_encryption_init();
424 let challenge = handshake.initialize(encryption_seed)?;
425
426 if let Some((key, b)) = challenge {
427 writer.write_packet(HandshakeChallenge { b, key }).await?;
428
429 let finalize = reader.next_packet::<HandshakeActiveProtocol>().await?;
430 let HandshakeActiveProtocol::SecurityCapabilityCheck(capability) = finalize;
431 if !capability.flag == HandshakeContent::FINISH {
432 return Err(HandshakeError::InvalidContentFlag);
433 }
434
435 let Some(challenge) = capability.challenge else {
436 return Err(HandshakeError::NoChallengeReceived);
437 };
438
439 handshake.finish(challenge)?;
440 writer.write_packet(HandshakeAccepted).await?;
441 }
442
443 if let Some(encryption) = handshake.done()? {
444 let encryption = Arc::new(encryption);
445 reader.enable_encryption(Arc::clone(&encryption));
446 writer.enable_encryption(encryption);
447 }
448
449 Ok(())
450 }
451}
452
453#[cfg(test)]
454mod test {
455 use super::*;
456 use crate::stream::SilkroadTcpExt;
457 use tokio::net::TcpSocket;
458
459 #[derive(Packet, ByteSize, Serialize, Deserialize)]
460 #[packet(opcode = 0x4242, encrypted = true)]
461 struct Test {
462 content: String,
463 }
464
465 #[tokio::test]
466 async fn test() {
467 let server = TcpSocket::new_v4().unwrap();
468 server.bind("127.0.0.1:0".parse().unwrap()).unwrap();
469 let listen_addr = server.local_addr().unwrap();
470 let server_listener = server.listen(0).unwrap();
471 let server_await = tokio::spawn(async move {
472 let (client_socket, _) = server_listener.accept().await.unwrap();
473 let (mut reader, mut writer) = client_socket.into_silkroad_stream();
474 ActiveSecuritySetup::handle(&mut reader, &mut writer)
475 .await
476 .unwrap();
477 assert!(reader.encryption().is_some());
478 assert!(writer.encryption().is_some());
479 let packet = reader.next_packet::<Test>().await.unwrap();
480 assert_eq!(packet.content, "Hello!");
481 });
482
483 let client = TcpSocket::new_v4()
484 .unwrap()
485 .connect(listen_addr)
486 .await
487 .unwrap();
488 let (mut reader, mut writer) = client.into_silkroad_stream();
489 PassiveSecuritySetup::handle(&mut reader, &mut writer)
490 .await
491 .unwrap();
492 assert!(reader.encryption().is_some());
493 assert!(writer.encryption().is_some());
494 writer
495 .write_packet(Test {
496 content: String::from("Hello!"),
497 })
498 .await
499 .unwrap();
500 server_await.await.unwrap();
501 }
502}