use crate::page::Page;
use base64::{engine::general_purpose, Engine as _};
use chromiumoxide_cdp::cdp::browser_protocol::{
fetch::TakeResponseBodyAsStreamParams,
io::{CloseParams, ReadParams, StreamHandle},
};
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Semaphore;
const DEFAULT_IO_READ_CHUNK: i64 = 65_536;
fn max_body_bytes() -> Option<usize> {
std::env::var("CHROMEY_STREAM_MAX_BODY_BYTES")
.ok()
.and_then(|v| v.parse().ok())
}
const BASE_STREAMING_THRESHOLD: usize = 262_144;
const MIN_STREAMING_THRESHOLD: usize = 32_768;
const HIGH_PRESSURE_PAGES: usize = 64;
const DEFAULT_STREAM_CONCURRENCY: usize = 48;
const STREAM_PERMIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15);
lazy_static::lazy_static! {
static ref GLOBAL_STREAM_PERMITS: Semaphore = {
let n = std::env::var("CHROMEY_STREAM_CONCURRENCY")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(DEFAULT_STREAM_CONCURRENCY);
Semaphore::new(n)
};
}
static INFLIGHT_STREAMS: AtomicUsize = AtomicUsize::new(0);
#[cfg(feature = "_cache_stream_disk")]
static STREAM_FILE_SEQ: AtomicUsize = AtomicUsize::new(0);
#[inline]
pub fn inflight_stream_count() -> usize {
INFLIGHT_STREAMS.load(Ordering::Relaxed)
}
#[inline]
pub fn streaming_threshold_bytes() -> usize {
let pages = crate::handler::page::active_page_count();
if pages >= HIGH_PRESSURE_PAGES {
return MIN_STREAMING_THRESHOLD;
}
let range = BASE_STREAMING_THRESHOLD - MIN_STREAMING_THRESHOLD;
let reduction = range * pages / HIGH_PRESSURE_PAGES;
BASE_STREAMING_THRESHOLD - reduction
}
struct StreamGuard<'a> {
page: &'a Page,
handle: Option<StreamHandle>,
}
impl<'a> StreamGuard<'a> {
fn new(page: &'a Page, handle: StreamHandle) -> Self {
Self {
page,
handle: Some(handle),
}
}
fn take(&mut self) -> Option<StreamHandle> {
self.handle.take()
}
}
impl Drop for StreamGuard<'_> {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
let page = self.page.clone();
tokio::spawn(async move {
let _ = page.execute(CloseParams { handle }).await;
});
}
}
}
#[cfg(feature = "_cache_stream_disk")]
enum ChunkSink {
Disk {
file: tokio::fs::File,
path: std::path::PathBuf,
},
Memory { buf: Vec<u8> },
}
#[cfg(feature = "_cache_stream_disk")]
impl ChunkSink {
async fn open_disk() -> Self {
match Self::try_open_disk().await {
Ok(sink) => sink,
Err(err) => {
tracing::debug!("stream disk init failed, using memory: {err}");
ChunkSink::Memory { buf: Vec::new() }
}
}
}
async fn try_open_disk() -> Result<Self, std::io::Error> {
let tmp_dir = std::env::temp_dir();
tokio::fs::create_dir_all(&tmp_dir).await?;
let seq = STREAM_FILE_SEQ.fetch_add(1, Ordering::Relaxed);
let file_name = format!("chromey_stream_{}_{}.tmp", std::process::id(), seq);
let path = tmp_dir.join(file_name);
let file = tokio::fs::File::create(&path).await?;
Ok(ChunkSink::Disk { file, path })
}
async fn write_chunk(&mut self, decoded: &[u8]) {
match self {
ChunkSink::Disk { file, path } => {
use tokio::io::AsyncWriteExt;
if let Err(err) = file.write_all(decoded).await {
tracing::debug!(
"stream disk write failed, falling back to memory: {err}"
);
let _ = file.flush().await;
let mut recovered = match tokio::fs::read(path.as_path()).await {
Ok(bytes) => bytes,
Err(_) => Vec::new(),
};
let _ = tokio::fs::remove_file(path.as_path()).await;
recovered.extend_from_slice(decoded);
*self = ChunkSink::Memory { buf: recovered };
}
}
ChunkSink::Memory { buf } => {
buf.extend_from_slice(decoded);
}
}
}
async fn finish(&mut self) -> Vec<u8> {
match self {
ChunkSink::Disk { ref mut file, ref path } => {
use tokio::io::AsyncWriteExt;
if let Err(err) = file.flush().await {
tracing::debug!("stream disk flush failed on finish: {err}");
}
let path = path.clone();
let body = match tokio::fs::read(&path).await {
Ok(bytes) => bytes,
Err(err) => {
tracing::debug!("stream disk read-back failed: {err}");
Vec::new()
}
};
let _ = tokio::fs::remove_file(&path).await;
*self = ChunkSink::Memory { buf: Vec::new() };
body
}
ChunkSink::Memory { ref mut buf } => std::mem::take(buf),
}
}
}
#[cfg(feature = "_cache_stream_disk")]
impl Drop for ChunkSink {
fn drop(&mut self) {
if let ChunkSink::Disk { path, .. } = self {
let path = path.clone();
tokio::spawn(async move {
let _ = tokio::fs::remove_file(&path).await;
});
}
}
}
async fn read_all_chunks(
page: &Page,
stream_handle: &StreamHandle,
content_length_hint: Option<usize>,
) -> Result<Vec<u8>, StreamError> {
#[cfg(feature = "_cache_stream_disk")]
let mut sink = ChunkSink::open_disk().await;
#[cfg(not(feature = "_cache_stream_disk"))]
let mut body = Vec::with_capacity(content_length_hint.unwrap_or(0));
let cap = max_body_bytes();
let mut total_bytes: usize = 0;
let mut decode_buf: Vec<u8> = Vec::with_capacity(DEFAULT_IO_READ_CHUNK as usize);
loop {
let read_result = page
.execute(
ReadParams::builder()
.handle(stream_handle.clone())
.size(DEFAULT_IO_READ_CHUNK)
.build()
.map_err(StreamError::Build)?,
)
.await
.map_err(StreamError::Cdp)?;
let chunk = &read_result.result;
let data_bytes: &[u8] = if chunk.base64_encoded.unwrap_or(false) {
decode_buf.clear();
general_purpose::STANDARD
.decode_vec(&chunk.data, &mut decode_buf)
.map_err(StreamError::Base64)?;
&decode_buf
} else {
chunk.data.as_bytes()
};
total_bytes += data_bytes.len();
if let Some(max) = cap {
if total_bytes > max {
return Err(StreamError::BodyTooLarge(total_bytes));
}
}
#[cfg(feature = "_cache_stream_disk")]
sink.write_chunk(data_bytes).await;
#[cfg(not(feature = "_cache_stream_disk"))]
body.extend_from_slice(data_bytes);
if chunk.eof {
break;
}
}
#[cfg(feature = "_cache_stream_disk")]
{
let _ = content_length_hint; Ok(sink.finish().await)
}
#[cfg(not(feature = "_cache_stream_disk"))]
Ok(body)
}
#[derive(Debug)]
pub enum StreamResult {
Ok(Vec<u8>),
NotStarted(StreamError),
PartialBody {
body: Vec<u8>,
error: StreamError,
},
}
pub async fn read_response_body_as_stream(
page: &Page,
request_id: impl Into<chromiumoxide_cdp::cdp::browser_protocol::fetch::RequestId>,
content_length_hint: Option<usize>,
) -> StreamResult {
let _permit = match tokio::time::timeout(
STREAM_PERMIT_TIMEOUT,
GLOBAL_STREAM_PERMITS.acquire(),
)
.await
{
Ok(Ok(p)) => p,
Ok(Err(_)) => return StreamResult::NotStarted(StreamError::SemaphoreClosed),
Err(_) => return StreamResult::NotStarted(StreamError::Timeout),
};
INFLIGHT_STREAMS.fetch_add(1, Ordering::Relaxed);
let _dec = DecrementOnDrop(&INFLIGHT_STREAMS);
let returns = match page
.execute(TakeResponseBodyAsStreamParams::new(request_id))
.await
{
Ok(r) => r,
Err(e) => {
return StreamResult::NotStarted(StreamError::Cdp(e));
}
};
let stream_handle = returns.result.stream;
let mut guard = StreamGuard::new(page, stream_handle.clone());
match read_all_chunks(page, &stream_handle, content_length_hint).await {
Ok(body) => {
if let Some(h) = guard.take() {
let _ = page.execute(CloseParams { handle: h }).await;
}
StreamResult::Ok(body)
}
Err(err) => {
if let Some(h) = guard.take() {
let _ = page.execute(CloseParams { handle: h }).await;
}
#[cfg(feature = "_cache_stream_disk")]
let partial = Vec::new();
#[cfg(not(feature = "_cache_stream_disk"))]
let partial = Vec::new();
StreamResult::PartialBody {
body: partial,
error: err,
}
}
}
}
#[derive(Debug)]
pub enum StreamError {
SemaphoreClosed,
Cdp(crate::error::CdpError),
Timeout,
BodyTooLarge(usize),
Base64(base64::DecodeError),
Build(String),
Io(std::io::Error),
}
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SemaphoreClosed => write!(f, "stream semaphore closed"),
Self::Cdp(e) => write!(f, "CDP error during stream read: {e}"),
Self::Timeout => write!(f, "stream read timed out"),
Self::BodyTooLarge(n) => write!(f, "response body too large: {n} bytes"),
Self::Base64(e) => write!(f, "base64 decode error: {e}"),
Self::Build(e) => write!(f, "CDP param build error: {e}"),
Self::Io(e) => write!(f, "stream I/O error: {e}"),
}
}
}
impl std::error::Error for StreamError {}
pub fn content_length_from_headers(
headers: &std::collections::HashMap<String, String>,
) -> Option<usize> {
for (k, v) in headers {
if k.eq_ignore_ascii_case("content-length") {
return v.parse().ok();
}
}
None
}
struct DecrementOnDrop<'a>(&'a AtomicUsize);
impl Drop for DecrementOnDrop<'_> {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}