switchgear_service/discovery/
service.rs

1use crate::axum::auth::BearerTokenAuthLayer;
2use crate::discovery::auth::DiscoveryBearerTokenValidator;
3use crate::discovery::handler::DiscoveryHandlers;
4use crate::discovery::state::DiscoveryState;
5use axum::routing::{delete, get, patch, post, put};
6use axum::Router;
7use switchgear_service_api::discovery::DiscoveryBackendStore;
8use switchgear_service_api::service::StatusCode;
9
10#[derive(Debug)]
11pub struct DiscoveryService;
12
13impl DiscoveryService {
14    pub fn router<S>(state: DiscoveryState<S>) -> Router
15    where
16        S: DiscoveryBackendStore + Clone + Send + Sync + 'static,
17    {
18        Router::new()
19            .route(
20                "/discovery/{public_key}",
21                get(DiscoveryHandlers::get_backend),
22            )
23            .route(
24                "/discovery/{public_key}",
25                put(DiscoveryHandlers::put_backend),
26            )
27            .route(
28                "/discovery/{public_key}",
29                patch(DiscoveryHandlers::patch_backend),
30            )
31            .route(
32                "/discovery/{public_key}",
33                delete(DiscoveryHandlers::delete_backend),
34            )
35            .route("/discovery", get(DiscoveryHandlers::get_backends))
36            .route("/discovery", post(DiscoveryHandlers::post_backend))
37            .layer(BearerTokenAuthLayer::new(
38                DiscoveryBearerTokenValidator::new(state.auth_authority().clone()),
39                "discovery",
40            ))
41            .route("/health", get(Self::health_check_handler))
42            .with_state(state)
43    }
44
45    async fn health_check_handler() -> StatusCode {
46        StatusCode::OK
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use crate::discovery::auth::{DiscoveryAudience, DiscoveryClaims};
53    use crate::discovery::service::DiscoveryService;
54    use crate::discovery::state::DiscoveryState;
55    use crate::testing::discovery::store::TestDiscoveryBackendStore;
56    use axum::http::StatusCode;
57    use axum_test::TestServer;
58    use jsonwebtoken::{encode, Algorithm, DecodingKey, EncodingKey, Header};
59    use p256::ecdsa::SigningKey;
60    use p256::pkcs8::EncodePrivateKey;
61    use p256::pkcs8::EncodePublicKey;
62    use rand::{thread_rng, Rng};
63    use secp256k1::{PublicKey, Secp256k1, SecretKey};
64    use std::time::{SystemTime, UNIX_EPOCH};
65    use switchgear_service_api::discovery::{
66        DiscoveryBackend, DiscoveryBackendPatchSparse, DiscoveryBackendSparse,
67    };
68
69    fn create_test_backend(partition: &str) -> DiscoveryBackend {
70        let secp = Secp256k1::new();
71        let mut rng = thread_rng();
72        let secret_key = SecretKey::from_byte_array(rng.gen::<[u8; 32]>()).unwrap();
73        let public_key = PublicKey::from_secret_key(&secp, &secret_key);
74
75        DiscoveryBackend {
76            public_key,
77            backend: DiscoveryBackendSparse {
78                name: None,
79                partitions: [partition.to_string()].into(),
80                weight: 100,
81                enabled: true,
82                implementation: "{}".as_bytes().to_vec(),
83            },
84        }
85    }
86
87    struct TestServerWithAuthorization {
88        server: TestServer,
89        authorization: String,
90    }
91
92    async fn setup_test_server() -> TestServerWithAuthorization {
93        let mut rng = thread_rng();
94        let private_key = SigningKey::random(&mut rng);
95        let public_key = *private_key.verifying_key();
96
97        let private_key = private_key
98            .to_pkcs8_pem(p256::pkcs8::LineEnding::default())
99            .unwrap();
100        let encoding_key = EncodingKey::from_ec_pem(private_key.as_bytes()).unwrap();
101
102        let public_key = public_key
103            .to_public_key_pem(p256::pkcs8::LineEnding::default())
104            .unwrap();
105        let decoding_key = DecodingKey::from_ec_pem(public_key.as_bytes()).unwrap();
106
107        let store = TestDiscoveryBackendStore::default();
108        let state = DiscoveryState::new(store, decoding_key);
109
110        let header = Header::new(Algorithm::ES256);
111        let claims = DiscoveryClaims {
112            aud: DiscoveryAudience::Discovery,
113            exp: (SystemTime::now()
114                .duration_since(UNIX_EPOCH)
115                .unwrap()
116                .as_secs()
117                + 3600) as usize,
118        };
119        let authorization = encode(&header, &claims, &encoding_key).unwrap();
120
121        let app = DiscoveryService::router(state);
122        TestServerWithAuthorization {
123            server: TestServer::new(app).unwrap(),
124            authorization,
125        }
126    }
127
128    #[tokio::test]
129    async fn health_check_when_called_then_returns_ok() {
130        let server = setup_test_server().await;
131
132        let response = server.server.get("/health").await;
133
134        assert_eq!(response.status_code(), StatusCode::OK);
135        // Health check returns empty body with 200 status
136        assert_eq!(response.text(), "");
137    }
138
139    #[tokio::test]
140    async fn get_backends_when_empty_then_returns_empty_list() {
141        let server = setup_test_server().await;
142
143        let response = server
144            .server
145            .get("/discovery")
146            .authorization_bearer(server.authorization.clone())
147            .await;
148
149        assert_eq!(response.status_code(), StatusCode::OK);
150        let backends: Vec<DiscoveryBackend> = response.json();
151        assert!(backends.is_empty());
152
153        // Verify cache headers
154        assert_eq!(
155            response.header("cache-control"),
156            "no-store, no-cache, must-revalidate"
157        );
158        assert_eq!(response.header("expires"), "Thu, 01 Jan 1970 00:00:00 GMT");
159        assert_eq!(response.header("pragma"), "no-cache");
160    }
161
162    #[tokio::test]
163    async fn post_backend_when_new_then_creates_and_returns_location() {
164        let server = setup_test_server().await;
165        let backend = create_test_backend("default");
166
167        let response = server
168            .server
169            .post("/discovery")
170            .authorization_bearer(server.authorization.clone())
171            .json(&backend)
172            .await;
173
174        assert_eq!(response.status_code(), StatusCode::CREATED);
175        let location = response.header("location").to_str().unwrap().to_string();
176        assert_eq!(location, backend.public_key.to_string());
177    }
178
179    #[tokio::test]
180    async fn post_backend_when_duplicate_then_returns_conflict() {
181        let server = setup_test_server().await;
182        let backend = create_test_backend("default");
183
184        // First POST should succeed
185        let response1 = server
186            .server
187            .post("/discovery")
188            .authorization_bearer(server.authorization.clone())
189            .json(&backend)
190            .await;
191        assert_eq!(response1.status_code(), StatusCode::CREATED);
192
193        // Second POST with same address should conflict
194        let response2 = server
195            .server
196            .post("/discovery")
197            .authorization_bearer(server.authorization.clone())
198            .json(&backend)
199            .await;
200        assert_eq!(response2.status_code(), StatusCode::CONFLICT);
201    }
202
203    #[tokio::test]
204    async fn get_backend_when_exists_then_returns_backend() {
205        let server = setup_test_server().await;
206        let backend = create_test_backend("default");
207
208        // First create the backend
209        let response = server
210            .server
211            .post("/discovery")
212            .authorization_bearer(server.authorization.clone())
213            .json(&backend)
214            .await;
215
216        let location = response.header("location");
217        let location = location.to_str().unwrap();
218
219        // Then retrieve it
220        let response = server
221            .server
222            .get(format!("/discovery/{location}").as_str())
223            .authorization_bearer(server.authorization.clone())
224            .await;
225
226        assert_eq!(response.status_code(), StatusCode::OK);
227        let retrieved: DiscoveryBackend = response.json();
228        assert_eq!(retrieved.public_key, backend.public_key);
229    }
230
231    #[tokio::test]
232    async fn get_backend_when_not_exists_then_returns_not_found() {
233        let server = setup_test_server().await;
234
235        let response = server
236            .server
237            .get("/discovery/default/inet/MTkyLjE2OC4xLjE6ODA4MA")
238            .authorization_bearer(server.authorization.clone())
239            .await;
240
241        assert_eq!(response.status_code(), StatusCode::NOT_FOUND);
242    }
243
244    #[tokio::test]
245    async fn put_backend_when_new_then_created() {
246        let server = setup_test_server().await;
247        let backend = create_test_backend("default");
248
249        let response = server
250            .server
251            .put(&format!("/discovery/{}", backend.public_key))
252            .authorization_bearer(server.authorization.clone())
253            .json(&backend.backend)
254            .await;
255
256        assert_eq!(response.status_code(), StatusCode::CREATED);
257    }
258
259    #[tokio::test]
260    async fn put_backend_when_exists_then_updates_no_content() {
261        let server = setup_test_server().await;
262        let mut backend = create_test_backend("default");
263
264        // Create initial backend
265        let response = server
266            .server
267            .post("/discovery")
268            .authorization_bearer(server.authorization.clone())
269            .json(&backend)
270            .await;
271
272        let location = response.header("location");
273        let location = location.to_str().unwrap();
274
275        // Update with PUT
276        backend.backend.weight = 200;
277        let response = server
278            .server
279            .put(&format!("/discovery/{location}"))
280            .authorization_bearer(server.authorization.clone())
281            .json(&backend.backend)
282            .await;
283
284        assert_eq!(response.status_code(), StatusCode::NO_CONTENT);
285
286        // Verify the update
287        let get_response = server
288            .server
289            .get(&format!("/discovery/{location}"))
290            .authorization_bearer(server.authorization.clone())
291            .await;
292        let updated: DiscoveryBackend = get_response.json();
293        assert_eq!(updated.backend.weight, 200);
294    }
295
296    #[tokio::test]
297    async fn patch_backend_then_no_content() {
298        let server = setup_test_server().await;
299        let mut backend = create_test_backend("default");
300
301        // Create initial backend
302        let response = server
303            .server
304            .post("/discovery")
305            .authorization_bearer(server.authorization.clone())
306            .json(&backend)
307            .await;
308
309        let location = response.header("location");
310        let location = location.to_str().unwrap();
311
312        let patch = DiscoveryBackendPatchSparse {
313            name: None,
314            partitions: None,
315            weight: Some(200),
316            enabled: None,
317        };
318        // Update with PATCH
319        backend.backend.weight = 200;
320        let response = server
321            .server
322            .patch(&format!("/discovery/{location}"))
323            .authorization_bearer(server.authorization.clone())
324            .json(&patch)
325            .await;
326
327        assert_eq!(response.status_code(), StatusCode::NO_CONTENT);
328
329        // Verify the update
330        let get_response = server
331            .server
332            .get(&format!("/discovery/{location}"))
333            .authorization_bearer(server.authorization.clone())
334            .await;
335        let updated: DiscoveryBackend = get_response.json();
336        assert_eq!(updated.backend.weight, 200);
337    }
338
339    #[tokio::test]
340    async fn patch_missing_backend_then_not_found() {
341        let server = setup_test_server().await;
342        let mut backend = create_test_backend("default");
343
344        let location = backend.public_key.to_string();
345
346        let patch = DiscoveryBackendPatchSparse {
347            name: None,
348            partitions: None,
349            weight: Some(200),
350            enabled: None,
351        };
352        // Update with PATCH
353        backend.backend.weight = 200;
354        let response = server
355            .server
356            .patch(&format!("/discovery/{location}"))
357            .authorization_bearer(server.authorization.clone())
358            .json(&patch)
359            .await;
360
361        assert_eq!(response.status_code(), StatusCode::NOT_FOUND);
362    }
363
364    #[tokio::test]
365    async fn delete_backend_when_exists_then_removes_and_returns_backend() {
366        let server = setup_test_server().await;
367        let backend = create_test_backend("default");
368
369        // Create backend
370        let response = server
371            .server
372            .post("/discovery")
373            .authorization_bearer(server.authorization.clone())
374            .json(&backend)
375            .await;
376        let location = response.header("location");
377        let location = location.to_str().unwrap();
378
379        // Delete backend
380        let response = server
381            .server
382            .delete(&format!("/discovery/{location}"))
383            .authorization_bearer(server.authorization.clone())
384            .await;
385        eprintln!("location: {location}");
386
387        assert_eq!(response.status_code(), StatusCode::NO_CONTENT);
388        // Delete returns empty body, no JSON to parse
389
390        // Verify it's gone
391        let get_response = server
392            .server
393            .get(&format!("/discovery/{location}"))
394            .authorization_bearer(server.authorization.clone())
395            .await;
396        assert_eq!(get_response.status_code(), StatusCode::NOT_FOUND);
397    }
398
399    #[tokio::test]
400    async fn delete_backend_when_not_exists_then_returns_not_found() {
401        let server = setup_test_server().await;
402
403        let response = server
404            .server
405            .delete("/discovery/default/inet/MTkyLjE2OC4xLjE6ODA4MA")
406            .authorization_bearer(server.authorization.clone())
407            .await;
408
409        assert_eq!(response.status_code(), StatusCode::NOT_FOUND);
410    }
411
412    #[tokio::test]
413    async fn get_backends_when_multiple_exist_then_returns_all() {
414        let server = setup_test_server().await;
415        let backend1 = create_test_backend("default");
416        let backend2 = create_test_backend("default");
417
418        // Sort backends by address before posting
419        let mut expected_backends = [backend1, backend2];
420        expected_backends.sort_by(|a, b| a.public_key.to_string().cmp(&b.public_key.to_string()));
421
422        // Create multiple backends
423        server
424            .server
425            .post("/discovery")
426            .authorization_bearer(server.authorization.clone())
427            .json(&expected_backends[0])
428            .await;
429        server
430            .server
431            .post("/discovery")
432            .authorization_bearer(server.authorization.clone())
433            .json(&expected_backends[1])
434            .await;
435
436        // Get all backends (first request)
437        let response = server
438            .server
439            .get("/discovery")
440            .authorization_bearer(server.authorization.clone())
441            .await;
442
443        assert_eq!(response.status_code(), StatusCode::OK);
444        let response_backends: Vec<DiscoveryBackend> = response.json();
445
446        assert_eq!(response_backends, expected_backends);
447
448        // Collect etag from first response
449        let etag = response.header("etag");
450
451        // Get all backends with IF_NONE_MATCH (second request)
452        let response2 = server
453            .server
454            .get("/discovery")
455            .authorization_bearer(server.authorization.clone())
456            .add_header(http::header::IF_NONE_MATCH, etag)
457            .await;
458
459        assert_eq!(response2.status_code(), StatusCode::NOT_MODIFIED);
460    }
461
462    #[tokio::test]
463    async fn api_when_invalid_json_then_returns_bad_request() {
464        let server = setup_test_server().await;
465
466        let response = server
467            .server
468            .post("/discovery")
469            .authorization_bearer(server.authorization.clone())
470            .text("invalid json")
471            .await;
472
473        assert_eq!(response.status_code(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
474    }
475
476    #[tokio::test]
477    async fn api_when_invalid_address_encoding_then_returns_bad_request() {
478        let server = setup_test_server().await;
479
480        let response = server
481            .server
482            .get("/discovery/default/inet/invalid_base64")
483            .authorization_bearer(server.authorization.clone())
484            .await;
485
486        assert_eq!(response.status_code(), StatusCode::NOT_FOUND);
487    }
488
489    #[tokio::test]
490    async fn api_when_unsupported_variant_then_returns_bad_request() {
491        let server = setup_test_server().await;
492
493        let response = server
494            .server
495            .get("/discovery/default/unsupported/dGVzdA")
496            .authorization_bearer(server.authorization.clone())
497            .await;
498
499        assert_eq!(response.status_code(), StatusCode::NOT_FOUND);
500    }
501
502    #[tokio::test]
503    async fn unauthorized() {
504        let server = setup_test_server().await;
505        let backend = create_test_backend("default");
506
507        let response = server.server.post("/discovery").json(&backend).await;
508
509        assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED);
510
511        let response = server
512            .server
513            .get("/discovery/default/inet/MTkyLjE2OC4xLjE6ODA4MA")
514            .await;
515
516        assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED);
517
518        let response = server
519            .server
520            .put("/discovery/default/inet/MTkyLjE2OC4xLjE6ODA4MA")
521            .json(&backend)
522            .await;
523
524        assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED);
525
526        let response = server.server.delete("/discovery/default").await;
527
528        assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED);
529    }
530}