1use super::super::Sample;
7use super::timestamped_observations::TimestampedObservations;
8use super::trimmed_observation::{ObservationLength, TrimmedObservation};
9use crate::internal::Timestamp;
10use std::collections::HashMap;
11use std::io;
12
13struct NonEmptyObservations {
14 aggregated_data: AggregatedObservations,
16 timestamped_data: TimestampedObservations,
19 obs_len: ObservationLength,
20 timestamped_samples_count: usize,
21}
22
23#[derive(Default)]
24pub struct Observations {
25 inner: Option<NonEmptyObservations>,
26}
27
28impl Observations {
30 pub fn new(observations_len: usize) -> Self {
31 #[allow(clippy::expect_used)]
32 Self::try_new(observations_len).expect("failed to initialize observations")
33 }
34
35 pub fn try_new(observations_len: usize) -> io::Result<Self> {
36 Ok(Observations {
37 inner: Some(NonEmptyObservations {
38 aggregated_data: AggregatedObservations::new(observations_len),
39 timestamped_data: TimestampedObservations::try_new(observations_len).map_err(
40 |err| {
41 io::Error::new(
42 err.kind(),
43 format!("failed to create timestamped observations: {err}"),
44 )
45 },
46 )?,
47 obs_len: ObservationLength::new(observations_len),
48 timestamped_samples_count: 0,
49 }),
50 })
51 }
52
53 pub fn add(
54 &mut self,
55 sample: Sample,
56 timestamp: Option<Timestamp>,
57 values: &[i64],
58 ) -> anyhow::Result<()> {
59 anyhow::ensure!(
60 self.inner.is_some(),
61 "Use of add on Observations that were not initialized"
62 );
63
64 let observations = unsafe { self.inner.as_mut().unwrap_unchecked() };
66 let obs_len = observations.obs_len;
67
68 anyhow::ensure!(
69 obs_len.eq(values.len()),
70 "Observation length mismatch, expected {obs_len:?} values, got {} instead",
71 values.len()
72 );
73
74 if let Some(ts) = timestamp {
75 observations.timestamped_data.add(sample, ts, values)?;
76 observations.timestamped_samples_count += 1;
77 } else {
78 observations.aggregated_data.add(sample, values)?;
79 }
80
81 Ok(())
82 }
83
84 pub fn is_empty(&self) -> bool {
85 self.inner.is_none()
86 || (self.aggregated_samples_count() == 0 && self.timestamped_samples_count() == 0)
87 }
88
89 pub fn aggregated_samples_count(&self) -> usize {
90 self.inner
91 .as_ref()
92 .map(|o| o.aggregated_data.len())
93 .unwrap_or(0)
94 }
95
96 pub fn timestamped_samples_count(&self) -> usize {
97 self.inner
98 .as_ref()
99 .map(|o| o.timestamped_samples_count)
100 .unwrap_or(0)
101 }
102
103 pub fn try_into_iter(self) -> io::Result<ObservationsIntoIter> {
104 match self.inner {
105 None => Ok(ObservationsIntoIter {
106 it: Box::new(std::iter::empty()),
107 }),
108 Some(NonEmptyObservations {
109 mut aggregated_data,
110 timestamped_data,
111 obs_len,
112 ..
113 }) => {
114 let ts_it = timestamped_data
115 .try_into_iter()?
116 .map(|(s, t, o)| (s, Some(t), o));
117
118 let agg_it = std::mem::take(&mut aggregated_data.data)
119 .into_iter()
120 .map(move |(s, o)| (s, None, unsafe { o.into_vec(obs_len) }));
121
122 Ok(ObservationsIntoIter {
123 it: Box::new(ts_it.chain(agg_it)),
124 })
125 }
126 }
127 }
128}
129
130#[derive(Default)]
131struct AggregatedObservations {
132 obs_len: ObservationLength,
133 data: HashMap<Sample, TrimmedObservation>,
134}
135
136impl AggregatedObservations {
137 pub fn new(obs_len: usize) -> Self {
138 AggregatedObservations {
139 obs_len: ObservationLength::new(obs_len),
140 data: Default::default(),
141 }
142 }
143
144 fn add(&mut self, sample: Sample, values: &[i64]) -> anyhow::Result<()> {
145 anyhow::ensure!(
146 self.obs_len.eq(values.len()),
147 "Observation length mismatch, expected {:?} values, got {} instead",
148 self.obs_len,
149 values.len()
150 );
151
152 if let Some(v) = self.data.get_mut(&sample) {
153 unsafe { v.as_mut_slice(self.obs_len) }
156 .iter_mut()
157 .zip(values)
158 .for_each(|(a, b)| *a = a.saturating_add(*b));
159 } else {
160 let trimmed = TrimmedObservation::new(values, self.obs_len);
161 self.data.insert(sample, trimmed);
162 }
163
164 Ok(())
165 }
166
167 fn len(&self) -> usize {
168 self.data.len()
169 }
170
171 #[allow(dead_code)]
172 fn is_empty(&self) -> bool {
173 self.data.is_empty()
174 }
175
176 #[allow(dead_code)]
177 fn contains_key(&self, sample: &Sample) -> bool {
178 self.data.contains_key(sample)
179 }
180
181 #[allow(dead_code)]
182 fn remove(&mut self, sample: &Sample) -> Option<TrimmedObservation> {
183 self.data.remove(sample)
184 }
185}
186
187impl Drop for AggregatedObservations {
188 fn drop(&mut self) {
189 let o = self.obs_len;
190 self.data.drain().for_each(|(_, v)| {
191 unsafe { v.consume(o) };
194 });
195 }
196}
197
198pub struct ObservationsIntoIter {
199 it: Box<dyn Iterator<Item = <ObservationsIntoIter as Iterator>::Item>>,
200}
201
202impl Iterator for ObservationsIntoIter {
203 type Item = (Sample, Option<Timestamp>, Vec<i64>);
204 fn next(&mut self) -> Option<Self::Item> {
205 self.it.next()
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use crate::collections::identifiable::*;
213 use crate::internal::{LabelSetId, StackTraceId};
214 use bolero::generator::*;
215 use std::num::NonZeroI64;
216
217 #[test]
218 fn add_and_iter_test() {
219 let mut o = Observations::new(3);
220 let s1 = Sample {
223 labels: LabelSetId::from_offset(1),
224 stacktrace: StackTraceId::from_offset(1),
225 };
226 let s2 = Sample {
227 labels: LabelSetId::from_offset(2),
228 stacktrace: StackTraceId::from_offset(2),
229 };
230 let s3 = Sample {
231 labels: LabelSetId::from_offset(3),
232 stacktrace: StackTraceId::from_offset(3),
233 };
234 let t1 = Some(Timestamp::new(1).unwrap());
235 let t2 = Some(Timestamp::new(2).unwrap());
236
237 o.add(s1, None, &[1, 2, 3]).unwrap();
238 o.add(s1, None, &[4, 5, 6]).unwrap();
239 o.add(s2, None, &[7, 8, 9]).unwrap();
240 o.add(s3, t1, &[10, 11, 12]).unwrap();
241 o.add(s2, t2, &[13, 14, 15]).unwrap();
242
243 assert_eq!(2, o.aggregated_samples_count());
245
246 assert_eq!(2, o.timestamped_samples_count());
247
248 o.try_into_iter().unwrap().for_each(|(k, ts, v)| {
249 if k == s1 {
250 assert_eq!(v, vec![5, 7, 9]);
252 } else if k == s2 {
253 if ts.is_some() {
255 assert_eq!(v, vec![13, 14, 15]);
256 assert_eq!(ts, t2);
257 } else {
258 assert_eq!(v, vec![7, 8, 9]);
259 assert!(ts.is_none());
260 }
261 } else if k == s3 {
262 assert_eq!(v, vec![10, 11, 12]);
264 assert_eq!(ts, t1);
265 } else {
266 panic!("Unexpected key");
267 }
268 });
269 }
270
271 #[test]
272 fn different_lengths_panic_different_key_no_ts() {
273 let s1 = Sample {
276 labels: LabelSetId::from_offset(1),
277 stacktrace: StackTraceId::from_offset(1),
278 };
279 let s2 = Sample {
280 labels: LabelSetId::from_offset(2),
281 stacktrace: StackTraceId::from_offset(2),
282 };
283
284 let mut o = Observations::new(3);
285 o.add(s1, None, &[1, 2, 3]).unwrap();
286 o.add(s2, None, &[4, 5]).unwrap_err();
287 }
288
289 #[test]
290 fn different_lengths_panic_same_key_no_ts() {
291 let s1 = Sample {
292 labels: LabelSetId::from_offset(1),
293 stacktrace: StackTraceId::from_offset(1),
294 };
295
296 let mut o = Observations::new(3);
297 o.add(s1, None, &[1, 2, 3]).unwrap();
298 o.add(s1, None, &[4, 5]).unwrap_err();
299 }
300
301 #[test]
302 fn different_lengths_panic_different_key_ts() {
303 let s1 = Sample {
306 labels: LabelSetId::from_offset(1),
307 stacktrace: StackTraceId::from_offset(1),
308 };
309 let s2 = Sample {
310 labels: LabelSetId::from_offset(2),
311 stacktrace: StackTraceId::from_offset(2),
312 };
313
314 let mut o = Observations::new(3);
315 let ts = NonZeroI64::new(1).unwrap();
316 o.add(s1, Some(ts), &[1, 2, 3]).unwrap();
317 o.add(s2, Some(ts), &[4, 5]).unwrap_err();
318 }
319
320 #[test]
321 fn different_lengths_panic_same_key_ts() {
322 let s1 = Sample {
323 labels: LabelSetId::from_offset(1),
324 stacktrace: StackTraceId::from_offset(1),
325 };
326
327 let mut o = Observations::new(3);
328 let ts = NonZeroI64::new(1).unwrap();
329 o.add(s1, Some(ts), &[1, 2, 3]).unwrap();
330 o.add(s1, Some(ts), &[4, 5]).unwrap_err();
331 }
332
333 #[test]
334 fn different_lengths_panic_different_key_mixed() {
335 let s1 = Sample {
338 labels: LabelSetId::from_offset(1),
339 stacktrace: StackTraceId::from_offset(1),
340 };
341 let s2 = Sample {
342 labels: LabelSetId::from_offset(2),
343 stacktrace: StackTraceId::from_offset(2),
344 };
345
346 let mut o = Observations::new(3);
347 let ts = NonZeroI64::new(1).unwrap();
348 o.add(s1, None, &[1, 2, 3]).unwrap();
349 o.add(s2, Some(ts), &[4, 5]).unwrap_err();
350 }
351
352 #[test]
353 #[should_panic]
354 fn different_lengths_panic_same_key_mixed() {
355 let s1 = Sample {
356 labels: LabelSetId::from_offset(1),
357 stacktrace: StackTraceId::from_offset(1),
358 };
359
360 let mut o = Observations::new(3);
361 let ts = NonZeroI64::new(1).unwrap();
362 o.add(s1, Some(ts), &[1, 2, 3]).unwrap();
363 o.add(s1, None, &[4, 5]).unwrap();
365 }
366
367 #[test]
368 fn into_iter_test() {
369 let mut o = Observations::new(3);
370 let s1 = Sample {
373 labels: LabelSetId::from_offset(1),
374 stacktrace: StackTraceId::from_offset(1),
375 };
376 let s2 = Sample {
377 labels: LabelSetId::from_offset(2),
378 stacktrace: StackTraceId::from_offset(2),
379 };
380 let s3 = Sample {
381 labels: LabelSetId::from_offset(3),
382 stacktrace: StackTraceId::from_offset(3),
383 };
384 let t1 = Some(Timestamp::new(1).unwrap());
385
386 o.add(s1, None, &[1, 2, 3]).unwrap();
387 o.add(s1, None, &[4, 5, 6]).unwrap();
388 o.add(s2, None, &[7, 8, 9]).unwrap();
389 o.add(s3, t1, &[1, 1, 2]).unwrap();
390
391 let mut count = 0;
392 o.try_into_iter().unwrap().for_each(|(k, ts, v)| {
393 count += 1;
394 if k == s1 {
395 assert!(ts.is_none());
396 assert_eq!(v, vec![5, 7, 9]);
397 } else if k == s2 {
398 assert!(ts.is_none());
399 assert_eq!(v, vec![7, 8, 9]);
400 } else if k == s3 {
401 assert_eq!(ts, t1);
402 assert_eq!(v, vec![1, 1, 2]);
403 } else {
404 panic!("Unexpected key");
405 }
406 });
407 assert_eq!(count, 3);
409 }
410
411 fn fuzz_inner(
412 observations_len: &usize,
413 ts_samples: &[(Sample, Timestamp, Vec<i64>)],
414 no_ts_samples: &[(Sample, Vec<i64>)],
415 ) {
416 let obs_len = ObservationLength::new(*observations_len);
417
418 let mut o = Observations::new(*observations_len);
419 assert!(o.is_empty());
420
421 let mut ts_samples_added = 0;
422
423 for (s, ts, v) in ts_samples {
424 if v.len() == *observations_len {
425 o.add(*s, Some(*ts), v).unwrap();
426 ts_samples_added += 1;
427 } else {
428 assert!(o.add(*s, Some(*ts), v).is_err());
429 }
430 }
431 assert_eq!(o.timestamped_samples_count(), ts_samples_added);
432
433 let mut aggregated_observations = AggregatedObservations::new(*observations_len);
434
435 for (s, v) in no_ts_samples {
436 if v.len() == *observations_len {
437 o.add(*s, None, v).unwrap();
438 aggregated_observations.add(*s, v).unwrap();
439 } else {
440 assert!(o.add(*s, None, v).is_err());
441 }
442 }
443
444 assert_eq!(o.aggregated_samples_count(), aggregated_observations.len());
445
446 let mut iter = o.try_into_iter().unwrap();
447 for (expected_sample, expected_ts, expected_values) in ts_samples.iter() {
448 if expected_values.len() != *observations_len {
449 continue;
450 }
451 let (sample, ts, values) = iter.next().unwrap();
452 assert_eq!(*expected_sample, sample);
453 assert_eq!(*expected_ts, ts.unwrap());
454 assert_eq!(*expected_values, values);
455 }
456
457 for (sample, ts, values) in iter {
458 assert!(ts.is_none());
459 assert!(aggregated_observations.contains_key(&sample));
460 let expected_values = aggregated_observations.remove(&sample).unwrap();
461 unsafe {
462 let b = expected_values.into_vec(obs_len);
463 assert_eq!(*b, values);
464 }
465 }
466 assert!(aggregated_observations.is_empty());
467 }
468
469 #[test]
470 fn fuzz_with_same_obs_len() {
471 let obs_len_gen = if cfg!(miri) {
474 1..=16usize
475 } else {
476 1..=1024usize
477 };
478 let num_ts_samples_gen = if cfg!(miri) {
479 1..=16usize
480 } else {
481 1..=1024usize
482 };
483 let num_samples_gen = if cfg!(miri) {
484 1..=16usize
485 } else {
486 1..=1024usize
487 };
488
489 bolero::check!()
494 .with_generator((obs_len_gen, num_ts_samples_gen, num_samples_gen))
495 .and_then(|(observations_len, num_ts_samples, num_samples)| {
496 let ts_samples = Vec::<(Sample, Timestamp, Vec<i64>)>::produce()
497 .with()
498 .values((
499 Sample::produce(),
500 Timestamp::produce(),
501 Vec::<i64>::produce().with().len(observations_len),
502 ))
503 .len(num_ts_samples);
504
505 let no_ts_samples = Vec::<(Sample, Vec<i64>)>::produce()
506 .with()
507 .values((
508 Sample::produce(),
509 Vec::<i64>::produce().with().len(observations_len),
510 ))
511 .len(num_samples);
512
513 (observations_len, ts_samples, no_ts_samples)
514 })
515 .for_each(|(observations_len, ts_samples, no_ts_samples)| {
516 fuzz_inner(observations_len, ts_samples, no_ts_samples);
517 });
518 }
519
520 #[test]
521 fn fuzz_with_random_obs_len() {
522 let num_ts_samples_gen = if cfg!(miri) {
523 1..=16usize
524 } else {
525 1..=1024usize
526 };
527 let num_samples_gen = if cfg!(miri) {
528 1..=16usize
529 } else {
530 1..=1024usize
531 };
532
533 bolero::check!()
534 .with_generator((num_ts_samples_gen, num_samples_gen))
535 .and_then(|(num_ts_samples, num_samples)| {
536 let ts_samples = Vec::<(Sample, Timestamp, Vec<i64>)>::produce()
537 .with()
538 .values((
539 Sample::produce(),
540 Timestamp::produce(),
541 Vec::<i64>::produce(),
542 ))
543 .len(num_ts_samples);
544
545 let no_ts_samples = Vec::<(Sample, Vec<i64>)>::produce()
546 .with()
547 .values((Sample::produce(), Vec::<i64>::produce()))
548 .len(num_samples);
549 (ts_samples, no_ts_samples)
550 })
551 .for_each(|(ts_samples, no_ts_samples)| {
552 fuzz_inner(&ts_samples[0].2.len(), ts_samples, no_ts_samples);
553 fuzz_inner(&no_ts_samples[0].1.len(), ts_samples, no_ts_samples);
557 });
558 }
559}