1use std::collections::HashMap;
2use std::fs;
3use std::path::Path;
4
5use serde::Deserialize;
6use serde_json::{Map, Value};
7use thiserror::Error;
8
9use crate::PiiClass;
10
11#[derive(Debug, Deserialize)]
12#[serde(deny_unknown_fields)]
13pub(crate) struct RawContext {
14 pub(crate) dictionaries: HashMap<String, RawContextDictionary>,
15 pub(crate) class_map: HashMap<String, String>,
16 pub(crate) fields: Map<String, Value>,
17}
18
19#[derive(Debug, Deserialize)]
20#[serde(deny_unknown_fields)]
21pub(crate) struct RawContextDictionary {
22 pub(crate) terms: Vec<String>,
23 #[serde(default)]
24 pub(crate) case_sensitive: bool,
25}
26
27#[derive(Debug, Clone, PartialEq)]
28pub struct Context {
29 pub dictionaries: HashMap<String, ContextDictionary>,
30 pub class_map: HashMap<String, PiiClass>,
31 pub fields: Map<String, Value>,
32}
33
34#[derive(Debug, Clone, Copy)]
35pub struct ContextFieldsRef<'a>(&'a Map<String, Value>);
36
37impl<'a> ContextFieldsRef<'a> {
38 pub fn as_map(&self) -> &'a Map<String, Value> {
39 self.0
40 }
41
42 pub fn get(&self, key: &str) -> Option<&'a Value> {
43 self.0.get(key)
44 }
45
46 pub fn iter(&self) -> impl Iterator<Item = (&'a String, &'a Value)> + 'a {
47 self.0.iter()
48 }
49
50 pub fn len(&self) -> usize {
51 self.0.len()
52 }
53
54 pub fn is_empty(&self) -> bool {
55 self.0.is_empty()
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
60pub struct ContextDictionary {
61 pub terms: Vec<String>,
62 pub case_sensitive: bool,
63}
64
65#[derive(Debug, Error)]
66#[non_exhaustive]
67pub enum ContextError {
68 #[error("failed to read context JSON: {0}")]
69 Io(#[source] std::io::Error),
70 #[error("failed to parse context JSON: {0}")]
71 Json(#[source] serde_json::Error),
72 #[error("unknown pii class in context class_map: {0}")]
73 UnknownClass(String),
74 #[error("context dictionary '{name}' has no terms")]
75 EmptyDictionary { name: String },
76 #[error(
77 "unicode dictionary insensitive matching unsupported in v0.4.0, use case_sensitive = true"
78 )]
79 UnicodeInsensitiveDictionaryUnsupported { name: String },
80}
81
82impl Context {
83 pub fn load(path: &Path) -> Result<Self, ContextError> {
84 let raw = fs::read_to_string(path).map_err(ContextError::Io)?;
85 Self::from_json_str(&raw)
86 }
87
88 pub fn from_json_str(raw: &str) -> Result<Self, ContextError> {
89 let raw = serde_json::from_str::<RawContext>(raw).map_err(ContextError::Json)?;
90 Self::from_raw(raw)
91 }
92
93 pub fn fields_typed(&self) -> ContextFieldsRef<'_> {
94 ContextFieldsRef(&self.fields)
95 }
96
97 fn from_raw(raw: RawContext) -> Result<Self, ContextError> {
98 let mut class_map = HashMap::with_capacity(raw.class_map.len());
99 for (name, class) in raw.class_map {
100 let parsed = PiiClass::from_policy_name(&class)
101 .ok_or_else(|| ContextError::UnknownClass(class.clone()))?;
102 class_map.insert(name, parsed);
103 }
104
105 let mut dictionaries = HashMap::with_capacity(raw.dictionaries.len());
106 for (name, dictionary) in raw.dictionaries {
107 if dictionary.terms.is_empty() {
108 return Err(ContextError::EmptyDictionary { name });
109 }
110 if !dictionary.case_sensitive && dictionary.terms.iter().any(|term| !term.is_ascii()) {
111 return Err(ContextError::UnicodeInsensitiveDictionaryUnsupported { name });
112 }
113 dictionaries.insert(
114 name,
115 ContextDictionary {
116 terms: dictionary.terms,
117 case_sensitive: dictionary.case_sensitive,
118 },
119 );
120 }
121
122 Ok(Self {
123 dictionaries,
124 class_map,
125 fields: raw.fields,
126 })
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn parses_typed_context_envelope() {
136 let ctx = Context::from_json_str(
137 r#"{
138 "dictionaries": {
139 "dict_alpha": { "terms": ["AAA-12345"], "case_sensitive": true }
140 },
141 "class_map": { "dict_alpha": "custom:class_alpha" },
142 "fields": { "tenant": "demo" }
143 }"#,
144 )
145 .expect("context");
146 assert_eq!(ctx.dictionaries["dict_alpha"].terms, vec!["AAA-12345"]);
147 assert_eq!(
148 ctx.class_map["dict_alpha"],
149 PiiClass::Custom("class_alpha".to_string())
150 );
151 assert_eq!(ctx.fields["tenant"], Value::String("demo".to_string()));
152 }
153
154 #[test]
155 fn rejects_unknown_top_level_context_keys() {
156 let err = serde_json::from_str::<RawContext>(
157 r#"{
158 "dictionaries": {},
159 "class_map": {},
160 "fields": {},
161 "extra": true
162 }"#,
163 )
164 .expect_err("unknown key must fail");
165 assert!(err.to_string().contains("unknown field"));
166 }
167
168 #[test]
169 fn rejects_unicode_case_insensitive_terms() {
170 assert!(matches!(
171 Context::from_json_str(
172 r#"{
173 "dictionaries": {
174 "songs": { "terms": ["Beyoncé"], "case_sensitive": false }
175 },
176 "class_map": { "songs": "custom:song" },
177 "fields": {}
178 }"#,
179 ),
180 Err(ContextError::UnicodeInsensitiveDictionaryUnsupported { .. })
181 ));
182 }
183
184 #[test]
185 fn context_fields_ref_iter_matches_underlying_map() {
186 let ctx = Context::from_json_str(
187 r#"{
188 "dictionaries": {},
189 "class_map": {},
190 "fields": { "tenant": "demo", "region": "eu" }
191 }"#,
192 )
193 .expect("context");
194
195 let typed = ctx.fields_typed();
196 let from_ref = typed.iter().collect::<Vec<_>>();
197 let from_map = ctx.fields.iter().collect::<Vec<_>>();
198
199 assert_eq!(from_ref, from_map);
200 assert_eq!(typed.len(), ctx.fields.len());
201 assert_eq!(
202 typed.get("tenant"),
203 Some(&Value::String("demo".to_string()))
204 );
205 assert!(!typed.is_empty());
206 }
207
208 #[test]
209 fn context_fields_ref_borrows_without_clone() {
210 let ctx = Context::from_json_str(
211 r#"{
212 "dictionaries": {},
213 "class_map": {},
214 "fields": { "tenant": "demo" }
215 }"#,
216 )
217 .expect("context");
218
219 assert!(std::ptr::eq(ctx.fields_typed().as_map(), &ctx.fields));
220 }
221}