api_gateway/middleware/
license_validation.rs1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::{IntoResponse, Response};
5use dashmap::DashMap;
6use http::Method;
7use modkit::api::{OperationSpec, Problem};
8use std::sync::Arc;
9
10use crate::middleware::common;
11
12const BASE_FEATURE: &str = "gts.x.core.lic.feat.v1~x.core.global.base.v1";
13
14type LicenseKey = (Method, String);
15
16#[derive(Clone)]
17pub struct LicenseRequirementMap {
18 requirements: Arc<DashMap<LicenseKey, Vec<String>>>,
19}
20
21impl LicenseRequirementMap {
22 #[must_use]
23 pub fn from_specs(specs: &[OperationSpec]) -> Self {
24 let requirements = DashMap::new();
25
26 for spec in specs {
27 if let Some(req) = spec.license_requirement.as_ref() {
28 requirements.insert(
29 (spec.method.clone(), spec.path.clone()),
30 req.license_names.clone(),
31 );
32 }
33 }
34
35 Self {
36 requirements: Arc::new(requirements),
37 }
38 }
39
40 fn get(&self, method: &Method, path: &str) -> Option<Vec<String>> {
41 self.requirements
42 .get(&(method.clone(), path.to_owned()))
43 .map(|v| v.value().clone())
44 }
45}
46
47pub async fn license_validation_middleware(
48 map: LicenseRequirementMap,
49 req: Request,
50 next: Next,
51) -> Response {
52 let method = req.method().clone();
53 let path = req
54 .extensions()
55 .get::<axum::extract::MatchedPath>()
56 .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
57
58 let path = common::resolve_path(&req, path.as_str());
59
60 let Some(required) = map.get(&method, &path) else {
61 return next.run(req).await;
62 };
63
64 if required.iter().any(|r| r != BASE_FEATURE) {
68 return Problem::new(
69 StatusCode::FORBIDDEN,
70 "Forbidden",
71 format!(
72 "Endpoint requires unsupported license features '{required:?}'; only '{BASE_FEATURE}' is allowed",
73 ),
74 )
75 .into_response();
76 }
77
78 next.run(req).await
79}