use bytes::{Buf, Bytes, BytesMut};
use hitbox::predicate::PredicateResult;
use http_body_util::BodyExt;
use hyper::body::Body as HttpBody;
use crate::{BufferedBody, PartialBufferedBody, Remaining};
async fn streaming_search<B>(body: BufferedBody<B>, pattern: &[u8]) -> (bool, BufferedBody<B>)
where
B: HttpBody + Unpin,
B::Data: Send,
{
if pattern.is_empty() {
return (true, body);
}
match body {
BufferedBody::Complete(Some(bytes)) => {
let found = bytes.windows(pattern.len()).any(|w| w == pattern);
(found, BufferedBody::Complete(Some(bytes)))
}
BufferedBody::Complete(None) => (false, BufferedBody::Complete(None)),
BufferedBody::Partial(partial) => {
let (prefix, remaining) = partial.into_parts();
match remaining {
Remaining::Body(body) => streaming_search_body(prefix, body, pattern).await,
Remaining::Error(error) => {
let found = prefix
.as_ref()
.map(|b| b.windows(pattern.len()).any(|w| w == pattern))
.unwrap_or(false);
(
found,
BufferedBody::Partial(PartialBufferedBody::new(
prefix,
Remaining::Error(error),
)),
)
}
}
}
BufferedBody::Passthrough(stream) => streaming_search_body(None, stream, pattern).await,
}
}
async fn streaming_search_body<B>(
initial_prefix: Option<Bytes>,
mut body: B,
pattern: &[u8],
) -> (bool, BufferedBody<B>)
where
B: HttpBody + Unpin,
B::Data: Send,
{
let mut buffer = BytesMut::new();
if let Some(prefix_bytes) = initial_prefix {
buffer.extend_from_slice(&prefix_bytes);
}
let overlap_size = pattern.len().saturating_sub(1);
loop {
match body.frame().await {
Some(Ok(frame)) => {
if let Ok(mut data) = frame.into_data() {
let search_start = buffer.len().saturating_sub(overlap_size);
buffer.extend_from_slice(&data.copy_to_bytes(data.remaining()));
if buffer[search_start..]
.windows(pattern.len())
.any(|w| w == pattern)
{
return (true, BufferedBody::Complete(Some(buffer.freeze())));
}
}
}
Some(Err(error)) => {
let buffered = if buffer.is_empty() {
None
} else {
Some(buffer.freeze())
};
let found = buffered
.as_ref()
.map(|b| b.windows(pattern.len()).any(|w| w == pattern))
.unwrap_or(false);
let result_body = BufferedBody::Partial(PartialBufferedBody::new(
buffered,
Remaining::Error(Some(error)),
));
return (found, result_body);
}
None => {
let combined = if buffer.is_empty() {
None
} else {
Some(buffer.freeze())
};
let found = combined
.as_ref()
.map(|b| b.windows(pattern.len()).any(|w| w == pattern))
.unwrap_or(false);
return (found, BufferedBody::Complete(combined));
}
}
}
}
#[derive(Debug)]
pub enum PlainOperation {
Eq(Bytes),
Contains(Bytes),
Starts(Bytes),
Ends(Bytes),
RegExp(regex::bytes::Regex),
}
impl PlainOperation {
pub async fn check<B>(&self, body: BufferedBody<B>) -> PredicateResult<BufferedBody<B>>
where
B: HttpBody + Unpin,
B::Data: Send,
{
match self {
PlainOperation::Starts(prefix) => {
if prefix.is_empty() {
return PredicateResult::Cacheable(body);
}
use crate::CollectExactResult;
let result = body.collect_exact(prefix.len()).await;
let matches = match &result {
CollectExactResult::AtLeast { buffered, .. } => buffered.starts_with(prefix),
CollectExactResult::Incomplete { .. } => false, };
let result_body = result.into_buffered_body();
if matches {
PredicateResult::Cacheable(result_body)
} else {
PredicateResult::NonCacheable(result_body)
}
}
PlainOperation::Eq(expected) => body
.collect()
.await
.map(|body_bytes| {
let matches = body_bytes.as_ref() == expected.as_ref();
let result_body = BufferedBody::Complete(Some(body_bytes));
if matches {
PredicateResult::Cacheable(result_body)
} else {
PredicateResult::NonCacheable(result_body)
}
})
.unwrap_or_else(PredicateResult::NonCacheable),
PlainOperation::Contains(sequence) => {
let (found, result_body) = streaming_search(body, sequence.as_ref()).await;
if found {
PredicateResult::Cacheable(result_body)
} else {
PredicateResult::NonCacheable(result_body)
}
}
PlainOperation::Ends(suffix) => body
.collect()
.await
.map(|body_bytes| {
let matches = body_bytes.ends_with(suffix);
let result_body = BufferedBody::Complete(Some(body_bytes));
if matches {
PredicateResult::Cacheable(result_body)
} else {
PredicateResult::NonCacheable(result_body)
}
})
.unwrap_or_else(PredicateResult::NonCacheable),
PlainOperation::RegExp(regex) => body
.collect()
.await
.map(|body_bytes| {
let matches = regex.is_match(body_bytes.as_ref());
let result_body = BufferedBody::Complete(Some(body_bytes));
if matches {
PredicateResult::Cacheable(result_body)
} else {
PredicateResult::NonCacheable(result_body)
}
})
.unwrap_or_else(PredicateResult::NonCacheable),
}
}
}