use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Semaphore;
use crate::config::{ProxyProtocolConfig, VersionPreference};
use crate::error::{AcceptError, ParseError};
use crate::parse::parse;
use crate::policy::PolicyDecision;
use crate::stream::ProxiedStream;
use crate::types::Version;
pub struct ProxyProtocolListener {
inner: TcpListener,
config: ProxyProtocolConfig,
handshake_semaphore: Arc<Semaphore>,
}
impl ProxyProtocolListener {
pub fn new(listener: TcpListener, config: ProxyProtocolConfig) -> Self {
let semaphore = Arc::new(Semaphore::new(config.max_pending_handshakes));
Self {
inner: listener,
config,
handshake_semaphore: semaphore,
}
}
pub async fn accept(&self) -> Result<ProxiedStream, AcceptError> {
let (stream, peer_addr) = self.inner.accept().await.map_err(AcceptError::Io)?;
let decision = self.config.policy.evaluate(peer_addr);
match decision {
PolicyDecision::Reject => {
drop(stream);
Err(AcceptError::Rejected(peer_addr))
}
PolicyDecision::Ignore => Ok(ProxiedStream::new(stream, Vec::new(), None, peer_addr)),
PolicyDecision::Require | PolicyDecision::Use => {
let required = decision == PolicyDecision::Require;
self.read_and_validate(stream, peer_addr, required).await
}
}
}
async fn read_and_validate(
&self,
mut stream: TcpStream,
peer_addr: SocketAddr,
required: bool,
) -> Result<ProxiedStream, AcceptError> {
let _permit = tokio::time::timeout(
self.config.header_timeout,
self.handshake_semaphore.acquire(),
)
.await
.map_err(|_| AcceptError::HeaderTimeout(peer_addr))?
.map_err(|_| AcceptError::Io(io::Error::other("semaphore closed")))?;
let deadline = tokio::time::Instant::now() + self.config.header_timeout;
let mut buf = Vec::with_capacity(256);
let mut tmp = [0u8; 256];
loop {
let n = tokio::time::timeout_at(deadline, stream.read(&mut tmp))
.await
.map_err(|_| AcceptError::HeaderTimeout(peer_addr))?
.map_err(AcceptError::Io)?;
if n == 0 {
if buf.is_empty() {
return Err(AcceptError::EmptyConnection(peer_addr));
}
return Err(AcceptError::Parse(ParseError::Incomplete, peer_addr));
}
buf.extend_from_slice(&tmp[..n]);
match parse(&buf) {
Ok((info, consumed)) => {
match (self.config.version, info.version) {
(VersionPreference::V1Only, Version::V2)
| (VersionPreference::V2Only, Version::V1) => {
return Err(AcceptError::VersionMismatch(peer_addr));
}
_ => {}
}
if let Some(validator) = &self.config.validator {
validator
.validate(&info, peer_addr)
.map_err(|e| AcceptError::ValidationFailed(e, peer_addr))?;
}
let leftover = buf[consumed..].to_vec();
return Ok(ProxiedStream::new(stream, leftover, Some(info), peer_addr));
}
Err(ParseError::Incomplete) if buf.len() < self.config.max_header_size => {
continue;
}
Err(ParseError::NotProxyProtocol) if !required => {
return Ok(ProxiedStream::new(stream, buf, None, peer_addr));
}
Err(e) => {
return Err(AcceptError::Parse(e, peer_addr));
}
}
}
}
pub fn inner(&self) -> &TcpListener {
&self.inner
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
}
impl From<TcpListener> for ProxyProtocolListener {
fn from(listener: TcpListener) -> Self {
Self::new(listener, ProxyProtocolConfig::default())
}
}