1use crate::errors::{AuthError, Result};
22use md5::Digest;
23use ring::rand::SecureRandom;
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::net::SocketAddr;
27use std::time::Duration;
28use tokio::net::UdpSocket;
29
30pub mod code {
34 pub const ACCESS_REQUEST: u8 = 1;
35 pub const ACCESS_ACCEPT: u8 = 2;
36 pub const ACCESS_REJECT: u8 = 3;
37 pub const ACCOUNTING_REQUEST: u8 = 4;
38 pub const ACCOUNTING_RESPONSE: u8 = 5;
39 pub const ACCESS_CHALLENGE: u8 = 11;
40}
41
42pub mod attr {
44 pub const USER_NAME: u8 = 1;
45 pub const USER_PASSWORD: u8 = 2;
46 pub const NAS_IP_ADDRESS: u8 = 4;
47 pub const NAS_PORT: u8 = 5;
48 pub const SERVICE_TYPE: u8 = 6;
49 pub const FRAMED_PROTOCOL: u8 = 7;
50 pub const FILTER_ID: u8 = 11;
51 pub const REPLY_MESSAGE: u8 = 18;
52 pub const STATE: u8 = 24;
53 pub const SESSION_TIMEOUT: u8 = 27;
54 pub const CALLING_STATION_ID: u8 = 31;
55 pub const NAS_IDENTIFIER: u8 = 32;
56 pub const ACCT_STATUS_TYPE: u8 = 40;
57 pub const ACCT_SESSION_ID: u8 = 44;
58 pub const NAS_PORT_TYPE: u8 = 61;
59 pub const EAP_MESSAGE: u8 = 79;
60 pub const MESSAGE_AUTHENTICATOR: u8 = 80;
61}
62
63const MAX_PACKET_SIZE: usize = 4096;
65
66const HEADER_LEN: usize = 20;
68
69const AUTHENTICATOR_LEN: usize = 16;
71
72#[derive(Debug, Clone)]
76pub struct RadiusConfig {
77 pub server_addr: String,
79
80 pub shared_secret: String,
82
83 pub timeout: Duration,
85
86 pub retries: u32,
88
89 pub nas_identifier: String,
91
92 pub accounting_addr: Option<String>,
94}
95
96impl Default for RadiusConfig {
97 fn default() -> Self {
98 Self {
99 server_addr: "127.0.0.1:1812".into(),
100 shared_secret: String::new(),
101 timeout: Duration::from_secs(5),
102 retries: 3,
103 nas_identifier: "auth-framework".into(),
104 accounting_addr: None,
105 }
106 }
107}
108
109impl RadiusConfig {
110 pub fn with_server(
123 server_addr: impl Into<String>,
124 shared_secret: impl Into<String>,
125 ) -> Result<Self> {
126 let secret = shared_secret.into();
127 if secret.len() < 6 {
128 return Err(AuthError::config(
129 "RADIUS shared_secret must be at least 6 bytes",
130 ));
131 }
132 Ok(Self {
133 server_addr: server_addr.into(),
134 shared_secret: secret,
135 ..Default::default()
136 })
137 }
138
139 pub fn with_options(
154 server_addr: impl Into<String>,
155 shared_secret: impl Into<String>,
156 timeout: Duration,
157 retries: u32,
158 ) -> Result<Self> {
159 let mut cfg = Self::with_server(server_addr, shared_secret)?;
160 cfg.timeout = timeout;
161 cfg.retries = retries;
162 Ok(cfg)
163 }
164}
165
166#[derive(Debug, Clone)]
170pub struct RadiusAttribute {
171 pub attr_type: u8,
172 pub value: Vec<u8>,
173}
174
175#[derive(Debug, Clone)]
177pub struct RadiusPacket {
178 pub code: u8,
179 pub identifier: u8,
180 pub authenticator: [u8; AUTHENTICATOR_LEN],
181 pub attributes: Vec<RadiusAttribute>,
182}
183
184impl RadiusPacket {
185 pub fn add_attribute(&mut self, attr_type: u8, value: impl AsRef<[u8]>) {
203 self.attributes.push(RadiusAttribute {
204 attr_type,
205 value: value.as_ref().to_vec(),
206 });
207 }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct RadiusAuthResult {
213 pub accepted: bool,
215
216 pub reply_message: Option<String>,
218
219 pub session_timeout: Option<u32>,
221
222 pub filter_id: Option<String>,
224
225 pub challenge: bool,
227
228 pub state: Option<Vec<u8>>,
230
231 pub reply_attributes: HashMap<u8, Vec<Vec<u8>>>,
233}
234
235#[derive(Debug)]
239pub struct RadiusClient {
240 config: RadiusConfig,
241}
242
243impl RadiusClient {
244 pub fn new(config: RadiusConfig) -> Result<Self> {
246 if config.shared_secret.is_empty() {
247 return Err(AuthError::config("RADIUS shared secret must not be empty"));
248 }
249 if config.shared_secret.len() < 6 {
250 return Err(AuthError::config(
251 "RADIUS shared secret should be at least 6 bytes",
252 ));
253 }
254 Ok(Self { config })
255 }
256
257 pub async fn authenticate(&self, username: &str, password: &str) -> Result<RadiusAuthResult> {
259 self.authenticate_with_state(username, password, None).await
260 }
261
262 pub async fn authenticate_with_state(
264 &self,
265 username: &str,
266 password: &str,
267 state: Option<&[u8]>,
268 ) -> Result<RadiusAuthResult> {
269 let rng = ring::rand::SystemRandom::new();
270 let mut authenticator = [0u8; AUTHENTICATOR_LEN];
271 rng.fill(&mut authenticator)
272 .map_err(|_| AuthError::crypto("Failed to generate RADIUS authenticator"))?;
273
274 let mut id_buf = [0u8; 1];
275 rng.fill(&mut id_buf)
276 .map_err(|_| AuthError::crypto("Failed to generate RADIUS identifier"))?;
277
278 let mut packet = RadiusPacket {
279 code: code::ACCESS_REQUEST,
280 identifier: id_buf[0],
281 authenticator,
282 attributes: Vec::new(),
283 };
284
285 packet.attributes.push(RadiusAttribute {
287 attr_type: attr::USER_NAME,
288 value: username.as_bytes().to_vec(),
289 });
290
291 let encrypted_password =
293 encrypt_pap_password(password, &self.config.shared_secret, &authenticator);
294 packet.attributes.push(RadiusAttribute {
295 attr_type: attr::USER_PASSWORD,
296 value: encrypted_password,
297 });
298
299 packet.attributes.push(RadiusAttribute {
301 attr_type: attr::NAS_IDENTIFIER,
302 value: self.config.nas_identifier.as_bytes().to_vec(),
303 });
304
305 if let Some(state_val) = state {
307 packet.attributes.push(RadiusAttribute {
308 attr_type: attr::STATE,
309 value: state_val.to_vec(),
310 });
311 }
312
313 let msg_auth =
315 compute_message_authenticator(&packet, self.config.shared_secret.as_bytes())?;
316 packet.attributes.push(RadiusAttribute {
317 attr_type: attr::MESSAGE_AUTHENTICATOR,
318 value: msg_auth.to_vec(),
319 });
320
321 let response = self.send_request(&packet).await?;
322 self.parse_response(&response, &authenticator)
323 }
324
325 pub async fn send_accounting(
327 &self,
328 username: &str,
329 session_id: &str,
330 status_type: u32,
331 ) -> Result<bool> {
332 let addr = self
333 .config
334 .accounting_addr
335 .as_deref()
336 .unwrap_or("127.0.0.1:1813");
337
338 let rng = ring::rand::SystemRandom::new();
339 let mut authenticator = [0u8; AUTHENTICATOR_LEN];
340 rng.fill(&mut authenticator)
341 .map_err(|_| AuthError::crypto("Failed to generate RADIUS authenticator"))?;
342
343 let mut id_buf = [0u8; 1];
344 rng.fill(&mut id_buf)
345 .map_err(|_| AuthError::crypto("Failed to generate RADIUS identifier"))?;
346
347 let mut packet = RadiusPacket {
348 code: code::ACCOUNTING_REQUEST,
349 identifier: id_buf[0],
350 authenticator,
351 attributes: Vec::new(),
352 };
353
354 packet.attributes.push(RadiusAttribute {
355 attr_type: attr::USER_NAME,
356 value: username.as_bytes().to_vec(),
357 });
358
359 packet.attributes.push(RadiusAttribute {
360 attr_type: attr::ACCT_SESSION_ID,
361 value: session_id.as_bytes().to_vec(),
362 });
363
364 packet.attributes.push(RadiusAttribute {
365 attr_type: attr::ACCT_STATUS_TYPE,
366 value: status_type.to_be_bytes().to_vec(),
367 });
368
369 packet.attributes.push(RadiusAttribute {
370 attr_type: attr::NAS_IDENTIFIER,
371 value: self.config.nas_identifier.as_bytes().to_vec(),
372 });
373
374 let encoded = encode_packet(&packet);
376 let acct_auth =
377 compute_accounting_authenticator(&encoded, self.config.shared_secret.as_bytes());
378 let mut final_packet = packet;
379 final_packet.authenticator = acct_auth;
380
381 let response = self
382 .send_packet(&encode_packet(&final_packet), addr)
383 .await?;
384 Ok(response[0] == code::ACCOUNTING_RESPONSE)
385 }
386
387 async fn send_request(&self, packet: &RadiusPacket) -> Result<Vec<u8>> {
389 let encoded = encode_packet(packet);
390 self.send_packet(&encoded, &self.config.server_addr).await
391 }
392
393 async fn send_packet(&self, data: &[u8], addr: &str) -> Result<Vec<u8>> {
395 let server_addr: SocketAddr = addr
396 .parse()
397 .map_err(|e| AuthError::config(format!("Invalid RADIUS server address: {e}")))?;
398
399 let socket = UdpSocket::bind("0.0.0.0:0")
400 .await
401 .map_err(|e| AuthError::internal(format!("Failed to bind UDP socket: {e}")))?;
402
403 for attempt in 0..=self.config.retries {
404 socket
405 .send_to(data, server_addr)
406 .await
407 .map_err(|e| AuthError::internal(format!("RADIUS send failed: {e}")))?;
408
409 let mut buf = vec![0u8; MAX_PACKET_SIZE];
410 match tokio::time::timeout(self.config.timeout, socket.recv_from(&mut buf)).await {
411 Ok(Ok((len, _))) => return Ok(buf[..len].to_vec()),
412 Ok(Err(e)) => {
413 return Err(AuthError::internal(format!("RADIUS recv failed: {e}")));
414 }
415 Err(_) if attempt < self.config.retries => continue,
416 Err(_) => {
417 return Err(AuthError::internal("RADIUS request timed out"));
418 }
419 }
420 }
421
422 Err(AuthError::internal("RADIUS request failed after retries"))
423 }
424
425 fn parse_response(
427 &self,
428 data: &[u8],
429 request_authenticator: &[u8; AUTHENTICATOR_LEN],
430 ) -> Result<RadiusAuthResult> {
431 if data.len() < HEADER_LEN {
432 return Err(AuthError::validation("RADIUS response too short"));
433 }
434
435 let response_code = data[0];
436 let _identifier = data[1];
437 let length = u16::from_be_bytes([data[2], data[3]]) as usize;
438
439 if length > data.len() {
440 return Err(AuthError::validation("RADIUS response length mismatch"));
441 }
442
443 let expected_auth = compute_response_authenticator(
445 data,
446 request_authenticator,
447 self.config.shared_secret.as_bytes(),
448 );
449 let actual_auth = &data[4..20];
450 if !constant_time_eq(actual_auth, &expected_auth) {
451 return Err(AuthError::validation(
452 "RADIUS response authenticator verification failed",
453 ));
454 }
455
456 let mut reply_attributes: HashMap<u8, Vec<Vec<u8>>> = HashMap::new();
458 let mut pos = HEADER_LEN;
459 while pos + 2 <= length {
460 let attr_type = data[pos];
461 let attr_len = data[pos + 1] as usize;
462 if attr_len < 2 || pos + attr_len > length {
463 break;
464 }
465 let value = data[pos + 2..pos + attr_len].to_vec();
466 reply_attributes.entry(attr_type).or_default().push(value);
467 pos += attr_len;
468 }
469
470 let reply_message = reply_attributes
471 .get(&attr::REPLY_MESSAGE)
472 .and_then(|v| v.first())
473 .and_then(|b| String::from_utf8(b.clone()).ok());
474
475 let session_timeout = reply_attributes
476 .get(&attr::SESSION_TIMEOUT)
477 .and_then(|v| v.first())
478 .and_then(|b| {
479 if b.len() == 4 {
480 Some(u32::from_be_bytes([b[0], b[1], b[2], b[3]]))
481 } else {
482 None
483 }
484 });
485
486 let filter_id = reply_attributes
487 .get(&attr::FILTER_ID)
488 .and_then(|v| v.first())
489 .and_then(|b| String::from_utf8(b.clone()).ok());
490
491 let state = reply_attributes
492 .get(&attr::STATE)
493 .and_then(|v| v.first())
494 .cloned();
495
496 Ok(RadiusAuthResult {
497 accepted: response_code == code::ACCESS_ACCEPT,
498 reply_message,
499 session_timeout,
500 filter_id,
501 challenge: response_code == code::ACCESS_CHALLENGE,
502 state,
503 reply_attributes,
504 })
505 }
506}
507
508fn encode_packet(packet: &RadiusPacket) -> Vec<u8> {
512 let mut buf = Vec::with_capacity(MAX_PACKET_SIZE);
513
514 buf.push(packet.code);
516 buf.push(packet.identifier);
517 buf.extend_from_slice(&[0, 0]); buf.extend_from_slice(&packet.authenticator);
519
520 for attr in &packet.attributes {
522 let attr_len = (2 + attr.value.len()) as u8;
523 buf.push(attr.attr_type);
524 buf.push(attr_len);
525 buf.extend_from_slice(&attr.value);
526 }
527
528 let len = buf.len() as u16;
530 buf[2..4].copy_from_slice(&len.to_be_bytes());
531
532 buf
533}
534
535fn encrypt_pap_password(
540 password: &str,
541 shared_secret: &str,
542 authenticator: &[u8; AUTHENTICATOR_LEN],
543) -> Vec<u8> {
544 let pwd_bytes = password.as_bytes();
545 let padded_len = ((pwd_bytes.len() + 15) / 16) * 16;
547 let padded_len = padded_len.max(16);
548 let mut padded = vec![0u8; padded_len];
549 padded[..pwd_bytes.len()].copy_from_slice(pwd_bytes);
550
551 let mut result = vec![0u8; padded_len];
552 let mut prev_block = authenticator.to_vec();
553
554 for i in 0..(padded_len / 16) {
555 let hasher = md5_hash(shared_secret.as_bytes(), &prev_block);
556 let chunk_start = i * 16;
557 for j in 0..16 {
558 result[chunk_start + j] = padded[chunk_start + j] ^ hasher[j];
559 }
560 prev_block = result[chunk_start..chunk_start + 16].to_vec();
561 }
562
563 result
564}
565
566fn md5_hash(a: &[u8], b: &[u8]) -> [u8; 16] {
568 let mut hasher = md5::Md5::new();
569 hasher.update(a);
570 hasher.update(b);
571 hasher.finalize().into()
572}
573
574fn compute_message_authenticator(packet: &RadiusPacket, secret: &[u8]) -> Result<[u8; 16]> {
576 let mut temp_packet = packet.clone();
579 temp_packet
581 .attributes
582 .retain(|a| a.attr_type != attr::MESSAGE_AUTHENTICATOR);
583 temp_packet.attributes.push(RadiusAttribute {
585 attr_type: attr::MESSAGE_AUTHENTICATOR,
586 value: vec![0u8; 16],
587 });
588
589 let encoded = encode_packet(&temp_packet);
590 let hmac_result = hmac_md5_truncated(secret, &encoded);
591 Ok(hmac_result)
592}
593
594fn hmac_md5_truncated(key: &[u8], data: &[u8]) -> [u8; 16] {
596 use hmac::Mac;
597 type HmacMd5 = hmac::Hmac<md5::Md5>;
598 let mut mac = HmacMd5::new_from_slice(key).expect("HMAC key length");
599 mac.update(data);
600 let result = mac.finalize().into_bytes();
601 let mut out = [0u8; 16];
602 out.copy_from_slice(&result[..16]);
603 out
604}
605
606fn compute_response_authenticator(
610 response: &[u8],
611 request_auth: &[u8; AUTHENTICATOR_LEN],
612 secret: &[u8],
613) -> [u8; 16] {
614 let mut hasher = md5::Md5::new();
615 hasher.update(&response[..4]); hasher.update(request_auth); if response.len() > HEADER_LEN {
618 hasher.update(&response[HEADER_LEN..]); }
620 hasher.update(secret);
621 hasher.finalize().into()
622}
623
624fn compute_accounting_authenticator(packet_bytes: &[u8], secret: &[u8]) -> [u8; AUTHENTICATOR_LEN] {
626 let mut hasher = md5::Md5::new();
627 hasher.update(&packet_bytes[..4]); hasher.update(&[0u8; AUTHENTICATOR_LEN]); if packet_bytes.len() > HEADER_LEN {
630 hasher.update(&packet_bytes[HEADER_LEN..]); }
632 hasher.update(secret);
633 hasher.finalize().into()
634}
635
636fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
638 use subtle::ConstantTimeEq;
639 if a.len() != b.len() {
640 return false;
641 }
642 a.ct_eq(b).into()
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
650 fn test_config_defaults() {
651 let config = RadiusConfig::default();
652 assert_eq!(config.server_addr, "127.0.0.1:1812");
653 assert_eq!(config.retries, 3);
654 }
655
656 #[test]
657 fn test_client_requires_secret() {
658 let config = RadiusConfig::default();
659 let err = RadiusClient::new(config).unwrap_err();
660 assert!(err.to_string().contains("secret"));
661 }
662
663 #[test]
664 fn test_client_rejects_short_secret() {
665 let config = RadiusConfig {
666 shared_secret: "abc".into(),
667 ..Default::default()
668 };
669 let err = RadiusClient::new(config).unwrap_err();
670 assert!(err.to_string().contains("6 bytes"));
671 }
672
673 #[test]
674 fn test_client_creation() {
675 let config = RadiusConfig {
676 shared_secret: "testing123".into(),
677 ..Default::default()
678 };
679 assert!(RadiusClient::new(config).is_ok());
680 }
681
682 #[test]
683 fn test_packet_encoding() {
684 let packet = RadiusPacket {
685 code: code::ACCESS_REQUEST,
686 identifier: 42,
687 authenticator: [0u8; AUTHENTICATOR_LEN],
688 attributes: vec![RadiusAttribute {
689 attr_type: attr::USER_NAME,
690 value: b"test".to_vec(),
691 }],
692 };
693
694 let encoded = encode_packet(&packet);
695 assert_eq!(encoded[0], code::ACCESS_REQUEST);
696 assert_eq!(encoded[1], 42);
697 let length = u16::from_be_bytes([encoded[2], encoded[3]]);
698 assert_eq!(length as usize, encoded.len());
699 }
700
701 #[test]
702 fn test_pap_password_encryption() {
703 let auth = [1u8; AUTHENTICATOR_LEN];
704 let encrypted = encrypt_pap_password("password", "secret", &auth);
705 assert_eq!(encrypted.len(), 16); assert_ne!(&encrypted[..8], b"password");
708 }
709
710 #[test]
711 fn test_radius_attribute_codes() {
712 assert_eq!(attr::USER_NAME, 1);
713 assert_eq!(attr::USER_PASSWORD, 2);
714 assert_eq!(attr::MESSAGE_AUTHENTICATOR, 80);
715 }
716
717 #[test]
718 fn test_radius_config_with_server() {
719 let config = RadiusConfig::with_server("10.0.0.1:1812", "testing123").unwrap();
720 assert_eq!(config.server_addr, "10.0.0.1:1812");
721 assert_eq!(config.shared_secret, "testing123");
722 assert_eq!(config.retries, 3); }
724
725 #[test]
726 fn test_radius_config_with_server_rejects_short_secret() {
727 let err = RadiusConfig::with_server("10.0.0.1:1812", "abc").unwrap_err();
728 assert!(err.to_string().contains("6 bytes"));
729 }
730
731 #[test]
732 fn test_radius_config_with_options() {
733 let config =
734 RadiusConfig::with_options("10.0.0.1:1812", "testing123", Duration::from_secs(10), 5)
735 .unwrap();
736 assert_eq!(config.timeout, Duration::from_secs(10));
737 assert_eq!(config.retries, 5);
738 }
739
740 #[test]
741 fn test_radius_packet_add_attribute() {
742 let mut packet = RadiusPacket {
743 code: code::ACCESS_REQUEST,
744 identifier: 1,
745 authenticator: [0u8; AUTHENTICATOR_LEN],
746 attributes: Vec::new(),
747 };
748 packet.add_attribute(attr::USER_NAME, b"alice");
749 packet.add_attribute(attr::NAS_IDENTIFIER, b"my-nas");
750
751 assert_eq!(packet.attributes.len(), 2);
752 assert_eq!(packet.attributes[0].attr_type, attr::USER_NAME);
753 assert_eq!(packet.attributes[0].value, b"alice");
754 assert_eq!(packet.attributes[1].attr_type, attr::NAS_IDENTIFIER);
755 }
756}