nt_network/
codec.rs

1use crate::ext::*;
2use crate::{
3    ClearAllEntries, ClientHello, EntryAssignment, EntryDelete, EntryFlagsUpdate, EntryUpdate,
4    Packet, ProtocolVersionUnsupported, Result, RpcExecute, RpcResponse, ServerHello,
5};
6use anyhow::anyhow;
7use bytes::{Buf, BytesMut};
8use std::io;
9use tokio_util::codec::{Decoder, Encoder};
10
11#[derive(Clone, Debug)]
12pub enum ReceivedPacket {
13    KeepAlive,
14    ClientHello(ClientHello),
15    ProtocolVersionUnsupported(ProtocolVersionUnsupported),
16    ServerHelloComplete,
17    ServerHello(ServerHello),
18    ClientHelloComplete,
19    EntryAssignment(EntryAssignment),
20    EntryUpdate(EntryUpdate),
21    EntryFlagsUpdate(EntryFlagsUpdate),
22    EntryDelete(EntryDelete),
23    ClearAllEntries(ClearAllEntries),
24    RpcExecute(RpcExecute),
25    RpcResponse(RpcResponse),
26}
27
28pub struct NTCodec;
29
30impl Encoder for NTCodec {
31    type Item = Box<dyn Packet>;
32    type Error = anyhow::Error;
33
34    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<()> {
35        dst.put_serializable(&*item);
36        Ok(())
37    }
38}
39
40impl Decoder for NTCodec {
41    type Item = ReceivedPacket;
42    type Error = anyhow::Error;
43
44    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<ReceivedPacket>> {
45        let mut buf = src.clone().freeze();
46
47        if buf.remaining() < 1 {
48            return Ok(None);
49        }
50
51        let (packet, bytes) = match try_decode(&mut buf) {
52            Ok(t) => t,
53            Err(e) => match e.downcast_ref::<io::Error>() {
54                Some(err) if err.kind() == io::ErrorKind::UnexpectedEof => {
55                    return Ok(None);
56                }
57                _ => return Err(e),
58            },
59        };
60
61        src.advance(bytes);
62        Ok(Some(packet))
63    }
64}
65
66fn try_decode(mut buf: &mut dyn Buf) -> Result<(ReceivedPacket, usize)> {
67    let id = buf.read_u8()?;
68
69    let mut bytes = 1;
70
71    let packet = match id {
72        0x00 => Some(ReceivedPacket::KeepAlive),
73        0x01 => {
74            let (packet, read) = ClientHello::deserialize(buf)?;
75            bytes += read;
76            Some(ReceivedPacket::ClientHello(packet))
77        }
78        0x02 => {
79            let (packet, read) = ProtocolVersionUnsupported::deserialize(buf)?;
80            bytes += read;
81            Some(ReceivedPacket::ProtocolVersionUnsupported(packet))
82        }
83        0x03 => Some(ReceivedPacket::ServerHelloComplete),
84        0x04 => {
85            let (packet, read) = ServerHello::deserialize(buf)?;
86            bytes += read;
87            Some(ReceivedPacket::ServerHello(packet))
88        }
89        0x05 => Some(ReceivedPacket::ClientHelloComplete),
90        0x10 => {
91            let (packet, read) = EntryAssignment::deserialize(buf)?;
92            bytes += read;
93            Some(ReceivedPacket::EntryAssignment(packet))
94        }
95        0x11 => {
96            let (packet, read) = EntryUpdate::deserialize(buf)?;
97            bytes += read;
98            Some(ReceivedPacket::EntryUpdate(packet))
99        }
100        0x12 => {
101            let (packet, read) = EntryFlagsUpdate::deserialize(buf)?;
102            bytes += read;
103            Some(ReceivedPacket::EntryFlagsUpdate(packet))
104        }
105        0x13 => {
106            let (packet, read) = EntryDelete::deserialize(buf)?;
107            bytes += read;
108            Some(ReceivedPacket::EntryDelete(packet))
109        }
110        0x14 => {
111            let (packet, read) = ClearAllEntries::deserialize(buf)?;
112            bytes += read;
113            Some(ReceivedPacket::ClearAllEntries(packet))
114        }
115        0x20 => {
116            let (packet, read) = RpcExecute::deserialize(buf)?;
117            bytes += read;
118            Some(ReceivedPacket::RpcExecute(packet))
119        }
120        0x21 => {
121            let (packet, read) = RpcResponse::deserialize(buf)?;
122            bytes += read;
123            Some(ReceivedPacket::RpcResponse(packet))
124        }
125        _ => None,
126    };
127
128    match packet {
129        Some(packet) => Ok((packet, bytes)),
130        None => Err(anyhow!("Failed to decode packet")),
131    }
132}