covert_system/
router.rs

1use std::sync::Arc;
2
3use covert_framework::Backend;
4use covert_types::{error::ApiError, mount::MountConfig, request::Request};
5use dashmap::DashMap;
6use futures::future::BoxFuture;
7use tower::Service;
8use uuid::Uuid;
9
10use crate::{
11    error::{Error, ErrorType},
12    repos::{mount::MountRepo, namespace::Namespace},
13    response::{ResponseContext, ResponseWithCtx},
14    system::SYSTEM_MOUNT_PATH,
15};
16
17/// Router is used to do prefix based routing of a request to a logical backend
18pub struct Router {
19    // mount id -> Backend
20    backend_lookup: DashMap<String, Arc<Backend>>,
21    mount_repo: MountRepo,
22}
23
24impl Router {
25    #[must_use]
26    pub fn new(mount_repo: MountRepo) -> Self {
27        Router {
28            backend_lookup: DashMap::default(),
29            mount_repo,
30        }
31    }
32
33    #[tracing::instrument(
34        skip(self, req),
35        fields(
36            path = req.path,
37            operation = ?req.operation
38        )
39    )]
40    pub async fn route(&self, mut req: Request) -> Result<ResponseWithCtx, ApiError> {
41        let (backend, path, config) = match req.extensions.get::<Namespace>() {
42            Some(_) if req.path.starts_with(SYSTEM_MOUNT_PATH) => {
43                let backend = self
44                    .get_system_mount()
45                    .ok_or_else(ApiError::internal_error)?;
46
47                (
48                    backend,
49                    SYSTEM_MOUNT_PATH.to_string(),
50                    MountConfig::default(),
51                )
52            }
53            Some(ns) => {
54                let mount = self
55                    .mount_repo
56                    .longest_prefix(&req.path, &ns.id)
57                    .await?
58                    .ok_or_else(|| {
59                        Error::from(ErrorType::MountNotFound {
60                            path: req.path.clone(),
61                        })
62                    })?;
63                let backend = self
64                    .backend_lookup
65                    .get(&mount.id.to_string())
66                    .map(|b| Arc::clone(&b))
67                    .ok_or_else(ApiError::internal_error)?;
68
69                (backend, mount.path, mount.config)
70            }
71            // Namespace can be null if not unsealed
72            None => {
73                // Only system backend can handle requests when not unsealed
74                if !req.path.starts_with(SYSTEM_MOUNT_PATH) {
75                    return Err(ApiError::unauthorized());
76                }
77
78                let backend = self
79                    .get_system_mount()
80                    .ok_or_else(ApiError::internal_error)?;
81
82                (
83                    backend,
84                    SYSTEM_MOUNT_PATH.to_string(),
85                    MountConfig::default(),
86                )
87            }
88        };
89
90        req.advance_path(&path);
91        req.extensions.insert(config.clone());
92
93        let span = tracing::span!(
94            tracing::Level::DEBUG,
95            "backend_handle_request",
96            backend_mount_path = path,
97            backend_type = %backend.variant(),
98        );
99        let _enter = span.enter();
100
101        backend.handle_request(req).await.map(|response| {
102            let ctx = ResponseContext {
103                backend_config: config,
104                backend_mount_path: path,
105            };
106            ResponseWithCtx { response, ctx }
107        })
108    }
109
110    pub fn clear_mounts(&self) {
111        self.backend_lookup.clear();
112    }
113
114    pub fn mount(&self, mount_id: Uuid, backend: Arc<Backend>) {
115        self.backend_lookup.insert(mount_id.to_string(), backend);
116    }
117
118    pub fn mount_system(&self, backend: Arc<Backend>) {
119        self.backend_lookup.insert("system".to_string(), backend);
120    }
121
122    #[must_use]
123    pub fn get_system_mount(&self) -> Option<Arc<Backend>> {
124        self.backend_lookup.get("system").map(|b| Arc::clone(&b))
125    }
126
127    #[must_use]
128    pub fn remove(&self, mount_id: Uuid) -> bool {
129        self.backend_lookup.remove(&mount_id.to_string()).is_some()
130    }
131}
132
133#[derive(Clone)]
134pub struct RouterService(Arc<Router>);
135
136impl RouterService {
137    #[must_use]
138    pub fn new(router: Arc<Router>) -> Self {
139        Self(router)
140    }
141}
142
143impl Service<Request> for RouterService {
144    type Response = ResponseWithCtx;
145
146    type Error = ApiError;
147
148    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
149
150    fn poll_ready(
151        &mut self,
152        _cx: &mut std::task::Context<'_>,
153    ) -> std::task::Poll<Result<(), Self::Error>> {
154        std::task::Poll::Ready(Ok(()))
155    }
156
157    fn call(&mut self, req: Request) -> Self::Future {
158        let router = self.0.clone();
159        Box::pin(async move { router.route(req).await })
160    }
161}