Skip to main content

api_gateway/middleware/
license_validation.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::{IntoResponse, Response};
5use dashmap::DashMap;
6use http::Method;
7use std::sync::Arc;
8
9use modkit::api::{OperationSpec, Problem};
10
11const BASE_FEATURE: &str = "gts.x.core.lic.feat.v1~x.core.global.base.v1";
12
13type LicenseKey = (Method, String);
14
15#[derive(Clone)]
16pub struct LicenseRequirementMap {
17    requirements: Arc<DashMap<LicenseKey, Vec<String>>>,
18}
19
20impl LicenseRequirementMap {
21    #[must_use]
22    pub fn from_specs(specs: &[OperationSpec]) -> Self {
23        let requirements = DashMap::new();
24
25        for spec in specs {
26            if let Some(req) = spec.license_requirement.as_ref() {
27                requirements.insert(
28                    (spec.method.clone(), spec.path.clone()),
29                    req.license_names.clone(),
30                );
31            }
32        }
33
34        Self {
35            requirements: Arc::new(requirements),
36        }
37    }
38
39    fn get(&self, method: &Method, path: &str) -> Option<Vec<String>> {
40        self.requirements
41            .get(&(method.clone(), path.to_owned()))
42            .map(|v| v.value().clone())
43    }
44}
45
46pub async fn license_validation_middleware(
47    map: LicenseRequirementMap,
48    req: Request,
49    next: Next,
50) -> Response {
51    let method = req.method().clone();
52    let path = req
53        .extensions()
54        .get::<axum::extract::MatchedPath>()
55        .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
56
57    let Some(required) = map.get(&method, &path) else {
58        return next.run(req).await;
59    };
60
61    // TODO: this is a stub implementation
62    // We need first to implement plugin and get its client from client_hub
63    // Plugin should provide an interface to get a list of global features (features that are not scoped to particular resource)
64    if required.iter().any(|r| r != BASE_FEATURE) {
65        return Problem::new(
66            StatusCode::FORBIDDEN,
67            "Forbidden",
68            format!(
69                "Endpoint requires unsupported license features '{required:?}'; only '{BASE_FEATURE}' is allowed",
70            ),
71        )
72        .into_response();
73    }
74
75    next.run(req).await
76}