use std::collections::HashMap;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use futures_util::stream::Stream;
use crate::error::ScrapflyError;
use crate::result::scrape::ScrapeResult;
const CRLF: &[u8] = b"\r\n";
const DOUBLE_CRLF: &[u8] = b"\r\n\r\n";
#[derive(Debug)]
pub struct BatchPart {
pub headers: HashMap<String, String>,
pub body: Bytes,
}
#[derive(Debug)]
pub struct BatchProxifiedResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: Bytes,
}
impl BatchProxifiedResponse {
pub fn text(&self) -> String {
String::from_utf8_lossy(&self.body).into_owned()
}
pub fn content_type(&self) -> Option<&str> {
self.headers.get("content-type").map(String::as_str)
}
pub fn scrapfly_log(&self) -> Option<&str> {
self.headers
.get("x-scrapfly-log")
.or_else(|| self.headers.get("x-scrapfly-log-uuid"))
.map(String::as_str)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BatchFormat {
#[default]
Json,
Msgpack,
}
impl BatchFormat {
pub(crate) fn accept_header(self) -> &'static str {
match self {
BatchFormat::Json => "application/json",
BatchFormat::Msgpack => "application/msgpack",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BatchOptions {
pub format: BatchFormat,
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub enum BatchOutcome {
Scrape(ScrapeResult),
Proxified(BatchProxifiedResponse),
Err(ScrapflyError),
}
fn find_subslice(buf: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() {
return Some(0);
}
if buf.len() < needle.len() {
return None;
}
(0..=buf.len() - needle.len()).find(|&i| &buf[i..i + needle.len()] == needle)
}
fn parse_content_type(value: &str) -> (String, HashMap<String, String>) {
if let Some(idx) = value.find(';') {
let mime = value[..idx].trim().to_lowercase();
let mut params = HashMap::new();
for piece in value[idx + 1..].split(';') {
if let Some(eq) = piece.find('=') {
let k = piece[..eq].trim().to_lowercase();
let mut v = piece[eq + 1..].trim().to_string();
if v.starts_with('"') && v.ends_with('"') && v.len() >= 2 {
v = v[1..v.len() - 1].to_string();
}
params.insert(k, v);
}
}
(mime, params)
} else {
(value.trim().to_lowercase(), HashMap::new())
}
}
pub struct BatchPartStream<S> {
inner: S,
boundary_line: Vec<u8>,
boundary_sep: Vec<u8>,
buf: BytesMut,
state: State,
done: bool,
}
enum State {
FindFirstBoundary,
BoundarySuffix,
Headers,
Body {
headers: HashMap<String, String>,
content_length: Option<usize>,
},
ConsumeSeparator,
Done,
}
impl<S> BatchPartStream<S>
where
S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
{
pub fn new(stream: S, boundary: &str) -> Self {
let boundary_line = format!("--{}", boundary).into_bytes();
let boundary_sep = format!("\r\n--{}", boundary).into_bytes();
Self {
inner: stream,
boundary_line,
boundary_sep,
buf: BytesMut::new(),
state: State::FindFirstBoundary,
done: false,
}
}
}
impl<S> Stream for BatchPartStream<S>
where
S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
{
type Item = Result<BatchPart, ScrapflyError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let this = &mut *self;
match &mut this.state {
State::Done => return Poll::Ready(None),
State::FindFirstBoundary => {
if let Some(idx) = find_subslice(&this.buf, &this.boundary_line) {
let _ = this.buf.split_to(idx + this.boundary_line.len());
this.state = State::BoundarySuffix;
continue;
}
}
State::BoundarySuffix => {
if this.buf.len() < 2 {
} else {
let head = &this.buf[..2];
if head == b"--" {
this.state = State::Done;
return Poll::Ready(None);
}
if head == CRLF {
let _ = this.buf.split_to(2);
this.state = State::Headers;
continue;
}
if this.buf[0] == b'\n' {
let _ = this.buf.split_to(1);
this.state = State::Headers;
continue;
}
this.state = State::Done;
return Poll::Ready(None);
}
}
State::Headers => {
if let Some(idx) = find_subslice(&this.buf, DOUBLE_CRLF) {
let header_block = this.buf.split_to(idx).freeze();
let _ = this.buf.split_to(DOUBLE_CRLF.len());
let mut headers: HashMap<String, String> = HashMap::new();
let bytes_ref: &[u8] = header_block.as_ref();
for line in bytes_ref.split(|b: &u8| *b == b'\n') {
let line: &[u8] = if let Some(l) = line.strip_suffix(&[b'\r'][..]) {
l
} else {
line
};
if line.is_empty() {
continue;
}
let s = match std::str::from_utf8(line) {
Ok(s) => s,
Err(_) => continue,
};
if let Some(colon) = s.find(':') {
let k = s[..colon].trim().to_lowercase();
let v = s[colon + 1..].trim().to_string();
headers.insert(k, v);
}
}
let content_length = headers
.get("content-length")
.and_then(|v| v.parse::<usize>().ok());
this.state = State::Body {
headers,
content_length,
};
continue;
}
}
State::Body {
headers,
content_length,
} => {
let (body_end, consume_sep_after_yield) = match *content_length {
Some(cl) if this.buf.len() >= cl => (Some(cl), true),
Some(_) => (None, false),
None => (find_subslice(&this.buf, &this.boundary_sep), false),
};
if let Some(end) = body_end {
let body = this.buf.split_to(end).freeze();
let part = BatchPart {
headers: std::mem::take(headers),
body,
};
if consume_sep_after_yield {
this.state = State::ConsumeSeparator;
} else {
let _ = this.buf.split_to(this.boundary_sep.len());
this.state = State::BoundarySuffix;
}
return Poll::Ready(Some(Ok(part)));
}
}
State::ConsumeSeparator => {
if let Some(idx) = find_subslice(&this.buf, &this.boundary_sep) {
let _ = this.buf.split_to(idx + this.boundary_sep.len());
this.state = State::BoundarySuffix;
continue;
}
}
}
if this.done {
return Poll::Ready(None);
}
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
this.done = true;
continue;
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(ScrapflyError::Config(format!(
"batch stream error: {}",
e
)))));
}
Poll::Ready(Some(Ok(bytes))) => {
this.buf.extend_from_slice(&bytes);
continue;
}
}
}
}
}
pub fn parts_from_response(
resp: reqwest::Response,
) -> Result<BatchPartStream<impl Stream<Item = Result<Bytes, reqwest::Error>>>, ScrapflyError> {
let ct = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let (mime, params) = parse_content_type(&ct);
if mime != "multipart/mixed" {
return Err(ScrapflyError::Config(format!(
"scrape_batch: expected Content-Type multipart/mixed, got {:?}",
ct
)));
}
let boundary = params.get("boundary").cloned().ok_or_else(|| {
ScrapflyError::Config(format!(
"scrape_batch: Content-Type multipart/mixed missing boundary: {:?}",
ct
))
})?;
Ok(BatchPartStream::new(resp.bytes_stream(), &boundary))
}
const UPSTREAM_PREFIX: &str = "x-scrapfly-upstream-";
pub fn build_proxified_response(part: BatchPart) -> BatchProxifiedResponse {
let status: u16 = part
.headers
.get("x-scrapfly-scrape-status")
.and_then(|s| s.parse().ok())
.unwrap_or(200);
let mut out_headers: HashMap<String, String> = HashMap::new();
for (key, value) in &part.headers {
if key == "content-type" {
out_headers.insert("content-type".into(), value.clone());
} else if let Some(stripped) = key.strip_prefix(UPSTREAM_PREFIX) {
out_headers.insert(stripped.to_string(), value.clone());
} else if key.starts_with("x-scrapfly-") {
out_headers.insert(key.clone(), value.clone());
}
}
if !out_headers.contains_key("x-scrapfly-log") {
if let Some(log_uuid) = out_headers.get("x-scrapfly-log-uuid").cloned() {
out_headers.insert("x-scrapfly-log".into(), log_uuid);
}
}
BatchProxifiedResponse {
status,
headers: out_headers,
body: part.body,
}
}
pub fn decode_part_body<T: serde::de::DeserializeOwned>(
part: &BatchPart,
) -> Result<T, ScrapflyError> {
let ct = part
.headers
.get("content-type")
.cloned()
.unwrap_or_else(|| "application/json".to_string());
if ct.starts_with("application/json") {
return serde_json::from_slice::<T>(&part.body)
.map_err(|e| ScrapflyError::Config(format!("scrape_batch: decode JSON part: {}", e)));
}
if ct.starts_with("application/msgpack") || ct.starts_with("application/x-msgpack") {
return rmp_serde::from_slice::<T>(&part.body).map_err(|e| {
ScrapflyError::Config(format!("scrape_batch: decode msgpack part: {}", e))
});
}
Err(ScrapflyError::Config(format!(
"scrape_batch: unsupported part Content-Type: {:?}",
ct
)))
}