markov_chain_monte_carlo/diagnostics.rs
1//! Trace recording and export helpers for MCMC diagnostics.
2//!
3//! This module stores numeric observable traces independently from plotting,
4//! notebook rendering, or downstream statistical estimators. A [`Trace`]
5//! contains one row per completed step, a stable [`ChainId`], accept/reject
6//! metadata through [`TraceStepOutcome`], the chain's cached target
7//! log-probability, and caller-defined numeric observable columns.
8//!
9//! # Acceptance signal availability by stepping path
10//!
11//! Populating accept/reject metadata requires an acceptance signal from the
12//! step that produced the row, and the three stepping paths expose different
13//! information:
14//!
15//! - In-place stepping ([`Chain::step_mut`](crate::Chain::step_mut) and
16//! [`Sampler::step_mut`](crate::Sampler::step_mut)) returns `bool`; convert
17//! it with [`TraceStepOutcome::from_proposal_acceptance`].
18//! - Delayed stepping ([`Chain::step_delayed`](crate::Chain::step_delayed) and
19//! [`Sampler::step_delayed`](crate::Sampler::step_delayed)) returns a
20//! [`Step`]/[`StepOutcome`]; convert it with the [`From`] implementations on
21//! [`TraceStepOutcome`].
22//! - By-value stepping ([`Chain::step`](crate::Chain::step) and
23//! [`Sampler::step`](crate::Sampler::step)) returns `Result<(), _>` with no
24//! acceptance information, so a by-value caller cannot truthfully populate
25//! accept/reject. Record from an in-place or delayed step when a trace must
26//! capture acceptance exactly.
27
28use std::collections::BTreeSet;
29use std::error::Error;
30use std::fmt;
31use std::io::{self, Write};
32
33use crate::{Chain, Step, StepOutcome};
34
35/// Stable identifier for one recorded Markov chain.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
37#[must_use]
38pub struct ChainId(usize);
39
40impl ChainId {
41 /// Create a chain identifier from a caller-owned index.
42 ///
43 /// ```
44 /// use markov_chain_monte_carlo::prelude::ChainId;
45 ///
46 /// let id = ChainId::new(2);
47 /// assert_eq!(id.get(), 2);
48 /// ```
49 pub const fn new(id: usize) -> Self {
50 Self(id)
51 }
52
53 /// Return the raw chain identifier.
54 ///
55 /// ```
56 /// use markov_chain_monte_carlo::prelude::ChainId;
57 ///
58 /// assert_eq!(ChainId::new(7).get(), 7);
59 /// ```
60 #[must_use]
61 pub const fn get(self) -> usize {
62 self.0
63 }
64}
65
66impl fmt::Display for ChainId {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 write!(f, "{}", self.0)
69 }
70}
71
72/// Acceptance/proposal outcome recorded for one trace row.
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74#[must_use]
75pub struct TraceStepOutcome {
76 accepted: bool,
77 proposed: bool,
78}
79
80impl TraceStepOutcome {
81 /// A concrete proposal was accepted.
82 ///
83 /// ```
84 /// use markov_chain_monte_carlo::prelude::TraceStepOutcome;
85 ///
86 /// let outcome = TraceStepOutcome::accepted();
87 /// assert!(outcome.is_accepted());
88 /// assert!(outcome.had_proposal());
89 /// ```
90 pub const fn accepted() -> Self {
91 Self {
92 accepted: true,
93 proposed: true,
94 }
95 }
96
97 /// A concrete proposal was rejected by the Metropolis-Hastings draw.
98 ///
99 /// ```
100 /// use markov_chain_monte_carlo::prelude::TraceStepOutcome;
101 ///
102 /// let outcome = TraceStepOutcome::rejected_proposal();
103 /// assert!(!outcome.is_accepted());
104 /// assert!(outcome.had_proposal());
105 /// ```
106 pub const fn rejected_proposal() -> Self {
107 Self {
108 accepted: false,
109 proposed: true,
110 }
111 }
112
113 /// No concrete proposal was available, so the step was a self-loop.
114 ///
115 /// ```
116 /// use markov_chain_monte_carlo::prelude::TraceStepOutcome;
117 ///
118 /// let outcome = TraceStepOutcome::no_proposal();
119 /// assert!(!outcome.is_accepted());
120 /// assert!(!outcome.had_proposal());
121 /// ```
122 pub const fn no_proposal() -> Self {
123 Self {
124 accepted: false,
125 proposed: false,
126 }
127 }
128
129 /// Convert a concrete proposal's boolean acceptance result.
130 ///
131 /// This is the natural adapter for [`crate::Chain::step_mut`] and
132 /// [`crate::Sampler::step_mut`] when the caller knows the in-place proposal
133 /// produced a concrete move. A `false` value is recorded as
134 /// [`Self::rejected_proposal`], not [`Self::no_proposal`].
135 ///
136 /// ```
137 /// use markov_chain_monte_carlo::prelude::TraceStepOutcome;
138 ///
139 /// let accepted = TraceStepOutcome::from_proposal_acceptance(true);
140 /// let rejected = TraceStepOutcome::from_proposal_acceptance(false);
141 ///
142 /// assert!(accepted.is_accepted());
143 /// assert!(accepted.had_proposal());
144 /// assert!(!rejected.is_accepted());
145 /// assert!(rejected.had_proposal());
146 /// ```
147 pub const fn from_proposal_acceptance(accepted: bool) -> Self {
148 if accepted {
149 Self::accepted()
150 } else {
151 Self::rejected_proposal()
152 }
153 }
154
155 /// Whether the step accepted and committed a concrete proposal.
156 #[must_use]
157 pub const fn is_accepted(self) -> bool {
158 self.accepted
159 }
160
161 /// Whether the step included a concrete proposal.
162 #[must_use]
163 pub const fn had_proposal(self) -> bool {
164 self.proposed
165 }
166}
167
168impl From<StepOutcome> for TraceStepOutcome {
169 fn from(outcome: StepOutcome) -> Self {
170 match outcome {
171 StepOutcome::Accepted => Self::accepted(),
172 StepOutcome::RejectedProposal => Self::rejected_proposal(),
173 StepOutcome::NoProposal => Self::no_proposal(),
174 }
175 }
176}
177
178impl<I> From<&Step<I>> for TraceStepOutcome {
179 fn from(step: &Step<I>) -> Self {
180 step.outcome.into()
181 }
182}
183
184/// One recorded post-step trace row.
185#[derive(Debug, Clone, PartialEq)]
186#[must_use]
187pub struct TraceRecord {
188 chain_id: ChainId,
189 step: usize,
190 outcome: TraceStepOutcome,
191 log_prob: f64,
192 observable_values: Vec<f64>,
193}
194
195impl TraceRecord {
196 /// Create one trace record from already-computed observable values.
197 ///
198 /// Most callers should use [`TraceRecorder::record`] so the step number and
199 /// log-probability come from a live [`Chain`]. Use this constructor when
200 /// assembling rows from an external source or merging trace data manually.
201 ///
202 /// ```
203 /// use markov_chain_monte_carlo::prelude::{ChainId, TraceRecord, TraceStepOutcome};
204 ///
205 /// let record = TraceRecord::new(
206 /// ChainId::new(0),
207 /// 10,
208 /// TraceStepOutcome::accepted(),
209 /// -1.25,
210 /// vec![3.0, 0.5],
211 /// );
212 ///
213 /// assert_eq!(record.step(), 10);
214 /// assert_eq!(record.observable_values(), &[3.0, 0.5]);
215 /// ```
216 pub const fn new(
217 chain_id: ChainId,
218 step: usize,
219 outcome: TraceStepOutcome,
220 log_prob: f64,
221 observable_values: Vec<f64>,
222 ) -> Self {
223 Self {
224 chain_id,
225 step,
226 outcome,
227 log_prob,
228 observable_values,
229 }
230 }
231
232 /// Chain identifier for this row.
233 pub const fn chain_id(&self) -> ChainId {
234 self.chain_id
235 }
236
237 /// Completed step number for this row.
238 #[must_use]
239 pub const fn step(&self) -> usize {
240 self.step
241 }
242
243 /// Acceptance/proposal outcome for this row.
244 pub const fn outcome(&self) -> TraceStepOutcome {
245 self.outcome
246 }
247
248 /// Cached target log-probability after this step.
249 #[must_use]
250 pub const fn log_prob(&self) -> f64 {
251 self.log_prob
252 }
253
254 /// Numeric observable values in the same order as the trace headers.
255 #[must_use]
256 pub fn observable_values(&self) -> &[f64] {
257 &self.observable_values
258 }
259}
260
261/// Errors returned while constructing trace data.
262#[derive(Debug, Clone, PartialEq, Eq)]
263#[non_exhaustive]
264pub enum TraceError {
265 /// An observable name was empty.
266 EmptyObservableName {
267 /// Zero-based position of the empty name.
268 index: usize,
269 },
270 /// An observable name appeared more than once.
271 DuplicateObservableName {
272 /// Duplicated observable name.
273 name: String,
274 },
275 /// A row had a different number of values than the trace header.
276 ObservableCountMismatch {
277 /// Number of values required by the header.
278 expected: usize,
279 /// Number of values provided for the row.
280 actual: usize,
281 },
282 /// Two traces used different observable columns.
283 ObservableNamesMismatch {
284 /// Observable columns required by the receiving trace.
285 expected: Vec<String>,
286 /// Observable columns provided by the appended trace.
287 actual: Vec<String>,
288 },
289}
290
291impl fmt::Display for TraceError {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 match self {
294 Self::EmptyObservableName { index } => {
295 write!(f, "observable name at index {index} is empty")
296 }
297 Self::DuplicateObservableName { name } => {
298 write!(f, "observable name {name:?} appears more than once")
299 }
300 Self::ObservableCountMismatch { expected, actual } => write!(
301 f,
302 "trace row has {actual} observable values, expected {expected}"
303 ),
304 Self::ObservableNamesMismatch { expected, actual } => write!(
305 f,
306 "trace observable columns differ: expected {expected:?}, got {actual:?}"
307 ),
308 }
309 }
310}
311
312impl Error for TraceError {}
313
314/// Multi-chain numeric trace with shared observable columns.
315#[derive(Debug, Clone, PartialEq)]
316#[must_use]
317pub struct Trace {
318 observable_names: Vec<String>,
319 records: Vec<TraceRecord>,
320}
321
322impl Trace {
323 /// Create an empty trace with the given observable columns.
324 ///
325 /// ```
326 /// use markov_chain_monte_carlo::prelude::{Trace, TraceError};
327 ///
328 /// let trace = Trace::new(["energy", "magnetization"])?;
329 ///
330 /// assert!(trace.is_empty());
331 /// assert_eq!(trace.observable_names(), &["energy", "magnetization"]);
332 /// # Ok::<(), TraceError>(())
333 /// ```
334 ///
335 /// # Errors
336 ///
337 /// Returns [`TraceError`] if an observable name is empty or duplicated.
338 pub fn new(
339 observable_names: impl IntoIterator<Item = impl Into<String>>,
340 ) -> Result<Self, TraceError> {
341 Ok(Self {
342 observable_names: validate_observable_names(observable_names)?,
343 records: Vec::new(),
344 })
345 }
346
347 /// Observable column names.
348 #[must_use]
349 pub fn observable_names(&self) -> &[String] {
350 &self.observable_names
351 }
352
353 /// Recorded rows.
354 pub fn records(&self) -> &[TraceRecord] {
355 &self.records
356 }
357
358 /// Whether the trace contains no rows.
359 #[must_use]
360 pub const fn is_empty(&self) -> bool {
361 self.records.is_empty()
362 }
363
364 /// Number of recorded rows.
365 #[must_use]
366 pub const fn len(&self) -> usize {
367 self.records.len()
368 }
369
370 /// Append one record.
371 ///
372 /// ```
373 /// use markov_chain_monte_carlo::prelude::{
374 /// ChainId, Trace, TraceError, TraceRecord, TraceStepOutcome,
375 /// };
376 ///
377 /// let mut trace = Trace::new(["energy"])?;
378 /// trace.push(TraceRecord::new(
379 /// ChainId::new(0),
380 /// 1,
381 /// TraceStepOutcome::accepted(),
382 /// -2.0,
383 /// vec![2.0],
384 /// ))?;
385 ///
386 /// assert_eq!(trace.len(), 1);
387 /// # Ok::<(), TraceError>(())
388 /// ```
389 ///
390 /// # Errors
391 ///
392 /// Returns [`TraceError::ObservableCountMismatch`] if the row does not
393 /// match this trace's observable columns.
394 pub fn push(&mut self, record: TraceRecord) -> Result<(), TraceError> {
395 let actual = record.observable_values.len();
396 let expected = self.observable_names.len();
397 if actual != expected {
398 return Err(TraceError::ObservableCountMismatch { expected, actual });
399 }
400 self.records.push(record);
401 Ok(())
402 }
403
404 /// Append all rows from another trace with identical observable columns.
405 ///
406 /// ```
407 /// use markov_chain_monte_carlo::prelude::{Trace, TraceError};
408 ///
409 /// let mut combined = Trace::new(["energy"])?;
410 /// let other = Trace::new(["energy"])?;
411 ///
412 /// combined.extend(other)?;
413 /// assert!(combined.is_empty());
414 /// # Ok::<(), TraceError>(())
415 /// ```
416 ///
417 /// # Errors
418 ///
419 /// Returns [`TraceError::ObservableNamesMismatch`] if the traces use
420 /// different observable columns.
421 pub fn extend(&mut self, other: Self) -> Result<(), TraceError> {
422 if self.observable_names != other.observable_names {
423 return Err(TraceError::ObservableNamesMismatch {
424 expected: self.observable_names.clone(),
425 actual: other.observable_names,
426 });
427 }
428 self.records.extend(other.records);
429 Ok(())
430 }
431
432 /// Iterate over rows for one chain.
433 ///
434 /// ```
435 /// use markov_chain_monte_carlo::prelude::{
436 /// ChainId, Trace, TraceError, TraceRecord, TraceStepOutcome,
437 /// };
438 ///
439 /// let mut trace = Trace::new(["energy"])?;
440 /// trace.push(TraceRecord::new(
441 /// ChainId::new(1),
442 /// 1,
443 /// TraceStepOutcome::accepted(),
444 /// 0.0,
445 /// vec![1.0],
446 /// ))?;
447 ///
448 /// assert_eq!(trace.records_for_chain(ChainId::new(1)).count(), 1);
449 /// assert_eq!(trace.records_for_chain(ChainId::new(0)).count(), 0);
450 /// # Ok::<(), TraceError>(())
451 /// ```
452 pub fn records_for_chain(&self, chain_id: ChainId) -> impl Iterator<Item = &TraceRecord> + '_ {
453 self.records
454 .iter()
455 .filter(move |record| record.chain_id == chain_id)
456 }
457
458 /// Acceptance rate for one chain, counting no-proposal self-loops as
459 /// rejected steps.
460 ///
461 /// Returns `0.0` when no rows exist for `chain_id`.
462 ///
463 /// ```
464 /// use approx::assert_relative_eq;
465 /// use markov_chain_monte_carlo::prelude::{
466 /// ChainId, Trace, TraceError, TraceRecord, TraceStepOutcome,
467 /// };
468 ///
469 /// let mut trace = Trace::new(["energy"])?;
470 /// trace.push(TraceRecord::new(
471 /// ChainId::new(0),
472 /// 1,
473 /// TraceStepOutcome::accepted(),
474 /// 0.0,
475 /// vec![1.0],
476 /// ))?;
477 /// trace.push(TraceRecord::new(
478 /// ChainId::new(0),
479 /// 2,
480 /// TraceStepOutcome::rejected_proposal(),
481 /// 0.0,
482 /// vec![1.0],
483 /// ))?;
484 ///
485 /// assert_relative_eq!(trace.acceptance_rate(ChainId::new(0)), 0.5);
486 /// # Ok::<(), TraceError>(())
487 /// ```
488 #[must_use]
489 #[expect(
490 clippy::cast_precision_loss,
491 reason = "trace lengths are diagnostic counters and fit in f64 for practical runs"
492 )]
493 pub fn acceptance_rate(&self, chain_id: ChainId) -> f64 {
494 let mut accepted = 0_usize;
495 let mut total = 0_usize;
496 for record in self.records_for_chain(chain_id) {
497 total = total.saturating_add(1);
498 if record.outcome.is_accepted() {
499 accepted = accepted.saturating_add(1);
500 }
501 }
502 if total == 0 {
503 0.0
504 } else {
505 accepted as f64 / total as f64
506 }
507 }
508
509 /// Write the trace in CSV format.
510 ///
511 /// The fixed columns are `chain_id`, `step`, `accepted`, `proposed`, and
512 /// `log_prob`, followed by the observable columns supplied when the trace
513 /// was created.
514 ///
515 /// ```
516 /// # use std::io;
517 /// use markov_chain_monte_carlo::prelude::{
518 /// ChainId, Trace, TraceError, TraceRecord, TraceStepOutcome,
519 /// };
520 /// # #[derive(Debug)]
521 /// # enum ExampleError {
522 /// # Trace(TraceError),
523 /// # Io(io::Error),
524 /// # }
525 /// # impl From<TraceError> for ExampleError {
526 /// # fn from(err: TraceError) -> Self { Self::Trace(err) }
527 /// # }
528 /// # impl From<io::Error> for ExampleError {
529 /// # fn from(err: io::Error) -> Self { Self::Io(err) }
530 /// # }
531 ///
532 /// let mut trace = Trace::new(["energy"])?;
533 /// trace.push(TraceRecord::new(
534 /// ChainId::new(0),
535 /// 1,
536 /// TraceStepOutcome::accepted(),
537 /// -1.0,
538 /// vec![1.5],
539 /// ))?;
540 ///
541 /// let mut csv = Vec::new();
542 /// trace.write_csv(&mut csv)?;
543 ///
544 /// assert_eq!(
545 /// csv,
546 /// b"chain_id,step,accepted,proposed,log_prob,energy\n0,1,true,true,-1,1.5\n",
547 /// );
548 /// # Ok::<(), ExampleError>(())
549 /// ```
550 ///
551 /// # Errors
552 ///
553 /// Returns any I/O error reported by `writer`.
554 pub fn write_csv(&self, mut writer: impl Write) -> io::Result<()> {
555 writer.write_all(b"chain_id,step,accepted,proposed,log_prob")?;
556 for name in &self.observable_names {
557 writer.write_all(b",")?;
558 write_csv_field(&mut writer, name)?;
559 }
560 writer.write_all(b"\n")?;
561
562 for record in &self.records {
563 write!(
564 writer,
565 "{},{},{},{},{}",
566 record.chain_id.get(),
567 record.step,
568 record.outcome.is_accepted(),
569 record.outcome.had_proposal(),
570 record.log_prob
571 )?;
572 for value in &record.observable_values {
573 write!(writer, ",{value}")?;
574 }
575 writer.write_all(b"\n")?;
576 }
577 Ok(())
578 }
579}
580
581/// Recorder for one chain within a multi-chain trace.
582#[derive(Debug, Clone, PartialEq)]
583#[must_use]
584pub struct TraceRecorder {
585 chain_id: ChainId,
586 trace: Trace,
587}
588
589impl TraceRecorder {
590 /// Create a recorder for one chain.
591 ///
592 /// ```
593 /// use markov_chain_monte_carlo::prelude::{ChainId, TraceError, TraceRecorder};
594 ///
595 /// let recorder = TraceRecorder::new(ChainId::new(3), ["energy"])?;
596 ///
597 /// assert_eq!(recorder.chain_id(), ChainId::new(3));
598 /// assert_eq!(recorder.trace().observable_names(), &["energy"]);
599 /// # Ok::<(), TraceError>(())
600 /// ```
601 ///
602 /// # Errors
603 ///
604 /// Returns [`TraceError`] if an observable name is empty or duplicated.
605 pub fn new(
606 chain_id: ChainId,
607 observable_names: impl IntoIterator<Item = impl Into<String>>,
608 ) -> Result<Self, TraceError> {
609 Ok(Self {
610 chain_id,
611 trace: Trace::new(observable_names)?,
612 })
613 }
614
615 /// Chain identifier used by this recorder.
616 pub const fn chain_id(&self) -> ChainId {
617 self.chain_id
618 }
619
620 /// Borrow the accumulated trace.
621 pub const fn trace(&self) -> &Trace {
622 &self.trace
623 }
624
625 /// Consume the recorder and return the accumulated trace.
626 pub fn into_trace(self) -> Trace {
627 self.trace
628 }
629
630 /// Record the current state of `chain` after a completed step.
631 ///
632 /// Observable values must be supplied in the same order as the recorder's
633 /// observable names.
634 ///
635 /// ```
636 /// use markov_chain_monte_carlo::prelude::{
637 /// Chain, ChainId, McmcError, Target, TraceError, TraceRecorder, TraceStepOutcome,
638 /// };
639 ///
640 /// # #[derive(Debug)]
641 /// # enum ExampleError {
642 /// # Mcmc(McmcError),
643 /// # Trace(TraceError),
644 /// # }
645 /// # impl From<McmcError> for ExampleError {
646 /// # fn from(err: McmcError) -> Self { Self::Mcmc(err) }
647 /// # }
648 /// # impl From<TraceError> for ExampleError {
649 /// # fn from(err: TraceError) -> Self { Self::Trace(err) }
650 /// # }
651 ///
652 /// struct Flat;
653 /// impl Target<i32> for Flat {
654 /// fn log_prob(&self, _: &i32) -> f64 { 0.0 }
655 /// }
656 ///
657 /// let chain = Chain::new(4, &Flat)?;
658 /// let mut recorder = TraceRecorder::new(ChainId::new(0), ["value"])?;
659 ///
660 /// recorder.record(&chain, TraceStepOutcome::accepted(), [f64::from(*chain.state())])?;
661 ///
662 /// assert_eq!(recorder.trace().records()[0].observable_values(), &[4.0]);
663 /// # Ok::<(), ExampleError>(())
664 /// ```
665 ///
666 /// # Errors
667 ///
668 /// Returns [`TraceError::ObservableCountMismatch`] if `observable_values`
669 /// does not match the recorder header.
670 pub fn record<S>(
671 &mut self,
672 chain: &Chain<S>,
673 outcome: TraceStepOutcome,
674 observable_values: impl IntoIterator<Item = f64>,
675 ) -> Result<(), TraceError> {
676 let values = observable_values.into_iter().collect();
677 let record = TraceRecord::new(
678 self.chain_id,
679 chain.total_steps(),
680 outcome,
681 chain.log_prob(),
682 values,
683 );
684 self.trace.push(record)
685 }
686}
687
688/// Validate trace observable names before storing them as CSV/header metadata.
689///
690/// This centralizes the non-empty and uniqueness checks so every trace row can
691/// rely on stable, unambiguous observable columns.
692fn validate_observable_names(
693 observable_names: impl IntoIterator<Item = impl Into<String>>,
694) -> Result<Vec<String>, TraceError> {
695 let names: Vec<_> = observable_names.into_iter().map(Into::into).collect();
696 let mut seen = BTreeSet::new();
697 for (index, name) in names.iter().enumerate() {
698 if name.is_empty() {
699 return Err(TraceError::EmptyObservableName { index });
700 }
701 if !seen.insert(name.as_str()) {
702 return Err(TraceError::DuplicateObservableName { name: name.clone() });
703 }
704 }
705 Ok(names)
706}
707
708/// Write one CSV field, quoting it only when CSV syntax requires escaping.
709fn write_csv_field(writer: &mut impl Write, value: &str) -> io::Result<()> {
710 if value.contains([',', '"', '\n', '\r']) {
711 writer.write_all(b"\"")?;
712 for byte in value.bytes() {
713 if byte == b'"' {
714 writer.write_all(b"\"\"")?;
715 } else {
716 writer.write_all(&[byte])?;
717 }
718 }
719 writer.write_all(b"\"")?;
720 } else {
721 writer.write_all(value.as_bytes())?;
722 }
723 Ok(())
724}
725
726#[cfg(test)]
727mod tests {
728 use approx::assert_relative_eq;
729
730 use super::*;
731 use crate::{McmcError, Target};
732
733 struct Flat;
734
735 impl Target<i32> for Flat {
736 fn log_prob(&self, _: &i32) -> f64 {
737 0.0
738 }
739 }
740
741 #[test]
742 fn recorder_collects_chain_rows() -> Result<(), TraceError> {
743 let chain = Chain::new(3, &Flat).expect("flat target has valid log-probability");
744 let mut recorder = TraceRecorder::new(ChainId::new(2), ["energy", "magnetization"])?;
745
746 recorder.record(&chain, TraceStepOutcome::accepted(), [1.25_f64, -0.5_f64])?;
747
748 let trace = recorder.trace();
749 assert_eq!(trace.len(), 1);
750 assert_eq!(trace.records()[0].chain_id(), ChainId::new(2));
751 assert_eq!(trace.records()[0].step(), 0);
752 assert_eq!(trace.records()[0].outcome(), TraceStepOutcome::accepted());
753 assert_relative_eq!(trace.records()[0].log_prob(), 0.0);
754 assert_eq!(trace.records()[0].observable_values(), &[1.25, -0.5]);
755 assert_relative_eq!(trace.acceptance_rate(ChainId::new(2)), 1.0);
756 Ok(())
757 }
758
759 #[test]
760 fn recorder_accessors_return_metadata_and_trace() -> Result<(), TraceError> {
761 let recorder = TraceRecorder::new(ChainId::new(7), ["energy"])?;
762
763 assert_eq!(recorder.chain_id(), ChainId::new(7));
764 assert_eq!(recorder.trace().observable_names(), &["energy"]);
765
766 let trace = recorder.into_trace();
767 assert_eq!(trace.observable_names(), &["energy"]);
768 assert!(trace.is_empty());
769 Ok(())
770 }
771
772 #[test]
773 fn chain_id_display_formats_raw_identifier() {
774 assert_eq!(ChainId::new(42).to_string(), "42");
775 }
776
777 #[test]
778 fn trace_rejects_bad_headers_and_row_widths() {
779 assert_eq!(
780 Trace::new(["energy", ""]).unwrap_err(),
781 TraceError::EmptyObservableName { index: 1 }
782 );
783 assert_eq!(
784 Trace::new(["energy", "energy"]).unwrap_err(),
785 TraceError::DuplicateObservableName {
786 name: "energy".to_owned()
787 }
788 );
789
790 let mut trace = Trace::new(["energy"]).unwrap();
791 let record = TraceRecord::new(
792 ChainId::new(0),
793 1,
794 TraceStepOutcome::accepted(),
795 0.0,
796 vec![1.0, 2.0],
797 );
798 assert_eq!(
799 trace.push(record).unwrap_err(),
800 TraceError::ObservableCountMismatch {
801 expected: 1,
802 actual: 2
803 }
804 );
805 }
806
807 #[test]
808 fn trace_error_display_messages_include_context() {
809 assert_eq!(
810 TraceError::EmptyObservableName { index: 3 }.to_string(),
811 "observable name at index 3 is empty"
812 );
813 assert_eq!(
814 TraceError::DuplicateObservableName {
815 name: "energy".to_owned()
816 }
817 .to_string(),
818 "observable name \"energy\" appears more than once"
819 );
820 assert_eq!(
821 TraceError::ObservableCountMismatch {
822 expected: 2,
823 actual: 1
824 }
825 .to_string(),
826 "trace row has 1 observable values, expected 2"
827 );
828 assert_eq!(
829 TraceError::ObservableNamesMismatch {
830 expected: vec!["energy".to_owned()],
831 actual: vec!["magnetization".to_owned()]
832 }
833 .to_string(),
834 "trace observable columns differ: expected [\"energy\"], got [\"magnetization\"]"
835 );
836 }
837
838 #[test]
839 fn trace_extend_rejects_mismatched_headers() {
840 let mut trace = Trace::new(["energy"]).expect("valid observable name");
841 let other = Trace::new(["energy", "magnetization"]).expect("valid observable names");
842
843 assert_eq!(
844 trace.extend(other).unwrap_err(),
845 TraceError::ObservableNamesMismatch {
846 expected: vec!["energy".to_owned()],
847 actual: vec!["energy".to_owned(), "magnetization".to_owned()]
848 }
849 );
850 }
851
852 #[test]
853 fn trace_extend_appends_matching_records() -> Result<(), TraceError> {
854 let mut trace = Trace::new(["energy"])?;
855 let mut other = Trace::new(["energy"])?;
856 other
857 .push(TraceRecord::new(
858 ChainId::new(1),
859 9,
860 TraceStepOutcome::accepted(),
861 -3.5,
862 vec![-7.0],
863 ))
864 .expect("row width matches");
865
866 trace.extend(other)?;
867
868 assert_eq!(trace.len(), 1);
869 assert_eq!(trace.records()[0].chain_id(), ChainId::new(1));
870 assert_eq!(trace.records()[0].step(), 9);
871 assert_eq!(trace.records()[0].observable_values(), &[-7.0]);
872 Ok(())
873 }
874
875 #[test]
876 fn trace_extend_reports_equal_width_name_mismatch() {
877 let mut trace = Trace::new(["energy"]).expect("valid observable name");
878 let other = Trace::new(["magnetization"]).expect("valid observable name");
879
880 assert_eq!(
881 trace.extend(other).unwrap_err(),
882 TraceError::ObservableNamesMismatch {
883 expected: vec!["energy".to_owned()],
884 actual: vec!["magnetization".to_owned()]
885 }
886 );
887 }
888
889 #[test]
890 fn trace_writes_csv() -> Result<(), io::Error> {
891 let mut trace = Trace::new(["energy", "quoted,name"]).expect("valid observable names");
892 trace
893 .push(TraceRecord::new(
894 ChainId::new(0),
895 1,
896 TraceStepOutcome::rejected_proposal(),
897 -2.5,
898 vec![3.0, 4.0],
899 ))
900 .expect("row width matches");
901
902 let mut csv = Vec::new();
903 trace.write_csv(&mut csv)?;
904
905 let csv = String::from_utf8(csv).expect("CSV is valid UTF-8");
906 assert_eq!(
907 csv,
908 "chain_id,step,accepted,proposed,log_prob,energy,\"quoted,name\"\n\
909 0,1,false,true,-2.5,3,4\n"
910 );
911 Ok(())
912 }
913
914 #[test]
915 fn trace_csv_escapes_quoted_headers() -> Result<(), io::Error> {
916 let mut trace = Trace::new([r#"quote"field"#]).expect("valid observable name");
917 trace
918 .push(TraceRecord::new(
919 ChainId::new(0),
920 1,
921 TraceStepOutcome::accepted(),
922 0.0,
923 vec![1.0],
924 ))
925 .expect("row width matches");
926
927 let mut csv = Vec::new();
928 trace.write_csv(&mut csv)?;
929
930 let csv = String::from_utf8(csv).expect("CSV is valid UTF-8");
931 assert_eq!(
932 csv,
933 "chain_id,step,accepted,proposed,log_prob,\"quote\"\"field\"\n\
934 0,1,true,true,0,1\n"
935 );
936 Ok(())
937 }
938
939 #[test]
940 fn trace_write_csv_propagates_row_write_errors() {
941 struct FailAfter {
942 accepted: usize,
943 limit: usize,
944 }
945
946 impl Write for FailAfter {
947 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
948 if self.accepted >= self.limit {
949 return Err(io::Error::other("writer full"));
950 }
951 let remaining = self.limit - self.accepted;
952 let written = remaining.min(buf.len());
953 self.accepted += written;
954 Ok(written)
955 }
956
957 fn flush(&mut self) -> io::Result<()> {
958 Ok(())
959 }
960 }
961
962 let mut trace = Trace::new(["energy"]).expect("valid observable name");
963 trace
964 .push(TraceRecord::new(
965 ChainId::new(0),
966 1,
967 TraceStepOutcome::accepted(),
968 0.0,
969 vec![1.0],
970 ))
971 .expect("row width matches");
972 let mut writer = FailAfter {
973 accepted: 0,
974 limit: "chain_id,step,accepted,proposed,log_prob,energy\n".len(),
975 };
976
977 let err = trace.write_csv(&mut writer).unwrap_err();
978
979 assert_eq!(err.kind(), io::ErrorKind::Other);
980 assert_eq!(err.to_string(), "writer full");
981 writer.flush().expect("flush is infallible for test writer");
982 }
983
984 #[test]
985 fn delayed_step_outcome_converts_to_trace_outcome() {
986 assert_eq!(
987 TraceStepOutcome::from(StepOutcome::Accepted),
988 TraceStepOutcome::accepted()
989 );
990 assert_eq!(
991 TraceStepOutcome::from(StepOutcome::RejectedProposal),
992 TraceStepOutcome::rejected_proposal()
993 );
994 assert_eq!(
995 TraceStepOutcome::from(StepOutcome::NoProposal),
996 TraceStepOutcome::no_proposal()
997 );
998 }
999
1000 #[test]
1001 fn delayed_step_reference_converts_to_trace_outcome() {
1002 let step = Step {
1003 outcome: StepOutcome::NoProposal,
1004 info: None::<()>,
1005 log_prob_before: 0.0,
1006 log_prob_after: None,
1007 log_alpha: None,
1008 };
1009
1010 assert_eq!(
1011 TraceStepOutcome::from(&step),
1012 TraceStepOutcome::no_proposal()
1013 );
1014 }
1015
1016 #[test]
1017 fn proposal_acceptance_adapter_preserves_proposal_presence() {
1018 assert_eq!(
1019 TraceStepOutcome::from_proposal_acceptance(true),
1020 TraceStepOutcome::accepted()
1021 );
1022 assert_eq!(
1023 TraceStepOutcome::from_proposal_acceptance(false),
1024 TraceStepOutcome::rejected_proposal()
1025 );
1026 }
1027
1028 #[test]
1029 fn empty_chain_acceptance_rate_is_zero() -> Result<(), TraceError> {
1030 let trace = Trace::new(["energy"])?;
1031
1032 assert_relative_eq!(trace.acceptance_rate(ChainId::new(0)), 0.0);
1033 Ok(())
1034 }
1035
1036 #[test]
1037 fn chain_construction_still_reports_mcmc_errors() {
1038 struct NanTarget;
1039
1040 impl Target<i32> for NanTarget {
1041 fn log_prob(&self, _: &i32) -> f64 {
1042 f64::NAN
1043 }
1044 }
1045
1046 assert_eq!(
1047 Chain::new(0, &NanTarget).unwrap_err(),
1048 McmcError::NanInitialLogProb
1049 );
1050 }
1051}