use std::net::SocketAddr;
use std::time::Duration;
use thiserror::Error;
use tokio::net::UdpSocket;
use tokio::time::timeout;
use crate::core::packet::Packet;
#[derive(Error, Debug)]
pub enum ClientError {
#[error("failed to bind a UDP socket; {0}")]
FailedUdpSocketBindingError(String),
#[error("failed to establish a UDP connection to {0}; {1}")]
FailedEstablishingUdpConnectionError(String, String),
#[error("failed to encode a RADIUS request; {0}")]
FailedRadiusPacketEncodingError(String),
#[error("failed to send a UDP datagram to {0}; {1}")]
FailedSendingRadiusPacketError(String, String),
#[error("failed to receive the UDP response from {0}; {1}")]
FailedReceivingResponseError(String, String),
#[error("failed to decode a RADIUS response packet; {0}")]
FailedDecodingRadiusResponseError(String),
#[error("connection timeout")]
ConnectionTimeoutError(),
#[error("socket timeout")]
SocketTimeoutError(),
}
pub struct Client {
connection_timeout: Option<Duration>,
socket_timeout: Option<Duration>,
}
impl Client {
const MAX_DATAGRAM_SIZE: usize = 65507;
pub fn new(connection_timeout: Option<Duration>, socket_timeout: Option<Duration>) -> Self {
Client {
connection_timeout,
socket_timeout,
}
}
pub async fn send_packet(
&self,
remote_addr: &SocketAddr,
request_packet: &Packet,
) -> Result<Packet, ClientError> {
let local_addr: SocketAddr = if remote_addr.is_ipv4() {
"0.0.0.0:0"
} else {
"[::]:0"
}
.parse()
.unwrap();
let conn = match UdpSocket::bind(local_addr).await {
Ok(conn) => conn,
Err(e) => return Err(ClientError::FailedUdpSocketBindingError(e.to_string())),
};
match self.connection_timeout {
Some(connection_timeout) => {
match timeout(connection_timeout, self.connect(&conn, remote_addr)).await {
Ok(conn_establish_res) => conn_establish_res,
Err(_) => Err(ClientError::ConnectionTimeoutError()),
}
}
None => self.connect(&conn, remote_addr).await,
}?;
let request_data = match request_packet.encode() {
Ok(encoded) => encoded,
Err(e) => return Err(ClientError::FailedRadiusPacketEncodingError(format!("{e}"))),
};
let response = match self.socket_timeout {
Some(socket_timeout) => {
match timeout(
socket_timeout,
self.request(&conn, &request_data, remote_addr),
)
.await
{
Ok(response) => response,
Err(_) => Err(ClientError::SocketTimeoutError()),
}
}
None => self.request(&conn, &request_data, remote_addr).await,
}?;
match Packet::decode(&response.to_vec(), request_packet.get_secret()) {
Ok(response_packet) => Ok(response_packet),
Err(e) => Err(ClientError::FailedDecodingRadiusResponseError(format!(
"{e}"
))),
}
}
async fn connect(&self, conn: &UdpSocket, remote_addr: &SocketAddr) -> Result<(), ClientError> {
match conn.connect(remote_addr).await {
Ok(_) => Ok(()),
Err(e) => Err(ClientError::FailedEstablishingUdpConnectionError(
remote_addr.to_string(),
e.to_string(),
)),
}
}
async fn request(
&self,
conn: &UdpSocket,
request_data: &[u8],
remote_addr: &SocketAddr,
) -> Result<Vec<u8>, ClientError> {
match conn.send(request_data).await {
Ok(_) => {}
Err(e) => {
return Err(ClientError::FailedSendingRadiusPacketError(
remote_addr.to_string(),
e.to_string(),
))
}
};
let mut buf = vec![0; Self::MAX_DATAGRAM_SIZE];
match conn.recv(&mut buf).await {
Ok(len) => Ok(buf[..len].to_vec()),
Err(e) => Err(ClientError::FailedReceivingResponseError(
remote_addr.to_string(),
e.to_string(),
)),
}
}
}