Skip to main content

posemesh_node_registration/
http.rs

1#[cfg(test)]
2use crate::state::clear_node_secret;
3use crate::state::{
4    read_node_secret, set_status, touch_healthcheck_now, write_node_secret, STATUS_REGISTERED,
5};
6use axum::extract::State;
7use axum::http::{header::USER_AGENT, HeaderMap, StatusCode};
8use axum::response::IntoResponse;
9use axum::routing::{get, post};
10use axum::Json;
11use serde::{Deserialize, Serialize};
12use tracing::{debug, info, warn};
13
14#[derive(Clone)]
15pub struct DdsState;
16
17#[derive(Debug, Deserialize)]
18pub struct RegistrationCallbackRequest {
19    pub id: String,
20    pub secret: String,
21    pub organization_id: Option<String>,
22    pub lighthouses_in_domains: Option<serde_json::Value>,
23    pub domains: Option<serde_json::Value>,
24}
25
26#[derive(Debug, Serialize)]
27struct OkResponse {
28    ok: bool,
29}
30
31#[derive(Debug)]
32enum CallbackError {
33    Unprocessable(&'static str), // 422
34    Forbidden(&'static str),     // 403
35    Conflict(&'static str),      // 409
36}
37
38impl IntoResponse for CallbackError {
39    fn into_response(self) -> axum::response::Response {
40        let (status, msg) = match self {
41            CallbackError::Unprocessable(m) => (StatusCode::UNPROCESSABLE_ENTITY, m),
42            CallbackError::Forbidden(m) => (StatusCode::FORBIDDEN, m),
43            CallbackError::Conflict(m) => (StatusCode::CONFLICT, m),
44        };
45        (status, msg).into_response()
46    }
47}
48
49pub fn router_dds(state: DdsState) -> axum::Router {
50    axum::Router::new()
51        .route("/internal/v1/registrations", post(callback_registration))
52        .route("/health", get(health))
53        .with_state(state)
54}
55
56async fn health(State(_state): State<DdsState>, headers: HeaderMap) -> impl IntoResponse {
57    let ua = headers
58        .get(USER_AGENT)
59        .and_then(|v| v.to_str().ok())
60        .unwrap_or("");
61    if ua.starts_with("DDS v") {
62        match touch_healthcheck_now() {
63            Ok(()) => debug!(
64                event = "healthcheck.touch",
65                user_agent = ua,
66                "last_healthcheck updated via /health"
67            ),
68            Err(e) => {
69                warn!(event = "healthcheck.touch.error", user_agent = ua, error = %e, "failed to update last_healthcheck")
70            }
71        }
72    } else {
73        debug!(
74            event = "healthcheck.skip",
75            user_agent = ua,
76            "health check not from DDS; not updating last_healthcheck"
77        );
78    }
79    StatusCode::OK
80}
81
82async fn callback_registration(
83    State(_state): State<DdsState>,
84    Json(payload): Json<RegistrationCallbackRequest>,
85) -> Result<Json<OkResponse>, CallbackError> {
86    // Basic shape validation
87    if payload.id.trim().is_empty() {
88        return Err(CallbackError::Unprocessable("missing id"));
89    }
90    if payload.secret.trim().is_empty() {
91        return Err(CallbackError::Unprocessable("missing secret"));
92    }
93
94    // Optional: enforce some maximum size to avoid abuse
95    if payload.secret.len() > 4096 {
96        return Err(CallbackError::Forbidden("secret too large"));
97    }
98
99    // Log without exposing sensitive secret
100    let secret_len = payload.secret.len();
101    let org = payload.organization_id.as_deref().unwrap_or("");
102    info!(id = %payload.id, org = %org, secret_len = secret_len, "Received registration callback");
103
104    // Persist atomically
105    write_node_secret(&payload.secret).map_err(|_| CallbackError::Conflict("persist failed"))?;
106
107    // Sanity read-back (optional; not exposing value)
108    match read_node_secret() {
109        Ok(Some(_)) => {}
110        Ok(None) => {
111            warn!("persisted secret missing after write");
112            return Err(CallbackError::Conflict("persist verify failed"));
113        }
114        Err(_) => {
115            return Err(CallbackError::Conflict("persist verify failed"));
116        }
117    }
118
119    // Mark registration state explicitly so downstream listeners can trigger SIWE promptly.
120    if let Err(e) = set_status(STATUS_REGISTERED) {
121        warn!(error = %e, "failed to set registration status to 'registered'");
122        // Not fatal to the callback; secret persistence is the primary signal.
123    }
124
125    Ok(Json(OkResponse { ok: true }))
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use axum::{
132        body::Body,
133        http::{Request, StatusCode},
134    };
135    use parking_lot::Mutex as PLMutex;
136    use std::io;
137    use std::sync::Arc;
138    use tower::ServiceExt;
139    use tracing::subscriber;
140    use tracing_subscriber::layer::SubscriberExt;
141
142    #[tokio::test]
143    async fn callback_persists_and_redacts_secret() {
144        struct BufWriter(Arc<PLMutex<Vec<u8>>>);
145        impl io::Write for BufWriter {
146            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
147                self.0.lock().extend_from_slice(buf);
148                Ok(buf.len())
149            }
150            fn flush(&mut self) -> io::Result<()> {
151                Ok(())
152            }
153        }
154        struct MakeBufWriter(Arc<PLMutex<Vec<u8>>>);
155        impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for MakeBufWriter {
156            type Writer = BufWriter;
157            fn make_writer(&'a self) -> Self::Writer {
158                BufWriter(self.0.clone())
159            }
160        }
161
162        let buf = Arc::new(PLMutex::new(Vec::<u8>::new()));
163        let make = MakeBufWriter(buf.clone());
164        let layer = tracing_subscriber::fmt::layer()
165            .with_writer(make)
166            .with_ansi(false)
167            .without_time();
168        let subscriber = tracing_subscriber::registry().with(layer);
169        let _guard = subscriber::set_default(subscriber);
170
171        clear_node_secret().unwrap();
172        let app = router_dds(DdsState);
173
174        let secret = "my-very-secret";
175        let body = serde_json::json!({
176            "id": "abc123",
177            "secret": secret,
178            "organization_id": "org1",
179            "lighthouses_in_domains": [],
180            "domains": []
181        })
182        .to_string();
183
184        let req = Request::builder()
185            .method("POST")
186            .uri("/internal/v1/registrations")
187            .header(axum::http::header::CONTENT_TYPE, "application/json")
188            .body(Body::from(body))
189            .unwrap();
190
191        let resp = app.oneshot(req).await.unwrap();
192        assert_eq!(resp.status(), StatusCode::OK);
193
194        let got = read_node_secret().unwrap();
195        assert_eq!(got.as_deref(), Some(secret));
196
197        // Registration status should be set to 'registered'.
198        let st = crate::state::read_state().unwrap();
199        assert_eq!(st.status.as_str(), STATUS_REGISTERED);
200
201        let captured = String::from_utf8(buf.lock().clone()).unwrap_or_default();
202        assert!(captured.contains("Received registration callback"));
203        assert!(
204            !captured.contains(secret),
205            "logs leaked secret: {}",
206            captured
207        );
208    }
209
210    #[tokio::test]
211    async fn health_ok() {
212        clear_node_secret().unwrap();
213        let app = router_dds(DdsState);
214
215        let req = Request::builder()
216            .method("GET")
217            .uri("/health")
218            .body(Body::empty())
219            .unwrap();
220        let resp = app.oneshot(req).await.unwrap();
221        assert_eq!(resp.status(), StatusCode::OK);
222    }
223}