ic_bn_lib/http/middleware/
waf.rs

1use std::{
2    net::IpAddr,
3    num::NonZeroU32,
4    path::PathBuf,
5    str::FromStr,
6    sync::Arc,
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use ahash::RandomState;
12use anyhow::{Context as _, anyhow};
13use arc_swap::ArcSwap;
14use async_trait::async_trait;
15use axum::{
16    Router,
17    extract::{Request as AxumRequest, State},
18    response::{IntoResponse, Response as AxumResponse},
19    routing::post,
20};
21use bytes::Bytes;
22use futures::future::BoxFuture;
23use governor::{
24    Quota, RateLimiter,
25    clock::{Clock, DefaultClock},
26    state::{InMemoryState, NotKeyed, keyed::DashMapStateStore},
27};
28use http::{
29    HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, Version,
30    header::{HOST, RETRY_AFTER},
31};
32use humantime::parse_duration;
33use ic_bn_lib_common::{
34    traits::{Run, http::Client},
35    types::http::{Error, WafCli},
36};
37use itertools::Itertools;
38use regex::Regex;
39use serde::Deserialize;
40use serde_with::{DeserializeFromStr, DisplayFromStr, serde_as};
41use tokio::{fs, select, time::interval};
42use tokio_util::sync::CancellationToken;
43use tower::{Layer, Service};
44use tracing::warn;
45use url::Url;
46
47use crate::http::middleware::extract_ip_from_request;
48
49/// Matches HTTP status codes or ranges
50#[derive(Debug, Clone, Copy, Eq, PartialEq, DeserializeFromStr)]
51pub struct StatusRange {
52    from: u16,
53    to: Option<u16>,
54}
55
56impl StatusRange {
57    /// Check status code against the range
58    pub const fn evaluate(&self, v: StatusCode) -> bool {
59        let code = v.as_u16();
60
61        if let Some(to) = self.to {
62            return code >= self.from && code <= to;
63        }
64
65        code == self.from
66    }
67}
68
69impl FromStr for StatusRange {
70    type Err = Error;
71
72    fn from_str(s: &str) -> Result<Self, Self::Err> {
73        let mut it = s.split('-');
74        let (from, to) = (it.next().unwrap(), it.next());
75        let from: u16 = from
76            .trim()
77            .parse()
78            .context("unable to parse status range start")?;
79
80        if !(100..=599).contains(&from) {
81            return Err(anyhow!("Status code can be between 100 and 599, not {from}").into());
82        }
83
84        let to = if let Some(v) = to {
85            let v = v
86                .trim()
87                .parse()
88                .context("unable to parse status range end")?;
89
90            if !(100..=599).contains(&v) {
91                return Err(anyhow!("Status code can be between 100 and 599, not {v}").into());
92            }
93
94            if v <= from {
95                return Err(anyhow!(
96                    "End of the range should be greater than start ({v} > {from})"
97                )
98                .into());
99            }
100
101            Some(v)
102        } else {
103            None
104        };
105
106        Ok(Self { from, to })
107    }
108}
109
110/// Matches headers
111#[serde_as]
112#[derive(Debug, Clone, Deserialize)]
113pub struct HeaderMatcher {
114    #[serde_as(as = "DisplayFromStr")]
115    pub name: HeaderName,
116    #[serde_as(as = "DisplayFromStr")]
117    #[serde(alias = "value")]
118    pub regex: Regex,
119}
120
121impl PartialEq for HeaderMatcher {
122    fn eq(&self, other: &Self) -> bool {
123        self.name == other.name && self.regex.as_str() == other.regex.as_str()
124    }
125}
126impl Eq for HeaderMatcher {}
127
128impl Ord for HeaderMatcher {
129    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
130        self.name.as_str().cmp(other.name.as_str())
131    }
132}
133impl PartialOrd for HeaderMatcher {
134    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
135        Some(self.cmp(other))
136    }
137}
138
139impl HeaderMatcher {
140    /// Check if the header name/value matches
141    pub fn evaluate(&self, name: &HeaderName, value: &HeaderValue) -> bool {
142        if name != self.name {
143            return false;
144        }
145
146        let Ok(value) = value.to_str() else {
147            return false;
148        };
149
150        self.regex.is_match(value)
151    }
152
153    /// Check if the header map matches
154    pub fn evaluate_headermap(&self, map: &HeaderMap) -> bool {
155        map.iter().any(|(name, value)| self.evaluate(name, value))
156    }
157}
158
159/// Matches against HTTP requests
160#[serde_as]
161#[derive(Debug, Clone, Deserialize)]
162pub struct RequestMatcher {
163    #[serde_as(as = "Option<DisplayFromStr>")]
164    pub host: Option<Regex>,
165    #[serde_as(as = "Option<DisplayFromStr>")]
166    pub path: Option<Regex>,
167    #[serde_as(as = "Option<Vec<DisplayFromStr>>")]
168    pub methods: Option<Vec<Method>>,
169    pub headers: Option<Vec<HeaderMatcher>>,
170}
171
172impl PartialEq for RequestMatcher {
173    fn eq(&self, other: &Self) -> bool {
174        self.methods == other.methods
175        // Sort header matchers before comparison
176            && self
177                .headers
178                .as_ref()
179                .map(|x| x.clone().into_iter().sorted().collect::<Vec<_>>())
180                == other
181                    .headers
182                    .as_ref()
183                    .map(|x| x.clone().into_iter().sorted().collect::<Vec<_>>())
184            && self.host.as_ref().map(|x| x.as_str()) == other.host.as_ref().map(|x| x.as_str())
185            && self.path.as_ref().map(|x| x.as_str()) == other.path.as_ref().map(|x| x.as_str())
186    }
187}
188impl Eq for RequestMatcher {}
189
190impl RequestMatcher {
191    /// Check if the request matches
192    pub fn evaluate<T>(&self, req: &Request<T>) -> bool {
193        // Check if host matches
194        if let Some(v) = &self.host {
195            let host = match req.version() {
196                // With <HTTP/2 the host portion of the URI is not populated,
197                // so extract the Host header.
198                Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => req
199                    .headers()
200                    .get(HOST)
201                    .and_then(|x| x.to_str().ok())
202                    .unwrap_or_default(),
203
204                // With >=HTTP/2 it is the other way around - there's no Host header.
205                _ => req.uri().host().unwrap_or_default(),
206            };
207
208            if !v.is_match(host) {
209                return false;
210            }
211        }
212
213        // Check if path matches
214        if let Some(v) = &self.path
215            && !v.is_match(
216                req.uri()
217                    .path_and_query()
218                    .map(|x| x.as_str())
219                    .unwrap_or_default(),
220            )
221        {
222            return false;
223        }
224
225        // Check if any methods match
226        if let Some(v) = &self.methods
227            && !v.iter().contains(req.method())
228        {
229            return false;
230        }
231
232        // Check that all of header rules match
233        if let Some(v) = &self.headers
234            && !v.iter().all(|rule| rule.evaluate_headermap(req.headers()))
235        {
236            return false;
237        }
238
239        // Empty rule matches anything
240        true
241    }
242}
243
244/// Matches against HTTP responses
245#[serde_as]
246#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
247pub struct ResponseMatcher {
248    pub headers: Option<Vec<HeaderMatcher>>,
249    pub status: Option<Vec<StatusRange>>,
250}
251
252impl ResponseMatcher {
253    /// Check if the response matches
254    pub fn evaluate<T>(&self, req: &Response<T>) -> bool {
255        // Check status codes
256        if let Some(v) = &self.status
257            && !v.iter().any(|x| x.evaluate(req.status()))
258        {
259            return false;
260        }
261
262        // Check that all of header rules match
263        if let Some(v) = &self.headers
264            && !v.iter().all(|rule| rule.evaluate_headermap(req.headers()))
265        {
266            return false;
267        }
268
269        // Empty rule matches anything
270        true
271    }
272}
273
274/// Decision on the rate limit
275#[derive(Debug, Clone, Eq, PartialEq)]
276pub enum RateLimitDecision {
277    Pass,
278    Throttle(Duration),
279}
280
281/// Type of the rate limiting applied
282#[derive(Debug)]
283pub enum RateLimitType {
284    Global(Quota, RateLimiter<NotKeyed, InMemoryState, DefaultClock>),
285    PerIp(
286        Quota,
287        RateLimiter<IpAddr, DashMapStateStore<IpAddr, RandomState>, DefaultClock>,
288    ),
289}
290
291impl FromStr for RateLimitType {
292    type Err = Error;
293
294    fn from_str(s: &str) -> Result<Self, Self::Err> {
295        let Some((typ, limit)) = s.split_once(':') else {
296            return Err(anyhow!("expecting limit in 'type:rate' format").into());
297        };
298
299        let Some((rate, dur)) = limit.split_once("/") else {
300            return Err(anyhow!("expecting rate in 'rate/duration' format").into());
301        };
302
303        let rate = rate.parse::<u32>().context("unable to parse rate as u32")?;
304        let dur = parse_duration(dur).context("unable to parse duration")?;
305
306        if rate == 0 {
307            return Err(anyhow!("rate must be > 0").into());
308        }
309
310        if dur == Duration::ZERO {
311            return Err(anyhow!("duration cannot be zero").into());
312        }
313
314        // We already checked that rate is > 0
315        let replenish_period = dur / rate;
316        let quota = Quota::with_period(replenish_period)
317            .unwrap()
318            .allow_burst(NonZeroU32::new(rate).unwrap());
319
320        Ok(match typ {
321            "global" => Self::Global(quota, RateLimiter::direct(quota)),
322            "per_ip" => Self::PerIp(
323                quota,
324                RateLimiter::dashmap_with_hasher(quota, RandomState::new()),
325            ),
326            _ => return Err(anyhow!("unknown rate limiter type {typ}").into()),
327        })
328    }
329}
330
331impl PartialEq for RateLimitType {
332    fn eq(&self, other: &Self) -> bool {
333        match (self, other) {
334            (Self::Global(q1, _), Self::Global(q2, _)) => q1 == q2,
335            (Self::PerIp(q1, _), Self::PerIp(q2, _)) => q1 == q2,
336            _ => false,
337        }
338    }
339}
340impl Eq for RateLimitType {}
341
342impl RateLimitType {
343    /// Evaluate the request against the rate limit
344    pub fn allowed<B>(&self, req: &Request<B>) -> RateLimitDecision {
345        let (clock, r) = match self {
346            Self::Global(_, v) => (v.clock(), v.check()),
347            Self::PerIp(_, v) => {
348                // Allow if we fail to extract IP.
349                // It shouldn't happen ever under normal workload
350                // and it's probably better to allow the request in this case.
351                let Some(ip) = extract_ip_from_request(req) else {
352                    return RateLimitDecision::Pass;
353                };
354
355                (v.clock(), v.check_key(&ip))
356            }
357        };
358
359        if let Err(e) = r {
360            let dur = e.wait_time_from(clock.now());
361            return RateLimitDecision::Throttle(dur);
362        }
363
364        RateLimitDecision::Pass
365    }
366}
367
368/// Action that applies to the requests
369#[derive(Debug, PartialEq, Eq)]
370pub enum RequestAction {
371    Pass,
372    Block(StatusCode),
373    RateLimit(RateLimitType),
374}
375
376impl FromStr for RequestAction {
377    type Err = Error;
378
379    fn from_str(s: &str) -> Result<Self, Self::Err> {
380        if s == "pass" {
381            return Ok(Self::Pass);
382        }
383
384        let mut it = s.split(':');
385        let (pfx, sfx) = (it.next().unwrap(), it.next());
386        if pfx == "block" {
387            let code = if let Some(code) = sfx {
388                StatusCode::from_str(code).context("unable to parse status code")?
389            } else {
390                StatusCode::FORBIDDEN
391            };
392
393            return Ok(Self::Block(code));
394        }
395
396        if pfx == "limit" {
397            let Some((_, v)) = s.split_once(':') else {
398                return Err(anyhow!("expecting limit definition after ':'").into());
399            };
400
401            return Ok(Self::RateLimit(RateLimitType::from_str(v)?));
402        }
403
404        Err(anyhow!("unsupported action format").into())
405    }
406}
407
408/// Action that applies to the responses
409#[derive(Debug, PartialEq, Eq)]
410pub enum ResponseAction {
411    Pass,
412    Block(StatusCode),
413}
414
415impl FromStr for ResponseAction {
416    type Err = Error;
417
418    fn from_str(s: &str) -> Result<Self, Self::Err> {
419        if s == "pass" {
420            return Ok(Self::Pass);
421        }
422
423        let mut it = s.split(':');
424        let (pfx, sfx) = (it.next().unwrap(), it.next());
425        if pfx == "block" {
426            let code = if let Some(code) = sfx {
427                StatusCode::from_str(code).context("unable to parse status code")?
428            } else {
429                StatusCode::FORBIDDEN
430            };
431
432            return Ok(Self::Block(code));
433        }
434
435        Err(anyhow!("unsupported action format").into())
436    }
437}
438
439/// Outcome of the rule evaluation
440#[derive(Debug, Clone, PartialEq, Eq)]
441pub enum Decision {
442    Pass,
443    Block(StatusCode),
444    Throttle(Duration),
445}
446
447impl IntoResponse for Decision {
448    fn into_response(self) -> AxumResponse {
449        match self {
450            Self::Pass => StatusCode::OK.into_response(),
451            Self::Block(v) => (v, "Blocked for policy reasons").into_response(),
452            Self::Throttle(v) => (
453                StatusCode::TOO_MANY_REQUESTS,
454                [(
455                    RETRY_AFTER,
456                    HeaderValue::from_str(&(v.as_secs() + 1).to_string()).unwrap(),
457                )],
458                "Request was rate-limited, consult Retry-After header for a number of seconds after which it can be retried",
459            )
460                .into_response(),
461        }
462    }
463}
464
465/// Request rule
466#[serde_as]
467#[derive(Debug, PartialEq, Eq, Deserialize)]
468pub struct RequestRule {
469    #[serde(alias = "match")]
470    pub matcher: RequestMatcher,
471    #[serde_as(as = "DisplayFromStr")]
472    pub action: RequestAction,
473}
474
475impl RequestRule {
476    pub fn evaluate<B>(&self, req: &Request<B>) -> Option<Decision> {
477        if !self.matcher.evaluate(req) {
478            return None;
479        }
480
481        Some(match &self.action {
482            RequestAction::Pass => Decision::Pass,
483            RequestAction::Block(v) => Decision::Block(*v),
484            RequestAction::RateLimit(v) => match v.allowed(req) {
485                RateLimitDecision::Pass => Decision::Pass,
486                RateLimitDecision::Throttle(v) => Decision::Throttle(v),
487            },
488        })
489    }
490}
491
492/// Response rule
493#[serde_as]
494#[derive(Debug, PartialEq, Eq, Deserialize)]
495pub struct ResponseRule {
496    #[serde(alias = "match_req")]
497    pub matcher_req: Option<RequestMatcher>,
498    #[serde(alias = "match_resp")]
499    pub matcher: ResponseMatcher,
500    #[serde_as(as = "DisplayFromStr")]
501    pub action: ResponseAction,
502}
503
504impl ResponseRule {
505    pub fn evaluate<B1, B2>(&self, req: &Request<B1>, resp: &Response<B2>) -> Option<Decision> {
506        if let Some(v) = &self.matcher_req
507            && !v.evaluate(req)
508        {
509            return None;
510        }
511
512        if !self.matcher.evaluate(resp) {
513            return None;
514        }
515
516        Some(match &self.action {
517            ResponseAction::Pass => Decision::Pass,
518            ResponseAction::Block(v) => Decision::Block(*v),
519        })
520    }
521}
522
523/// Ruleset
524#[derive(Debug, PartialEq, Eq, Deserialize, Default)]
525pub struct Ruleset {
526    pub requests: Option<Vec<RequestRule>>,
527    pub responses: Option<Vec<ResponseRule>>,
528}
529
530impl Ruleset {
531    fn is_empty(&self) -> bool {
532        (self.requests.is_none() || self.requests.as_ref().map(|x| x.is_empty()) == Some(true))
533            && (self.responses.is_none()
534                || self.responses.as_ref().map(|x| x.is_empty()) == Some(true))
535    }
536
537    /// Evaluate given request against ruleset
538    fn evaluate_request<B>(&self, req: &Request<B>) -> Decision {
539        let Some(v) = &self.requests else {
540            return Decision::Pass;
541        };
542
543        v.iter()
544            .find_map(|x| x.evaluate(req))
545            .unwrap_or(Decision::Pass)
546    }
547
548    /// Evaluate given request parts & response against ruleset
549    fn evaluate_response<B1, B2>(&self, req: &Request<B1>, resp: &Response<B2>) -> Decision {
550        let Some(v) = &self.responses else {
551            return Decision::Pass;
552        };
553
554        v.iter()
555            .find_map(|x| x.evaluate(req, resp))
556            .unwrap_or(Decision::Pass)
557    }
558}
559
560/// Web Application Firewall
561#[derive(Debug, Clone)]
562pub struct Waf<S> {
563    ruleset: Arc<ArcSwap<Ruleset>>,
564    inner: S,
565}
566
567/// Implement Tower Service for Waf
568impl<S> Service<AxumRequest> for Waf<S>
569where
570    S: Service<AxumRequest, Response = AxumResponse> + Send + 'static,
571    S::Future: Send + 'static,
572{
573    type Response = S::Response;
574    type Error = S::Error;
575    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
576
577    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
578        self.inner.poll_ready(cx)
579    }
580
581    fn call(&mut self, request: AxumRequest) -> Self::Future {
582        let ruleset = self.ruleset.load_full();
583        // Fast-path
584        if ruleset.is_empty() {
585            let fut = self.inner.call(request);
586            return Box::pin(fut);
587        }
588
589        // Evaluate the request
590        let decision = ruleset.evaluate_request(&request);
591        if decision != Decision::Pass {
592            return Box::pin(async move { Ok(decision.into_response()) });
593        }
594
595        // Clone the parts for later evaluation
596        let (parts, body) = request.into_parts();
597        let parts_clone = parts.clone();
598        let request = AxumRequest::from_parts(parts, body);
599
600        let future = self.inner.call(request);
601        Box::pin(async move {
602            let response: AxumResponse = future.await?;
603
604            // Evaluate the response
605            let req = Request::from_parts(parts_clone, ());
606            let decision = ruleset.evaluate_response(&req, &response);
607            if decision != Decision::Pass {
608                return Ok(decision.into_response());
609            }
610
611            Ok(response)
612        })
613    }
614}
615
616fn parse_ruleset(data: &[u8]) -> Result<Ruleset, Error> {
617    let ruleset: Ruleset = serde_json::from_slice(data)
618        .context("unable to parse ruleset as JSON")
619        .or_else(|_| serde_yaml_ng::from_slice(data))
620        .context("unable to parse ruleset as YAML")?;
621
622    Ok(ruleset)
623}
624
625/// Trait to fetch ruleset
626#[async_trait]
627trait FetchesRuleset: Send + Sync {
628    async fn fetch_rules(&self) -> Result<Ruleset, Error>;
629}
630
631/// Fetches ruleset from the file
632struct RulesetFetcherFile {
633    path: PathBuf,
634}
635
636#[async_trait]
637impl FetchesRuleset for RulesetFetcherFile {
638    async fn fetch_rules(&self) -> Result<Ruleset, Error> {
639        let data = fs::read(&self.path)
640            .await
641            .context("unable to read ruleset from file")?;
642
643        parse_ruleset(&data)
644    }
645}
646
647/// Fetches ruleset from the URL
648struct RulesetFetcherUrl {
649    http_client: Arc<dyn Client>,
650    url: Url,
651}
652
653#[async_trait]
654impl FetchesRuleset for RulesetFetcherUrl {
655    async fn fetch_rules(&self) -> Result<Ruleset, Error> {
656        let req = reqwest::Request::new(Method::GET, self.url.clone());
657        let resp = self
658            .http_client
659            .execute(req)
660            .await
661            .context("unable to execute request")?;
662        let data = resp.bytes().await.context("unable to get response body")?;
663
664        parse_ruleset(&data)
665    }
666}
667
668/// Waf layer usable as an Axum middleware
669#[derive(Clone, derive_new::new)]
670pub struct WafLayer {
671    ruleset: Arc<ArcSwap<Ruleset>>,
672    fetcher: Option<Arc<dyn FetchesRuleset>>,
673    interval: Duration,
674}
675
676impl<S> Layer<S> for WafLayer {
677    type Service = Waf<S>;
678
679    fn layer(&self, inner: S) -> Self::Service {
680        Waf {
681            ruleset: self.ruleset.clone(),
682            inner,
683        }
684    }
685}
686
687/// API handler to update the ruleset.
688/// Supports JSON and YAML.
689async fn api_handler(State(state): State<WafLayer>, body: Bytes) -> AxumResponse {
690    let ruleset = match parse_ruleset(&body) {
691        Ok(v) => v,
692        Err(e) => {
693            return (
694                StatusCode::BAD_REQUEST,
695                format!("Unable to parse ruleset: {e:#}"),
696            )
697                .into_response();
698        }
699    };
700
701    warn!("WAF: Ruleset updated over API");
702    if state.set_ruleset(ruleset) {
703        "Ruleset updated\n"
704    } else {
705        "Ruleset is the same, not updated\n"
706    }
707    .into_response()
708}
709
710/// Create an API router for nesting
711pub fn create_router<S: Send + Sync + Clone + 'static>(layer: WafLayer) -> Router<S> {
712    Router::new().route("/update", post(api_handler).with_state(layer))
713}
714
715impl WafLayer {
716    /// Create a new layer from provided CLI and optional HTTP client
717    pub fn new_from_cli(cli: &WafCli, http_client: Option<Arc<dyn Client>>) -> Result<Self, Error> {
718        let fetcher = if let Some(v) = &cli.waf_url {
719            let Some(http_client) = http_client else {
720                return Err(anyhow!("URL source requires HTTP client").into());
721            };
722
723            Some(Arc::new(RulesetFetcherUrl {
724                http_client,
725                url: v.clone(),
726            }) as Arc<dyn FetchesRuleset>)
727        } else {
728            cli.waf_file.as_ref().map(|v| {
729                Arc::new(RulesetFetcherFile { path: v.clone() }) as Arc<dyn FetchesRuleset>
730            })
731        };
732
733        Ok(Self {
734            ruleset: Arc::new(ArcSwap::new(Arc::new(Ruleset::default()))),
735            fetcher,
736            interval: cli.waf_interval,
737        })
738    }
739
740    /// Updates the ruleset, but only if the new ruleset is different.
741    pub fn set_ruleset(&self, new: Ruleset) -> bool {
742        let new = Arc::new(new);
743
744        // Check if the new ruleset is different
745        if new == self.ruleset.load_full() {
746            return false;
747        }
748
749        self.ruleset.store(new);
750        true
751    }
752
753    async fn update_ruleset(&self) {
754        let Some(fetcher) = &self.fetcher else {
755            return;
756        };
757
758        let ruleset = match fetcher.fetch_rules().await {
759            Ok(v) => v,
760            Err(e) => {
761                warn!("WAF: unable to fetch ruleset: {e:#}");
762                return;
763            }
764        };
765
766        if self.set_ruleset(ruleset) {
767            warn!("WAF: Ruleset was updated");
768        }
769    }
770}
771
772#[async_trait]
773impl Run for WafLayer {
774    async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> {
775        let mut interval = interval(self.interval);
776        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
777
778        loop {
779            select! {
780                () = token.cancelled() => return Ok(()),
781                _ = interval.tick() => self.update_ruleset().await,
782            }
783        }
784    }
785}
786
787#[cfg(test)]
788mod test {
789    use axum::{Router, body::Body};
790    use serde_json::json;
791
792    use super::*;
793
794    #[test]
795    fn test_status_range() {
796        assert_eq!(
797            StatusRange::from_str("100 -   200").unwrap(),
798            StatusRange {
799                from: 100,
800                to: Some(200)
801            }
802        );
803
804        assert_eq!(
805            StatusRange::from_str("100").unwrap(),
806            StatusRange {
807                from: 100,
808                to: None
809            }
810        );
811
812        assert!(StatusRange::from_str("").is_err());
813        assert!(StatusRange::from_str("+").is_err());
814        assert!(StatusRange::from_str("-").is_err());
815        assert!(StatusRange::from_str("99").is_err());
816        assert!(StatusRange::from_str("99-600").is_err());
817        assert!(StatusRange::from_str("99-").is_err());
818        assert!(StatusRange::from_str("-500").is_err());
819        assert!(StatusRange::from_str("101-600").is_err());
820        assert!(StatusRange::from_str("199-100").is_err());
821
822        let range = StatusRange::from_str("200-499").unwrap();
823
824        assert!(range.evaluate(StatusCode::OK));
825        assert!(range.evaluate(StatusCode::ACCEPTED));
826        assert!(range.evaluate(StatusCode::PERMANENT_REDIRECT));
827        assert!(range.evaluate(StatusCode::NOT_FOUND));
828        assert!(!range.evaluate(StatusCode::CONTINUE));
829        assert!(!range.evaluate(StatusCode::INTERNAL_SERVER_ERROR));
830        assert!(!range.evaluate(StatusCode::SERVICE_UNAVAILABLE));
831
832        let range = StatusRange::from_str("200").unwrap();
833        assert!(range.evaluate(StatusCode::OK));
834        assert!(!range.evaluate(StatusCode::ACCEPTED));
835    }
836
837    #[test]
838    fn test_request() {
839        let rule = json!({
840            "methods": ["GET", "OPTIONS"],
841            "headers": [
842                {
843                    "name": "foo",
844                    "regex": "^bar.*$"
845                },
846                {
847                    "name": "dead",
848                    "regex": "^beef.*$"
849                }
850            ],
851            "host": "^lala",
852            "path": "^/foo",
853        })
854        .to_string();
855
856        let rule: RequestMatcher = serde_json::from_str(&rule).unwrap();
857        assert_eq!(
858            rule,
859            RequestMatcher {
860                methods: Some(vec![Method::GET, Method::OPTIONS]),
861                headers: Some(vec![
862                    HeaderMatcher {
863                        name: HeaderName::from_static("foo"),
864                        regex: Regex::from_str("^bar.*$").unwrap(),
865                    },
866                    HeaderMatcher {
867                        name: HeaderName::from_static("dead"),
868                        regex: Regex::from_str("^beef.*$").unwrap(),
869                    }
870                ]),
871                host: Some(Regex::from_str("^lala").unwrap()),
872                path: Some(Regex::from_str("^/foo").unwrap()),
873            }
874        );
875
876        // Test full matches
877        let req = Request::builder()
878            .header("foo", "barfuss")
879            .header("dead", "beefbeef")
880            .method(Method::GET)
881            .version(Version::HTTP_2)
882            .uri("https://lala/foo")
883            .body("")
884            .unwrap();
885        assert!(rule.evaluate(&req));
886
887        for http_ver in [Version::HTTP_09, Version::HTTP_10, Version::HTTP_11] {
888            let req = Request::builder()
889                .header("foo", "barfuss")
890                .header("dead", "beefbeef")
891                .header("host", "lala")
892                .version(http_ver)
893                .method(Method::OPTIONS)
894                .uri("https://lala/foo")
895                .body("")
896                .unwrap();
897            assert!(rule.evaluate(&req));
898        }
899
900        let req = Request::builder()
901            .header("dead", "beefbeef")
902            .header("foo", "barfuss")
903            .version(Version::HTTP_2)
904            .method(Method::OPTIONS)
905            .uri("https://lala/foo")
906            .body("")
907            .unwrap();
908        assert!(rule.evaluate(&req));
909
910        // Test partial matches (no match)
911        let req = Request::builder()
912            .header("foo", "barfuss")
913            .header("dead", "beefbeef")
914            .method(Method::POST)
915            .uri("https://lala/foo")
916            .body("")
917            .unwrap();
918        assert!(!rule.evaluate(&req));
919
920        let req = Request::builder()
921            .header("foo", "barfuss")
922            .header("dead", "beefbeef")
923            .method(Method::GET)
924            .uri("https://lala/bar")
925            .body("")
926            .unwrap();
927        assert!(!rule.evaluate(&req));
928
929        let req = Request::builder()
930            .header("fox", "barfuss")
931            .header("dead", "beefbeef")
932            .method(Method::GET)
933            .uri("https://lala/foo")
934            .body("")
935            .unwrap();
936        assert!(!rule.evaluate(&req));
937    }
938
939    #[test]
940    fn test_response() {
941        let rule = json!({
942            "status": ["100-200", "307", "400-500"],
943            "headers": [
944                {
945                    "name": "foo",
946                    "regex": "^bar.*$"
947                },
948                {
949                    "name": "dead",
950                    "regex": "^beef.*$"
951                }
952            ],
953        })
954        .to_string();
955
956        let rule: ResponseMatcher = serde_json::from_str(&rule).unwrap();
957        assert_eq!(
958            rule,
959            ResponseMatcher {
960                status: Some(vec![
961                    StatusRange::from_str("100-200").unwrap(),
962                    StatusRange::from_str("307").unwrap(),
963                    StatusRange::from_str("400-500").unwrap()
964                ]),
965                headers: Some(vec![
966                    HeaderMatcher {
967                        name: HeaderName::from_static("foo"),
968                        regex: Regex::from_str("^bar.*$").unwrap(),
969                    },
970                    HeaderMatcher {
971                        name: HeaderName::from_static("dead"),
972                        regex: Regex::from_str("^beef.*$").unwrap(),
973                    }
974                ]),
975            }
976        );
977
978        // Test full matches
979        let resp = Response::builder()
980            .header("foo", "barfuss")
981            .header("dead", "beefbeef")
982            .status(StatusCode::OK)
983            .body("")
984            .unwrap();
985        assert!(rule.evaluate(&resp));
986
987        let resp = Response::builder()
988            .header("foo", "barfuss")
989            .header("dead", "beefbeef")
990            .status(StatusCode::CONTINUE)
991            .body("")
992            .unwrap();
993        assert!(rule.evaluate(&resp));
994
995        let resp = Response::builder()
996            .header("foo", "barfuss")
997            .header("dead", "beefbeef")
998            .status(StatusCode::TEMPORARY_REDIRECT)
999            .body("")
1000            .unwrap();
1001        assert!(rule.evaluate(&resp));
1002
1003        let resp = Response::builder()
1004            .header("foo", "barfuss")
1005            .header("dead", "beefbeef")
1006            .status(StatusCode::NOT_FOUND)
1007            .body("")
1008            .unwrap();
1009        assert!(rule.evaluate(&resp));
1010
1011        // Test partial matches (no match)
1012        let resp = Response::builder()
1013            .header("foo", "barfuss")
1014            .header("dead", "beefbeef")
1015            .status(StatusCode::PERMANENT_REDIRECT)
1016            .body("")
1017            .unwrap();
1018        assert!(!rule.evaluate(&resp));
1019
1020        let resp = Response::builder()
1021            .header("foo", "barfuss")
1022            .header("dead", "zbeefbeef")
1023            .status(StatusCode::OK)
1024            .body("")
1025            .unwrap();
1026        assert!(!rule.evaluate(&resp));
1027    }
1028
1029    #[test]
1030    fn test_request_action() {
1031        assert_eq!(
1032            RequestAction::from_str("pass").unwrap(),
1033            RequestAction::Pass
1034        );
1035
1036        assert_eq!(
1037            RequestAction::from_str("block").unwrap(),
1038            RequestAction::Block(StatusCode::FORBIDDEN)
1039        );
1040
1041        assert_eq!(
1042            RequestAction::from_str("block:451").unwrap(),
1043            RequestAction::Block(StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS)
1044        );
1045
1046        assert!(RequestAction::from_str("block:0").is_err());
1047        assert!(RequestAction::from_str("block:foo").is_err());
1048        assert!(RequestAction::from_str("foo").is_err());
1049
1050        assert_eq!(
1051            RequestAction::from_str("limit:global:10/1m").unwrap(),
1052            RequestAction::RateLimit(RateLimitType::Global(
1053                Quota::with_period(Duration::from_secs(6))
1054                    .unwrap()
1055                    .allow_burst(NonZeroU32::new(10).unwrap()),
1056                RateLimiter::direct(Quota::with_period(Duration::from_secs(6)).unwrap())
1057            ))
1058        );
1059
1060        assert_eq!(
1061            RequestAction::from_str("limit:per_ip:10/1m").unwrap(),
1062            RequestAction::RateLimit(RateLimitType::PerIp(
1063                Quota::with_period(Duration::from_secs(6))
1064                    .unwrap()
1065                    .allow_burst(NonZeroU32::new(10).unwrap()),
1066                RateLimiter::dashmap_with_hasher(
1067                    Quota::with_period(Duration::from_secs(6)).unwrap(),
1068                    RandomState::new()
1069                )
1070            ))
1071        );
1072
1073        assert!(RequestAction::from_str("limit").is_err());
1074        assert!(RequestAction::from_str("limit:").is_err());
1075        assert!(RequestAction::from_str("limit:foo").is_err());
1076        assert!(RequestAction::from_str("limit:0/1s").is_err());
1077        assert!(RequestAction::from_str("limit:1/0s").is_err());
1078        assert!(RequestAction::from_str("limit:1/foo").is_err());
1079    }
1080
1081    #[test]
1082    fn test_response_action() {
1083        assert_eq!(
1084            ResponseAction::from_str("pass").unwrap(),
1085            ResponseAction::Pass
1086        );
1087
1088        assert_eq!(
1089            ResponseAction::from_str("block").unwrap(),
1090            ResponseAction::Block(StatusCode::FORBIDDEN)
1091        );
1092
1093        assert_eq!(
1094            ResponseAction::from_str("block:451").unwrap(),
1095            ResponseAction::Block(StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS)
1096        );
1097
1098        assert!(ResponseAction::from_str("block:0").is_err());
1099        assert!(ResponseAction::from_str("block:foo").is_err());
1100        assert!(ResponseAction::from_str("foo").is_err());
1101    }
1102
1103    #[test]
1104    fn test_ruleset() {
1105        let ruleset = json!({
1106            "requests": [
1107            {
1108                "match": {
1109                    "methods": ["GET", "POST"],
1110                    "host": "^foo",
1111                    "path": "^/bar"
1112                },
1113                "action": "limit:global:10/1h",
1114            }]
1115        })
1116        .to_string();
1117        let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1118        assert!(!ruleset.is_empty());
1119
1120        let ruleset = json!({
1121            "responses": [
1122            {
1123                "match_req": {
1124                    "methods": ["OPTIONS"],
1125                },
1126                "match_resp": {
1127                    "status": ["100-200", "400-499", "599"],
1128                },
1129                "action": "block:499",
1130            }]
1131        })
1132        .to_string();
1133        let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1134        assert!(!ruleset.is_empty());
1135
1136        let ruleset = json!({
1137            "requests": [],
1138            "responses": [],
1139        })
1140        .to_string();
1141        let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1142        assert!(ruleset.is_empty());
1143
1144        let ruleset = json!({}).to_string();
1145        let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1146        assert!(ruleset.is_empty());
1147
1148        let ruleset = json!({
1149            "requests": [
1150            {
1151                "match": {
1152                    "methods": ["GET", "POST"],
1153                    "host": "^foo",
1154                    "path": "^/bar"
1155                },
1156                "action": "limit:global:10/1h",
1157            },
1158            {
1159                "match": {
1160                    "methods": ["DELETE"],
1161                },
1162                "action": "block:403",
1163            }],
1164
1165            "responses": [
1166            {
1167                "match_req": {
1168                    "methods": ["OPTIONS"],
1169                },
1170                "match_resp": {
1171                    "status": ["100-200", "400-499", "599"],
1172                },
1173                "action": "block:499",
1174            },
1175            {
1176                "match_resp": {
1177                    "status": ["100-200", "400-499", "599"],
1178                },
1179                "action": "block:451",
1180            },
1181            {
1182                "match_resp": {
1183                    "status": ["500"],
1184                    "headers": [{
1185                        "name": "foo",
1186                        "value": "bar.*",
1187                    }]
1188                },
1189                "action": "block:401",
1190            }]
1191        })
1192        .to_string();
1193
1194        let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1195        assert!(!ruleset.is_empty());
1196
1197        // Test requests
1198
1199        // Should always pass
1200        for _ in 0..1000 {
1201            let req = Request::builder().method(Method::OPTIONS).body("").unwrap();
1202            assert_eq!(ruleset.evaluate_request(&req), Decision::Pass);
1203        }
1204
1205        // Should always block
1206        for _ in 0..1000 {
1207            let req = Request::builder().method(Method::DELETE).body("").unwrap();
1208            assert_eq!(
1209                ruleset.evaluate_request(&req),
1210                Decision::Block(StatusCode::FORBIDDEN)
1211            );
1212        }
1213
1214        // 10 should go through, the rest throttled
1215        for _ in 0..10 {
1216            let req = Request::builder()
1217                .method(Method::GET)
1218                .version(Version::HTTP_2)
1219                .uri("https://foo/bar")
1220                .body("")
1221                .unwrap();
1222            assert_eq!(ruleset.evaluate_request(&req), Decision::Pass);
1223        }
1224
1225        let req = Request::builder()
1226            .method(Method::GET)
1227            .version(Version::HTTP_2)
1228            .uri("https://foo/bar")
1229            .body("")
1230            .unwrap();
1231
1232        let r = ruleset.evaluate_request(&req);
1233        match r {
1234            Decision::Throttle(v) => {
1235                assert!(v >= Duration::from_secs(359) && v <= Duration::from_secs(360))
1236            }
1237            _ => unreachable!(),
1238        }
1239
1240        for _ in 0..1000 {
1241            let req = Request::builder()
1242                .method(Method::GET)
1243                .version(Version::HTTP_2)
1244                .uri("https://foo/bar")
1245                .body("")
1246                .unwrap();
1247            assert!(matches!(
1248                ruleset.evaluate_request(&req),
1249                Decision::Throttle(_)
1250            ));
1251        }
1252
1253        // Test responses
1254        let req = Request::builder().method(Method::POST).body(()).unwrap();
1255
1256        let resp = Response::builder()
1257            .status(StatusCode::PERMANENT_REDIRECT)
1258            .body("")
1259            .unwrap();
1260
1261        // Should always pass
1262        for _ in 0..1000 {
1263            assert_eq!(ruleset.evaluate_response(&req, &resp), Decision::Pass);
1264        }
1265
1266        // Should always block with 451
1267        for _ in 0..1000 {
1268            let resp = Response::builder().status(StatusCode::OK).body("").unwrap();
1269            assert_eq!(
1270                ruleset.evaluate_response(&req, &resp),
1271                Decision::Block(StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS)
1272            );
1273        }
1274
1275        // Should always block with 499
1276        let req = Request::builder().method(Method::OPTIONS).body(()).unwrap();
1277
1278        for _ in 0..1000 {
1279            let resp = Response::builder().status(StatusCode::OK).body("").unwrap();
1280            assert_eq!(
1281                ruleset.evaluate_response(&req, &resp),
1282                Decision::Block(StatusCode::from_u16(499).unwrap())
1283            );
1284        }
1285
1286        // Should always block with 401
1287        let resp = Response::builder()
1288            .status(StatusCode::INTERNAL_SERVER_ERROR)
1289            .header("foo", "bardead")
1290            .body("")
1291            .unwrap();
1292
1293        for _ in 0..1000 {
1294            assert_eq!(
1295                ruleset.evaluate_response(&req, &resp),
1296                Decision::Block(StatusCode::UNAUTHORIZED)
1297            );
1298        }
1299    }
1300
1301    #[tokio::test]
1302    async fn test_waflayer() {
1303        use axum::routing::get;
1304
1305        let ruleset = Ruleset::default();
1306        let layer = WafLayer::new(
1307            Arc::new(ArcSwap::new(Arc::new(ruleset))),
1308            None,
1309            Duration::ZERO,
1310        );
1311
1312        let ruleset = r#"
1313        requests:
1314        - action: block:451
1315          match:
1316            methods:
1317            - OPTIONS
1318            - GET
1319            headers:
1320            - name: foo
1321              regex: ^bar.*$
1322        "#;
1323
1324        assert!(layer.set_ruleset(parse_ruleset(ruleset.as_bytes()).unwrap()));
1325        assert!(!layer.set_ruleset(parse_ruleset(ruleset.as_bytes()).unwrap()));
1326
1327        let mut router = Router::new()
1328            .route("/", get(|| async { "foo" }).options(|| async { "bar" }))
1329            .layer(layer);
1330
1331        // Should block
1332        let req = Request::builder()
1333            .method(Method::OPTIONS)
1334            .header("foo", "barfuss")
1335            .body(Body::empty())
1336            .unwrap();
1337        let resp = router.call(req).await.unwrap();
1338        assert_eq!(resp.status(), StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS);
1339
1340        let req = Request::builder()
1341            .method(Method::GET)
1342            .header("foo", "barfuss")
1343            .body(Body::empty())
1344            .unwrap();
1345        let resp = router.call(req).await.unwrap();
1346        assert_eq!(resp.status(), StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS);
1347
1348        // Should pass
1349        let req = Request::builder()
1350            .method(Method::OPTIONS)
1351            .body(Body::empty())
1352            .unwrap();
1353        let resp = router.call(req).await.unwrap();
1354        assert_eq!(resp.status(), StatusCode::OK);
1355    }
1356}