use std::time::Duration;
#[cfg(feature = "dto")]
use dto::{FromProto, IntoProto};
#[cfg(feature = "validation")]
use ipext::IpExt;
use monostate::MustBe;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationSecondsWithFrac};
use strum::Display;
#[cfg(feature = "validation")]
use thiserror::Error;
#[cfg(feature = "validation")]
use url::Url;
#[cfg(feature = "validation")]
use validator::{Validate, ValidationError};
#[cfg(feature = "validation")]
use crate::validation_util::validation_error;
#[derive(Debug, Display, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[cfg_attr(feature = "client", non_exhaustive)]
#[cfg_attr(
feature = "dto",
derive(IntoProto, FromProto),
proto(target = "proto::common::HttpMethod")
)]
#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
#[cfg_attr(feature = "clap", clap(rename_all = "UPPER"))]
#[serde(rename_all = "UPPERCASE")]
#[strum(serialize_all = "UPPERCASE")]
pub enum HttpMethod {
Delete,
Get,
Head,
Patch,
Post,
Put,
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[cfg_attr(feature = "validation", derive(Validate))]
#[cfg_attr(
feature = "dto",
derive(IntoProto, FromProto),
proto(target = "proto::common::Webhook")
)]
#[cfg_attr(feature = "server", serde(deny_unknown_fields), serde(default))]
pub struct Webhook {
#[serde(rename = "type")]
_kind: MustBe!("webhook"),
#[cfg_attr(
feature = "validation",
validate(required, custom = "validate_webhook_url")
)]
#[cfg_attr(feature = "dto", proto(required))]
pub url: Option<String>,
pub http_method: HttpMethod,
#[cfg_attr(feature = "validation", validate(custom = "validate_timeout"))]
#[serde_as(as = "DurationSecondsWithFrac")]
#[cfg_attr(
feature = "dto",
into_proto(map = "std::time::Duration::as_secs_f64", map_by_ref),
from_proto(map = "Duration::from_secs_f64")
)]
pub timeout_s: std::time::Duration,
pub retry: Option<RetryConfig>,
}
#[cfg(feature = "server")]
impl Default for Webhook {
fn default() -> Self {
Self {
_kind: Default::default(),
url: None,
http_method: HttpMethod::Post,
timeout_s: Duration::from_secs(5),
retry: None,
}
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[cfg_attr(feature = "client", non_exhaustive)]
#[cfg_attr(
feature = "dto",
derive(IntoProto, FromProto),
proto(target = "proto::common::RetryConfig", oneof = "policy")
)]
#[cfg_attr(feature = "server", serde(deny_unknown_fields))]
#[serde(rename_all = "snake_case")]
#[serde(untagged)]
pub enum RetryConfig {
#[cfg_attr(feature = "dto", proto(name = "Simple"))]
SimpleRetry(SimpleRetry),
#[cfg_attr(feature = "dto", proto(name = "ExponentialBackoff"))]
ExponentialBackoffRetry(ExponentialBackoffRetry),
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[cfg_attr(feature = "validation", derive(Validate))]
#[cfg_attr(
feature = "dto",
derive(IntoProto, FromProto),
proto(target = "proto::common::SimpleRetry")
)]
#[cfg_attr(feature = "server", serde(default), serde(deny_unknown_fields))]
pub struct SimpleRetry {
#[serde(rename = "type")]
_kind: MustBe!("simple"),
#[cfg_attr(feature = "validation", validate(range(min = 1, max = 10)))]
pub max_num_attempts: u32,
#[serde_as(as = "DurationSecondsWithFrac")]
#[cfg_attr(
feature = "dto",
into_proto(map = "std::time::Duration::as_secs_f64", map_by_ref),
from_proto(map = "Duration::from_secs_f64")
)]
#[cfg_attr(
feature = "validation",
validate(custom = "validate_retry_delay")
)]
pub delay_s: Duration,
}
#[cfg(feature = "server")]
impl Default for SimpleRetry {
fn default() -> Self {
Self {
_kind: Default::default(),
max_num_attempts: 5,
delay_s: Duration::from_secs(60),
}
}
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[cfg_attr(feature = "validation", derive(Validate))]
#[cfg_attr(
feature = "dto",
derive(IntoProto, FromProto),
proto(target = "proto::common::ExponentialBackoffRetry")
)]
#[serde(deny_unknown_fields)]
pub struct ExponentialBackoffRetry {
#[serde(rename = "type")]
_kind: MustBe!("exponential_backoff"),
#[cfg_attr(feature = "validation", validate(range(min = 1, max = 10)))]
pub max_num_attempts: u32,
#[serde_as(as = "DurationSecondsWithFrac")]
#[cfg_attr(
feature = "dto",
into_proto(map = "std::time::Duration::as_secs_f64", map_by_ref),
from_proto(map = "Duration::from_secs_f64")
)]
#[cfg_attr(
feature = "validation",
validate(custom = "validate_retry_delay")
)]
pub delay_s: Duration,
#[serde_as(as = "DurationSecondsWithFrac")]
#[cfg_attr(
feature = "dto",
into_proto(map = "std::time::Duration::as_secs_f64", map_by_ref),
from_proto(map = "Duration::from_secs_f64")
)]
#[cfg_attr(
feature = "validation",
validate(custom = "validate_retry_delay")
)]
pub max_delay_s: Duration,
}
#[cfg(feature = "validation")]
fn validate_timeout(timeout: &Duration) -> Result<(), ValidationError> {
if timeout.as_secs_f64() < 1.0 || timeout.as_secs_f64() > 30.0 {
return Err(validation_error(
"invalid_timeout",
"Timeout must be between 1.0 and 30.0 seconds".to_string(),
));
};
Ok(())
}
#[cfg(feature = "validation")]
fn validate_retry_delay(delay: &Duration) -> Result<(), ValidationError> {
if delay.as_secs_f64() < 5.0 || delay.as_secs_f64() > 300.0 {
return Err(validation_error(
"invalid_delay",
"Retry delay must be between 5.0 and 300.0 seconds".to_string(),
));
};
Ok(())
}
#[cfg(feature = "validation")]
#[derive(Error, Debug)]
enum WebhookUrlValidationError {
#[error("Failed to parse url: {0}")]
InvalidUrl(String),
#[error(
"Unsupported url scheme: {0}. Only 'http' and 'https' are supported"
)]
UnsupportedScheme(String),
#[error("Failed to resolve ip of url '{0}'")]
Dns(String),
#[error("Domain resolves to non-routable public IP: {0}")]
NonRoutableIp(String),
}
#[cfg(feature = "validation")]
pub fn validate_webhook_url(url_string: &str) -> Result<(), ValidationError> {
let url = Url::parse(url_string)
.map_err(|e| WebhookUrlValidationError::InvalidUrl(e.to_string()))?;
validate_endpoint_scheme(url.scheme())?;
validate_endpoint_url_public_ip(&url)?;
Ok(())
}
#[cfg(feature = "validation")]
fn validate_endpoint_url_public_ip(
url: &Url,
) -> Result<(), WebhookUrlValidationError> {
if let Ok(val) = std::env::var("CRONBACK__SKIP_PUBLIC_IP_VALIDATION") {
eprintln!(
"Skipping public ip validation because \
'CRONBACK__SKIP_PUBLIC_IP_VALIDATION' env is set to {val}!"
);
return Ok(());
}
let addrs = url
.socket_addrs(|| None)
.map_err(|_| WebhookUrlValidationError::Dns(url.to_string()))?;
for addr in addrs {
if !IpExt::is_global(&addr.ip()) {
return Err(WebhookUrlValidationError::NonRoutableIp(
addr.ip().to_string(),
));
}
}
Ok(())
}
#[cfg(feature = "validation")]
fn validate_endpoint_scheme(
scheme: &str,
) -> Result<(), WebhookUrlValidationError> {
if scheme == "http" || scheme == "https" {
Ok(())
} else {
Err(WebhookUrlValidationError::UnsupportedScheme(
scheme.to_string(),
))
}
}
#[cfg(feature = "validation")]
impl From<WebhookUrlValidationError> for ValidationError {
fn from(value: WebhookUrlValidationError) -> Self {
validation_error("EMIT_VALIDATION_FAILED", value.to_string())
}
}
#[cfg(all(test, feature = "validation"))]
mod tests {
use super::{validate_webhook_url, HttpMethod};
#[test]
fn http_method_to_string() {
assert_eq!("GET", HttpMethod::Get.to_string());
assert_eq!("POST", HttpMethod::Post.to_string());
assert_eq!("PATCH", HttpMethod::Patch.to_string());
assert_eq!("DELETE", HttpMethod::Delete.to_string());
assert_eq!("PUT", HttpMethod::Put.to_string());
assert_eq!("HEAD", HttpMethod::Head.to_string());
}
#[test]
fn valid_urls() {
std::env::remove_var("CRONBACK__SKIP_PUBLIC_IP_VALIDATION");
let urls = vec![
"https://google.com/url",
"https://example.com:3030/url",
"https://1.1.1.1/url",
"http://[2606:4700:4700::1111]/another_url/path",
"http://[2606:4700:4700::1111]:5050/another_url/path",
"http://user:pass@google.com/another_url/path",
];
for url in urls {
let result = validate_webhook_url(url);
assert!(
matches!(result, Ok(())),
"URL: {}, result: {:?}",
url,
result,
);
}
}
#[test]
fn invalid_urls() {
std::env::remove_var("CRONBACK__SKIP_PUBLIC_IP_VALIDATION");
let urls = vec![
"https://10.0.10.1",
"https://192.168.1.1",
"https://[::1]:80",
"ftp://google.com",
"https://localhost/url",
"google.com/url",
"http---@goog.com",
"https://ppqqzonlnp.io/url/url",
];
for url in urls {
let result = validate_webhook_url(url);
assert!(
matches!(result, Err(_)),
"URL: {}, result: {:?}",
url,
result
);
}
}
}