Skip to main content

kagi_server/server/
mod.rs

1pub mod errors;
2pub mod routes;
3pub mod state;
4
5use crate::server::state::AppState;
6use std::net::SocketAddr;
7use std::path::Path;
8use tower_governor::GovernorLayer;
9use tower_governor::governor::GovernorConfigBuilder;
10use tower_http::limit::RequestBodyLimitLayer;
11use tower_http::trace::TraceLayer;
12
13pub async fn serve(
14    bind: SocketAddr,
15    db_path: &Path,
16    key_file_path: &Path,
17    max_body_size: usize,
18) -> Result<(), anyhow::Error> {
19    let state = AppState::new(db_path, key_file_path).await?;
20
21    let governor_conf = GovernorConfigBuilder::default()
22        .per_second(2)
23        .burst_size(30)
24        .finish()
25        .unwrap();
26
27    let app = routes::router(state.clone())
28        .layer(GovernorLayer::new(governor_conf))
29        .layer(RequestBodyLimitLayer::new(max_body_size))
30        .layer(TraceLayer::new_for_http());
31
32    tracing::info!("kagi: server key fingerprint {}", state.fingerprint);
33
34    let listener = tokio::net::TcpListener::bind(bind).await?;
35    let addr = listener.local_addr()?;
36    println!("kagi: server running on http://{}", addr);
37    tracing::info!("kagi: listening on http://{}", addr);
38
39    if bind.ip().is_unspecified() || !bind.ip().is_loopback() {
40        tracing::warn!(
41            "kagi: server bound to public interface without HTTPS. Application-layer encryption protects payloads, but HTTPS is recommended for metadata safety."
42        );
43    }
44
45    tracing::info!("kagi: server started successfully");
46    axum::serve(
47        listener,
48        app.into_make_service_with_connect_info::<SocketAddr>(),
49    )
50    .await?;
51    Ok(())
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use std::net::SocketAddr;
58
59    async fn spawn_test_server(
60        max_body_size: usize,
61    ) -> (SocketAddr, tempfile::TempDir, tempfile::TempDir) {
62        let db_dir = tempfile::TempDir::new().unwrap();
63        let key_dir = tempfile::TempDir::new().unwrap();
64        let db_path = db_dir.path().join("server.db");
65        let key_path = key_dir.path().join("server.key");
66
67        let state = AppState::new(&db_path, &key_path).await.unwrap();
68        let governor_conf = GovernorConfigBuilder::default()
69            .per_second(60)
70            .burst_size(100)
71            .finish()
72            .unwrap();
73        let app = routes::router(state)
74            .layer(GovernorLayer::new(governor_conf))
75            .layer(RequestBodyLimitLayer::new(max_body_size))
76            .layer(TraceLayer::new_for_http());
77
78        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
79        let addr = listener.local_addr().unwrap();
80        tokio::spawn(async move {
81            axum::serve(
82                listener,
83                app.into_make_service_with_connect_info::<SocketAddr>(),
84            )
85            .await
86            .unwrap();
87        });
88
89        // Give server a moment to start
90        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
91        (addr, db_dir, key_dir)
92    }
93
94    async fn spawn_test_server_with_rate_limit(
95        max_body_size: usize,
96        per_second: u64,
97        burst_size: u32,
98    ) -> (SocketAddr, tempfile::TempDir, tempfile::TempDir) {
99        let db_dir = tempfile::TempDir::new().unwrap();
100        let key_dir = tempfile::TempDir::new().unwrap();
101        let db_path = db_dir.path().join("server.db");
102        let key_path = key_dir.path().join("server.key");
103
104        let state = AppState::new(&db_path, &key_path).await.unwrap();
105        let governor_conf = GovernorConfigBuilder::default()
106            .per_second(per_second)
107            .burst_size(burst_size)
108            .finish()
109            .unwrap();
110        let app = routes::router(state)
111            .layer(GovernorLayer::new(governor_conf))
112            .layer(RequestBodyLimitLayer::new(max_body_size))
113            .layer(TraceLayer::new_for_http());
114
115        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
116        let addr = listener.local_addr().unwrap();
117        tokio::spawn(async move {
118            axum::serve(
119                listener,
120                app.into_make_service_with_connect_info::<SocketAddr>(),
121            )
122            .await
123            .unwrap();
124        });
125
126        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
127        (addr, db_dir, key_dir)
128    }
129
130    fn test_http_client() -> reqwest::Client {
131        reqwest::Client::builder().no_proxy().build().unwrap()
132    }
133
134    #[tokio::test]
135    async fn test_health_check_endpoint() {
136        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
137        let client = test_http_client();
138        let resp = client
139            .get(format!("http://{}/", addr))
140            .send()
141            .await
142            .unwrap();
143        assert_eq!(resp.status(), 200);
144        let body = resp.text().await.unwrap();
145        assert!(body.contains("Kagi"));
146        assert!(body.contains("Secure secrets"));
147    }
148
149    #[tokio::test]
150    async fn test_server_key_endpoint() {
151        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
152        let client = test_http_client();
153        let resp = client
154            .get(format!("http://{}/v1/server-key", addr))
155            .send()
156            .await
157            .unwrap();
158        assert_eq!(resp.status(), 200);
159        let body: serde_json::Value = resp.json().await.unwrap();
160        assert_eq!(body["version"], 1);
161        assert!(body["server_key_id"].as_str().unwrap().starts_with("kgs_"));
162        assert!(body["recipient"].as_str().unwrap().starts_with("age1"));
163        assert!(!body["fingerprint"].as_str().unwrap().is_empty());
164    }
165
166    #[tokio::test]
167    async fn test_oversized_request_body_rejected() {
168        let (addr, _db_dir, _key_dir) = spawn_test_server(1024).await;
169        let client = test_http_client();
170        let large_body = serde_json::json!({"data": "x".repeat(2048) });
171        let resp = client
172            .post(format!("http://{}/v1/projects/kgp_test/push", addr))
173            .json(&large_body)
174            .send()
175            .await
176            .unwrap();
177        // RequestBodyLimitLayer returns 413 Payload Too Large
178        assert_eq!(resp.status(), 413);
179    }
180
181    #[tokio::test]
182    async fn test_malformed_json_rejected() {
183        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
184        let client = test_http_client();
185        let resp = client
186            .post(format!("http://{}/v1/projects/kgp_test/push", addr))
187            .header("Content-Type", "application/json")
188            .body("not valid json {")
189            .send()
190            .await
191            .unwrap();
192        // Axum's Json extractor returns 400 Bad Request for malformed JSON
193        assert_eq!(resp.status(), 400);
194    }
195
196    #[tokio::test]
197    async fn test_encrypted_roundtrip_create_project_request() {
198        use age::x25519;
199        use kagi_sync::domain::envelope::{RequestPlaintext, ResponseEnvelope};
200        use kagi_sync::infrastructure::remote_envelope::{decrypt_response, encrypt_request};
201        use std::str::FromStr;
202        use time::OffsetDateTime;
203
204        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
205        let client = test_http_client();
206
207        // 1. Fetch server key
208        let server_key_resp = client
209            .get(format!("http://{}/v1/server-key", addr))
210            .send()
211            .await
212            .unwrap();
213        assert_eq!(server_key_resp.status(), 200);
214        let server_key: serde_json::Value = server_key_resp.json().await.unwrap();
215        let server_recipient_str = server_key["recipient"].as_str().unwrap();
216        let server_recipient = x25519::Recipient::from_str(server_recipient_str).unwrap();
217
218        // 2. Create client identity
219        let client_identity = x25519::Identity::generate();
220        let client_recipient = client_identity.to_public();
221
222        // 3. Build plaintext
223        let issued_at = OffsetDateTime::now_utc()
224            .format(&time::format_description::well_known::Rfc3339)
225            .unwrap();
226        let alice_identity = x25519::Identity::generate();
227        let alice_recipient = alice_identity.to_public().to_string();
228        let plaintext = RequestPlaintext {
229            version: 1,
230            request_id: "kgr_test_1".into(),
231            issued_at,
232            operation: "create_project_request".into(),
233            method: "POST".into(),
234            path: "/v1/projects/requests".into(),
235            project_id: Some("kgp_roundtrip".into()),
236            token: None,
237            claim_secret: None,
238            payload: serde_json::json!({
239                "requester_member_id": "kgm_alice",
240                "requester_name": "Alice",
241                "requester_recipient": alice_recipient,
242                "claim_secret_hash": "cs:test",
243            }),
244        };
245
246        // 4. Encrypt request
247        let envelope = encrypt_request(&plaintext, &server_recipient, &client_recipient).unwrap();
248        let server_key_id = server_key["server_key_id"].as_str().unwrap();
249        let mut envelope = envelope;
250        envelope.server_key_id = server_key_id.into();
251
252        // 5. Send encrypted request
253        let resp = client
254            .post(format!("http://{}/v1/projects/requests", addr))
255            .json(&envelope)
256            .send()
257            .await
258            .unwrap();
259        assert_eq!(resp.status(), 200);
260
261        // 6. Parse and decrypt response
262        let response_envelope: ResponseEnvelope = resp.json().await.unwrap();
263        assert_eq!(response_envelope.request_id, "kgr_test_1");
264
265        let ciphertext =
266            kagi_sync::domain::project_token::base64_decode_url(&response_envelope.ciphertext)
267                .unwrap();
268        let decrypted = decrypt_response(&ciphertext, &client_identity).unwrap();
269        assert_eq!(decrypted["ok"], true);
270        assert_eq!(decrypted["data"]["project_id"], "kgp_roundtrip");
271        assert_eq!(decrypted["data"]["status"], "pending");
272    }
273
274    #[tokio::test]
275    async fn test_rate_limit_rejects_excess_requests() {
276        let (addr, _db_dir, _key_dir) =
277            spawn_test_server_with_rate_limit(10 * 1024 * 1024, 1, 1).await;
278        let client = test_http_client();
279
280        // First request should succeed
281        let resp1 = client
282            .get(format!("http://{}/v1/server-key", addr))
283            .send()
284            .await
285            .unwrap();
286        assert_eq!(resp1.status(), 200);
287
288        // Immediate second request should be rate limited (429)
289        let resp2 = client
290            .get(format!("http://{}/v1/server-key", addr))
291            .send()
292            .await
293            .unwrap();
294        assert_eq!(resp2.status(), 429);
295    }
296
297    #[tokio::test]
298    async fn test_metrics_endpoint_requires_auth() {
299        let (addr, _db_dir, _key_dir) = spawn_test_server(10 * 1024 * 1024).await;
300        let client = test_http_client();
301        // No auth header -> should fail
302        let resp = client
303            .get(format!("http://{}/v1/metrics", addr))
304            .send()
305            .await
306            .unwrap();
307        assert_eq!(resp.status(), 401);
308    }
309}