use crate::HashMap;
use rustdb::{GenTransaction, Transaction};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::{broadcast, mpsc, oneshot};
pub struct SharedState {
pub spd: Arc<rustdb::SharedPagedData>,
pub bmap: Arc<rustdb::BuiltinMap>,
pub update_tx: mpsc::Sender<UpdateMessage>,
pub email_tx: mpsc::UnboundedSender<()>,
pub sleep_tx: mpsc::UnboundedSender<u64>,
pub wait_tx: broadcast::Sender<()>,
pub is_master: bool,
pub replicate_source: String,
pub replicate_credentials: String,
pub dos_limit: UA,
pub dos: Mutex<HashMap<String, UseInfo>>,
pub tracetime: bool,
pub tracedos: bool,
pub tracemem: bool,
}
pub type UA = [u64; 4];
pub const U_COUNT: usize = 0;
pub const U_READ: usize = 1;
pub const U_CPU: usize = 2;
pub const U_WRITE: usize = 3;
#[derive(Debug)]
pub struct UseInfo {
pub used: UA,
pub limit: UA,
}
impl UseInfo {
fn new(limit: &UA) -> Self {
Self {
used: [0, 0, 0, 0],
limit: *limit,
}
}
}
impl Default for UseInfo {
fn default() -> Self {
Self::new(&[0, 0, 0, 0])
}
}
impl SharedState {
pub fn u_budget(&self, uid: String) -> UA {
let mut m = self.dos.lock().unwrap();
let info = m
.entry(uid)
.or_insert_with(|| UseInfo::new(&self.dos_limit));
let mut result = [0; 4];
for (i, item) in result.iter_mut().enumerate() {
if info.used[i] >= info.limit[i] {
return [0; 4];
}
*item = info.limit[i] - info.used[i];
}
result
}
pub fn u_inc(&self, uid: &str, amount: UA) {
let mut m = self.dos.lock().unwrap();
if let Some(info) = m.get_mut(uid) {
for (i, amt) in amount.iter().enumerate() {
info.used[i] += *amt;
}
if self.tracedos {
println!(
"uid={} Count={}% Read={}% Cpu={}% Write={}%",
uid,
100. * info.used[0] as f32 / info.limit[0] as f32,
100. * info.used[1] as f32 / info.limit[1] as f32,
100. * info.used[2] as f32 / info.limit[2] as f32,
100. * info.used[3] as f32 / info.limit[3] as f32,
);
}
}
}
pub fn u_set_limits(&self, uid: String, limit: UA) -> bool {
let mut m = self.dos.lock().unwrap();
let info = m.entry(uid).or_default();
info.limit = limit;
for i in 0..4 {
if info.used[i] >= info.limit[i] {
return false;
}
}
true
}
pub fn u_decay(&self) {
let mut m = self.dos.lock().unwrap();
m.retain(|_uid, info| {
let mut nz = false;
for i in 0..4 {
if info.used[i] > 0 {
info.used[i] -= 1 + info.used[i] / 10;
}
nz = nz || info.used[i] > 0;
}
nz
});
}
pub fn new_trans(&self) {
let _ = self.wait_tx.send(());
}
pub async fn process(&self, mut trans: Trans) -> Trans {
let start = std::time::SystemTime::now();
let mut trans = if trans.readonly {
let spd = self.spd.clone();
let bmap = self.bmap.clone();
let task = tokio::task::spawn_blocking(move || {
let apd = spd.new_reader();
let db = rustdb::Database::new(apd, "", bmap);
let sql = trans.x.qy.sql.clone();
db.run(&sql, &mut trans.x);
trans
});
task.await.unwrap()
} else {
let (reply, rx) = oneshot::channel::<Trans>();
let _ = self.update_tx.send(UpdateMessage { trans, reply }).await;
rx.await.unwrap()
};
if trans.updates > 0 {
self.new_trans();
}
trans.run_time = start.elapsed().unwrap();
let ext = trans.x.get_extension();
if let Some(ext) = ext.downcast_ref::<TransExt>() {
trans.uid = ext.uid.clone();
if self.is_master {
if ext.sleep > 0 {
let _ = self.sleep_tx.send(ext.sleep);
}
if ext.tx_email {
let _ = self.email_tx.send(());
}
}
if ext.trans_wait {
let mut wait_rx = self.wait_tx.subscribe();
tokio::select! {
_ = wait_rx.recv() => {}
_ = tokio::time::sleep(Duration::from_secs(600)) => {}
}
}
if ext.trans_flush {
let spd = self.spd.clone();
let _ = tokio::task::spawn_blocking(move || {
spd.wait_complete();
})
.await;
}
if ext.to_pdf {
trans.convert_to_pdf();
}
}
trans.x.set_extension(ext);
trans
}
}
pub struct Trans {
pub x: GenTransaction,
pub log: bool,
pub readonly: bool,
pub run_time: core::time::Duration,
pub updates: usize,
pub uid: String,
}
impl Trans {
fn make() -> Self {
Self {
x: GenTransaction::new(),
log: true,
readonly: false,
run_time: Duration::from_micros(0),
updates: 0,
uid: String::new(),
}
}
pub fn new() -> Self {
let mut result = Self::make();
result.x.ext = TransExt::new();
result
}
pub fn new_with_state(ss: Arc<SharedState>, ip: String) -> Self {
let mut result = Self::make();
let mut ext = TransExt::new();
ext.ss = Some(ss);
ext.uid = ip;
result.x.ext = ext;
result
}
fn convert_to_pdf(&mut self) {
let source = &self.x.rp.output;
let mut w = pdf_min::Writer::default();
pdf_min::html(&mut w, source);
w.finish();
self.x.rp.output = w.b.b;
}
pub fn no_log(&mut self) -> bool {
let mut result = false;
let ext = self.x.get_extension();
if let Some(ext) = ext.downcast_ref::<TransExt>() {
result = ext.no_log;
}
self.x.set_extension(ext);
result
}
}
impl Default for Trans {
fn default() -> Self {
Self::new()
}
}
pub struct UpdateMessage {
pub trans: Trans,
pub reply: oneshot::Sender<Trans>,
}
pub struct TransExt {
pub ss: Option<Arc<SharedState>>,
pub uid: String,
pub tx_email: bool,
pub sleep: u64,
pub trans_wait: bool,
pub trans_flush: bool,
pub to_pdf: bool,
pub no_log: bool,
}
impl TransExt {
fn new() -> Box<Self> {
Box::new(Self {
ss: None,
uid: String::new(),
tx_email: false,
sleep: 0,
trans_wait: false,
trans_flush: false,
to_pdf: false,
no_log: false,
})
}
pub fn set_dos(&self, uid: String, to: UA) -> bool {
if let Some(ss) = &self.ss {
ss.u_set_limits(uid, to)
} else {
true
}
}
}
#[derive(Debug)]
pub struct Error {
pub code: u16,
}
impl From<std::io::Error> for Error {
fn from(_e: std::io::Error) -> Self {
Self { code: 500 }
}
}
impl From<std::str::Utf8Error> for Error {
fn from(_e: std::str::Utf8Error) -> Self {
Self { code: 400 } }
}
impl From<serde_urlencoded::de::Error> for Error {
fn from(_e: serde_urlencoded::de::Error) -> Self {
Self { code: 400 } }
}
impl std::error::Error for Error {}
impl core::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "{}", self.code)
}
}