1use std::fmt::Debug;
2use std::io::{Read, Write};
3use std::net::SocketAddr;
4use std::ops::Not;
5use std::sync::Arc;
6
7use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
8use message_encoding::{m_max, m_max_list, m_opt_sum, m_static, MessageEncoding};
9use serde::ser::SerializeStruct;
10use serde::Serialize;
11
12use crate::{AgentSessionId, PortRange};
13use crate::hmac::HmacSha256;
14
15#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
16pub enum ControlRequest {
17 Ping(Ping),
18 AgentRegister(AgentRegister),
19 AgentKeepAlive(AgentSessionId),
20 SetupUdpChannel(AgentSessionId),
21 AgentCheckPortMapping(AgentCheckPortMapping),
22}
23
24#[repr(u32)]
25#[derive(Copy, Clone, PartialEq, Eq, Hash)]
26pub enum ControlRequestId {
27 _PingV1 = 1,
28 AgentRegisterV1,
29 AgentKeepAliveV1,
30 SetupUdpChannelV1,
31 AgentCheckPortMappingV1,
32 PingV2,
33 AgentRegisterV2,
34 END,
35}
36
37impl ControlRequestId {
38 pub fn from_num(num: u32) -> Option<Self> {
39 if (Self::END as u32) <= num || num == 0 {
40 return None;
41 }
42 Some(unsafe { std::mem::transmute::<u32, Self>(num) })
43 }
44}
45
46impl MessageEncoding for ControlRequestId {
47 const STATIC_SIZE: Option<usize> = Some(4);
48
49 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
50 (*self as u32).write_to(out)
51 }
52
53 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
54 let v = u32::read_from(read)?;
55 ControlRequestId::from_num(v)
56 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid request id"))
57 }
58}
59
60impl MessageEncoding for ControlRequest {
61 const MAX_SIZE: Option<usize> = Some(m_static::<ControlRequestId>() + m_max_list(&[
62 m_max::<Ping>(),
63 m_max::<AgentRegister>(),
64 m_max::<AgentSessionId>(),
65 m_max::<AgentCheckPortMapping>(),
66 ]));
67
68 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
69 let mut sum = 0;
70
71 match self {
72 ControlRequest::Ping(data) => {
73 sum += ControlRequestId::PingV2.write_to(out)?;
74 sum += data.write_to(out)?;
75 }
76 ControlRequest::AgentRegister(data) => {
77 if data.proto_version <= 1 {
78 sum += ControlRequestId::AgentRegisterV1.write_to(out)?;
79 } else {
80 sum += ControlRequestId::AgentRegisterV2.write_to(out)?;
81 }
82 sum += data.write_to(out)?;
83 }
84 ControlRequest::AgentKeepAlive(data) => {
85 sum += ControlRequestId::AgentKeepAliveV1.write_to(out)?;
86 sum += data.write_to(out)?;
87 }
88 ControlRequest::SetupUdpChannel(data) => {
89 sum += ControlRequestId::SetupUdpChannelV1.write_to(out)?;
90 sum += data.write_to(out)?;
91 }
92 ControlRequest::AgentCheckPortMapping(data) => {
93 sum += ControlRequestId::AgentCheckPortMappingV1.write_to(out)?;
94 sum += data.write_to(out)?;
95 }
96 }
97
98 Ok(sum)
99 }
100
101 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
102 let id = ControlRequestId::read_from(read)?;
103
104 match id {
105 ControlRequestId::PingV2 => Ok(ControlRequest::Ping(Ping::read_from(read)?)),
106 ControlRequestId::AgentRegisterV1 => Ok(ControlRequest::AgentRegister(AgentRegisterV1::read_from(read)?.upgrade())),
107 ControlRequestId::AgentRegisterV2 => Ok(ControlRequest::AgentRegister(AgentRegister::read_from(read)?)),
108 ControlRequestId::AgentKeepAliveV1 => Ok(ControlRequest::AgentKeepAlive(AgentSessionId::read_from(read)?)),
109 ControlRequestId::SetupUdpChannelV1 => Ok(ControlRequest::SetupUdpChannel(AgentSessionId::read_from(read)?)),
110 ControlRequestId::AgentCheckPortMappingV1 => Ok(ControlRequest::AgentCheckPortMapping(AgentCheckPortMapping::read_from(read)?)),
111 ControlRequestId::_PingV1 => Ok(ControlRequest::Ping(Ping {
112 now: u64::read_from(read)?,
113 session_id: None,
114 current_ping: None,
115 })),
116 _ => Err(std::io::Error::other("old control request no longer supported")),
117 }
118 }
119}
120
121#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
122pub struct AgentCheckPortMapping {
123 pub agent_session_id: AgentSessionId,
124 pub port_range: PortRange,
125}
126
127impl MessageEncoding for AgentCheckPortMapping {
128 const MAX_SIZE: Option<usize> = Some(m_static::<AgentSessionId>() + m_max::<PortRange>());
129
130 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
131 let mut sum = 0;
132 sum += self.agent_session_id.write_to(out)?;
133 sum += self.port_range.write_to(out)?;
134 Ok(sum)
135 }
136
137 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
138 Ok(AgentCheckPortMapping {
139 agent_session_id: AgentSessionId::read_from(read)?,
140 port_range: PortRange::read_from(read)?,
141 })
142 }
143}
144
145#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
146pub struct Ping {
147 pub now: u64,
148 pub current_ping: Option<u32>,
149 pub session_id: Option<AgentSessionId>,
150}
151
152impl MessageEncoding for Ping {
153 const STATIC_SIZE: Option<usize> = Some(8 + m_static::<Option<u32>>() + m_static::<Option<AgentSessionId>>());
154
155 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
156 let mut sum = 0;
157 sum += self.now.write_to(out)?;
158 sum += self.current_ping.write_to(out)?;
159 sum += self.session_id.write_to(out)?;
160 Ok(sum)
161 }
162
163 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
164 Ok(Ping {
165 now: MessageEncoding::read_from(read)?,
166 current_ping: MessageEncoding::read_from(read)?,
167 session_id: MessageEncoding::read_from(read)?,
168 })
169 }
170}
171
172
173#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
174pub struct AgentRegister {
175 pub proto_version: u64,
176 pub account_id: u64,
177 pub agent_id: u64,
178 pub agent_version: u64,
179 pub timestamp: u64,
180 pub client_addr: SocketAddr,
181 pub tunnel_addr: SocketAddr,
182 pub signature: [u8; 32],
183}
184
185impl AgentRegister {
186 pub fn update_signature(&mut self, temp_buffer: &mut Vec<u8>, hmac: &HmacSha256) {
187 self.write_plain(temp_buffer);
188 self.signature = hmac.sign(temp_buffer);
189 }
190
191 pub fn verify_signature(&self, temp_buffer: &mut Vec<u8>, hmac: &HmacSha256) -> bool {
192 self.write_plain(temp_buffer);
193 hmac.verify(temp_buffer, &self.signature).is_ok()
194 }
195
196 fn write_plain(&self, temp_buffer: &mut Vec<u8>) {
197 temp_buffer.clear();
198 self.write_to(temp_buffer).unwrap();
199 assert!(self.signature.len() <= temp_buffer.len());
200
201 let adjusted_len = temp_buffer.len() - self.signature.len();
202 temp_buffer.truncate(adjusted_len);
203 }
204}
205
206const ENCODING_INCLUDES_VERSION_BIT: u64 = 1u64 << 63;
207
208impl MessageEncoding for AgentRegister {
209 const MAX_SIZE: Option<usize> = m_opt_sum(&[
210 u64::MAX_SIZE,
211 u64::MAX_SIZE,
212 u64::MAX_SIZE,
213 u64::MAX_SIZE,
214 u64::MAX_SIZE,
215 SocketAddr::MAX_SIZE,
216 SocketAddr::MAX_SIZE,
217 Some(32),
218 ]);
219
220 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
221 let mut sum = 0;
222
223 if self.proto_version <= 1 {
224 if (self.account_id & ENCODING_INCLUDES_VERSION_BIT) == ENCODING_INCLUDES_VERSION_BIT {
225 return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "account id too large for proto version 1"));
226 }
227
228 sum += self.account_id.write_to(out)?;
229 } else {
230 if (self.proto_version & ENCODING_INCLUDES_VERSION_BIT) == ENCODING_INCLUDES_VERSION_BIT {
231 return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid proto version"));
232 }
233
234 sum += (self.proto_version | ENCODING_INCLUDES_VERSION_BIT).write_to(out)?;
235 sum += self.account_id.write_to(out)?;
236 }
237
238 sum += self.agent_id.write_to(out)?;
239 sum += self.agent_version.write_to(out)?;
240 sum += self.timestamp.write_to(out)?;
241 sum += self.client_addr.write_to(out)?;
242 sum += self.tunnel_addr.write_to(out)?;
243 out.write_all(&self.signature)?;
244 sum += self.signature.len();
245 Ok(sum)
246 }
247
248 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
249 let first_word = u64::read_from(read)?;
250
251 let mut proto_version = 1;
252 let account_id: u64;
253
254 if (first_word & ENCODING_INCLUDES_VERSION_BIT) == ENCODING_INCLUDES_VERSION_BIT {
255 proto_version = first_word & ENCODING_INCLUDES_VERSION_BIT.not();
256 account_id = u64::read_from(read)?;
257 } else {
258 account_id = first_word;
259 }
260
261 let mut res = AgentRegister {
262 proto_version,
263 account_id,
264 agent_id: u64::read_from(read)?,
265 agent_version: u64::read_from(read)?,
266 timestamp: u64::read_from(read)?,
267 client_addr: SocketAddr::read_from(read)?,
268 tunnel_addr: SocketAddr::read_from(read)?,
269 signature: [0u8; 32],
270 };
271
272 read.read_exact(&mut res.signature[..])?;
273 Ok(res)
274 }
275}
276
277pub struct AgentRegisterV1 {
278 pub account_id: u64,
279 pub agent_id: u64,
280 pub agent_version: u64,
281 pub timestamp: u64,
282 pub client_addr: SocketAddr,
283 pub tunnel_addr: SocketAddr,
284 pub signature: [u8; 32],
285}
286
287impl MessageEncoding for AgentRegisterV1 {
288 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
289 out.write_u64::<BigEndian>(self.account_id)?;
290 out.write_u64::<BigEndian>(self.agent_id)?;
291 out.write_u64::<BigEndian>(self.agent_version)?;
292 out.write_u64::<BigEndian>(self.timestamp)?;
293 let mut len = 8 + 8 + 8 + 8;
294 len += self.client_addr.write_to(out)?;
295 len += self.tunnel_addr.write_to(out)?;
296 if out.write(&self.signature)? != 32 {
297 return Err(std::io::Error::new(std::io::ErrorKind::WriteZero, "failed to write full signature"));
298 }
299 len += 32;
300 Ok(len)
301 }
302
303 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
304 let mut res = Self {
305 account_id: read.read_u64::<BigEndian>()?,
306 agent_id: read.read_u64::<BigEndian>()?,
307 agent_version: read.read_u64::<BigEndian>()?,
308 timestamp: read.read_u64::<BigEndian>()?,
309 client_addr: SocketAddr::read_from(read)?,
310 tunnel_addr: SocketAddr::read_from(read)?,
311 signature: [0u8; 32],
312 };
313
314 if read.read(&mut res.signature[..])? != 32 {
315 return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "missing signature"));
316 }
317
318 Ok(res)
319 }
320}
321
322impl AgentRegisterV1 {
323 pub fn upgrade(self) -> AgentRegister {
324 AgentRegister {
325 proto_version: 1,
326 account_id: self.account_id,
327 agent_id: self.agent_id,
328 agent_version: self.agent_version,
329 timestamp: self.timestamp,
330 client_addr: self.client_addr,
331 tunnel_addr: self.tunnel_addr,
332 signature: self.signature,
333 }
334 }
335}
336
337#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
338pub enum ControlResponse {
339 Pong(Pong),
340 InvalidSignature,
341 Unauthorized,
342 RequestQueued,
343 TryAgainLater,
344 AgentRegistered(AgentRegistered),
345 AgentPortMapping(AgentPortMapping),
346 UdpChannelDetails(UdpChannelDetails),
347}
348
349impl MessageEncoding for ControlResponse {
350 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
351 let mut sum = 0;
352
353 match self {
354 ControlResponse::Pong(data) => {
355 sum += 1u32.write_to(out)?;
356 sum += data.write_to(out)?;
357 }
358 ControlResponse::InvalidSignature => {
359 sum += 2u32.write_to(out)?;
360 }
361 ControlResponse::Unauthorized => {
362 sum += 3u32.write_to(out)?;
363 }
364 ControlResponse::RequestQueued => {
365 sum += 4u32.write_to(out)?;
366 }
367 ControlResponse::TryAgainLater => {
368 sum += 5u32.write_to(out)?;
369 }
370 ControlResponse::AgentRegistered(data) => {
371 sum += 6u32.write_to(out)?;
372 sum += data.write_to(out)?;
373 }
374 ControlResponse::AgentPortMapping(data) => {
375 sum += 7u32.write_to(out)?;
376 sum += data.write_to(out)?;
377 }
378 ControlResponse::UdpChannelDetails(data) => {
379 sum += 8u32.write_to(out)?;
380 sum += data.write_to(out)?;
381 }
382 }
383
384 Ok(sum)
385 }
386
387 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
388 match read.read_u32::<BigEndian>()? {
389 1 => Ok(ControlResponse::Pong(Pong::read_from(read)?)),
390 2 => Ok(ControlResponse::InvalidSignature),
391 3 => Ok(ControlResponse::Unauthorized),
392 4 => Ok(ControlResponse::RequestQueued),
393 5 => Ok(ControlResponse::TryAgainLater),
394 6 => Ok(ControlResponse::AgentRegistered(AgentRegistered::read_from(read)?)),
395 7 => Ok(ControlResponse::AgentPortMapping(AgentPortMapping::read_from(read)?)),
396 8 => Ok(ControlResponse::UdpChannelDetails(UdpChannelDetails::read_from(read)?)),
397 _ => Err(std::io::Error::other("invalid ControlResponse id")),
398 }
399 }
400}
401
402#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
403pub struct AgentPortMapping {
404 pub range: PortRange,
405 pub found: Option<AgentPortMappingFound>,
406}
407
408impl MessageEncoding for AgentPortMapping {
409 const MAX_SIZE: Option<usize> = Some(
410 m_max::<PortRange>() +
411 m_max::<Option<AgentPortMappingFound>>()
412 );
413
414 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
415 let mut sum = 0;
416 sum += self.range.write_to(out)?;
417 sum += self.found.write_to(out)?;
418 Ok(sum)
419 }
420
421 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
422 Ok(AgentPortMapping {
423 range: PortRange::read_from(read)?,
424 found: Option::<AgentPortMappingFound>::read_from(read)?,
425 })
426 }
427}
428
429#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
430pub enum AgentPortMappingFound {
431 ToAgent(AgentSessionId),
432}
433
434impl MessageEncoding for AgentPortMappingFound {
435 const MAX_SIZE: Option<usize> = Some(4 + m_max_list(&[
436 m_max::<AgentSessionId>(),
437 ]));
438
439 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
440 let mut sum = 0;
441
442 match self {
443 AgentPortMappingFound::ToAgent(id) => {
444 sum += 1u32.write_to(out)?;
445 sum += id.write_to(out)?;
446 }
447 }
448
449 Ok(sum)
450 }
451
452 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
453 match read.read_u32::<BigEndian>()? {
454 1 => Ok(AgentPortMappingFound::ToAgent(AgentSessionId::read_from(read)?)),
455 _ => Err(std::io::Error::new(std::io::ErrorKind::Other, "unknown AgentPortMappingFound id")),
456 }
457 }
458}
459
460#[derive(Eq, PartialEq, Clone)]
461pub struct UdpChannelDetails {
462 pub tunnel_addr: SocketAddr,
463 pub token: Arc<Vec<u8>>,
464}
465
466impl Serialize for UdpChannelDetails {
467 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
468 let mut s = serializer.serialize_struct("UdpChannelDetails", 2)?;
469 s.serialize_field("tunnel_addr", &self.tunnel_addr)?;
470 s.serialize_field("token", &*self.token)?;
471 s.end()
472 }
473}
474
475impl Debug for UdpChannelDetails {
476 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477 f.debug_struct("UdpChannelDetails")
478 .field("tunnel_addr", &self.tunnel_addr)
479 .field("token", &hex::encode(&self.token[..]))
480 .finish()
481 }
482}
483
484impl MessageEncoding for UdpChannelDetails {
485 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
486 let mut sum = 0;
487 sum += self.tunnel_addr.write_to(out)?;
488 sum += self.token.write_to(out)?;
489 Ok(sum)
490 }
491
492 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
493 Ok(UdpChannelDetails {
494 tunnel_addr: SocketAddr::read_from(read)?,
495 token: Arc::new(Vec::read_from(read)?),
496 })
497 }
498}
499
500#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
501pub struct Pong {
502 pub request_now: u64,
503 pub server_now: u64,
504 pub server_id: u64,
505 pub data_center_id: u32,
506 pub client_addr: SocketAddr,
507 pub tunnel_addr: SocketAddr,
508 pub session_expire_at: Option<u64>,
509}
510
511impl MessageEncoding for Pong {
512 const MAX_SIZE: Option<usize> = Some(
513 m_static::<u64>() * 3 +
514 m_static::<u32>() +
515 m_max::<SocketAddr>() * 2 +
516 m_static::<Option<u64>>()
517 );
518
519 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
520 let mut sum = 0;
521 sum += self.request_now.write_to(out)?;
522 sum += self.server_now.write_to(out)?;
523 sum += self.server_id.write_to(out)?;
524 sum += self.data_center_id.write_to(out)?;
525 sum += self.client_addr.write_to(out)?;
526 sum += self.tunnel_addr.write_to(out)?;
527 sum += self.session_expire_at.write_to(out)?;
528 Ok(sum)
529 }
530
531 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
532 Ok(Pong {
533 request_now: read.read_u64::<BigEndian>()?,
534 server_now: read.read_u64::<BigEndian>()?,
535 server_id: read.read_u64::<BigEndian>()?,
536 data_center_id: read.read_u32::<BigEndian>()?,
537 client_addr: SocketAddr::read_from(read)?,
538 tunnel_addr: SocketAddr::read_from(read)?,
539 session_expire_at: Option::read_from(read)?,
540 })
541 }
542}
543
544#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
545pub struct AgentRegistered {
546 pub id: AgentSessionId,
547 pub expires_at: u64,
548}
549
550impl MessageEncoding for AgentRegistered {
551 const STATIC_SIZE: Option<usize> = Some(
552 m_static::<AgentSessionId>() +
553 m_static::<u64>()
554 );
555
556 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
557 let mut sum = 0;
558 sum += self.id.write_to(out)?;
559 sum += self.expires_at.write_to(out)?;
560 Ok(sum)
561 }
562
563 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
564 Ok(AgentRegistered {
565 id: AgentSessionId::read_from(read)?,
566 expires_at: read.read_u64::<BigEndian>()?,
567 })
568 }
569}
570
571#[cfg(test)]
572mod test {
573 use std::fmt::Debug;
574 use std::net::{IpAddr, Ipv4Addr};
575
576 use rand::{Rng, RngCore, thread_rng};
577
578 use crate::PortProto;
579 use crate::rpc::ControlRpcMessage;
580
581 use super::*;
582
583 #[test]
584 fn agent_register_sign_test() {
585 let mut reg = AgentRegister {
586 proto_version: 0,
587 account_id: 1,
588 agent_id: 2,
589 agent_version: 3,
590 timestamp: 1000,
591 client_addr: "10.20.30.40:5678".parse().unwrap(),
592 tunnel_addr: "9.20.3.40:9912".parse().unwrap(),
593 signature: [0u8; 32],
594 };
595
596 let hmac = HmacSha256::create("this is a super secret secret".as_bytes());
597
598 let mut buffer = Vec::new();
599 reg.update_signature(&mut buffer, &hmac);
600 assert!(reg.verify_signature(&mut buffer, &hmac));
601
602 reg.proto_version = 1;
603 reg.update_signature(&mut buffer, &hmac);
604 assert!(reg.verify_signature(&mut buffer, &hmac));
605 }
606
607 #[test]
608 fn agent_register_old_proto_decode() {
609 let reg = AgentRegisterV1 {
610 account_id: 1,
611 agent_id: 2,
612 agent_version: 3,
613 timestamp: 1000,
614 client_addr: "10.20.30.40:5678".parse().unwrap(),
615 tunnel_addr: "9.20.3.40:9912".parse().unwrap(),
616 signature: [0u8; 32],
617 };
618
619 let mut out = Vec::new();
620 ControlRequestId::AgentRegisterV1.write_to(&mut out).unwrap();
621 reg.write_to(&mut out).unwrap();
622
623 let mut reader = &out[..];
624 let read = ControlRequest::read_from(&mut reader).unwrap();
625 assert_eq!(read, ControlRequest::AgentRegister(AgentRegister {
626 proto_version: 1,
627 account_id: 1,
628 agent_id: 2,
629 agent_version: 3,
630 timestamp: 1000,
631 client_addr: "10.20.30.40:5678".parse().unwrap(),
632 tunnel_addr: "9.20.3.40:9912".parse().unwrap(),
633 signature: [0u8; 32],
634 }))
635 }
636
637 #[test]
638 fn fuzzy_test_control_request() {
639 let mut rng = thread_rng();
640 let mut buffer = vec![0u8; 2048];
641
642 for _ in 0..100000 {
643 let msg = rng_control_request(&mut rng);
644 test_encoding(msg, &mut buffer);
645 }
646
647 for _ in 0..1000 {
648 test_encoding(ControlRpcMessage {
649 request_id: rng.next_u64(),
650 content: rng_control_request(&mut rng),
651 }, &mut buffer);
652 }
653 }
654
655 #[test]
656 fn fuzzy_test_control_response() {
657 let mut rng = thread_rng();
658 let mut buffer = vec![0u8; 2048];
659
660 for _ in 0..100000 {
661 let msg = rng_control_response(&mut rng);
662 test_encoding(msg, &mut buffer);
663 }
664
665 for _ in 0..1000 {
666 test_encoding(ControlRpcMessage {
667 request_id: rng.next_u64(),
668 content: rng_control_response(&mut rng),
669 }, &mut buffer);
670 }
671 }
672
673 fn test_encoding<T: MessageEncoding + PartialEq + Debug>(msg: T, buffer: &mut [u8]) {
674 assert_eq!(0, T::_ASSERT);
675
676 let mut writer = &mut buffer[..];
677 msg.write_to(&mut writer).unwrap();
678
679 let remaining_len = writer.len();
680 let written = buffer.len() - remaining_len;
681
682 if let Some(size) = T::STATIC_SIZE {
683 assert_eq!(written, size);
684 }
685
686 if let Some(size) = T::MAX_SIZE {
687 assert!(written <= size);
688 }
689
690 let mut reader = &buffer[0..written];
691 let recovered = T::read_from(&mut reader).unwrap();
692
693 assert_eq!(msg, recovered);
694 }
695
696 pub fn rng_control_request<R: RngCore>(rng: &mut R) -> ControlRequest {
697 match rng.next_u32() % 5 {
698 0 => ControlRequest::Ping(Ping {
699 now: rng.next_u64(),
700 current_ping: if rng.next_u32() % 2 == 0 {
701 Some(rng.next_u32())
702 } else {
703 None
704 },
705 session_id: if rng.next_u32() % 2 == 0 {
706 Some(AgentSessionId {
707 session_id: rng.next_u64(),
708 account_id: rng.next_u64() % (i64::MAX as u64),
709 agent_id: rng.next_u64(),
710 })
711 } else {
712 None
713 },
714 }),
715 1 => ControlRequest::AgentRegister(AgentRegister {
716 proto_version: 1 + rng.next_u64() % 2,
717 account_id: rng.next_u64() % (i64::MAX as u64),
718 agent_id: rng.next_u64(),
719 agent_version: rng.next_u64(),
720 timestamp: rng.next_u64(),
721 client_addr: rng_socket_address(rng),
722 tunnel_addr: rng_socket_address(rng),
723 signature: {
724 let mut bytes = [0u8; 32];
725 rng.fill(&mut bytes);
726 bytes
727 },
728 }),
729 2 => ControlRequest::AgentKeepAlive(AgentSessionId {
730 session_id: rng.next_u64(),
731 account_id: rng.next_u64() % (i64::MAX as u64),
732 agent_id: rng.next_u64(),
733 }),
734 3 => ControlRequest::SetupUdpChannel(AgentSessionId {
735 session_id: rng.next_u64(),
736 account_id: rng.next_u64() % (i64::MAX as u64),
737 agent_id: rng.next_u64(),
738 }),
739 4 => ControlRequest::AgentCheckPortMapping(AgentCheckPortMapping {
740 agent_session_id: AgentSessionId {
741 session_id: rng.next_u64(),
742 account_id: rng.next_u64() % (i64::MAX as u64),
743 agent_id: rng.next_u64(),
744 },
745 port_range: PortRange {
746 ip: match rng.next_u32() % 2 {
747 0 => IpAddr::V4(Ipv4Addr::from(rng.next_u32())),
748 1 => IpAddr::V6({
749 let mut bytes = [0u8; 16];
750 rng.fill(&mut bytes);
751 bytes.into()
752 }),
753 _ => unreachable!(),
754 },
755 port_start: rng.next_u32() as u16,
756 port_end: rng.next_u32() as u16,
757 port_proto: match rng.next_u32() % 3 {
758 0 => PortProto::Tcp,
759 1 => PortProto::Udp,
760 2 => PortProto::Both,
761 _ => unreachable!(),
762 },
763 },
764 }),
765 _ => unreachable!(),
766 }
767 }
768
769 pub fn rng_control_response<R: RngCore>(rng: &mut R) -> ControlResponse {
770 match rng.next_u32() % 8 {
771 0 => ControlResponse::Pong(Pong {
772 request_now: rng.next_u64(),
773 server_now: rng.next_u64(),
774 server_id: rng.next_u64(),
775 data_center_id: rng.next_u32(),
776 client_addr: rng_socket_address(rng),
777 tunnel_addr: rng_socket_address(rng),
778 session_expire_at: if rng.next_u32() % 2 == 1 {
779 Some(rng.next_u64())
780 } else {
781 None
782 },
783 }),
784 1 => ControlResponse::InvalidSignature,
785 2 => ControlResponse::Unauthorized,
786 3 => ControlResponse::RequestQueued,
787 4 => ControlResponse::TryAgainLater,
788 5 => ControlResponse::AgentRegistered(AgentRegistered {
789 id: AgentSessionId {
790 session_id: rng.next_u64(),
791 account_id: rng.next_u64() % (i64::MAX as u64),
792 agent_id: rng.next_u64(),
793 },
794 expires_at: rng.next_u64(),
795 }),
796 6 => ControlResponse::AgentPortMapping(AgentPortMapping {
797 range: PortRange {
798 ip: match rng.next_u32() % 2 {
799 0 => IpAddr::V4(Ipv4Addr::from(rng.next_u32())),
800 1 => IpAddr::V6({
801 let mut bytes = [0u8; 16];
802 rng.fill(&mut bytes);
803 bytes.into()
804 }),
805 _ => unreachable!(),
806 },
807 port_start: rng.next_u32() as u16,
808 port_end: rng.next_u32() as u16,
809 port_proto: match rng.next_u32() % 3 {
810 0 => PortProto::Tcp,
811 1 => PortProto::Udp,
812 2 => PortProto::Both,
813 _ => unreachable!(),
814 },
815 },
816 found: match rng.next_u32() % 2 {
817 0 => None,
818 1 => Some(AgentPortMappingFound::ToAgent(AgentSessionId {
819 session_id: rng.next_u64(),
820 account_id: rng.next_u64() % (i64::MAX as u64),
821 agent_id: rng.next_u64(),
822 })),
823 _ => unreachable!()
824 },
825 }),
826 7 => ControlResponse::UdpChannelDetails(UdpChannelDetails {
827 tunnel_addr: rng_socket_address(rng),
828 token: {
829 let len = ((rng.next_u64() % 30) + 32) as usize;
830 let mut buffer = vec![0u8; len];
831 rng.fill_bytes(&mut buffer);
832 Arc::new(buffer)
833 },
834 }),
835 _ => unreachable!()
836 }
837 }
838
839 fn rng_socket_address<R: RngCore>(rng: &mut R) -> SocketAddr {
840 SocketAddr::new(
841 match rng.next_u32() % 2 {
842 0 => IpAddr::V4(Ipv4Addr::from(rng.next_u32())),
843 1 => IpAddr::V6({
844 let mut bytes = [0u8; 16];
845 rng.fill(&mut bytes);
846 bytes.into()
847 }),
848 _ => unreachable!(),
849 },
850 rng.next_u32() as u16,
851 )
852 }
853
854 #[test]
855 fn agent_register_v1_ip4_same_encoding_test() {
856 let mut msg = AgentRegister {
857 account_id: 100,
858 agent_id: 32,
859 agent_version: 676,
860 timestamp: 103201401,
861 client_addr: "127.0.0.1:4123".parse().unwrap(),
862 tunnel_addr: "99.12.34.51:5312".parse().unwrap(),
863 signature: [0u8; 32],
864 proto_version: 1,
865 };
866
867 let sig = HmacSha256::create("test-secret-hehehe".as_bytes());
868 let mut buffer = Vec::new();
869 msg.update_signature(&mut buffer, &sig);
870 assert!(msg.verify_signature(&mut buffer, &sig));
871
872 buffer.clear();
873 msg.write_to(&mut buffer).unwrap();
874
875 let hex_buffer = hex::encode(&buffer);
876 assert_eq!(hex_buffer, "0000000000000064000000000000002000000000000002a4000000000626ba79047f000001101b04630c223314c0767a59319b8edfcc1e6f3d3ea2d19ac74a74e5f5333c9b335adc72cda821de5f");
877 }
878
879 #[test]
880 fn agent_register_v1_ip6_same_encoding_test() {
881 let mut msg = AgentRegister {
882 account_id: 100,
883 agent_id: 32,
884 agent_version: 676,
885 timestamp: 103201401,
886 client_addr: "[::88]:4123".parse().unwrap(),
887 tunnel_addr: "[::99]:5312".parse().unwrap(),
888 signature: [0u8; 32],
889 proto_version: 1,
890 };
891
892 let sig = HmacSha256::create("test-secret-hehehe".as_bytes());
893 let mut buffer = Vec::new();
894 msg.update_signature(&mut buffer, &sig);
895 assert!(msg.verify_signature(&mut buffer, &sig));
896
897 buffer.clear();
898 msg.write_to(&mut buffer).unwrap();
899
900 let hex_buffer = hex::encode(&buffer);
901 assert_eq!(hex_buffer, "0000000000000064000000000000002000000000000002a4000000000626ba790600000000000000000000000000000088101b060000000000000000000000000000009914c0724f203e7ac2f090800dbeb68afbf184f367f9ca14d8a0082e245070c3835c4b");
902 }
903
904 #[test]
905 fn legacy_mc_java_ping_decode_test() {
906 let data = hex::decode("000000000000000100000001000000000000000000").unwrap();
907 let mut reader = &data[..];
908
909 let msg = ControlRpcMessage::<ControlRequest>::read_from(&mut reader).unwrap();
910 assert_eq!(msg, ControlRpcMessage {
911 request_id: 1,
912 content: ControlRequest::Ping(Ping {
913 now: 0,
914 current_ping: None,
915 session_id: None,
916 }),
917 });
918 println!("Got msg: {msg:?}");
919 }
920}