1use serde::{Deserialize, Serialize};
2
3use crate::GateValidationError;
4
5#[derive(Clone, Debug, Serialize)]
17#[serde(tag = "kind", rename_all = "snake_case")]
18pub enum Obligation {
19 Audit {
20 tag: String,
21 },
22 RateLimit {
23 window_secs: u64,
24 max: u32,
25 },
26 Custom {
31 value: serde_json::Value,
32 },
33}
34
35#[derive(Deserialize)]
37#[serde(tag = "kind", rename_all = "snake_case")]
38enum RawObligation {
39 Audit { tag: String },
40 RateLimit { window_secs: u64, max: u32 },
41 Custom { value: serde_json::Value },
42}
43
44impl TryFrom<RawObligation> for Obligation {
45 type Error = GateValidationError;
46
47 fn try_from(raw: RawObligation) -> Result<Self, Self::Error> {
48 match raw {
49 RawObligation::Audit { tag } => {
50 if tag.is_empty() {
51 return Err(GateValidationError::EmptyAuditTag);
52 }
53 Ok(Obligation::Audit { tag })
54 }
55 RawObligation::RateLimit { window_secs, max } => {
56 Obligation::try_rate_limit(window_secs, max)
57 }
58 RawObligation::Custom { value } => Ok(Obligation::Custom { value }),
59 }
60 }
61}
62
63impl<'de> Deserialize<'de> for Obligation {
64 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
65 where
66 D: serde::Deserializer<'de>,
67 {
68 let raw = RawObligation::deserialize(deserializer)?;
69 Obligation::try_from(raw).map_err(serde::de::Error::custom)
70 }
71}
72
73impl Obligation {
74 pub fn try_rate_limit(window_secs: u64, max: u32) -> Result<Self, GateValidationError> {
77 if window_secs == 0 {
78 return Err(GateValidationError::ZeroRateLimitWindow);
79 }
80 if max == 0 {
81 return Err(GateValidationError::ZeroRateLimitMax);
82 }
83 Ok(Self::RateLimit { window_secs, max })
84 }
85
86 pub fn rate_limit(window_secs: u64, max: u32) -> Self {
88 Self::try_rate_limit(window_secs, max)
89 .expect("Obligation::rate_limit: window_secs and max must be > 0")
90 }
91}