Documentation
// Copyright (c) 2025, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

//! Port picker for integration testing
//!
//! This module provides utilities for finding unused ports for integration testing.
//! It supports both TCP and UDP sockets and provides functions to check if a port is free.
//!

use std::net::{
    Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, TcpListener, ToSocketAddrs, UdpSocket,
};

use getrandom::getrandom;

use crate::port::Port;

// Try to bind to a socket using UDP
fn test_bind_udp<A: ToSocketAddrs>(addr: A) -> Option<Port> {
    Some(UdpSocket::bind(addr).ok()?.local_addr().ok()?.port())
}

// Try to bind to a socket using TCP
fn test_bind_tcp<A: ToSocketAddrs>(addr: A) -> Option<Port> {
    Some(TcpListener::bind(addr).ok()?.local_addr().ok()?.port())
}

/// Check if a port is free on UDP
fn is_free_udp(port: Port) -> bool {
    let ipv4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port);
    let ipv6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0);

    test_bind_udp(ipv6).is_some() && test_bind_udp(ipv4).is_some()
}

/// Check if a port is free on TCP
fn is_free_tcp(port: Port) -> bool {
    let ipv4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port);
    let ipv6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0);

    test_bind_tcp(ipv6).is_some() && test_bind_tcp(ipv4).is_some()
}

/// Check if a port is free on both TCP and UDP
fn is_free(port: Port) -> bool {
    is_free_tcp(port) && is_free_udp(port)
}

/// Asks the OS for a free port
fn ask_free_tcp_port() -> Option<Port> {
    let ipv4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
    let ipv6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0);

    test_bind_tcp(ipv6).or_else(|| test_bind_tcp(ipv4))
}

fn generate_port() -> Port {
    const MIN: Port = 15000;
    const MAX: Port = 25000;
    const RANGE: Port = MAX - MIN;
    const BUFFER_SIZE: usize = std::mem::size_of::<Port>();

    let mut buffer = [0; BUFFER_SIZE];

    getrandom(&mut buffer).unwrap();

    let rnd = Port::from_le_bytes(buffer);

    MIN + rnd % RANGE
}

/// Select an unused port.
pub fn pick_unused_port() -> Option<Port> {
    // Try random port first
    for _ in 0..10 {
        let port = generate_port();
        if is_free(port) {
            return Some(port);
        }
    }

    // Ask the OS for a port
    for _ in 0..10 {
        if let Some(port) = ask_free_tcp_port() {
            // Test that the udp port is free as well
            if is_free_udp(port) {
                return Some(port);
            }
        }
    }

    // Give up
    None
}

#[cfg(test)]
mod tests {
    use super::pick_unused_port;

    #[test]
    fn it_works() {
        assert!(pick_unused_port().is_some());
    }
}