rama_proxy/proxydb/
str.rs

1use serde::{Deserialize, Serialize};
2use std::{convert::Infallible, str::FromStr};
3use unicode_normalization::UnicodeNormalization;
4
5#[derive(Debug, Clone)]
6/// A string filter that normalizes the string prior to consumption.
7///
8/// Normalizations:
9///
10/// - trims whitespace
11/// - case-insensitive
12/// - NFC normalizes
13pub struct StringFilter(String);
14
15impl StringFilter {
16    /// Create a string filter which will match anything
17    #[must_use]
18    pub fn any() -> Self {
19        "*".into()
20    }
21
22    /// Create a new string filter.
23    pub fn new(value: impl AsRef<str>) -> Self {
24        Self(value.as_ref().trim().to_lowercase().nfc().collect())
25    }
26
27    /// Get the inner string.
28    #[must_use]
29    pub fn inner(&self) -> &str {
30        &self.0
31    }
32
33    /// Convert the string filter into the inner string.
34    #[must_use]
35    pub fn into_inner(self) -> String {
36        self.0
37    }
38
39    /// Return `true` if this value is considered an "any" value
40    #[must_use]
41    pub fn is_any(&self) -> bool {
42        self.0 == "*"
43    }
44}
45
46impl PartialEq for StringFilter {
47    fn eq(&self, other: &Self) -> bool {
48        match (self.0.as_str(), other.0.as_str()) {
49            ("*", _) | (_, "*") => true,
50            _ => self.0 == other.0,
51        }
52    }
53}
54
55impl Eq for StringFilter {}
56
57impl std::hash::Hash for StringFilter {
58    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59        self.0.hash(state);
60    }
61}
62
63impl AsRef<str> for StringFilter {
64    fn as_ref(&self) -> &str {
65        &self.0
66    }
67}
68
69impl FromStr for StringFilter {
70    type Err = Infallible;
71
72    fn from_str(s: &str) -> Result<Self, Self::Err> {
73        Ok(Self::new(s))
74    }
75}
76
77impl std::fmt::Display for StringFilter {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        write!(f, "{}", self.0)
80    }
81}
82
83impl From<StringFilter> for String {
84    fn from(filter: StringFilter) -> Self {
85        filter.0
86    }
87}
88
89impl From<&StringFilter> for String {
90    fn from(filter: &StringFilter) -> Self {
91        filter.0.clone()
92    }
93}
94
95impl From<&str> for StringFilter {
96    fn from(value: &str) -> Self {
97        Self::new(value)
98    }
99}
100
101impl From<String> for StringFilter {
102    fn from(value: String) -> Self {
103        Self::new(value)
104    }
105}
106
107impl From<&String> for StringFilter {
108    fn from(value: &String) -> Self {
109        Self::new(value)
110    }
111}
112
113impl Serialize for StringFilter {
114    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
115    where
116        S: serde::Serializer,
117    {
118        self.0.serialize(serializer)
119    }
120}
121
122impl<'de> Deserialize<'de> for StringFilter {
123    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
124    where
125        D: serde::Deserializer<'de>,
126    {
127        String::deserialize(deserializer).map(Self::new)
128    }
129}
130
131#[cfg(feature = "memory-db")]
132impl venndb::Any for StringFilter {
133    fn is_any(&self) -> bool {
134        Self::is_any(self)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_string_filter_creation() {
144        let filter = StringFilter::new("  Hello World  ");
145        assert_eq!(filter, "hello world".into());
146    }
147
148    #[test]
149    fn test_string_filter_nfc() {
150        let filter = StringFilter::new("ÅΩ");
151        assert_eq!(filter, "ÅΩ".into());
152    }
153
154    #[test]
155    fn test_string_filter_case_insensitive() {
156        let filter = StringFilter::new("Hello World");
157        assert_eq!(filter, "hello world".into());
158    }
159
160    #[test]
161    fn test_string_filter_deref() {
162        let filter = StringFilter::new("Hello World");
163        assert_eq!(filter.as_ref().to_ascii_uppercase(), "HELLO WORLD");
164    }
165
166    #[test]
167    fn test_string_filter_as_str() {
168        let filter = StringFilter::new("Hello World");
169        assert_eq!(filter.as_ref(), "hello world");
170    }
171
172    #[test]
173    fn test_string_filter_serialization() {
174        let filter = StringFilter::new("Hello World");
175        let json = serde_json::to_string(&filter).unwrap();
176        assert_eq!(json, "\"hello world\"");
177        let filter2: StringFilter = serde_json::from_str(&json).unwrap();
178        assert_eq!(filter, filter2);
179    }
180
181    #[test]
182    fn test_string_filter_deserialization() {
183        let json = "\"  Hello World\"";
184        let filter: StringFilter = serde_json::from_str(json).unwrap();
185        assert_eq!(filter, "hello world".into());
186    }
187
188    #[test]
189    fn test_string_filter_any() {
190        let filter = StringFilter::any();
191        assert!(filter.is_any());
192
193        let filter: StringFilter = "hello".into();
194        assert!(!filter.is_any());
195    }
196
197    #[test]
198    fn test_string_filter_eq_cases() {
199        for (a, b) in [
200            ("hello", "hello"),
201            ("hello", "HELLO"),
202            ("HELLO", "hello"),
203            ("HELLO", "HELLO"),
204            (" foo", "foo "),
205            ("foo ", " foo"),
206            (" FOO ", " foo"),
207            ("*", "*"),
208            ("*", "foo"),
209            ("foo", "*"),
210            ("  * ", "foo"),
211            ("foo", "  * "),
212        ] {
213            let a: StringFilter = a.into();
214            let b: StringFilter = b.into();
215            assert_eq!(a, b);
216        }
217    }
218
219    #[test]
220    fn test_string_filter_neq() {
221        for (a, b) in [("hello", "world"), ("world", "hello")] {
222            let a: StringFilter = a.into();
223            let b: StringFilter = b.into();
224            assert_ne!(a, b);
225        }
226    }
227}