cronback_api_model/
webhook.rs

1use std::time::Duration;
2
3#[cfg(feature = "dto")]
4use dto::{FromProto, IntoProto};
5#[cfg(feature = "validation")]
6use ipext::IpExt;
7use monostate::MustBe;
8use serde::{Deserialize, Serialize};
9use serde_with::{serde_as, skip_serializing_none, DurationSecondsWithFrac};
10use strum::Display;
11#[cfg(feature = "validation")]
12use thiserror::Error;
13#[cfg(feature = "validation")]
14use url::Url;
15#[cfg(feature = "validation")]
16use validator::{Validate, ValidationError};
17
18#[cfg(feature = "validation")]
19use crate::validation_util::validation_error;
20
21#[derive(Debug, Display, Clone, Copy, Serialize, Deserialize, PartialEq)]
22#[cfg_attr(feature = "client", non_exhaustive)]
23#[cfg_attr(
24    feature = "dto",
25    derive(IntoProto, FromProto),
26    proto(target = "proto::common::HttpMethod")
27)]
28#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
29#[cfg_attr(feature = "clap", clap(rename_all = "UPPER"))]
30#[serde(rename_all = "UPPERCASE")]
31#[strum(serialize_all = "UPPERCASE")]
32pub enum HttpMethod {
33    Delete,
34    Get,
35    Head,
36    Patch,
37    Post,
38    Put,
39}
40
41#[serde_as]
42#[skip_serializing_none]
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
44#[cfg_attr(feature = "validation", derive(Validate))]
45#[cfg_attr(
46    feature = "dto",
47    derive(IntoProto, FromProto),
48    proto(target = "proto::common::Webhook")
49)]
50#[cfg_attr(feature = "server", serde(deny_unknown_fields), serde(default))]
51pub struct Webhook {
52    // allows an optional "type" field to be passed in. This enables other
53    // variants of action to be differentiated.
54    #[serde(rename = "type")]
55    _kind: MustBe!("webhook"),
56    #[cfg_attr(
57        feature = "validation",
58        validate(required, custom = "validate_webhook_url")
59    )]
60    #[cfg_attr(feature = "dto", proto(required))]
61    pub url: Option<String>,
62    pub http_method: HttpMethod,
63    #[cfg_attr(feature = "validation", validate(custom = "validate_timeout"))]
64    #[serde_as(as = "DurationSecondsWithFrac")]
65    #[cfg_attr(
66        feature = "dto",
67        into_proto(map = "std::time::Duration::as_secs_f64", map_by_ref),
68        from_proto(map = "Duration::from_secs_f64")
69    )]
70    pub timeout_s: std::time::Duration,
71    // None means no retry
72    pub retry: Option<RetryConfig>,
73}
74
75#[cfg(feature = "server")]
76impl Default for Webhook {
77    fn default() -> Self {
78        Self {
79            _kind: Default::default(),
80            url: None,
81            http_method: HttpMethod::Post,
82            timeout_s: Duration::from_secs(5),
83            retry: None,
84        }
85    }
86}
87
88#[skip_serializing_none]
89#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
90#[cfg_attr(feature = "client", non_exhaustive)]
91#[cfg_attr(
92    feature = "dto",
93    derive(IntoProto, FromProto),
94    proto(target = "proto::common::RetryConfig", oneof = "policy")
95)]
96#[cfg_attr(feature = "server", serde(deny_unknown_fields))]
97#[serde(rename_all = "snake_case")]
98#[serde(untagged)]
99pub enum RetryConfig {
100    #[cfg_attr(feature = "dto", proto(name = "Simple"))]
101    SimpleRetry(SimpleRetry),
102    #[cfg_attr(feature = "dto", proto(name = "ExponentialBackoff"))]
103    ExponentialBackoffRetry(ExponentialBackoffRetry),
104}
105
106#[serde_as]
107#[skip_serializing_none]
108#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
109#[cfg_attr(feature = "validation", derive(Validate))]
110#[cfg_attr(
111    feature = "dto",
112    derive(IntoProto, FromProto),
113    proto(target = "proto::common::SimpleRetry")
114)]
115#[cfg_attr(feature = "server", serde(default), serde(deny_unknown_fields))]
116pub struct SimpleRetry {
117    #[serde(rename = "type")]
118    _kind: MustBe!("simple"),
119    #[cfg_attr(feature = "validation", validate(range(min = 1, max = 10)))]
120    pub max_num_attempts: u32,
121    #[serde_as(as = "DurationSecondsWithFrac")]
122    #[cfg_attr(
123        feature = "dto",
124        into_proto(map = "std::time::Duration::as_secs_f64", map_by_ref),
125        from_proto(map = "Duration::from_secs_f64")
126    )]
127    #[cfg_attr(
128        feature = "validation",
129        validate(custom = "validate_retry_delay")
130    )]
131    pub delay_s: Duration,
132}
133
134#[cfg(feature = "server")]
135impl Default for SimpleRetry {
136    fn default() -> Self {
137        Self {
138            _kind: Default::default(),
139            max_num_attempts: 5,
140            delay_s: Duration::from_secs(60),
141        }
142    }
143}
144
145#[serde_as]
146#[skip_serializing_none]
147#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
148#[cfg_attr(feature = "validation", derive(Validate))]
149#[cfg_attr(
150    feature = "dto",
151    derive(IntoProto, FromProto),
152    proto(target = "proto::common::ExponentialBackoffRetry")
153)]
154#[serde(deny_unknown_fields)]
155pub struct ExponentialBackoffRetry {
156    #[serde(rename = "type")]
157    _kind: MustBe!("exponential_backoff"),
158    #[cfg_attr(feature = "validation", validate(range(min = 1, max = 10)))]
159    pub max_num_attempts: u32,
160    #[serde_as(as = "DurationSecondsWithFrac")]
161    #[cfg_attr(
162        feature = "dto",
163        into_proto(map = "std::time::Duration::as_secs_f64", map_by_ref),
164        from_proto(map = "Duration::from_secs_f64")
165    )]
166    #[cfg_attr(
167        feature = "validation",
168        validate(custom = "validate_retry_delay")
169    )]
170    pub delay_s: Duration,
171    #[serde_as(as = "DurationSecondsWithFrac")]
172    #[cfg_attr(
173        feature = "dto",
174        into_proto(map = "std::time::Duration::as_secs_f64", map_by_ref),
175        from_proto(map = "Duration::from_secs_f64")
176    )]
177    #[cfg_attr(
178        feature = "validation",
179        validate(custom = "validate_retry_delay")
180    )]
181    pub max_delay_s: Duration,
182}
183
184#[cfg(feature = "validation")]
185fn validate_timeout(timeout: &Duration) -> Result<(), ValidationError> {
186    if timeout.as_secs_f64() < 1.0 || timeout.as_secs_f64() > 30.0 {
187        return Err(validation_error(
188            "invalid_timeout",
189            "Timeout must be between 1.0 and 30.0 seconds".to_string(),
190        ));
191    };
192    Ok(())
193}
194
195#[cfg(feature = "validation")]
196fn validate_retry_delay(delay: &Duration) -> Result<(), ValidationError> {
197    if delay.as_secs_f64() < 5.0 || delay.as_secs_f64() > 300.0 {
198        return Err(validation_error(
199            "invalid_delay",
200            "Retry delay must be between 5.0 and 300.0 seconds".to_string(),
201        ));
202    };
203    Ok(())
204}
205
206#[cfg(feature = "validation")]
207#[derive(Error, Debug)]
208enum WebhookUrlValidationError {
209    #[error("Failed to parse url: {0}")]
210    InvalidUrl(String),
211
212    #[error(
213        "Unsupported url scheme: {0}. Only 'http' and 'https' are supported"
214    )]
215    UnsupportedScheme(String),
216
217    #[error("Failed to resolve ip of url '{0}'")]
218    Dns(String),
219
220    #[error("Domain resolves to non-routable public IP: {0}")]
221    NonRoutableIp(String),
222}
223
224#[cfg(feature = "validation")]
225pub fn validate_webhook_url(url_string: &str) -> Result<(), ValidationError> {
226    let url = Url::parse(url_string)
227        .map_err(|e| WebhookUrlValidationError::InvalidUrl(e.to_string()))?;
228    validate_endpoint_scheme(url.scheme())?;
229    validate_endpoint_url_public_ip(&url)?;
230
231    Ok(())
232}
233
234#[cfg(feature = "validation")]
235fn validate_endpoint_url_public_ip(
236    url: &Url,
237) -> Result<(), WebhookUrlValidationError> {
238    // TODO: Move to a non-global setting.
239    if let Ok(val) = std::env::var("CRONBACK__SKIP_PUBLIC_IP_VALIDATION") {
240        eprintln!(
241            "Skipping public ip validation because  \
242             'CRONBACK__SKIP_PUBLIC_IP_VALIDATION' env is set to {val}!"
243        );
244        return Ok(());
245    }
246    // This function does the DNS resolution. Unfortunately, it's synchronous.
247    let addrs = url
248        // TODO: Replace with non-blocking nameservice lookup
249        .socket_addrs(|| None)
250        .map_err(|_| WebhookUrlValidationError::Dns(url.to_string()))?;
251
252    // To error on the safe side, a hostname is valid if ALL its IPs are
253    // publicly addressable.
254    for addr in addrs {
255        if !IpExt::is_global(&addr.ip()) {
256            return Err(WebhookUrlValidationError::NonRoutableIp(
257                addr.ip().to_string(),
258            ));
259        }
260    }
261    Ok(())
262}
263
264#[cfg(feature = "validation")]
265fn validate_endpoint_scheme(
266    scheme: &str,
267) -> Result<(), WebhookUrlValidationError> {
268    if scheme == "http" || scheme == "https" {
269        Ok(())
270    } else {
271        Err(WebhookUrlValidationError::UnsupportedScheme(
272            scheme.to_string(),
273        ))
274    }
275}
276
277#[cfg(feature = "validation")]
278impl From<WebhookUrlValidationError> for ValidationError {
279    fn from(value: WebhookUrlValidationError) -> Self {
280        validation_error("EMIT_VALIDATION_FAILED", value.to_string())
281    }
282}
283
284#[cfg(all(test, feature = "validation"))]
285mod tests {
286
287    use super::{validate_webhook_url, HttpMethod};
288
289    #[test]
290    fn http_method_to_string() {
291        assert_eq!("GET", HttpMethod::Get.to_string());
292        assert_eq!("POST", HttpMethod::Post.to_string());
293        assert_eq!("PATCH", HttpMethod::Patch.to_string());
294        assert_eq!("DELETE", HttpMethod::Delete.to_string());
295        assert_eq!("PUT", HttpMethod::Put.to_string());
296        assert_eq!("HEAD", HttpMethod::Head.to_string());
297    }
298
299    #[test]
300    fn valid_urls() {
301        // This is a best effort approach to enable validation. This will
302        // sporadically fail due to the fact the env vars are shared
303        // process-wide.
304        // TODO: Replace with a more robust approach
305        std::env::remove_var("CRONBACK__SKIP_PUBLIC_IP_VALIDATION");
306        let urls = vec![
307            "https://google.com/url",
308            "https://example.com:3030/url",
309            "https://1.1.1.1/url",
310            "http://[2606:4700:4700::1111]/another_url/path",
311            "http://[2606:4700:4700::1111]:5050/another_url/path",
312            "http://user:pass@google.com/another_url/path",
313        ];
314
315        for url in urls {
316            let result = validate_webhook_url(url);
317            assert!(
318                matches!(result, Ok(())),
319                "URL: {}, result: {:?}",
320                url,
321                result,
322            );
323        }
324    }
325
326    #[test]
327    fn invalid_urls() {
328        std::env::remove_var("CRONBACK__SKIP_PUBLIC_IP_VALIDATION");
329        let urls = vec![
330            // Private IPs
331            "https://10.0.10.1",
332            "https://192.168.1.1",
333            "https://[::1]:80",
334            // Non-http url
335            "ftp://google.com",
336            // Lookback address
337            "https://localhost/url",
338            // Scheme-less
339            "google.com/url",
340            // Unparsable URL
341            "http---@goog.com",
342            // Non-existent domains
343            "https://ppqqzonlnp.io/url/url",
344        ];
345
346        for url in urls {
347            let result = validate_webhook_url(url);
348            assert!(
349                matches!(result, Err(_)),
350                "URL: {}, result: {:?}",
351                url,
352                result
353            );
354        }
355    }
356}