qml-rs 1.1.0

A Rust implementation of QML background job processing
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
//! Authentication and CSRF protection for the dashboard.
//!
//! Two axum middlewares are exposed:
//!
//! - [`require_auth`] — gates every dashboard request on a configured
//!   [`DashboardAuth`]. Supports HTTP Basic and Bearer tokens; comparisons
//!   are constant-time to keep the timing side-channel closed.
//! - [`csrf_guard`] — on state-changing methods (`POST`/`PUT`/`PATCH`/
//!   `DELETE`), requires that the `Origin` header (falling back to `Referer`)
//!   names the same authority as the request's `Host` header. Rejects with
//!   `403` otherwise. Same-origin browser requests set `Origin` automatically,
//!   so the legitimate dashboard UI keeps working; a cross-site `<form>`
//!   submit from an attacker page does not.

use axum::{
    extract::{Request, State},
    http::{Method, StatusCode, header},
    middleware::Next,
    response::{IntoResponse, Response},
};
use std::sync::Arc;

/// How to authenticate requests to the dashboard.
#[derive(Debug, Clone)]
pub enum DashboardAuth {
    /// HTTP Basic authentication. The supplied `username`/`password` are
    /// compared in constant time against the credentials in the request's
    /// `Authorization: Basic …` header.
    Basic { username: String, password: String },
    /// HTTP Bearer token authentication. The supplied `token` is compared in
    /// constant time against the token in the request's `Authorization:
    /// Bearer …` header.
    Bearer { token: String },
}

/// Axum middleware that gates every request on the configured [`DashboardAuth`].
///
/// On a missing or invalid credential the middleware returns `401
/// Unauthorized`; for Basic auth the response also carries a
/// `WWW-Authenticate` challenge so browsers will prompt for credentials.
pub async fn require_auth(
    State(auth): State<Arc<DashboardAuth>>,
    req: Request,
    next: Next,
) -> Response {
    let provided = req
        .headers()
        .get(header::AUTHORIZATION)
        .and_then(|v| v.to_str().ok());

    let ok = match (auth.as_ref(), provided) {
        (DashboardAuth::Basic { username, password }, Some(h)) => {
            check_basic(h, username, password)
        }
        (DashboardAuth::Bearer { token }, Some(h)) => check_bearer(h, token),
        _ => false,
    };

    if ok {
        next.run(req).await
    } else {
        unauthorized(&auth)
    }
}

/// Axum middleware that enforces same-origin on state-changing requests.
///
/// Applied to every route, but only inspects `POST`/`PUT`/`PATCH`/`DELETE` —
/// safe methods pass through untouched. For unsafe methods, the request's
/// `Host` header is compared to the authority of `Origin` (or `Referer` if
/// `Origin` is absent); mismatches and missing values both fail closed with
/// `403 Forbidden`.
pub async fn csrf_guard(req: Request, next: Next) -> Response {
    let is_mutation = matches!(
        *req.method(),
        Method::POST | Method::PUT | Method::PATCH | Method::DELETE
    );
    if !is_mutation {
        return next.run(req).await;
    }

    let headers = req.headers();
    let host = headers.get(header::HOST).and_then(|v| v.to_str().ok());
    let source = headers
        .get(header::ORIGIN)
        .or_else(|| headers.get(header::REFERER))
        .and_then(|v| v.to_str().ok())
        .and_then(authority_of);

    match (host, source) {
        (Some(host), Some(source)) if host.eq_ignore_ascii_case(&source) => next.run(req).await,
        _ => (StatusCode::FORBIDDEN, "CSRF check failed").into_response(),
    }
}

/// Return `true` if `host` is a loopback interface; used to decide whether
/// the dashboard may start without an auth guard.
pub(crate) fn is_loopback_host(host: &str) -> bool {
    if host.eq_ignore_ascii_case("localhost") {
        return true;
    }
    // IPv6 literals may appear bracketed (`[::1]`) in config strings.
    let trimmed = host.trim_start_matches('[').trim_end_matches(']');
    trimmed
        .parse::<std::net::IpAddr>()
        .map(|ip| ip.is_loopback())
        .unwrap_or(false)
}

fn unauthorized(auth: &DashboardAuth) -> Response {
    let mut resp = (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
    if matches!(auth, DashboardAuth::Basic { .. })
        && let Ok(challenge) = header::HeaderValue::from_str("Basic realm=\"qml-dashboard\"")
    {
        resp.headers_mut()
            .insert(header::WWW_AUTHENTICATE, challenge);
    }
    resp
}

fn check_basic(header_value: &str, expected_user: &str, expected_pass: &str) -> bool {
    let Some(encoded) = header_value.strip_prefix("Basic ") else {
        return false;
    };
    let Some(bytes) = base64_decode(encoded.trim()) else {
        return false;
    };
    let Ok(decoded) = std::str::from_utf8(&bytes) else {
        return false;
    };
    let Some((user, pass)) = decoded.split_once(':') else {
        return false;
    };
    // Always run both comparisons so the total time doesn't depend on which
    // field mismatched first.
    let user_ok = constant_time_eq(user.as_bytes(), expected_user.as_bytes());
    let pass_ok = constant_time_eq(pass.as_bytes(), expected_pass.as_bytes());
    user_ok & pass_ok
}

fn check_bearer(header_value: &str, expected: &str) -> bool {
    let Some(token) = header_value.strip_prefix("Bearer ") else {
        return false;
    };
    constant_time_eq(token.trim().as_bytes(), expected.as_bytes())
}

fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
    if a.len() != b.len() {
        return false;
    }
    let mut diff = 0u8;
    for (x, y) in a.iter().zip(b.iter()) {
        diff |= x ^ y;
    }
    diff == 0
}

/// Extract the authority (`host[:port]`) from a URL. Returns `None` if the
/// input doesn't start with a valid scheme — this mirrors browser-emitted
/// `Origin` / `Referer` headers, which are always absolute URLs.
fn authority_of(url: &str) -> Option<String> {
    let (_, rest) = url.split_once("://")?;
    let authority = rest.split(['/', '?', '#']).next()?;
    if authority.is_empty() {
        None
    } else {
        Some(authority.to_ascii_lowercase())
    }
}

fn base64_decode(input: &str) -> Option<Vec<u8>> {
    fn val(c: u8) -> Option<u32> {
        match c {
            b'A'..=b'Z' => Some((c - b'A') as u32),
            b'a'..=b'z' => Some((c - b'a' + 26) as u32),
            b'0'..=b'9' => Some((c - b'0' + 52) as u32),
            b'+' => Some(62),
            b'/' => Some(63),
            _ => None,
        }
    }
    let input = input.trim_end_matches('=').as_bytes();
    let mut out = Vec::with_capacity(input.len() * 3 / 4);
    let mut buf: u32 = 0;
    let mut bits: u8 = 0;
    for &c in input {
        buf = (buf << 6) | val(c)?;
        bits += 6;
        if bits >= 8 {
            bits -= 8;
            out.push((buf >> bits) as u8);
            buf &= (1u32 << bits) - 1;
        }
    }
    Some(out)
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        Router,
        body::Body,
        http::{Method, Request, StatusCode, header},
        middleware,
        routing::{delete, get, post},
    };
    use tower::ServiceExt;

    #[test]
    fn base64_decodes_standard_input() {
        assert_eq!(
            base64_decode("YWRtaW46c2VjcmV0"),
            Some(b"admin:secret".to_vec())
        );
        assert_eq!(base64_decode("Zm9vOmJhcg=="), Some(b"foo:bar".to_vec()));
        assert_eq!(base64_decode(""), Some(Vec::new()));
        assert!(base64_decode("not*base64!").is_none());
    }

    #[test]
    fn check_basic_accepts_matching_creds() {
        let header = "Basic YWRtaW46c2VjcmV0"; // admin:secret
        assert!(check_basic(header, "admin", "secret"));
        assert!(!check_basic(header, "admin", "wrong"));
        assert!(!check_basic(header, "root", "secret"));
        assert!(!check_basic("Bearer xyz", "admin", "secret"));
        assert!(!check_basic("Basic !!!", "admin", "secret"));
    }

    #[test]
    fn check_bearer_accepts_matching_token() {
        assert!(check_bearer("Bearer abc123", "abc123"));
        assert!(!check_bearer("Bearer abc123", "xyz"));
        assert!(!check_bearer("Basic abc123", "abc123"));
    }

    #[test]
    fn constant_time_eq_handles_length_mismatch() {
        assert!(!constant_time_eq(b"abc", b"abcd"));
        assert!(constant_time_eq(b"abc", b"abc"));
        assert!(!constant_time_eq(b"abc", b"abd"));
    }

    #[test]
    fn authority_of_strips_path_and_scheme() {
        assert_eq!(
            authority_of("http://example.com:8080/foo?bar=1"),
            Some("example.com:8080".to_string())
        );
        assert_eq!(
            authority_of("https://DASH.local"),
            Some("dash.local".to_string())
        );
        assert_eq!(authority_of("not-a-url"), None);
    }

    #[test]
    fn is_loopback_host_recognizes_expected_values() {
        assert!(is_loopback_host("localhost"));
        assert!(is_loopback_host("127.0.0.1"));
        assert!(is_loopback_host("::1"));
        assert!(is_loopback_host("[::1]"));
        assert!(!is_loopback_host("10.0.0.1"));
        assert!(!is_loopback_host("example.com"));
    }

    fn test_app(auth: Option<DashboardAuth>) -> Router {
        let mut app = Router::new()
            .route("/api/health", get(|| async { "ok" }))
            .route(
                "/api/jobs/{id}/retry",
                post(|| async { StatusCode::NO_CONTENT }),
            )
            .route(
                "/api/jobs/{id}",
                delete(|| async { StatusCode::NO_CONTENT }),
            );

        app = app.layer(middleware::from_fn(csrf_guard));

        if let Some(auth) = auth {
            app = app.layer(middleware::from_fn_with_state(Arc::new(auth), require_auth));
        }
        app
    }

    async fn send(app: Router, req: Request<Body>) -> StatusCode {
        app.oneshot(req).await.unwrap().status()
    }

    #[tokio::test]
    async fn require_auth_rejects_missing_credentials() {
        let app = test_app(Some(DashboardAuth::Bearer {
            token: "secret".into(),
        }));
        let req = Request::builder()
            .uri("/api/health")
            .body(Body::empty())
            .unwrap();
        assert_eq!(send(app, req).await, StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn require_auth_accepts_matching_bearer_token() {
        let app = test_app(Some(DashboardAuth::Bearer {
            token: "secret".into(),
        }));
        let req = Request::builder()
            .uri("/api/health")
            .header(header::AUTHORIZATION, "Bearer secret")
            .body(Body::empty())
            .unwrap();
        assert_eq!(send(app, req).await, StatusCode::OK);
    }

    #[tokio::test]
    async fn require_auth_rejects_wrong_bearer_token() {
        let app = test_app(Some(DashboardAuth::Bearer {
            token: "secret".into(),
        }));
        let req = Request::builder()
            .uri("/api/health")
            .header(header::AUTHORIZATION, "Bearer wrong")
            .body(Body::empty())
            .unwrap();
        assert_eq!(send(app, req).await, StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn require_auth_accepts_matching_basic_credentials() {
        let app = test_app(Some(DashboardAuth::Basic {
            username: "admin".into(),
            password: "secret".into(),
        }));
        // base64("admin:secret") = "YWRtaW46c2VjcmV0"
        let req = Request::builder()
            .uri("/api/health")
            .header(header::AUTHORIZATION, "Basic YWRtaW46c2VjcmV0")
            .body(Body::empty())
            .unwrap();
        assert_eq!(send(app, req).await, StatusCode::OK);
    }

    #[tokio::test]
    async fn csrf_guard_blocks_mutation_without_origin() {
        let app = test_app(None);
        let req = Request::builder()
            .method(Method::POST)
            .uri("/api/jobs/abc/retry")
            .header(header::HOST, "localhost:8080")
            .body(Body::empty())
            .unwrap();
        assert_eq!(send(app, req).await, StatusCode::FORBIDDEN);
    }

    #[tokio::test]
    async fn csrf_guard_allows_same_origin_mutation() {
        let app = test_app(None);
        let req = Request::builder()
            .method(Method::POST)
            .uri("/api/jobs/abc/retry")
            .header(header::HOST, "localhost:8080")
            .header(header::ORIGIN, "http://localhost:8080")
            .body(Body::empty())
            .unwrap();
        assert_eq!(send(app, req).await, StatusCode::NO_CONTENT);
    }

    #[tokio::test]
    async fn csrf_guard_rejects_cross_origin_mutation() {
        let app = test_app(None);
        let req = Request::builder()
            .method(Method::DELETE)
            .uri("/api/jobs/abc")
            .header(header::HOST, "localhost:8080")
            .header(header::ORIGIN, "https://evil.example.com")
            .body(Body::empty())
            .unwrap();
        assert_eq!(send(app, req).await, StatusCode::FORBIDDEN);
    }

    #[tokio::test]
    async fn csrf_guard_ignores_safe_methods() {
        let app = test_app(None);
        let req = Request::builder()
            .method(Method::GET)
            .uri("/api/health")
            .header(header::HOST, "localhost:8080")
            .body(Body::empty())
            .unwrap();
        assert_eq!(send(app, req).await, StatusCode::OK);
    }

    #[tokio::test]
    async fn csrf_guard_falls_back_to_referer() {
        let app = test_app(None);
        let req = Request::builder()
            .method(Method::POST)
            .uri("/api/jobs/abc/retry")
            .header(header::HOST, "localhost:8080")
            .header(header::REFERER, "http://localhost:8080/jobs")
            .body(Body::empty())
            .unwrap();
        assert_eq!(send(app, req).await, StatusCode::NO_CONTENT);
    }
}