1use crate::ext::*;
2use crate::{
3 ClearAllEntries, ClientHello, EntryAssignment, EntryDelete, EntryFlagsUpdate, EntryUpdate,
4 Packet, ProtocolVersionUnsupported, Result, RpcExecute, RpcResponse, ServerHello,
5};
6use anyhow::anyhow;
7use bytes::{Buf, BytesMut};
8use std::io;
9use tokio_util::codec::{Decoder, Encoder};
10
11#[derive(Clone, Debug)]
12pub enum ReceivedPacket {
13 KeepAlive,
14 ClientHello(ClientHello),
15 ProtocolVersionUnsupported(ProtocolVersionUnsupported),
16 ServerHelloComplete,
17 ServerHello(ServerHello),
18 ClientHelloComplete,
19 EntryAssignment(EntryAssignment),
20 EntryUpdate(EntryUpdate),
21 EntryFlagsUpdate(EntryFlagsUpdate),
22 EntryDelete(EntryDelete),
23 ClearAllEntries(ClearAllEntries),
24 RpcExecute(RpcExecute),
25 RpcResponse(RpcResponse),
26}
27
28pub struct NTCodec;
29
30impl Encoder for NTCodec {
31 type Item = Box<dyn Packet>;
32 type Error = anyhow::Error;
33
34 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<()> {
35 dst.put_serializable(&*item);
36 Ok(())
37 }
38}
39
40impl Decoder for NTCodec {
41 type Item = ReceivedPacket;
42 type Error = anyhow::Error;
43
44 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<ReceivedPacket>> {
45 let mut buf = src.clone().freeze();
46
47 if buf.remaining() < 1 {
48 return Ok(None);
49 }
50
51 let (packet, bytes) = match try_decode(&mut buf) {
52 Ok(t) => t,
53 Err(e) => match e.downcast_ref::<io::Error>() {
54 Some(err) if err.kind() == io::ErrorKind::UnexpectedEof => {
55 return Ok(None);
56 }
57 _ => return Err(e),
58 },
59 };
60
61 src.advance(bytes);
62 Ok(Some(packet))
63 }
64}
65
66fn try_decode(mut buf: &mut dyn Buf) -> Result<(ReceivedPacket, usize)> {
67 let id = buf.read_u8()?;
68
69 let mut bytes = 1;
70
71 let packet = match id {
72 0x00 => Some(ReceivedPacket::KeepAlive),
73 0x01 => {
74 let (packet, read) = ClientHello::deserialize(buf)?;
75 bytes += read;
76 Some(ReceivedPacket::ClientHello(packet))
77 }
78 0x02 => {
79 let (packet, read) = ProtocolVersionUnsupported::deserialize(buf)?;
80 bytes += read;
81 Some(ReceivedPacket::ProtocolVersionUnsupported(packet))
82 }
83 0x03 => Some(ReceivedPacket::ServerHelloComplete),
84 0x04 => {
85 let (packet, read) = ServerHello::deserialize(buf)?;
86 bytes += read;
87 Some(ReceivedPacket::ServerHello(packet))
88 }
89 0x05 => Some(ReceivedPacket::ClientHelloComplete),
90 0x10 => {
91 let (packet, read) = EntryAssignment::deserialize(buf)?;
92 bytes += read;
93 Some(ReceivedPacket::EntryAssignment(packet))
94 }
95 0x11 => {
96 let (packet, read) = EntryUpdate::deserialize(buf)?;
97 bytes += read;
98 Some(ReceivedPacket::EntryUpdate(packet))
99 }
100 0x12 => {
101 let (packet, read) = EntryFlagsUpdate::deserialize(buf)?;
102 bytes += read;
103 Some(ReceivedPacket::EntryFlagsUpdate(packet))
104 }
105 0x13 => {
106 let (packet, read) = EntryDelete::deserialize(buf)?;
107 bytes += read;
108 Some(ReceivedPacket::EntryDelete(packet))
109 }
110 0x14 => {
111 let (packet, read) = ClearAllEntries::deserialize(buf)?;
112 bytes += read;
113 Some(ReceivedPacket::ClearAllEntries(packet))
114 }
115 0x20 => {
116 let (packet, read) = RpcExecute::deserialize(buf)?;
117 bytes += read;
118 Some(ReceivedPacket::RpcExecute(packet))
119 }
120 0x21 => {
121 let (packet, read) = RpcResponse::deserialize(buf)?;
122 bytes += read;
123 Some(ReceivedPacket::RpcResponse(packet))
124 }
125 _ => None,
126 };
127
128 match packet {
129 Some(packet) => Ok((packet, bytes)),
130 None => Err(anyhow!("Failed to decode packet")),
131 }
132}