1use smallvec::SmallVec;
2
3#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)]
4#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
5#[repr(u8)]
6pub enum QueryResponse {
7 #[default]
8 Query,
9 Response,
10}
11
12#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)]
13#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
14#[repr(u8)]
15pub enum Opcode {
16 #[default]
17 Query,
18 InverseQuery,
19 Status,
20 Update,
21 Other(u8),
22}
23
24impl From<u8> for Opcode {
25 fn from(value: u8) -> Self {
26 match value {
27 0 => Opcode::Query,
28 1 => Opcode::InverseQuery,
29 2 => Opcode::Status,
30 5 => Opcode::Update,
31 3..=15 => Opcode::Other(value),
32 _ => panic!("invalid range of value for opcode"),
33 }
34 }
35}
36
37impl From<Opcode> for u8 {
38 fn from(value: Opcode) -> u8 {
39 match value {
40 Opcode::Query => 0,
41 Opcode::InverseQuery => 1,
42 Opcode::Status => 2,
43 Opcode::Update => 5,
44 Opcode::Other(x) => x,
45 }
46 }
47}
48
49#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)]
50#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
51#[repr(u8)]
52pub enum ResponseCode {
53 #[default]
54 NoError,
55 FormatError,
56 ServerFailure,
57 NameError,
58 NotImplemented,
59 Refused,
60 YxDomain,
61 YxRRSet,
62 NxRRSet,
63 NotAuth,
64 NotZone,
65 Other(u8),
66}
67
68impl From<u8> for ResponseCode {
69 fn from(value: u8) -> Self {
70 match value {
71 0 => ResponseCode::NoError,
72 1 => ResponseCode::FormatError,
73 2 => ResponseCode::ServerFailure,
74 3 => ResponseCode::NameError,
75 4 => ResponseCode::NotImplemented,
76 5 => ResponseCode::Refused,
77 6 => ResponseCode::YxDomain,
78 7 => ResponseCode::YxRRSet,
79 8 => ResponseCode::NxRRSet,
80 9 => ResponseCode::NotAuth,
81 10 => ResponseCode::NotZone,
82 11..=15 => ResponseCode::Other(value),
83 _ => panic!("invalid range of value for response code"),
84 }
85 }
86}
87
88impl From<ResponseCode> for u8 {
89 fn from(value: ResponseCode) -> u8 {
90 match value {
91 ResponseCode::NoError => 0,
92 ResponseCode::FormatError => 1,
93 ResponseCode::ServerFailure => 2,
94 ResponseCode::NameError => 3,
95 ResponseCode::NotImplemented => 4,
96 ResponseCode::Refused => 5,
97 ResponseCode::YxDomain => 6,
98 ResponseCode::YxRRSet => 7,
99 ResponseCode::NxRRSet => 8,
100 ResponseCode::NotAuth => 9,
101 ResponseCode::NotZone => 10,
102 ResponseCode::Other(x) => x,
103 }
104 }
105}
106
107#[derive(Default, Clone, Debug)]
108#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
109pub struct Header {
110 pub id: u16,
111 pub query_response: QueryResponse,
112 pub opcode: Opcode,
113 pub is_authoritative: bool,
114 pub is_truncated: bool,
115 pub recursion_desired: bool,
116 pub recursion_available: bool,
117 pub reserved: u8,
118 pub response_code: ResponseCode,
119 pub question_count: u16,
120 pub answer_count: u16,
121 pub nameserver_count: u16,
122 pub additional_record_count: u16,
123}
124
125impl Header {
126 pub const LENGTH: usize = 12;
127
128 pub(crate) fn validate(&self) -> bool {
129 if matches!(self.opcode, Opcode::Other(_)) {
130 return false;
131 }
132 if matches!(self.response_code, ResponseCode::Other(_)) {
133 return false;
134 }
135 true
136 }
137
138 pub(crate) fn parse(data: [u8; Self::LENGTH]) -> Self {
139 let flags = u16::from_be_bytes(data[2..4].try_into().unwrap());
140 Self {
141 id: u16::from_be_bytes(data[..2].try_into().unwrap()),
142 query_response: if flags >> 15 & 0b1 == 0 {
143 QueryResponse::Query
144 } else {
145 QueryResponse::Response
146 },
147 opcode: Opcode::from((flags >> 11) as u8 & 0b1111),
148 is_authoritative: flags >> 10 & 0b1 != 0,
149 is_truncated: flags >> 9 & 0b1 != 0,
150 recursion_desired: flags >> 8 & 0b1 != 0,
151 recursion_available: flags >> 7 & 0b1 != 0,
152 reserved: (flags >> 4 & 0b111) as u8,
153 response_code: ResponseCode::from((flags & 0b1111) as u8),
154 question_count: u16::from_be_bytes(data[4..6].try_into().unwrap()),
155 answer_count: u16::from_be_bytes(data[6..8].try_into().unwrap()),
156 nameserver_count: u16::from_be_bytes(data[8..10].try_into().unwrap()),
157 additional_record_count: u16::from_be_bytes(data[10..12].try_into().unwrap()),
158 }
159 }
160
161 pub fn to_bytes(&self) -> [u8; 12] {
162 let mut output: SmallVec<[u8; Self::LENGTH]> = SmallVec::new();
163 output.extend(self.id.to_be_bytes());
164 let mut flags = 0u16;
165 if self.query_response == QueryResponse::Response {
166 flags |= 0b1 << 15;
167 }
168 let opcode: u8 = self.opcode.into();
169 let response_code: u8 = self.response_code.into();
170 flags |= (opcode as u16 & 0b1111) << 11;
171 flags |= (self.is_authoritative as u8 as u16) << 10;
172 flags |= (self.is_truncated as u8 as u16) << 9;
173 flags |= (self.recursion_desired as u8 as u16) << 8;
174 flags |= (self.recursion_available as u8 as u16) << 7;
175 flags |= (self.reserved as u16 & 0b111) << 4;
176 flags |= response_code as u16 & 0b1111;
177 output.extend(flags.to_be_bytes());
178 output.extend(self.question_count.to_be_bytes());
179 output.extend(self.answer_count.to_be_bytes());
180 output.extend(self.nameserver_count.to_be_bytes());
181 output.extend(self.additional_record_count.to_be_bytes());
182 output.into_inner().unwrap()
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::test_data::*;
190
191 #[test]
192 fn test_header_parse() {
193 let header = Header::parse(DNS_QUERY[..Header::LENGTH].try_into().unwrap());
194 assert!(header.validate());
195 assert!(!header.is_authoritative);
196 assert!(!header.is_truncated);
197 assert!(header.recursion_desired);
198 assert_eq!(header.query_response, QueryResponse::Query);
199 assert_eq!(header.question_count, 1);
200 assert_eq!(header.additional_record_count, 1);
201
202 assert_eq!(&DNS_QUERY[..Header::LENGTH], &header.to_bytes());
203
204 let header = Header::parse(DNS_RESPONSE[..Header::LENGTH].try_into().unwrap());
205 assert!(header.validate());
206 assert!(!header.is_authoritative);
207 assert!(!header.is_truncated);
208 assert!(header.recursion_desired);
209 assert!(header.recursion_available);
210 assert_eq!(header.response_code, ResponseCode::NoError);
211 assert_eq!(header.query_response, QueryResponse::Response);
212 assert_eq!(header.question_count, 1);
213 assert_eq!(header.answer_count, 1);
214 assert_eq!(header.additional_record_count, 1);
215
216 assert_eq!(&DNS_RESPONSE[..Header::LENGTH], &header.to_bytes());
217 }
218}