Skip to main content

playit_agent_proto/
control_messages.rs

1use std::fmt::Debug;
2use std::io::{Read, Write};
3use std::net::SocketAddr;
4use std::ops::Not;
5use std::sync::Arc;
6
7use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
8use message_encoding::{m_max, m_max_list, m_opt_sum, m_static, MessageEncoding};
9use serde::ser::SerializeStruct;
10use serde::Serialize;
11
12use crate::{AgentSessionId, PortRange};
13use crate::hmac::HmacSha256;
14
15#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
16pub enum ControlRequest {
17    Ping(Ping),
18    AgentRegister(AgentRegister),
19    AgentKeepAlive(AgentSessionId),
20    SetupUdpChannel(AgentSessionId),
21    AgentCheckPortMapping(AgentCheckPortMapping),
22}
23
24#[repr(u32)]
25#[derive(Copy, Clone, PartialEq, Eq, Hash)]
26pub enum ControlRequestId {
27    _PingV1 = 1,
28    AgentRegisterV1,
29    AgentKeepAliveV1,
30    SetupUdpChannelV1,
31    AgentCheckPortMappingV1,
32    PingV2,
33    AgentRegisterV2,
34    END,
35}
36
37impl ControlRequestId {
38    pub fn from_num(num: u32) -> Option<Self> {
39        if (Self::END as u32) <= num || num == 0 {
40            return None;
41        }
42        Some(unsafe { std::mem::transmute::<u32, Self>(num) })
43    }
44}
45
46impl MessageEncoding for ControlRequestId {
47    const STATIC_SIZE: Option<usize> = Some(4);
48    
49    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
50        (*self as u32).write_to(out)
51    }
52
53    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
54        let v = u32::read_from(read)?;
55        ControlRequestId::from_num(v)
56            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid request id"))
57    }
58}
59
60impl MessageEncoding for ControlRequest {
61    const MAX_SIZE: Option<usize> = Some(m_static::<ControlRequestId>() + m_max_list(&[
62        m_max::<Ping>(),
63        m_max::<AgentRegister>(),
64        m_max::<AgentSessionId>(),
65        m_max::<AgentCheckPortMapping>(),
66    ]));
67
68    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
69        let mut sum = 0;
70
71        match self {
72            ControlRequest::Ping(data) => {
73                sum += ControlRequestId::PingV2.write_to(out)?;
74                sum += data.write_to(out)?;
75            }
76            ControlRequest::AgentRegister(data) => {
77                if data.proto_version <= 1 {
78                    sum += ControlRequestId::AgentRegisterV1.write_to(out)?;
79                } else {
80                    sum += ControlRequestId::AgentRegisterV2.write_to(out)?;
81                }
82                sum += data.write_to(out)?;
83            }
84            ControlRequest::AgentKeepAlive(data) => {
85                sum += ControlRequestId::AgentKeepAliveV1.write_to(out)?;
86                sum += data.write_to(out)?;
87            }
88            ControlRequest::SetupUdpChannel(data) => {
89                sum += ControlRequestId::SetupUdpChannelV1.write_to(out)?;
90                sum += data.write_to(out)?;
91            }
92            ControlRequest::AgentCheckPortMapping(data) => {
93                sum += ControlRequestId::AgentCheckPortMappingV1.write_to(out)?;
94                sum += data.write_to(out)?;
95            }
96        }
97
98        Ok(sum)
99    }
100
101    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
102        let id = ControlRequestId::read_from(read)?;
103        
104        match id {
105            ControlRequestId::PingV2 => Ok(ControlRequest::Ping(Ping::read_from(read)?)),
106            ControlRequestId::AgentRegisterV1 => Ok(ControlRequest::AgentRegister(AgentRegisterV1::read_from(read)?.upgrade())),
107            ControlRequestId::AgentRegisterV2 => Ok(ControlRequest::AgentRegister(AgentRegister::read_from(read)?)),
108            ControlRequestId::AgentKeepAliveV1 => Ok(ControlRequest::AgentKeepAlive(AgentSessionId::read_from(read)?)),
109            ControlRequestId::SetupUdpChannelV1 => Ok(ControlRequest::SetupUdpChannel(AgentSessionId::read_from(read)?)),
110            ControlRequestId::AgentCheckPortMappingV1 => Ok(ControlRequest::AgentCheckPortMapping(AgentCheckPortMapping::read_from(read)?)),
111            ControlRequestId::_PingV1 => Ok(ControlRequest::Ping(Ping {
112                now: u64::read_from(read)?,
113                session_id: None,
114                current_ping: None,
115            })),
116            _ => Err(std::io::Error::other("old control request no longer supported")),
117        }
118    }
119}
120
121#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
122pub struct AgentCheckPortMapping {
123    pub agent_session_id: AgentSessionId,
124    pub port_range: PortRange,
125}
126
127impl MessageEncoding for AgentCheckPortMapping {
128    const MAX_SIZE: Option<usize> = Some(m_static::<AgentSessionId>() + m_max::<PortRange>());
129
130    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
131        let mut sum = 0;
132        sum += self.agent_session_id.write_to(out)?;
133        sum += self.port_range.write_to(out)?;
134        Ok(sum)
135    }
136
137    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
138        Ok(AgentCheckPortMapping {
139            agent_session_id: AgentSessionId::read_from(read)?,
140            port_range: PortRange::read_from(read)?,
141        })
142    }
143}
144
145#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
146pub struct Ping {
147    pub now: u64,
148    pub current_ping: Option<u32>,
149    pub session_id: Option<AgentSessionId>,
150}
151
152impl MessageEncoding for Ping {
153    const STATIC_SIZE: Option<usize> = Some(8 + m_static::<Option<u32>>() + m_static::<Option<AgentSessionId>>());
154
155    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
156        let mut sum = 0;
157        sum += self.now.write_to(out)?;
158        sum += self.current_ping.write_to(out)?;
159        sum += self.session_id.write_to(out)?;
160        Ok(sum)
161    }
162
163    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
164        Ok(Ping {
165            now: MessageEncoding::read_from(read)?,
166            current_ping: MessageEncoding::read_from(read)?,
167            session_id: MessageEncoding::read_from(read)?,
168        })
169    }
170}
171
172
173#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
174pub struct AgentRegister {
175    pub proto_version: u64,
176    pub account_id: u64,
177    pub agent_id: u64,
178    pub agent_version: u64,
179    pub timestamp: u64,
180    pub client_addr: SocketAddr,
181    pub tunnel_addr: SocketAddr,
182    pub signature: [u8; 32],
183}
184
185impl AgentRegister {
186    pub fn update_signature(&mut self, temp_buffer: &mut Vec<u8>, hmac: &HmacSha256) {
187        self.write_plain(temp_buffer);
188        self.signature = hmac.sign(temp_buffer);
189    }
190
191    pub fn verify_signature(&self, temp_buffer: &mut Vec<u8>, hmac: &HmacSha256) -> bool {
192        self.write_plain(temp_buffer);
193        hmac.verify(temp_buffer, &self.signature).is_ok()
194    }
195
196    fn write_plain(&self, temp_buffer: &mut Vec<u8>) {
197        temp_buffer.clear();
198        self.write_to(temp_buffer).unwrap();
199        assert!(self.signature.len() <= temp_buffer.len());
200
201        let adjusted_len = temp_buffer.len() - self.signature.len();
202        temp_buffer.truncate(adjusted_len);
203    }
204}
205
206const ENCODING_INCLUDES_VERSION_BIT: u64 = 1u64 << 63;
207
208impl MessageEncoding for AgentRegister {
209    const MAX_SIZE: Option<usize> = m_opt_sum(&[
210        u64::MAX_SIZE,
211        u64::MAX_SIZE,
212        u64::MAX_SIZE,
213        u64::MAX_SIZE,
214        u64::MAX_SIZE,
215        SocketAddr::MAX_SIZE,
216        SocketAddr::MAX_SIZE,
217        Some(32),
218    ]);
219
220    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
221        let mut sum = 0;
222
223        if self.proto_version <= 1 {
224            if (self.account_id & ENCODING_INCLUDES_VERSION_BIT) == ENCODING_INCLUDES_VERSION_BIT {
225                return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "account id too large for proto version 1"));
226            }
227
228            sum += self.account_id.write_to(out)?;
229        } else {
230            if (self.proto_version & ENCODING_INCLUDES_VERSION_BIT) == ENCODING_INCLUDES_VERSION_BIT {
231                return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid proto version"));
232            }
233
234            sum += (self.proto_version | ENCODING_INCLUDES_VERSION_BIT).write_to(out)?;
235            sum += self.account_id.write_to(out)?;
236        }
237
238        sum += self.agent_id.write_to(out)?;
239        sum += self.agent_version.write_to(out)?;
240        sum += self.timestamp.write_to(out)?;
241        sum += self.client_addr.write_to(out)?;
242        sum += self.tunnel_addr.write_to(out)?;
243        out.write_all(&self.signature)?;
244        sum += self.signature.len();
245        Ok(sum)
246    }
247
248    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
249        let first_word = u64::read_from(read)?;
250
251        let mut proto_version = 1;
252        let account_id: u64;
253
254        if (first_word & ENCODING_INCLUDES_VERSION_BIT) == ENCODING_INCLUDES_VERSION_BIT {
255            proto_version = first_word & ENCODING_INCLUDES_VERSION_BIT.not();
256            account_id = u64::read_from(read)?;
257        } else {
258            account_id = first_word;
259        }
260
261        let mut res = AgentRegister {
262            proto_version,
263            account_id,
264            agent_id: u64::read_from(read)?,
265            agent_version: u64::read_from(read)?,
266            timestamp: u64::read_from(read)?,
267            client_addr: SocketAddr::read_from(read)?,
268            tunnel_addr: SocketAddr::read_from(read)?,
269            signature: [0u8; 32],
270        };
271
272        read.read_exact(&mut res.signature[..])?;
273        Ok(res)
274    }
275}
276
277pub struct AgentRegisterV1 {
278    pub account_id: u64,
279    pub agent_id: u64,
280    pub agent_version: u64,
281    pub timestamp: u64,
282    pub client_addr: SocketAddr,
283    pub tunnel_addr: SocketAddr,
284    pub signature: [u8; 32],
285}
286
287impl MessageEncoding for AgentRegisterV1 {
288    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
289        out.write_u64::<BigEndian>(self.account_id)?;
290        out.write_u64::<BigEndian>(self.agent_id)?;
291        out.write_u64::<BigEndian>(self.agent_version)?;
292        out.write_u64::<BigEndian>(self.timestamp)?;
293        let mut len = 8 + 8 + 8 + 8;
294        len += self.client_addr.write_to(out)?;
295        len += self.tunnel_addr.write_to(out)?;
296        if out.write(&self.signature)? != 32 {
297            return Err(std::io::Error::new(std::io::ErrorKind::WriteZero, "failed to write full signature"));
298        }
299        len += 32;
300        Ok(len)
301    }
302
303    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
304        let mut res = Self {
305            account_id: read.read_u64::<BigEndian>()?,
306            agent_id: read.read_u64::<BigEndian>()?,
307            agent_version: read.read_u64::<BigEndian>()?,
308            timestamp: read.read_u64::<BigEndian>()?,
309            client_addr: SocketAddr::read_from(read)?,
310            tunnel_addr: SocketAddr::read_from(read)?,
311            signature: [0u8; 32],
312        };
313
314        if read.read(&mut res.signature[..])? != 32 {
315            return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "missing signature"));
316        }
317
318        Ok(res)
319    }
320}
321
322impl AgentRegisterV1 {
323    pub fn upgrade(self) -> AgentRegister {
324        AgentRegister {
325            proto_version: 1,
326            account_id: self.account_id,
327            agent_id: self.agent_id,
328            agent_version: self.agent_version,
329            timestamp: self.timestamp,
330            client_addr: self.client_addr,
331            tunnel_addr: self.tunnel_addr,
332            signature: self.signature,
333        }
334    }
335}
336
337#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
338pub enum ControlResponse {
339    Pong(Pong),
340    InvalidSignature,
341    Unauthorized,
342    RequestQueued,
343    TryAgainLater,
344    AgentRegistered(AgentRegistered),
345    AgentPortMapping(AgentPortMapping),
346    UdpChannelDetails(UdpChannelDetails),
347}
348
349impl MessageEncoding for ControlResponse {
350    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
351        let mut sum = 0;
352
353        match self {
354            ControlResponse::Pong(data) => {
355                sum += 1u32.write_to(out)?;
356                sum += data.write_to(out)?;
357            }
358            ControlResponse::InvalidSignature => {
359                sum += 2u32.write_to(out)?;
360            }
361            ControlResponse::Unauthorized => {
362                sum += 3u32.write_to(out)?;
363            }
364            ControlResponse::RequestQueued => {
365                sum += 4u32.write_to(out)?;
366            }
367            ControlResponse::TryAgainLater => {
368                sum += 5u32.write_to(out)?;
369            }
370            ControlResponse::AgentRegistered(data) => {
371                sum += 6u32.write_to(out)?;
372                sum += data.write_to(out)?;
373            }
374            ControlResponse::AgentPortMapping(data) => {
375                sum += 7u32.write_to(out)?;
376                sum += data.write_to(out)?;
377            }
378            ControlResponse::UdpChannelDetails(data) => {
379                sum += 8u32.write_to(out)?;
380                sum += data.write_to(out)?;
381            }
382        }
383
384        Ok(sum)
385    }
386
387    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
388        match read.read_u32::<BigEndian>()? {
389            1 => Ok(ControlResponse::Pong(Pong::read_from(read)?)),
390            2 => Ok(ControlResponse::InvalidSignature),
391            3 => Ok(ControlResponse::Unauthorized),
392            4 => Ok(ControlResponse::RequestQueued),
393            5 => Ok(ControlResponse::TryAgainLater),
394            6 => Ok(ControlResponse::AgentRegistered(AgentRegistered::read_from(read)?)),
395            7 => Ok(ControlResponse::AgentPortMapping(AgentPortMapping::read_from(read)?)),
396            8 => Ok(ControlResponse::UdpChannelDetails(UdpChannelDetails::read_from(read)?)),
397            _ => Err(std::io::Error::other("invalid ControlResponse id")),
398        }
399    }
400}
401
402#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
403pub struct AgentPortMapping {
404    pub range: PortRange,
405    pub found: Option<AgentPortMappingFound>,
406}
407
408impl MessageEncoding for AgentPortMapping {
409    const MAX_SIZE: Option<usize> = Some(
410        m_max::<PortRange>() +
411        m_max::<Option<AgentPortMappingFound>>()
412    );
413
414    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
415        let mut sum = 0;
416        sum += self.range.write_to(out)?;
417        sum += self.found.write_to(out)?;
418        Ok(sum)
419    }
420
421    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
422        Ok(AgentPortMapping {
423            range: PortRange::read_from(read)?,
424            found: Option::<AgentPortMappingFound>::read_from(read)?,
425        })
426    }
427}
428
429#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
430pub enum AgentPortMappingFound {
431    ToAgent(AgentSessionId),
432}
433
434impl MessageEncoding for AgentPortMappingFound {
435    const MAX_SIZE: Option<usize> = Some(4 + m_max_list(&[
436        m_max::<AgentSessionId>(),
437    ]));
438    
439    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
440        let mut sum = 0;
441
442        match self {
443            AgentPortMappingFound::ToAgent(id) => {
444                sum += 1u32.write_to(out)?;
445                sum += id.write_to(out)?;
446            }
447        }
448
449        Ok(sum)
450    }
451
452    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
453        match read.read_u32::<BigEndian>()? {
454            1 => Ok(AgentPortMappingFound::ToAgent(AgentSessionId::read_from(read)?)),
455            _ => Err(std::io::Error::new(std::io::ErrorKind::Other, "unknown AgentPortMappingFound id")),
456        }
457    }
458}
459
460#[derive(Eq, PartialEq, Clone)]
461pub struct UdpChannelDetails {
462    pub tunnel_addr: SocketAddr,
463    pub token: Arc<Vec<u8>>,
464}
465
466impl Serialize for UdpChannelDetails {
467    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
468        let mut s = serializer.serialize_struct("UdpChannelDetails", 2)?;
469        s.serialize_field("tunnel_addr", &self.tunnel_addr)?;
470        s.serialize_field("token", &*self.token)?;
471        s.end()
472    }
473}
474
475impl Debug for UdpChannelDetails {
476    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477        f.debug_struct("UdpChannelDetails")
478            .field("tunnel_addr", &self.tunnel_addr)
479            .field("token", &hex::encode(&self.token[..]))
480            .finish()
481    }
482}
483
484impl MessageEncoding for UdpChannelDetails {
485    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
486        let mut sum = 0;
487        sum += self.tunnel_addr.write_to(out)?;
488        sum += self.token.write_to(out)?;
489        Ok(sum)
490    }
491
492    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
493        Ok(UdpChannelDetails {
494            tunnel_addr: SocketAddr::read_from(read)?,
495            token: Arc::new(Vec::read_from(read)?),
496        })
497    }
498}
499
500#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
501pub struct Pong {
502    pub request_now: u64,
503    pub server_now: u64,
504    pub server_id: u64,
505    pub data_center_id: u32,
506    pub client_addr: SocketAddr,
507    pub tunnel_addr: SocketAddr,
508    pub session_expire_at: Option<u64>,
509}
510
511impl MessageEncoding for Pong {
512    const MAX_SIZE: Option<usize> = Some(
513        m_static::<u64>() * 3 +
514        m_static::<u32>() +
515        m_max::<SocketAddr>() * 2 +
516        m_static::<Option<u64>>()
517    );
518
519    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
520        let mut sum = 0;
521        sum += self.request_now.write_to(out)?;
522        sum += self.server_now.write_to(out)?;
523        sum += self.server_id.write_to(out)?;
524        sum += self.data_center_id.write_to(out)?;
525        sum += self.client_addr.write_to(out)?;
526        sum += self.tunnel_addr.write_to(out)?;
527        sum += self.session_expire_at.write_to(out)?;
528        Ok(sum)
529    }
530
531    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
532        Ok(Pong {
533            request_now: read.read_u64::<BigEndian>()?,
534            server_now: read.read_u64::<BigEndian>()?,
535            server_id: read.read_u64::<BigEndian>()?,
536            data_center_id: read.read_u32::<BigEndian>()?,
537            client_addr: SocketAddr::read_from(read)?,
538            tunnel_addr: SocketAddr::read_from(read)?,
539            session_expire_at: Option::read_from(read)?,
540        })
541    }
542}
543
544#[derive(Debug, Eq, PartialEq, Clone, Serialize)]
545pub struct AgentRegistered {
546    pub id: AgentSessionId,
547    pub expires_at: u64,
548}
549
550impl MessageEncoding for AgentRegistered {
551    const STATIC_SIZE: Option<usize> = Some(
552        m_static::<AgentSessionId>() +
553        m_static::<u64>()
554    );
555
556    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
557        let mut sum = 0;
558        sum += self.id.write_to(out)?;
559        sum += self.expires_at.write_to(out)?;
560        Ok(sum)
561    }
562
563    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
564        Ok(AgentRegistered {
565            id: AgentSessionId::read_from(read)?,
566            expires_at: read.read_u64::<BigEndian>()?,
567        })
568    }
569}
570
571#[cfg(test)]
572mod test {
573    use std::fmt::Debug;
574    use std::net::{IpAddr, Ipv4Addr};
575
576    use rand::{Rng, RngCore, thread_rng};
577
578    use crate::PortProto;
579    use crate::rpc::ControlRpcMessage;
580
581    use super::*;
582
583    #[test]
584    fn agent_register_sign_test() {
585        let mut reg = AgentRegister {
586            proto_version: 0,
587            account_id: 1,
588            agent_id: 2,
589            agent_version: 3,
590            timestamp: 1000,
591            client_addr: "10.20.30.40:5678".parse().unwrap(),
592            tunnel_addr: "9.20.3.40:9912".parse().unwrap(),
593            signature: [0u8; 32],
594        };
595
596        let hmac = HmacSha256::create("this is a super secret secret".as_bytes());
597
598        let mut buffer = Vec::new();
599        reg.update_signature(&mut buffer, &hmac);
600        assert!(reg.verify_signature(&mut buffer, &hmac));
601
602        reg.proto_version = 1;
603        reg.update_signature(&mut buffer, &hmac);
604        assert!(reg.verify_signature(&mut buffer, &hmac));
605    }
606
607    #[test]
608    fn agent_register_old_proto_decode() {
609        let reg = AgentRegisterV1 {
610            account_id: 1,
611            agent_id: 2,
612            agent_version: 3,
613            timestamp: 1000,
614            client_addr: "10.20.30.40:5678".parse().unwrap(),
615            tunnel_addr: "9.20.3.40:9912".parse().unwrap(),
616            signature: [0u8; 32],
617        };
618
619        let mut out = Vec::new();
620        ControlRequestId::AgentRegisterV1.write_to(&mut out).unwrap();
621        reg.write_to(&mut out).unwrap();
622
623        let mut reader = &out[..];
624        let read = ControlRequest::read_from(&mut reader).unwrap();
625        assert_eq!(read, ControlRequest::AgentRegister(AgentRegister {
626            proto_version: 1,
627            account_id: 1,
628            agent_id: 2,
629            agent_version: 3,
630            timestamp: 1000,
631            client_addr: "10.20.30.40:5678".parse().unwrap(),
632            tunnel_addr: "9.20.3.40:9912".parse().unwrap(),
633            signature: [0u8; 32],
634        }))
635    }
636
637    #[test]
638    fn fuzzy_test_control_request() {
639        let mut rng = thread_rng();
640        let mut buffer = vec![0u8; 2048];
641
642        for _ in 0..100000 {
643            let msg = rng_control_request(&mut rng);
644            test_encoding(msg, &mut buffer);
645        }
646
647        for _ in 0..1000 {
648            test_encoding(ControlRpcMessage {
649                request_id: rng.next_u64(),
650                content: rng_control_request(&mut rng),
651            }, &mut buffer);
652        }
653    }
654
655    #[test]
656    fn fuzzy_test_control_response() {
657        let mut rng = thread_rng();
658        let mut buffer = vec![0u8; 2048];
659
660        for _ in 0..100000 {
661            let msg = rng_control_response(&mut rng);
662            test_encoding(msg, &mut buffer);
663        }
664
665        for _ in 0..1000 {
666            test_encoding(ControlRpcMessage {
667                request_id: rng.next_u64(),
668                content: rng_control_response(&mut rng),
669            }, &mut buffer);
670        }
671    }
672
673    fn test_encoding<T: MessageEncoding + PartialEq + Debug>(msg: T, buffer: &mut [u8]) {
674        assert_eq!(0, T::_ASSERT);
675
676        let mut writer = &mut buffer[..];
677        msg.write_to(&mut writer).unwrap();
678
679        let remaining_len = writer.len();
680        let written = buffer.len() - remaining_len;
681
682        if let Some(size) =  T::STATIC_SIZE {
683            assert_eq!(written, size);
684        }
685
686        if let Some(size) = T::MAX_SIZE {
687            assert!(written <= size);
688        }
689
690        let mut reader = &buffer[0..written];
691        let recovered = T::read_from(&mut reader).unwrap();
692
693        assert_eq!(msg, recovered);
694    }
695
696    pub fn rng_control_request<R: RngCore>(rng: &mut R) -> ControlRequest {
697        match rng.next_u32() % 5 {
698            0 => ControlRequest::Ping(Ping {
699                now: rng.next_u64(),
700                current_ping: if rng.next_u32() % 2 == 0 {
701                    Some(rng.next_u32())
702                } else {
703                    None
704                },
705                session_id: if rng.next_u32() % 2 == 0 {
706                    Some(AgentSessionId {
707                        session_id: rng.next_u64(),
708                        account_id: rng.next_u64() % (i64::MAX as u64),
709                        agent_id: rng.next_u64(),
710                    })
711                } else {
712                    None
713                },
714            }),
715            1 => ControlRequest::AgentRegister(AgentRegister {
716                proto_version: 1 + rng.next_u64() % 2,
717                account_id: rng.next_u64() % (i64::MAX as u64),
718                agent_id: rng.next_u64(),
719                agent_version: rng.next_u64(),
720                timestamp: rng.next_u64(),
721                client_addr: rng_socket_address(rng),
722                tunnel_addr: rng_socket_address(rng),
723                signature: {
724                    let mut bytes = [0u8; 32];
725                    rng.fill(&mut bytes);
726                    bytes
727                },
728            }),
729            2 => ControlRequest::AgentKeepAlive(AgentSessionId {
730                session_id: rng.next_u64(),
731                account_id: rng.next_u64() % (i64::MAX as u64),
732                agent_id: rng.next_u64(),
733            }),
734            3 => ControlRequest::SetupUdpChannel(AgentSessionId {
735                session_id: rng.next_u64(),
736                account_id: rng.next_u64() % (i64::MAX as u64),
737                agent_id: rng.next_u64(),
738            }),
739            4 => ControlRequest::AgentCheckPortMapping(AgentCheckPortMapping {
740                agent_session_id: AgentSessionId {
741                    session_id: rng.next_u64(),
742                    account_id: rng.next_u64() % (i64::MAX as u64),
743                    agent_id: rng.next_u64(),
744                },
745                port_range: PortRange {
746                    ip: match rng.next_u32() % 2 {
747                        0 => IpAddr::V4(Ipv4Addr::from(rng.next_u32())),
748                        1 => IpAddr::V6({
749                            let mut bytes = [0u8; 16];
750                            rng.fill(&mut bytes);
751                            bytes.into()
752                        }),
753                        _ => unreachable!(),
754                    },
755                    port_start: rng.next_u32() as u16,
756                    port_end: rng.next_u32() as u16,
757                    port_proto: match rng.next_u32() % 3 {
758                        0 => PortProto::Tcp,
759                        1 => PortProto::Udp,
760                        2 => PortProto::Both,
761                        _ => unreachable!(),
762                    },
763                },
764            }),
765            _ => unreachable!(),
766        }
767    }
768
769    pub fn rng_control_response<R: RngCore>(rng: &mut R) -> ControlResponse {
770        match rng.next_u32() % 8 {
771            0 => ControlResponse::Pong(Pong {
772                request_now: rng.next_u64(),
773                server_now: rng.next_u64(),
774                server_id: rng.next_u64(),
775                data_center_id: rng.next_u32(),
776                client_addr: rng_socket_address(rng),
777                tunnel_addr: rng_socket_address(rng),
778                session_expire_at: if rng.next_u32() % 2 == 1 {
779                    Some(rng.next_u64())
780                } else {
781                    None
782                },
783            }),
784            1 => ControlResponse::InvalidSignature,
785            2 => ControlResponse::Unauthorized,
786            3 => ControlResponse::RequestQueued,
787            4 => ControlResponse::TryAgainLater,
788            5 => ControlResponse::AgentRegistered(AgentRegistered {
789                id: AgentSessionId {
790                    session_id: rng.next_u64(),
791                    account_id: rng.next_u64() % (i64::MAX as u64),
792                    agent_id: rng.next_u64(),
793                },
794                expires_at: rng.next_u64(),
795            }),
796            6 => ControlResponse::AgentPortMapping(AgentPortMapping {
797                range: PortRange {
798                    ip: match rng.next_u32() % 2 {
799                        0 => IpAddr::V4(Ipv4Addr::from(rng.next_u32())),
800                        1 => IpAddr::V6({
801                            let mut bytes = [0u8; 16];
802                            rng.fill(&mut bytes);
803                            bytes.into()
804                        }),
805                        _ => unreachable!(),
806                    },
807                    port_start: rng.next_u32() as u16,
808                    port_end: rng.next_u32() as u16,
809                    port_proto: match rng.next_u32() % 3 {
810                        0 => PortProto::Tcp,
811                        1 => PortProto::Udp,
812                        2 => PortProto::Both,
813                        _ => unreachable!(),
814                    },
815                },
816                found: match rng.next_u32() % 2 {
817                    0 => None,
818                    1 => Some(AgentPortMappingFound::ToAgent(AgentSessionId {
819                        session_id: rng.next_u64(),
820                        account_id: rng.next_u64() % (i64::MAX as u64),
821                        agent_id: rng.next_u64(),
822                    })),
823                    _ => unreachable!()
824                },
825            }),
826            7 => ControlResponse::UdpChannelDetails(UdpChannelDetails {
827                tunnel_addr: rng_socket_address(rng),
828                token: {
829                    let len = ((rng.next_u64() % 30) + 32) as usize;
830                    let mut buffer = vec![0u8; len];
831                    rng.fill_bytes(&mut buffer);
832                    Arc::new(buffer)
833                },
834            }),
835            _ => unreachable!()
836        }
837    }
838
839    fn rng_socket_address<R: RngCore>(rng: &mut R) -> SocketAddr {
840        SocketAddr::new(
841            match rng.next_u32() % 2 {
842                0 => IpAddr::V4(Ipv4Addr::from(rng.next_u32())),
843                1 => IpAddr::V6({
844                    let mut bytes = [0u8; 16];
845                    rng.fill(&mut bytes);
846                    bytes.into()
847                }),
848                _ => unreachable!(),
849            },
850            rng.next_u32() as u16,
851        )
852    }
853
854    #[test]
855    fn agent_register_v1_ip4_same_encoding_test() {
856        let mut msg = AgentRegister {
857            account_id: 100,
858            agent_id: 32,
859            agent_version: 676,
860            timestamp: 103201401,
861            client_addr: "127.0.0.1:4123".parse().unwrap(),
862            tunnel_addr: "99.12.34.51:5312".parse().unwrap(),
863            signature: [0u8; 32],
864            proto_version: 1,
865        };
866
867        let sig = HmacSha256::create("test-secret-hehehe".as_bytes());
868        let mut buffer = Vec::new();
869        msg.update_signature(&mut buffer, &sig);
870        assert!(msg.verify_signature(&mut buffer, &sig));
871
872        buffer.clear();
873        msg.write_to(&mut buffer).unwrap();
874
875        let hex_buffer = hex::encode(&buffer);
876        assert_eq!(hex_buffer, "0000000000000064000000000000002000000000000002a4000000000626ba79047f000001101b04630c223314c0767a59319b8edfcc1e6f3d3ea2d19ac74a74e5f5333c9b335adc72cda821de5f");
877    }
878
879    #[test]
880    fn agent_register_v1_ip6_same_encoding_test() {
881        let mut msg = AgentRegister {
882            account_id: 100,
883            agent_id: 32,
884            agent_version: 676,
885            timestamp: 103201401,
886            client_addr: "[::88]:4123".parse().unwrap(),
887            tunnel_addr: "[::99]:5312".parse().unwrap(),
888            signature: [0u8; 32],
889            proto_version: 1,
890        };
891
892        let sig = HmacSha256::create("test-secret-hehehe".as_bytes());
893        let mut buffer = Vec::new();
894        msg.update_signature(&mut buffer, &sig);
895        assert!(msg.verify_signature(&mut buffer, &sig));
896
897        buffer.clear();
898        msg.write_to(&mut buffer).unwrap();
899
900        let hex_buffer = hex::encode(&buffer);
901        assert_eq!(hex_buffer, "0000000000000064000000000000002000000000000002a4000000000626ba790600000000000000000000000000000088101b060000000000000000000000000000009914c0724f203e7ac2f090800dbeb68afbf184f367f9ca14d8a0082e245070c3835c4b");
902    }
903
904    #[test]
905    fn legacy_mc_java_ping_decode_test() {
906        let data = hex::decode("000000000000000100000001000000000000000000").unwrap();
907        let mut reader = &data[..];
908
909        let msg = ControlRpcMessage::<ControlRequest>::read_from(&mut reader).unwrap();
910        assert_eq!(msg, ControlRpcMessage {
911            request_id: 1,
912            content: ControlRequest::Ping(Ping {
913                now: 0,
914                current_ping: None,
915                session_id: None,
916            }),
917        });
918        println!("Got msg: {msg:?}");
919    }
920}