use crate::page::Page;
use base64::{engine::general_purpose, Engine as _};
use chromiumoxide_cdp::cdp::browser_protocol::{
fetch::TakeResponseBodyAsStreamParams,
io::{CloseParams, ReadParams, StreamHandle},
};
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Semaphore;
const DEFAULT_IO_READ_CHUNK: i64 = 65_536;
const MAX_BODY_BYTES: usize = 50 * 1024 * 1024;
const STREAM_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
const BASE_STREAMING_THRESHOLD: usize = 262_144;
const MIN_STREAMING_THRESHOLD: usize = 32_768;
const HIGH_PRESSURE_PAGES: usize = 64;
static GLOBAL_STREAM_PERMITS: Semaphore = Semaphore::const_new(12);
static INFLIGHT_STREAMS: AtomicUsize = AtomicUsize::new(0);
#[inline]
pub fn inflight_stream_count() -> usize {
INFLIGHT_STREAMS.load(Ordering::Relaxed)
}
#[inline]
pub fn streaming_threshold_bytes() -> usize {
let pages = crate::handler::page::active_page_count();
if pages >= HIGH_PRESSURE_PAGES {
return MIN_STREAMING_THRESHOLD;
}
let range = BASE_STREAMING_THRESHOLD - MIN_STREAMING_THRESHOLD;
let reduction = range * pages / HIGH_PRESSURE_PAGES;
BASE_STREAMING_THRESHOLD - reduction
}
struct StreamGuard<'a> {
page: &'a Page,
handle: Option<StreamHandle>,
}
impl<'a> StreamGuard<'a> {
fn new(page: &'a Page, handle: StreamHandle) -> Self {
Self {
page,
handle: Some(handle),
}
}
fn take(&mut self) -> Option<StreamHandle> {
self.handle.take()
}
}
impl Drop for StreamGuard<'_> {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
let page = self.page.clone();
tokio::spawn(async move {
let _ = page.execute(CloseParams { handle }).await;
});
}
}
}
#[cfg(feature = "_cache_stream_disk")]
struct TempFileGuard {
path: std::path::PathBuf,
}
#[cfg(feature = "_cache_stream_disk")]
impl TempFileGuard {
fn new(path: std::path::PathBuf) -> Self {
Self { path }
}
fn into_path(mut self) -> std::path::PathBuf {
let p = std::mem::take(&mut self.path);
std::mem::forget(self);
p
}
}
#[cfg(feature = "_cache_stream_disk")]
impl Drop for TempFileGuard {
fn drop(&mut self) {
if !self.path.as_os_str().is_empty() {
let path = self.path.clone();
tokio::spawn(async move {
let _ = tokio::fs::remove_file(&path).await;
});
}
}
}
#[cfg(feature = "_cache_stream_disk")]
async fn read_chunks_to_disk(
page: &Page,
stream_handle: &StreamHandle,
guard: &mut StreamGuard<'_>,
) -> Result<Vec<u8>, StreamError> {
use tokio::io::AsyncWriteExt;
let tmp_dir = std::env::temp_dir();
let file_name = format!(
"chromey_stream_{:x}_{:x}.tmp",
std::process::id(),
INFLIGHT_STREAMS.load(Ordering::Relaxed) as u64
^ std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64,
);
let tmp_path = tmp_dir.join(file_name);
let file_guard = TempFileGuard::new(tmp_path.clone());
let mut file = tokio::fs::File::create(&tmp_path)
.await
.map_err(StreamError::Io)?;
let mut total_bytes: usize = 0;
loop {
let read_result = tokio::time::timeout(STREAM_TIMEOUT, async {
page.execute(
ReadParams::builder()
.handle(stream_handle.clone())
.size(DEFAULT_IO_READ_CHUNK)
.build()
.map_err(StreamError::Build)?,
)
.await
.map_err(StreamError::Cdp)
})
.await
.map_err(|_| StreamError::Timeout)??;
let chunk = &read_result.result;
let decoded = if chunk.base64_encoded.unwrap_or(false) {
general_purpose::STANDARD
.decode(&chunk.data)
.map_err(StreamError::Base64)?
} else {
chunk.data.as_bytes().to_vec()
};
total_bytes += decoded.len();
if total_bytes > MAX_BODY_BYTES {
if let Some(h) = guard.take() {
let _ = page.execute(CloseParams { handle: h }).await;
}
return Err(StreamError::BodyTooLarge(total_bytes));
}
file.write_all(&decoded).await.map_err(StreamError::Io)?;
if chunk.eof {
break;
}
}
file.flush().await.map_err(StreamError::Io)?;
drop(file);
let body = tokio::fs::read(file_guard.into_path())
.await
.map_err(StreamError::Io)?;
let _ = tokio::fs::remove_file(&tmp_path).await;
Ok(body)
}
#[cfg(not(feature = "_cache_stream_disk"))]
async fn read_chunks_to_memory(
page: &Page,
stream_handle: &StreamHandle,
guard: &mut StreamGuard<'_>,
content_length_hint: Option<usize>,
) -> Result<Vec<u8>, StreamError> {
let alloc_hint = content_length_hint.unwrap_or(0).min(MAX_BODY_BYTES);
let mut body = Vec::with_capacity(alloc_hint);
loop {
let read_result = tokio::time::timeout(STREAM_TIMEOUT, async {
page.execute(
ReadParams::builder()
.handle(stream_handle.clone())
.size(DEFAULT_IO_READ_CHUNK)
.build()
.map_err(StreamError::Build)?,
)
.await
.map_err(StreamError::Cdp)
})
.await
.map_err(|_| StreamError::Timeout)??;
let chunk = &read_result.result;
let decoded = if chunk.base64_encoded.unwrap_or(false) {
general_purpose::STANDARD
.decode(&chunk.data)
.map_err(StreamError::Base64)?
} else {
chunk.data.as_bytes().to_vec()
};
if body.len() + decoded.len() > MAX_BODY_BYTES {
if let Some(h) = guard.take() {
let _ = page.execute(CloseParams { handle: h }).await;
}
return Err(StreamError::BodyTooLarge(body.len() + decoded.len()));
}
body.extend_from_slice(&decoded);
if chunk.eof {
break;
}
}
Ok(body)
}
pub async fn read_response_body_as_stream(
page: &Page,
request_id: impl Into<chromiumoxide_cdp::cdp::browser_protocol::fetch::RequestId>,
content_length_hint: Option<usize>,
) -> Result<Vec<u8>, StreamError> {
let _permit = GLOBAL_STREAM_PERMITS
.acquire()
.await
.map_err(|_| StreamError::SemaphoreClosed)?;
INFLIGHT_STREAMS.fetch_add(1, Ordering::Relaxed);
let _dec = DecrementOnDrop(&INFLIGHT_STREAMS);
let returns = tokio::time::timeout(
STREAM_TIMEOUT,
page.execute(TakeResponseBodyAsStreamParams::new(request_id)),
)
.await
.map_err(|_| StreamError::Timeout)?
.map_err(StreamError::Cdp)?;
let stream_handle = returns.result.stream;
let mut guard = StreamGuard::new(page, stream_handle.clone());
#[cfg(feature = "_cache_stream_disk")]
let body = {
let _ = content_length_hint; read_chunks_to_disk(page, &stream_handle, &mut guard).await?
};
#[cfg(not(feature = "_cache_stream_disk"))]
let body = read_chunks_to_memory(page, &stream_handle, &mut guard, content_length_hint).await?;
if let Some(h) = guard.take() {
let _ = page.execute(CloseParams { handle: h }).await;
}
Ok(body)
}
#[derive(Debug)]
pub enum StreamError {
SemaphoreClosed,
Cdp(crate::error::CdpError),
Timeout,
BodyTooLarge(usize),
Base64(base64::DecodeError),
Build(String),
Io(std::io::Error),
}
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SemaphoreClosed => write!(f, "stream semaphore closed"),
Self::Cdp(e) => write!(f, "CDP error during stream read: {e}"),
Self::Timeout => write!(f, "stream read timed out"),
Self::BodyTooLarge(n) => write!(f, "response body too large: {n} bytes"),
Self::Base64(e) => write!(f, "base64 decode error: {e}"),
Self::Build(e) => write!(f, "CDP param build error: {e}"),
Self::Io(e) => write!(f, "stream I/O error: {e}"),
}
}
}
impl std::error::Error for StreamError {}
pub fn content_length_from_headers(
headers: &std::collections::HashMap<String, String>,
) -> Option<usize> {
for (k, v) in headers {
if k.eq_ignore_ascii_case("content-length") {
return v.parse().ok();
}
}
None
}
struct DecrementOnDrop<'a>(&'a AtomicUsize);
impl Drop for DecrementOnDrop<'_> {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}