1use std::{
2 net::IpAddr,
3 num::NonZeroU32,
4 path::PathBuf,
5 str::FromStr,
6 sync::Arc,
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use ahash::RandomState;
12use anyhow::{Context as _, anyhow};
13use arc_swap::ArcSwap;
14use async_trait::async_trait;
15use axum::{
16 Router,
17 extract::{Request as AxumRequest, State},
18 response::{IntoResponse, Response as AxumResponse},
19 routing::post,
20};
21use bytes::Bytes;
22use futures::future::BoxFuture;
23use governor::{
24 Quota, RateLimiter,
25 clock::{Clock, DefaultClock},
26 state::{InMemoryState, NotKeyed, keyed::DashMapStateStore},
27};
28use http::{
29 HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, Version,
30 header::{HOST, RETRY_AFTER},
31};
32use humantime::parse_duration;
33use ic_bn_lib_common::{
34 traits::{Run, http::Client},
35 types::http::{Error, WafCli},
36};
37use itertools::Itertools;
38use regex::Regex;
39use serde::Deserialize;
40use serde_with::{DeserializeFromStr, DisplayFromStr, serde_as};
41use tokio::{fs, select, time::interval};
42use tokio_util::sync::CancellationToken;
43use tower::{Layer, Service};
44use tracing::warn;
45use url::Url;
46
47use crate::http::middleware::extract_ip_from_request;
48
49#[derive(Debug, Clone, Copy, Eq, PartialEq, DeserializeFromStr)]
51pub struct StatusRange {
52 from: u16,
53 to: Option<u16>,
54}
55
56impl StatusRange {
57 pub const fn evaluate(&self, v: StatusCode) -> bool {
59 let code = v.as_u16();
60
61 if let Some(to) = self.to {
62 return code >= self.from && code <= to;
63 }
64
65 code == self.from
66 }
67}
68
69impl FromStr for StatusRange {
70 type Err = Error;
71
72 fn from_str(s: &str) -> Result<Self, Self::Err> {
73 let mut it = s.split('-');
74 let (from, to) = (it.next().unwrap(), it.next());
75 let from: u16 = from
76 .trim()
77 .parse()
78 .context("unable to parse status range start")?;
79
80 if !(100..=599).contains(&from) {
81 return Err(anyhow!("Status code can be between 100 and 599, not {from}").into());
82 }
83
84 let to = if let Some(v) = to {
85 let v = v
86 .trim()
87 .parse()
88 .context("unable to parse status range end")?;
89
90 if !(100..=599).contains(&v) {
91 return Err(anyhow!("Status code can be between 100 and 599, not {v}").into());
92 }
93
94 if v <= from {
95 return Err(anyhow!(
96 "End of the range should be greater than start ({v} > {from})"
97 )
98 .into());
99 }
100
101 Some(v)
102 } else {
103 None
104 };
105
106 Ok(Self { from, to })
107 }
108}
109
110#[serde_as]
112#[derive(Debug, Clone, Deserialize)]
113pub struct HeaderMatcher {
114 #[serde_as(as = "DisplayFromStr")]
115 pub name: HeaderName,
116 #[serde_as(as = "DisplayFromStr")]
117 #[serde(alias = "value")]
118 pub regex: Regex,
119}
120
121impl PartialEq for HeaderMatcher {
122 fn eq(&self, other: &Self) -> bool {
123 self.name == other.name && self.regex.as_str() == other.regex.as_str()
124 }
125}
126impl Eq for HeaderMatcher {}
127
128impl Ord for HeaderMatcher {
129 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
130 self.name.as_str().cmp(other.name.as_str())
131 }
132}
133impl PartialOrd for HeaderMatcher {
134 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
135 Some(self.cmp(other))
136 }
137}
138
139impl HeaderMatcher {
140 pub fn evaluate(&self, name: &HeaderName, value: &HeaderValue) -> bool {
142 if name != self.name {
143 return false;
144 }
145
146 let Ok(value) = value.to_str() else {
147 return false;
148 };
149
150 self.regex.is_match(value)
151 }
152
153 pub fn evaluate_headermap(&self, map: &HeaderMap) -> bool {
155 map.iter().any(|(name, value)| self.evaluate(name, value))
156 }
157}
158
159#[serde_as]
161#[derive(Debug, Clone, Deserialize)]
162pub struct RequestMatcher {
163 #[serde_as(as = "Option<DisplayFromStr>")]
164 pub host: Option<Regex>,
165 #[serde_as(as = "Option<DisplayFromStr>")]
166 pub path: Option<Regex>,
167 #[serde_as(as = "Option<Vec<DisplayFromStr>>")]
168 pub methods: Option<Vec<Method>>,
169 pub headers: Option<Vec<HeaderMatcher>>,
170}
171
172impl PartialEq for RequestMatcher {
173 fn eq(&self, other: &Self) -> bool {
174 self.methods == other.methods
175 && self
177 .headers
178 .as_ref()
179 .map(|x| x.clone().into_iter().sorted().collect::<Vec<_>>())
180 == other
181 .headers
182 .as_ref()
183 .map(|x| x.clone().into_iter().sorted().collect::<Vec<_>>())
184 && self.host.as_ref().map(|x| x.as_str()) == other.host.as_ref().map(|x| x.as_str())
185 && self.path.as_ref().map(|x| x.as_str()) == other.path.as_ref().map(|x| x.as_str())
186 }
187}
188impl Eq for RequestMatcher {}
189
190impl RequestMatcher {
191 pub fn evaluate<T>(&self, req: &Request<T>) -> bool {
193 if let Some(v) = &self.host {
195 let host = match req.version() {
196 Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => req
199 .headers()
200 .get(HOST)
201 .and_then(|x| x.to_str().ok())
202 .unwrap_or_default(),
203
204 _ => req.uri().host().unwrap_or_default(),
206 };
207
208 if !v.is_match(host) {
209 return false;
210 }
211 }
212
213 if let Some(v) = &self.path
215 && !v.is_match(
216 req.uri()
217 .path_and_query()
218 .map(|x| x.as_str())
219 .unwrap_or_default(),
220 )
221 {
222 return false;
223 }
224
225 if let Some(v) = &self.methods
227 && !v.iter().contains(req.method())
228 {
229 return false;
230 }
231
232 if let Some(v) = &self.headers
234 && !v.iter().all(|rule| rule.evaluate_headermap(req.headers()))
235 {
236 return false;
237 }
238
239 true
241 }
242}
243
244#[serde_as]
246#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
247pub struct ResponseMatcher {
248 pub headers: Option<Vec<HeaderMatcher>>,
249 pub status: Option<Vec<StatusRange>>,
250}
251
252impl ResponseMatcher {
253 pub fn evaluate<T>(&self, req: &Response<T>) -> bool {
255 if let Some(v) = &self.status
257 && !v.iter().any(|x| x.evaluate(req.status()))
258 {
259 return false;
260 }
261
262 if let Some(v) = &self.headers
264 && !v.iter().all(|rule| rule.evaluate_headermap(req.headers()))
265 {
266 return false;
267 }
268
269 true
271 }
272}
273
274#[derive(Debug, Clone, Eq, PartialEq)]
276pub enum RateLimitDecision {
277 Pass,
278 Throttle(Duration),
279}
280
281#[derive(Debug)]
283pub enum RateLimitType {
284 Global(Quota, RateLimiter<NotKeyed, InMemoryState, DefaultClock>),
285 PerIp(
286 Quota,
287 RateLimiter<IpAddr, DashMapStateStore<IpAddr, RandomState>, DefaultClock>,
288 ),
289}
290
291impl FromStr for RateLimitType {
292 type Err = Error;
293
294 fn from_str(s: &str) -> Result<Self, Self::Err> {
295 let Some((typ, limit)) = s.split_once(':') else {
296 return Err(anyhow!("expecting limit in 'type:rate' format").into());
297 };
298
299 let Some((rate, dur)) = limit.split_once("/") else {
300 return Err(anyhow!("expecting rate in 'rate/duration' format").into());
301 };
302
303 let rate = rate.parse::<u32>().context("unable to parse rate as u32")?;
304 let dur = parse_duration(dur).context("unable to parse duration")?;
305
306 if rate == 0 {
307 return Err(anyhow!("rate must be > 0").into());
308 }
309
310 if dur == Duration::ZERO {
311 return Err(anyhow!("duration cannot be zero").into());
312 }
313
314 let replenish_period = dur / rate;
316 let quota = Quota::with_period(replenish_period)
317 .unwrap()
318 .allow_burst(NonZeroU32::new(rate).unwrap());
319
320 Ok(match typ {
321 "global" => Self::Global(quota, RateLimiter::direct(quota)),
322 "per_ip" => Self::PerIp(
323 quota,
324 RateLimiter::dashmap_with_hasher(quota, RandomState::new()),
325 ),
326 _ => return Err(anyhow!("unknown rate limiter type {typ}").into()),
327 })
328 }
329}
330
331impl PartialEq for RateLimitType {
332 fn eq(&self, other: &Self) -> bool {
333 match (self, other) {
334 (Self::Global(q1, _), Self::Global(q2, _)) => q1 == q2,
335 (Self::PerIp(q1, _), Self::PerIp(q2, _)) => q1 == q2,
336 _ => false,
337 }
338 }
339}
340impl Eq for RateLimitType {}
341
342impl RateLimitType {
343 pub fn allowed<B>(&self, req: &Request<B>) -> RateLimitDecision {
345 let (clock, r) = match self {
346 Self::Global(_, v) => (v.clock(), v.check()),
347 Self::PerIp(_, v) => {
348 let Some(ip) = extract_ip_from_request(req) else {
352 return RateLimitDecision::Pass;
353 };
354
355 (v.clock(), v.check_key(&ip))
356 }
357 };
358
359 if let Err(e) = r {
360 let dur = e.wait_time_from(clock.now());
361 return RateLimitDecision::Throttle(dur);
362 }
363
364 RateLimitDecision::Pass
365 }
366}
367
368#[derive(Debug, PartialEq, Eq)]
370pub enum RequestAction {
371 Pass,
372 Block(StatusCode),
373 RateLimit(RateLimitType),
374}
375
376impl FromStr for RequestAction {
377 type Err = Error;
378
379 fn from_str(s: &str) -> Result<Self, Self::Err> {
380 if s == "pass" {
381 return Ok(Self::Pass);
382 }
383
384 let mut it = s.split(':');
385 let (pfx, sfx) = (it.next().unwrap(), it.next());
386 if pfx == "block" {
387 let code = if let Some(code) = sfx {
388 StatusCode::from_str(code).context("unable to parse status code")?
389 } else {
390 StatusCode::FORBIDDEN
391 };
392
393 return Ok(Self::Block(code));
394 }
395
396 if pfx == "limit" {
397 let Some((_, v)) = s.split_once(':') else {
398 return Err(anyhow!("expecting limit definition after ':'").into());
399 };
400
401 return Ok(Self::RateLimit(RateLimitType::from_str(v)?));
402 }
403
404 Err(anyhow!("unsupported action format").into())
405 }
406}
407
408#[derive(Debug, PartialEq, Eq)]
410pub enum ResponseAction {
411 Pass,
412 Block(StatusCode),
413}
414
415impl FromStr for ResponseAction {
416 type Err = Error;
417
418 fn from_str(s: &str) -> Result<Self, Self::Err> {
419 if s == "pass" {
420 return Ok(Self::Pass);
421 }
422
423 let mut it = s.split(':');
424 let (pfx, sfx) = (it.next().unwrap(), it.next());
425 if pfx == "block" {
426 let code = if let Some(code) = sfx {
427 StatusCode::from_str(code).context("unable to parse status code")?
428 } else {
429 StatusCode::FORBIDDEN
430 };
431
432 return Ok(Self::Block(code));
433 }
434
435 Err(anyhow!("unsupported action format").into())
436 }
437}
438
439#[derive(Debug, Clone, PartialEq, Eq)]
441pub enum Decision {
442 Pass,
443 Block(StatusCode),
444 Throttle(Duration),
445}
446
447impl IntoResponse for Decision {
448 fn into_response(self) -> AxumResponse {
449 match self {
450 Self::Pass => StatusCode::OK.into_response(),
451 Self::Block(v) => (v, "Blocked for policy reasons").into_response(),
452 Self::Throttle(v) => (
453 StatusCode::TOO_MANY_REQUESTS,
454 [(
455 RETRY_AFTER,
456 HeaderValue::from_str(&(v.as_secs() + 1).to_string()).unwrap(),
457 )],
458 "Request was rate-limited, consult Retry-After header for a number of seconds after which it can be retried",
459 )
460 .into_response(),
461 }
462 }
463}
464
465#[serde_as]
467#[derive(Debug, PartialEq, Eq, Deserialize)]
468pub struct RequestRule {
469 #[serde(alias = "match")]
470 pub matcher: RequestMatcher,
471 #[serde_as(as = "DisplayFromStr")]
472 pub action: RequestAction,
473}
474
475impl RequestRule {
476 pub fn evaluate<B>(&self, req: &Request<B>) -> Option<Decision> {
477 if !self.matcher.evaluate(req) {
478 return None;
479 }
480
481 Some(match &self.action {
482 RequestAction::Pass => Decision::Pass,
483 RequestAction::Block(v) => Decision::Block(*v),
484 RequestAction::RateLimit(v) => match v.allowed(req) {
485 RateLimitDecision::Pass => Decision::Pass,
486 RateLimitDecision::Throttle(v) => Decision::Throttle(v),
487 },
488 })
489 }
490}
491
492#[serde_as]
494#[derive(Debug, PartialEq, Eq, Deserialize)]
495pub struct ResponseRule {
496 #[serde(alias = "match_req")]
497 pub matcher_req: Option<RequestMatcher>,
498 #[serde(alias = "match_resp")]
499 pub matcher: ResponseMatcher,
500 #[serde_as(as = "DisplayFromStr")]
501 pub action: ResponseAction,
502}
503
504impl ResponseRule {
505 pub fn evaluate<B1, B2>(&self, req: &Request<B1>, resp: &Response<B2>) -> Option<Decision> {
506 if let Some(v) = &self.matcher_req
507 && !v.evaluate(req)
508 {
509 return None;
510 }
511
512 if !self.matcher.evaluate(resp) {
513 return None;
514 }
515
516 Some(match &self.action {
517 ResponseAction::Pass => Decision::Pass,
518 ResponseAction::Block(v) => Decision::Block(*v),
519 })
520 }
521}
522
523#[derive(Debug, PartialEq, Eq, Deserialize, Default)]
525pub struct Ruleset {
526 pub requests: Option<Vec<RequestRule>>,
527 pub responses: Option<Vec<ResponseRule>>,
528}
529
530impl Ruleset {
531 fn is_empty(&self) -> bool {
532 (self.requests.is_none() || self.requests.as_ref().map(|x| x.is_empty()) == Some(true))
533 && (self.responses.is_none()
534 || self.responses.as_ref().map(|x| x.is_empty()) == Some(true))
535 }
536
537 fn evaluate_request<B>(&self, req: &Request<B>) -> Decision {
539 let Some(v) = &self.requests else {
540 return Decision::Pass;
541 };
542
543 v.iter()
544 .find_map(|x| x.evaluate(req))
545 .unwrap_or(Decision::Pass)
546 }
547
548 fn evaluate_response<B1, B2>(&self, req: &Request<B1>, resp: &Response<B2>) -> Decision {
550 let Some(v) = &self.responses else {
551 return Decision::Pass;
552 };
553
554 v.iter()
555 .find_map(|x| x.evaluate(req, resp))
556 .unwrap_or(Decision::Pass)
557 }
558}
559
560#[derive(Debug, Clone)]
562pub struct Waf<S> {
563 ruleset: Arc<ArcSwap<Ruleset>>,
564 inner: S,
565}
566
567impl<S> Service<AxumRequest> for Waf<S>
569where
570 S: Service<AxumRequest, Response = AxumResponse> + Send + 'static,
571 S::Future: Send + 'static,
572{
573 type Response = S::Response;
574 type Error = S::Error;
575 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
576
577 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
578 self.inner.poll_ready(cx)
579 }
580
581 fn call(&mut self, request: AxumRequest) -> Self::Future {
582 let ruleset = self.ruleset.load_full();
583 if ruleset.is_empty() {
585 let fut = self.inner.call(request);
586 return Box::pin(fut);
587 }
588
589 let decision = ruleset.evaluate_request(&request);
591 if decision != Decision::Pass {
592 return Box::pin(async move { Ok(decision.into_response()) });
593 }
594
595 let (parts, body) = request.into_parts();
597 let parts_clone = parts.clone();
598 let request = AxumRequest::from_parts(parts, body);
599
600 let future = self.inner.call(request);
601 Box::pin(async move {
602 let response: AxumResponse = future.await?;
603
604 let req = Request::from_parts(parts_clone, ());
606 let decision = ruleset.evaluate_response(&req, &response);
607 if decision != Decision::Pass {
608 return Ok(decision.into_response());
609 }
610
611 Ok(response)
612 })
613 }
614}
615
616fn parse_ruleset(data: &[u8]) -> Result<Ruleset, Error> {
617 let ruleset: Ruleset = serde_json::from_slice(data)
618 .context("unable to parse ruleset as JSON")
619 .or_else(|_| serde_yaml_ng::from_slice(data))
620 .context("unable to parse ruleset as YAML")?;
621
622 Ok(ruleset)
623}
624
625#[async_trait]
627trait FetchesRuleset: Send + Sync {
628 async fn fetch_rules(&self) -> Result<Ruleset, Error>;
629}
630
631struct RulesetFetcherFile {
633 path: PathBuf,
634}
635
636#[async_trait]
637impl FetchesRuleset for RulesetFetcherFile {
638 async fn fetch_rules(&self) -> Result<Ruleset, Error> {
639 let data = fs::read(&self.path)
640 .await
641 .context("unable to read ruleset from file")?;
642
643 parse_ruleset(&data)
644 }
645}
646
647struct RulesetFetcherUrl {
649 http_client: Arc<dyn Client>,
650 url: Url,
651}
652
653#[async_trait]
654impl FetchesRuleset for RulesetFetcherUrl {
655 async fn fetch_rules(&self) -> Result<Ruleset, Error> {
656 let req = reqwest::Request::new(Method::GET, self.url.clone());
657 let resp = self
658 .http_client
659 .execute(req)
660 .await
661 .context("unable to execute request")?;
662 let data = resp.bytes().await.context("unable to get response body")?;
663
664 parse_ruleset(&data)
665 }
666}
667
668#[derive(Clone, derive_new::new)]
670pub struct WafLayer {
671 ruleset: Arc<ArcSwap<Ruleset>>,
672 fetcher: Option<Arc<dyn FetchesRuleset>>,
673 interval: Duration,
674}
675
676impl<S> Layer<S> for WafLayer {
677 type Service = Waf<S>;
678
679 fn layer(&self, inner: S) -> Self::Service {
680 Waf {
681 ruleset: self.ruleset.clone(),
682 inner,
683 }
684 }
685}
686
687async fn api_handler(State(state): State<WafLayer>, body: Bytes) -> AxumResponse {
690 let ruleset = match parse_ruleset(&body) {
691 Ok(v) => v,
692 Err(e) => {
693 return (
694 StatusCode::BAD_REQUEST,
695 format!("Unable to parse ruleset: {e:#}"),
696 )
697 .into_response();
698 }
699 };
700
701 warn!("WAF: Ruleset updated over API");
702 if state.set_ruleset(ruleset) {
703 "Ruleset updated\n"
704 } else {
705 "Ruleset is the same, not updated\n"
706 }
707 .into_response()
708}
709
710pub fn create_router<S: Send + Sync + Clone + 'static>(layer: WafLayer) -> Router<S> {
712 Router::new().route("/update", post(api_handler).with_state(layer))
713}
714
715impl WafLayer {
716 pub fn new_from_cli(cli: &WafCli, http_client: Option<Arc<dyn Client>>) -> Result<Self, Error> {
718 let fetcher = if let Some(v) = &cli.waf_url {
719 let Some(http_client) = http_client else {
720 return Err(anyhow!("URL source requires HTTP client").into());
721 };
722
723 Some(Arc::new(RulesetFetcherUrl {
724 http_client,
725 url: v.clone(),
726 }) as Arc<dyn FetchesRuleset>)
727 } else {
728 cli.waf_file.as_ref().map(|v| {
729 Arc::new(RulesetFetcherFile { path: v.clone() }) as Arc<dyn FetchesRuleset>
730 })
731 };
732
733 Ok(Self {
734 ruleset: Arc::new(ArcSwap::new(Arc::new(Ruleset::default()))),
735 fetcher,
736 interval: cli.waf_interval,
737 })
738 }
739
740 pub fn set_ruleset(&self, new: Ruleset) -> bool {
742 let new = Arc::new(new);
743
744 if new == self.ruleset.load_full() {
746 return false;
747 }
748
749 self.ruleset.store(new);
750 true
751 }
752
753 async fn update_ruleset(&self) {
754 let Some(fetcher) = &self.fetcher else {
755 return;
756 };
757
758 let ruleset = match fetcher.fetch_rules().await {
759 Ok(v) => v,
760 Err(e) => {
761 warn!("WAF: unable to fetch ruleset: {e:#}");
762 return;
763 }
764 };
765
766 if self.set_ruleset(ruleset) {
767 warn!("WAF: Ruleset was updated");
768 }
769 }
770}
771
772#[async_trait]
773impl Run for WafLayer {
774 async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> {
775 let mut interval = interval(self.interval);
776 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
777
778 loop {
779 select! {
780 () = token.cancelled() => return Ok(()),
781 _ = interval.tick() => self.update_ruleset().await,
782 }
783 }
784 }
785}
786
787#[cfg(test)]
788mod test {
789 use axum::{Router, body::Body};
790 use serde_json::json;
791
792 use super::*;
793
794 #[test]
795 fn test_status_range() {
796 assert_eq!(
797 StatusRange::from_str("100 - 200").unwrap(),
798 StatusRange {
799 from: 100,
800 to: Some(200)
801 }
802 );
803
804 assert_eq!(
805 StatusRange::from_str("100").unwrap(),
806 StatusRange {
807 from: 100,
808 to: None
809 }
810 );
811
812 assert!(StatusRange::from_str("").is_err());
813 assert!(StatusRange::from_str("+").is_err());
814 assert!(StatusRange::from_str("-").is_err());
815 assert!(StatusRange::from_str("99").is_err());
816 assert!(StatusRange::from_str("99-600").is_err());
817 assert!(StatusRange::from_str("99-").is_err());
818 assert!(StatusRange::from_str("-500").is_err());
819 assert!(StatusRange::from_str("101-600").is_err());
820 assert!(StatusRange::from_str("199-100").is_err());
821
822 let range = StatusRange::from_str("200-499").unwrap();
823
824 assert!(range.evaluate(StatusCode::OK));
825 assert!(range.evaluate(StatusCode::ACCEPTED));
826 assert!(range.evaluate(StatusCode::PERMANENT_REDIRECT));
827 assert!(range.evaluate(StatusCode::NOT_FOUND));
828 assert!(!range.evaluate(StatusCode::CONTINUE));
829 assert!(!range.evaluate(StatusCode::INTERNAL_SERVER_ERROR));
830 assert!(!range.evaluate(StatusCode::SERVICE_UNAVAILABLE));
831
832 let range = StatusRange::from_str("200").unwrap();
833 assert!(range.evaluate(StatusCode::OK));
834 assert!(!range.evaluate(StatusCode::ACCEPTED));
835 }
836
837 #[test]
838 fn test_request() {
839 let rule = json!({
840 "methods": ["GET", "OPTIONS"],
841 "headers": [
842 {
843 "name": "foo",
844 "regex": "^bar.*$"
845 },
846 {
847 "name": "dead",
848 "regex": "^beef.*$"
849 }
850 ],
851 "host": "^lala",
852 "path": "^/foo",
853 })
854 .to_string();
855
856 let rule: RequestMatcher = serde_json::from_str(&rule).unwrap();
857 assert_eq!(
858 rule,
859 RequestMatcher {
860 methods: Some(vec![Method::GET, Method::OPTIONS]),
861 headers: Some(vec![
862 HeaderMatcher {
863 name: HeaderName::from_static("foo"),
864 regex: Regex::from_str("^bar.*$").unwrap(),
865 },
866 HeaderMatcher {
867 name: HeaderName::from_static("dead"),
868 regex: Regex::from_str("^beef.*$").unwrap(),
869 }
870 ]),
871 host: Some(Regex::from_str("^lala").unwrap()),
872 path: Some(Regex::from_str("^/foo").unwrap()),
873 }
874 );
875
876 let req = Request::builder()
878 .header("foo", "barfuss")
879 .header("dead", "beefbeef")
880 .method(Method::GET)
881 .version(Version::HTTP_2)
882 .uri("https://lala/foo")
883 .body("")
884 .unwrap();
885 assert!(rule.evaluate(&req));
886
887 for http_ver in [Version::HTTP_09, Version::HTTP_10, Version::HTTP_11] {
888 let req = Request::builder()
889 .header("foo", "barfuss")
890 .header("dead", "beefbeef")
891 .header("host", "lala")
892 .version(http_ver)
893 .method(Method::OPTIONS)
894 .uri("https://lala/foo")
895 .body("")
896 .unwrap();
897 assert!(rule.evaluate(&req));
898 }
899
900 let req = Request::builder()
901 .header("dead", "beefbeef")
902 .header("foo", "barfuss")
903 .version(Version::HTTP_2)
904 .method(Method::OPTIONS)
905 .uri("https://lala/foo")
906 .body("")
907 .unwrap();
908 assert!(rule.evaluate(&req));
909
910 let req = Request::builder()
912 .header("foo", "barfuss")
913 .header("dead", "beefbeef")
914 .method(Method::POST)
915 .uri("https://lala/foo")
916 .body("")
917 .unwrap();
918 assert!(!rule.evaluate(&req));
919
920 let req = Request::builder()
921 .header("foo", "barfuss")
922 .header("dead", "beefbeef")
923 .method(Method::GET)
924 .uri("https://lala/bar")
925 .body("")
926 .unwrap();
927 assert!(!rule.evaluate(&req));
928
929 let req = Request::builder()
930 .header("fox", "barfuss")
931 .header("dead", "beefbeef")
932 .method(Method::GET)
933 .uri("https://lala/foo")
934 .body("")
935 .unwrap();
936 assert!(!rule.evaluate(&req));
937 }
938
939 #[test]
940 fn test_response() {
941 let rule = json!({
942 "status": ["100-200", "307", "400-500"],
943 "headers": [
944 {
945 "name": "foo",
946 "regex": "^bar.*$"
947 },
948 {
949 "name": "dead",
950 "regex": "^beef.*$"
951 }
952 ],
953 })
954 .to_string();
955
956 let rule: ResponseMatcher = serde_json::from_str(&rule).unwrap();
957 assert_eq!(
958 rule,
959 ResponseMatcher {
960 status: Some(vec![
961 StatusRange::from_str("100-200").unwrap(),
962 StatusRange::from_str("307").unwrap(),
963 StatusRange::from_str("400-500").unwrap()
964 ]),
965 headers: Some(vec![
966 HeaderMatcher {
967 name: HeaderName::from_static("foo"),
968 regex: Regex::from_str("^bar.*$").unwrap(),
969 },
970 HeaderMatcher {
971 name: HeaderName::from_static("dead"),
972 regex: Regex::from_str("^beef.*$").unwrap(),
973 }
974 ]),
975 }
976 );
977
978 let resp = Response::builder()
980 .header("foo", "barfuss")
981 .header("dead", "beefbeef")
982 .status(StatusCode::OK)
983 .body("")
984 .unwrap();
985 assert!(rule.evaluate(&resp));
986
987 let resp = Response::builder()
988 .header("foo", "barfuss")
989 .header("dead", "beefbeef")
990 .status(StatusCode::CONTINUE)
991 .body("")
992 .unwrap();
993 assert!(rule.evaluate(&resp));
994
995 let resp = Response::builder()
996 .header("foo", "barfuss")
997 .header("dead", "beefbeef")
998 .status(StatusCode::TEMPORARY_REDIRECT)
999 .body("")
1000 .unwrap();
1001 assert!(rule.evaluate(&resp));
1002
1003 let resp = Response::builder()
1004 .header("foo", "barfuss")
1005 .header("dead", "beefbeef")
1006 .status(StatusCode::NOT_FOUND)
1007 .body("")
1008 .unwrap();
1009 assert!(rule.evaluate(&resp));
1010
1011 let resp = Response::builder()
1013 .header("foo", "barfuss")
1014 .header("dead", "beefbeef")
1015 .status(StatusCode::PERMANENT_REDIRECT)
1016 .body("")
1017 .unwrap();
1018 assert!(!rule.evaluate(&resp));
1019
1020 let resp = Response::builder()
1021 .header("foo", "barfuss")
1022 .header("dead", "zbeefbeef")
1023 .status(StatusCode::OK)
1024 .body("")
1025 .unwrap();
1026 assert!(!rule.evaluate(&resp));
1027 }
1028
1029 #[test]
1030 fn test_request_action() {
1031 assert_eq!(
1032 RequestAction::from_str("pass").unwrap(),
1033 RequestAction::Pass
1034 );
1035
1036 assert_eq!(
1037 RequestAction::from_str("block").unwrap(),
1038 RequestAction::Block(StatusCode::FORBIDDEN)
1039 );
1040
1041 assert_eq!(
1042 RequestAction::from_str("block:451").unwrap(),
1043 RequestAction::Block(StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS)
1044 );
1045
1046 assert!(RequestAction::from_str("block:0").is_err());
1047 assert!(RequestAction::from_str("block:foo").is_err());
1048 assert!(RequestAction::from_str("foo").is_err());
1049
1050 assert_eq!(
1051 RequestAction::from_str("limit:global:10/1m").unwrap(),
1052 RequestAction::RateLimit(RateLimitType::Global(
1053 Quota::with_period(Duration::from_secs(6))
1054 .unwrap()
1055 .allow_burst(NonZeroU32::new(10).unwrap()),
1056 RateLimiter::direct(Quota::with_period(Duration::from_secs(6)).unwrap())
1057 ))
1058 );
1059
1060 assert_eq!(
1061 RequestAction::from_str("limit:per_ip:10/1m").unwrap(),
1062 RequestAction::RateLimit(RateLimitType::PerIp(
1063 Quota::with_period(Duration::from_secs(6))
1064 .unwrap()
1065 .allow_burst(NonZeroU32::new(10).unwrap()),
1066 RateLimiter::dashmap_with_hasher(
1067 Quota::with_period(Duration::from_secs(6)).unwrap(),
1068 RandomState::new()
1069 )
1070 ))
1071 );
1072
1073 assert!(RequestAction::from_str("limit").is_err());
1074 assert!(RequestAction::from_str("limit:").is_err());
1075 assert!(RequestAction::from_str("limit:foo").is_err());
1076 assert!(RequestAction::from_str("limit:0/1s").is_err());
1077 assert!(RequestAction::from_str("limit:1/0s").is_err());
1078 assert!(RequestAction::from_str("limit:1/foo").is_err());
1079 }
1080
1081 #[test]
1082 fn test_response_action() {
1083 assert_eq!(
1084 ResponseAction::from_str("pass").unwrap(),
1085 ResponseAction::Pass
1086 );
1087
1088 assert_eq!(
1089 ResponseAction::from_str("block").unwrap(),
1090 ResponseAction::Block(StatusCode::FORBIDDEN)
1091 );
1092
1093 assert_eq!(
1094 ResponseAction::from_str("block:451").unwrap(),
1095 ResponseAction::Block(StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS)
1096 );
1097
1098 assert!(ResponseAction::from_str("block:0").is_err());
1099 assert!(ResponseAction::from_str("block:foo").is_err());
1100 assert!(ResponseAction::from_str("foo").is_err());
1101 }
1102
1103 #[test]
1104 fn test_ruleset() {
1105 let ruleset = json!({
1106 "requests": [
1107 {
1108 "match": {
1109 "methods": ["GET", "POST"],
1110 "host": "^foo",
1111 "path": "^/bar"
1112 },
1113 "action": "limit:global:10/1h",
1114 }]
1115 })
1116 .to_string();
1117 let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1118 assert!(!ruleset.is_empty());
1119
1120 let ruleset = json!({
1121 "responses": [
1122 {
1123 "match_req": {
1124 "methods": ["OPTIONS"],
1125 },
1126 "match_resp": {
1127 "status": ["100-200", "400-499", "599"],
1128 },
1129 "action": "block:499",
1130 }]
1131 })
1132 .to_string();
1133 let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1134 assert!(!ruleset.is_empty());
1135
1136 let ruleset = json!({
1137 "requests": [],
1138 "responses": [],
1139 })
1140 .to_string();
1141 let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1142 assert!(ruleset.is_empty());
1143
1144 let ruleset = json!({}).to_string();
1145 let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1146 assert!(ruleset.is_empty());
1147
1148 let ruleset = json!({
1149 "requests": [
1150 {
1151 "match": {
1152 "methods": ["GET", "POST"],
1153 "host": "^foo",
1154 "path": "^/bar"
1155 },
1156 "action": "limit:global:10/1h",
1157 },
1158 {
1159 "match": {
1160 "methods": ["DELETE"],
1161 },
1162 "action": "block:403",
1163 }],
1164
1165 "responses": [
1166 {
1167 "match_req": {
1168 "methods": ["OPTIONS"],
1169 },
1170 "match_resp": {
1171 "status": ["100-200", "400-499", "599"],
1172 },
1173 "action": "block:499",
1174 },
1175 {
1176 "match_resp": {
1177 "status": ["100-200", "400-499", "599"],
1178 },
1179 "action": "block:451",
1180 },
1181 {
1182 "match_resp": {
1183 "status": ["500"],
1184 "headers": [{
1185 "name": "foo",
1186 "value": "bar.*",
1187 }]
1188 },
1189 "action": "block:401",
1190 }]
1191 })
1192 .to_string();
1193
1194 let ruleset: Ruleset = serde_json::from_str(&ruleset).unwrap();
1195 assert!(!ruleset.is_empty());
1196
1197 for _ in 0..1000 {
1201 let req = Request::builder().method(Method::OPTIONS).body("").unwrap();
1202 assert_eq!(ruleset.evaluate_request(&req), Decision::Pass);
1203 }
1204
1205 for _ in 0..1000 {
1207 let req = Request::builder().method(Method::DELETE).body("").unwrap();
1208 assert_eq!(
1209 ruleset.evaluate_request(&req),
1210 Decision::Block(StatusCode::FORBIDDEN)
1211 );
1212 }
1213
1214 for _ in 0..10 {
1216 let req = Request::builder()
1217 .method(Method::GET)
1218 .version(Version::HTTP_2)
1219 .uri("https://foo/bar")
1220 .body("")
1221 .unwrap();
1222 assert_eq!(ruleset.evaluate_request(&req), Decision::Pass);
1223 }
1224
1225 let req = Request::builder()
1226 .method(Method::GET)
1227 .version(Version::HTTP_2)
1228 .uri("https://foo/bar")
1229 .body("")
1230 .unwrap();
1231
1232 let r = ruleset.evaluate_request(&req);
1233 match r {
1234 Decision::Throttle(v) => {
1235 assert!(v >= Duration::from_secs(359) && v <= Duration::from_secs(360))
1236 }
1237 _ => unreachable!(),
1238 }
1239
1240 for _ in 0..1000 {
1241 let req = Request::builder()
1242 .method(Method::GET)
1243 .version(Version::HTTP_2)
1244 .uri("https://foo/bar")
1245 .body("")
1246 .unwrap();
1247 assert!(matches!(
1248 ruleset.evaluate_request(&req),
1249 Decision::Throttle(_)
1250 ));
1251 }
1252
1253 let req = Request::builder().method(Method::POST).body(()).unwrap();
1255
1256 let resp = Response::builder()
1257 .status(StatusCode::PERMANENT_REDIRECT)
1258 .body("")
1259 .unwrap();
1260
1261 for _ in 0..1000 {
1263 assert_eq!(ruleset.evaluate_response(&req, &resp), Decision::Pass);
1264 }
1265
1266 for _ in 0..1000 {
1268 let resp = Response::builder().status(StatusCode::OK).body("").unwrap();
1269 assert_eq!(
1270 ruleset.evaluate_response(&req, &resp),
1271 Decision::Block(StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS)
1272 );
1273 }
1274
1275 let req = Request::builder().method(Method::OPTIONS).body(()).unwrap();
1277
1278 for _ in 0..1000 {
1279 let resp = Response::builder().status(StatusCode::OK).body("").unwrap();
1280 assert_eq!(
1281 ruleset.evaluate_response(&req, &resp),
1282 Decision::Block(StatusCode::from_u16(499).unwrap())
1283 );
1284 }
1285
1286 let resp = Response::builder()
1288 .status(StatusCode::INTERNAL_SERVER_ERROR)
1289 .header("foo", "bardead")
1290 .body("")
1291 .unwrap();
1292
1293 for _ in 0..1000 {
1294 assert_eq!(
1295 ruleset.evaluate_response(&req, &resp),
1296 Decision::Block(StatusCode::UNAUTHORIZED)
1297 );
1298 }
1299 }
1300
1301 #[tokio::test]
1302 async fn test_waflayer() {
1303 use axum::routing::get;
1304
1305 let ruleset = Ruleset::default();
1306 let layer = WafLayer::new(
1307 Arc::new(ArcSwap::new(Arc::new(ruleset))),
1308 None,
1309 Duration::ZERO,
1310 );
1311
1312 let ruleset = r#"
1313 requests:
1314 - action: block:451
1315 match:
1316 methods:
1317 - OPTIONS
1318 - GET
1319 headers:
1320 - name: foo
1321 regex: ^bar.*$
1322 "#;
1323
1324 assert!(layer.set_ruleset(parse_ruleset(ruleset.as_bytes()).unwrap()));
1325 assert!(!layer.set_ruleset(parse_ruleset(ruleset.as_bytes()).unwrap()));
1326
1327 let mut router = Router::new()
1328 .route("/", get(|| async { "foo" }).options(|| async { "bar" }))
1329 .layer(layer);
1330
1331 let req = Request::builder()
1333 .method(Method::OPTIONS)
1334 .header("foo", "barfuss")
1335 .body(Body::empty())
1336 .unwrap();
1337 let resp = router.call(req).await.unwrap();
1338 assert_eq!(resp.status(), StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS);
1339
1340 let req = Request::builder()
1341 .method(Method::GET)
1342 .header("foo", "barfuss")
1343 .body(Body::empty())
1344 .unwrap();
1345 let resp = router.call(req).await.unwrap();
1346 assert_eq!(resp.status(), StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS);
1347
1348 let req = Request::builder()
1350 .method(Method::OPTIONS)
1351 .body(Body::empty())
1352 .unwrap();
1353 let resp = router.call(req).await.unwrap();
1354 assert_eq!(resp.status(), StatusCode::OK);
1355 }
1356}