1use crate::axes::is_valid_axis_name;
4use crate::error::ValidationError;
5use crate::types::Source;
6use serde::{Deserialize, Serialize};
7use std::collections::BTreeMap;
8use std::fmt;
9
10pub const MAX_USER_ID_LENGTH: usize = 256;
12
13pub type AxisValue = f32;
15
16#[derive(Clone, Serialize, Deserialize)]
27pub struct StateSnapshot {
28 pub user_id: String,
30
31 pub updated_at_unix_ms: i64,
33
34 pub source: Source,
36
37 pub confidence: f32,
39
40 pub axes: BTreeMap<String, AxisValue>,
43}
44
45fn redact_user_id(user_id: &str) -> String {
50 if user_id.len() <= 4 {
51 "[redacted]".to_string()
53 } else {
54 format!("{}...", &user_id[..4])
55 }
56}
57
58impl fmt::Debug for StateSnapshot {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_struct("StateSnapshot")
61 .field("user_id", &redact_user_id(&self.user_id))
62 .field("updated_at_unix_ms", &self.updated_at_unix_ms)
63 .field("source", &self.source)
64 .field("confidence", &self.confidence)
65 .field("axes", &self.axes)
66 .finish()
67 }
68}
69
70impl StateSnapshot {
71 pub fn builder() -> StateSnapshotBuilder {
73 StateSnapshotBuilder::new()
74 }
75
76 pub fn validate(&self) -> Result<(), ValidationError> {
78 validate_user_id(&self.user_id)?;
80
81 if !(0.0..=1.0).contains(&self.confidence) {
83 return Err(ValidationError::ConfidenceOutOfRange {
84 value: self.confidence,
85 });
86 }
87
88 for (name, value) in &self.axes {
90 if !is_valid_axis_name(name) {
91 return Err(ValidationError::InvalidAxisName { axis: name.clone() });
92 }
93 if !(0.0..=1.0).contains(value) {
94 return Err(ValidationError::AxisOutOfRange {
95 axis: name.clone(),
96 value: *value,
97 });
98 }
99 }
100
101 Ok(())
102 }
103
104 pub fn get_axis(&self, name: &str) -> AxisValue {
106 *self.axes.get(name).unwrap_or(&0.5)
107 }
108
109 pub fn get_axis_opt(&self, name: &str) -> Option<AxisValue> {
111 self.axes.get(name).copied()
112 }
113}
114
115impl Default for StateSnapshot {
116 fn default() -> Self {
117 Self {
118 user_id: String::new(),
119 updated_at_unix_ms: chrono::Utc::now().timestamp_millis(),
120 source: Source::default(),
121 confidence: 1.0,
122 axes: BTreeMap::new(),
123 }
124 }
125}
126
127#[derive(Default)]
129pub struct StateSnapshotBuilder {
130 user_id: Option<String>,
131 updated_at_unix_ms: Option<i64>,
132 source: Source,
133 confidence: f32,
134 axes: BTreeMap<String, AxisValue>,
135}
136
137impl StateSnapshotBuilder {
138 pub fn new() -> Self {
140 Self {
141 user_id: None,
142 updated_at_unix_ms: None,
143 source: Source::SelfReport,
144 confidence: 1.0,
145 axes: BTreeMap::new(),
146 }
147 }
148
149 pub fn user_id(mut self, user_id: impl Into<String>) -> Self {
151 self.user_id = Some(user_id.into());
152 self
153 }
154
155 pub fn updated_at(mut self, unix_ms: i64) -> Self {
157 self.updated_at_unix_ms = Some(unix_ms);
158 self
159 }
160
161 pub fn source(mut self, source: Source) -> Self {
163 self.source = source;
164 self
165 }
166
167 pub fn confidence(mut self, confidence: f32) -> Self {
169 self.confidence = confidence;
170 self
171 }
172
173 pub fn axis(mut self, name: impl Into<String>, value: AxisValue) -> Self {
175 self.axes.insert(name.into(), value);
176 self
177 }
178
179 pub fn axes(mut self, axes: impl IntoIterator<Item = (String, AxisValue)>) -> Self {
181 self.axes.extend(axes);
182 self
183 }
184
185 pub fn build(self) -> Result<StateSnapshot, ValidationError> {
187 let user_id = self.user_id.ok_or(ValidationError::MissingField {
188 field: "user_id".to_string(),
189 })?;
190
191 let snapshot = StateSnapshot {
192 user_id,
193 updated_at_unix_ms: self
194 .updated_at_unix_ms
195 .unwrap_or_else(|| chrono::Utc::now().timestamp_millis()),
196 source: self.source,
197 confidence: self.confidence,
198 axes: self.axes,
199 };
200
201 snapshot.validate()?;
202 Ok(snapshot)
203 }
204}
205
206pub fn validate_user_id(user_id: &str) -> Result<(), ValidationError> {
208 if user_id.is_empty() {
209 return Err(ValidationError::EmptyUserId);
210 }
211
212 if user_id.len() > MAX_USER_ID_LENGTH {
213 return Err(ValidationError::UserIdTooLong {
214 max: MAX_USER_ID_LENGTH,
215 });
216 }
217
218 if !user_id
219 .chars()
220 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
221 {
222 return Err(ValidationError::InvalidUserIdChars);
223 }
224
225 Ok(())
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_builder_basic() {
234 let snapshot = StateSnapshot::builder()
235 .user_id("user_123")
236 .axis("warmth", 0.7)
237 .build()
238 .unwrap();
239
240 assert_eq!(snapshot.user_id, "user_123");
241 assert_eq!(snapshot.get_axis("warmth"), 0.7);
242 assert_eq!(snapshot.get_axis("unknown"), 0.5); }
244
245 #[test]
246 fn test_validation_axis_out_of_range() {
247 let result = StateSnapshot::builder()
248 .user_id("user_123")
249 .axis("warmth", 1.5)
250 .build();
251
252 assert!(matches!(
253 result,
254 Err(ValidationError::AxisOutOfRange { axis, value })
255 if axis == "warmth" && value == 1.5
256 ));
257 }
258
259 #[test]
260 fn test_validation_invalid_user_id() {
261 let result = StateSnapshot::builder().user_id("user with spaces").build();
262
263 assert!(matches!(result, Err(ValidationError::InvalidUserIdChars)));
264 }
265
266 #[test]
267 fn test_validation_empty_user_id() {
268 let result = StateSnapshot::builder().user_id("").build();
269
270 assert!(matches!(result, Err(ValidationError::EmptyUserId)));
271 }
272
273 #[test]
274 fn test_serialization() {
275 let snapshot = StateSnapshot::builder()
276 .user_id("u_123")
277 .source(Source::SelfReport)
278 .confidence(1.0)
279 .axis("warmth", 0.6)
280 .axis("formality", 0.3)
281 .build()
282 .unwrap();
283
284 let json = serde_json::to_string(&snapshot).unwrap();
285 let parsed: StateSnapshot = serde_json::from_str(&json).unwrap();
286
287 assert_eq!(parsed.user_id, snapshot.user_id);
288 assert_eq!(parsed.get_axis("warmth"), 0.6);
289 }
290
291 #[test]
292 fn test_debug_redacts_user_id() {
293 let snapshot = StateSnapshot::builder()
294 .user_id("user_123456789")
295 .axis("warmth", 0.5)
296 .build()
297 .unwrap();
298
299 let debug_output = format!("{:?}", snapshot);
300
301 assert!(debug_output.contains("user..."));
303 assert!(!debug_output.contains("user_123456789"));
305 }
306
307 #[test]
308 fn test_debug_redacts_short_user_id() {
309 let snapshot = StateSnapshot::builder()
310 .user_id("ab12")
311 .axis("warmth", 0.5)
312 .build()
313 .unwrap();
314
315 let debug_output = format!("{:?}", snapshot);
316
317 assert!(debug_output.contains("[redacted]"));
319 assert!(!debug_output.contains("ab12"));
320 }
321
322 #[test]
323 fn test_redact_user_id_function() {
324 assert_eq!(redact_user_id("user_12345"), "user...");
326 assert_eq!(redact_user_id("abcde"), "abcd...");
327
328 assert_eq!(redact_user_id("abc"), "[redacted]");
330 assert_eq!(redact_user_id("abcd"), "[redacted]");
331 assert_eq!(redact_user_id(""), "[redacted]");
332 }
333
334 mod property_tests {
336 use super::*;
337 use proptest::prelude::*;
338
339 fn valid_user_id() -> impl Strategy<Value = String> {
341 "[a-zA-Z0-9_-]{1,64}".prop_filter("non-empty", |s| !s.is_empty())
342 }
343
344 fn valid_axis_value() -> impl Strategy<Value = f32> {
346 0.0f32..=1.0f32
347 }
348
349 fn valid_axis_name() -> impl Strategy<Value = String> {
351 "[a-z][a-z0-9_]{0,30}[a-z0-9]?"
352 .prop_filter("must not end with underscore", |s| !s.ends_with('_'))
353 }
354
355 proptest! {
356 #[test]
357 fn prop_valid_axis_values_accepted(value in valid_axis_value()) {
358 let result = StateSnapshot::builder()
359 .user_id("test_user")
360 .axis("test_axis", value)
361 .build();
362
363 prop_assert!(result.is_ok());
364 let snapshot = result.unwrap();
365 let stored = snapshot.get_axis("test_axis");
366 prop_assert!((stored - value).abs() < f32::EPSILON || stored == 0.5);
367 }
368
369 #[test]
370 fn prop_invalid_axis_values_rejected(value in prop::num::f32::ANY) {
371 prop_assume!(!(0.0..=1.0).contains(&value));
372 prop_assume!(!value.is_nan());
373
374 let result = StateSnapshot::builder()
375 .user_id("test_user")
376 .axis("test_axis", value)
377 .build();
378
379 prop_assert!(result.is_err());
380 }
381
382 #[test]
383 fn prop_valid_user_ids_accepted(user_id in valid_user_id()) {
384 let result = StateSnapshot::builder()
385 .user_id(&user_id)
386 .build();
387
388 prop_assert!(result.is_ok());
389 }
390
391 #[test]
392 fn prop_snapshot_roundtrip_serialization(
393 user_id in valid_user_id(),
394 warmth in valid_axis_value(),
395 formality in valid_axis_value(),
396 confidence in valid_axis_value(),
397 ) {
398 let snapshot = StateSnapshot::builder()
399 .user_id(&user_id)
400 .confidence(confidence)
401 .axis("warmth", warmth)
402 .axis("formality", formality)
403 .build()
404 .unwrap();
405
406 let json = serde_json::to_string(&snapshot).unwrap();
407 let parsed: StateSnapshot = serde_json::from_str(&json).unwrap();
408
409 prop_assert_eq!(&parsed.user_id, &snapshot.user_id);
410 prop_assert!((parsed.confidence - snapshot.confidence).abs() < f32::EPSILON);
411 prop_assert!((parsed.get_axis("warmth") - snapshot.get_axis("warmth")).abs() < f32::EPSILON);
412 }
413
414 #[test]
415 fn prop_multiple_axes_preserved(
416 axes in prop::collection::btree_map(
417 valid_axis_name(),
418 valid_axis_value(),
419 0..20
420 )
421 ) {
422 let mut builder = StateSnapshot::builder().user_id("test_user");
423
424 for (name, value) in &axes {
425 builder = builder.axis(name, *value);
426 }
427
428 let snapshot = builder.build().unwrap();
429
430 for (name, expected_value) in &axes {
431 let actual = snapshot.get_axis(name);
432 prop_assert!(
433 (actual - expected_value).abs() < f32::EPSILON,
434 "Axis {} expected {} but got {}", name, expected_value, actual
435 );
436 }
437 }
438
439 #[test]
440 fn prop_get_axis_returns_default_for_unknown(
441 axis_name in valid_axis_name()
442 ) {
443 let snapshot = StateSnapshot::builder()
444 .user_id("test_user")
445 .build()
446 .unwrap();
447
448 let value = snapshot.get_axis(&axis_name);
449 prop_assert_eq!(value, 0.5, "Unknown axis should return default 0.5");
450 }
451 }
452 }
453}