pub fn verify_dns_flags(flags: u16) -> Result<u16, String> {
let (qr, opcode, aa, tc, _rd, ra, z, rcode) = extract_dns_flags(flags);
verify_z_field(z)?;
verify_opcode(opcode)?;
verify_rcode(rcode)?;
verify_ra_in_query(qr, ra)?;
if qr == 1 {
verify_response_flags(opcode, aa, tc, rcode)?;
}
Ok(flags)
}
fn extract_dns_flags(flags: u16) -> (u16, u16, u16, u16, u16, u16, u16, u16) {
let qr = (flags >> 15) & 0b1;
let opcode = (flags >> 11) & 0b1111;
let aa = (flags >> 10) & 0b1;
let tc = (flags >> 9) & 0b1;
let rd = (flags >> 8) & 0b1;
let ra = (flags >> 7) & 0b1;
let z = (flags >> 4) & 0b111;
let rcode = flags & 0b1111;
(qr, opcode, aa, tc, rd, ra, z, rcode)
}
#[allow(dead_code)]
fn verify_z_field(z: u16) -> Result<(), String> {
if z != 0 {
return Err(format!("Invalid Z field, must be 0. Here it's: {}", z));
}
Ok(())
}
fn verify_opcode(opcode: u16) -> Result<(), String> {
if opcode > 5 {
return Err(format!(
"Invalid Opcode, must be between 0 and 5. Here it's: {}",
opcode
));
}
Ok(())
}
fn verify_rcode(rcode: u16) -> Result<(), String> {
if rcode > 5 {
return Err(format!(
"Invalid RCode, must be between 0 and 5. Here it's: {}",
rcode
));
}
Ok(())
}
fn verify_ra_in_query(qr: u16, ra: u16) -> Result<(), String> {
if qr == 0 && ra != 0 {
return Err(format!("RA must be 0 in queries. Here it's: {}", ra));
}
Ok(())
}
fn verify_response_flags(opcode: u16, aa: u16, tc: u16, rcode: u16) -> Result<(), String> {
if opcode == 2 && (aa != 0 || tc != 0) {
return Err(format!(
"AA and TC must be 0 in STATUS responses. Here AA is: {}, TC is: {}",
aa, tc
));
}
if rcode == 2 && aa != 0 {
return Err(format!(
"Rcode = 2 so AA must be 0 in Server failure responses. Here it's: {}",
aa
));
}
if rcode == 3 && aa != 1 {
return Err(format!(
"Rcode = 3 AA must be 1 in Name Error responses. Here it's: {}",
aa
));
}
if rcode == 5 && aa != 0 {
return Err(format!(
"Rcode = 5 AA must be 0 in Refused responses. Here it's: {}",
aa
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_verify_z_field() {
assert_eq!(verify_z_field(0), Ok(()));
assert_eq!(
verify_z_field(1),
Err("Invalid Z field, must be 0. Here it's: 1".to_string())
);
}
#[test]
fn test_verify_opcode() {
assert_eq!(verify_opcode(0), Ok(()));
assert_eq!(verify_opcode(5), Ok(()));
assert_eq!(
verify_opcode(6),
Err("Invalid Opcode, must be between 0 and 5. Here it's: 6".to_string())
);
}
#[test]
fn test_verify_rcode() {
assert_eq!(verify_rcode(0), Ok(()));
assert_eq!(verify_rcode(5), Ok(()));
assert_eq!(
verify_rcode(6),
Err("Invalid RCode, must be between 0 and 5. Here it's: 6".to_string())
);
}
#[test]
fn test_verify_ra_in_query() {
assert_eq!(verify_ra_in_query(0, 0), Ok(()));
assert_eq!(
verify_ra_in_query(0, 1),
Err("RA must be 0 in queries. Here it's: 1".to_string())
);
assert_eq!(verify_ra_in_query(1, 1), Ok(()));
}
#[test]
fn test_verify_response_flags() {
assert_eq!(verify_response_flags(2, 0, 0, 0), Ok(()));
assert_eq!(
verify_response_flags(2, 1, 0, 0),
Err("AA and TC must be 0 in STATUS responses. Here AA is: 1, TC is: 0".to_string())
);
assert_eq!(
verify_response_flags(2, 0, 1, 0),
Err("AA and TC must be 0 in STATUS responses. Here AA is: 0, TC is: 1".to_string())
);
assert_eq!(
verify_response_flags(0, 1, 0, 2),
Err("Rcode = 2 so AA must be 0 in Server failure responses. Here it's: 1".to_string())
);
assert_eq!(
verify_response_flags(0, 0, 0, 3),
Err("Rcode = 3 AA must be 1 in Name Error responses. Here it's: 0".to_string())
);
assert_eq!(verify_response_flags(0, 0, 0, 5), Ok(()));
assert_eq!(
verify_response_flags(0, 1, 0, 5),
Err("Rcode = 5 AA must be 0 in Refused responses. Here it's: 1".to_string())
);
}
#[test]
fn test_flags_zero() {
let flags: u16 = 0x0000; assert_eq!(verify_dns_flags(flags), Ok(flags));
}
#[test]
fn test_valid_flags_standard_query() {
let flags: u16 = 0x0100; assert_eq!(verify_dns_flags(flags), Ok(flags));
}
#[test]
fn test_valid_flags_response_no_such_name() {
let flags: u16 = 0x8583; assert_eq!(verify_dns_flags(flags), Ok(flags));
}
#[test]
fn test_valid_flags_response_no_error() {
let flags: u16 = 0x8180; assert_eq!(verify_dns_flags(flags), Ok(flags));
}
#[test]
fn test_invalid_z_field() {
let flags: u16 = 0x8010; assert_eq!(
verify_dns_flags(flags),
Err("Invalid Z field, must be 0. Here it's: 1".to_string())
);
}
#[test]
fn test_invalid_opcode() {
let flags: u16 = 0x7104; assert_eq!(
verify_dns_flags(flags),
Err("Invalid Opcode, must be between 0 and 5. Here it's: 14".to_string())
);
}
#[test]
fn test_invalid_rcode() {
let flags: u16 = 0x8006; assert_eq!(
verify_dns_flags(flags),
Err("Invalid RCode, must be between 0 and 5. Here it's: 6".to_string())
);
}
#[test]
fn test_ra_in_query() {
let flags: u16 = 0x0080; assert_eq!(
verify_dns_flags(flags),
Err("RA must be 0 in queries. Here it's: 1".to_string())
);
}
#[test]
fn test_aa_tc_in_status_response() {
let flags: u16 = 0x8410; assert_eq!(
verify_dns_flags(flags),
Err("Invalid Z field, must be 0. Here it's: 1".to_string())
);
}
#[test]
fn test_aa_in_server_failure() {
let flags: u16 = 0x8082; assert_eq!(verify_dns_flags(flags), Ok(flags));
}
#[test]
fn test_aa_in_name_error() {
let flags: u16 = 0x8183; assert_eq!(
verify_dns_flags(flags),
Err("Rcode = 3 AA must be 1 in Name Error responses. Here it's: 0".to_string())
);
}
#[test]
fn test_aa_in_refused() {
let flags: u16 = 0x8185; assert_eq!(verify_dns_flags(flags), Ok(flags));
}
#[test]
fn test_random_val() {
let flags: u16 = 0x9786; assert_eq!(
verify_dns_flags(flags),
Err("Invalid RCode, must be between 0 and 5. Here it's: 6".to_string())
);
}
}