use crate::error::{CdpError, Result};
use crate::page::Page;
use chromiumoxide_cdp::cdp::js_protocol::runtime::{
CallArgument, CallFunctionOnParams, EvaluateParams, RemoteObjectId,
};
use futures_util::stream::{try_unfold, Stream};
pub const DEFAULT_CHUNK_UNITS: u32 = 65_536;
pub const MIN_CHUNK_UNITS: u32 = 1024;
pub const MAX_CHUNK_UNITS: u32 = 4_194_304;
pub const MAX_DOCUMENT_UNITS: u32 = 256 * 1024 * 1024;
pub const DEFAULT_MAX_ACCUMULATED_BYTES: usize = 512 * 1024 * 1024;
#[inline]
fn clamp_chunk_units(units: u32) -> u32 {
units.clamp(MIN_CHUNK_UNITS, MAX_CHUNK_UNITS)
}
const MAX_CHUNKS: usize = 262_144;
const INIT_JS: &str = r###"(()=>{
let rv='';
if(document.doctype){rv+=new XMLSerializer().serializeToString(document.doctype);}
if(document.documentElement){rv+=document.documentElement.outerHTML;}
return {h:rv};
})()"###;
const LEN_FN: &str = "function(){return this.h.length}";
const SLICE_FN: &str = r###"function(start,size){
const s=this.h;
const L=s.length;
if(start>=L)return '';
let end=start+size;
if(end>L)end=L;
if(end<L){
const c=s.charCodeAt(end-1);
if(c>=0xD800&&c<=0xDBFF)end-=1;
}
return s.slice(start,end);
}"###;
struct RemoteRefGuard {
page: Page,
object_id: RemoteObjectId,
}
impl RemoteRefGuard {
#[inline]
fn new(page: Page, object_id: RemoteObjectId) -> Self {
Self { page, object_id }
}
#[inline]
fn id(&self) -> &RemoteObjectId {
&self.object_id
}
}
impl Drop for RemoteRefGuard {
fn drop(&mut self) {
let id = std::mem::take(&mut self.object_id);
if !id.0.is_empty() {
crate::runtime_release::try_release(self.page.clone(), id);
}
}
}
#[cfg(feature = "_cache_stream_disk")]
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(feature = "_cache_stream_disk")]
static CONTENT_FILE_SEQ: AtomicUsize = AtomicUsize::new(0);
#[cfg(feature = "_cache_stream_disk")]
enum Sink {
Disk {
file: tokio::fs::File,
path: std::path::PathBuf,
},
Memory {
buf: Vec<u8>,
},
}
#[cfg(feature = "_cache_stream_disk")]
impl Sink {
async fn open(cap_hint: usize) -> Self {
match Self::try_open_disk().await {
Ok(s) => s,
Err(err) => {
tracing::debug!("content stream disk init failed, using memory: {err}");
Sink::Memory {
buf: Vec::with_capacity(cap_hint),
}
}
}
}
async fn try_open_disk() -> std::io::Result<Self> {
let tmp_dir = std::env::temp_dir();
tokio::fs::create_dir_all(&tmp_dir).await?;
let seq = CONTENT_FILE_SEQ.fetch_add(1, Ordering::Relaxed);
let name = format!("chromey_content_{}_{}.tmp", std::process::id(), seq);
let path = tmp_dir.join(name);
let file = tokio::fs::File::create(&path).await?;
Ok(Sink::Disk { file, path })
}
async fn write(&mut self, data: &[u8]) {
match self {
Sink::Disk { file, path } => {
use tokio::io::AsyncWriteExt;
if let Err(err) = file.write_all(data).await {
tracing::debug!(
"content stream disk write failed, falling back to memory: {err}"
);
let _ = file.flush().await;
let mut recovered = tokio::fs::read(path.as_path()).await.unwrap_or_default();
let _ = tokio::fs::remove_file(path.as_path()).await;
recovered.extend_from_slice(data);
*self = Sink::Memory { buf: recovered };
}
}
Sink::Memory { buf } => buf.extend_from_slice(data),
}
}
async fn finish(&mut self) -> Vec<u8> {
match self {
Sink::Disk { file, path } => {
use tokio::io::AsyncWriteExt;
let _ = file.flush().await;
let p = path.clone();
let body = tokio::fs::read(&p).await.unwrap_or_default();
let _ = tokio::fs::remove_file(&p).await;
*self = Sink::Memory { buf: Vec::new() };
body
}
Sink::Memory { buf } => std::mem::take(buf),
}
}
}
#[cfg(feature = "_cache_stream_disk")]
impl Drop for Sink {
fn drop(&mut self) {
if let Sink::Disk { path, .. } = self {
let p = path.clone();
tokio::spawn(async move {
let _ = tokio::fs::remove_file(&p).await;
});
}
}
}
#[inline]
fn stream_threshold_units() -> u32 {
const BASE: u32 = 1_048_576; const MIN: u32 = 262_144; const HIGH_PRESSURE_PAGES: u32 = 128;
let pages = crate::handler::page::active_page_count() as u32;
if pages >= HIGH_PRESSURE_PAGES {
return MIN;
}
let range = BASE - MIN;
let reduction = range * pages / HIGH_PRESSURE_PAGES;
BASE - reduction
}
pub async fn content_bytes_streaming(page: &Page) -> Result<Vec<u8>> {
let Some((guard, total_units, chunk_units)) = init_state(page).await? else {
return Ok(Vec::new());
};
read_chunks(page, guard.id(), total_units, chunk_units).await
}
pub fn content_bytes_stream(
page: &Page,
chunk_units: Option<u32>,
) -> impl Stream<Item = Result<Vec<u8>>> + Send + 'static {
let page = page.clone();
let override_units = chunk_units.map(clamp_chunk_units);
try_unfold(
PumpState::Init {
page,
override_units,
},
|state| async move {
match state {
PumpState::Init {
page,
override_units,
} => match init_state(&page).await? {
None => Ok(None),
Some((guard, total, default_units)) => {
let chunk_units = override_units.unwrap_or(default_units);
pump_next(page, guard, total, 0, chunk_units, 0).await
}
},
PumpState::Pumping {
page,
guard,
total,
offset,
chunk_units,
rounds,
} => pump_next(page, guard, total, offset, chunk_units, rounds).await,
}
},
)
}
pub fn content_stream(
page: &Page,
chunk_units: Option<u32>,
) -> impl Stream<Item = Result<String>> + Send + 'static {
use futures_util::StreamExt;
content_bytes_stream(page, chunk_units).map(|r| {
r.and_then(|bytes| {
String::from_utf8(bytes)
.map_err(|e| CdpError::msg(format!("invalid UTF-8 in page content chunk: {e}")))
})
})
}
enum PumpState {
Init {
page: Page,
override_units: Option<u32>,
},
Pumping {
page: Page,
guard: RemoteRefGuard,
total: u32,
offset: u32,
chunk_units: u32,
rounds: usize,
},
}
async fn pump_next(
page: Page,
guard: RemoteRefGuard,
total: u32,
offset: u32,
chunk_units: u32,
rounds: usize,
) -> Result<Option<(Vec<u8>, PumpState)>> {
if offset >= total || rounds >= MAX_CHUNKS {
return Ok(None);
}
let bytes = read_slice(&page, guard.id(), offset, chunk_units).await?;
if bytes.is_empty() {
return Ok(None);
}
let advanced = utf16_len_of_utf8(&bytes);
if advanced == 0 {
return Ok(None);
}
let new_offset = offset.saturating_add(advanced);
Ok(Some((
bytes,
PumpState::Pumping {
page,
guard,
total,
offset: new_offset,
chunk_units,
rounds: rounds + 1,
},
)))
}
async fn init_state(page: &Page) -> Result<Option<(RemoteRefGuard, u32, u32)>> {
crate::runtime_release::init_worker();
let ctx = page.execution_context().await?;
let mut init = EvaluateParams::new(INIT_JS);
init.context_id = ctx;
init.await_promise = Some(true);
init.return_by_value = Some(false);
let init_resp = page.execute(init).await?.result;
if let Some(ex) = init_resp.exception_details {
return Err(CdpError::JavascriptException(Box::new(ex)));
}
let object_id = init_resp
.result
.object_id
.ok_or_else(|| CdpError::msg("content stream: init returned no objectId"))?;
let guard = RemoteRefGuard::new(page.clone(), object_id);
let len_params = CallFunctionOnParams::builder()
.function_declaration(LEN_FN)
.object_id(guard.id().clone())
.return_by_value(true)
.await_promise(false)
.build()
.map_err(CdpError::msg)?;
let len_resp = page.execute(len_params).await?.result;
if let Some(ex) = len_resp.exception_details {
return Err(CdpError::JavascriptException(Box::new(ex)));
}
let total_units_u64 = len_resp
.result
.value
.and_then(|v| v.as_u64())
.ok_or_else(|| CdpError::msg("content stream: length was not a number"))?;
if total_units_u64 > MAX_DOCUMENT_UNITS as u64 {
return Err(CdpError::msg(format!(
"content stream: document exceeds MAX_DOCUMENT_UNITS ({} > {})",
total_units_u64, MAX_DOCUMENT_UNITS
)));
}
let total_units: u32 = total_units_u64 as u32;
if total_units == 0 {
return Ok(None);
}
let chunk = if total_units < stream_threshold_units() {
DEFAULT_CHUNK_UNITS.max(total_units)
} else {
DEFAULT_CHUNK_UNITS
};
Ok(Some((guard, total_units, chunk)))
}
async fn read_slice(
page: &Page,
object_id: &RemoteObjectId,
offset: u32,
chunk_units: u32,
) -> Result<Vec<u8>> {
let params = CallFunctionOnParams::builder()
.function_declaration(SLICE_FN)
.object_id(object_id.clone())
.argument(
CallArgument::builder()
.value(serde_json::json!(offset))
.build(),
)
.argument(
CallArgument::builder()
.value(serde_json::json!(chunk_units))
.build(),
)
.return_by_value(true)
.await_promise(false)
.build()
.map_err(CdpError::msg)?;
let resp = page.execute(params).await?.result;
if let Some(ex) = resp.exception_details {
return Err(CdpError::JavascriptException(Box::new(ex)));
}
match resp.result.value {
Some(serde_json::Value::String(s)) => Ok(s.into_bytes()),
Some(serde_json::Value::Null) | None => Ok(Vec::new()),
other => Err(CdpError::msg(format!(
"content stream: unexpected slice value: {other:?}"
))),
}
}
pub async fn content_streaming(page: &Page) -> Result<String> {
let bytes = content_bytes_streaming(page).await?;
String::from_utf8(bytes)
.map_err(|e| CdpError::msg(format!("invalid UTF-8 in page content: {e}")))
}
async fn read_chunks(
page: &Page,
object_id: &RemoteObjectId,
total_units: u32,
chunk_units: u32,
) -> Result<Vec<u8>> {
let byte_cap = max_accumulated_bytes();
let cap_hint = (total_units as usize).saturating_mul(3) / 2;
let cap_hint = cap_hint.min(8 * 1024 * 1024).min(byte_cap);
#[cfg(feature = "_cache_stream_disk")]
let mut sink = Sink::open(cap_hint).await;
#[cfg(not(feature = "_cache_stream_disk"))]
let mut buf: Vec<u8> = Vec::with_capacity(cap_hint);
let mut offset: u32 = 0;
let mut rounds: usize = 0;
let mut total_bytes: usize = 0;
while offset < total_units {
if rounds >= MAX_CHUNKS {
return Err(CdpError::msg("content stream exceeded MAX_CHUNKS"));
}
rounds += 1;
let chunk_bytes = read_slice(page, object_id, offset, chunk_units).await?;
if chunk_bytes.is_empty() {
break;
}
let units_advanced = utf16_len_of_utf8(&chunk_bytes);
if units_advanced == 0 {
break;
}
total_bytes = total_bytes.saturating_add(chunk_bytes.len());
if total_bytes > byte_cap {
return Err(CdpError::msg(format!(
"content stream: accumulated bytes exceeded cap ({} > {})",
total_bytes, byte_cap
)));
}
#[cfg(feature = "_cache_stream_disk")]
sink.write(&chunk_bytes).await;
#[cfg(not(feature = "_cache_stream_disk"))]
buf.extend_from_slice(&chunk_bytes);
offset = offset.saturating_add(units_advanced);
}
#[cfg(feature = "_cache_stream_disk")]
{
Ok(sink.finish().await)
}
#[cfg(not(feature = "_cache_stream_disk"))]
{
Ok(buf)
}
}
#[inline]
fn max_accumulated_bytes() -> usize {
std::env::var("CHROMEY_CONTENT_STREAM_MAX_BYTES")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_ACCUMULATED_BYTES)
}
#[inline]
fn utf16_len_of_utf8(bytes: &[u8]) -> u32 {
let total = bytes.len();
let mut cont: u64 = 0;
let mut four: u64 = 0;
let mut chunks = bytes.chunks_exact(8);
for chunk in &mut chunks {
let arr: [u8; 8] = [
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
];
let w = u64::from_ne_bytes(arr);
let cont_mask = (w & 0xC0C0_C0C0_C0C0_C0C0) ^ 0x8080_8080_8080_8080;
cont = cont.wrapping_add(count_zero_bytes_u64(cont_mask) as u64);
let four_mask = (w & 0xF0F0_F0F0_F0F0_F0F0) ^ 0xF0F0_F0F0_F0F0_F0F0;
four = four.wrapping_add(count_zero_bytes_u64(four_mask) as u64);
}
for &b in chunks.remainder() {
if (b & 0xC0) == 0x80 {
cont += 1;
}
if b >= 0xF0 {
four += 1;
}
}
let units = (total as u64).saturating_sub(cont).saturating_add(four);
units.min(u32::MAX as u64) as u32
}
#[inline(always)]
fn count_zero_bytes_u64(v: u64) -> u32 {
let mask = v.wrapping_sub(0x0101_0101_0101_0101) & !v & 0x8080_8080_8080_8080;
mask.count_ones()
}
#[cfg(test)]
mod utf16_len_tests {
use super::utf16_len_of_utf8;
#[test]
fn empty() {
assert_eq!(utf16_len_of_utf8(b""), 0);
}
#[test]
fn pure_ascii_short() {
assert_eq!(utf16_len_of_utf8(b"hello"), 5);
}
#[test]
fn pure_ascii_multi_word() {
let s = b"abcdefghijklmnopqrstuvwx";
assert_eq!(utf16_len_of_utf8(s), 24);
}
#[test]
fn ascii_with_tail() {
let s = b"hello world";
assert_eq!(utf16_len_of_utf8(s), 11);
}
#[test]
fn two_byte_sequences() {
let s = "éééééééé".as_bytes(); assert_eq!(s.len(), 16);
assert_eq!(utf16_len_of_utf8(s), 8);
}
#[test]
fn three_byte_sequences() {
let s = "€€€€€".as_bytes(); assert_eq!(s.len(), 15);
assert_eq!(utf16_len_of_utf8(s), 5);
}
#[test]
fn four_byte_sequences() {
let s = "😀😀😀".as_bytes(); assert_eq!(s.len(), 12);
assert_eq!(utf16_len_of_utf8(s), 6);
}
#[test]
fn mixed_matches_std() {
let s = "<p>héllo 世界 🌍</p>";
let expected: u32 = s.encode_utf16().count() as u32;
assert_eq!(utf16_len_of_utf8(s.as_bytes()), expected);
}
#[test]
fn large_ascii_matches_std() {
let s: String = "abcdefghij".repeat(1024);
let expected: u32 = s.encode_utf16().count() as u32;
assert_eq!(utf16_len_of_utf8(s.as_bytes()), expected);
}
#[test]
fn large_mixed_matches_std() {
let s: String = "a€b😀c".repeat(512);
let expected: u32 = s.encode_utf16().count() as u32;
assert_eq!(utf16_len_of_utf8(s.as_bytes()), expected);
}
}