1use std::sync::Arc;
2
3use covert_framework::Backend;
4use covert_types::{error::ApiError, mount::MountConfig, request::Request};
5use dashmap::DashMap;
6use futures::future::BoxFuture;
7use tower::Service;
8use uuid::Uuid;
9
10use crate::{
11 error::{Error, ErrorType},
12 repos::{mount::MountRepo, namespace::Namespace},
13 response::{ResponseContext, ResponseWithCtx},
14 system::SYSTEM_MOUNT_PATH,
15};
16
17pub struct Router {
19 backend_lookup: DashMap<String, Arc<Backend>>,
21 mount_repo: MountRepo,
22}
23
24impl Router {
25 #[must_use]
26 pub fn new(mount_repo: MountRepo) -> Self {
27 Router {
28 backend_lookup: DashMap::default(),
29 mount_repo,
30 }
31 }
32
33 #[tracing::instrument(
34 skip(self, req),
35 fields(
36 path = req.path,
37 operation = ?req.operation
38 )
39 )]
40 pub async fn route(&self, mut req: Request) -> Result<ResponseWithCtx, ApiError> {
41 let (backend, path, config) = match req.extensions.get::<Namespace>() {
42 Some(_) if req.path.starts_with(SYSTEM_MOUNT_PATH) => {
43 let backend = self
44 .get_system_mount()
45 .ok_or_else(ApiError::internal_error)?;
46
47 (
48 backend,
49 SYSTEM_MOUNT_PATH.to_string(),
50 MountConfig::default(),
51 )
52 }
53 Some(ns) => {
54 let mount = self
55 .mount_repo
56 .longest_prefix(&req.path, &ns.id)
57 .await?
58 .ok_or_else(|| {
59 Error::from(ErrorType::MountNotFound {
60 path: req.path.clone(),
61 })
62 })?;
63 let backend = self
64 .backend_lookup
65 .get(&mount.id.to_string())
66 .map(|b| Arc::clone(&b))
67 .ok_or_else(ApiError::internal_error)?;
68
69 (backend, mount.path, mount.config)
70 }
71 None => {
73 if !req.path.starts_with(SYSTEM_MOUNT_PATH) {
75 return Err(ApiError::unauthorized());
76 }
77
78 let backend = self
79 .get_system_mount()
80 .ok_or_else(ApiError::internal_error)?;
81
82 (
83 backend,
84 SYSTEM_MOUNT_PATH.to_string(),
85 MountConfig::default(),
86 )
87 }
88 };
89
90 req.advance_path(&path);
91 req.extensions.insert(config.clone());
92
93 let span = tracing::span!(
94 tracing::Level::DEBUG,
95 "backend_handle_request",
96 backend_mount_path = path,
97 backend_type = %backend.variant(),
98 );
99 let _enter = span.enter();
100
101 backend.handle_request(req).await.map(|response| {
102 let ctx = ResponseContext {
103 backend_config: config,
104 backend_mount_path: path,
105 };
106 ResponseWithCtx { response, ctx }
107 })
108 }
109
110 pub fn clear_mounts(&self) {
111 self.backend_lookup.clear();
112 }
113
114 pub fn mount(&self, mount_id: Uuid, backend: Arc<Backend>) {
115 self.backend_lookup.insert(mount_id.to_string(), backend);
116 }
117
118 pub fn mount_system(&self, backend: Arc<Backend>) {
119 self.backend_lookup.insert("system".to_string(), backend);
120 }
121
122 #[must_use]
123 pub fn get_system_mount(&self) -> Option<Arc<Backend>> {
124 self.backend_lookup.get("system").map(|b| Arc::clone(&b))
125 }
126
127 #[must_use]
128 pub fn remove(&self, mount_id: Uuid) -> bool {
129 self.backend_lookup.remove(&mount_id.to_string()).is_some()
130 }
131}
132
133#[derive(Clone)]
134pub struct RouterService(Arc<Router>);
135
136impl RouterService {
137 #[must_use]
138 pub fn new(router: Arc<Router>) -> Self {
139 Self(router)
140 }
141}
142
143impl Service<Request> for RouterService {
144 type Response = ResponseWithCtx;
145
146 type Error = ApiError;
147
148 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
149
150 fn poll_ready(
151 &mut self,
152 _cx: &mut std::task::Context<'_>,
153 ) -> std::task::Poll<Result<(), Self::Error>> {
154 std::task::Poll::Ready(Ok(()))
155 }
156
157 fn call(&mut self, req: Request) -> Self::Future {
158 let router = self.0.clone();
159 Box::pin(async move { router.route(req).await })
160 }
161}