1#![allow(dead_code)]
2use bytes::Bytes;
3use md5::{Digest, Md5};
4use rand::Rng;
5use rand::RngCore;
6use std::sync::Arc;
7use thiserror::Error;
8use zeroize::Zeroize;
9
10use crate::{
11 Code,
12 attribute::{
13 self, AttributeParseError, AttributeValue, Attributes, FromRadiusAttribute,
14 ToRadiusAttribute,
15 },
16};
17
18pub const MAX_PACKET_SIZE: usize = 4096;
19#[derive(Debug, Clone)]
24pub struct Packet {
25 pub code: Code,
27
28 pub identifier: u8,
31
32 pub authenticator: [u8; 16],
35
36 pub attributes: Attributes,
38
39 pub secret: Arc<[u8]>,
41}
42
43#[derive(Debug, Error, PartialEq, Eq)]
44pub enum PacketParseError {
45 #[error("Packet not at least 20 bytes long")]
46 TooShortHeader,
47
48 #[error("Unknown packet code: {0}")]
49 UnknownPacketCode(u8),
50
51 #[error("Invalid packet length in header: {0}")]
52 InvalidLength(usize),
53
54 #[error("Buffer too short: header expects {expected} bytes, but only received {actual}")]
56 BufferTooShort { expected: usize, actual: usize },
57
58 #[error("Attribute parsing failed: {0}")]
59 AttributeError(AttributeParseError),
60}
61impl Packet {
62 fn new(code: Code, secret: Arc<[u8]>) -> Self {
63 let mut rng = rand::rng();
64 let mut authenticator = [0u8; 16];
65 rng.fill_bytes(&mut authenticator);
66
67 Packet {
68 code,
69 identifier: rng.random::<u8>(),
70 authenticator,
71 attributes: Attributes::default(),
72 secret,
73 }
74 }
75
76 pub fn parse_packet(b: Bytes, secret: Arc<[u8]>) -> Result<Self, PacketParseError> {
77 if b.len() < 20 {
78 return Err(PacketParseError::TooShortHeader);
79 }
80
81 if b.len() > MAX_PACKET_SIZE {
82 return Err(PacketParseError::InvalidLength(b.len()));
83 }
84
85 let length = u16::from_be_bytes([b[2], b[3]]) as usize;
86
87 if !(20..=MAX_PACKET_SIZE).contains(&length) {
89 return Err(PacketParseError::InvalidLength(length));
90 }
91
92 if b.len() < length {
93 return Err(PacketParseError::BufferTooShort {
94 expected: length,
95 actual: b.len(),
96 });
97 }
98
99 let code = Code::try_from(b[0])?;
100 let identifier = b[1];
101
102 let mut authenticator = [0u8; 16];
103 authenticator.copy_from_slice(&b[4..20]);
104
105 let attribute_data = b.slice(20..length);
106 let attrs = attribute::parse_attributes(attribute_data)?;
107
108 Ok(Self {
109 code,
110 identifier,
111 authenticator,
112 attributes: attrs,
113 secret,
114 })
115 }
116
117 pub fn encode(&self) -> Result<Vec<u8>, PacketParseError> {
118 let mut b = self.encode_raw()?;
119 let code: Code = self.code;
120
121 match code {
122 Code::AccessRequest | Code::StatusServer => Ok(b),
126
127 Code::AccessAccept
130 | Code::AccessReject
131 | Code::AccessChallenge
132 | Code::AccountingResponse
133 | Code::DisconnectAck
134 | Code::DisconnectNak => {
135 let mut hasher = Md5::new();
136
137 hasher.update(&b[0..4]);
138
139 hasher.update(&b[4..20]);
140
141 hasher.update(&b[20..]);
142
143 hasher.update(&self.secret);
144
145 let hash_result = hasher.finalize();
146 b[4..20].copy_from_slice(&hash_result);
147
148 Ok(b)
149 }
150 Code::AccountingRequest | Code::DisconnectRequest | Code::CoARequest => {
151 let mut hasher = Md5::new();
152 hasher.update(&b[0..4]);
153
154 const NUL_AUTHENTICATOR: [u8; 16] = [0u8; 16];
156 hasher.update(NUL_AUTHENTICATOR);
157
158 hasher.update(&b[20..]);
159 hasher.update(&self.secret);
160
161 let hash_result = hasher.finalize();
162 b[4..20].copy_from_slice(&hash_result);
163
164 Ok(b)
165 }
166
167 _ => Err(PacketParseError::UnknownPacketCode(b[0])),
168 }
169 }
170 pub fn encode_raw(&self) -> Result<Vec<u8>, PacketParseError> {
171 let attributes_len = self.attributes.encoded_len()?;
172 let size: usize = 20 + attributes_len;
173 if size > MAX_PACKET_SIZE {
174 return Err(PacketParseError::InvalidLength(size));
175 }
176 let mut b = vec![0u8; size];
177 b[0] = self.code as u8;
178 b[1] = self.identifier;
179 b[2..4].copy_from_slice(&(size as u16).to_be_bytes());
180 b[4..20].copy_from_slice(&self.authenticator);
181 self.attributes
182 .encode_to(&mut b[20..])
183 .map_err(PacketParseError::AttributeError)?;
184 Ok(b)
185 }
186 pub fn verify_request(&self) -> bool {
187 if self.secret.is_empty() {
188 return false;
189 }
190
191 match self.code {
192 Code::AccessRequest | Code::StatusServer => true,
196
197 Code::AccountingRequest | Code::DisconnectRequest | Code::CoARequest => {
201 let packet_raw_result = self.encode_raw();
202 if packet_raw_result.is_err() {
203 return false;
204 }
205 let mut packet_raw = packet_raw_result.unwrap();
206
207 const NUL_AUTHENTICATOR: [u8; 16] = [0u8; 16];
208 packet_raw[4..20].copy_from_slice(&NUL_AUTHENTICATOR);
209
210 let mut hasher = Md5::new();
211
212 hasher.update(&packet_raw);
214
215 hasher.update(&*self.secret);
217
218 let calculated_hash = hasher.finalize();
219
220 let calculated_bytes: [u8; 16] = calculated_hash.into();
221
222 calculated_bytes == self.authenticator
224 }
225
226 _ => false,
227 }
228 }
229
230 pub fn get_attribute(&self, key: u8) -> Option<&AttributeValue> {
231 self.attributes.get(key)
232 }
233
234 pub fn set_attribute(&mut self, key: u8, value: AttributeValue) {
235 self.attributes.set(key, value);
236 }
237 pub fn get_vsa_attribute(&self, vendor_id: u32, vendor_type: u8) -> Option<&[u8]> {
238 self.attributes.get_vsa_attribute(vendor_id, vendor_type)
239 }
240 pub fn set_vsa_attribute(&mut self, vendor_id: u32, vendor_type: u8, value: AttributeValue) {
241 self.attributes
242 .set_vsa_attribute(vendor_id, vendor_type, value);
243 }
244 pub fn encrypt_user_password(&self, plaintext: &[u8]) -> Option<Vec<u8>> {
246 if plaintext.len() > 128 || self.secret.is_empty() {
247 return None;
248 }
249
250 let chunks = if plaintext.is_empty() {
251 1
252 } else {
253 plaintext.len().div_ceil(16)
254 };
255 let mut enc = Vec::with_capacity(chunks * 16);
256
257 let mut hasher = Md5::new();
258 hasher.update(&self.secret);
259 hasher.update(self.authenticator);
260 let mut b = hasher.finalize();
261
262 for i in 0..16 {
264 let p_byte = if i < plaintext.len() { plaintext[i] } else { 0 };
265 enc.push(p_byte ^ b[i]);
266 }
267
268 for i in 1..chunks {
270 hasher = Md5::new();
271 hasher.update(&self.secret);
272 hasher.update(&enc[(i - 1) * 16..i * 16]);
273 b = hasher.finalize();
274
275 for j in 0..16 {
276 let offset = i * 16 + j;
277 let p_byte = if offset < plaintext.len() {
278 plaintext[offset]
279 } else {
280 0
281 };
282 enc.push(p_byte ^ b[j]);
283 }
284 }
285 Some(enc)
286 }
287
288 pub fn decrypt_user_password(&self, encrypted: &[u8]) -> Option<Vec<u8>> {
290 if encrypted.is_empty() || !encrypted.len().is_multiple_of(16) || self.secret.is_empty() {
291 return None;
292 }
293
294 let mut plaintext = Vec::with_capacity(encrypted.len());
295 let mut last_round = [0u8; 16];
296 last_round.copy_from_slice(&self.authenticator);
297
298 for chunk in encrypted.chunks(16) {
299 let mut hasher = Md5::new();
300 hasher.update(&*self.secret);
301 hasher.update(last_round);
302 let b = hasher.finalize();
303 for i in 0..16 {
304 plaintext.push(chunk[i] ^ b[i]);
305 }
306 last_round.copy_from_slice(chunk);
307 }
308 let mut end = plaintext.len();
309 while end > 0 && plaintext[end - 1] == 0 {
310 end -= 1;
311 }
312 Some(plaintext[..end].to_vec())
313 }
314
315 pub fn encrypt_tunnel_password(&self, plaintext: &[u8]) -> Option<Vec<u8>> {
317 if self.secret.is_empty() {
318 return None;
319 }
320
321 let mut salt = [0u8; 2];
322 rand::rng().fill_bytes(&mut salt);
323 salt[0] |= 0x80;
324
325 let mut data = vec![plaintext.len() as u8];
326 data.extend_from_slice(plaintext);
327 while !data.len().is_multiple_of(16) {
328 data.push(0);
329 }
330
331 let mut result = salt.to_vec();
332 let mut last_round = Vec::with_capacity(16 + 2);
333 last_round.extend_from_slice(&self.authenticator);
334 last_round.extend_from_slice(&salt);
335
336 for chunk in data.chunks(16) {
337 let mut hasher = Md5::new();
338 hasher.update(&self.secret);
339 hasher.update(&last_round);
340 let b = hasher.finalize();
341
342 let mut encrypted_chunk = [0u8; 16];
343 for i in 0..16 {
344 encrypted_chunk[i] = chunk[i] ^ b[i];
345 }
346 result.extend_from_slice(&encrypted_chunk);
347 last_round = encrypted_chunk.to_vec();
348 }
349 Some(result)
350 }
351
352 pub fn decrypt_tunnel_password(&self, encrypted: &[u8]) -> Option<Vec<u8>> {
354 if encrypted.len() < 18
355 || !(encrypted.len() - 2).is_multiple_of(16)
356 || self.secret.is_empty()
357 {
358 return None;
359 }
360
361 let salt = &encrypted[0..2];
362 let ciphertext = &encrypted[2..];
363 let mut plaintext = Vec::with_capacity(ciphertext.len());
364
365 let mut last_round = Vec::with_capacity(16 + 2);
366 last_round.extend_from_slice(&self.authenticator);
367 last_round.extend_from_slice(salt);
368
369 for chunk in ciphertext.chunks(16) {
370 let mut hasher = Md5::new();
371 hasher.update(&self.secret);
372 hasher.update(&last_round);
373 let b = hasher.finalize();
374 for i in 0..16 {
375 plaintext.push(chunk[i] ^ b[i]);
376 }
377 last_round = chunk.to_vec();
378 }
379
380 let len = plaintext[0] as usize;
381 if len > plaintext.len() - 1 {
382 return None;
383 }
384 Some(plaintext[1..1 + len].to_vec())
385 }
386
387 pub fn get_attribute_as<T: FromRadiusAttribute>(&self, type_code: u8) -> Option<T> {
388 match type_code {
389 2 => {
390 let raw = self.get_attribute(2)?;
392 let mut decrypted = self.decrypt_user_password(raw)?;
393 let result = T::from_bytes(&decrypted);
394 decrypted.zeroize();
395 result
396 }
397 69 => {
398 let raw = self.get_attribute(69)?;
400 let mut decrypted = self.decrypt_tunnel_password(raw)?;
401 let result = T::from_bytes(&decrypted);
402 decrypted.zeroize();
403 result
404 }
405 _ => self
406 .get_attribute(type_code)
407 .and_then(|raw| T::from_bytes(raw)),
408 }
409 }
410
411 pub fn set_attribute_as<T: ToRadiusAttribute>(&mut self, type_code: u8, value: T) {
412 match type_code {
413 2 => {
414 if let Some(encrypted_vec) = self.encrypt_user_password(&value.to_bytes()) {
415 let encrypted_bytes = Bytes::from(encrypted_vec);
416 self.set_attribute(2, encrypted_bytes);
417 }
418 }
419 69 => {
420 if let Some(encrypted_vec) = self.encrypt_tunnel_password(&value.to_bytes()) {
421 let encrypted_bytes = Bytes::from(encrypted_vec);
423 self.set_attribute(69, encrypted_bytes);
424 }
425 }
426 _ => {
427 let attr_bytes = Bytes::from(value.to_bytes());
429 self.set_attribute(type_code, attr_bytes);
430 }
431 }
432 }
433
434 pub fn get_vsa_attribute_as<T: FromRadiusAttribute>(&self, v_id: u32, v_type: u8) -> Option<T> {
435 self.get_vsa_attribute(v_id, v_type)
436 .and_then(|raw| T::from_bytes(raw))
437 }
438
439 pub fn set_vsa_attribute_as<T: ToRadiusAttribute>(&mut self, v_id: u32, v_type: u8, value: T) {
440 let raw_bytes = value.to_bytes();
441
442 let value_bytes = Bytes::from(raw_bytes);
443
444 self.set_vsa_attribute(v_id, v_type, value_bytes);
445 }
446
447 pub fn create_response_packet(&self, code: Code) -> Packet {
448 Packet {
449 code,
450 identifier: self.identifier,
451 authenticator: self.authenticator,
452 attributes: Attributes::default(),
453 secret: self.secret.clone(),
454 }
455 }
456}
457
458impl From<AttributeParseError> for PacketParseError {
459 fn from(err: AttributeParseError) -> Self {
460 PacketParseError::AttributeError(err)
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use std::sync::Arc;
468
469 #[test]
470 fn parse_packet_too_short() {
471 let secret = Arc::from(&b"shared-secret"[..]);
472 let buf = Bytes::from(vec![0u8; 10]);
474
475 let err = Packet::parse_packet(buf, secret).unwrap_err();
476 assert_eq!(err, PacketParseError::TooShortHeader);
477 }
478
479 #[test]
480 fn test_encode_access_request_preserves_authenticator() {
481 let mut packet = Packet::new(Code::AccessRequest, Arc::from(&b"secret"[..]));
482 let auth = [1u8; 16];
483 packet.authenticator = auth;
484
485 let encoded = packet.encode().unwrap();
486 assert_eq!(&encoded[4..20], &auth);
488 }
489
490 #[test]
491 fn test_verify_request_accounting_valid() {
492 let secret = Arc::from(&b"super-secret"[..]);
493 let packet = Packet::new(Code::AccountingRequest, Arc::clone(&secret));
494
495 let encoded_bytes = Bytes::from(packet.encode().unwrap());
497
498 let received = Packet::parse_packet(encoded_bytes, Arc::clone(&secret)).unwrap();
500
501 assert!(received.verify_request());
502 }
503
504 #[test]
505 fn test_verify_request_accounting_invalid_secret() {
506 let secret = Arc::from("real-secret".as_bytes());
507 let wrong_secret = Arc::from("hacker-secret".as_bytes());
508
509 let packet = Packet::new(Code::AccountingRequest, Arc::clone(&secret));
510 let encoded_bytes = Bytes::from(packet.encode().unwrap());
511
512 let received = Packet::parse_packet(encoded_bytes, wrong_secret).unwrap();
514
515 assert!(!received.verify_request());
516 }
517
518 #[test]
519 fn test_verify_request_accounting_tampered_data() {
520 let secret = Arc::from(&b"secret"[..]);
521 let packet_to_send = Packet::new(Code::AccountingRequest, Arc::clone(&secret));
522
523 let mut encoded_vec = packet_to_send.encode().unwrap();
525
526 encoded_vec[1] ^= 0xFF;
528
529 let received = Packet::parse_packet(Bytes::from(encoded_vec), secret).unwrap();
530 assert!(!received.verify_request());
531 }
532
533 #[test]
534 fn test_verify_request_access_request_always_true() {
535 let secret = Arc::from(&b"secret"[..]);
536 let packet = Packet::new(Code::AccessRequest, secret);
537 assert!(packet.verify_request());
538 }
539
540 #[test]
541 fn test_user_password_roundtrip() {
542 let secret = Arc::from(&b"shared-secret"[..]);
543 let mut packet = Packet::new(Code::AccessRequest, secret);
544 packet.authenticator = [0x42; 16];
545
546 let original = b"very-secure-password-123";
547 let encrypted = packet
548 .encrypt_user_password(original)
549 .expect("Encryption failed");
550 let decrypted = packet
551 .decrypt_user_password(&encrypted)
552 .expect("Decryption failed");
553
554 assert_eq!(original.to_vec(), decrypted);
555 }
556
557 #[test]
558 fn test_tunnel_password_roundtrip() {
559 let secret = Arc::from(&b"shared-secret"[..]);
560 let mut packet = Packet::new(Code::AccessRequest, secret);
561 packet.authenticator = [0x77; 16];
562
563 let original = b"tunnel-secret-password";
564 let encrypted = packet
565 .encrypt_tunnel_password(original)
566 .expect("Tunnel encryption failed");
567 let decrypted = packet
568 .decrypt_tunnel_password(&encrypted)
569 .expect("Tunnel decryption failed");
570
571 assert_eq!(original.to_vec(), decrypted);
572 assert_eq!(encrypted.len(), 2 + 32);
573 assert!(encrypted[0] >= 0x80, "Salt MSB must be set");
574 }
575
576 #[test]
577 fn test_encrypt_user_password_blocks() {
578 let secret = Arc::from(&b"mysecret"[..]);
579 let mut packet = Packet::new(Code::AccessRequest, Arc::clone(&secret));
580 packet.authenticator = [0x11; 16];
581
582 let pass1 = b"password";
583 let enc1 = packet.encrypt_user_password(pass1).unwrap();
584 assert_eq!(enc1.len(), 16);
585
586 let pass2 = b"this-is-a-very-long-password-exceeding-16-bytes";
587 let enc2 = packet.encrypt_user_password(pass2).unwrap();
588 assert_eq!(enc2.len(), 48);
589
590 let mut hasher = Md5::new();
592 hasher.update(&*secret); hasher.update(&packet.authenticator);
594 let b1 = hasher.finalize();
595
596 let decrypted_p1 = enc1[0] ^ b1[0];
597 assert_eq!(decrypted_p1, b'p');
598 }
599}