use base64::{Engine, prelude::BASE64_STANDARD_NO_PAD};
use futures_concurrency::future::Race;
use futures_util::{AsyncReadExt, AsyncWriteExt};
use stdcode::StdcodeSerializeExt;
use crate::{Config, auth::IS_PLUS, bw_token::bw_token_consume};
use anyctx::AnyCtx;
use std::{
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::{Duration, Instant},
};
const BW_TOKEN_BYTES: u64 = 10_000_000;
#[derive(Clone)]
pub struct BwAccountingHandle {
force_refresh: smol::channel::Sender<()>,
accounted_bytes: Arc<AtomicU64>,
}
pub struct BwAccountingLoop {
handle: BwAccountingHandle,
force_refresh: smol::channel::Receiver<()>,
}
pub fn bw_accounting_pair() -> (BwAccountingHandle, BwAccountingLoop) {
let (send, recv) = smol::channel::unbounded();
let handle = BwAccountingHandle {
force_refresh: send,
accounted_bytes: Arc::new(AtomicU64::new(0)),
};
let accounting_loop = BwAccountingLoop {
handle: handle.clone(),
force_refresh: recv,
};
(handle, accounting_loop)
}
pub fn notify_bw_accounting(
ctx: &AnyCtx<Config>,
accounting: &BwAccountingHandle,
consumed: usize,
) {
if !ctx.get(IS_PLUS).load(Ordering::SeqCst) {
return;
}
let old_bytes = accounting
.accounted_bytes
.fetch_add(consumed as u64, Ordering::SeqCst);
let new_bytes = old_bytes + consumed as u64;
let tokens_due = new_bytes / BW_TOKEN_BYTES - old_bytes / BW_TOKEN_BYTES;
for _ in 0..tokens_due {
let _ = accounting.force_refresh.try_send(());
}
}
#[tracing::instrument(skip_all)]
pub async fn bw_accounting_client_loop(
ctx: AnyCtx<Config>,
stream: picomux::Stream,
accounting: BwAccountingLoop,
) -> anyhow::Result<()> {
tracing::info!("starting bandwidth accounting");
let (mut read, mut write) = stream.split();
let read_fut = {
async {
let mut buf = [0u8; 8];
let mut last_log = Instant::now() - Duration::from_secs(5);
loop {
read.read_exact(&mut buf).await?;
let current_bytes = u64::from_be_bytes(buf);
if last_log.elapsed() >= Duration::from_secs(5) {
tracing::debug!(current_bytes, "received remote bandwidth accounting update");
last_log = Instant::now();
}
if current_bytes == 0 {
tracing::debug!(
current_bytes,
"remote bytes went to zero, FORCING a refresh"
);
}
if current_bytes == 0 {
let _ = accounting.handle.force_refresh.send(()).await;
}
}
}
};
let write_fut = {
async {
loop {
let _ = accounting.force_refresh.recv().await;
if !ctx.get(IS_PLUS).load(Ordering::SeqCst) {
tracing::debug!("not plus, skipping bandwidth token send");
continue;
}
let (token, sig) = bw_token_consume(&ctx).await?;
let enc = BASE64_STANDARD_NO_PAD.encode((token, sig).stdcode());
write.write_all(enc.as_bytes()).await?;
write.write_all(b"\n").await?;
tracing::debug!("consuming a bandwidth token");
}
}
};
(read_fut, write_fut).race().await
}