use base64::Engine;
use http::header::HeaderName;
use std::net::SocketAddr;
use std::time::Duration;
use crate::challenge::error::ConfigError;
pub use crate::otoroshi::protocol::{
DEFAULT_STATE_HEADER, DEFAULT_STATE_RESP_HEADER, DEFAULT_TOKEN_EXPIRY_SECONDS,
};
pub const DEFAULT_LISTEN_PORT: u16 = 8080;
pub const DEFAULT_BACKEND_PORT: u16 = 9000;
pub const DEFAULT_BACKEND_HOST: &str = "127.0.0.1";
pub const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_TOKEN_TTL_SECS: i64 = DEFAULT_TOKEN_EXPIRY_SECONDS;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProtocolVersion {
V1,
V2,
}
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub listen_addr: SocketAddr,
pub backend_url: String,
pub secret: Option<Vec<u8>>,
pub state_header: HeaderName,
pub state_resp_header: HeaderName,
pub request_timeout: Duration,
pub token_ttl: i64,
pub version: ProtocolVersion,
}
impl ProxyConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
port: u16,
backend_host: String,
backend_port: u16,
secret: Option<String>,
secret_base64: bool,
state_header: String,
state_resp_header: String,
timeout_secs: u64,
token_ttl: i64,
use_v1: bool,
) -> Result<Self, ConfigError> {
let state_header = HeaderName::from_bytes(state_header.as_bytes()).map_err(|e| {
ConfigError::InvalidHeader {
name: "state_header",
source: e,
}
})?;
let state_resp_header =
HeaderName::from_bytes(state_resp_header.as_bytes()).map_err(|e| {
ConfigError::InvalidHeader {
name: "state_resp_header",
source: e,
}
})?;
let version = if use_v1 {
ProtocolVersion::V1
} else {
ProtocolVersion::V2
};
if token_ttl <= 0 {
return Err(ConfigError::InvalidTokenTtl(token_ttl));
}
if port == 0 || backend_port == 0 {
return Err(ConfigError::InvalidPort);
}
if backend_host.is_empty() {
return Err(ConfigError::InvalidBackendHost(
"host cannot be empty".to_string(),
));
}
if backend_host.chars().any(|c| c.is_whitespace()) {
return Err(ConfigError::InvalidBackendHost(
"host cannot contain whitespace".to_string(),
));
}
let secret_bytes = match secret {
Some(s) if secret_base64 => Some(base64::engine::general_purpose::STANDARD.decode(&s)?),
Some(s) => Some(s.into_bytes()),
None => None,
};
Ok(ProxyConfig {
listen_addr: SocketAddr::from(([0, 0, 0, 0], port)),
backend_url: format!("http://{}:{}", backend_host, backend_port),
secret: secret_bytes,
state_header,
state_resp_header,
request_timeout: Duration::from_secs(timeout_secs),
token_ttl,
version,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_new_with_defaults_v2() {
let config = ProxyConfig::new(
DEFAULT_LISTEN_PORT,
DEFAULT_BACKEND_HOST.to_string(),
DEFAULT_BACKEND_PORT,
Some("test-secret".to_string()),
false,
DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
DEFAULT_REQUEST_TIMEOUT_SECS,
DEFAULT_TOKEN_TTL_SECS,
false,
);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.listen_addr.port(), 8080);
assert_eq!(config.backend_url, "http://127.0.0.1:9000");
assert_eq!(config.secret, Some(b"test-secret".to_vec()));
assert_eq!(config.state_header.as_str(), "otoroshi-state");
assert_eq!(config.state_resp_header.as_str(), "otoroshi-state-resp");
assert_eq!(config.request_timeout, Duration::from_secs(30));
assert_eq!(config.token_ttl, 30);
assert_eq!(config.version, ProtocolVersion::V2);
}
#[test]
fn test_config_v1_mode() {
let config = ProxyConfig::new(
8080,
"127.0.0.1".to_string(),
9000,
None,
false,
DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
30,
true,
);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.version, ProtocolVersion::V1);
assert!(config.secret.is_none());
}
#[test]
fn test_config_custom_values() {
let config = ProxyConfig::new(
3000,
"localhost".to_string(),
8000,
Some("my-secret".to_string()),
false,
"X-Challenge".to_string(),
"X-Challenge-Resp".to_string(),
60,
45,
false,
);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.listen_addr.port(), 3000);
assert_eq!(config.backend_url, "http://localhost:8000");
assert_eq!(config.request_timeout, Duration::from_secs(60));
assert_eq!(config.token_ttl, 45);
}
#[test]
fn test_config_base64_secret() {
let config = ProxyConfig::new(
8080,
"127.0.0.1".to_string(),
9000,
Some("aGVsbG8=".to_string()),
true,
DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
30,
false,
);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.secret, Some(b"hello".to_vec()));
}
#[test]
fn test_config_invalid_header() {
let config = ProxyConfig::new(
8080,
"127.0.0.1".to_string(),
9000,
Some("secret".to_string()),
false,
"Invalid Header With Spaces".to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
30,
false,
);
assert!(config.is_err());
}
#[test]
fn test_config_invalid_base64_secret() {
let config = ProxyConfig::new(
8080,
"127.0.0.1".to_string(),
9000,
Some("not-valid-base64!!!".to_string()),
true, DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
30,
false,
);
assert!(config.is_err());
assert!(matches!(
config.unwrap_err(),
ConfigError::InvalidBase64Secret(_)
));
}
#[test]
fn test_config_invalid_ttl_zero() {
let config = ProxyConfig::new(
8080,
"127.0.0.1".to_string(),
9000,
Some("secret".to_string()),
false,
DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
0, false,
);
assert!(config.is_err());
assert!(matches!(
config.unwrap_err(),
ConfigError::InvalidTokenTtl(0)
));
}
#[test]
fn test_config_invalid_ttl_negative() {
let config = ProxyConfig::new(
8080,
"127.0.0.1".to_string(),
9000,
Some("secret".to_string()),
false,
DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
-10, false,
);
assert!(config.is_err());
assert!(matches!(
config.unwrap_err(),
ConfigError::InvalidTokenTtl(-10)
));
}
#[test]
fn test_config_invalid_backend_host_empty() {
let config = ProxyConfig::new(
8080,
"".to_string(), 9000,
Some("secret".to_string()),
false,
DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
30,
false,
);
assert!(config.is_err());
assert!(matches!(
config.unwrap_err(),
ConfigError::InvalidBackendHost(_)
));
}
#[test]
fn test_config_invalid_backend_host_whitespace() {
let config = ProxyConfig::new(
8080,
"host with spaces".to_string(),
9000,
Some("secret".to_string()),
false,
DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
30,
false,
);
assert!(config.is_err());
assert!(matches!(
config.unwrap_err(),
ConfigError::InvalidBackendHost(_)
));
}
#[test]
fn test_config_invalid_port_zero() {
let config = ProxyConfig::new(
0, "127.0.0.1".to_string(),
9000,
Some("secret".to_string()),
false,
DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
30,
false,
);
assert!(config.is_err());
assert!(matches!(config.unwrap_err(), ConfigError::InvalidPort));
}
#[test]
fn test_config_invalid_backend_port_zero() {
let config = ProxyConfig::new(
8080,
"127.0.0.1".to_string(),
0, Some("secret".to_string()),
false,
DEFAULT_STATE_HEADER.to_string(),
DEFAULT_STATE_RESP_HEADER.to_string(),
30,
30,
false,
);
assert!(config.is_err());
assert!(matches!(config.unwrap_err(), ConfigError::InvalidPort));
}
}