Skip to main content

pubky_homeserver/admin_server/
app.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::time::Duration;
4
5use super::routes::{
6    dav_handler, delete_entry,
7    disable_users::{disable_user, enable_user},
8    generate_signup_token, info, root, user_quota,
9};
10use super::trace::with_trace_layer;
11use super::{app_state::AppState, auth_middleware::AdminAuthLayer};
12use crate::AppContext;
13#[cfg(any(test, feature = "testing"))]
14use crate::MockDataDir;
15use crate::{AppContextConversionError, PersistentDataDir};
16use axum::routing::{any, delete, post};
17use axum::{routing::get, Router};
18use axum_server::Handle;
19use tokio::task::JoinHandle;
20use tower_http::cors::CorsLayer;
21
22/// Admin password protected router.
23fn create_protected_router(password: &str) -> Router<AppState> {
24    Router::new()
25        .route(
26            "/generate_signup_token",
27            get(generate_signup_token::generate_signup_token)
28                .post(generate_signup_token::generate_signup_token_with_limits),
29        )
30        .route("/info", get(info::info))
31        .route("/webdav/{*entry_path}", delete(delete_entry::delete_entry))
32        .route("/users/{pubkey}/disable", post(disable_user))
33        .route("/users/{pubkey}/enable", post(enable_user))
34        .route(
35            "/users/{pubkey}/quota",
36            get(user_quota::get_user_quota).patch(user_quota::patch_user_quota),
37        )
38        .layer(AdminAuthLayer::new(password.to_string()))
39}
40
41/// Public router without any authentication.
42/// NO PASSWORD PROTECTION!
43fn create_public_router() -> Router<AppState> {
44    Router::new().route("/", get(root::handler))
45}
46
47/// Create the app
48pub(crate) fn create_app(
49    state: AppState,
50    password: &str,
51) -> axum::routing::IntoMakeService<Router> {
52    let admin_router = create_protected_router(password);
53    let public_router = create_public_router();
54    let app = Router::new()
55        .merge(admin_router)
56        .merge(public_router)
57        .route("/dav{*path}", any(dav_handler::dav_handler))
58        .with_state(state)
59        .layer(CorsLayer::very_permissive());
60
61    with_trace_layer(app).into_make_service()
62}
63
64/// Errors that can occur when building a `AdminServer`.
65#[derive(thiserror::Error, Debug)]
66pub enum AdminServerBuildError {
67    /// Failed to create admin server.
68    #[error("Failed to create admin server: {0}")]
69    Server(anyhow::Error),
70
71    /// Failed to boostrap from the data directory.
72    #[error("Failed to boostrap from the data directory: {0}")]
73    DataDir(AppContextConversionError),
74}
75
76/// Admin server
77///
78/// This server is protected by the admin auth middleware.
79///
80/// When dropped, the server will stop.
81pub struct AdminServer {
82    http_handle: Handle<SocketAddr>,
83    join_handle: JoinHandle<()>,
84    socket: SocketAddr,
85    password: String,
86}
87
88impl AdminServer {
89    /// Create a new admin server from a data directory.
90    pub async fn from_data_dir(data_dir: PersistentDataDir) -> Result<Self, AdminServerBuildError> {
91        let context = AppContext::read_from(data_dir)
92            .await
93            .map_err(AdminServerBuildError::DataDir)?;
94        Self::start(&context).await
95    }
96
97    /// Create a new admin server from a data directory path.
98    pub async fn from_data_dir_path(data_dir_path: PathBuf) -> Result<Self, AdminServerBuildError> {
99        let data_dir = PersistentDataDir::new(data_dir_path);
100        Self::from_data_dir(data_dir).await
101    }
102
103    /// Create a new admin server from a mock data directory.
104    #[cfg(any(test, feature = "testing"))]
105    pub async fn from_mock_dir(mock_dir: MockDataDir) -> Result<Self, AdminServerBuildError> {
106        let context = AppContext::read_from(mock_dir)
107            .await
108            .map_err(AdminServerBuildError::DataDir)?;
109        Self::start(&context).await
110    }
111
112    /// Run the admin server.
113    pub async fn start(context: &AppContext) -> Result<Self, AdminServerBuildError> {
114        let password = context.config_toml.admin.admin_password.clone();
115        let state = AppState::new(
116            context.sql_db.clone(),
117            context.file_service.clone(),
118            &password,
119            context.user_service.clone(),
120        )
121        .with_metadata_from_config(
122            context.keypair.public_key().z32(),
123            &context.config_toml,
124            env!("CARGO_PKG_VERSION"),
125        );
126        let socket = context.config_toml.admin.listen_socket;
127        let app = create_app(state, password.as_str());
128        let listener = std::net::TcpListener::bind(socket)
129            .map_err(|e| AdminServerBuildError::Server(e.into()))?;
130        listener
131            .set_nonblocking(true)
132            .map_err(|e| AdminServerBuildError::Server(e.into()))?;
133        let socket = listener
134            .local_addr()
135            .map_err(|e| AdminServerBuildError::Server(e.into()))?;
136        let http_handle = Handle::new();
137        let inner_http_handle = http_handle.clone();
138        let server =
139            axum_server::from_tcp(listener).map_err(|e| AdminServerBuildError::Server(e.into()))?;
140        let join_handle = tokio::spawn(async move {
141            server
142                .handle(inner_http_handle)
143                .serve(app)
144                .await
145                .unwrap_or_else(|e| tracing::error!("Admin server error: {}", e));
146        });
147        Ok(Self {
148            http_handle,
149            socket,
150            join_handle,
151            password,
152        })
153    }
154
155    /// Get the socket address of the admin server.
156    pub fn listen_socket(&self) -> SocketAddr {
157        self.socket
158    }
159
160    /// Create a signup token for the given homeserver.
161    pub async fn create_signup_token(&self) -> anyhow::Result<String> {
162        let admin_socket = self.listen_socket();
163        let url = format!("http://{}/generate_signup_token", admin_socket);
164        let response = reqwest::Client::new()
165            .get(url)
166            .header("X-Admin-Password", &self.password)
167            .send()
168            .await?;
169        let response = response.error_for_status()?;
170        let body = response.text().await?;
171        Ok(body)
172    }
173}
174
175impl Drop for AdminServer {
176    fn drop(&mut self) {
177        self.http_handle
178            .graceful_shutdown(Some(Duration::from_secs(5)));
179        self.join_handle.abort();
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use std::str::FromStr;
186
187    use axum::http::Method;
188    use axum_test::TestServer;
189    use base64::Engine;
190
191    use crate::data_directory::quota_config::BandwidthQuota;
192    use crate::persistence::files::FileService;
193
194    use super::*;
195
196    fn bw(s: &str) -> BandwidthQuota {
197        BandwidthQuota::from_str(s).unwrap()
198    }
199
200    fn create_test_server(context: &AppContext) -> TestServer {
201        TestServer::new(create_app(
202            AppState::new(
203                context.sql_db.clone(),
204                FileService::new_from_context(context).unwrap(),
205                "",
206                context.user_service.clone(),
207            ),
208            "test",
209        ))
210        .unwrap()
211    }
212
213    #[tokio::test]
214    #[pubky_test_utils::test]
215    async fn test_root() {
216        let context = AppContext::test().await;
217        let server = create_test_server(&context);
218        let response = server.get("/").expect_success().await;
219        response.assert_status_ok();
220    }
221
222    #[tokio::test]
223    #[pubky_test_utils::test]
224    async fn test_generate_signup_token_fail() {
225        let context = AppContext::test().await;
226        let server = create_test_server(&context);
227        // No password
228        let response = server.get("/generate_signup_token").expect_failure().await;
229        response.assert_status_unauthorized();
230
231        // wrong password
232        let response = server
233            .get("/generate_signup_token")
234            .add_header("X-Admin-Password", "wrongpassword")
235            .expect_failure()
236            .await;
237        response.assert_status_unauthorized();
238    }
239
240    #[tokio::test]
241    #[pubky_test_utils::test]
242    async fn test_generate_signup_token_success() {
243        let context = AppContext::test().await;
244        let server = create_test_server(&context);
245        let response = server
246            .get("/generate_signup_token")
247            .add_header("X-Admin-Password", "test")
248            .expect_success()
249            .await;
250        response.assert_status_ok();
251    }
252
253    fn auth_header() -> String {
254        // AppState is created with password "" in create_test_server
255        let auth = base64::engine::general_purpose::STANDARD.encode("admin:");
256        format!("Basic {auth}")
257    }
258
259    /// PROPFIND and GET on /dav/ root should succeed.
260    #[tokio::test]
261    #[pubky_test_utils::test]
262    async fn test_dav_root_propfind_and_get() {
263        let context = AppContext::test().await;
264        let server = create_test_server(&context);
265        let auth_value = auth_header();
266
267        let propfind = Method::from_bytes(b"PROPFIND").unwrap();
268        let response = server
269            .method(propfind, "/dav/")
270            .add_header("Authorization", auth_value.as_str())
271            .add_header("Depth", "1")
272            .expect_success()
273            .await;
274        // WebDAV PROPFIND returns 207 Multi-Status on success
275        response.assert_status(axum::http::StatusCode::MULTI_STATUS);
276
277        let response = server
278            .get("/dav/")
279            .add_header("Authorization", auth_value.as_str())
280            .expect_success()
281            .await;
282        response.assert_status_ok();
283    }
284
285    /// PUT a file via WebDAV, GET it back, then DELETE it.
286    #[tokio::test]
287    #[pubky_test_utils::test]
288    async fn test_dav_put_get_delete_file() {
289        use crate::persistence::sql::user::UserRepository;
290        use pubky_common::crypto::Keypair;
291
292        let context = AppContext::test().await;
293        let server = create_test_server(&context);
294        let auth_value = auth_header();
295
296        // Register a user so writes are accepted by the entry layer
297        let keypair = Keypair::from_secret(&[0; 32]);
298        let pubkey = keypair.public_key();
299        UserRepository::create(&pubkey, &mut context.sql_db.pool().into())
300            .await
301            .unwrap();
302
303        let file_content = b"hello webdav";
304        let file_url = format!("/dav/{}/pub/test.txt", pubkey.z32());
305
306        // PUT a file
307        let response = server
308            .put(&file_url)
309            .add_header("Authorization", auth_value.as_str())
310            .bytes(file_content.to_vec().into())
311            .expect_success()
312            .await;
313        response.assert_status(axum::http::StatusCode::CREATED);
314
315        // GET it back
316        let response = server
317            .get(&file_url)
318            .add_header("Authorization", auth_value.as_str())
319            .expect_success()
320            .await;
321        response.assert_status_ok();
322        assert_eq!(response.as_bytes().as_ref(), file_content);
323
324        // PROPFIND on the user's pub directory should list the file
325        let propfind = Method::from_bytes(b"PROPFIND").unwrap();
326        let dir_url = format!("/dav/{}/pub/", pubkey.z32());
327        let response = server
328            .method(propfind, &dir_url)
329            .add_header("Authorization", auth_value.as_str())
330            .add_header("Depth", "1")
331            .expect_success()
332            .await;
333        response.assert_status(axum::http::StatusCode::MULTI_STATUS);
334        let body = response.text();
335        assert!(body.contains("test.txt"), "PROPFIND should list the file");
336
337        // DELETE the file
338        let response = server
339            .delete(&file_url)
340            .add_header("Authorization", auth_value.as_str())
341            .expect_success()
342            .await;
343        response.assert_status(axum::http::StatusCode::NO_CONTENT);
344
345        // GET should now return 404
346        let response = server
347            .get(&file_url)
348            .add_header("Authorization", auth_value.as_str())
349            .expect_failure()
350            .await;
351        response.assert_status(axum::http::StatusCode::NOT_FOUND);
352    }
353
354    /// Exceeding user quota through the admin DAV endpoint currently returns 500.
355    #[tokio::test]
356    #[pubky_test_utils::test]
357    async fn test_dav_put_quota_overflow_returns_500() {
358        use crate::persistence::sql::user::UserRepository;
359        use pubky_common::crypto::Keypair;
360
361        let mut context = AppContext::test().await;
362        context.config_toml.storage.default_quota_mb = Some(1);
363        let server = create_test_server(&context);
364        let auth_value = auth_header();
365
366        let keypair = Keypair::from_secret(&[0; 32]);
367        let pubkey = keypair.public_key();
368        UserRepository::create(&pubkey, &mut context.sql_db.pool().into())
369            .await
370            .unwrap();
371
372        let pubkey = keypair.public_key().z32();
373        let file1_url = format!("/dav/{pubkey}/pub/one.bin");
374        let file2_url = format!("/dav/{pubkey}/pub/two.bin");
375        let file_content = vec![0u8; 600_000];
376
377        let response = server
378            .put(&file1_url)
379            .add_header("Authorization", auth_value.as_str())
380            .bytes(file_content.clone().into())
381            .expect_success()
382            .await;
383        response.assert_status(axum::http::StatusCode::CREATED);
384
385        let response = server
386            .put(&file2_url)
387            .add_header("Authorization", auth_value.as_str())
388            .bytes(file_content.into())
389            .expect_failure()
390            .await;
391        response.assert_status(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
392    }
393
394    #[tokio::test]
395    #[pubky_test_utils::test]
396    async fn test_generate_signup_token_with_limits() {
397        use crate::persistence::sql::signup_code::{SignupCodeId, SignupCodeRepository};
398        use crate::shared::user_quota::QuotaOverride;
399
400        let context = AppContext::test().await;
401        let server = create_test_server(&context);
402
403        // POST with custom limits: null = Default, absent = Default, value = Value(T)
404        let body = serde_json::json!({
405            "storage_quota_mb": 1024,
406            "rate_read": "200mb/m"
407        });
408        let response = server
409            .post("/generate_signup_token")
410            .add_header("X-Admin-Password", "test")
411            .content_type("application/json")
412            .bytes(serde_json::to_vec(&body).unwrap().into())
413            .expect_success()
414            .await;
415        response.assert_status_ok();
416
417        // Verify the code was created with custom limits
418        let token_str = response.text();
419        let code_id = SignupCodeId::new(token_str).unwrap();
420        let code = SignupCodeRepository::get(&code_id, &mut context.sql_db.pool().into())
421            .await
422            .unwrap();
423        let limits = code.quota();
424        assert_eq!(limits.storage_quota_mb, QuotaOverride::Value(1024));
425        assert_eq!(limits.rate_read, QuotaOverride::Value(bw("200mb/m")));
426        assert_eq!(limits.rate_write, QuotaOverride::Default);
427    }
428}