Skip to main content

use_guardrail/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        GuardrailAction, GuardrailCheckKind, GuardrailError, GuardrailId, GuardrailKind,
10        GuardrailName, GuardrailPolicyArea, GuardrailResultKind, GuardrailSeverity,
11        GuardrailStatus, GuardrailViolationKind,
12    };
13}
14
15macro_rules! guardrail_text_newtype {
16    ($name:ident) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, GuardrailError> {
22                non_empty_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28
29            pub fn value(&self) -> &str {
30                self.as_str()
31            }
32
33            pub fn into_string(self) -> String {
34                self.0
35            }
36        }
37
38        impl AsRef<str> for $name {
39            fn as_ref(&self) -> &str {
40                self.as_str()
41            }
42        }
43
44        impl fmt::Display for $name {
45            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
46                formatter.write_str(self.as_str())
47            }
48        }
49
50        impl FromStr for $name {
51            type Err = GuardrailError;
52
53            fn from_str(value: &str) -> Result<Self, Self::Err> {
54                Self::new(value)
55            }
56        }
57
58        impl TryFrom<&str> for $name {
59            type Error = GuardrailError;
60
61            fn try_from(value: &str) -> Result<Self, Self::Error> {
62                Self::new(value)
63            }
64        }
65    };
66}
67
68macro_rules! guardrail_enum {
69    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
70        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
71        pub enum $name {
72            $($variant),+
73        }
74
75        impl $name {
76            pub const ALL: &'static [Self] = &[$(Self::$variant),+];
77
78            pub const fn as_str(self) -> &'static str {
79                match self {
80                    $(Self::$variant => $label),+
81                }
82            }
83        }
84
85        impl fmt::Display for $name {
86            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
87                formatter.write_str(self.as_str())
88            }
89        }
90
91        impl FromStr for $name {
92            type Err = GuardrailError;
93
94            fn from_str(value: &str) -> Result<Self, Self::Err> {
95                match normalized_label(value)?.as_str() {
96                    $($label => Ok(Self::$variant),)+
97                    _ => Err(GuardrailError::UnknownLabel),
98                }
99            }
100        }
101    };
102}
103
104guardrail_text_newtype!(GuardrailName);
105guardrail_text_newtype!(GuardrailId);
106
107guardrail_enum!(GuardrailKind {
108    Input => "input",
109    Output => "output",
110    ToolUse => "tool-use",
111    Retrieval => "retrieval",
112    Memory => "memory",
113    Policy => "policy",
114    Format => "format",
115    RateLimit => "rate-limit",
116    CostLimit => "cost-limit",
117    HumanReview => "human-review",
118    Custom => "custom",
119});
120
121guardrail_enum!(GuardrailAction {
122    Allow => "allow",
123    Block => "block",
124    Redact => "redact",
125    Transform => "transform",
126    Warn => "warn",
127    Escalate => "escalate",
128    RequireReview => "require-review",
129    Refuse => "refuse",
130    Unknown => "unknown",
131});
132
133guardrail_enum!(GuardrailSeverity {
134    Informational => "informational",
135    Low => "low",
136    Medium => "medium",
137    High => "high",
138    Critical => "critical",
139});
140
141guardrail_enum!(GuardrailStatus {
142    Enabled => "enabled",
143    Disabled => "disabled",
144    Shadow => "shadow",
145    Testing => "testing",
146    Deprecated => "deprecated",
147});
148
149guardrail_enum!(GuardrailPolicyArea {
150    Safety => "safety",
151    Security => "security",
152    Privacy => "privacy",
153    Compliance => "compliance",
154    Copyright => "copyright",
155    Pii => "pii",
156    Secrets => "secrets",
157    Abuse => "abuse",
158    Quality => "quality",
159    Custom => "custom",
160});
161
162guardrail_enum!(GuardrailCheckKind {
163    Moderation => "moderation",
164    PiiDetection => "pii-detection",
165    SecretDetection => "secret-detection",
166    JailbreakDetection => "jailbreak-detection",
167    PromptInjectionDetection => "prompt-injection-detection",
168    CitationCheck => "citation-check",
169    SchemaCheck => "schema-check",
170    ToolPermissionCheck => "tool-permission-check",
171    Custom => "custom",
172});
173
174guardrail_enum!(GuardrailResultKind {
175    Passed => "passed",
176    Failed => "failed",
177    Warning => "warning",
178    Skipped => "skipped",
179    Error => "error",
180    Unknown => "unknown",
181});
182
183guardrail_enum!(GuardrailViolationKind {
184    UnsafeContent => "unsafe-content",
185    Pii => "pii",
186    Secret => "secret",
187    PromptInjection => "prompt-injection",
188    Jailbreak => "jailbreak",
189    PolicyViolation => "policy-violation",
190    SchemaViolation => "schema-violation",
191    ToolMisuse => "tool-misuse",
192    Unknown => "unknown",
193});
194
195#[derive(Clone, Copy, Debug, Eq, PartialEq)]
196pub enum GuardrailError {
197    Empty,
198    UnknownLabel,
199}
200
201impl fmt::Display for GuardrailError {
202    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
203        match self {
204            Self::Empty => formatter.write_str("guardrail metadata text cannot be empty"),
205            Self::UnknownLabel => formatter.write_str("unknown guardrail metadata label"),
206        }
207    }
208}
209
210impl Error for GuardrailError {}
211
212fn non_empty_text(value: impl AsRef<str>) -> Result<String, GuardrailError> {
213    let trimmed = value.as_ref().trim();
214    if trimmed.is_empty() {
215        Err(GuardrailError::Empty)
216    } else {
217        Ok(trimmed.to_string())
218    }
219}
220
221fn normalized_label(value: &str) -> Result<String, GuardrailError> {
222    let trimmed = value.trim();
223    if trimmed.is_empty() {
224        Err(GuardrailError::Empty)
225    } else {
226        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::{
233        GuardrailAction, GuardrailCheckKind, GuardrailError, GuardrailId, GuardrailKind,
234        GuardrailName, GuardrailPolicyArea, GuardrailResultKind, GuardrailSeverity,
235        GuardrailStatus, GuardrailViolationKind,
236    };
237    use core::{fmt, str::FromStr};
238
239    macro_rules! assert_text_newtype {
240        ($type:ty, $value:literal) => {{
241            let value = <$type>::new(concat!(" ", $value, " "))?;
242            assert_eq!(value.as_str(), $value);
243            assert_eq!(value.value(), $value);
244            assert_eq!(value.as_ref(), $value);
245            assert_eq!(value.to_string(), $value);
246            assert_eq!(<$type as TryFrom<&str>>::try_from($value)?, value);
247            assert_eq!(value.into_string(), $value.to_string());
248        }};
249    }
250
251    fn assert_enum_family<T>(variants: &[T]) -> Result<(), GuardrailError>
252    where
253        T: Copy + Eq + fmt::Debug + fmt::Display + FromStr<Err = GuardrailError>,
254    {
255        for variant in variants {
256            let label = variant.to_string();
257            assert_eq!(label.parse::<T>()?, *variant);
258            assert_eq!(label.replace('-', "_").parse::<T>()?, *variant);
259            assert_eq!(label.replace('-', " ").parse::<T>()?, *variant);
260        }
261        Ok(())
262    }
263
264    #[test]
265    fn validates_guardrail_text_newtypes() -> Result<(), GuardrailError> {
266        assert_text_newtype!(GuardrailName, "pii-redaction");
267        assert_text_newtype!(GuardrailId, "guardrail-001");
268        assert_eq!(GuardrailName::new("  "), Err(GuardrailError::Empty));
269        Ok(())
270    }
271
272    #[test]
273    fn displays_and_parses_guardrail_enums() -> Result<(), GuardrailError> {
274        assert_enum_family(GuardrailKind::ALL)?;
275        assert_enum_family(GuardrailAction::ALL)?;
276        assert_enum_family(GuardrailSeverity::ALL)?;
277        assert_enum_family(GuardrailStatus::ALL)?;
278        assert_enum_family(GuardrailPolicyArea::ALL)?;
279        assert_enum_family(GuardrailCheckKind::ALL)?;
280        assert_enum_family(GuardrailResultKind::ALL)?;
281        assert_enum_family(GuardrailViolationKind::ALL)?;
282        assert_eq!(
283            "require review".parse::<GuardrailAction>()?,
284            GuardrailAction::RequireReview
285        );
286        Ok(())
287    }
288}