Skip to main content

mii_http/
server.rs

1//! HTTP server runtime built on top of axum.
2
3use crate::exec::{self, BodyValue, ExecContext, ExecOutput};
4use crate::spec::*;
5use crate::value::{self, ValidationError};
6use axum::{
7    Router,
8    body::Bytes,
9    extract::{Path as AxPath, Query, State},
10    http::{HeaderMap, StatusCode, header},
11    response::{IntoResponse, Response},
12    routing::{MethodFilter, MethodRouter},
13};
14use std::collections::{BTreeMap, HashMap};
15use std::net::SocketAddr;
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::net::TcpListener;
19
20#[derive(Clone)]
21struct AppState {
22    spec: Arc<Spec>,
23    auth_secret: Option<Vec<u8>>,
24    auth_jwt_verifier: Option<String>,
25    dry_run: bool,
26}
27
28pub async fn serve(spec: Spec, addr: SocketAddr, dry_run: bool) -> std::io::Result<()> {
29    tracing::debug!(addr = %addr, dry_run, endpoints = spec.endpoints.len(), "server::serve");
30    let auth_secret = match &spec.setup.token_secret {
31        Some(src) => Some(resolve_static_source(src)?.into_bytes()),
32        None => None,
33    };
34    let auth_jwt_verifier = match &spec.setup.jwt_verifier {
35        Some(src) => Some(resolve_static_source(src)?),
36        None => None,
37    };
38    let state = AppState {
39        spec: Arc::new(spec),
40        auth_secret,
41        auth_jwt_verifier,
42        dry_run,
43    };
44    let router = build_router(state.clone());
45    let listener = TcpListener::bind(addr).await?;
46    if dry_run {
47        tracing::info!("mii-http listening on {} (dry-run: commands will not be executed)", addr);
48    } else {
49        tracing::info!("mii-http listening on {}", addr);
50    }
51    axum::serve(listener, router.into_make_service())
52        .await
53        .map_err(|e| std::io::Error::other(e.to_string()))
54}
55
56fn resolve_static_source(src: &ValueSource) -> std::io::Result<String> {
57    match src {
58        ValueSource::Env { name, .. } => std::env::var(name).map_err(|_| {
59            std::io::Error::other(format!("env var `{}` not set", name))
60        }),
61        ValueSource::Literal { value, .. } => Ok(value.clone()),
62        ValueSource::Header { .. } => Err(std::io::Error::other(
63            "[HEADER ...] is not valid for static setup values",
64        )),
65    }
66}
67
68fn build_router(state: AppState) -> Router {
69    tracing::debug!("server::build_router");
70    let mut routes: HashMap<String, MethodRouter<AppState>> = HashMap::new();
71    let prefix = compute_prefix(&state.spec.setup);
72
73    for (idx, ep) in state.spec.endpoints.iter().enumerate() {
74        let path = format!("{}{}", prefix, axum_path(&ep.path_segments));
75        tracing::debug!(method = ep.method.as_str(), path = %path, "server::build_router: mounting route");
76        let entry = routes.entry(path).or_insert_with(MethodRouter::<AppState>::new);
77        let idx_clone = idx;
78        let mr = MethodRouter::<AppState>::new().on(
79            method_filter(ep.method),
80            move |s: State<AppState>,
81                  p: AxPath<HashMap<String, String>>,
82                  q: Query<HashMap<String, String>>,
83                  h: HeaderMap,
84                  b: Bytes| handle(s, p, q, h, b, idx_clone),
85        );
86        let merged = std::mem::take(entry).merge(mr);
87        *entry = merged;
88    }
89
90    let mut router = Router::new();
91    for (path, mr) in routes {
92        router = router.route(&path, mr);
93    }
94    router.with_state(state)
95}
96
97fn method_filter(m: Method) -> MethodFilter {
98    match m {
99        Method::Get => MethodFilter::GET,
100        Method::Post => MethodFilter::POST,
101        Method::Put => MethodFilter::PUT,
102        Method::Delete => MethodFilter::DELETE,
103        Method::Patch => MethodFilter::PATCH,
104    }
105}
106
107fn compute_prefix(setup: &Setup) -> String {
108    let base = setup.base.clone().unwrap_or_default();
109    let version = setup
110        .version
111        .map(|v| format!("/v{}", v))
112        .unwrap_or_default();
113    format!("{}{}", base, version)
114}
115
116fn axum_path(segs: &[PathSegment]) -> String {
117    let mut out = String::new();
118    for seg in segs {
119        out.push('/');
120        match seg {
121            PathSegment::Literal(s) => out.push_str(s),
122            PathSegment::Param { name, .. } => {
123                out.push(':');
124                out.push_str(name);
125            }
126        }
127    }
128    if out.is_empty() {
129        "/".into()
130    } else {
131        out
132    }
133}
134
135async fn handle(
136    State(state): State<AppState>,
137    AxPath(path): AxPath<HashMap<String, String>>,
138    Query(query): Query<HashMap<String, String>>,
139    headers: HeaderMap,
140    body: Bytes,
141    endpoint_idx: usize,
142) -> Response {
143    let ep = match state.spec.endpoints.get(endpoint_idx) {
144        Some(e) => e,
145        None => return error_response(StatusCode::INTERNAL_SERVER_ERROR, "endpoint missing"),
146    };
147    tracing::info!(method = ep.method.as_str(), path = %ep.path, "server::handle: incoming request");
148    match handle_inner(&state, ep, path, query, headers, body).await {
149        Ok(r) => r,
150        Err(err) => {
151            tracing::warn!(method = ep.method.as_str(), path = %ep.path, status = %err.status, error = %err.message, "server::handle: returning error");
152            err.into_response()
153        }
154    }
155}
156
157async fn handle_inner(
158    state: &AppState,
159    ep: &Endpoint,
160    path: HashMap<String, String>,
161    query: HashMap<String, String>,
162    headers: HeaderMap,
163    body: Bytes,
164) -> Result<Response, HandlerError> {
165    let setup = &state.spec.setup;
166
167    enforce_body_size(setup, &body)?;
168    authenticate(state, &headers)?;
169
170    let ctx = ExecContext {
171        query: validate_query(setup, ep, &query)?,
172        headers: validate_headers(setup, ep, &headers)?,
173        path: validate_path(ep, &path)?,
174        vars: resolve_vars(ep, &headers)?,
175        body: build_body(ep, body)?,
176    };
177
178    let timeout = setup.timeout_ms.map(Duration::from_millis);
179
180    if state.dry_run {
181        let preview = exec::preview_pipeline(&ep.exec.pipeline, &ctx);
182        tracing::info!(
183            method = ep.method.as_str(),
184            path = %ep.path,
185            stages = ?preview,
186            "dry-run: skipping execution",
187        );
188        let mut body_text = String::from("[dry-run] would execute:\n");
189        for stage in &preview {
190            body_text.push_str("  ");
191            body_text.push_str(stage);
192            body_text.push('\n');
193        }
194        let mut resp = Response::new(body_text.into());
195        resp.headers_mut().insert(
196            header::CONTENT_TYPE,
197            header::HeaderValue::from_static("text/plain; charset=utf-8"),
198        );
199        return Ok(resp);
200    }
201
202    let ExecOutput {
203        status,
204        stdout,
205        stderr,
206    } = exec::run_pipeline(&ep.exec.pipeline, &ctx, timeout)
207        .await
208        .map_err(|e| HandlerError::new(StatusCode::INTERNAL_SERVER_ERROR, e))?;
209
210    if status != 0 {
211        tracing::warn!(
212            method = ep.method.as_str(),
213            path = %ep.path,
214            status,
215            stderr = %String::from_utf8_lossy(&stderr),
216            "exec returned non-zero"
217        );
218        return Err(HandlerError::new(
219            StatusCode::INTERNAL_SERVER_ERROR,
220            format!("command exited with status {}", status),
221        ));
222    }
223
224    let content_type = ep
225        .response_type
226        .clone()
227        .unwrap_or_else(|| "text/plain; charset=utf-8".into());
228    let mut resp = Response::new(stdout.into());
229    resp.headers_mut().insert(
230        header::CONTENT_TYPE,
231        content_type.parse().unwrap_or_else(|_| {
232            header::HeaderValue::from_static("text/plain; charset=utf-8")
233        }),
234    );
235    Ok(resp)
236}
237
238fn check_validation(r: Result<(), ValidationError>, scope: &str) -> Result<(), HandlerError> {
239    r.map_err(|e| HandlerError::new(StatusCode::BAD_REQUEST, format!("{}: {}", scope, e.message)))
240}
241
242fn enforce_body_size(setup: &Setup, body: &Bytes) -> Result<(), HandlerError> {
243    if let Some(max) = setup.max_body_size {
244        if body.len() as u64 > max {
245            return Err(HandlerError::new(
246                StatusCode::PAYLOAD_TOO_LARGE,
247                format!("body exceeds max size of {} bytes", max),
248            ));
249        }
250    }
251    Ok(())
252}
253
254fn authenticate(state: &AppState, headers: &HeaderMap) -> Result<(), HandlerError> {
255    tracing::debug!("server::authenticate");
256    if let Some(AuthSpec::BearerHeader { header: hname, .. }) = &state.spec.setup.auth {
257        let token = extract_bearer(headers, hname)?;
258        verify_token(state, &token)?;
259    }
260    Ok(())
261}
262
263fn enforce_size(
264    actual: usize,
265    max: Option<u64>,
266    status: StatusCode,
267    label: impl FnOnce() -> String,
268) -> Result<(), HandlerError> {
269    if let Some(max) = max {
270        if actual as u64 > max {
271            return Err(HandlerError::new(status, label()));
272        }
273    }
274    Ok(())
275}
276
277fn require_or_optional<T>(
278    found: Option<T>,
279    optional: bool,
280    missing_msg: impl FnOnce() -> String,
281) -> Result<Option<T>, HandlerError> {
282    match found {
283        Some(v) => Ok(Some(v)),
284        None if optional => Ok(None),
285        None => Err(HandlerError::new(StatusCode::BAD_REQUEST, missing_msg())),
286    }
287}
288
289fn validate_query(
290    setup: &Setup,
291    ep: &Endpoint,
292    query: &HashMap<String, String>,
293) -> Result<BTreeMap<String, String>, HandlerError> {
294    tracing::debug!(endpoint = %ep.path, fields = ep.query_params.len(), "server::validate_query");
295    let mut out = BTreeMap::new();
296    for f in &ep.query_params {
297        let v = require_or_optional(query.get(&f.name).cloned(), f.optional, || {
298            format!("missing query parameter `{}`", f.name)
299        })?;
300        if let Some(v) = v {
301            enforce_size(v.len(), setup.max_query_param_size, StatusCode::URI_TOO_LONG, || {
302                format!("query param `{}` exceeds max size", f.name)
303            })?;
304            check_validation(value::validate_text(&v, &f.ty), &format!("query `{}`", f.name))?;
305            out.insert(f.name.clone(), v);
306        }
307    }
308    Ok(out)
309}
310
311fn validate_headers(
312    setup: &Setup,
313    ep: &Endpoint,
314    headers: &HeaderMap,
315) -> Result<BTreeMap<String, String>, HandlerError> {
316    tracing::debug!(endpoint = %ep.path, fields = ep.headers.len(), "server::validate_headers");
317    let mut out = BTreeMap::new();
318    for f in &ep.headers {
319        let v = require_or_optional(header_get(headers, &f.name), f.optional, || {
320            format!("missing header `{}`", f.name)
321        })?;
322        if let Some(v) = v {
323            enforce_size(
324                v.len(),
325                setup.max_header_size,
326                StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
327                || format!("header `{}` exceeds max size", f.name),
328            )?;
329            check_validation(value::validate_text(&v, &f.ty), &format!("header `{}`", f.name))?;
330            out.insert(f.name.clone(), v);
331        }
332    }
333    Ok(out)
334}
335
336fn validate_path(
337    ep: &Endpoint,
338    path: &HashMap<String, String>,
339) -> Result<BTreeMap<String, String>, HandlerError> {
340    tracing::debug!(endpoint = %ep.path, "server::validate_path");
341    let mut out = BTreeMap::new();
342    for seg in &ep.path_segments {
343        if let PathSegment::Param { name, ty, .. } = seg {
344            let v = path.get(name).cloned().ok_or_else(|| {
345                HandlerError::new(StatusCode::BAD_REQUEST, format!("missing path param `{}`", name))
346            })?;
347            check_validation(value::validate_text(&v, ty), &format!("path `{}`", name))?;
348            out.insert(name.clone(), v);
349        }
350    }
351    Ok(out)
352}
353
354fn resolve_vars(
355    ep: &Endpoint,
356    headers: &HeaderMap,
357) -> Result<BTreeMap<String, String>, HandlerError> {
358    tracing::debug!(endpoint = %ep.path, vars = ep.vars.len(), "server::resolve_vars");
359    let mut out = BTreeMap::new();
360    for v in &ep.vars {
361        let resolved = resolve_runtime_source(&v.source, headers)
362            .map_err(|e| HandlerError::new(StatusCode::INTERNAL_SERVER_ERROR, e))?;
363        out.insert(v.name.clone(), resolved);
364    }
365    Ok(out)
366}
367
368fn build_body(ep: &Endpoint, body: Bytes) -> Result<BodyValue, HandlerError> {
369    tracing::debug!(endpoint = %ep.path, body_len = body.len(), "server::build_body");
370    Ok(match &ep.body {
371        None => BodyValue::None,
372        Some(BodySpec::String { .. }) => BodyValue::Text(
373            String::from_utf8(body.to_vec())
374                .map_err(|_| HandlerError::new(StatusCode::BAD_REQUEST, "body is not valid UTF-8"))?,
375        ),
376        Some(BodySpec::Binary { .. }) => BodyValue::Binary(body),
377        Some(BodySpec::Json { schema, .. }) => {
378            let v: serde_json::Value = serde_json::from_slice(&body).map_err(|e| {
379                HandlerError::new(StatusCode::BAD_REQUEST, format!("invalid JSON body: {}", e))
380            })?;
381            if let Some(schema) = schema {
382                check_validation(value::validate_json(&v, schema), "json body")?;
383            }
384            BodyValue::Json(v)
385        }
386        Some(BodySpec::Form { fields, .. }) => {
387            let parsed: BTreeMap<String, String> =
388                form_urlencoded::parse(&body).into_owned().collect();
389            for f in fields {
390                let v = require_or_optional(parsed.get(&f.name), f.optional, || {
391                    format!("missing form field `{}`", f.name)
392                })?;
393                if let Some(v) = v {
394                    check_validation(
395                        value::validate_text(v, &f.ty),
396                        &format!("form field `{}`", f.name),
397                    )?;
398                }
399            }
400            BodyValue::Form(parsed)
401        }
402    })
403}
404
405fn header_get(headers: &HeaderMap, name: &str) -> Option<String> {
406    headers
407        .get(name)
408        .and_then(|v| v.to_str().ok())
409        .map(|s| s.to_string())
410}
411
412fn extract_bearer(headers: &HeaderMap, header_name: &str) -> Result<String, HandlerError> {
413    let raw = header_get(headers, header_name).ok_or_else(|| {
414        HandlerError::new(StatusCode::UNAUTHORIZED, format!("missing `{}`", header_name))
415    })?;
416    let token = raw
417        .strip_prefix("Bearer ")
418        .or_else(|| raw.strip_prefix("bearer "))
419        .unwrap_or(&raw)
420        .trim()
421        .to_string();
422    if token.is_empty() {
423        return Err(HandlerError::new(
424            StatusCode::UNAUTHORIZED,
425            "empty bearer token",
426        ));
427    }
428    Ok(token)
429}
430
431fn verify_token(state: &AppState, token: &str) -> Result<(), HandlerError> {
432    if let Some(verifier) = &state.auth_jwt_verifier {
433        use jsonwebtoken::{DecodingKey, Validation, decode};
434        let key = DecodingKey::from_secret(verifier.as_bytes());
435        let mut validation = Validation::default();
436        validation.validate_exp = true;
437        decode::<serde_json::Value>(token, &key, &validation).map_err(|e| {
438            HandlerError::new(StatusCode::UNAUTHORIZED, format!("invalid token: {}", e))
439        })?;
440        return Ok(());
441    }
442    if let Some(secret) = &state.auth_secret {
443        if constant_time_eq(token.as_bytes(), secret) {
444            return Ok(());
445        }
446        return Err(HandlerError::new(StatusCode::UNAUTHORIZED, "invalid token"));
447    }
448    Ok(())
449}
450
451fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
452    if a.len() != b.len() {
453        return false;
454    }
455    let mut diff: u8 = 0;
456    for (x, y) in a.iter().zip(b.iter()) {
457        diff |= x ^ y;
458    }
459    diff == 0
460}
461
462fn resolve_runtime_source(src: &ValueSource, headers: &HeaderMap) -> Result<String, String> {
463    match src {
464        ValueSource::Env { name, .. } => {
465            std::env::var(name).map_err(|_| format!("env var `{}` not set", name))
466        }
467        ValueSource::Header { name, .. } => header_get(headers, name)
468            .ok_or_else(|| format!("header `{}` not present", name)),
469        ValueSource::Literal { value, .. } => Ok(value.clone()),
470    }
471}
472
473#[derive(Debug)]
474struct HandlerError {
475    status: StatusCode,
476    message: String,
477}
478
479impl HandlerError {
480    fn new(status: StatusCode, msg: impl Into<String>) -> Self {
481        Self {
482            status,
483            message: msg.into(),
484        }
485    }
486}
487
488impl IntoResponse for HandlerError {
489    fn into_response(self) -> Response {
490        error_response(self.status, &self.message)
491    }
492}
493
494fn error_response(status: StatusCode, msg: &str) -> Response {
495    let mut resp = Response::new(format!("{}\n", msg).into());
496    *resp.status_mut() = status;
497    resp.headers_mut().insert(
498        header::CONTENT_TYPE,
499        header::HeaderValue::from_static("text/plain; charset=utf-8"),
500    );
501    resp
502}
503