1use std::io;
22use std::net::{IpAddr, SocketAddr};
23use std::sync::Arc;
24use std::time::Duration;
25
26use axum::extract::{Query, State};
27use axum::http::StatusCode;
28use axum::response::{IntoResponse, Response};
29use axum::routing::{get, post, put};
30use axum::{Json, Router};
31use eyre::Context;
32use serde::{Deserialize, Serialize};
33use tokio::net::TcpListener;
34use tokio::sync::mpsc;
35use tokio::task::{JoinError, JoinSet};
36use tokio_util::sync::CancellationToken;
37use tower_http::trace::TraceLayer;
38use tracing::{error, info};
39
40use crate::config::loader::ConfigEntry;
41use crate::error::ConfigError;
42use crate::store::StoreDir;
43
44use super::Config;
45use crate::config::file::CONFIG_FILE_NAME_NO_EXT;
46
47#[derive(thiserror::Error, Debug)]
49pub enum HttpError {
50 #[error("couldn't bind the address {addr}")]
52 Bind {
53 addr: SocketAddr,
55 source: io::Error,
57 },
58 #[error("server panicked")]
60 Join(#[from] JoinError),
61}
62
63#[derive(thiserror::Error, Debug, displaydoc::Display)]
65pub enum ErrorResponse {
66 InvalidConfig(#[from] ConfigError),
68 Serialize(#[from] toml::ser::Error),
70 Write(#[from] io::Error),
72 Channel,
74 MediaType,
76 Deserialize,
78}
79
80impl IntoResponse for ErrorResponse {
81 fn into_response(self) -> Response {
82 let (status, msg) = match self {
83 ErrorResponse::InvalidConfig(err) => (
84 StatusCode::BAD_REQUEST,
85 format!("Invalid configuration: {err}"),
86 ),
87 ErrorResponse::Serialize(err) => (
88 StatusCode::INTERNAL_SERVER_ERROR,
89 format!("Error in config serialization, {err}"),
90 ),
91 ErrorResponse::Write(err) => (
92 StatusCode::INTERNAL_SERVER_ERROR,
93 format!("Unable to write in toml file, {err}"),
94 ),
95 ErrorResponse::Channel => (
96 StatusCode::INTERNAL_SERVER_ERROR,
97 "Channel error".to_string(),
98 ),
99 ErrorResponse::MediaType => (
100 StatusCode::UNSUPPORTED_MEDIA_TYPE,
101 "Config files must be either JSON or TOML".to_string(),
102 ),
103 ErrorResponse::Deserialize => (
104 StatusCode::BAD_REQUEST,
105 "Invalid configuration file".to_string(),
106 ),
107 };
108
109 let t = (
110 status,
111 Json(ConfigResponse {
112 result: "KO".to_string(),
113 message: Some(msg),
114 }),
115 );
116
117 t.into_response()
118 }
119}
120
121#[derive(Debug, Clone, Deserialize, Serialize)]
122struct ConfigResponse {
123 result: String,
124 message: Option<String>,
125}
126
127impl Default for ConfigResponse {
128 fn default() -> Self {
129 ConfigResponse {
130 result: "OK".to_string(),
131 message: None,
132 }
133 }
134}
135
136#[derive(Debug)]
137struct ConfigServer {
138 tx: mpsc::Sender<ConfigEntry>,
139 store_dir: StoreDir,
140}
141
142impl ConfigServer {
143 fn new(tx: mpsc::Sender<ConfigEntry>, store_dir: StoreDir) -> Self {
144 Self { tx, store_dir }
145 }
146
147 async fn send_config(&self, config: Config, file_name: &str) -> Result<(), ErrorResponse> {
148 let path = self.store_dir.dynamic_config_file(file_name);
149
150 let entry = ConfigEntry::new(path, config);
151
152 self.tx
153 .send_timeout(entry, Duration::from_secs(10))
154 .await
155 .map_err(|error| {
156 error!(%error, "couldn't send configuration");
157
158 ErrorResponse::Channel
159 })?;
160
161 Ok(())
162 }
163}
164
165#[derive(Debug, Clone, Deserialize)]
166struct ConfigPayload {
167 realm: String,
168 device_id: Option<String>,
169 credentials_secret: Option<String>,
170 pairing_url: String,
171 pairing_token: Option<String>,
172 grpc_socket_host: Option<IpAddr>,
173 grpc_socket_port: Option<u16>,
174}
175
176impl TryFrom<ConfigPayload> for Config {
177 type Error = ErrorResponse;
178
179 fn try_from(value: ConfigPayload) -> Result<Self, Self::Error> {
180 let ConfigPayload {
181 realm,
182 device_id,
183 credentials_secret,
184 pairing_url,
185 pairing_token,
186 grpc_socket_host,
187 grpc_socket_port,
188 } = value;
189
190 let config = Self {
191 realm: Some(realm),
192 device_id,
193 credentials_secret,
194 pairing_url: Some(pairing_url),
195 pairing_token,
196 grpc_socket_host,
197 grpc_socket_port,
198 ..Default::default()
199 };
200
201 config.validate()?;
202
203 Ok(config)
204 }
205}
206
207pub async fn serve(
209 tasks: &mut JoinSet<eyre::Result<()>>,
210 cancel: CancellationToken,
211 address: &SocketAddr,
212 tx: mpsc::Sender<ConfigEntry>,
213 store_dir: StoreDir,
214) -> Result<SocketAddr, HttpError> {
215 let cfg_server = ConfigServer::new(tx, store_dir);
216
217 let app = Router::new()
218 .route("/", get(root))
219 .route("/config", post(set_config))
220 .route("/config/upload/{file_name}", put(upload_config))
221 .layer(TraceLayer::new_for_http())
222 .with_state(Arc::new(cfg_server));
223
224 let listener = TcpListener::bind(address)
225 .await
226 .map_err(|e| HttpError::Bind {
227 addr: *address,
228 source: e,
229 })?;
230
231 let local_addr = listener.local_addr().map_err(|error| {
232 error!(%error, "couldn't get binded address");
233
234 HttpError::Bind {
235 addr: *address,
236 source: error,
237 }
238 })?;
239
240 info!("HTTP dynamic config server listening on http://{local_addr}");
241
242 tasks.spawn(async move {
243 axum::serve(listener, app)
244 .with_graceful_shutdown(async move {
245 cancel.cancelled().await;
246
247 info!("HTTP server exiting");
248 })
249 .await
250 .wrap_err("couldn't run HTTP dynamic config server")
251 });
252
253 Ok(local_addr)
254}
255
256#[derive(Debug, Deserialize)]
257struct UploadQuery {
258 #[serde(default = "UploadQuery::default_store")]
259 store: bool,
260}
261
262impl UploadQuery {
263 fn default_store() -> bool {
264 true
265 }
266}
267
268async fn root() -> (StatusCode, Json<ConfigResponse>) {
270 (StatusCode::OK, Json(ConfigResponse::default()))
271}
272
273async fn set_config(
275 State(state): State<Arc<ConfigServer>>,
276 Query(query): Query<UploadQuery>,
277 Json(payload): Json<ConfigPayload>,
278) -> Result<(StatusCode, Json<ConfigResponse>), ErrorResponse> {
279 let config = Config::try_from(payload)?;
280
281 if query.store {
282 state
283 .store_dir
284 .store_config(&config, CONFIG_FILE_NAME_NO_EXT)
285 .await;
286 }
287
288 state.send_config(config, CONFIG_FILE_NAME_NO_EXT).await?;
289
290 Ok((StatusCode::OK, Json(ConfigResponse::default())))
291}
292
293async fn upload_config(
294 State(state): State<Arc<ConfigServer>>,
295 axum::extract::Path(file_name): axum::extract::Path<String>,
296 Query(query): Query<UploadQuery>,
297 Json(payload): Json<ConfigPayload>,
298) -> Result<StatusCode, ErrorResponse> {
299 let file_name = file_name.strip_suffix(".json").unwrap_or(&file_name);
300 let file_name = file_name.strip_suffix(".toml").unwrap_or(file_name);
301
302 let config = Config::try_from(payload)?;
303
304 if query.store {
305 state.store_dir.store_config(&config, file_name).await;
306 }
307
308 state.send_config(config, file_name).await?;
309
310 Ok(StatusCode::NO_CONTENT)
311}
312
313#[cfg(test)]
314mod test {
315 use std::collections::HashMap;
316 use std::time::Duration;
317
318 use pretty_assertions::assert_eq;
319 use rstest::rstest;
320 use serde_json::{Map, Number, Value};
321 use tempfile::TempDir;
322
323 use super::*;
324
325 struct TestServer {
326 tasks: JoinSet<eyre::Result<()>>,
327 cancel_token: CancellationToken,
328 rx: mpsc::Receiver<ConfigEntry>,
329 address: SocketAddr,
330 dir: TempDir,
331 }
332
333 impl TestServer {
334 async fn serve() -> Self {
335 let dir = TempDir::new().unwrap();
336
337 let mut tasks = JoinSet::new();
338 let cancel_token = CancellationToken::new();
339 let (tx, rx) = tokio::sync::mpsc::channel(1);
340
341 let store_dir = StoreDir::create(dir.path().to_path_buf()).await.unwrap();
342
343 let address = serve(
344 &mut tasks,
345 cancel_token.clone(),
346 &"127.0.0.1:0".parse().unwrap(),
347 tx,
348 store_dir,
349 )
350 .await
351 .expect("failed to create server");
352
353 Self {
354 tasks,
355 cancel_token,
356 rx,
357 address,
358 dir,
359 }
360 }
361 }
362
363 #[rstest]
364 #[timeout(Duration::from_secs(2))]
365 #[tokio::test]
366 async fn server_test() {
367 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
368
369 let mut server = TestServer::serve().await;
370
371 let exp = Config {
372 realm: Some("realm".to_string()),
373 device_id: Some("device_id".to_string()),
374 pairing_url: Some("pairing_url".to_string()),
375 credentials_secret: Some("credentials_secret".to_string()),
376 ..Default::default()
377 };
378
379 let client = reqwest::Client::new();
380
381 let resp = client
382 .post(format!("http://{}/config", server.address))
383 .json(&exp)
384 .send()
385 .await
386 .unwrap()
387 .error_for_status()
388 .unwrap();
389
390 let json: ConfigResponse = resp.json().await.unwrap();
391 assert_eq!(json.result, "OK".to_string());
392
393 let config = server.rx.try_recv().unwrap();
394
395 assert_eq!(config.config, exp);
396
397 server.cancel_token.cancel();
398
399 server.tasks.join_next().await.unwrap().unwrap().unwrap();
400
401 let config: Config = toml::from_str(
402 &tokio::fs::read_to_string(server.dir.path().join("config/50-message-hub-config.toml"))
403 .await
404 .unwrap(),
405 )
406 .unwrap();
407
408 assert_eq!(config, exp);
409 }
410
411 #[rstest]
412 #[timeout(Duration::from_secs(2))]
413 #[tokio::test]
414 async fn server_upload_test() {
415 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
416
417 let mut server = TestServer::serve().await;
418
419 let exp = Config {
420 realm: Some("realm".to_string()),
421 device_id: Some("device_id".to_string()),
422 pairing_url: Some("pairing_url".to_string()),
423 credentials_secret: Some("credentials_secret".to_string()),
424 ..Default::default()
425 };
426
427 let client = reqwest::Client::new();
428
429 let resp = client
430 .put(format!(
431 "http://{}/config/upload/99-custom.toml",
432 server.address
433 ))
434 .json(&exp)
435 .send()
436 .await
437 .unwrap()
438 .error_for_status()
439 .unwrap();
440
441 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
442
443 let config = server.rx.try_recv().unwrap();
444
445 assert_eq!(config.config, exp);
446
447 server.cancel_token.cancel();
448
449 server.tasks.join_next().await.unwrap().unwrap().unwrap();
450
451 let config: Config = toml::from_str(
452 &tokio::fs::read_to_string(server.dir.path().join("config/99-custom.toml"))
453 .await
454 .unwrap(),
455 )
456 .unwrap();
457
458 assert_eq!(config, exp);
459 }
460
461 #[rstest]
462 #[timeout(Duration::from_secs(2))]
463 #[tokio::test]
464 async fn bad_request_test() {
465 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
466
467 let mut server = TestServer::serve().await;
468
469 let mut body = HashMap::new();
470 body.insert("device_id", "device_id");
471 body.insert("pairing_url", "pairing_url");
472
473 let client = reqwest::Client::new();
474 let resp = client
475 .post(format!("http://{}/config", server.address))
476 .json(&body)
477 .send()
478 .await
479 .unwrap();
480
481 let status = resp.status();
482 assert!(!status.is_success());
483
484 server.cancel_token.cancel();
485
486 server.tasks.join_next().await.unwrap().unwrap().unwrap();
487 }
488
489 #[rstest]
490 #[timeout(Duration::from_secs(2))]
491 #[tokio::test]
492 async fn test_set_config_invalid_cfg() {
493 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
494
495 let mut server = TestServer::serve().await;
496
497 let mut body = Map::new();
498 body.insert("realm".to_string(), Value::String("".to_string()));
499 body.insert(
500 "device_id".to_string(),
501 Value::String("device_id".to_string()),
502 );
503 body.insert(
504 "credentials_secret".to_string(),
505 Value::String("credentials_secret".to_string()),
506 );
507 body.insert(
508 "pairing_url".to_string(),
509 Value::String("pairing_url".to_string()),
510 );
511 body.insert(
512 "grpc_socket_port".to_string(),
513 Value::Number(Number::from(22_u16)),
514 );
515
516 let client = reqwest::Client::new();
517 let resp = client
518 .post(format!("http://{}/config", server.address))
519 .json(&body)
520 .send()
521 .await
522 .unwrap();
523
524 let status = resp.status();
525 assert_eq!(status, reqwest::StatusCode::BAD_REQUEST);
526 let json: ConfigResponse = resp.json().await.unwrap();
527 assert_eq!(json.result, "KO".to_string());
528
529 server.cancel_token.cancel();
530
531 server.tasks.join_next().await.unwrap().unwrap().unwrap();
532 }
533}