use std::fmt;
use std::net::{TcpListener, TcpStream, ToSocketAddrs};
use tungstenite::protocol::WebSocket;
use tungstenite::Message;
use crate::error::FrpError;
use crate::message::{FrpMessage, FrpProtocolMessage, Severity};
pub struct FrpListener {
listener: TcpListener,
supported_versions: Vec<String>,
}
impl FrpListener {
pub fn bind(
addr: impl ToSocketAddrs,
supported_versions: &[&str],
) -> Result<Self, FrpError> {
let listener = TcpListener::bind(addr)
.map_err(|e| FrpError::WebSocket(Box::new(tungstenite::Error::Io(e))))?;
Ok(Self {
listener,
supported_versions: supported_versions.iter().map(|&s| s.to_owned()).collect(),
})
}
pub fn accept(&self) -> Result<FrpConnection, FrpError> {
let (stream, _addr) = self
.listener
.accept()
.map_err(|e| FrpError::WebSocket(Box::new(tungstenite::Error::Io(e))))?;
let mut socket = tungstenite::accept(stream).map_err(|e| match e {
tungstenite::HandshakeError::Failure(e) => FrpError::WebSocket(Box::new(e)),
tungstenite::HandshakeError::Interrupted(_) => {
FrpError::Handshake("WebSocket handshake interrupted".into())
}
})?;
let (client_versions, client_name) = loop {
match socket.read()? {
Message::Text(text) => {
if let Ok(FrpMessage::Protocol(FrpProtocolMessage::Start {
version,
name,
})) = FrpMessage::parse(&text)
{
break (version, name);
}
}
Message::Close(_) => return Err(FrpError::Closed),
_ => {}
}
};
let selected = select_version(&self.supported_versions, &client_versions);
if let Some(version) = selected {
let init = FrpProtocolMessage::Init {
version: version.clone(),
};
let json = serde_json::to_string(&init)?;
socket.send(Message::text(json))?;
Ok(FrpConnection {
socket,
version,
client_name,
})
} else {
let alert = FrpProtocolMessage::Alert {
severity: Severity::Critical,
message: "No compatible FRP version".into(),
};
let json = serde_json::to_string(&alert)?;
socket.send(Message::text(json))?;
socket.close(None)?;
Err(FrpError::Handshake("No compatible FRP version".into()))
}
}
}
impl fmt::Debug for FrpListener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FrpListener")
.field("supported_versions", &self.supported_versions)
.finish_non_exhaustive()
}
}
pub struct FrpConnection {
socket: WebSocket<TcpStream>,
version: String,
client_name: Option<String>,
}
impl FrpConnection {
#[must_use]
pub fn version(&self) -> &str {
&self.version
}
#[must_use]
pub fn client_name(&self) -> Option<&str> {
self.client_name.as_deref()
}
pub fn recv(&mut self) -> Result<FrpMessage, FrpError> {
loop {
match self.socket.read()? {
Message::Text(text) => return FrpMessage::parse(&text),
Message::Close(_) => return Err(FrpError::Closed),
_ => {}
}
}
}
pub fn try_recv(&mut self) -> Result<Option<FrpMessage>, FrpError> {
loop {
match self.socket.read() {
Ok(Message::Text(text)) => return Ok(Some(FrpMessage::parse(&text)?)),
Ok(Message::Close(_)) => return Err(FrpError::Closed),
Ok(_) => {}
Err(tungstenite::Error::Io(ref e))
if e.kind() == std::io::ErrorKind::WouldBlock =>
{
return Ok(None);
}
Err(e) => return Err(e.into()),
}
}
}
pub fn send(&mut self, msg: &FrpMessage) -> Result<(), FrpError> {
let json = msg.to_json()?;
self.socket.send(Message::text(json))?;
Ok(())
}
pub fn send_envelope(&mut self, env: &crate::FrpEnvelope) -> Result<(), FrpError> {
let json = serde_json::to_string(env)?;
self.socket.send(Message::text(json))?;
Ok(())
}
pub fn set_nonblocking(&self, nonblocking: bool) -> Result<(), FrpError> {
self.socket
.get_ref()
.set_nonblocking(nonblocking)
.map_err(|e| FrpError::WebSocket(Box::new(tungstenite::Error::Io(e))))
}
pub fn close(mut self) -> Result<(), FrpError> {
self.socket.close(None)?;
loop {
match self.socket.read() {
Ok(Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => {
return Ok(());
}
Err(tungstenite::Error::AlreadyClosed) => return Ok(()),
Err(e) => return Err(e.into()),
_ => {}
}
}
}
}
impl fmt::Debug for FrpConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FrpConnection")
.field("version", &self.version)
.field("client_name", &self.client_name)
.finish_non_exhaustive()
}
}
fn select_version(server: &[String], client: &[String]) -> Option<String> {
for cv in client {
if server.contains(cv) {
return Some(cv.clone());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn version_selection() {
let server = vec!["0.1.0".to_owned(), "0.2.0".to_owned()];
let client = vec!["0.2.0".to_owned(), "0.1.0".to_owned()];
assert_eq!(select_version(&server, &client), Some("0.2.0".to_owned()));
}
#[test]
fn version_selection_no_match() {
let server = vec!["0.1.0".to_owned()];
let client = vec!["0.2.0".to_owned()];
assert_eq!(select_version(&server, &client), None);
}
}