use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::envelope::{DataCategory, Envelope};
use crate::options::ClientOptions;
mod ratelimit;
pub use ratelimit::RateLimiter;
pub trait Transport: Send + Sync {
fn send_envelope(&self, envelope: Envelope);
fn flush(&self, timeout: Duration) -> bool;
fn shutdown(&self, timeout: Duration) -> bool;
}
pub trait TransportFactory: Send + Sync {
fn create_transport(&self, options: &ClientOptions) -> Arc<dyn Transport>;
}
pub struct DefaultTransportFactory;
impl TransportFactory for DefaultTransportFactory {
fn create_transport(&self, options: &ClientOptions) -> Arc<dyn Transport> {
Arc::new(ReqwestTransport::new(options))
}
}
const MAX_RETRIES: u32 = 3;
enum Message {
Send(Envelope),
Flush(std::sync::mpsc::SyncSender<()>),
}
pub struct ReqwestTransport {
sender: Mutex<Option<std::sync::mpsc::SyncSender<Message>>>,
handle: Mutex<Option<std::thread::JoinHandle<()>>>,
}
impl ReqwestTransport {
pub fn new(options: &ClientOptions) -> Self {
let (tx, rx) = std::sync::mpsc::sync_channel::<Message>(options.transport_queue_size);
let host = options.host.trim_end_matches('/').to_string();
let api_key = options.api_key.clone();
let user_agent = crate::util::user_agent();
let handle = std::thread::Builder::new()
.name("allstak-transport".into())
.spawn(move || {
let rt = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(rt) => rt,
Err(_) => return,
};
rt.block_on(worker_loop(rx, host, api_key, user_agent));
})
.ok();
ReqwestTransport {
sender: Mutex::new(Some(tx)),
handle: Mutex::new(handle),
}
}
}
async fn worker_loop(
rx: std::sync::mpsc::Receiver<Message>,
host: String,
api_key: String,
user_agent: String,
) {
let client = match reqwest::Client::builder()
.user_agent(user_agent)
.timeout(Duration::from_secs(30))
.build()
{
Ok(c) => c,
Err(_) => return,
};
let limiter = RateLimiter::new();
loop {
let msg = match rx.recv() {
Ok(m) => m,
Err(_) => break, };
match msg {
Message::Send(env) => {
deliver(&client, &host, &api_key, &limiter, env).await;
}
Message::Flush(ack) => {
let _ = ack.send(());
}
}
}
}
async fn deliver(
client: &reqwest::Client,
host: &str,
api_key: &str,
limiter: &RateLimiter,
env: Envelope,
) {
if limiter.is_limited(env.category) {
return; }
let url = format!("{host}{}", env.path);
let mut attempt = 0u32;
loop {
let res = client
.post(&url)
.header("X-AllStak-Key", api_key)
.header("Content-Type", "application/json")
.json(&env.body)
.send()
.await;
match res {
Ok(resp) => {
let status = resp.status();
if status.as_u16() == 429 {
limiter.update_from_response(env.category, resp.headers());
return;
}
if status.is_server_error() {
if attempt >= MAX_RETRIES {
return;
}
backoff(attempt).await;
attempt += 1;
continue;
}
return;
}
Err(_) => {
if attempt >= MAX_RETRIES {
return;
}
backoff(attempt).await;
attempt += 1;
}
}
}
}
async fn backoff(attempt: u32) {
let millis = 250u64.saturating_mul(1 << attempt).min(4000);
tokio::time::sleep(Duration::from_millis(millis)).await;
}
impl Transport for ReqwestTransport {
fn send_envelope(&self, envelope: Envelope) {
if let Ok(guard) = self.sender.lock() {
if let Some(tx) = guard.as_ref() {
let _ = tx.try_send(Message::Send(envelope));
}
}
}
fn flush(&self, timeout: Duration) -> bool {
let tx = match self.sender.lock() {
Ok(g) => g.as_ref().cloned(),
Err(_) => None,
};
let Some(tx) = tx else {
return true;
};
let (ack_tx, ack_rx) = std::sync::mpsc::sync_channel::<()>(0);
if tx.send(Message::Flush(ack_tx)).is_err() {
return true; }
ack_rx.recv_timeout(timeout).is_ok()
}
fn shutdown(&self, timeout: Duration) -> bool {
let drained = self.flush(timeout);
if let Ok(mut guard) = self.sender.lock() {
guard.take();
}
if let Ok(mut h) = self.handle.lock() {
if let Some(handle) = h.take() {
let _ = handle.join();
}
}
drained
}
}
impl Drop for ReqwestTransport {
fn drop(&mut self) {
self.shutdown(Duration::from_secs(2));
}
}
#[derive(Clone, Default)]
pub struct StubTransport {
sent: Arc<Mutex<Vec<Envelope>>>,
}
impl StubTransport {
pub fn new() -> Self {
StubTransport::default()
}
pub fn sent(&self) -> Vec<Envelope> {
self.sent.lock().map(|v| v.clone()).unwrap_or_default()
}
pub fn sent_for(&self, category: DataCategory) -> Vec<Envelope> {
self.sent()
.into_iter()
.filter(|e| e.category == category)
.collect()
}
}
impl Transport for StubTransport {
fn send_envelope(&self, envelope: Envelope) {
if let Ok(mut v) = self.sent.lock() {
v.push(envelope);
}
}
fn flush(&self, _timeout: Duration) -> bool {
true
}
fn shutdown(&self, _timeout: Duration) -> bool {
true
}
}
pub struct StubTransportFactory {
transport: StubTransport,
}
impl StubTransportFactory {
pub fn new(transport: StubTransport) -> Self {
StubTransportFactory { transport }
}
}
impl TransportFactory for StubTransportFactory {
fn create_transport(&self, _options: &ClientOptions) -> Arc<dyn Transport> {
Arc::new(self.transport.clone())
}
}
pub(crate) type LimitMap = HashMap<DataCategory, Instant>;