adns_proto/
header.rs

1use smallvec::SmallVec;
2
3#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)]
4#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
5#[repr(u8)]
6pub enum QueryResponse {
7    #[default]
8    Query,
9    Response,
10}
11
12#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)]
13#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
14#[repr(u8)]
15pub enum Opcode {
16    #[default]
17    Query,
18    InverseQuery,
19    Status,
20    Update,
21    Other(u8),
22}
23
24impl From<u8> for Opcode {
25    fn from(value: u8) -> Self {
26        match value {
27            0 => Opcode::Query,
28            1 => Opcode::InverseQuery,
29            2 => Opcode::Status,
30            5 => Opcode::Update,
31            3..=15 => Opcode::Other(value),
32            _ => panic!("invalid range of value for opcode"),
33        }
34    }
35}
36
37impl From<Opcode> for u8 {
38    fn from(value: Opcode) -> u8 {
39        match value {
40            Opcode::Query => 0,
41            Opcode::InverseQuery => 1,
42            Opcode::Status => 2,
43            Opcode::Update => 5,
44            Opcode::Other(x) => x,
45        }
46    }
47}
48
49#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)]
50#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
51#[repr(u8)]
52pub enum ResponseCode {
53    #[default]
54    NoError,
55    FormatError,
56    ServerFailure,
57    NameError,
58    NotImplemented,
59    Refused,
60    YxDomain,
61    YxRRSet,
62    NxRRSet,
63    NotAuth,
64    NotZone,
65    Other(u8),
66}
67
68impl From<u8> for ResponseCode {
69    fn from(value: u8) -> Self {
70        match value {
71            0 => ResponseCode::NoError,
72            1 => ResponseCode::FormatError,
73            2 => ResponseCode::ServerFailure,
74            3 => ResponseCode::NameError,
75            4 => ResponseCode::NotImplemented,
76            5 => ResponseCode::Refused,
77            6 => ResponseCode::YxDomain,
78            7 => ResponseCode::YxRRSet,
79            8 => ResponseCode::NxRRSet,
80            9 => ResponseCode::NotAuth,
81            10 => ResponseCode::NotZone,
82            11..=15 => ResponseCode::Other(value),
83            _ => panic!("invalid range of value for response code"),
84        }
85    }
86}
87
88impl From<ResponseCode> for u8 {
89    fn from(value: ResponseCode) -> u8 {
90        match value {
91            ResponseCode::NoError => 0,
92            ResponseCode::FormatError => 1,
93            ResponseCode::ServerFailure => 2,
94            ResponseCode::NameError => 3,
95            ResponseCode::NotImplemented => 4,
96            ResponseCode::Refused => 5,
97            ResponseCode::YxDomain => 6,
98            ResponseCode::YxRRSet => 7,
99            ResponseCode::NxRRSet => 8,
100            ResponseCode::NotAuth => 9,
101            ResponseCode::NotZone => 10,
102            ResponseCode::Other(x) => x,
103        }
104    }
105}
106
107#[derive(Default, Clone, Debug)]
108#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
109pub struct Header {
110    pub id: u16,
111    pub query_response: QueryResponse,
112    pub opcode: Opcode,
113    pub is_authoritative: bool,
114    pub is_truncated: bool,
115    pub recursion_desired: bool,
116    pub recursion_available: bool,
117    pub reserved: u8,
118    pub response_code: ResponseCode,
119    pub question_count: u16,
120    pub answer_count: u16,
121    pub nameserver_count: u16,
122    pub additional_record_count: u16,
123}
124
125impl Header {
126    pub const LENGTH: usize = 12;
127
128    pub(crate) fn validate(&self) -> bool {
129        if matches!(self.opcode, Opcode::Other(_)) {
130            return false;
131        }
132        if matches!(self.response_code, ResponseCode::Other(_)) {
133            return false;
134        }
135        true
136    }
137
138    pub(crate) fn parse(data: [u8; Self::LENGTH]) -> Self {
139        let flags = u16::from_be_bytes(data[2..4].try_into().unwrap());
140        Self {
141            id: u16::from_be_bytes(data[..2].try_into().unwrap()),
142            query_response: if flags >> 15 & 0b1 == 0 {
143                QueryResponse::Query
144            } else {
145                QueryResponse::Response
146            },
147            opcode: Opcode::from((flags >> 11) as u8 & 0b1111),
148            is_authoritative: flags >> 10 & 0b1 != 0,
149            is_truncated: flags >> 9 & 0b1 != 0,
150            recursion_desired: flags >> 8 & 0b1 != 0,
151            recursion_available: flags >> 7 & 0b1 != 0,
152            reserved: (flags >> 4 & 0b111) as u8,
153            response_code: ResponseCode::from((flags & 0b1111) as u8),
154            question_count: u16::from_be_bytes(data[4..6].try_into().unwrap()),
155            answer_count: u16::from_be_bytes(data[6..8].try_into().unwrap()),
156            nameserver_count: u16::from_be_bytes(data[8..10].try_into().unwrap()),
157            additional_record_count: u16::from_be_bytes(data[10..12].try_into().unwrap()),
158        }
159    }
160
161    pub fn to_bytes(&self) -> [u8; 12] {
162        let mut output: SmallVec<[u8; Self::LENGTH]> = SmallVec::new();
163        output.extend(self.id.to_be_bytes());
164        let mut flags = 0u16;
165        if self.query_response == QueryResponse::Response {
166            flags |= 0b1 << 15;
167        }
168        let opcode: u8 = self.opcode.into();
169        let response_code: u8 = self.response_code.into();
170        flags |= (opcode as u16 & 0b1111) << 11;
171        flags |= (self.is_authoritative as u8 as u16) << 10;
172        flags |= (self.is_truncated as u8 as u16) << 9;
173        flags |= (self.recursion_desired as u8 as u16) << 8;
174        flags |= (self.recursion_available as u8 as u16) << 7;
175        flags |= (self.reserved as u16 & 0b111) << 4;
176        flags |= response_code as u16 & 0b1111;
177        output.extend(flags.to_be_bytes());
178        output.extend(self.question_count.to_be_bytes());
179        output.extend(self.answer_count.to_be_bytes());
180        output.extend(self.nameserver_count.to_be_bytes());
181        output.extend(self.additional_record_count.to_be_bytes());
182        output.into_inner().unwrap()
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::test_data::*;
190
191    #[test]
192    fn test_header_parse() {
193        let header = Header::parse(DNS_QUERY[..Header::LENGTH].try_into().unwrap());
194        assert!(header.validate());
195        assert!(!header.is_authoritative);
196        assert!(!header.is_truncated);
197        assert!(header.recursion_desired);
198        assert_eq!(header.query_response, QueryResponse::Query);
199        assert_eq!(header.question_count, 1);
200        assert_eq!(header.additional_record_count, 1);
201
202        assert_eq!(&DNS_QUERY[..Header::LENGTH], &header.to_bytes());
203
204        let header = Header::parse(DNS_RESPONSE[..Header::LENGTH].try_into().unwrap());
205        assert!(header.validate());
206        assert!(!header.is_authoritative);
207        assert!(!header.is_truncated);
208        assert!(header.recursion_desired);
209        assert!(header.recursion_available);
210        assert_eq!(header.response_code, ResponseCode::NoError);
211        assert_eq!(header.query_response, QueryResponse::Response);
212        assert_eq!(header.question_count, 1);
213        assert_eq!(header.answer_count, 1);
214        assert_eq!(header.additional_record_count, 1);
215
216        assert_eq!(&DNS_RESPONSE[..Header::LENGTH], &header.to_bytes());
217    }
218}