use super::{
Error, get_hash_key, get_plugin_factory, get_str_conf, get_str_slice_conf,
};
use async_trait::async_trait;
use bstr::ByteSlice;
use bytes::{Bytes, BytesMut};
use ctor::ctor;
use pingap_config::{PluginCategory, PluginConf};
use pingap_core::{
Ctx, HTTP_HEADER_TRANSFER_CHUNKED, ModifyResponseBody, Plugin,
ResponseBodyPluginResult, ResponsePluginResult,
};
use pingora::http::ResponseHeader;
use pingora::proxy::Session;
use regex::Regex;
use regex::bytes::RegexBuilder;
use std::borrow::Cow;
use std::sync::Arc;
use std::sync::LazyLock;
const PLUGIN_ID: &str = "_sub_filter_";
type Result<T, E = Error> = std::result::Result<T, E>;
pub struct SubFilter {
path: Option<Regex>,
replacer: SubFilterReplacer,
hash_value: String,
status_codes: Option<Vec<u16>>,
}
static SUBS_FILTER_REGEX: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r"(subs_filter|sub_filter)\s+'([^']+)'\s+'([^']+)'(?:\s+([ig]+))?",
)
.expect("Failed to compile subs filter regex")
});
#[derive(Debug, Default, Clone)]
struct SubFilterParams {
regex_pattern: Option<regex::bytes::Regex>,
pattern: Vec<u8>,
replacement: Vec<u8>,
flags: Vec<char>,
}
fn parse_subs_filter(rule: &str) -> Option<SubFilterParams> {
let captures = SUBS_FILTER_REGEX.captures(rule)?;
let mut params = SubFilterParams {
flags: captures
.get(4)
.map(|m| m.as_str().chars().collect())
.unwrap_or_default(),
replacement: captures.get(3)?.as_str().as_bytes().to_vec(),
..Default::default()
};
let pattern = captures.get(2)?.as_str();
match captures.get(1)?.as_str() {
"subs_filter" => {
let regex_pattern = RegexBuilder::new(pattern)
.case_insensitive(params.flags.contains(&'i'))
.build()
.ok()?;
params.regex_pattern = Some(regex_pattern);
},
_ => {
params.pattern = pattern.as_bytes().to_vec();
},
};
Some(params)
}
#[derive(Debug, Default, Clone)]
struct SubFilterReplacer {
filters: Vec<SubFilterParams>,
buffer: BytesMut,
}
impl ModifyResponseBody for SubFilterReplacer {
fn handle(
&mut self,
_session: &Session,
body: &mut Option<bytes::Bytes>,
end_of_stream: bool,
) -> pingora::Result<()> {
if let Some(data) = body {
self.buffer.extend(&data[..]);
data.clear();
}
if !end_of_stream {
return Ok(());
}
let mut data = self.buffer.to_vec();
for item in self.filters.iter() {
if let Some(regex_pattern) = &item.regex_pattern {
if item.flags.contains(&'g') {
data = regex_pattern
.replace_all(&data, &item.replacement)
.to_vec();
} else {
data = regex_pattern
.replace(&data, &item.replacement)
.to_vec();
}
} else {
if item.flags.contains(&'g') {
data = data.replace(&item.pattern, &item.replacement);
} else {
data = data.replacen(&item.pattern, &item.replacement, 1);
}
}
}
*body = Some(Bytes::from(data));
Ok(())
}
fn name(&self) -> String {
"sub_filter".to_string()
}
}
impl TryFrom<&PluginConf> for SubFilter {
type Error = Error;
fn try_from(value: &PluginConf) -> Result<Self> {
let path_value = get_str_conf(value, "path");
let path = if path_value.is_empty() {
None
} else {
let regex = Regex::new(get_str_conf(value, "path").as_str())
.map_err(|e| Error::Invalid {
category: PluginCategory::SubFilter.to_string(),
message: e.to_string(),
})?;
Some(regex)
};
let filters = get_str_slice_conf(value, "filters")
.iter()
.map(|s| {
parse_subs_filter(s).ok_or(Error::Invalid {
category: PluginCategory::SubFilter.to_string(),
message: format!("invalid subs filter: {s}"),
})
})
.collect::<Result<Vec<_>>>()?;
let status_codes = get_str_conf(value, "status_codes");
let status_codes = if !status_codes.is_empty() {
Some(
status_codes
.split(",")
.flat_map(|s| s.trim().parse::<u16>().ok())
.collect::<Vec<_>>(),
)
} else {
None
};
let hash_value = get_hash_key(value);
Ok(Self {
path,
replacer: SubFilterReplacer {
filters,
buffer: BytesMut::new(),
},
hash_value,
status_codes,
})
}
}
impl SubFilter {
pub fn new(params: &PluginConf) -> Result<Self> {
Self::try_from(params)
}
}
#[async_trait]
impl Plugin for SubFilter {
fn config_key(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.hash_value)
}
async fn handle_response(
&self,
session: &mut Session,
ctx: &mut Ctx,
upstream_response: &mut ResponseHeader,
) -> pingora::Result<ResponsePluginResult> {
if let Some(status_codes) = &self.status_codes
&& !status_codes.contains(&upstream_response.status.as_u16())
{
return Ok(ResponsePluginResult::Unchanged);
}
let mut is_matched = true;
if let Some(regex) = &self.path {
is_matched = regex.is_match(session.req_header().uri.path());
}
if is_matched {
upstream_response.remove_header(&http::header::CONTENT_LENGTH);
let _ = upstream_response.insert_header(
http::header::TRANSFER_ENCODING,
HTTP_HEADER_TRANSFER_CHUNKED.1.clone(),
);
ctx.add_modify_body_handler(
PLUGIN_ID,
Box::new(self.replacer.clone()),
);
return Ok(ResponsePluginResult::Modified);
}
Ok(ResponsePluginResult::Unchanged)
}
fn handle_response_body(
&self,
session: &mut Session,
ctx: &mut Ctx,
body: &mut Option<bytes::Bytes>,
end_of_stream: bool,
) -> pingora::Result<ResponseBodyPluginResult> {
if let Some(modifier) = ctx.get_modify_body_handler(PLUGIN_ID) {
modifier.handle(session, body, end_of_stream)?;
let result = if end_of_stream {
ResponseBodyPluginResult::FullyReplaced
} else {
ResponseBodyPluginResult::PartialReplaced
};
Ok(result)
} else {
Ok(ResponseBodyPluginResult::Unchanged)
}
}
}
#[ctor]
fn init() {
get_plugin_factory()
.register("sub_filter", |params| Ok(Arc::new(SubFilter::new(params)?)));
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_parse_subs_filter() {
let rule = "subs_filter 'http://pingap.io' 'https://pingap.io/api' ig";
let params = parse_subs_filter(rule).unwrap();
assert_eq!(params.regex_pattern.unwrap().as_str(), "http://pingap.io");
assert_eq!(params.pattern, b"");
assert_eq!(params.replacement, b"https://pingap.io/api");
assert_eq!(params.flags, vec!['i', 'g']);
let rule = "sub_filter 'http://pingap.io' 'https://pingap.io/api' ig";
let params = parse_subs_filter(rule).unwrap();
assert_eq!(params.regex_pattern.is_none(), true);
assert_eq!(params.pattern, b"http://pingap.io");
assert_eq!(params.replacement, b"https://pingap.io/api");
assert_eq!(params.flags, vec!['i', 'g']);
}
#[test]
fn test_sub_filter_replacer() {
}
}