1use std::collections::BTreeMap;
14
15use serde::{Deserialize, Serialize};
16
17use crate::attributes::{Labels, Tags};
18use crate::ids::TrialId;
19use crate::names::{ElementName, NameError, ParameterName};
20use crate::value::Value;
21
22#[derive(Debug, thiserror::Error, PartialEq, Eq)]
32pub enum TrialError {
33 #[error(transparent)]
35 Name(#[from] NameError),
36
37 #[error("element '{element}' has no parameter assignments; an Assignments entry must be non-empty")]
39 EmptyElementAssignments {
40 element: String,
42 },
43
44 #[error(
47 "mis-addressed value: assignment key '{expected}' does not match value.provenance.parameter '{actual}'"
48 )]
49 MisaddressedValue {
50 expected: ParameterName,
52 actual: ParameterName,
54 },
55}
56
57type Result<T, E = TrialError> = std::result::Result<T, E>;
58
59#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
76#[serde(transparent)]
77pub struct Assignments(BTreeMap<ElementName, BTreeMap<ParameterName, Value>>);
78
79impl Assignments {
80 #[must_use]
82 pub fn empty() -> Self {
83 Self::default()
84 }
85
86 pub fn new(map: BTreeMap<ElementName, BTreeMap<ParameterName, Value>>) -> Result<Self> {
88 for (element, params) in &map {
89 if params.is_empty() {
90 return Err(TrialError::EmptyElementAssignments {
91 element: element.as_str().to_owned(),
92 });
93 }
94 for (pname, value) in params {
95 if value.parameter() != pname {
96 return Err(TrialError::MisaddressedValue {
97 expected: pname.clone(),
98 actual: value.parameter().clone(),
99 });
100 }
101 }
102 }
103 Ok(Self(map))
104 }
105
106 #[must_use]
108 pub fn get(&self, element: &ElementName, param: &ParameterName) -> Option<&Value> {
109 self.0.get(element).and_then(|p| p.get(param))
110 }
111
112 #[must_use]
114 pub fn for_element(
115 &self,
116 element: &ElementName,
117 ) -> Option<&BTreeMap<ParameterName, Value>> {
118 self.0.get(element)
119 }
120
121 pub fn iter(&self) -> impl Iterator<Item = (&ElementName, &ParameterName, &Value)> {
123 self.0
124 .iter()
125 .flat_map(|(e, params)| params.iter().map(move |(p, v)| (e, p, v)))
126 }
127
128 #[must_use]
130 pub fn len(&self) -> usize {
131 self.0.values().map(BTreeMap::len).sum()
132 }
133
134 #[must_use]
136 pub fn is_empty(&self) -> bool {
137 self.0.is_empty()
138 }
139
140 #[must_use]
142 pub fn element_count(&self) -> usize {
143 self.0.len()
144 }
145
146 pub(crate) const fn as_map(&self) -> &BTreeMap<ElementName, BTreeMap<ParameterName, Value>> {
148 &self.0
149 }
150}
151
152#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, bon::Builder)]
161pub struct TrialMetadata {
162 pub enumeration_index: Option<u32>,
165
166 pub group: Option<String>,
168
169 pub generation_method: Option<String>,
171
172 pub priority: Option<i32>,
174}
175
176#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, bon::Builder)]
182pub struct Trial {
183 pub id: TrialId,
185
186 pub assignments: Assignments,
188
189 #[builder(default)]
191 pub labels: Labels,
192
193 #[builder(default)]
195 pub tags: Tags,
196
197 pub metadata: Option<TrialMetadata>,
199}
200
201pub const TRIAL_TAG: u8 = 0x60;
205
206impl Trial {
207 #[must_use]
215 pub fn canonical_bytes(&self) -> Vec<u8> {
216 let mut out = Vec::new();
217 out.push(TRIAL_TAG);
218
219 let assignments = self.assignments.as_map();
220 let n_elements = u32::try_from(assignments.len()).expect("element count fits in u32");
221 out.extend_from_slice(&n_elements.to_le_bytes());
222
223 for (element_name, params) in assignments {
224 write_str_len_prefixed(&mut out, element_name.as_str());
225 let n_params = u32::try_from(params.len()).expect("parameter count fits in u32");
226 out.extend_from_slice(&n_params.to_le_bytes());
227 for (param_name, value) in params {
228 write_str_len_prefixed(&mut out, param_name.as_str());
229 out.extend_from_slice(value.fingerprint().as_bytes());
233 }
234 }
235
236 out
237 }
238}
239
240fn write_str_len_prefixed(out: &mut Vec<u8>, s: &str) {
241 let bytes = s.as_bytes();
242 let len = u32::try_from(bytes.len()).expect("string length fits in u32");
243 out.extend_from_slice(&len.to_le_bytes());
244 out.extend_from_slice(bytes);
245}
246
247#[cfg(test)]
248mod tests {
249 use std::collections::BTreeMap;
250
251 use crate::{Fingerprint, IntegerValue, LabelKey, LabelValue, Value};
252 use ulid::Ulid;
253
254 use super::*;
255
256 fn ename(s: &str) -> ElementName {
257 ElementName::new(s).unwrap()
258 }
259 fn pname(s: &str) -> ParameterName {
260 ParameterName::new(s).unwrap()
261 }
262 fn tid(n: u64) -> TrialId {
263 TrialId::from_ulid(Ulid::from_parts(n, 1))
264 }
265
266 fn one_element_assignments(
267 elem: &str,
268 param: &str,
269 value: i64,
270 ) -> Assignments {
271 let mut inner = BTreeMap::new();
272 inner.insert(pname(param), Value::integer(pname(param), value, None));
273 let mut outer = BTreeMap::new();
274 outer.insert(ename(elem), inner);
275 Assignments::new(outer).unwrap()
276 }
277
278 #[test]
281 fn assignments_rejects_empty_inner_map() {
282 let mut outer = BTreeMap::new();
283 outer.insert(ename("db"), BTreeMap::new());
284 let err = Assignments::new(outer).unwrap_err();
285 assert!(matches!(err, TrialError::EmptyElementAssignments { .. }));
286 }
287
288 #[test]
289 fn assignments_rejects_misaddressed_value() {
290 let mut inner = BTreeMap::new();
292 inner.insert(
293 pname("threads"),
294 Value::integer(pname("connections"), 8, None),
295 );
296 let mut outer = BTreeMap::new();
297 outer.insert(ename("db"), inner);
298 let err = Assignments::new(outer).unwrap_err();
299 assert!(matches!(err, TrialError::MisaddressedValue { .. }));
300 }
301
302 #[test]
303 fn assignments_iter_visits_every_triple() {
304 let mut inner = BTreeMap::new();
305 inner.insert(pname("a"), Value::integer(pname("a"), 1, None));
306 inner.insert(pname("b"), Value::integer(pname("b"), 2, None));
307 let mut outer = BTreeMap::new();
308 outer.insert(ename("x"), inner);
309 let a = Assignments::new(outer).unwrap();
310 let triples: Vec<(&str, &str, i64)> = a
311 .iter()
312 .map(|(e, p, v)| (e.as_str(), p.as_str(), v.as_integer().unwrap()))
313 .collect();
314 assert_eq!(triples, vec![("x", "a", 1), ("x", "b", 2)]);
315 }
316
317 #[test]
318 fn assignments_len_sums_across_elements() {
319 let mut inner1 = BTreeMap::new();
320 inner1.insert(pname("a"), Value::integer(pname("a"), 1, None));
321 let mut inner2 = BTreeMap::new();
322 inner2.insert(pname("a"), Value::integer(pname("a"), 1, None));
323 inner2.insert(pname("b"), Value::integer(pname("b"), 2, None));
324 let mut outer = BTreeMap::new();
325 outer.insert(ename("x"), inner1);
326 outer.insert(ename("y"), inner2);
327 let a = Assignments::new(outer).unwrap();
328 assert_eq!(a.len(), 3);
329 assert_eq!(a.element_count(), 2);
330 }
331
332 #[test]
333 fn assignments_serde_roundtrip() {
334 let a = one_element_assignments("db", "threads", 8);
335 let json = serde_json::to_string(&a).unwrap();
336 let back: Assignments = serde_json::from_str(&json).unwrap();
337 assert_eq!(a, back);
338 }
339
340 #[test]
343 fn trial_builds_and_accesses_its_fields() {
344 let t = Trial::builder()
345 .id(tid(1))
346 .assignments(one_element_assignments("db", "threads", 8))
347 .build();
348 assert_eq!(t.id, tid(1));
349 assert_eq!(t.assignments.len(), 1);
350 assert!(t.labels.is_empty());
351 assert!(t.metadata.is_none());
352 }
353
354 #[test]
355 fn trial_metadata_builder() {
356 let md = TrialMetadata::builder()
357 .enumeration_index(3)
358 .group("baseline".to_owned())
359 .priority(10)
360 .build();
361 assert_eq!(md.enumeration_index, Some(3));
362 assert_eq!(md.group.as_deref(), Some("baseline"));
363 assert_eq!(md.priority, Some(10));
364 }
365
366 #[test]
367 fn trial_serde_roundtrip() {
368 let t = Trial::builder()
369 .id(tid(42))
370 .assignments(one_element_assignments("db", "threads", 8))
371 .build();
372 let json = serde_json::to_string(&t).unwrap();
373 let back: Trial = serde_json::from_str(&json).unwrap();
374 assert_eq!(t, back);
375 }
376
377 #[test]
380 fn canonical_bytes_deterministic() {
381 let a = Trial::builder()
382 .id(tid(1))
383 .assignments(one_element_assignments("db", "threads", 8))
384 .build();
385 let b = Trial::builder()
386 .id(tid(2)) .assignments(one_element_assignments("db", "threads", 8))
388 .build();
389 assert_eq!(a.canonical_bytes(), b.canonical_bytes());
391 }
392
393 #[test]
394 fn canonical_bytes_distinguish_values() {
395 let a = Trial::builder()
396 .id(tid(1))
397 .assignments(one_element_assignments("db", "threads", 8))
398 .build();
399 let b = Trial::builder()
400 .id(tid(1))
401 .assignments(one_element_assignments("db", "threads", 16))
402 .build();
403 assert_ne!(a.canonical_bytes(), b.canonical_bytes());
404 }
405
406 #[test]
407 fn canonical_bytes_distinguish_element_names() {
408 let a = Trial::builder()
409 .id(tid(1))
410 .assignments(one_element_assignments("db", "threads", 8))
411 .build();
412 let b = Trial::builder()
413 .id(tid(1))
414 .assignments(one_element_assignments("cache", "threads", 8))
415 .build();
416 assert_ne!(a.canonical_bytes(), b.canonical_bytes());
417 }
418
419 #[test]
420 fn canonical_bytes_excludes_labels_and_metadata() {
421 let mut labels = Labels::new();
422 labels.insert(
423 LabelKey::new("trial_code").unwrap(),
424 LabelValue::new("0x0001").unwrap(),
425 );
426 let with_labels = Trial::builder()
427 .id(tid(1))
428 .assignments(one_element_assignments("db", "threads", 8))
429 .labels(labels)
430 .metadata(
431 TrialMetadata::builder()
432 .enumeration_index(7)
433 .build(),
434 )
435 .build();
436 let plain = Trial::builder()
437 .id(tid(1))
438 .assignments(one_element_assignments("db", "threads", 8))
439 .build();
440 assert_eq!(with_labels.canonical_bytes(), plain.canonical_bytes());
441 }
442
443 #[test]
444 fn canonical_bytes_match_hand_built_layout() {
445 let t = Trial::builder()
446 .id(tid(1))
447 .assignments(one_element_assignments("db", "threads", 42))
448 .build();
449 let got = t.canonical_bytes();
450
451 let mut expected = vec![TRIAL_TAG];
454 expected.extend_from_slice(&1u32.to_le_bytes()); let elem = "db";
456 expected.extend_from_slice(&u32::try_from(elem.len()).unwrap().to_le_bytes());
457 expected.extend_from_slice(elem.as_bytes());
458 expected.extend_from_slice(&1u32.to_le_bytes()); let param = "threads";
460 expected.extend_from_slice(&u32::try_from(param.len()).unwrap().to_le_bytes());
461 expected.extend_from_slice(param.as_bytes());
462 let fp = IntegerValue::fingerprint_of(&pname("threads"), 42);
463 expected.extend_from_slice(fp.as_bytes());
464 assert_eq!(got, expected);
465 let _ = Fingerprint::of(&got);
468 }
469}