use crate::{ast::Config, Error, Result};
#[derive(Debug, Clone)]
pub struct Filter {
pub filter_type: FilterType,
pub pattern: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FilterType {
ServerName,
Port,
Upstream,
Location,
SslOnly,
Directive,
}
impl Filter {
#[must_use]
pub fn new(filter_type: FilterType, pattern: impl Into<String>) -> Self {
Self {
filter_type,
pattern: pattern.into(),
}
}
pub fn apply(&self, config: &Config) -> Result<Config> {
let mut filtered = config.clone();
match self.filter_type {
FilterType::ServerName => {
Self::filter_by_server_name(&self.pattern, &mut filtered);
}
FilterType::Port => {
Self::filter_by_port(&self.pattern, &mut filtered)?;
}
FilterType::SslOnly => {
Self::filter_ssl_only(&mut filtered);
}
FilterType::Directive => {
Self::filter_by_directive(&self.pattern, &mut filtered);
}
FilterType::Upstream | FilterType::Location => {
return Err(Error::NotImplemented(format!(
"Filter type {:?} not yet implemented",
self.filter_type
)));
}
}
Ok(filtered)
}
fn filter_by_server_name(_pattern: &str, _config: &mut Config) {
}
fn filter_by_port(pattern: &str, _config: &mut Config) -> Result<()> {
let _target_port: u16 = pattern
.parse()
.map_err(|_| Error::InvalidInput(format!("Invalid port number: {pattern}")))?;
Ok(())
}
fn filter_ssl_only(_config: &mut Config) {
}
fn filter_by_directive(directive_name: &str, config: &mut Config) {
config.directives.retain(|d| d.name() == directive_name);
}
}
impl std::str::FromStr for Filter {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let parts: Vec<&str> = s.splitn(2, '=').collect();
if parts.len() != 2 {
return Err(Error::InvalidInput(format!(
"Invalid filter format. Expected: type=pattern, got: {s}"
)));
}
let filter_type = match parts[0].to_lowercase().as_str() {
"server_name" | "server" => FilterType::ServerName,
"port" => FilterType::Port,
"upstream" => FilterType::Upstream,
"location" => FilterType::Location,
"ssl" | "ssl_only" => FilterType::SslOnly,
"directive" => FilterType::Directive,
other => {
return Err(Error::InvalidInput(format!("Unknown filter type: {other}")));
}
};
Ok(Self::new(filter_type, parts[1]))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_filter() {
let filter: Filter = "server_name=*.example.com".parse().unwrap();
assert_eq!(filter.filter_type, FilterType::ServerName);
assert_eq!(filter.pattern, "*.example.com");
}
#[test]
fn test_parse_port_filter() {
let filter: Filter = "port=443".parse().unwrap();
assert_eq!(filter.filter_type, FilterType::Port);
assert_eq!(filter.pattern, "443");
}
#[test]
fn test_parse_ssl_filter() {
let filter: Filter = "ssl_only=true".parse().unwrap();
assert_eq!(filter.filter_type, FilterType::SslOnly);
}
#[test]
fn test_parse_directive_filter() {
let filter: Filter = "directive=proxy_pass".parse().unwrap();
assert_eq!(filter.filter_type, FilterType::Directive);
assert_eq!(filter.pattern, "proxy_pass");
}
#[test]
fn test_invalid_filter_format() {
let result: Result<Filter> = "invalid".parse();
assert!(result.is_err());
}
#[test]
fn test_invalid_filter_type() {
let result: Result<Filter> = "unknown_type=value".parse();
assert!(result.is_err());
}
#[test]
fn test_filter_creation() {
let filter = Filter::new(FilterType::Port, "8080");
assert_eq!(filter.filter_type, FilterType::Port);
assert_eq!(filter.pattern, "8080");
}
#[test]
fn test_port_filter_validates_number() {
let config = Config::default();
let filter = Filter::new(FilterType::Port, "not_a_number");
let result = filter.apply(&config);
assert!(result.is_err());
}
}