mod executor;
mod permission;
pub use permission::*;
use crate::{
embedded::{
language::QueryExecutor,
messaging::{Request, Response},
storage::Storage,
},
err,
};
use executor::Executor;
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::{
sync::{mpsc, oneshot, Mutex, RwLock},
time,
};
struct ServerRequest {
sender: oneshot::Sender<Response>,
request: Request,
}
#[derive(Clone)]
pub struct Server {
executor: Arc<Executor>,
sender: mpsc::UnboundedSender<ServerRequest>,
receiver: Arc<Mutex<mpsc::UnboundedReceiver<ServerRequest>>>,
permission: Arc<RwLock<Permission>>,
}
impl Server {
pub fn new(storage: impl Storage) -> Self {
let (sender, receiver) = mpsc::unbounded_channel();
Self {
executor: Arc::new(Executor::new(storage)),
sender,
receiver: Arc::new(Mutex::new(receiver)),
permission: Default::default(),
}
}
pub async fn set_permission(&self, permission: Permission) {
*self.permission.write().await = permission;
}
pub async fn start(&self) {
let server = self.clone();
tokio::spawn(async move {
server.listen().await;
});
}
pub async fn listen(&self) {
while let Some(request) = self.receiver.lock().await.recv().await {
let ServerRequest { sender, request } = request;
let executor = Arc::clone(&self.executor);
let permission = Arc::clone(&self.permission);
tokio::spawn(async move {
let is_allowed = { permission.read().await.allowed(&request) };
if let Err(error) = is_allowed {
sender.send(error.as_response()).ok();
} else {
let response = match request {
Request::Set(key, value) => executor.set(key, value).await,
Request::Get(key) => executor.get(key).await,
Request::Delete(key) => executor.delete(key).await,
Request::Exists(key) => executor.exists(key).await,
Request::Increment(key, num) => executor.increment(key, num).await,
Request::Decrement(key, num) => executor.decrement(key, num).await,
Request::Search(key) => executor.search(key).await,
Request::Flush => executor.flush().await,
Request::DowngradePermission => {
let mut permission = permission.write().await;
*permission = permission.lower();
Response::Ok
}
};
sender.send(response).ok();
}
});
}
}
pub async fn cast(&self, request: Request) -> super::Result<oneshot::Receiver<Response>> {
let (sender, receiver) = oneshot::channel();
let request = ServerRequest { sender, request };
if self.sender.send(request).is_err() {
Err(err!(embedded, SendFail))
} else {
Ok(receiver)
}
}
pub async fn call(&self, request: Request) -> super::Result<Response> {
self.cast(request)
.await?
.await
.map_err(|_| err!(embedded, RecvFail))
}
pub async fn call_in(&self, request: Request, timeout: Duration) -> super::Result<Response> {
time::timeout(timeout, self.call(request))
.await
.map_err(|_| err!(embedded, RecvTimeout))?
}
pub async fn query<T>(
&self,
query: T,
env: HashMap<String, String>,
) -> super::Result<Vec<Response>>
where
T: ToString,
{
let mut runtime = QueryExecutor::new(query.to_string(), env);
runtime.execute(self).await
}
}