Skip to main content

mii_http/
server.rs

1//! HTTP server runtime built on top of axum.
2
3use crate::exec::{self, BodyValue, ExecContext, ExecOutput};
4use crate::spec::*;
5use crate::value::{self, ValidationError};
6use axum::{
7    Router,
8    body::Bytes,
9    extract::{DefaultBodyLimit, Path as AxPath, Query, State},
10    http::{HeaderMap, StatusCode, header},
11    response::{IntoResponse, Response},
12    routing::{MethodFilter, MethodRouter},
13};
14use std::collections::{BTreeMap, HashMap};
15use std::net::SocketAddr;
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::net::TcpListener;
19
20#[derive(Clone)]
21struct AppState {
22    spec: Arc<Spec>,
23    auth_secret: Option<Vec<u8>>,
24    auth_jwt_verifier: Option<String>,
25    dry_run: bool,
26}
27
28pub async fn serve(spec: Spec, addr: SocketAddr, dry_run: bool) -> std::io::Result<()> {
29    tracing::debug!(addr = %addr, dry_run, endpoints = spec.endpoints.len(), "server::serve");
30    let auth_secret = match &spec.setup.token_secret {
31        Some(src) => Some(resolve_static_source(src)?.into_bytes()),
32        None => None,
33    };
34    let auth_jwt_verifier = match &spec.setup.jwt_verifier {
35        Some(src) => Some(resolve_static_source(src)?),
36        None => None,
37    };
38    let state = AppState {
39        spec: Arc::new(spec),
40        auth_secret,
41        auth_jwt_verifier,
42        dry_run,
43    };
44    let router = build_router(state.clone());
45    let listener = TcpListener::bind(addr).await?;
46    if dry_run {
47        tracing::info!(
48            "mii-http listening on {} (dry-run: commands will not be executed)",
49            addr
50        );
51    } else {
52        tracing::info!("mii-http listening on {}", addr);
53    }
54    axum::serve(listener, router.into_make_service())
55        .await
56        .map_err(|e| std::io::Error::other(e.to_string()))
57}
58
59fn resolve_static_source(src: &ValueSource) -> std::io::Result<String> {
60    match src {
61        ValueSource::Env { name, .. } => std::env::var(name)
62            .map_err(|_| std::io::Error::other(format!("env var `{}` not set", name))),
63        ValueSource::Literal { value, .. } => Ok(value.clone()),
64        ValueSource::Header { .. } => Err(std::io::Error::other(
65            "[HEADER ...] is not valid for static setup values",
66        )),
67    }
68}
69
70fn build_router(state: AppState) -> Router {
71    tracing::debug!("server::build_router");
72    let mut routes: HashMap<String, MethodRouter<AppState>> = HashMap::new();
73    let prefix = compute_prefix(&state.spec.setup);
74    let body_limit = state.spec.setup.max_body_size.map(saturating_usize);
75
76    for (idx, ep) in state.spec.endpoints.iter().enumerate() {
77        let path = format!("{}{}", prefix, axum_path(&ep.path_segments));
78        tracing::debug!(method = ep.method.as_str(), path = %path, "server::build_router: mounting route");
79        let entry = routes.entry(path).or_default();
80        let idx_clone = idx;
81        let mr = MethodRouter::<AppState>::new().on(
82            method_filter(ep.method),
83            move |s: State<AppState>,
84                  p: AxPath<HashMap<String, String>>,
85                  q: Query<HashMap<String, String>>,
86                  h: HeaderMap,
87                  b: Bytes| handle(s, p, q, h, b, idx_clone),
88        );
89        let merged = std::mem::take(entry).merge(mr);
90        *entry = merged;
91    }
92
93    let mut router = Router::new();
94    for (path, mr) in routes {
95        router = router.route(&path, mr);
96    }
97    let router = router.with_state(state);
98    if let Some(limit) = body_limit {
99        router.layer(DefaultBodyLimit::max(limit))
100    } else {
101        router
102    }
103}
104
105fn saturating_usize(n: u64) -> usize {
106    usize::try_from(n).unwrap_or(usize::MAX)
107}
108
109fn method_filter(m: Method) -> MethodFilter {
110    match m {
111        Method::Get => MethodFilter::GET,
112        Method::Post => MethodFilter::POST,
113        Method::Put => MethodFilter::PUT,
114        Method::Delete => MethodFilter::DELETE,
115        Method::Patch => MethodFilter::PATCH,
116    }
117}
118
119fn compute_prefix(setup: &Setup) -> String {
120    let base = setup.base.clone().unwrap_or_default();
121    let version = setup
122        .version
123        .map(|v| format!("/v{}", v))
124        .unwrap_or_default();
125    format!("{}{}", base, version)
126}
127
128fn axum_path(segs: &[PathSegment]) -> String {
129    let mut out = String::new();
130    for seg in segs {
131        out.push('/');
132        match seg {
133            PathSegment::Literal(s) => out.push_str(s),
134            PathSegment::Param { name, .. } => {
135                out.push(':');
136                out.push_str(name);
137            }
138        }
139    }
140    if out.is_empty() { "/".into() } else { out }
141}
142
143async fn handle(
144    State(state): State<AppState>,
145    AxPath(path): AxPath<HashMap<String, String>>,
146    Query(query): Query<HashMap<String, String>>,
147    headers: HeaderMap,
148    body: Bytes,
149    endpoint_idx: usize,
150) -> Response {
151    let ep = match state.spec.endpoints.get(endpoint_idx) {
152        Some(e) => e,
153        None => return error_response(StatusCode::INTERNAL_SERVER_ERROR, "endpoint missing"),
154    };
155    tracing::info!(method = ep.method.as_str(), path = %ep.path, "server::handle: incoming request");
156    match handle_inner(&state, ep, path, query, headers, body).await {
157        Ok(r) => r,
158        Err(err) => {
159            tracing::warn!(method = ep.method.as_str(), path = %ep.path, status = %err.status, error = %err.message, "server::handle: returning error");
160            err.into_response()
161        }
162    }
163}
164
165async fn handle_inner(
166    state: &AppState,
167    ep: &Endpoint,
168    path: HashMap<String, String>,
169    query: HashMap<String, String>,
170    headers: HeaderMap,
171    body: Bytes,
172) -> Result<Response, HandlerError> {
173    let setup = &state.spec.setup;
174
175    enforce_body_size(setup, &body)?;
176    authenticate(state, &headers)?;
177
178    let ctx = ExecContext {
179        query: validate_query(setup, ep, &query)?,
180        headers: validate_headers(setup, ep, &headers)?,
181        path: validate_path(ep, &path)?,
182        vars: resolve_vars(setup, ep, &headers)?,
183        body: build_body(ep, body)?,
184    };
185
186    let timeout = setup.timeout_ms.map(Duration::from_millis);
187
188    if state.dry_run {
189        let preview = exec::preview_pipeline(&ep.exec.pipeline, &ctx);
190        tracing::info!(
191            method = ep.method.as_str(),
192            path = %ep.path,
193            stages = ?preview,
194            "dry-run: skipping execution",
195        );
196        let mut body_text = String::from("[dry-run] would execute:\n");
197        for stage in &preview {
198            body_text.push_str("  ");
199            body_text.push_str(stage);
200            body_text.push('\n');
201        }
202        let mut resp = Response::new(body_text.into());
203        resp.headers_mut().insert(
204            header::CONTENT_TYPE,
205            header::HeaderValue::from_static("text/plain; charset=utf-8"),
206        );
207        return Ok(resp);
208    }
209
210    let ExecOutput {
211        status,
212        stdout,
213        stderr,
214    } = exec::run_pipeline(&ep.exec.pipeline, &ctx, timeout)
215        .await
216        .map_err(|e| HandlerError::new(StatusCode::INTERNAL_SERVER_ERROR, e))?;
217
218    if status != 0 {
219        tracing::warn!(
220            method = ep.method.as_str(),
221            path = %ep.path,
222            status,
223            stderr = %String::from_utf8_lossy(&stderr),
224            "exec returned non-zero"
225        );
226        return Err(HandlerError::new(
227            StatusCode::INTERNAL_SERVER_ERROR,
228            format!("command exited with status {}", status),
229        ));
230    }
231
232    let content_type = ep
233        .response_type
234        .clone()
235        .unwrap_or_else(|| "text/plain; charset=utf-8".into());
236    let mut resp = Response::new(stdout.into());
237    resp.headers_mut().insert(
238        header::CONTENT_TYPE,
239        content_type
240            .parse()
241            .unwrap_or_else(|_| header::HeaderValue::from_static("text/plain; charset=utf-8")),
242    );
243    Ok(resp)
244}
245
246fn check_validation(r: Result<(), ValidationError>, scope: &str) -> Result<(), HandlerError> {
247    r.map_err(|e| HandlerError::new(StatusCode::BAD_REQUEST, format!("{}: {}", scope, e.message)))
248}
249
250fn enforce_body_size(setup: &Setup, body: &Bytes) -> Result<(), HandlerError> {
251    if let Some(max) = setup.max_body_size
252        && body.len() as u64 > max
253    {
254        return Err(HandlerError::new(
255            StatusCode::PAYLOAD_TOO_LARGE,
256            format!("body exceeds max size of {} bytes", max),
257        ));
258    }
259    Ok(())
260}
261
262fn authenticate(state: &AppState, headers: &HeaderMap) -> Result<(), HandlerError> {
263    tracing::debug!("server::authenticate");
264    if let Some(AuthSpec::BearerHeader { header: hname, .. }) = &state.spec.setup.auth {
265        let token = extract_bearer(headers, hname, state.spec.setup.max_header_size)?;
266        verify_token(state, &token)?;
267    }
268    Ok(())
269}
270
271fn enforce_size(
272    actual: usize,
273    max: Option<u64>,
274    status: StatusCode,
275    label: impl FnOnce() -> String,
276) -> Result<(), HandlerError> {
277    if let Some(max) = max
278        && actual as u64 > max
279    {
280        return Err(HandlerError::new(status, label()));
281    }
282    Ok(())
283}
284
285fn require_or_optional<T>(
286    found: Option<T>,
287    optional: bool,
288    missing_msg: impl FnOnce() -> String,
289) -> Result<Option<T>, HandlerError> {
290    match found {
291        Some(v) => Ok(Some(v)),
292        None if optional => Ok(None),
293        None => Err(HandlerError::new(StatusCode::BAD_REQUEST, missing_msg())),
294    }
295}
296
297fn validate_query(
298    setup: &Setup,
299    ep: &Endpoint,
300    query: &HashMap<String, String>,
301) -> Result<BTreeMap<String, String>, HandlerError> {
302    tracing::debug!(endpoint = %ep.path, fields = ep.query_params.len(), "server::validate_query");
303    let mut out = BTreeMap::new();
304    for f in &ep.query_params {
305        let v = require_or_optional(query.get(&f.name).cloned(), f.optional, || {
306            format!("missing query parameter `{}`", f.name)
307        })?;
308        if let Some(v) = v {
309            enforce_size(
310                v.len(),
311                setup.max_query_param_size,
312                StatusCode::URI_TOO_LONG,
313                || format!("query param `{}` exceeds max size", f.name),
314            )?;
315            check_validation(
316                value::validate_text(&v, &f.ty),
317                &format!("query `{}`", f.name),
318            )?;
319            out.insert(f.name.clone(), v);
320        }
321    }
322    Ok(out)
323}
324
325fn validate_headers(
326    setup: &Setup,
327    ep: &Endpoint,
328    headers: &HeaderMap,
329) -> Result<BTreeMap<String, String>, HandlerError> {
330    tracing::debug!(endpoint = %ep.path, fields = ep.headers.len(), "server::validate_headers");
331    let mut out = BTreeMap::new();
332    for f in &ep.headers {
333        let v = require_or_optional(header_get(headers, &f.name), f.optional, || {
334            format!("missing header `{}`", f.name)
335        })?;
336        if let Some(v) = v {
337            enforce_size(
338                v.len(),
339                setup.max_header_size,
340                StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
341                || format!("header `{}` exceeds max size", f.name),
342            )?;
343            check_validation(
344                value::validate_text(&v, &f.ty),
345                &format!("header `{}`", f.name),
346            )?;
347            out.insert(f.name.clone(), v);
348        }
349    }
350    Ok(out)
351}
352
353fn validate_path(
354    ep: &Endpoint,
355    path: &HashMap<String, String>,
356) -> Result<BTreeMap<String, String>, HandlerError> {
357    tracing::debug!(endpoint = %ep.path, "server::validate_path");
358    let mut out = BTreeMap::new();
359    for seg in &ep.path_segments {
360        if let PathSegment::Param { name, ty, .. } = seg {
361            let v = path.get(name).cloned().ok_or_else(|| {
362                HandlerError::new(
363                    StatusCode::BAD_REQUEST,
364                    format!("missing path param `{}`", name),
365                )
366            })?;
367            check_validation(value::validate_text(&v, ty), &format!("path `{}`", name))?;
368            out.insert(name.clone(), v);
369        }
370    }
371    Ok(out)
372}
373
374fn resolve_vars(
375    setup: &Setup,
376    ep: &Endpoint,
377    headers: &HeaderMap,
378) -> Result<BTreeMap<String, String>, HandlerError> {
379    tracing::debug!(endpoint = %ep.path, vars = ep.vars.len(), "server::resolve_vars");
380    let mut out = BTreeMap::new();
381    for v in &ep.vars {
382        let resolved = resolve_runtime_source(setup, &v.source, headers)?;
383        out.insert(v.name.clone(), resolved);
384    }
385    Ok(out)
386}
387
388fn build_body(ep: &Endpoint, body: Bytes) -> Result<BodyValue, HandlerError> {
389    tracing::debug!(endpoint = %ep.path, body_len = body.len(), "server::build_body");
390    Ok(match &ep.body {
391        None => BodyValue::None,
392        Some(BodySpec::String { .. }) => {
393            BodyValue::Text(String::from_utf8(body.to_vec()).map_err(|_| {
394                HandlerError::new(StatusCode::BAD_REQUEST, "body is not valid UTF-8")
395            })?)
396        }
397        Some(BodySpec::Binary { .. }) => BodyValue::Binary(body),
398        Some(BodySpec::Json { schema, .. }) => {
399            let v: serde_json::Value = serde_json::from_slice(&body).map_err(|e| {
400                HandlerError::new(StatusCode::BAD_REQUEST, format!("invalid JSON body: {}", e))
401            })?;
402            if let Some(schema) = schema {
403                check_validation(value::validate_json(&v, schema), "json body")?;
404            }
405            BodyValue::Json(v)
406        }
407        Some(BodySpec::Form { fields, .. }) => {
408            let parsed: BTreeMap<String, String> =
409                form_urlencoded::parse(&body).into_owned().collect();
410            for f in fields {
411                let v = require_or_optional(parsed.get(&f.name), f.optional, || {
412                    format!("missing form field `{}`", f.name)
413                })?;
414                if let Some(v) = v {
415                    check_validation(
416                        value::validate_text(v, &f.ty),
417                        &format!("form field `{}`", f.name),
418                    )?;
419                }
420            }
421            BodyValue::Form(parsed)
422        }
423    })
424}
425
426fn header_get(headers: &HeaderMap, name: &str) -> Option<String> {
427    headers
428        .get(name)
429        .and_then(|v| v.to_str().ok())
430        .map(|s| s.to_string())
431}
432
433fn extract_bearer(
434    headers: &HeaderMap,
435    header_name: &str,
436    max_header_size: Option<u64>,
437) -> Result<String, HandlerError> {
438    let raw = header_get(headers, header_name).ok_or_else(|| {
439        HandlerError::new(
440            StatusCode::UNAUTHORIZED,
441            format!("missing `{}`", header_name),
442        )
443    })?;
444    enforce_size(
445        raw.len(),
446        max_header_size,
447        StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
448        || format!("auth header `{}` exceeds max size", header_name),
449    )?;
450    let token = raw
451        .strip_prefix("Bearer ")
452        .or_else(|| raw.strip_prefix("bearer "))
453        .unwrap_or(&raw)
454        .trim()
455        .to_string();
456    if token.is_empty() {
457        return Err(HandlerError::new(
458            StatusCode::UNAUTHORIZED,
459            "empty bearer token",
460        ));
461    }
462    Ok(token)
463}
464
465fn verify_token(state: &AppState, token: &str) -> Result<(), HandlerError> {
466    if let Some(verifier) = &state.auth_jwt_verifier {
467        use jsonwebtoken::{DecodingKey, Validation, decode};
468        let key = DecodingKey::from_secret(verifier.as_bytes());
469        let mut validation = Validation::default();
470        validation.validate_exp = true;
471        decode::<serde_json::Value>(token, &key, &validation).map_err(|e| {
472            HandlerError::new(StatusCode::UNAUTHORIZED, format!("invalid token: {}", e))
473        })?;
474        return Ok(());
475    }
476    if let Some(secret) = &state.auth_secret {
477        if constant_time_eq(token.as_bytes(), secret) {
478            return Ok(());
479        }
480        return Err(HandlerError::new(StatusCode::UNAUTHORIZED, "invalid token"));
481    }
482    Ok(())
483}
484
485fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
486    if a.len() != b.len() {
487        return false;
488    }
489    let mut diff: u8 = 0;
490    for (x, y) in a.iter().zip(b.iter()) {
491        diff |= x ^ y;
492    }
493    diff == 0
494}
495
496fn resolve_runtime_source(
497    setup: &Setup,
498    src: &ValueSource,
499    headers: &HeaderMap,
500) -> Result<String, HandlerError> {
501    match src {
502        ValueSource::Env { name, .. } => std::env::var(name).map_err(|_| {
503            HandlerError::new(
504                StatusCode::INTERNAL_SERVER_ERROR,
505                format!("env var `{}` not set", name),
506            )
507        }),
508        ValueSource::Header { name, .. } => {
509            let value = header_get(headers, name).ok_or_else(|| {
510                HandlerError::new(
511                    StatusCode::BAD_REQUEST,
512                    format!("missing VAR header `{}`", name),
513                )
514            })?;
515            enforce_size(
516                value.len(),
517                setup.max_header_size,
518                StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
519                || format!("VAR header `{}` exceeds max size", name),
520            )?;
521            Ok(value)
522        }
523        ValueSource::Literal { value, .. } => Ok(value.clone()),
524    }
525}
526
527#[derive(Debug)]
528struct HandlerError {
529    status: StatusCode,
530    message: String,
531}
532
533impl HandlerError {
534    fn new(status: StatusCode, msg: impl Into<String>) -> Self {
535        Self {
536            status,
537            message: msg.into(),
538        }
539    }
540}
541
542impl IntoResponse for HandlerError {
543    fn into_response(self) -> Response {
544        error_response(self.status, &self.message)
545    }
546}
547
548fn error_response(status: StatusCode, msg: &str) -> Response {
549    let mut resp = Response::new(format!("{}\n", msg).into());
550    *resp.status_mut() = status;
551    resp.headers_mut().insert(
552        header::CONTENT_TYPE,
553        header::HeaderValue::from_static("text/plain; charset=utf-8"),
554    );
555    resp
556}