use std::path::Path;
use tracing::{debug, info, instrument};
use viewpoint_cdp::protocol::page::{
CaptureScreenshotParams, CaptureScreenshotResult, ScreenshotFormat as CdpScreenshotFormat,
Viewport,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ScreenshotFormat {
#[default]
Png,
Jpeg,
Webp,
}
impl From<ScreenshotFormat> for CdpScreenshotFormat {
fn from(format: ScreenshotFormat) -> Self {
match format {
ScreenshotFormat::Png => CdpScreenshotFormat::Png,
ScreenshotFormat::Jpeg => CdpScreenshotFormat::Jpeg,
ScreenshotFormat::Webp => CdpScreenshotFormat::Webp,
}
}
}
use crate::error::PageError;
use super::Page;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Animations {
#[default]
Allow,
Disabled,
}
#[derive(Debug, Clone, Copy)]
pub struct ClipRegion {
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
}
impl ClipRegion {
pub fn new(x: f64, y: f64, width: f64, height: f64) -> Self {
Self {
x,
y,
width,
height,
}
}
}
#[derive(Debug, Clone)]
pub struct ScreenshotBuilder<'a> {
page: &'a Page,
format: ScreenshotFormat,
quality: Option<u8>,
full_page: bool,
clip: Option<ClipRegion>,
path: Option<String>,
omit_background: bool,
animations: Animations,
capture_beyond_viewport: bool,
}
impl<'a> ScreenshotBuilder<'a> {
pub(crate) fn new(page: &'a Page) -> Self {
Self {
page,
format: ScreenshotFormat::Png,
quality: None,
full_page: false,
clip: None,
path: None,
omit_background: false,
animations: Animations::default(),
capture_beyond_viewport: false,
}
}
#[must_use]
pub fn png(mut self) -> Self {
self.format = ScreenshotFormat::Png;
self
}
#[must_use]
pub fn jpeg(mut self, quality: Option<u8>) -> Self {
self.format = ScreenshotFormat::Jpeg;
self.quality = quality;
self
}
#[must_use]
pub fn format(mut self, format: ScreenshotFormat) -> Self {
self.format = format;
self
}
#[must_use]
pub fn quality(mut self, quality: u8) -> Self {
self.quality = Some(quality.min(100));
self
}
#[must_use]
pub fn full_page(mut self, full_page: bool) -> Self {
self.full_page = full_page;
self.capture_beyond_viewport = full_page;
self
}
#[must_use]
pub fn clip(mut self, x: f64, y: f64, width: f64, height: f64) -> Self {
self.clip = Some(ClipRegion::new(x, y, width, height));
self
}
#[must_use]
pub fn clip_region(mut self, region: ClipRegion) -> Self {
self.clip = Some(region);
self
}
#[must_use]
pub fn path(mut self, path: impl AsRef<Path>) -> Self {
self.path = Some(path.as_ref().to_string_lossy().to_string());
self
}
#[must_use]
pub fn omit_background(mut self, omit: bool) -> Self {
self.omit_background = omit;
self
}
#[must_use]
pub fn animations(mut self, animations: Animations) -> Self {
self.animations = animations;
self
}
#[instrument(level = "info", skip(self), fields(format = ?self.format, full_page = self.full_page, has_path = self.path.is_some()))]
pub async fn capture(self) -> Result<Vec<u8>, PageError> {
if self.page.is_closed() {
return Err(PageError::Closed);
}
info!("Capturing screenshot");
if self.animations == Animations::Disabled {
debug!("Disabling animations");
self.disable_animations().await?;
}
let clip = if self.full_page {
let dimensions = self.get_full_page_dimensions().await?;
debug!(
width = dimensions.0,
height = dimensions.1,
"Full page dimensions"
);
Some(Viewport {
x: 0.0,
y: 0.0,
width: dimensions.0,
height: dimensions.1,
scale: 1.0,
})
} else {
self.clip.map(|c| Viewport {
x: c.x,
y: c.y,
width: c.width,
height: c.height,
scale: 1.0,
})
};
let params = CaptureScreenshotParams {
format: Some(self.format.into()),
quality: self.quality,
clip,
from_surface: Some(true),
capture_beyond_viewport: Some(self.capture_beyond_viewport),
optimize_for_speed: None,
};
debug!("Sending Page.captureScreenshot command");
let result: CaptureScreenshotResult = self
.page
.connection()
.send_command(
"Page.captureScreenshot",
Some(params),
Some(self.page.session_id()),
)
.await?;
if self.animations == Animations::Disabled {
debug!("Re-enabling animations");
self.enable_animations().await?;
}
let data = base64_decode(&result.data)?;
debug!(bytes = data.len(), "Screenshot captured");
if let Some(ref path) = self.path {
debug!(path = path, "Saving screenshot to file");
tokio::fs::write(path, &data).await.map_err(|e| {
PageError::EvaluationFailed(format!("Failed to save screenshot: {e}"))
})?;
info!(path = path, "Screenshot saved");
}
Ok(data)
}
async fn get_full_page_dimensions(&self) -> Result<(f64, f64), PageError> {
let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
.page
.connection()
.send_command(
"Runtime.evaluate",
Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
expression: r"
JSON.stringify({
width: Math.max(
document.body.scrollWidth,
document.documentElement.scrollWidth,
document.body.offsetWidth,
document.documentElement.offsetWidth,
document.body.clientWidth,
document.documentElement.clientWidth
),
height: Math.max(
document.body.scrollHeight,
document.documentElement.scrollHeight,
document.body.offsetHeight,
document.documentElement.offsetHeight,
document.body.clientHeight,
document.documentElement.clientHeight
)
})
"
.to_string(),
object_group: None,
include_command_line_api: None,
silent: Some(true),
context_id: None,
return_by_value: Some(true),
await_promise: Some(false),
}),
Some(self.page.session_id()),
)
.await?;
let json_str = result
.result
.value
.and_then(|v| v.as_str().map(String::from))
.ok_or_else(|| {
PageError::EvaluationFailed("Failed to get page dimensions".to_string())
})?;
let dimensions: serde_json::Value = serde_json::from_str(&json_str)
.map_err(|e| PageError::EvaluationFailed(format!("Failed to parse dimensions: {e}")))?;
let width = dimensions["width"].as_f64().unwrap_or(800.0);
let height = dimensions["height"].as_f64().unwrap_or(600.0);
Ok((width, height))
}
async fn disable_animations(&self) -> Result<(), PageError> {
let script = r"
(function() {
const style = document.createElement('style');
style.id = '__viewpoint_disable_animations__';
style.textContent = '*, *::before, *::after { animation-duration: 0s !important; animation-delay: 0s !important; transition-duration: 0s !important; transition-delay: 0s !important; }';
document.head.appendChild(style);
})()
";
self.page
.connection()
.send_command::<_, serde_json::Value>(
"Runtime.evaluate",
Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
expression: script.to_string(),
object_group: None,
include_command_line_api: None,
silent: Some(true),
context_id: None,
return_by_value: Some(true),
await_promise: Some(false),
}),
Some(self.page.session_id()),
)
.await?;
Ok(())
}
async fn enable_animations(&self) -> Result<(), PageError> {
let script = r"
(function() {
const style = document.getElementById('__viewpoint_disable_animations__');
if (style) style.remove();
})()
";
self.page
.connection()
.send_command::<_, serde_json::Value>(
"Runtime.evaluate",
Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
expression: script.to_string(),
object_group: None,
include_command_line_api: None,
silent: Some(true),
context_id: None,
return_by_value: Some(true),
await_promise: Some(false),
}),
Some(self.page.session_id()),
)
.await?;
Ok(())
}
}
pub(crate) fn base64_decode(input: &str) -> Result<Vec<u8>, PageError> {
const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
fn decode_char(c: u8) -> Option<u8> {
ALPHABET.iter().position(|&x| x == c).map(|p| p as u8)
}
let input = input.as_bytes();
let mut output = Vec::with_capacity(input.len() * 3 / 4);
let mut buffer = 0u32;
let mut bits = 0u8;
for &byte in input {
if byte == b'=' {
break;
}
if byte == b'\n' || byte == b'\r' || byte == b' ' {
continue;
}
let val = decode_char(byte)
.ok_or_else(|| PageError::EvaluationFailed("Invalid base64 character".to_string()))?;
buffer = (buffer << 6) | u32::from(val);
bits += 6;
if bits >= 8 {
bits -= 8;
output.push((buffer >> bits) as u8);
}
}
Ok(output)
}