Skip to main content

api_gateway/middleware/
license_validation.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::{IntoResponse, Response};
5use dashmap::DashMap;
6use http::Method;
7use modkit::api::{OperationSpec, Problem};
8use std::sync::Arc;
9
10use crate::middleware::common;
11
12const BASE_FEATURE: &str = "gts.x.core.lic.feat.v1~x.core.global.base.v1";
13
14type LicenseKey = (Method, String);
15
16#[derive(Clone)]
17pub struct LicenseRequirementMap {
18    requirements: Arc<DashMap<LicenseKey, Vec<String>>>,
19}
20
21impl LicenseRequirementMap {
22    #[must_use]
23    pub fn from_specs(specs: &[OperationSpec]) -> Self {
24        let requirements = DashMap::new();
25
26        for spec in specs {
27            if let Some(req) = spec.license_requirement.as_ref() {
28                requirements.insert(
29                    (spec.method.clone(), spec.path.clone()),
30                    req.license_names.clone(),
31                );
32            }
33        }
34
35        Self {
36            requirements: Arc::new(requirements),
37        }
38    }
39
40    fn get(&self, method: &Method, path: &str) -> Option<Vec<String>> {
41        self.requirements
42            .get(&(method.clone(), path.to_owned()))
43            .map(|v| v.value().clone())
44    }
45}
46
47pub async fn license_validation_middleware(
48    map: LicenseRequirementMap,
49    req: Request,
50    next: Next,
51) -> Response {
52    let method = req.method().clone();
53    let path = req
54        .extensions()
55        .get::<axum::extract::MatchedPath>()
56        .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
57
58    let path = common::resolve_path(&req, path.as_str());
59
60    let Some(required) = map.get(&method, &path) else {
61        return next.run(req).await;
62    };
63
64    // TODO: this is a stub implementation
65    // We need first to implement plugin and get its client from client_hub
66    // Plugin should provide an interface to get a list of global features (features that are not scoped to particular resource)
67    if required.iter().any(|r| r != BASE_FEATURE) {
68        return Problem::new(
69            StatusCode::FORBIDDEN,
70            "Forbidden",
71            format!(
72                "Endpoint requires unsupported license features '{required:?}'; only '{BASE_FEATURE}' is allowed",
73            ),
74        )
75        .into_response();
76    }
77
78    next.run(req).await
79}