use std::collections::HashMap;
use std::path::Path;
use std::str::FromStr;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum Method {
Get,
Post,
Put,
Patch,
Delete,
}
impl Method {
pub fn from_http_str(s: &str) -> Option<Self> {
match s.to_ascii_uppercase().as_str() {
"GET" => Some(Method::Get),
"POST" => Some(Method::Post),
"PUT" => Some(Method::Put),
"PATCH" => Some(Method::Patch),
"DELETE" => Some(Method::Delete),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Method::Get => "GET",
Method::Post => "POST",
Method::Put => "PUT",
Method::Patch => "PATCH",
Method::Delete => "DELETE",
}
}
}
impl FromStr for Method {
type Err = UnknownMethodError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Method::from_http_str(s).ok_or_else(|| UnknownMethodError(s.to_string()))
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[error("unsupported HTTP method: {0}")]
pub struct UnknownMethodError(pub String);
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct RequestMatch {
#[serde(default)]
pub query: HashMap<String, String>,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub body: Option<Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ResponseConfig {
#[serde(default = "default_status")]
pub status: u16,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub body: Option<Value>,
#[serde(
default,
with = "duration_option",
skip_serializing_if = "Option::is_none"
)]
pub delay: Option<Duration>,
#[serde(default)]
pub close_connection: bool,
}
impl Default for ResponseConfig {
fn default() -> Self {
ResponseConfig {
status: default_status(),
headers: HashMap::new(),
body: None,
delay: None,
close_connection: false,
}
}
}
fn default_status() -> u16 {
200
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ResponseSpec {
Sequence {
sequence: Vec<ResponseConfig>,
},
Single(ResponseConfig),
}
impl ResponseSpec {
pub fn into_responses(self) -> Vec<ResponseConfig> {
match self {
ResponseSpec::Single(r) => vec![r],
ResponseSpec::Sequence { sequence } => sequence,
}
}
}
impl Default for ResponseSpec {
fn default() -> Self {
ResponseSpec::Single(ResponseConfig::default())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Route {
pub method: Method,
pub path: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub when: Option<RequestMatch>,
#[serde(default)]
pub response: ResponseSpec,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Config {
#[serde(default = "default_listen")]
pub listen: String,
#[serde(default)]
pub routes: Vec<Route>,
}
impl Default for Config {
fn default() -> Self {
Config {
listen: default_listen(),
routes: Vec::new(),
}
}
}
fn default_listen() -> String {
":8080".to_string()
}
impl Config {
pub fn parse(input: &str) -> Result<Self, ConfigError> {
serde_yaml::from_str(input).map_err(ConfigError::from)
}
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
let contents = std::fs::read_to_string(path.as_ref()).map_err(ConfigError::Read)?;
Self::parse(&contents)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("could not read config file: {0}")]
Read(#[source] std::io::Error),
#[error("could not parse config: {0}")]
Parse(#[from] serde_yaml::Error),
}
mod duration_option {
use super::*;
pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match value {
None => serializer.serialize_none(),
Some(d) => serializer.serialize_str(&humantime::format_duration(*d).to_string()),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: serde::Deserializer<'de>,
{
let opt: Option<String> = Option::deserialize(deserializer)?;
match opt {
None => Ok(None),
Some(raw) => humantime::parse_duration(&raw)
.map(Some)
.map_err(serde::de::Error::custom),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn method_round_trip_uppercase() {
let yaml = "GET";
let method: Method = serde_yaml::from_str(yaml).unwrap();
assert_eq!(method, Method::Get);
let back: String = serde_yaml::to_string(&method).unwrap();
assert!(back.contains("GET"));
}
#[test]
fn method_from_http_str_is_case_insensitive() {
assert_eq!(Method::from_http_str("get"), Some(Method::Get));
assert_eq!(Method::from_http_str("Delete"), Some(Method::Delete));
assert_eq!(Method::from_http_str("FOO"), None);
}
#[test]
fn empty_uses_defaults() {
let cfg = Config::parse("").unwrap();
assert_eq!(cfg.listen, ":8080");
assert!(cfg.routes.is_empty());
}
#[test]
fn parses_full_example() {
let yaml = r#"
listen: ":9090"
routes:
- method: GET
path: /users/{id}
when:
query:
role: admin
response:
status: 200
delay: 2s
body:
id: "{{path.id}}"
"#;
let cfg = Config::parse(yaml).unwrap();
assert_eq!(cfg.listen, ":9090");
assert_eq!(cfg.routes.len(), 1);
let route = &cfg.routes[0];
assert_eq!(route.method, Method::Get);
assert_eq!(route.path, "/users/{id}");
assert_eq!(
route.when.as_ref().unwrap().query.get("role").unwrap(),
"admin"
);
let resp = match &route.response {
ResponseSpec::Single(r) => r,
ResponseSpec::Sequence { .. } => panic!("expected Single"),
};
assert_eq!(resp.status, 200);
assert_eq!(resp.delay, Some(Duration::from_secs(2)));
}
#[test]
fn invalid_yaml_is_rejected() {
let yaml = "listen: :8080\n routes: [broken\n";
assert!(Config::parse(yaml).is_err());
}
#[test]
fn unknown_method_is_rejected() {
let yaml = "routes:\n - method: FOO\n path: /x\n";
assert!(Config::parse(yaml).is_err());
}
#[test]
fn round_trip_keeps_listen_and_routes() {
let yaml = r#"
listen: ":8080"
routes:
- method: POST
path: /items
response:
status: 201
"#;
let cfg = Config::parse(yaml).unwrap();
let reserialized = serde_yaml::to_string(&cfg).unwrap();
let cfg2 = Config::parse(&reserialized).unwrap();
assert_eq!(cfg, cfg2);
}
#[test]
fn missing_file_errors() {
let err = Config::from_file("/nonexistent/path/to/config.yaml").unwrap_err();
assert!(matches!(err, ConfigError::Read(_)));
}
#[test]
fn parses_sequence_response() {
let yaml = r#"
routes:
- method: GET
path: /flaky
response:
sequence:
- status: 500
- status: 200
body:
ok: true
"#;
let cfg = Config::parse(yaml).unwrap();
let route = &cfg.routes[0];
match &route.response {
ResponseSpec::Sequence { sequence } => {
assert_eq!(sequence.len(), 2);
assert_eq!(sequence[0].status, 500);
assert_eq!(sequence[1].status, 200);
assert_eq!(
sequence[1].body,
Some(Value::Object(serde_json::Map::from_iter([(
"ok".to_string(),
Value::Bool(true)
)])))
);
}
other => panic!("expected Sequence, got {other:?}"),
}
}
#[test]
fn parses_single_response_by_default() {
let yaml = r#"
routes:
- method: GET
path: /health
response:
status: 200
body:
ok: true
"#;
let cfg = Config::parse(yaml).unwrap();
assert!(matches!(cfg.routes[0].response, ResponseSpec::Single(_)));
}
#[test]
fn sequence_round_trip() {
let yaml = r#"
routes:
- method: GET
path: /retry
response:
sequence:
- status: 500
- status: 200
"#;
let cfg = Config::parse(yaml).unwrap();
let reserialized = serde_yaml::to_string(&cfg).unwrap();
let cfg2 = Config::parse(&reserialized).unwrap();
assert_eq!(cfg, cfg2);
}
}