shotover 0.7.2

Shotover API for building custom transforms
Documentation
use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol};
use crate::message::{Message, MessageIdMap, Messages, QueryType};
use crate::transforms::{ChainState, Transform, TransformBuilder, TransformConfig};
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub enum Filter {
    AllowList(Vec<QueryType>),
    DenyList(Vec<QueryType>),
}

#[derive(Clone)]
pub struct QueryTypeFilter {
    pub filter: Filter,
    pub filtered_requests: MessageIdMap<Message>,
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct QueryTypeFilterConfig {
    #[serde(flatten)]
    pub filter: Filter,
}

const NAME: &str = "QueryTypeFilter";
#[typetag::serde(name = "QueryTypeFilter")]
#[async_trait(?Send)]
impl TransformConfig for QueryTypeFilterConfig {
    async fn get_builder(
        &self,
        _transform_context: TransformContextConfig,
    ) -> Result<Box<dyn TransformBuilder>> {
        Ok(Box::new(QueryTypeFilter {
            filter: self.filter.clone(),
            filtered_requests: MessageIdMap::default(),
        }))
    }

    fn up_chain_protocol(&self) -> UpChainProtocol {
        UpChainProtocol::Any
    }

    fn down_chain_protocol(&self) -> DownChainProtocol {
        DownChainProtocol::SameAsUpChain
    }
}

impl TransformBuilder for QueryTypeFilter {
    fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
        Box::new(self.clone())
    }

    fn get_name(&self) -> &'static str {
        NAME
    }
}

#[async_trait]
impl Transform for QueryTypeFilter {
    fn get_name(&self) -> &'static str {
        NAME
    }

    async fn transform<'shorter, 'longer: 'shorter>(
        &mut self,
        chain_state: &'shorter mut ChainState<'longer>,
    ) -> Result<Messages> {
        for request in chain_state.requests.iter_mut() {
            let filter_out = match &self.filter {
                Filter::AllowList(allow_list) => !allow_list.contains(&request.get_query_type()),
                Filter::DenyList(deny_list) => deny_list.contains(&request.get_query_type()),
            };

            if filter_out {
                self.filtered_requests.insert(
                    request.id(),
                    request
                        .from_request_to_error_response(
                            "Message was filtered out by shotover".to_owned(),
                        )
                        .map_err(|e| e.context("Failed to filter message"))?,
                );
                request.replace_with_dummy();
            }
        }

        let mut responses = chain_state.call_next_transform().await?;
        for response in responses.iter_mut() {
            if let Some(request_id) = response.request_id() {
                if let Some(error_response) = self.filtered_requests.remove(&request_id) {
                    *response = error_response;
                }
            }
        }

        Ok(responses)
    }
}

#[cfg(all(test, feature = "valkey"))]
mod test {
    use super::Filter;
    use crate::frame::Frame;
    use crate::frame::ValkeyFrame;
    use crate::message::MessageIdMap;
    use crate::message::{Message, QueryType};
    use crate::transforms::chain::TransformAndMetrics;
    use crate::transforms::filter::QueryTypeFilter;
    use crate::transforms::loopback::Loopback;
    use crate::transforms::{ChainState, Transform};
    use pretty_assertions::assert_eq;

    #[tokio::test(flavor = "multi_thread")]
    async fn test_filter_denylist() {
        let mut filter_transform = QueryTypeFilter {
            filter: Filter::DenyList(vec![QueryType::Read]),
            filtered_requests: MessageIdMap::default(),
        };

        let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))];

        let messages: Vec<_> = (0..26)
            .map(|i| {
                if i % 2 == 0 {
                    Message::from_frame(Frame::Valkey(ValkeyFrame::Array(vec![
                        ValkeyFrame::BulkString("GET".into()),
                        ValkeyFrame::BulkString("key".into()),
                    ])))
                } else {
                    Message::from_frame(Frame::Valkey(ValkeyFrame::Array(vec![
                        ValkeyFrame::BulkString("SET".into()),
                        ValkeyFrame::BulkString("key".into()),
                        ValkeyFrame::BulkString("value".into()),
                    ])))
                }
            })
            .collect();

        let mut chain_state = ChainState::new_test(messages);
        chain_state.reset(&mut chain);
        let result = filter_transform.transform(&mut chain_state).await.unwrap();

        assert_eq!(result.len(), 26);

        for (i, mut message) in result.into_iter().enumerate() {
            if let Some(frame) = message.frame() {
                if i % 2 == 0 {
                    assert_eq!(
                        frame,
                        &Frame::Valkey(ValkeyFrame::Error(
                            "ERR Message was filtered out by shotover".into()
                        )),
                    )
                } else {
                    assert_eq!(
                        frame,
                        &Frame::Valkey(ValkeyFrame::Array(vec![
                            ValkeyFrame::BulkString("SET".into()),
                            ValkeyFrame::BulkString("key".into()),
                            ValkeyFrame::BulkString("value".into()),
                        ]))
                    )
                }
            }
        }
    }

    #[tokio::test(flavor = "multi_thread")]
    async fn test_filter_allowlist() {
        let mut filter_transform = QueryTypeFilter {
            filter: Filter::AllowList(vec![QueryType::Write]),
            filtered_requests: MessageIdMap::default(),
        };

        let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))];

        let messages: Vec<_> = (0..26)
            .map(|i| {
                if i % 2 == 0 {
                    Message::from_frame(Frame::Valkey(ValkeyFrame::Array(vec![
                        ValkeyFrame::BulkString("GET".into()),
                        ValkeyFrame::BulkString("key".into()),
                    ])))
                } else {
                    Message::from_frame(Frame::Valkey(ValkeyFrame::Array(vec![
                        ValkeyFrame::BulkString("SET".into()),
                        ValkeyFrame::BulkString("key".into()),
                        ValkeyFrame::BulkString("value".into()),
                    ])))
                }
            })
            .collect();

        let mut chain_state = ChainState::new_test(messages);
        chain_state.reset(&mut chain);
        let result = filter_transform.transform(&mut chain_state).await.unwrap();

        assert_eq!(result.len(), 26);

        for (i, mut message) in result.into_iter().enumerate() {
            if let Some(frame) = message.frame() {
                if i % 2 == 0 {
                    assert_eq!(
                        frame,
                        &Frame::Valkey(ValkeyFrame::Error(
                            "ERR Message was filtered out by shotover".into()
                        )),
                    )
                } else {
                    assert_eq!(
                        frame,
                        &Frame::Valkey(ValkeyFrame::Array(vec![
                            ValkeyFrame::BulkString("SET".into()),
                            ValkeyFrame::BulkString("key".into()),
                            ValkeyFrame::BulkString("value".into()),
                        ]))
                    )
                }
            }
        }
    }
}