#![allow(dead_code)]
#![allow(clippy::too_many_arguments)]
use super::ice::{CandidateType, IceCandidate, IceServer, TransportProtocol};
use super::stun::{Attribute, AttributeType, Message, MessageType};
use crate::error::{NetError, NetResult};
use std::net::{IpAddr, SocketAddr, UdpSocket};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::net::UdpSocket as TokioUdpSocket;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IceRole {
Controlling,
Controlled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IceConnectionState {
New,
Gathering,
Checking,
Connected,
Completed,
Failed,
Disconnected,
Closed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IceGatheringState {
New,
Gathering,
Complete,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CandidatePairState {
Waiting,
InProgress,
Succeeded,
Failed,
Frozen,
}
#[derive(Debug, Clone)]
pub struct CandidatePair {
pub local: IceCandidate,
pub remote: IceCandidate,
pub state: CandidatePairState,
pub priority: u64,
pub nominated: bool,
pub last_check: Option<Instant>,
pub checks_sent: u32,
}
impl CandidatePair {
#[must_use]
pub fn new(local: IceCandidate, remote: IceCandidate, controlling: bool) -> Self {
let priority = Self::calculate_priority(&local, &remote, controlling);
Self {
local,
remote,
state: CandidatePairState::Waiting,
priority,
nominated: false,
last_check: None,
checks_sent: 0,
}
}
#[must_use]
pub fn calculate_priority(
local: &IceCandidate,
remote: &IceCandidate,
controlling: bool,
) -> u64 {
let g = if controlling {
local.priority
} else {
remote.priority
};
let d = if controlling {
remote.priority
} else {
local.priority
};
let g = u64::from(g);
let d = u64::from(d);
(1_u64 << 32).wrapping_mul(g.min(d)) + 2 * g.max(d) + if g > d { 1 } else { 0 }
}
pub fn local_addr(&self) -> NetResult<SocketAddr> {
let ip: IpAddr = self
.local
.address
.parse()
.map_err(|_| NetError::parse(0, "Invalid local address"))?;
Ok(SocketAddr::new(ip, self.local.port))
}
pub fn remote_addr(&self) -> NetResult<SocketAddr> {
let ip: IpAddr = self
.remote
.address
.parse()
.map_err(|_| NetError::parse(0, "Invalid remote address"))?;
Ok(SocketAddr::new(ip, self.remote.port))
}
}
#[derive(Debug, Clone)]
pub struct IceAgentConfig {
pub ice_servers: Vec<IceServer>,
pub local_ufrag: String,
pub local_pwd: String,
pub remote_ufrag: Option<String>,
pub remote_pwd: Option<String>,
pub controlling: bool,
pub tie_breaker: u64,
}
impl Default for IceAgentConfig {
fn default() -> Self {
use rand::RngExt;
let mut rng = rand::rng();
Self {
ice_servers: Vec::new(),
local_ufrag: generate_ice_string(8),
local_pwd: generate_ice_string(24),
remote_ufrag: None,
remote_pwd: None,
controlling: true,
tie_breaker: rng.random::<u64>(),
}
}
}
pub struct IceAgent {
config: IceAgentConfig,
state: Arc<Mutex<IceConnectionState>>,
gathering_state: Arc<Mutex<IceGatheringState>>,
local_candidates: Arc<Mutex<Vec<IceCandidate>>>,
remote_candidates: Arc<Mutex<Vec<IceCandidate>>>,
pairs: Arc<Mutex<Vec<CandidatePair>>>,
selected_pair: Arc<Mutex<Option<CandidatePair>>>,
socket: Arc<Mutex<Option<Arc<TokioUdpSocket>>>>,
}
impl IceAgent {
#[must_use]
pub fn new(config: IceAgentConfig) -> Self {
Self {
config,
state: Arc::new(Mutex::new(IceConnectionState::New)),
gathering_state: Arc::new(Mutex::new(IceGatheringState::New)),
local_candidates: Arc::new(Mutex::new(Vec::new())),
remote_candidates: Arc::new(Mutex::new(Vec::new())),
pairs: Arc::new(Mutex::new(Vec::new())),
selected_pair: Arc::new(Mutex::new(None)),
socket: Arc::new(Mutex::new(None)),
}
}
pub async fn gather_candidates(&self) -> NetResult<Vec<IceCandidate>> {
*self
.gathering_state
.lock()
.unwrap_or_else(|e| e.into_inner()) = IceGatheringState::Gathering;
let mut candidates = Vec::new();
candidates.extend(self.gather_host_candidates()?);
for server in &self.config.ice_servers {
if !server.is_turn() {
if let Some(srflx) = self.gather_srflx_candidate(server).await? {
candidates.push(srflx);
}
}
}
*self
.local_candidates
.lock()
.unwrap_or_else(|e| e.into_inner()) = candidates.clone();
*self
.gathering_state
.lock()
.unwrap_or_else(|e| e.into_inner()) = IceGatheringState::Complete;
Ok(candidates)
}
fn gather_host_candidates(&self) -> NetResult<Vec<IceCandidate>> {
let mut candidates = Vec::new();
let socket = UdpSocket::bind("0.0.0.0:0")
.map_err(|e| NetError::connection(format!("Failed to bind socket: {e}")))?;
let local_addr = socket
.local_addr()
.map_err(|e| NetError::connection(format!("Failed to get local address: {e}")))?;
let foundation = super::ice::compute_foundation(
CandidateType::Host,
&local_addr.ip().to_string(),
TransportProtocol::Udp,
None,
);
let candidate =
IceCandidate::host(foundation, local_addr.ip().to_string(), local_addr.port());
candidates.push(candidate);
drop(socket);
Ok(candidates)
}
async fn gather_srflx_candidate(&self, server: &IceServer) -> NetResult<Option<IceCandidate>> {
if server.urls.is_empty() {
return Ok(None);
}
let url = &server.urls[0];
let addr = url
.strip_prefix("stun:")
.or_else(|| url.strip_prefix("stun://"))
.ok_or_else(|| NetError::invalid_url("Invalid STUN URL"))?;
let server_addr: SocketAddr = tokio::net::lookup_host(addr)
.await
.map_err(|e| NetError::connection(format!("Failed to resolve STUN server: {e}")))?
.next()
.ok_or_else(|| NetError::connection("No addresses found for STUN server"))?;
let socket = TokioUdpSocket::bind("0.0.0.0:0")
.await
.map_err(|e| NetError::connection(format!("Failed to bind socket: {e}")))?;
let local_addr = socket
.local_addr()
.map_err(|e| NetError::connection(format!("Failed to get local address: {e}")))?;
let request = Message::binding_request();
let encoded = request.encode();
socket
.send_to(&encoded, server_addr)
.await
.map_err(|e| NetError::connection(format!("Failed to send STUN request: {e}")))?;
let mut buf = [0u8; 2048];
let result = tokio::time::timeout(Duration::from_secs(5), socket.recv_from(&mut buf)).await;
match result {
Ok(Ok((len, _))) => {
let response = Message::parse(&buf[..len])?;
if response.message_type == MessageType::BindingResponse {
if let Some(attr) = response.get_attribute(AttributeType::XorMappedAddress) {
let mapped_addr =
attr.parse_xor_mapped_address(&response.transaction_id)?;
let foundation = super::ice::compute_foundation(
CandidateType::ServerReflexive,
&mapped_addr.ip().to_string(),
TransportProtocol::Udp,
Some(addr),
);
let candidate = IceCandidate::server_reflexive(
foundation,
mapped_addr.ip().to_string(),
mapped_addr.port(),
local_addr.ip().to_string(),
local_addr.port(),
);
return Ok(Some(candidate));
}
}
}
Ok(Err(e)) => {
return Err(NetError::connection(format!(
"Failed to receive STUN response: {e}"
)));
}
Err(_) => {
return Err(NetError::timeout("STUN request timeout"));
}
}
Ok(None)
}
pub fn add_remote_candidate(&self, candidate: IceCandidate) {
let mut remote = self
.remote_candidates
.lock()
.unwrap_or_else(|e| e.into_inner());
remote.push(candidate);
self.create_pairs();
}
pub fn set_remote_params(&mut self, ufrag: String, pwd: String) {
self.config.remote_ufrag = Some(ufrag);
self.config.remote_pwd = Some(pwd);
}
fn create_pairs(&self) {
let local = self
.local_candidates
.lock()
.unwrap_or_else(|e| e.into_inner());
let remote = self
.remote_candidates
.lock()
.unwrap_or_else(|e| e.into_inner());
let mut pairs = self.pairs.lock().unwrap_or_else(|e| e.into_inner());
pairs.clear();
for local_cand in local.iter() {
for remote_cand in remote.iter() {
let pair = CandidatePair::new(
local_cand.clone(),
remote_cand.clone(),
self.config.controlling,
);
pairs.push(pair);
}
}
pairs.sort_by(|a, b| b.priority.cmp(&a.priority));
}
pub async fn check_connectivity(&self) -> NetResult<()> {
*self.state.lock().unwrap_or_else(|e| e.into_inner()) = IceConnectionState::Checking;
let pairs_to_check = {
let pairs = self.pairs.lock().unwrap_or_else(|e| e.into_inner());
pairs.clone()
};
if pairs_to_check.is_empty() {
*self.state.lock().unwrap_or_else(|e| e.into_inner()) = IceConnectionState::Failed;
return Err(NetError::connection("No candidate pairs to check"));
}
for pair in &pairs_to_check {
if let Ok(true) = self.check_pair(pair).await {
*self.selected_pair.lock().unwrap_or_else(|e| e.into_inner()) = Some(pair.clone());
*self.state.lock().unwrap_or_else(|e| e.into_inner()) =
IceConnectionState::Connected;
return Ok(());
}
}
*self.state.lock().unwrap_or_else(|e| e.into_inner()) = IceConnectionState::Failed;
Err(NetError::connection("No valid candidate pairs found"))
}
async fn check_pair(&self, pair: &CandidatePair) -> NetResult<bool> {
let local_addr = pair.local_addr()?;
let remote_addr = pair.remote_addr()?;
let socket = {
let mut sock = self.socket.lock().unwrap_or_else(|e| e.into_inner());
match sock.as_ref() {
Some(s) => s.clone(),
None => {
let new_sock = Arc::new(
TokioUdpSocket::bind(local_addr)
.await
.map_err(|e| NetError::connection(format!("Failed to bind: {e}")))?,
);
*sock = Some(new_sock.clone());
new_sock
}
}
};
let username = format!(
"{}:{}",
self.config.remote_ufrag.as_ref().unwrap_or(&String::new()),
self.config.local_ufrag
);
let mut request = Message::binding_request()
.with_attribute(Attribute::username(&username))
.with_attribute(Attribute::priority(pair.local.priority));
if self.config.controlling {
request = request.with_attribute(Attribute::ice_controlling(self.config.tie_breaker));
if pair.nominated {
request = request.with_attribute(Attribute::use_candidate());
}
} else {
request = request.with_attribute(Attribute::ice_controlled(self.config.tie_breaker));
}
let pwd = self
.config
.remote_pwd
.as_ref()
.unwrap_or(&self.config.local_pwd);
let encoded = request.encode_with_integrity(pwd);
socket
.send_to(&encoded, remote_addr)
.await
.map_err(|e| NetError::connection(format!("Failed to send: {e}")))?;
let mut buf = [0u8; 2048];
let result =
tokio::time::timeout(Duration::from_millis(500), socket.recv_from(&mut buf)).await;
match result {
Ok(Ok((len, _))) => {
let response = Message::parse(&buf[..len])?;
Ok(response.message_type == MessageType::BindingResponse)
}
_ => Ok(false),
}
}
#[must_use]
pub fn state(&self) -> IceConnectionState {
*self.state.lock().unwrap_or_else(|e| e.into_inner())
}
#[must_use]
pub fn gathering_state(&self) -> IceGatheringState {
*self
.gathering_state
.lock()
.unwrap_or_else(|e| e.into_inner())
}
#[must_use]
pub fn local_candidates(&self) -> Vec<IceCandidate> {
self.local_candidates
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
#[must_use]
pub fn selected_pair(&self) -> Option<CandidatePair> {
self.selected_pair
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
#[must_use]
pub fn socket(&self) -> Option<Arc<TokioUdpSocket>> {
self.socket
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
}
#[must_use]
fn generate_ice_string(length: usize) -> String {
use rand::RngExt;
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
let mut rng = rand::rng();
(0..length)
.map(|_| {
let idx = rng.random_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ice_role() {
assert_eq!(IceRole::Controlling, IceRole::Controlling);
assert_ne!(IceRole::Controlling, IceRole::Controlled);
}
#[test]
fn test_ice_config_default() {
let config = IceAgentConfig::default();
assert!(!config.local_ufrag.is_empty());
assert!(!config.local_pwd.is_empty());
assert!(config.controlling);
}
#[test]
fn test_candidate_pair_priority() {
let local = IceCandidate::host("1", "192.168.1.1", 5000).with_priority(100);
let remote = IceCandidate::host("2", "192.168.1.2", 5001).with_priority(200);
let pair = CandidatePair::new(local, remote, true);
assert!(pair.priority > 0);
}
#[test]
fn test_generate_ice_string() {
let s = generate_ice_string(16);
assert_eq!(s.len(), 16);
assert!(s.chars().all(|c| c.is_alphanumeric()));
}
}