1use serde::{Deserialize, Serialize};
2
3use crate::GateValidationError;
4
5#[derive(Clone, Debug, Serialize)]
17#[serde(tag = "kind", rename_all = "snake_case")]
18pub enum Obligation {
19 Audit {
20 tag: String,
21 },
22 RateLimit {
23 window_secs: u64,
24 max: u32,
25 },
26 Custom {
31 value: serde_json::Value,
32 },
33}
34
35#[derive(Deserialize)]
37#[serde(tag = "kind", rename_all = "snake_case")]
38enum RawObligation {
39 Audit { tag: String },
40 RateLimit { window_secs: u64, max: u32 },
41 Custom { value: serde_json::Value },
42}
43
44impl TryFrom<RawObligation> for Obligation {
45 type Error = GateValidationError;
46
47 fn try_from(raw: RawObligation) -> Result<Self, Self::Error> {
48 match raw {
49 RawObligation::Audit { tag } => Ok(Obligation::Audit { tag }),
50 RawObligation::RateLimit { window_secs, max } => {
51 Obligation::try_rate_limit(window_secs, max)
52 }
53 RawObligation::Custom { value } => Ok(Obligation::Custom { value }),
54 }
55 }
56}
57
58impl<'de> Deserialize<'de> for Obligation {
59 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
60 where
61 D: serde::Deserializer<'de>,
62 {
63 let raw = RawObligation::deserialize(deserializer)?;
64 Obligation::try_from(raw).map_err(serde::de::Error::custom)
65 }
66}
67
68impl Obligation {
69 pub fn try_rate_limit(window_secs: u64, max: u32) -> Result<Self, GateValidationError> {
72 if window_secs == 0 {
73 return Err(GateValidationError::ZeroRateLimitWindow);
74 }
75 if max == 0 {
76 return Err(GateValidationError::ZeroRateLimitMax);
77 }
78 Ok(Self::RateLimit { window_secs, max })
79 }
80
81 pub fn rate_limit(window_secs: u64, max: u32) -> Self {
83 Self::try_rate_limit(window_secs, max)
84 .expect("Obligation::rate_limit: window_secs and max must be > 0")
85 }
86}