posemesh_node_registration/
http.rs

1use crate::persist::{read_node_secret_from_path, write_node_secret_to_path};
2use crate::state::touch_healthcheck_now;
3use axum::extract::State;
4use axum::http::{header::USER_AGENT, HeaderMap, StatusCode};
5use axum::response::IntoResponse;
6use axum::routing::{get, post};
7use axum::Json;
8use serde::{Deserialize, Serialize};
9use std::path::PathBuf;
10use tracing::{debug, info, warn};
11
12#[derive(Clone)]
13pub struct DdsState {
14    pub secret_path: PathBuf,
15}
16
17#[derive(Debug, Deserialize)]
18pub struct RegistrationCallbackRequest {
19    pub id: String,
20    pub secret: String,
21    pub organization_id: Option<String>,
22    pub lighthouses_in_domains: Option<serde_json::Value>,
23    pub domains: Option<serde_json::Value>,
24}
25
26#[derive(Debug, Serialize)]
27struct OkResponse {
28    ok: bool,
29}
30
31#[derive(Debug)]
32enum CallbackError {
33    Unprocessable(&'static str), // 422
34    Forbidden(&'static str),     // 403
35    Conflict(&'static str),      // 409
36}
37
38impl IntoResponse for CallbackError {
39    fn into_response(self) -> axum::response::Response {
40        let (status, msg) = match self {
41            CallbackError::Unprocessable(m) => (StatusCode::UNPROCESSABLE_ENTITY, m),
42            CallbackError::Forbidden(m) => (StatusCode::FORBIDDEN, m),
43            CallbackError::Conflict(m) => (StatusCode::CONFLICT, m),
44        };
45        (status, msg).into_response()
46    }
47}
48
49pub fn router_dds(state: DdsState) -> axum::Router {
50    axum::Router::new()
51        .route("/internal/v1/registrations", post(callback_registration))
52        .route("/health", get(health))
53        .with_state(state)
54}
55
56async fn health(State(_state): State<DdsState>, headers: HeaderMap) -> impl IntoResponse {
57    let ua = headers
58        .get(USER_AGENT)
59        .and_then(|v| v.to_str().ok())
60        .unwrap_or("");
61    if ua.starts_with("DDS v") {
62        match touch_healthcheck_now() {
63            Ok(()) => debug!(
64                event = "healthcheck.touch",
65                user_agent = ua,
66                "last_healthcheck updated via /health"
67            ),
68            Err(e) => {
69                warn!(event = "healthcheck.touch.error", user_agent = ua, error = %e, "failed to update last_healthcheck")
70            }
71        }
72    } else {
73        debug!(
74            event = "healthcheck.skip",
75            user_agent = ua,
76            "health check not from DDS; not updating last_healthcheck"
77        );
78    }
79    StatusCode::OK
80}
81
82async fn callback_registration(
83    State(state): State<DdsState>,
84    Json(payload): Json<RegistrationCallbackRequest>,
85) -> Result<Json<OkResponse>, CallbackError> {
86    // Basic shape validation
87    if payload.id.trim().is_empty() {
88        return Err(CallbackError::Unprocessable("missing id"));
89    }
90    if payload.secret.trim().is_empty() {
91        return Err(CallbackError::Unprocessable("missing secret"));
92    }
93
94    // Optional: enforce some maximum size to avoid abuse
95    if payload.secret.len() > 4096 {
96        return Err(CallbackError::Forbidden("secret too large"));
97    }
98
99    // Log without exposing sensitive secret
100    let secret_len = payload.secret.len();
101    let org = payload.organization_id.as_deref().unwrap_or("");
102    info!(id = %payload.id, org = %org, secret_len = secret_len, "Received registration callback");
103
104    // Persist atomically
105    write_node_secret_to_path(&state.secret_path, &payload.secret)
106        .map_err(|_| CallbackError::Conflict("persist failed"))?;
107
108    // Sanity read-back (optional; not exposing value)
109    match read_node_secret_from_path(&state.secret_path) {
110        Ok(Some(_)) => {}
111        Ok(None) => {
112            warn!("persisted secret missing after write");
113            return Err(CallbackError::Conflict("persist verify failed"));
114        }
115        Err(_) => {
116            return Err(CallbackError::Conflict("persist verify failed"));
117        }
118    }
119
120    Ok(Json(OkResponse { ok: true }))
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use axum::{
127        body::Body,
128        http::{Request, StatusCode},
129    };
130    use parking_lot::Mutex as PLMutex;
131    use std::io;
132    use std::path::PathBuf;
133    use std::sync::Arc;
134    use tower::ServiceExt;
135    use tracing::subscriber;
136    use tracing_subscriber::layer::SubscriberExt;
137
138    #[tokio::test]
139    async fn callback_persists_and_redacts_secret() {
140        struct BufWriter(Arc<PLMutex<Vec<u8>>>);
141        impl io::Write for BufWriter {
142            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
143                self.0.lock().extend_from_slice(buf);
144                Ok(buf.len())
145            }
146            fn flush(&mut self) -> io::Result<()> {
147                Ok(())
148            }
149        }
150        struct MakeBufWriter(Arc<PLMutex<Vec<u8>>>);
151        impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for MakeBufWriter {
152            type Writer = BufWriter;
153            fn make_writer(&'a self) -> Self::Writer {
154                BufWriter(self.0.clone())
155            }
156        }
157
158        let buf = Arc::new(PLMutex::new(Vec::<u8>::new()));
159        let make = MakeBufWriter(buf.clone());
160        let layer = tracing_subscriber::fmt::layer()
161            .with_writer(make)
162            .with_ansi(false)
163            .without_time();
164        let subscriber = tracing_subscriber::registry().with(layer);
165        let _guard = subscriber::set_default(subscriber);
166
167        let secret_path = PathBuf::from(format!("dds_http_test/{}", uuid::Uuid::new_v4()));
168        let app = router_dds(DdsState {
169            secret_path: secret_path.clone(),
170        });
171
172        let secret = "my-very-secret";
173        let body = serde_json::json!({
174            "id": "abc123",
175            "secret": secret,
176            "organization_id": "org1",
177            "lighthouses_in_domains": [],
178            "domains": []
179        })
180        .to_string();
181
182        let req = Request::builder()
183            .method("POST")
184            .uri("/internal/v1/registrations")
185            .header(axum::http::header::CONTENT_TYPE, "application/json")
186            .body(Body::from(body))
187            .unwrap();
188
189        let resp = app.oneshot(req).await.unwrap();
190        assert_eq!(resp.status(), StatusCode::OK);
191
192        let got = read_node_secret_from_path(&secret_path).unwrap();
193        assert_eq!(got.as_deref(), Some(secret));
194
195        let captured = String::from_utf8(buf.lock().clone()).unwrap_or_default();
196        assert!(captured.contains("Received registration callback"));
197        assert!(
198            !captured.contains(secret),
199            "logs leaked secret: {}",
200            captured
201        );
202    }
203
204    #[tokio::test]
205    async fn health_ok() {
206        let secret_path = PathBuf::from(format!("dds_http_test/{}", uuid::Uuid::new_v4()));
207        let app = router_dds(DdsState { secret_path });
208
209        let req = Request::builder()
210            .method("GET")
211            .uri("/health")
212            .body(Body::empty())
213            .unwrap();
214        let resp = app.oneshot(req).await.unwrap();
215        assert_eq!(resp.status(), StatusCode::OK);
216    }
217}