use prometheus::{Histogram, exponential_buckets};
use crate::{
config::filter::Filter as FilterConfig,
filters::{FilterRegistry, prelude::*},
metrics::{CollectorExt, histogram_opts},
};
const FILTER_LABEL: &str = "filter";
const BUCKET_START: f64 = 0.000125;
const BUCKET_FACTOR: f64 = 2.5;
const BUCKET_COUNT: usize = 11;
#[derive(Clone, Default)]
pub struct FilterChain {
filters: Vec<(String, FilterInstance)>,
filter_read_duration_seconds: Vec<Histogram>,
filter_write_duration_seconds: Vec<Histogram>,
}
impl FilterChain {
pub fn new(filters: Vec<(String, FilterInstance)>) -> Result<Self, CreationError> {
let subsystem = "filter";
Ok(Self {
filter_read_duration_seconds: filters
.iter()
.map(|(name, _)| {
Histogram::with_opts(
histogram_opts(
"read_duration_seconds",
subsystem,
"Seconds taken to execute a given filter's `read`.",
Some(
exponential_buckets(BUCKET_START, BUCKET_FACTOR, BUCKET_COUNT)
.unwrap(),
),
)
.const_label(FILTER_LABEL, name),
)
.and_then(|histogram| histogram.register_if_not_exists())
})
.collect::<Result<_, prometheus::Error>>()?,
filter_write_duration_seconds: filters
.iter()
.map(|(name, _)| {
Histogram::with_opts(
histogram_opts(
"write_duration_seconds",
subsystem,
"Seconds taken to execute a given filter's `write`.",
Some(exponential_buckets(0.000125, 2.5, 11).unwrap()),
)
.const_label(FILTER_LABEL, name),
)
.and_then(|histogram| histogram.register_if_not_exists())
})
.collect::<Result<_, prometheus::Error>>()?,
filters,
})
}
pub fn testing<const N: usize>(filters: [FilterInstance; N]) -> Self {
let filters = filters.into_iter().map(|f| (String::new(), f)).collect();
Self::new(filters).unwrap()
}
#[inline]
pub fn len(&self) -> usize {
self.filters.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.filters.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = FilterConfig> + '_ {
self.filters.iter().map(|(name, instance)| FilterConfig {
name: name.clone(),
label: instance.label().map(String::from),
config: match instance.config() {
serde_json::Value::Null => None,
value => Some(value.clone()),
},
})
}
pub fn try_create_fallible<Item>(
filter_configs: impl IntoIterator<Item = Item>,
) -> Result<Self, CreationError>
where
Item: TryInto<FilterConfig, Error = CreationError>,
{
let mut filters = Vec::new();
for filter_config in filter_configs {
let filter_config = filter_config.try_into()?;
let filter = FilterRegistry::get(
&filter_config.name,
CreateFilterArgs::fixed(filter_config.config),
)?;
filters.push((filter_config.name, filter));
}
Self::new(filters)
}
pub fn try_create(
filter_configs: impl IntoIterator<Item = FilterConfig>,
) -> Result<Self, CreationError> {
let mut filters = Vec::new();
for filter_config in filter_configs {
let filter = FilterRegistry::get(
&filter_config.name,
CreateFilterArgs::fixed(filter_config.config),
)?;
filters.push((filter_config.name, filter));
}
Self::new(filters)
}
}
impl std::fmt::Debug for FilterChain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut filters = f.debug_struct("Filters");
for (id, instance) in &self.filters {
filters.field(id, instance.config());
}
filters.finish()
}
}
impl PartialEq for FilterChain {
fn eq(&self, rhs: &Self) -> bool {
if self.filters.len() != rhs.filters.len() {
return false;
}
self.filters.iter().zip(&rhs.filters).all(
|((lhs_name, lhs_instance), (rhs_name, rhs_instance))| {
lhs_name == rhs_name
&& lhs_instance.config() == rhs_instance.config()
&& lhs_instance.label() == rhs_instance.label()
},
)
}
}
use crate::generated::envoy::config::listener::v3::FilterChain as EnvoyFilterChain;
impl TryFrom<FilterChain> for EnvoyFilterChain {
type Error = CreationError;
fn try_from(chain: FilterChain) -> Result<Self, Self::Error> {
Self::try_from(&chain)
}
}
impl TryFrom<&'_ FilterChain> for EnvoyFilterChain {
type Error = CreationError;
fn try_from(chain: &FilterChain) -> Result<Self, Self::Error> {
Ok(Self {
filters: chain
.iter()
.map(TryFrom::try_from)
.collect::<Result<_, Self::Error>>()?,
..<_>::default()
})
}
}
impl TryFrom<&'_ FilterChain> for crate::net::cluster::proto::FilterChain {
type Error = CreationError;
fn try_from(value: &'_ FilterChain) -> Result<Self, Self::Error> {
Ok(Self {
filters: value
.iter()
.map(|filter| crate::net::cluster::proto::Filter {
name: filter.name,
label: filter.label,
config: filter.config.map(|v| v.to_string()),
})
.collect(),
})
}
}
impl std::ops::Index<usize> for FilterChain {
type Output = (String, FilterInstance);
fn index(&self, index: usize) -> &Self::Output {
&self.filters[index]
}
}
impl<'de> serde::Deserialize<'de> for FilterChain {
fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
let filters = <Vec<FilterConfig>>::deserialize(de)?;
Self::try_create(filters).map_err(serde::de::Error::custom)
}
}
impl serde::Serialize for FilterChain {
fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
let filters = self
.filters
.iter()
.map(|(name, instance)| FilterConfig {
name: name.clone(),
label: instance.label().map(String::from),
config: Some(serde_json::Value::clone(instance.config())),
})
.collect::<Vec<_>>();
filters.serialize(ser)
}
}
impl schemars::JsonSchema for FilterChain {
fn schema_name() -> std::borrow::Cow<'static, str> {
<Vec<FilterConfig>>::schema_name()
}
fn json_schema(sg: &mut schemars::generate::SchemaGenerator) -> schemars::Schema {
<Vec<FilterConfig>>::json_schema(sg)
}
}
impl Filter for FilterChain {
fn read<P: PacketMut>(&self, ctx: &mut ReadContext<'_, P>) -> Result<(), FilterError> {
for ((id, instance), histogram) in self
.filters
.iter()
.zip(self.filter_read_duration_seconds.iter())
{
tracing::trace!(%id, "read filtering packet");
let timer = histogram.start_timer();
let result = instance.filter().read(ctx);
timer.stop_and_record();
match result {
Ok(()) => tracing::trace!(%id, "read passing packet"),
Err(error) => {
tracing::trace!(%id, "read dropping packet");
return Err(error);
}
}
}
if ctx.destinations.is_empty() {
ctx.destinations
.extend(ctx.endpoints.endpoints().into_iter().map(|ep| ep.address));
}
Ok(())
}
fn write<P: PacketMut>(&self, ctx: &mut WriteContext<P>) -> Result<(), FilterError> {
for ((id, instance), histogram) in self
.filters
.iter()
.rev()
.zip(self.filter_write_duration_seconds.iter().rev())
{
tracing::trace!(%id, "write filtering packet");
let timer = histogram.start_timer();
let result = instance.filter().write(ctx);
timer.stop_and_record();
match result {
Ok(()) => tracing::trace!(%id, "write passing packet"),
Err(error) => {
tracing::trace!(%id, "write dropping packet");
return Err(error);
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::{
filters::Debug,
net::endpoint::Endpoint,
test::{TestConfig, TestFilter, alloc_buffer},
};
use super::*;
#[test]
fn from_config() {
let provider = Debug::factory();
let filter_configs = [FilterConfig {
name: provider.name().into(),
label: None,
config: Some(serde_json::Map::default().into()),
}];
let chain = FilterChain::try_create(filter_configs).unwrap();
assert_eq!(1, chain.filters.len());
let filter_configs = [FilterConfig {
name: "this is so wrong".into(),
label: None,
config: Default::default(),
}];
let result = FilterChain::try_create(filter_configs);
assert!(result.is_err());
}
fn endpoints() -> std::sync::Arc<crate::net::cluster::ClusterMap> {
crate::net::cluster::ClusterMap::new_default(
[
Endpoint::new("127.0.0.1:80".parse().unwrap()),
Endpoint::new("127.0.0.1:90".parse().unwrap()),
]
.into(),
)
.into()
}
#[tokio::test]
async fn chain_single_test_filter() {
crate::test::load_test_filters();
let config = TestConfig::new();
let endpoints_fixture = endpoints();
let mut dest = Vec::new();
let mut context = ReadContext::new(
&endpoints_fixture,
"127.0.0.1:70".parse().unwrap(),
alloc_buffer(b"hello"),
&mut dest,
);
config.filters.read(&mut context).unwrap();
let expected = endpoints_fixture.clone();
assert_eq!(&*expected.endpoints(), &*context.destinations);
assert_eq!(
"hello:odr:127.0.0.1:70",
std::str::from_utf8(&context.contents).unwrap()
);
assert_eq!(
"receive",
context.metadata[&"downstream".into()].as_string().unwrap()
);
let mut context = WriteContext::new(
endpoints_fixture
.endpoints()
.first()
.unwrap()
.address
.clone(),
"127.0.0.1:70".parse().unwrap(),
alloc_buffer(b"hello"),
);
config.filters.write(&mut context).unwrap();
assert_eq!(
"receive",
context.metadata[&"upstream".into()].as_string().unwrap()
);
assert_eq!(b"hello:our:127.0.0.1:80:127.0.0.1:70", &*context.contents,);
}
#[tokio::test]
async fn chain_double_test_filter() {
let chain = FilterChain::new(vec![
(
TestFilter::NAME.into(),
FilterInstance::new(serde_json::json!(null), TestFilter.into()),
),
(
TestFilter::NAME.into(),
FilterInstance::new(serde_json::json!(null), TestFilter.into()),
),
])
.unwrap();
let endpoints_fixture = endpoints();
let mut dest = Vec::new();
let (contents, metadata) = {
let mut context = ReadContext::new(
&endpoints_fixture,
"127.0.0.1:70".parse().unwrap(),
alloc_buffer(b"hello"),
&mut dest,
);
chain.read(&mut context).unwrap();
(context.contents, context.metadata)
};
let expected = endpoints_fixture.clone();
assert_eq!(expected.endpoints(), dest);
assert_eq!(b"hello:odr:127.0.0.1:70:odr:127.0.0.1:70", &*contents);
assert_eq!(
"receive:receive",
metadata[&"downstream".into()].as_string().unwrap()
);
let mut context = WriteContext::new(
endpoints_fixture
.endpoints()
.first()
.unwrap()
.address
.clone(),
"127.0.0.1:70".parse().unwrap(),
alloc_buffer(b"hello"),
);
chain.write(&mut context).unwrap();
assert_eq!(
"hello:our:127.0.0.1:80:127.0.0.1:70:our:127.0.0.1:80:127.0.0.1:70",
std::str::from_utf8(&context.contents).unwrap(),
);
assert_eq!(
"receive:receive",
context.metadata[&"upstream".into()].as_string().unwrap()
);
}
#[test]
fn get_configs() {
let filter_chain = FilterChain::new(vec![(
"TestFilter".into(),
FilterInstance::new(serde_json::json!(null), TestFilter.into()),
)])
.unwrap();
let configs = filter_chain.iter().collect::<Vec<_>>();
assert_eq!(
vec![FilterConfig {
name: "TestFilter".into(),
label: None,
config: None,
},],
configs
);
}
}