use crate::average::Average;
use crate::spawn;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use bytes::{BufMut, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use crate::commands::CommandDictionary;
use crate::config::Config;
use crate::platform::Platform;
use crate::request::Request;
use crate::response::OutputError;
use arc_swap::ArcSwap;
use std::sync::Mutex;
use tokio::net::tcp::WriteHalf;
const READ_WAIT_TIMEOUT: Duration = Duration::from_millis(500);
const DEFAULT_BUFFER_SIZE: usize = 8192;
const CONNECT_WAIT_TIMEOUT: Duration = Duration::from_millis(500);
pub struct Connection {
peer_address: String,
active: AtomicBool,
commands: Average,
name: ArcSwap<Option<String>>,
}
impl PartialEq for Connection {
fn eq(&self, other: &Self) -> bool {
self.peer_address == other.peer_address
}
}
impl Connection {
pub fn is_active(&self) -> bool {
self.active.load(Ordering::Acquire)
}
pub fn quit(&self) {
self.active.store(false, Ordering::Release);
}
pub fn set_name(&self, name: &str) {
self.name.store(Arc::new(Some(name.to_owned())));
}
pub fn get_name(&self) -> Arc<Option<String>> {
self.name.load().clone()
}
pub fn commands(&self) -> &Average {
&self.commands
}
}
pub struct ConnectionInfo {
pub peer_address: String,
pub client: String,
pub commands: Average,
}
pub struct Server {
running: AtomicBool,
current_address: Mutex<Option<String>>,
platform: Arc<Platform>,
connections: Mutex<Vec<Arc<Connection>>>,
}
impl Server {
pub fn install(platform: &Arc<Platform>) -> Arc<Self> {
let server = Arc::new(Server {
running: AtomicBool::new(false),
current_address: Mutex::new(None),
platform: platform.clone(),
connections: Mutex::new(Vec::new()),
});
platform.register::<Server>(server.clone());
server
}
pub fn connections(&self) -> Vec<ConnectionInfo> {
let mut result = Vec::new();
for connection in self.connections.lock().unwrap().iter() {
result.push(ConnectionInfo {
peer_address: connection.peer_address.clone(),
commands: connection.commands.clone(),
client: connection
.name
.load()
.as_deref()
.map(|name| name.to_string())
.unwrap_or_else(|| "".to_string()),
});
}
result
}
pub fn kill(&self, peer_address: &str) -> bool {
self.connections
.lock()
.unwrap()
.iter()
.find(|c| c.peer_address == peer_address)
.map(|c| c.active.store(false, Ordering::Release))
.is_some()
}
fn add_connection(&self, connection: Arc<Connection>) {
self.connections.lock().unwrap().push(connection);
}
fn remove_connection(&self, connection: Arc<Connection>) {
let mut mut_connections = self.connections.lock().unwrap();
if let Some(index) = mut_connections
.iter()
.position(|other| *other == connection)
{
let _ = mut_connections.remove(index);
}
}
fn is_running(&self) -> bool {
self.running.load(Ordering::Acquire)
}
fn address(&self) -> String {
self.platform
.find::<Config>()
.map(|config| {
let handle = config.current();
format!(
"{}:{}",
handle.config()["server"]["host"]
.as_str()
.unwrap_or("0.0.0.0"),
handle.config()["server"]["port"]
.as_i64()
.filter(|port| port > &0 && port <= &(u16::MAX as i64))
.unwrap_or(2410)
)
})
.unwrap_or_else(|| "0.0.0.0:2410".to_owned())
}
pub fn fork(server: &Arc<Server>) {
let cloned_server = server.clone();
spawn!(async move {
cloned_server.event_loop().await;
});
}
pub async fn fork_and_await(server: &Arc<Server>) {
Server::fork(server);
while !server.is_running() {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
pub async fn event_loop(&self) {
let mut address = String::new();
let mut last_bind_error_reported = Instant::now();
while self.platform.is_running() {
if !self.is_running() {
address = self.address();
self.running.store(true, Ordering::Release);
}
if let Ok(mut listener) = TcpListener::bind(&address).await {
log::info!("Opened server socket on {}...", &address);
*self.current_address.lock().unwrap() = Some(address.clone());
self.server_loop(&mut listener).await;
log::info!("Closing server socket on {}.", &address);
} else {
if Instant::now()
.duration_since(last_bind_error_reported)
.as_secs()
> 5
{
log::error!(
"Cannot open server address: {}. Retrying every 500ms...",
&address
);
last_bind_error_reported = Instant::now();
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
}
async fn server_loop(&self, listener: &mut TcpListener) {
let mut config_changed_flag = self.platform.require::<Config>().notifier();
while self.platform.is_running() && self.is_running() {
tokio::select! {
timeout_stream = tokio::time::timeout(CONNECT_WAIT_TIMEOUT, listener.accept()) => {
if let Ok(stream) = timeout_stream {
if let Ok((stream, _)) = stream {
self.handle_new_connection(stream);
} else {
return;
}
}
}
_ = config_changed_flag.recv() => {
let new_address = self.address();
if let Some(current_address) = &*self.current_address.lock().unwrap() {
if current_address != &new_address {
log::info!("Server address has changed. Restarting server socket...");
self.running.store(false, Ordering::Release);
return;
}
}
}
}
}
}
fn handle_new_connection(&self, stream: TcpStream) {
let platform = self.platform.clone();
spawn!(async move {
let _ = stream.set_nodelay(true);
let server = platform.require::<Server>();
let connection = Arc::new(Connection {
peer_address: stream
.peer_addr()
.map(|addr| addr.to_string())
.unwrap_or_else(|_| "<unknown>".to_owned()),
active: AtomicBool::new(true),
commands: Average::new(),
name: ArcSwap::new(Arc::new(None)),
});
log::debug!("Opened connection from {}...", connection.peer_address);
server.add_connection(connection.clone());
if let Err(error) = resp_protocol_loop(platform, connection.clone(), stream).await {
log::debug!(
"An IO error occurred in connection {}: {}",
connection.peer_address,
error
);
}
log::debug!("Closing connection to {}...", connection.peer_address);
server.remove_connection(connection);
});
}
}
async fn resp_protocol_loop(
platform: Arc<Platform>,
connection: Arc<Connection>,
mut stream: TcpStream,
) -> anyhow::Result<()> {
let mut dispatcher = platform.require::<CommandDictionary>().dispatcher();
let mut input_buffer = BytesMut::with_capacity(DEFAULT_BUFFER_SIZE);
let (mut reader, mut writer) = stream.split();
while platform.is_running() && connection.is_active() {
match tokio::time::timeout(READ_WAIT_TIMEOUT, reader.read_buf(&mut input_buffer)).await {
Ok(Ok(bytes_read)) if bytes_read > 0 => match Request::parse(&input_buffer) {
Ok(Some(request)) => {
log::debug!("Received {}", request.command());
let watch = Instant::now();
let request_len = request.len();
match dispatcher.invoke(request, Some(&connection)).await {
Ok(response_data) => {
connection.commands.add(watch.elapsed().as_micros() as i32);
writer.write_all(response_data.as_ref()).await?;
writer.flush().await?;
}
Err(error) => {
handle_error(error, &mut writer).await?;
return Ok(());
}
}
input_buffer = clear_input_buffer(input_buffer, request_len);
}
Err(error) => {
handle_protocol_error(error, &mut writer).await?;
return Ok(());
}
_ => (),
},
Ok(Ok(0)) => return Ok(()),
Ok(Err(error)) => {
return Err(anyhow::anyhow!(
"An error occurred while reading from the client: {}",
error
));
}
_ => (),
}
}
Ok(())
}
async fn handle_error(error: OutputError, writer: &mut WriteHalf<'_>) -> anyhow::Result<()> {
if let OutputError::ProtocolError(error) = error {
let error_message = error.to_string().replace(['\r', '\n'], " ");
writer
.write_all(format!("-SERVER: {}\r\n", error_message).as_bytes())
.await?;
writer.flush().await?;
}
Ok(())
}
async fn handle_protocol_error(
error: anyhow::Error,
writer: &mut WriteHalf<'_>,
) -> anyhow::Result<()> {
writer
.write_all(
format!(
"-CLIENT: A malformed RESP request was received: {}\r\n",
error
)
.as_bytes(),
)
.await?;
writer.flush().await?;
Ok(())
}
fn clear_input_buffer(mut input_buffer: BytesMut, request_len: usize) -> BytesMut {
if input_buffer.capacity() > DEFAULT_BUFFER_SIZE || input_buffer.len() > request_len {
let previous_buffer = input_buffer;
input_buffer = BytesMut::with_capacity(DEFAULT_BUFFER_SIZE);
if previous_buffer.len() > request_len {
input_buffer.put_slice(&previous_buffer[request_len..]);
}
} else {
input_buffer.truncate(0);
}
input_buffer
}
#[cfg(test)]
mod tests {
use crate::builder::Builder;
use crate::config::Config;
use crate::server::Server;
use crate::testing::{query_redis_async, test_async};
#[test]
fn integration_test() {
log::info!("Acquiring shared resources...");
let _guard = crate::testing::SHARED_TEST_RESOURCES.lock().unwrap();
log::info!("Successfully acquired shared resources.");
test_async(async {
let platform = Builder::new().enable_all().disable_config().build().await;
let _ = crate::config::install(platform.clone(), false).await;
platform
.require::<Config>()
.load_from_string(
"
server:
port: 1503
",
None,
)
.unwrap();
Server::fork_and_await(&platform.require::<Server>()).await;
let result = query_redis_async(|con| redis::cmd("PING").query::<String>(con))
.await
.unwrap();
assert_eq!(result, "PONG");
platform.terminate();
});
}
}