1use std::collections::HashMap;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7
8#[derive(Serialize, Deserialize, ts_rs::TS, Clone, Copy, Debug, PartialEq, Eq, Hash)]
10pub enum SearchAttributeType {
11 String,
13 Int,
15 Float,
17 Bool,
19 Datetime,
21 KeywordList,
23}
24
25#[derive(Serialize, Deserialize, ts_rs::TS, Clone, Debug, PartialEq)]
27#[serde(tag = "type", content = "data")]
28pub enum SearchAttributeValue {
29 String(String),
31 Int(i64),
33 Float(f64),
35 Bool(bool),
37 Datetime(DateTime<Utc>),
39 KeywordList(Vec<String>),
41}
42
43impl SearchAttributeValue {
44 #[must_use]
46 pub const fn attribute_type(&self) -> SearchAttributeType {
47 match self {
48 Self::String(_) => SearchAttributeType::String,
49 Self::Int(_) => SearchAttributeType::Int,
50 Self::Float(_) => SearchAttributeType::Float,
51 Self::Bool(_) => SearchAttributeType::Bool,
52 Self::Datetime(_) => SearchAttributeType::Datetime,
53 Self::KeywordList(_) => SearchAttributeType::KeywordList,
54 }
55 }
56}
57
58#[derive(Serialize, Deserialize, ts_rs::TS, Clone, Debug, Default, PartialEq, Eq)]
60pub struct SearchAttributeSchema {
61 attributes: HashMap<String, SearchAttributeType>,
62}
63
64impl SearchAttributeSchema {
65 #[must_use]
67 pub fn new() -> Self {
68 Self::default()
69 }
70
71 pub fn register(
78 &mut self,
79 name: impl Into<String>,
80 attribute_type: SearchAttributeType,
81 ) -> Result<(), SearchAttributeError> {
82 let name = name.into();
83 if let Some(existing) = self.attributes.get(&name).copied() {
84 if existing == attribute_type {
85 return Ok(());
86 }
87
88 return Err(SearchAttributeError::ConflictingType {
89 name,
90 existing,
91 requested: attribute_type,
92 });
93 }
94
95 self.attributes.insert(name, attribute_type);
96 Ok(())
97 }
98
99 pub fn validate(
107 &self,
108 name: &str,
109 value: &SearchAttributeValue,
110 ) -> Result<(), SearchAttributeError> {
111 let expected = self.attributes.get(name).copied().ok_or_else(|| {
112 SearchAttributeError::UnregisteredAttribute {
113 name: String::from(name),
114 }
115 })?;
116 let actual = value.attribute_type();
117
118 if expected == actual {
119 Ok(())
120 } else {
121 Err(SearchAttributeError::TypeMismatch {
122 name: String::from(name),
123 expected,
124 actual,
125 })
126 }
127 }
128}
129
130#[must_use]
136pub fn search_attributes_from_events(
137 events: &[crate::Event],
138) -> HashMap<String, SearchAttributeValue> {
139 let mut attributes = HashMap::new();
140 for event in events {
141 if let crate::Event::SearchAttributesUpdated {
142 attributes: updated,
143 ..
144 } = event
145 {
146 attributes.extend(updated.clone());
147 }
148 }
149 attributes
150}
151
152#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq)]
154pub enum SearchAttributeError {
155 #[error("search attribute `{name}` is already registered as {existing:?}, not {requested:?}")]
157 ConflictingType {
158 name: String,
160 existing: SearchAttributeType,
162 requested: SearchAttributeType,
164 },
165 #[error("search attribute `{name}` is not registered")]
167 UnregisteredAttribute {
168 name: String,
170 },
171 #[error("search attribute `{name}` expected {expected:?}, got {actual:?}")]
173 TypeMismatch {
174 name: String,
176 expected: SearchAttributeType,
178 actual: SearchAttributeType,
180 },
181}
182
183#[cfg(test)]
184mod tests {
185 use chrono::{DateTime, Utc};
186
187 use super::{
188 SearchAttributeError, SearchAttributeSchema, SearchAttributeType, SearchAttributeValue,
189 };
190
191 fn recorded_at() -> DateTime<Utc> {
192 DateTime::from_timestamp(1_700_000_000, 123_000_000).unwrap_or_default()
193 }
194
195 #[test]
196 fn values_report_matching_attribute_types() {
197 let values = [
198 SearchAttributeValue::String(String::from("customer-123")),
199 SearchAttributeValue::Int(42),
200 SearchAttributeValue::Float(12.5),
201 SearchAttributeValue::Bool(true),
202 SearchAttributeValue::Datetime(recorded_at()),
203 SearchAttributeValue::KeywordList(vec![String::from("vip"), String::from("west")]),
204 ];
205 let expected_types = [
206 SearchAttributeType::String,
207 SearchAttributeType::Int,
208 SearchAttributeType::Float,
209 SearchAttributeType::Bool,
210 SearchAttributeType::Datetime,
211 SearchAttributeType::KeywordList,
212 ];
213
214 for (value, expected_type) in values.iter().zip(expected_types) {
215 assert_eq!(value.attribute_type(), expected_type);
216 }
217 }
218
219 #[test]
220 fn search_attribute_types_round_trip_through_json() -> Result<(), Box<dyn std::error::Error>> {
221 let attribute_types = [
222 SearchAttributeType::String,
223 SearchAttributeType::Int,
224 SearchAttributeType::Float,
225 SearchAttributeType::Bool,
226 SearchAttributeType::Datetime,
227 SearchAttributeType::KeywordList,
228 ];
229
230 for attribute_type in attribute_types {
231 let json = serde_json::to_string(&attribute_type)?;
232 let decoded = serde_json::from_str::<SearchAttributeType>(&json)?;
233 assert_eq!(attribute_type, decoded);
234 }
235 Ok(())
236 }
237
238 #[test]
239 fn search_attribute_values_round_trip_through_json() -> Result<(), Box<dyn std::error::Error>> {
240 let values = [
241 SearchAttributeValue::String(String::from("customer-123")),
242 SearchAttributeValue::Int(42),
243 SearchAttributeValue::Float(12.5),
244 SearchAttributeValue::Bool(true),
245 SearchAttributeValue::Datetime(recorded_at()),
246 SearchAttributeValue::KeywordList(vec![String::from("vip"), String::from("west")]),
247 ];
248
249 for value in values {
250 let json = serde_json::to_string(&value)?;
251 let decoded = serde_json::from_str::<SearchAttributeValue>(&json)?;
252 assert_eq!(value, decoded);
253 }
254 Ok(())
255 }
256
257 #[test]
258 fn schema_registers_and_validates_matching_types() -> Result<(), Box<dyn std::error::Error>> {
259 let mut schema = SearchAttributeSchema::new();
260 schema.register("customer_id", SearchAttributeType::String)?;
261 schema.register("customer_id", SearchAttributeType::String)?;
262
263 schema.validate(
264 "customer_id",
265 &SearchAttributeValue::String(String::from("customer-123")),
266 )?;
267 Ok(())
268 }
269
270 #[test]
271 fn registering_same_name_with_different_type_errors() -> Result<(), Box<dyn std::error::Error>>
272 {
273 let mut schema = SearchAttributeSchema::new();
274 schema.register("customer_id", SearchAttributeType::String)?;
275
276 assert_eq!(
277 schema.register("customer_id", SearchAttributeType::Int),
278 Err(SearchAttributeError::ConflictingType {
279 name: String::from("customer_id"),
280 existing: SearchAttributeType::String,
281 requested: SearchAttributeType::Int,
282 })
283 );
284 Ok(())
285 }
286
287 #[test]
288 fn validating_unregistered_attribute_errors() {
289 let schema = SearchAttributeSchema::new();
290
291 assert_eq!(
292 schema.validate(
293 "customer_id",
294 &SearchAttributeValue::String(String::from("customer-123"))
295 ),
296 Err(SearchAttributeError::UnregisteredAttribute {
297 name: String::from("customer_id"),
298 })
299 );
300 }
301
302 #[test]
303 fn validating_mismatched_type_errors() -> Result<(), Box<dyn std::error::Error>> {
304 let mut schema = SearchAttributeSchema::new();
305 schema.register("customer_id", SearchAttributeType::String)?;
306
307 assert_eq!(
308 schema.validate("customer_id", &SearchAttributeValue::Int(42)),
309 Err(SearchAttributeError::TypeMismatch {
310 name: String::from("customer_id"),
311 expected: SearchAttributeType::String,
312 actual: SearchAttributeType::Int,
313 })
314 );
315 Ok(())
316 }
317}