mise_server/
mise.rs

1use crate::server::Server;
2use http::StatusCode;
3use metrics::{Counter, Histogram, counter, histogram};
4use serde_json::Value;
5use std::{
6    collections::HashMap,
7    net::SocketAddr,
8    panic::{AssertUnwindSafe, catch_unwind},
9    time::Instant,
10};
11use tracing::{trace, warn};
12
13pub struct Request(pub http::Request<Value>);
14pub struct Response(pub http::Response<Value>);
15
16/// Final routes are always in JSON.
17pub trait Route: Fn(Request) -> Response + 'static {}
18impl<T> Route for T where T: Fn(Request) -> Response + 'static {}
19
20// In some cases it's possible to install a route that only returns text.
21// this is used for certain side behavior such as returning prometheus scrape
22// renders. These are meant to be static and have no parameter and respond only
23// on absolute paths: no request object is available and the response is always
24// a string.
25pub trait TextRoute: Fn() -> String + 'static {}
26impl<T> TextRoute for T where T: Fn() -> String + 'static {}
27
28/// Server resource and entry point.
29#[derive(Default)]
30pub struct Mise {
31    routes: HashMap<RouteMethod, HashMap<String, RouteContext>>,
32    text_routes: HashMap<String, Box<dyn TextRoute>>,
33}
34
35impl Mise {
36    /// Create a new default server.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Register a get route.
42    pub fn get<F: Route>(mut self, path: &str, f: F) -> Self {
43        let get_routes = self.routes.entry(RouteMethod::Get).or_default();
44        let d: Box<dyn Route> = Box::new(f);
45        get_routes.insert(path.to_string(), (path, d).into());
46        self
47    }
48
49    /// Register a text result only.
50    /// This is used for cases such as prometheus scrape renders.
51    /// Text routes always take precedence even if the route is already defined
52    /// for any other methods.
53    pub fn text<F: TextRoute>(mut self, path: &str, f: F) -> Self {
54        self.text_routes.insert(path.to_string(), Box::new(f));
55        self
56    }
57
58    /// Registers a post route. Body is obtained from [Request::body] and it is
59    /// always a json [Value].
60    pub fn post<F: Route>(mut self, path: &str, f: F) -> Self {
61        let get_routes = self.routes.entry(RouteMethod::Post).or_default();
62        let d: Box<dyn Route> = Box::new(f);
63        get_routes.insert(path.to_string(), (path, d).into());
64        self
65    }
66
67    /// Starts the server. Blocks until the server quits.
68    /// Can panic if cannot bind the server.
69    pub fn serve(self, addr: SocketAddr) {
70        Server::serve(
71            RequestProcessor {
72                routes: self.routes,
73                text_routes: self.text_routes,
74            },
75            addr,
76        )
77        .run();
78    }
79}
80
81impl From<Request> for Value {
82    fn from(value: Request) -> Self {
83        value.0.body().to_owned()
84    }
85}
86
87impl From<Value> for Response {
88    fn from(value: Value) -> Self {
89        Response(http::Response::new(value))
90    }
91}
92
93impl From<http::StatusCode> for Response {
94    fn from(value: http::StatusCode) -> Self {
95        Response(
96            http::Response::builder()
97                .status(value)
98                .body(Value::Null)
99                .expect("Statically built body should not fail"),
100        )
101    }
102}
103
104struct RouteMetrics {
105    // Histograms also count the events so no need to also have a counter for
106    // each request. The histogram will do both.
107    call_seconds: Histogram,
108    // call_seconds only measure successful requests, otherwise use this for
109    // marking when a route returns a 500/panic
110    error_count: Counter,
111}
112
113impl RouteMetrics {
114    fn new(name: &str) -> Self {
115        Self {
116            call_seconds: histogram!("http_request_seconds", "uri" => name.to_string()),
117            error_count: counter!("http_error", "uri" => name.to_string()),
118        }
119    }
120}
121
122struct RouteContext {
123    route: Box<dyn Route>,
124    metrics: RouteMetrics,
125}
126
127impl From<(&str, Box<dyn Route>)> for RouteContext {
128    fn from(value: (&str, Box<dyn Route>)) -> Self {
129        Self {
130            route: value.1,
131            metrics: RouteMetrics::new(value.0),
132        }
133    }
134}
135
136#[derive(Clone, Debug, Eq, Hash, PartialEq)]
137pub(crate) enum RouteMethod {
138    Get,
139    Post,
140}
141
142impl TryFrom<&str> for RouteMethod {
143    type Error = ();
144
145    fn try_from(value: &str) -> Result<Self, Self::Error> {
146        if value.eq_ignore_ascii_case("GET") {
147            return Ok(RouteMethod::Get);
148        }
149        if value.eq_ignore_ascii_case("POST") {
150            return Ok(RouteMethod::Post);
151        }
152        Err(())
153    }
154}
155
156pub(crate) struct RequestProcessor {
157    routes: HashMap<RouteMethod, HashMap<String, RouteContext>>,
158    text_routes: HashMap<String, Box<dyn TextRoute>>,
159}
160
161pub(crate) enum ProcessorResponse {
162    Json(Response),
163    Text(String),
164}
165
166impl RequestProcessor {
167    pub(crate) fn process(&self, req: Request) -> ProcessorResponse {
168        if let Some(txt) = self.text_routes.get(req.0.uri().path()) {
169            return ProcessorResponse::Text(txt());
170        }
171
172        let start = Instant::now();
173
174        let Ok(method) = req.0.method().as_str().try_into() else {
175            return ProcessorResponse::Json(StatusCode::METHOD_NOT_ALLOWED.into());
176        };
177        let Some(routes) = self.routes.get(&method) else {
178            return ProcessorResponse::Json(StatusCode::NOT_FOUND.into());
179        };
180        let Some(route) = routes.get(req.0.uri().path()) else {
181            return ProcessorResponse::Json(StatusCode::NOT_FOUND.into());
182        };
183
184        trace!("Received {:?}", req.0);
185        let res = catch_unwind(AssertUnwindSafe(|| (route.route)(req)));
186        let resp = match res {
187            Ok(resp) => resp,
188            Err(ref e) => {
189                route.metrics.error_count.increment(1);
190                warn!("Internal server error: {e:?}");
191                return ProcessorResponse::Json(StatusCode::INTERNAL_SERVER_ERROR.into());
192            }
193        };
194        trace!("Sending {:?}", resp.0);
195
196        route
197            .metrics
198            .call_seconds
199            .record(Instant::now().duration_since(start).as_secs_f64());
200
201        ProcessorResponse::Json(resp)
202    }
203}
204
205impl Request {
206    /// Returns the query param by name.
207    pub fn query_param(&self, name: &str) -> Option<String> {
208        let q = self.0.uri().query()?;
209        let f = format!("{name}=");
210        let idx = q.find(&f)?;
211        let end = q[idx..].find('&').unwrap_or(q.len());
212        Some(q[idx + f.len()..end].to_string())
213    }
214
215    /// If there is a body in the request, then this will be the json of that
216    pub fn body(&self) -> &Value {
217        self.0.body()
218    }
219}