api_gateway/middleware/
license_validation.rs1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::{IntoResponse, Response};
5use dashmap::DashMap;
6use http::Method;
7use std::sync::Arc;
8
9use modkit::api::{OperationSpec, Problem};
10
11const BASE_FEATURE: &str = "gts.x.core.lic.feat.v1~x.core.global.base.v1";
12
13type LicenseKey = (Method, String);
14
15#[derive(Clone)]
16pub struct LicenseRequirementMap {
17 requirements: Arc<DashMap<LicenseKey, Vec<String>>>,
18}
19
20impl LicenseRequirementMap {
21 #[must_use]
22 pub fn from_specs(specs: &[OperationSpec]) -> Self {
23 let requirements = DashMap::new();
24
25 for spec in specs {
26 if let Some(req) = spec.license_requirement.as_ref() {
27 requirements.insert(
28 (spec.method.clone(), spec.path.clone()),
29 req.license_names.clone(),
30 );
31 }
32 }
33
34 Self {
35 requirements: Arc::new(requirements),
36 }
37 }
38
39 fn get(&self, method: &Method, path: &str) -> Option<Vec<String>> {
40 self.requirements
41 .get(&(method.clone(), path.to_owned()))
42 .map(|v| v.value().clone())
43 }
44}
45
46pub async fn license_validation_middleware(
47 map: LicenseRequirementMap,
48 req: Request,
49 next: Next,
50) -> Response {
51 let method = req.method().clone();
52 let path = req
53 .extensions()
54 .get::<axum::extract::MatchedPath>()
55 .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
56
57 let Some(required) = map.get(&method, &path) else {
58 return next.run(req).await;
59 };
60
61 if required.iter().any(|r| r != BASE_FEATURE) {
65 return Problem::new(
66 StatusCode::FORBIDDEN,
67 "Forbidden",
68 format!(
69 "Endpoint requires unsupported license features '{required:?}'; only '{BASE_FEATURE}' is allowed",
70 ),
71 )
72 .into_response();
73 }
74
75 next.run(req).await
76}