use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use anyhow::anyhow;
use bytes::BytesMut;
use crate::average::Average;
use crate::platform::Platform;
use crate::request::Request;
use crate::response::{OutputError, Response};
use crate::server::Connection;
#[derive(Debug)]
pub enum CommandError {
OutputError(OutputError),
ClientError(anyhow::Error),
ServerError(anyhow::Error),
}
#[macro_export]
macro_rules! server_error {
($err:expr $(,)?) => ({
jupiter::commands::CommandError::ServerError(anyhow::anyhow!($err))
});
($fmt:expr, $($arg:tt)*) => {
jupiter::commands::CommandError::ServerError(anyhow::anyhow!($fmt, $($arg)*))
};
}
#[macro_export]
macro_rules! client_error {
($err:expr $(,)?) => ({
jupiter::commands::CommandError::ClientError(anyhow::anyhow!($err))
});
($fmt:expr, $($arg:tt)*) => {
jupiter::commands::CommandError::ClientError(anyhow::anyhow!($fmt, $($arg)*))
};
}
impl From<OutputError> for CommandError {
fn from(output_error: OutputError) -> Self {
CommandError::OutputError(output_error)
}
}
impl From<anyhow::Error> for CommandError {
fn from(error: anyhow::Error) -> Self {
CommandError::ClientError(error)
}
}
pub type CommandResult = std::result::Result<(), CommandError>;
pub trait ResultExt {
fn complete(self, call: Call);
}
impl ResultExt for CommandResult {
fn complete(self, call: Call) {
call.complete(self);
}
}
pub struct Call {
pub request: Request,
pub response: Response,
pub token: usize,
callback: tokio::sync::oneshot::Sender<Result<BytesMut, OutputError>>,
}
impl Call {
#[allow(clippy::question_mark)] pub fn complete(mut self, result: CommandResult) {
let result = match result {
Ok(_) => self.response.complete(),
Err(CommandError::OutputError(error)) => Err(error),
Err(CommandError::ClientError(error)) => {
if let Err(error) = self.response.error(format!("CLIENT: {}", error)) {
Err(error)
} else {
self.response.complete()
}
}
Err(CommandError::ServerError(error)) => {
if let Err(error) = self.response.error(format!("SERVER: {}", error)) {
Err(error)
} else {
self.response.complete()
}
}
};
if self.callback.send(result).is_err() {
log::error!("Failed to submit a result to a oneshot callback channel!");
}
}
pub fn handle_unknown_token(self) {
let token = self.token;
self.complete(Err(CommandError::ServerError(anyhow::anyhow!(
"Unknown token received: {}!",
token
))));
}
}
pub type Queue = tokio::sync::mpsc::Sender<Call>;
pub type Endpoint = tokio::sync::mpsc::Receiver<Call>;
pub fn queue() -> (Queue, Endpoint) {
tokio::sync::mpsc::channel(1024)
}
pub struct Command {
pub name: &'static str,
queue: Queue,
token: usize,
call_metrics: Average,
}
impl Command {
pub fn call_count(&self) -> u64 {
self.call_metrics.count()
}
pub fn avg_duration(&self) -> i32 {
self.call_metrics.avg()
}
}
#[derive(Default)]
pub struct CommandDictionary {
commands: Mutex<HashMap<&'static str, Arc<Command>>>,
}
pub struct Dispatcher {
commands: HashMap<&'static str, (Arc<Command>, Queue)>,
}
impl CommandDictionary {
pub fn new() -> Self {
CommandDictionary {
commands: Mutex::new(HashMap::default()),
}
}
pub fn install(platform: &Arc<Platform>) -> Arc<Self> {
let commands = Arc::new(CommandDictionary::new());
platform.register::<CommandDictionary>(commands.clone());
commands
}
pub fn register_command(&self, name: &'static str, queue: Queue, token: usize) {
let mut commands = self.commands.lock().unwrap();
if commands.get(name).is_some() {
log::error!("Not going to register command {} as there is already a command present for this name",
name);
} else {
log::debug!("Registering command {}...", name);
let _ = commands.insert(
name,
Arc::new(Command {
name,
queue,
token,
call_metrics: Average::new(),
}),
);
}
}
pub fn commands(&self) -> Vec<Arc<Command>> {
let mut result = Vec::new();
for command in self.commands.lock().unwrap().values() {
result.push(command.clone());
}
result
}
pub fn dispatcher(&self) -> Dispatcher {
let commands = self.commands.lock().unwrap();
let mut cloned_commands = HashMap::with_capacity(commands.len());
for command in commands.values() {
let _ = cloned_commands.insert(command.name, (command.clone(), command.queue.clone()));
}
Dispatcher {
commands: cloned_commands,
}
}
}
impl Dispatcher {
pub async fn invoke(
&mut self,
request: Request,
connection: Option<&Arc<Connection>>,
) -> Result<BytesMut, OutputError> {
let response = Response::new();
match self.commands.get_mut(request.command()) {
Some((command, queue)) => {
Dispatcher::invoke_command(command, queue, request, response).await
}
_ => self.handle_built_in(request, response, connection).await,
}
}
async fn handle_built_in(
&mut self,
request: Request,
mut response: Response,
connection: Option<&Arc<Connection>>,
) -> Result<BytesMut, OutputError> {
match request.command().to_uppercase().as_str() {
"QUIT" => {
if let Some(connection) = connection {
connection.quit();
}
response.ok()?;
}
"CLIENT" => {
if request.str_parameter(0)?.to_uppercase() == "SETNAME" {
if let Some(connection) = connection {
connection.set_name(request.str_parameter(1)?);
}
}
response.ok()?;
}
"PING" => {
if request.parameter_count() > 0 {
response.bulk(request.str_parameter(0)?)?;
} else {
response.simple("PONG")?;
}
}
_ => response.error(format!("CLIENT: Unknown command: {}", request.command()))?,
}
response.complete()
}
async fn invoke_command(
command: &Arc<Command>,
queue: &mut Queue,
request: Request,
response: Response,
) -> Result<BytesMut, OutputError> {
let (callback, promise) = tokio::sync::oneshot::channel();
let task = Call {
request,
response,
callback,
token: command.token,
};
let watch = Instant::now();
if queue.send(task).await.is_err() {
Err(OutputError::ProtocolError(anyhow!(
"Failed to submit command into queue!"
)))
} else {
match promise.await {
Ok(result) => {
command.call_metrics.add(watch.elapsed().as_micros() as i32);
result
}
_ => Err(OutputError::ProtocolError(anyhow!(
"Command {} did not yield any result!",
command.name
))),
}
}
}
}