sark-client 0.2.1

Simple Asynchronous Rust webKit - Client
use std::io::Read;
use std::ptr::NonNull;
use std::time::{Duration, Instant};

use cartel_core::{FatalSlot, Slab};
use dope::WakerSet;
use dope::manifold::connector;
use dope::runtime::token::Token;
use o3::buffer::Owned;
use sark_core::http::Response;
use sark_core::http::codec::{DecodeMode, HeaderLookup, Parse};

use crate::connector::codec::{self, Head};
use crate::connector::error::Error;
use crate::connector::retry::RetryPolicy;

pub(super) type Outcome = Result<Response, Error>;

pub(super) const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30);

pub(super) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(15);

const KEEPALIVE_MARGIN: Duration = Duration::from_secs(1);

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum DecompressionPolicy {
    #[default]
    Strict,
    Lenient,
}

#[derive(Default)]
pub struct ConnState {
    pending_close: bool,
}

impl connector::Lifecycle for ConnState {
    fn wants_close(&self) -> connector::Close {
        if self.pending_close {
            connector::Close::Reconnect
        } else {
            connector::Close::Keep
        }
    }

    fn defer_close(&self) -> bool {
        false
    }

    fn is_drained(&self) -> bool {
        true
    }
}

struct ConnEntry {
    conn_id: Token,
    slab_ix: usize,
    last_activity: Instant,
    keepalive: Option<Duration>,
}

pub struct Shared {
    conns: Vec<ConnEntry>,
    #[allow(clippy::vec_box)]
    slabs: Vec<Box<Slab<Outcome>>>,
    pub active_wakers: WakerSet,
    pub fatal: FatalSlot<Error>,
    pub host: String,
    pub decompression: DecompressionPolicy,
    pub max_redirects: u32,
    pub retry: RetryPolicy,
    pub idle_timeout: Duration,
    pub request_timeout: Duration,
}

impl Shared {
    fn new(host: String) -> Self {
        Self {
            conns: Vec::new(),
            slabs: Vec::new(),
            active_wakers: WakerSet::new(),
            fatal: FatalSlot::default(),
            host,
            decompression: DecompressionPolicy::Strict,
            max_redirects: 10,
            retry: RetryPolicy::default(),
            idle_timeout: DEFAULT_IDLE_TIMEOUT,
            request_timeout: DEFAULT_REQUEST_TIMEOUT,
        }
    }

    pub fn any_ready(&self) -> bool {
        !self.conns.is_empty()
    }

    pub fn live_conns(&self) -> usize {
        self.conns.len()
    }

    fn alloc_slab(&mut self) -> usize {
        for ix in 0..self.slabs.len() {
            let bound = self.conns.iter().any(|c| c.slab_ix == ix);
            if !bound && self.slabs[ix].is_drained() {
                return ix;
            }
        }
        self.slabs.push(Box::new(Slab::new()));
        self.slabs.len() - 1
    }

    fn note_connect(&mut self, conn_id: Token, now: Instant) {
        if self.conns.iter().any(|c| c.conn_id == conn_id) {
            return;
        }
        let slab_ix = self.alloc_slab();
        self.conns.push(ConnEntry {
            conn_id,
            slab_ix,
            last_activity: now,
            keepalive: None,
        });
        self.active_wakers.drain_wake();
    }

    fn push_response(
        &mut self,
        conn_id: Token,
        outcome: Outcome,
        keepalive: Option<Duration>,
        now: Instant,
    ) {
        let Some(pos) = self.conns.iter().position(|c| c.conn_id == conn_id) else {
            return;
        };
        let slab_ix = {
            let c = &mut self.conns[pos];
            c.last_activity = now;
            if keepalive.is_some() {
                c.keepalive = keepalive;
            }
            c.slab_ix
        };
        self.slabs[slab_ix].push(outcome);
        self.slabs[slab_ix].complete();
        self.active_wakers.drain_wake();
    }

    fn fail_connection(&mut self, conn_id: Token, fatal: Option<String>) {
        if let Some(pos) = self.conns.iter().position(|c| c.conn_id == conn_id) {
            let slab_ix = self.conns[pos].slab_ix;
            self.slabs[slab_ix].fail_all(|| match &fatal {
                Some(m) => Err(Error::Http(m.clone())),
                None => Err(Error::Closed),
            });
            self.conns.remove(pos);
        }
        self.active_wakers.drain_wake();
    }

    pub fn drop_conn(&mut self, conn_id: Token) {
        if let Some(pos) = self.conns.iter().position(|c| c.conn_id == conn_id) {
            let slab_ix = self.conns[pos].slab_ix;
            self.slabs[slab_ix].fail_all(|| Err(Error::Closed));
            self.conns.remove(pos);
        }
    }

    pub fn acquire(&mut self, now: Instant, idle_timeout: Duration) -> (Option<Token>, Vec<Token>) {
        let mut chosen: Option<usize> = None;
        let mut best_depth = usize::MAX;
        let mut recycle_idx: Vec<usize> = Vec::new();
        for (i, c) in self.conns.iter().enumerate() {
            let limit = c
                .keepalive
                .map(|k| k.saturating_sub(KEEPALIVE_MARGIN))
                .unwrap_or(idle_timeout);
            let stale = now.saturating_duration_since(c.last_activity) >= limit;
            let depth = self.slabs[c.slab_ix].depth();
            if stale {
                recycle_idx.push(i);
                continue;
            }
            if depth < best_depth {
                best_depth = depth;
                chosen = Some(i);
            }
        }
        let chosen_tok = chosen.map(|i| self.conns[i].conn_id);
        let mut recycle = Vec::with_capacity(recycle_idx.len());
        recycle_idx.sort_unstable();
        for &i in recycle_idx.iter().rev() {
            recycle.push(self.conns[i].conn_id);
            self.conns.remove(i);
        }
        (chosen_tok, recycle)
    }

    pub fn slab_ptr_for(&mut self, conn_id: Token) -> Option<NonNull<Slab<Outcome>>> {
        let pos = self.conns.iter().position(|c| c.conn_id == conn_id)?;
        let slab_ix = self.conns[pos].slab_ix;
        Some(NonNull::from(&mut *self.slabs[slab_ix]))
    }

    pub fn touch(&mut self, conn_id: Token, now: Instant) {
        if let Some(c) = self.conns.iter_mut().find(|c| c.conn_id == conn_id) {
            c.last_activity = now;
        }
    }
}

pub struct Session {
    codec: codec::Codec,
    pub shared: Shared,
}

impl Session {
    pub fn new(host: impl Into<String>) -> Self {
        Self {
            codec: codec::Codec::default(),
            shared: Shared::new(host.into()),
        }
    }

    pub fn with_decompression(host: impl Into<String>, policy: DecompressionPolicy) -> Self {
        let mut session = Self::new(host);
        session.shared.decompression = policy;
        session
    }

    pub fn max_response_body(&mut self, cap: usize) -> &mut Self {
        self.codec.max_response_body = cap;
        self
    }

    pub fn with_max_redirects(mut self, max: u32) -> Self {
        self.shared.max_redirects = max;
        self
    }

    pub fn with_retry(mut self, retry: RetryPolicy) -> Self {
        self.shared.retry = retry;
        self
    }

    pub fn with_idle_timeout(mut self, idle: Duration) -> Self {
        self.shared.idle_timeout = idle;
        self
    }

    pub fn with_request_timeout(mut self, dur: Duration) -> Self {
        self.shared.request_timeout = dur;
        self
    }
}

impl connector::Session for Session {
    type Codec = codec::Codec;
    type ConnState = ConnState;

    fn codec(&self) -> &codec::Codec {
        &self.codec
    }

    fn connect(&mut self, ctx: &mut connector::Ctx<'_, Self>) {
        self.shared.fatal.clear();
        self.shared.note_connect(ctx.conn_id, Instant::now());
    }

    fn response(&mut self, head: Head, ctx: &mut connector::Ctx<'_, Self>) {
        if let Some(reason) = head.error {
            self.shared.push_response(
                ctx.conn_id,
                Err(Error::Parse(reason.into())),
                None,
                Instant::now(),
            );
            ctx.state.pending_close = true;
            return;
        }
        let bytes = head.full.as_ref();
        let (outcome, keep_alive, keepalive_timeout) =
            match Parse::response(bytes, DecodeMode::Response) {
                Ok(Some(mut resp)) => {
                    let keep = Self::should_keep_alive(&resp);
                    let timeout = Self::keepalive_timeout(&resp);
                    let outcome = match Self::decompress(
                        &mut resp,
                        self.shared.decompression,
                        self.codec.max_response_body,
                    ) {
                        Ok(()) => Ok(resp),
                        Err(e) => Err(e),
                    };
                    (outcome, keep, timeout)
                }
                Ok(None) => (
                    Err(Error::Parse("incomplete response frame".into())),
                    true,
                    None,
                ),
                Err(e) => (Err(Error::Parse(e.to_string())), true, None),
            };
        if !keep_alive {
            ctx.state.pending_close = true;
        }
        self.shared
            .push_response(ctx.conn_id, outcome, keepalive_timeout, Instant::now());
    }

    fn disconnect(&mut self, ctx: &mut connector::Ctx<'_, Self>) {
        let fatal_msg = self.shared.fatal.as_ref().map(|e| e.to_string());
        self.shared.fail_connection(ctx.conn_id, fatal_msg);
        ctx.state.pending_close = false;
    }
}

impl Session {
    fn should_keep_alive(resp: &Response) -> bool {
        let headers = resp.headers();
        if headers.has_token(http::header::CONNECTION, "close")
            || headers.has_token(http::header::CONNECTION, "upgrade")
        {
            return false;
        }
        if headers.has_token(http::header::CONNECTION, "keep-alive") {
            return true;
        }

        let status = resp.status().as_u16();
        if status < 200 || status == 204 || status == 304 {
            return true;
        }

        headers.contains_key(http::header::CONTENT_LENGTH)
            || headers.value_eq_ascii_case(http::header::TRANSFER_ENCODING, "chunked")
    }

    fn keepalive_timeout(resp: &Response) -> Option<Duration> {
        let name = http::header::HeaderName::from_static("keep-alive");
        let raw = resp.headers().get(name)?;
        let value = raw.to_str().ok()?;
        for part in value.split(',') {
            let part = part.trim();
            if let Some(rest) = part.strip_prefix("timeout=")
                && let Ok(secs) = rest.trim().parse::<u64>()
            {
                return Some(Duration::from_secs(secs));
            }
        }
        None
    }

    fn decompress(
        resp: &mut Response,
        policy: DecompressionPolicy,
        max_body: usize,
    ) -> Result<(), Error> {
        let is_gzip = resp
            .headers()
            .value_eq_ascii_case(http::header::CONTENT_ENCODING, "gzip");
        if !is_gzip || resp.body().is_empty() {
            return Ok(());
        }

        let limit = max_body as u64;
        let mut decoder = flate2::read::GzDecoder::new(resp.body()).take(limit + 1);
        let mut decompressed = Vec::new();
        match decoder.read_to_end(&mut decompressed) {
            Ok(_) if decompressed.len() as u64 > limit => Err(Error::Parse(
                "decompressed response body exceeds size limit".into(),
            )),
            Ok(_) => {
                resp.set_body(Owned::from(&decompressed[..]));
                resp.headers_mut().remove("content-encoding");
                resp.headers_mut().remove("content-length");
                Ok(())
            }
            Err(e) if policy == DecompressionPolicy::Strict => {
                Err(Error::Parse(format!("gzip decompression failed: {e}")))
            }
            Err(_) => Ok(()),
        }
    }
}