posemesh_node_registration/
http.rs

1use crate::persist::{read_node_secret, write_node_secret};
2use crate::state::touch_healthcheck_now;
3use axum::extract::State;
4use axum::http::{header::USER_AGENT, HeaderMap, StatusCode};
5use axum::response::IntoResponse;
6use axum::routing::{get, post};
7use axum::Json;
8use serde::{Deserialize, Serialize};
9use tracing::{debug, info, warn};
10
11#[derive(Clone)]
12pub struct DdsState;
13
14#[derive(Debug, Deserialize)]
15pub struct RegistrationCallbackRequest {
16    pub id: String,
17    pub secret: String,
18    pub organization_id: Option<String>,
19    pub lighthouses_in_domains: Option<serde_json::Value>,
20    pub domains: Option<serde_json::Value>,
21}
22
23#[derive(Debug, Serialize)]
24struct OkResponse {
25    ok: bool,
26}
27
28#[derive(Debug)]
29enum CallbackError {
30    Unprocessable(&'static str), // 422
31    Forbidden(&'static str),     // 403
32    Conflict(&'static str),      // 409
33}
34
35impl IntoResponse for CallbackError {
36    fn into_response(self) -> axum::response::Response {
37        let (status, msg) = match self {
38            CallbackError::Unprocessable(m) => (StatusCode::UNPROCESSABLE_ENTITY, m),
39            CallbackError::Forbidden(m) => (StatusCode::FORBIDDEN, m),
40            CallbackError::Conflict(m) => (StatusCode::CONFLICT, m),
41        };
42        (status, msg).into_response()
43    }
44}
45
46pub fn router_dds(state: DdsState) -> axum::Router {
47    axum::Router::new()
48        .route("/internal/v1/registrations", post(callback_registration))
49        .route("/health", get(health))
50        .with_state(state)
51}
52
53async fn health(State(_state): State<DdsState>, headers: HeaderMap) -> impl IntoResponse {
54    let ua = headers
55        .get(USER_AGENT)
56        .and_then(|v| v.to_str().ok())
57        .unwrap_or("");
58    if ua.starts_with("DDS v") {
59        match touch_healthcheck_now() {
60            Ok(()) => debug!(
61                event = "healthcheck.touch",
62                user_agent = ua,
63                "last_healthcheck updated via /health"
64            ),
65            Err(e) => {
66                warn!(event = "healthcheck.touch.error", user_agent = ua, error = %e, "failed to update last_healthcheck")
67            }
68        }
69    } else {
70        debug!(
71            event = "healthcheck.skip",
72            user_agent = ua,
73            "health check not from DDS; not updating last_healthcheck"
74        );
75    }
76    StatusCode::OK
77}
78
79async fn callback_registration(
80    State(_state): State<DdsState>,
81    Json(payload): Json<RegistrationCallbackRequest>,
82) -> Result<Json<OkResponse>, CallbackError> {
83    // Basic shape validation
84    if payload.id.trim().is_empty() {
85        return Err(CallbackError::Unprocessable("missing id"));
86    }
87    if payload.secret.trim().is_empty() {
88        return Err(CallbackError::Unprocessable("missing secret"));
89    }
90
91    // Optional: enforce some maximum size to avoid abuse
92    if payload.secret.len() > 4096 {
93        return Err(CallbackError::Forbidden("secret too large"));
94    }
95
96    // Log without exposing sensitive secret
97    let secret_len = payload.secret.len();
98    let org = payload.organization_id.as_deref().unwrap_or("");
99    info!(id = %payload.id, org = %org, secret_len = secret_len, "Received registration callback");
100
101    // Persist atomically
102    write_node_secret(&payload.secret).map_err(|_| CallbackError::Conflict("persist failed"))?;
103
104    // Sanity read-back (optional; not exposing value)
105    match read_node_secret() {
106        Ok(Some(_)) => {}
107        Ok(None) => {
108            warn!("persisted secret missing after write");
109            return Err(CallbackError::Conflict("persist verify failed"));
110        }
111        Err(_) => {
112            return Err(CallbackError::Conflict("persist verify failed"));
113        }
114    }
115
116    Ok(Json(OkResponse { ok: true }))
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::persist;
123    use axum::{
124        body::Body,
125        http::{Request, StatusCode},
126    };
127    use parking_lot::Mutex as PLMutex;
128    use std::io;
129    use std::sync::Arc;
130    use tower::ServiceExt;
131    use tracing::subscriber;
132    use tracing_subscriber::layer::SubscriberExt;
133
134    #[tokio::test]
135    async fn callback_persists_and_redacts_secret() {
136        struct BufWriter(Arc<PLMutex<Vec<u8>>>);
137        impl io::Write for BufWriter {
138            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
139                self.0.lock().extend_from_slice(buf);
140                Ok(buf.len())
141            }
142            fn flush(&mut self) -> io::Result<()> {
143                Ok(())
144            }
145        }
146        struct MakeBufWriter(Arc<PLMutex<Vec<u8>>>);
147        impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for MakeBufWriter {
148            type Writer = BufWriter;
149            fn make_writer(&'a self) -> Self::Writer {
150                BufWriter(self.0.clone())
151            }
152        }
153
154        let buf = Arc::new(PLMutex::new(Vec::<u8>::new()));
155        let make = MakeBufWriter(buf.clone());
156        let layer = tracing_subscriber::fmt::layer()
157            .with_writer(make)
158            .with_ansi(false)
159            .without_time();
160        let subscriber = tracing_subscriber::registry().with(layer);
161        let _guard = subscriber::set_default(subscriber);
162
163        persist::clear_node_secret().unwrap();
164        let app = router_dds(DdsState);
165
166        let secret = "my-very-secret";
167        let body = serde_json::json!({
168            "id": "abc123",
169            "secret": secret,
170            "organization_id": "org1",
171            "lighthouses_in_domains": [],
172            "domains": []
173        })
174        .to_string();
175
176        let req = Request::builder()
177            .method("POST")
178            .uri("/internal/v1/registrations")
179            .header(axum::http::header::CONTENT_TYPE, "application/json")
180            .body(Body::from(body))
181            .unwrap();
182
183        let resp = app.oneshot(req).await.unwrap();
184        assert_eq!(resp.status(), StatusCode::OK);
185
186        let got = read_node_secret().unwrap();
187        assert_eq!(got.as_deref(), Some(secret));
188
189        let captured = String::from_utf8(buf.lock().clone()).unwrap_or_default();
190        assert!(captured.contains("Received registration callback"));
191        assert!(
192            !captured.contains(secret),
193            "logs leaked secret: {}",
194            captured
195        );
196    }
197
198    #[tokio::test]
199    async fn health_ok() {
200        persist::clear_node_secret().unwrap();
201        let app = router_dds(DdsState);
202
203        let req = Request::builder()
204            .method("GET")
205            .uri("/health")
206            .body(Body::empty())
207            .unwrap();
208        let resp = app.oneshot(req).await.unwrap();
209        assert_eq!(resp.status(), StatusCode::OK);
210    }
211}