Skip to main content

kagi_server/server/
mod.rs

1pub mod errors;
2pub mod routes;
3pub mod state;
4
5use crate::server::state::AppState;
6use std::net::SocketAddr;
7use std::path::Path;
8use tower_governor::GovernorLayer;
9use tower_governor::governor::GovernorConfigBuilder;
10use tower_http::limit::RequestBodyLimitLayer;
11use tower_http::trace::TraceLayer;
12
13pub async fn serve(
14    bind: SocketAddr,
15    db_path: &Path,
16    key_file_path: &Path,
17    max_body_size: usize,
18) -> Result<(), anyhow::Error> {
19    let state = AppState::new(db_path, key_file_path).await?;
20
21    let governor_conf = GovernorConfigBuilder::default()
22        .per_second(2)
23        .burst_size(30)
24        .finish()
25        .unwrap();
26
27    let app = routes::router(state.clone())
28        .layer(GovernorLayer::new(governor_conf))
29        .layer(RequestBodyLimitLayer::new(max_body_size))
30        .layer(TraceLayer::new_for_http());
31
32    tracing::info!("kagi: server key fingerprint {}", state.fingerprint);
33
34    let listener = tokio::net::TcpListener::bind(bind).await?;
35    let addr = listener.local_addr()?;
36    println!("kagi: server running on http://{addr}");
37    tracing::info!("kagi: listening on http://{}", addr);
38
39    if bind.ip().is_unspecified() || !bind.ip().is_loopback() {
40        tracing::warn!(
41            "kagi: server bound to public interface without HTTPS. Application-layer encryption protects payloads, but HTTPS is recommended for metadata safety."
42        );
43    }
44
45    tracing::info!("kagi: server started successfully");
46    axum::serve(
47        listener,
48        app.into_make_service_with_connect_info::<SocketAddr>(),
49    )
50    .await?;
51    Ok(())
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use std::net::SocketAddr;
58
59    async fn spawn_test_server(
60        max_body_size: usize,
61    ) -> (SocketAddr, tempfile::TempDir, tempfile::TempDir) {
62        let db_dir = tempfile::TempDir::new().unwrap();
63        let key_dir = tempfile::TempDir::new().unwrap();
64        let db_path = db_dir.path().join("server.db");
65        let key_path = key_dir.path().join("server.key");
66
67        let state = AppState::new(&db_path, &key_path).await.unwrap();
68        let governor_conf = GovernorConfigBuilder::default()
69            .per_second(60)
70            .burst_size(100)
71            .finish()
72            .unwrap();
73        let app = routes::router(state)
74            .layer(GovernorLayer::new(governor_conf))
75            .layer(RequestBodyLimitLayer::new(max_body_size))
76            .layer(TraceLayer::new_for_http());
77
78        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
79        let addr = listener.local_addr().unwrap();
80        tokio::spawn(async move {
81            axum::serve(
82                listener,
83                app.into_make_service_with_connect_info::<SocketAddr>(),
84            )
85            .await
86            .unwrap();
87        });
88
89        // Give server a moment to start
90        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
91        (addr, db_dir, key_dir)
92    }
93
94    async fn spawn_test_server_with_rate_limit(
95        max_body_size: usize,
96        per_second: u64,
97        burst_size: u32,
98    ) -> (SocketAddr, tempfile::TempDir, tempfile::TempDir) {
99        let db_dir = tempfile::TempDir::new().unwrap();
100        let key_dir = tempfile::TempDir::new().unwrap();
101        let db_path = db_dir.path().join("server.db");
102        let key_path = key_dir.path().join("server.key");
103
104        let state = AppState::new(&db_path, &key_path).await.unwrap();
105        let governor_conf = GovernorConfigBuilder::default()
106            .per_second(per_second)
107            .burst_size(burst_size)
108            .finish()
109            .unwrap();
110        let app = routes::router(state)
111            .layer(GovernorLayer::new(governor_conf))
112            .layer(RequestBodyLimitLayer::new(max_body_size))
113            .layer(TraceLayer::new_for_http());
114
115        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
116        let addr = listener.local_addr().unwrap();
117        tokio::spawn(async move {
118            axum::serve(
119                listener,
120                app.into_make_service_with_connect_info::<SocketAddr>(),
121            )
122            .await
123            .unwrap();
124        });
125
126        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
127        (addr, db_dir, key_dir)
128    }
129
130    fn test_http_client() -> reqwest::Client {
131        reqwest::Client::builder().no_proxy().build().unwrap()
132    }
133
134    #[tokio::test]
135    async fn test_health_check_endpoint() {
136        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
137        let client = test_http_client();
138        let resp = client.get(format!("http://{addr}/")).send().await.unwrap();
139        assert_eq!(resp.status(), 200);
140        let body = resp.text().await.unwrap();
141        assert!(body.contains("Kagi"));
142        assert!(body.contains("Secure secrets"));
143    }
144
145    #[tokio::test]
146    async fn test_server_key_endpoint() {
147        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
148        let client = test_http_client();
149        let resp = client
150            .get(format!("http://{addr}/v1/server-key"))
151            .send()
152            .await
153            .unwrap();
154        assert_eq!(resp.status(), 200);
155        let body: serde_json::Value = resp.json().await.unwrap();
156        assert_eq!(body["version"], 1);
157        assert!(body["server_key_id"].as_str().unwrap().starts_with("kgs_"));
158        assert!(body["recipient"].as_str().unwrap().starts_with("age1"));
159        assert!(!body["fingerprint"].as_str().unwrap().is_empty());
160    }
161
162    #[tokio::test]
163    async fn test_oversized_request_body_rejected() {
164        let (addr, _db_dir, _key_dir) = spawn_test_server(1024).await;
165        let client = test_http_client();
166        let large_body = serde_json::json!({"data": "x".repeat(2048) });
167        let resp = client
168            .post(format!("http://{addr}/v1/projects/kgp_test/push"))
169            .json(&large_body)
170            .send()
171            .await
172            .unwrap();
173        // RequestBodyLimitLayer returns 413 Payload Too Large
174        assert_eq!(resp.status(), 413);
175    }
176
177    #[tokio::test]
178    async fn test_malformed_json_rejected() {
179        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
180        let client = test_http_client();
181        let resp = client
182            .post(format!("http://{addr}/v1/projects/kgp_test/push"))
183            .header("Content-Type", "application/json")
184            .body("not valid json {")
185            .send()
186            .await
187            .unwrap();
188        // Axum's Json extractor returns 400 Bad Request for malformed JSON
189        assert_eq!(resp.status(), 400);
190    }
191
192    #[tokio::test]
193    async fn test_encrypted_roundtrip_create_project_request() {
194        use age::x25519;
195        use kagi_sync::domain::envelope::{RequestPlaintext, ResponseEnvelope};
196        use kagi_sync::infrastructure::remote_envelope::{decrypt_response, encrypt_request};
197        use std::str::FromStr;
198        use time::OffsetDateTime;
199
200        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
201        let client = test_http_client();
202
203        // 1. Fetch server key
204        let server_key_resp = client
205            .get(format!("http://{addr}/v1/server-key"))
206            .send()
207            .await
208            .unwrap();
209        assert_eq!(server_key_resp.status(), 200);
210        let server_key: serde_json::Value = server_key_resp.json().await.unwrap();
211        let server_recipient_str = server_key["recipient"].as_str().unwrap();
212        let server_recipient = x25519::Recipient::from_str(server_recipient_str).unwrap();
213
214        // 2. Create client identity
215        let client_identity = x25519::Identity::generate();
216        let client_recipient = client_identity.to_public();
217
218        // 3. Build plaintext
219        let issued_at = OffsetDateTime::now_utc()
220            .format(&time::format_description::well_known::Rfc3339)
221            .unwrap();
222        let alice_identity = x25519::Identity::generate();
223        let alice_recipient = alice_identity.to_public().to_string();
224        let plaintext = RequestPlaintext {
225            version: 1,
226            request_id: "kgr_test_1".into(),
227            issued_at,
228            operation: "create_project_request".into(),
229            method: "POST".into(),
230            path: "/v1/projects/requests".into(),
231            project_id: Some("kgp_roundtrip".into()),
232            token: None,
233            claim_secret: None,
234            payload: serde_json::json!({
235                "requester_member_id": "kgm_alice",
236                "requester_name": "Alice",
237                "requester_recipient": alice_recipient,
238                "claim_secret_hash": "cs:test",
239            }),
240        };
241
242        // 4. Encrypt request
243        let envelope = encrypt_request(&plaintext, &server_recipient, &client_recipient).unwrap();
244        let server_key_id = server_key["server_key_id"].as_str().unwrap();
245        let mut envelope = envelope;
246        envelope.server_key_id = server_key_id.into();
247
248        // 5. Send encrypted request
249        let resp = client
250            .post(format!("http://{addr}/v1/projects/requests"))
251            .json(&envelope)
252            .send()
253            .await
254            .unwrap();
255        assert_eq!(resp.status(), 200);
256
257        // 6. Parse and decrypt response
258        let response_envelope: ResponseEnvelope = resp.json().await.unwrap();
259        assert_eq!(response_envelope.request_id, "kgr_test_1");
260
261        let ciphertext =
262            kagi_sync::domain::project_token::base64_decode_url(&response_envelope.ciphertext)
263                .unwrap();
264        let decrypted = decrypt_response(&ciphertext, &client_identity).unwrap();
265        assert_eq!(decrypted["ok"], true);
266        assert_eq!(decrypted["data"]["project_id"], "kgp_roundtrip");
267        assert_eq!(decrypted["data"]["status"], "pending");
268    }
269
270    #[tokio::test]
271    async fn test_rate_limit_rejects_excess_requests() {
272        let (addr, _db_dir, _key_dir) =
273            spawn_test_server_with_rate_limit(10 * 1024 * 1024, 1, 1).await;
274        let client = test_http_client();
275
276        // First request should succeed
277        let resp1 = client
278            .get(format!("http://{addr}/v1/server-key"))
279            .send()
280            .await
281            .unwrap();
282        assert_eq!(resp1.status(), 200);
283
284        // Immediate second request should be rate limited (429)
285        let resp2 = client
286            .get(format!("http://{addr}/v1/server-key"))
287            .send()
288            .await
289            .unwrap();
290        assert_eq!(resp2.status(), 429);
291    }
292
293    #[tokio::test]
294    async fn test_metrics_endpoint_requires_auth() {
295        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
296        let client = test_http_client();
297        // No auth header -> should fail
298        let resp = client
299            .get(format!("http://{addr}/v1/metrics"))
300            .send()
301            .await
302            .unwrap();
303        assert_eq!(resp.status(), 401);
304    }
305}