1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
use tokio_codec::{Encoder, Decoder}; use bytes::{BytesMut, Buf, IntoBuf}; use crate::ext::*; use crate::{Result, ClientHello, ProtocolVersionUnsupported, ServerHello, EntryAssignment, EntryUpdate, EntryFlagsUpdate, EntryDelete, ClearAllEntries, Packet}; use std::io; #[derive(Clone, Debug)] pub enum ReceivedPacket { KeepAlive, ClientHello(ClientHello), ProtocolVersionUnsupported(ProtocolVersionUnsupported), ServerHelloComplete, ServerHello(ServerHello), ClientHelloComplete, EntryAssignment(EntryAssignment), EntryUpdate(EntryUpdate), EntryFlagsUpdate(EntryFlagsUpdate), EntryDelete(EntryDelete), ClearAllEntries(ClearAllEntries) } pub struct NTCodec; impl Encoder for NTCodec { type Item = Box<dyn Packet>; type Error = failure::Error; fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<()> { dst.put_serializable(&*item); Ok(()) } } impl Decoder for NTCodec { type Item = ReceivedPacket; type Error = failure::Error; fn decode(&mut self, src: &mut BytesMut) -> Result<Option<ReceivedPacket>> { let mut buf = src.clone().freeze().into_buf(); if buf.remaining() < 1 { return Ok(None); } let (packet, bytes) = match try_decode(&mut buf) { Ok(t) => t, Err(e) => match e.find_root_cause().downcast_ref::<io::Error>() { Some(err) => if err.kind() == io::ErrorKind::UnexpectedEof { return Ok(None); }else { return Err(e); }, None => return Err(e) } }; src.advance(bytes); Ok(Some(packet)) } } fn try_decode(mut buf: &mut dyn Buf) -> Result<(ReceivedPacket, usize)> { let id = buf.read_u8()?; let mut bytes = 1; let packet = match id { 0x00 => Some(ReceivedPacket::KeepAlive), 0x01 => { let (packet, read) = ClientHello::deserialize(buf)?; bytes += read; Some(ReceivedPacket::ClientHello(packet)) } 0x02 => { let (packet, read) = ProtocolVersionUnsupported::deserialize(buf)?; bytes += read; Some(ReceivedPacket::ProtocolVersionUnsupported(packet)) } 0x03 => Some(ReceivedPacket::ServerHelloComplete), 0x04 => { let (packet, read) = ServerHello::deserialize(buf)?; bytes += read; Some(ReceivedPacket::ServerHello(packet)) } 0x05 => Some(ReceivedPacket::ClientHelloComplete), 0x10 => { let (packet, read) = EntryAssignment::deserialize(buf)?; bytes += read; Some(ReceivedPacket::EntryAssignment(packet)) } 0x11 => { let (packet, read) = EntryUpdate::deserialize(buf)?; bytes += read; Some(ReceivedPacket::EntryUpdate(packet)) } 0x12 => { let (packet, read) = EntryFlagsUpdate::deserialize(buf)?; bytes += read; Some(ReceivedPacket::EntryFlagsUpdate(packet)) } 0x13 => { let (packet, read) = EntryDelete::deserialize(buf)?; bytes += read; Some(ReceivedPacket::EntryDelete(packet)) } 0x14 => { let (packet, read) = ClearAllEntries::deserialize(buf)?; bytes += read; Some(ReceivedPacket::ClearAllEntries(packet)) } _ => None }; Ok((packet.unwrap(), bytes)) }