use std::net::{SocketAddr, TcpListener};
use crate::error::{Error, Result};
const LOOPBACK_ADDR: &str = "127.0.0.1";
pub struct LoopbackListener {
listener: TcpListener,
addr: SocketAddr,
}
impl LoopbackListener {
pub fn bind(preferred_port: Option<u16>) -> Result<Self> {
if let Some(port) = preferred_port {
if let Ok(listener) = TcpListener::bind((LOOPBACK_ADDR, port)) {
return Self::finalize(listener);
}
}
let listener = TcpListener::bind((LOOPBACK_ADDR, 0)).map_err(|err| {
Error::TokenProvider(format!("failed to bind loopback listener: {err}"))
})?;
Self::finalize(listener)
}
fn finalize(listener: TcpListener) -> Result<Self> {
listener
.set_nonblocking(true)
.map_err(|err| Error::TokenProvider(format!("failed to configure listener: {err}")))?;
let addr = listener.local_addr().map_err(|err| {
Error::TokenProvider(format!("failed to get listener address: {err}"))
})?;
Ok(Self { listener, addr })
}
pub fn port(&self) -> u16 {
self.addr.port()
}
pub fn into_std(self) -> TcpListener {
self.listener
}
}
pub fn build_redirect_uri(port: u16) -> String {
format!("http://{LOOPBACK_ADDR}:{port}")
}
pub fn bind_loopback_listener(preferred_port: Option<u16>) -> Result<LoopbackListener> {
LoopbackListener::bind(preferred_port)
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpStream;
#[test]
fn test_bind_loopback_listener_without_preferred_port() {
let listener = LoopbackListener::bind(None).unwrap();
assert!(listener.addr.ip().is_loopback());
let port = listener.port();
assert!(port > 0);
let std_listener = listener.into_std();
assert!(std_listener.local_addr().is_ok());
}
#[test]
fn test_bind_loopback_listener_with_available_preferred_port() {
let temp_listener = TcpListener::bind(("127.0.0.1", 0)).unwrap();
let available_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let listener = LoopbackListener::bind(Some(available_port)).unwrap();
assert_eq!(listener.port(), available_port);
assert!(listener.addr.ip().is_loopback());
}
#[test]
fn test_bind_loopback_listener_with_unavailable_preferred_port() {
let _occupier = TcpListener::bind(("127.0.0.1", 0)).unwrap();
let occupied_port = _occupier.local_addr().unwrap().port();
let listener = LoopbackListener::bind(Some(occupied_port)).unwrap();
assert_ne!(listener.port(), occupied_port);
assert!(listener.port() > 0);
}
#[test]
fn test_loopback_listener_is_nonblocking() {
let listener = LoopbackListener::bind(None).unwrap();
let std_listener = listener.into_std();
match std_listener.accept() {
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
}
_ => panic!("Expected WouldBlock error for non-blocking listener"),
}
drop(std_listener);
}
#[test]
fn test_loopback_listener_can_accept_connections() {
let listener = LoopbackListener::bind(None).unwrap();
let port = listener.port();
let std_listener = listener.into_std();
std_listener.set_nonblocking(false).unwrap();
let handle = std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
TcpStream::connect(("127.0.0.1", port))
});
let result = std_listener.accept();
assert!(result.is_ok());
handle.join().unwrap().unwrap();
}
#[test]
fn test_build_redirect_uri() {
let uri = build_redirect_uri(8080);
assert_eq!(uri, "http://127.0.0.1:8080");
let uri2 = build_redirect_uri(4317);
assert_eq!(uri2, "http://127.0.0.1:4317");
}
#[test]
fn test_bind_loopback_listener_convenience_function() {
let listener = bind_loopback_listener(None).unwrap();
assert!(listener.port() > 0);
let temp_listener = TcpListener::bind(("127.0.0.1", 0)).unwrap();
let available_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let listener = bind_loopback_listener(Some(available_port)).unwrap();
assert_eq!(listener.port(), available_port);
}
}