Skip to main content

markov_chain_monte_carlo/
diagnostics.rs

1//! Trace recording and export helpers for MCMC diagnostics.
2//!
3//! This module stores numeric observable traces independently from plotting,
4//! notebook rendering, or downstream statistical estimators.  A [`Trace`]
5//! contains one row per completed step, a stable [`ChainId`], accept/reject
6//! metadata through [`TraceStepOutcome`], the chain's cached target
7//! log-probability, and caller-defined numeric observable columns.
8//!
9//! # Acceptance signal availability by stepping path
10//!
11//! Populating accept/reject metadata requires an acceptance signal from the
12//! step that produced the row, and the three stepping paths expose different
13//! information:
14//!
15//! - In-place stepping ([`Chain::step_mut`](crate::Chain::step_mut) and
16//!   [`Sampler::step_mut`](crate::Sampler::step_mut)) returns `bool`; convert
17//!   it with [`TraceStepOutcome::from_proposal_acceptance`].
18//! - Delayed stepping ([`Chain::step_delayed`](crate::Chain::step_delayed) and
19//!   [`Sampler::step_delayed`](crate::Sampler::step_delayed)) returns a
20//!   [`Step`]/[`StepOutcome`]; convert it with the [`From`] implementations on
21//!   [`TraceStepOutcome`].
22//! - By-value stepping ([`Chain::step`](crate::Chain::step) and
23//!   [`Sampler::step`](crate::Sampler::step)) returns `Result<(), _>` with no
24//!   acceptance information, so a by-value caller cannot truthfully populate
25//!   accept/reject. Record from an in-place or delayed step when a trace must
26//!   capture acceptance exactly.
27
28use std::collections::BTreeSet;
29use std::error::Error;
30use std::fmt;
31use std::io::{self, Write};
32
33use crate::{Chain, Step, StepOutcome};
34
35/// Stable identifier for one recorded Markov chain.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
37#[must_use]
38pub struct ChainId(usize);
39
40impl ChainId {
41    /// Create a chain identifier from a caller-owned index.
42    ///
43    /// ```
44    /// use markov_chain_monte_carlo::prelude::ChainId;
45    ///
46    /// let id = ChainId::new(2);
47    /// assert_eq!(id.get(), 2);
48    /// ```
49    pub const fn new(id: usize) -> Self {
50        Self(id)
51    }
52
53    /// Return the raw chain identifier.
54    ///
55    /// ```
56    /// use markov_chain_monte_carlo::prelude::ChainId;
57    ///
58    /// assert_eq!(ChainId::new(7).get(), 7);
59    /// ```
60    #[must_use]
61    pub const fn get(self) -> usize {
62        self.0
63    }
64}
65
66impl fmt::Display for ChainId {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        write!(f, "{}", self.0)
69    }
70}
71
72/// Acceptance/proposal outcome recorded for one trace row.
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74#[must_use]
75pub struct TraceStepOutcome {
76    accepted: bool,
77    proposed: bool,
78}
79
80impl TraceStepOutcome {
81    /// A concrete proposal was accepted.
82    ///
83    /// ```
84    /// use markov_chain_monte_carlo::prelude::TraceStepOutcome;
85    ///
86    /// let outcome = TraceStepOutcome::accepted();
87    /// assert!(outcome.is_accepted());
88    /// assert!(outcome.had_proposal());
89    /// ```
90    pub const fn accepted() -> Self {
91        Self {
92            accepted: true,
93            proposed: true,
94        }
95    }
96
97    /// A concrete proposal was rejected by the Metropolis-Hastings draw.
98    ///
99    /// ```
100    /// use markov_chain_monte_carlo::prelude::TraceStepOutcome;
101    ///
102    /// let outcome = TraceStepOutcome::rejected_proposal();
103    /// assert!(!outcome.is_accepted());
104    /// assert!(outcome.had_proposal());
105    /// ```
106    pub const fn rejected_proposal() -> Self {
107        Self {
108            accepted: false,
109            proposed: true,
110        }
111    }
112
113    /// No concrete proposal was available, so the step was a self-loop.
114    ///
115    /// ```
116    /// use markov_chain_monte_carlo::prelude::TraceStepOutcome;
117    ///
118    /// let outcome = TraceStepOutcome::no_proposal();
119    /// assert!(!outcome.is_accepted());
120    /// assert!(!outcome.had_proposal());
121    /// ```
122    pub const fn no_proposal() -> Self {
123        Self {
124            accepted: false,
125            proposed: false,
126        }
127    }
128
129    /// Convert a concrete proposal's boolean acceptance result.
130    ///
131    /// This is the natural adapter for [`crate::Chain::step_mut`] and
132    /// [`crate::Sampler::step_mut`] when the caller knows the in-place proposal
133    /// produced a concrete move.  A `false` value is recorded as
134    /// [`Self::rejected_proposal`], not [`Self::no_proposal`].
135    ///
136    /// ```
137    /// use markov_chain_monte_carlo::prelude::TraceStepOutcome;
138    ///
139    /// let accepted = TraceStepOutcome::from_proposal_acceptance(true);
140    /// let rejected = TraceStepOutcome::from_proposal_acceptance(false);
141    ///
142    /// assert!(accepted.is_accepted());
143    /// assert!(accepted.had_proposal());
144    /// assert!(!rejected.is_accepted());
145    /// assert!(rejected.had_proposal());
146    /// ```
147    pub const fn from_proposal_acceptance(accepted: bool) -> Self {
148        if accepted {
149            Self::accepted()
150        } else {
151            Self::rejected_proposal()
152        }
153    }
154
155    /// Whether the step accepted and committed a concrete proposal.
156    #[must_use]
157    pub const fn is_accepted(self) -> bool {
158        self.accepted
159    }
160
161    /// Whether the step included a concrete proposal.
162    #[must_use]
163    pub const fn had_proposal(self) -> bool {
164        self.proposed
165    }
166}
167
168impl From<StepOutcome> for TraceStepOutcome {
169    fn from(outcome: StepOutcome) -> Self {
170        match outcome {
171            StepOutcome::Accepted => Self::accepted(),
172            StepOutcome::RejectedProposal => Self::rejected_proposal(),
173            StepOutcome::NoProposal => Self::no_proposal(),
174        }
175    }
176}
177
178impl<I> From<&Step<I>> for TraceStepOutcome {
179    fn from(step: &Step<I>) -> Self {
180        step.outcome.into()
181    }
182}
183
184/// One recorded post-step trace row.
185#[derive(Debug, Clone, PartialEq)]
186#[must_use]
187pub struct TraceRecord {
188    chain_id: ChainId,
189    step: usize,
190    outcome: TraceStepOutcome,
191    log_prob: f64,
192    observable_values: Vec<f64>,
193}
194
195impl TraceRecord {
196    /// Create one trace record from already-computed observable values.
197    ///
198    /// Most callers should use [`TraceRecorder::record`] so the step number and
199    /// log-probability come from a live [`Chain`].  Use this constructor when
200    /// assembling rows from an external source or merging trace data manually.
201    ///
202    /// ```
203    /// use markov_chain_monte_carlo::prelude::{ChainId, TraceRecord, TraceStepOutcome};
204    ///
205    /// let record = TraceRecord::new(
206    ///     ChainId::new(0),
207    ///     10,
208    ///     TraceStepOutcome::accepted(),
209    ///     -1.25,
210    ///     vec![3.0, 0.5],
211    /// );
212    ///
213    /// assert_eq!(record.step(), 10);
214    /// assert_eq!(record.observable_values(), &[3.0, 0.5]);
215    /// ```
216    pub const fn new(
217        chain_id: ChainId,
218        step: usize,
219        outcome: TraceStepOutcome,
220        log_prob: f64,
221        observable_values: Vec<f64>,
222    ) -> Self {
223        Self {
224            chain_id,
225            step,
226            outcome,
227            log_prob,
228            observable_values,
229        }
230    }
231
232    /// Chain identifier for this row.
233    pub const fn chain_id(&self) -> ChainId {
234        self.chain_id
235    }
236
237    /// Completed step number for this row.
238    #[must_use]
239    pub const fn step(&self) -> usize {
240        self.step
241    }
242
243    /// Acceptance/proposal outcome for this row.
244    pub const fn outcome(&self) -> TraceStepOutcome {
245        self.outcome
246    }
247
248    /// Cached target log-probability after this step.
249    #[must_use]
250    pub const fn log_prob(&self) -> f64 {
251        self.log_prob
252    }
253
254    /// Numeric observable values in the same order as the trace headers.
255    #[must_use]
256    pub fn observable_values(&self) -> &[f64] {
257        &self.observable_values
258    }
259}
260
261/// Errors returned while constructing trace data.
262#[derive(Debug, Clone, PartialEq, Eq)]
263#[non_exhaustive]
264pub enum TraceError {
265    /// An observable name was empty.
266    EmptyObservableName {
267        /// Zero-based position of the empty name.
268        index: usize,
269    },
270    /// An observable name appeared more than once.
271    DuplicateObservableName {
272        /// Duplicated observable name.
273        name: String,
274    },
275    /// A row had a different number of values than the trace header.
276    ObservableCountMismatch {
277        /// Number of values required by the header.
278        expected: usize,
279        /// Number of values provided for the row.
280        actual: usize,
281    },
282    /// Two traces used different observable columns.
283    ObservableNamesMismatch {
284        /// Observable columns required by the receiving trace.
285        expected: Vec<String>,
286        /// Observable columns provided by the appended trace.
287        actual: Vec<String>,
288    },
289}
290
291impl fmt::Display for TraceError {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        match self {
294            Self::EmptyObservableName { index } => {
295                write!(f, "observable name at index {index} is empty")
296            }
297            Self::DuplicateObservableName { name } => {
298                write!(f, "observable name {name:?} appears more than once")
299            }
300            Self::ObservableCountMismatch { expected, actual } => write!(
301                f,
302                "trace row has {actual} observable values, expected {expected}"
303            ),
304            Self::ObservableNamesMismatch { expected, actual } => write!(
305                f,
306                "trace observable columns differ: expected {expected:?}, got {actual:?}"
307            ),
308        }
309    }
310}
311
312impl Error for TraceError {}
313
314/// Multi-chain numeric trace with shared observable columns.
315#[derive(Debug, Clone, PartialEq)]
316#[must_use]
317pub struct Trace {
318    observable_names: Vec<String>,
319    records: Vec<TraceRecord>,
320}
321
322impl Trace {
323    /// Create an empty trace with the given observable columns.
324    ///
325    /// ```
326    /// use markov_chain_monte_carlo::prelude::{Trace, TraceError};
327    ///
328    /// let trace = Trace::new(["energy", "magnetization"])?;
329    ///
330    /// assert!(trace.is_empty());
331    /// assert_eq!(trace.observable_names(), &["energy", "magnetization"]);
332    /// # Ok::<(), TraceError>(())
333    /// ```
334    ///
335    /// # Errors
336    ///
337    /// Returns [`TraceError`] if an observable name is empty or duplicated.
338    pub fn new(
339        observable_names: impl IntoIterator<Item = impl Into<String>>,
340    ) -> Result<Self, TraceError> {
341        Ok(Self {
342            observable_names: validate_observable_names(observable_names)?,
343            records: Vec::new(),
344        })
345    }
346
347    /// Observable column names.
348    #[must_use]
349    pub fn observable_names(&self) -> &[String] {
350        &self.observable_names
351    }
352
353    /// Recorded rows.
354    pub fn records(&self) -> &[TraceRecord] {
355        &self.records
356    }
357
358    /// Whether the trace contains no rows.
359    #[must_use]
360    pub const fn is_empty(&self) -> bool {
361        self.records.is_empty()
362    }
363
364    /// Number of recorded rows.
365    #[must_use]
366    pub const fn len(&self) -> usize {
367        self.records.len()
368    }
369
370    /// Append one record.
371    ///
372    /// ```
373    /// use markov_chain_monte_carlo::prelude::{
374    ///     ChainId, Trace, TraceError, TraceRecord, TraceStepOutcome,
375    /// };
376    ///
377    /// let mut trace = Trace::new(["energy"])?;
378    /// trace.push(TraceRecord::new(
379    ///     ChainId::new(0),
380    ///     1,
381    ///     TraceStepOutcome::accepted(),
382    ///     -2.0,
383    ///     vec![2.0],
384    /// ))?;
385    ///
386    /// assert_eq!(trace.len(), 1);
387    /// # Ok::<(), TraceError>(())
388    /// ```
389    ///
390    /// # Errors
391    ///
392    /// Returns [`TraceError::ObservableCountMismatch`] if the row does not
393    /// match this trace's observable columns.
394    pub fn push(&mut self, record: TraceRecord) -> Result<(), TraceError> {
395        let actual = record.observable_values.len();
396        let expected = self.observable_names.len();
397        if actual != expected {
398            return Err(TraceError::ObservableCountMismatch { expected, actual });
399        }
400        self.records.push(record);
401        Ok(())
402    }
403
404    /// Append all rows from another trace with identical observable columns.
405    ///
406    /// ```
407    /// use markov_chain_monte_carlo::prelude::{Trace, TraceError};
408    ///
409    /// let mut combined = Trace::new(["energy"])?;
410    /// let other = Trace::new(["energy"])?;
411    ///
412    /// combined.extend(other)?;
413    /// assert!(combined.is_empty());
414    /// # Ok::<(), TraceError>(())
415    /// ```
416    ///
417    /// # Errors
418    ///
419    /// Returns [`TraceError::ObservableNamesMismatch`] if the traces use
420    /// different observable columns.
421    pub fn extend(&mut self, other: Self) -> Result<(), TraceError> {
422        if self.observable_names != other.observable_names {
423            return Err(TraceError::ObservableNamesMismatch {
424                expected: self.observable_names.clone(),
425                actual: other.observable_names,
426            });
427        }
428        self.records.extend(other.records);
429        Ok(())
430    }
431
432    /// Iterate over rows for one chain.
433    ///
434    /// ```
435    /// use markov_chain_monte_carlo::prelude::{
436    ///     ChainId, Trace, TraceError, TraceRecord, TraceStepOutcome,
437    /// };
438    ///
439    /// let mut trace = Trace::new(["energy"])?;
440    /// trace.push(TraceRecord::new(
441    ///     ChainId::new(1),
442    ///     1,
443    ///     TraceStepOutcome::accepted(),
444    ///     0.0,
445    ///     vec![1.0],
446    /// ))?;
447    ///
448    /// assert_eq!(trace.records_for_chain(ChainId::new(1)).count(), 1);
449    /// assert_eq!(trace.records_for_chain(ChainId::new(0)).count(), 0);
450    /// # Ok::<(), TraceError>(())
451    /// ```
452    pub fn records_for_chain(&self, chain_id: ChainId) -> impl Iterator<Item = &TraceRecord> + '_ {
453        self.records
454            .iter()
455            .filter(move |record| record.chain_id == chain_id)
456    }
457
458    /// Acceptance rate for one chain, counting no-proposal self-loops as
459    /// rejected steps.
460    ///
461    /// Returns `0.0` when no rows exist for `chain_id`.
462    ///
463    /// ```
464    /// use approx::assert_relative_eq;
465    /// use markov_chain_monte_carlo::prelude::{
466    ///     ChainId, Trace, TraceError, TraceRecord, TraceStepOutcome,
467    /// };
468    ///
469    /// let mut trace = Trace::new(["energy"])?;
470    /// trace.push(TraceRecord::new(
471    ///     ChainId::new(0),
472    ///     1,
473    ///     TraceStepOutcome::accepted(),
474    ///     0.0,
475    ///     vec![1.0],
476    /// ))?;
477    /// trace.push(TraceRecord::new(
478    ///     ChainId::new(0),
479    ///     2,
480    ///     TraceStepOutcome::rejected_proposal(),
481    ///     0.0,
482    ///     vec![1.0],
483    /// ))?;
484    ///
485    /// assert_relative_eq!(trace.acceptance_rate(ChainId::new(0)), 0.5);
486    /// # Ok::<(), TraceError>(())
487    /// ```
488    #[must_use]
489    #[expect(
490        clippy::cast_precision_loss,
491        reason = "trace lengths are diagnostic counters and fit in f64 for practical runs"
492    )]
493    pub fn acceptance_rate(&self, chain_id: ChainId) -> f64 {
494        let mut accepted = 0_usize;
495        let mut total = 0_usize;
496        for record in self.records_for_chain(chain_id) {
497            total = total.saturating_add(1);
498            if record.outcome.is_accepted() {
499                accepted = accepted.saturating_add(1);
500            }
501        }
502        if total == 0 {
503            0.0
504        } else {
505            accepted as f64 / total as f64
506        }
507    }
508
509    /// Write the trace in CSV format.
510    ///
511    /// The fixed columns are `chain_id`, `step`, `accepted`, `proposed`, and
512    /// `log_prob`, followed by the observable columns supplied when the trace
513    /// was created.
514    ///
515    /// ```
516    /// # use std::io;
517    /// use markov_chain_monte_carlo::prelude::{
518    ///     ChainId, Trace, TraceError, TraceRecord, TraceStepOutcome,
519    /// };
520    /// # #[derive(Debug)]
521    /// # enum ExampleError {
522    /// #     Trace(TraceError),
523    /// #     Io(io::Error),
524    /// # }
525    /// # impl From<TraceError> for ExampleError {
526    /// #     fn from(err: TraceError) -> Self { Self::Trace(err) }
527    /// # }
528    /// # impl From<io::Error> for ExampleError {
529    /// #     fn from(err: io::Error) -> Self { Self::Io(err) }
530    /// # }
531    ///
532    /// let mut trace = Trace::new(["energy"])?;
533    /// trace.push(TraceRecord::new(
534    ///     ChainId::new(0),
535    ///     1,
536    ///     TraceStepOutcome::accepted(),
537    ///     -1.0,
538    ///     vec![1.5],
539    /// ))?;
540    ///
541    /// let mut csv = Vec::new();
542    /// trace.write_csv(&mut csv)?;
543    ///
544    /// assert_eq!(
545    ///     csv,
546    ///     b"chain_id,step,accepted,proposed,log_prob,energy\n0,1,true,true,-1,1.5\n",
547    /// );
548    /// # Ok::<(), ExampleError>(())
549    /// ```
550    ///
551    /// # Errors
552    ///
553    /// Returns any I/O error reported by `writer`.
554    pub fn write_csv(&self, mut writer: impl Write) -> io::Result<()> {
555        writer.write_all(b"chain_id,step,accepted,proposed,log_prob")?;
556        for name in &self.observable_names {
557            writer.write_all(b",")?;
558            write_csv_field(&mut writer, name)?;
559        }
560        writer.write_all(b"\n")?;
561
562        for record in &self.records {
563            write!(
564                writer,
565                "{},{},{},{},{}",
566                record.chain_id.get(),
567                record.step,
568                record.outcome.is_accepted(),
569                record.outcome.had_proposal(),
570                record.log_prob
571            )?;
572            for value in &record.observable_values {
573                write!(writer, ",{value}")?;
574            }
575            writer.write_all(b"\n")?;
576        }
577        Ok(())
578    }
579}
580
581/// Recorder for one chain within a multi-chain trace.
582#[derive(Debug, Clone, PartialEq)]
583#[must_use]
584pub struct TraceRecorder {
585    chain_id: ChainId,
586    trace: Trace,
587}
588
589impl TraceRecorder {
590    /// Create a recorder for one chain.
591    ///
592    /// ```
593    /// use markov_chain_monte_carlo::prelude::{ChainId, TraceError, TraceRecorder};
594    ///
595    /// let recorder = TraceRecorder::new(ChainId::new(3), ["energy"])?;
596    ///
597    /// assert_eq!(recorder.chain_id(), ChainId::new(3));
598    /// assert_eq!(recorder.trace().observable_names(), &["energy"]);
599    /// # Ok::<(), TraceError>(())
600    /// ```
601    ///
602    /// # Errors
603    ///
604    /// Returns [`TraceError`] if an observable name is empty or duplicated.
605    pub fn new(
606        chain_id: ChainId,
607        observable_names: impl IntoIterator<Item = impl Into<String>>,
608    ) -> Result<Self, TraceError> {
609        Ok(Self {
610            chain_id,
611            trace: Trace::new(observable_names)?,
612        })
613    }
614
615    /// Chain identifier used by this recorder.
616    pub const fn chain_id(&self) -> ChainId {
617        self.chain_id
618    }
619
620    /// Borrow the accumulated trace.
621    pub const fn trace(&self) -> &Trace {
622        &self.trace
623    }
624
625    /// Consume the recorder and return the accumulated trace.
626    pub fn into_trace(self) -> Trace {
627        self.trace
628    }
629
630    /// Record the current state of `chain` after a completed step.
631    ///
632    /// Observable values must be supplied in the same order as the recorder's
633    /// observable names.
634    ///
635    /// ```
636    /// use markov_chain_monte_carlo::prelude::{
637    ///     Chain, ChainId, McmcError, Target, TraceError, TraceRecorder, TraceStepOutcome,
638    /// };
639    ///
640    /// # #[derive(Debug)]
641    /// # enum ExampleError {
642    /// #     Mcmc(McmcError),
643    /// #     Trace(TraceError),
644    /// # }
645    /// # impl From<McmcError> for ExampleError {
646    /// #     fn from(err: McmcError) -> Self { Self::Mcmc(err) }
647    /// # }
648    /// # impl From<TraceError> for ExampleError {
649    /// #     fn from(err: TraceError) -> Self { Self::Trace(err) }
650    /// # }
651    ///
652    /// struct Flat;
653    /// impl Target<i32> for Flat {
654    ///     fn log_prob(&self, _: &i32) -> f64 { 0.0 }
655    /// }
656    ///
657    /// let chain = Chain::new(4, &Flat)?;
658    /// let mut recorder = TraceRecorder::new(ChainId::new(0), ["value"])?;
659    ///
660    /// recorder.record(&chain, TraceStepOutcome::accepted(), [f64::from(*chain.state())])?;
661    ///
662    /// assert_eq!(recorder.trace().records()[0].observable_values(), &[4.0]);
663    /// # Ok::<(), ExampleError>(())
664    /// ```
665    ///
666    /// # Errors
667    ///
668    /// Returns [`TraceError::ObservableCountMismatch`] if `observable_values`
669    /// does not match the recorder header.
670    pub fn record<S>(
671        &mut self,
672        chain: &Chain<S>,
673        outcome: TraceStepOutcome,
674        observable_values: impl IntoIterator<Item = f64>,
675    ) -> Result<(), TraceError> {
676        let values = observable_values.into_iter().collect();
677        let record = TraceRecord::new(
678            self.chain_id,
679            chain.total_steps(),
680            outcome,
681            chain.log_prob(),
682            values,
683        );
684        self.trace.push(record)
685    }
686}
687
688/// Validate trace observable names before storing them as CSV/header metadata.
689///
690/// This centralizes the non-empty and uniqueness checks so every trace row can
691/// rely on stable, unambiguous observable columns.
692fn validate_observable_names(
693    observable_names: impl IntoIterator<Item = impl Into<String>>,
694) -> Result<Vec<String>, TraceError> {
695    let names: Vec<_> = observable_names.into_iter().map(Into::into).collect();
696    let mut seen = BTreeSet::new();
697    for (index, name) in names.iter().enumerate() {
698        if name.is_empty() {
699            return Err(TraceError::EmptyObservableName { index });
700        }
701        if !seen.insert(name.as_str()) {
702            return Err(TraceError::DuplicateObservableName { name: name.clone() });
703        }
704    }
705    Ok(names)
706}
707
708/// Write one CSV field, quoting it only when CSV syntax requires escaping.
709fn write_csv_field(writer: &mut impl Write, value: &str) -> io::Result<()> {
710    if value.contains([',', '"', '\n', '\r']) {
711        writer.write_all(b"\"")?;
712        for byte in value.bytes() {
713            if byte == b'"' {
714                writer.write_all(b"\"\"")?;
715            } else {
716                writer.write_all(&[byte])?;
717            }
718        }
719        writer.write_all(b"\"")?;
720    } else {
721        writer.write_all(value.as_bytes())?;
722    }
723    Ok(())
724}
725
726#[cfg(test)]
727mod tests {
728    use approx::assert_relative_eq;
729
730    use super::*;
731    use crate::{McmcError, Target};
732
733    struct Flat;
734
735    impl Target<i32> for Flat {
736        fn log_prob(&self, _: &i32) -> f64 {
737            0.0
738        }
739    }
740
741    #[test]
742    fn recorder_collects_chain_rows() -> Result<(), TraceError> {
743        let chain = Chain::new(3, &Flat).expect("flat target has valid log-probability");
744        let mut recorder = TraceRecorder::new(ChainId::new(2), ["energy", "magnetization"])?;
745
746        recorder.record(&chain, TraceStepOutcome::accepted(), [1.25_f64, -0.5_f64])?;
747
748        let trace = recorder.trace();
749        assert_eq!(trace.len(), 1);
750        assert_eq!(trace.records()[0].chain_id(), ChainId::new(2));
751        assert_eq!(trace.records()[0].step(), 0);
752        assert_eq!(trace.records()[0].outcome(), TraceStepOutcome::accepted());
753        assert_relative_eq!(trace.records()[0].log_prob(), 0.0);
754        assert_eq!(trace.records()[0].observable_values(), &[1.25, -0.5]);
755        assert_relative_eq!(trace.acceptance_rate(ChainId::new(2)), 1.0);
756        Ok(())
757    }
758
759    #[test]
760    fn recorder_accessors_return_metadata_and_trace() -> Result<(), TraceError> {
761        let recorder = TraceRecorder::new(ChainId::new(7), ["energy"])?;
762
763        assert_eq!(recorder.chain_id(), ChainId::new(7));
764        assert_eq!(recorder.trace().observable_names(), &["energy"]);
765
766        let trace = recorder.into_trace();
767        assert_eq!(trace.observable_names(), &["energy"]);
768        assert!(trace.is_empty());
769        Ok(())
770    }
771
772    #[test]
773    fn chain_id_display_formats_raw_identifier() {
774        assert_eq!(ChainId::new(42).to_string(), "42");
775    }
776
777    #[test]
778    fn trace_rejects_bad_headers_and_row_widths() {
779        assert_eq!(
780            Trace::new(["energy", ""]).unwrap_err(),
781            TraceError::EmptyObservableName { index: 1 }
782        );
783        assert_eq!(
784            Trace::new(["energy", "energy"]).unwrap_err(),
785            TraceError::DuplicateObservableName {
786                name: "energy".to_owned()
787            }
788        );
789
790        let mut trace = Trace::new(["energy"]).unwrap();
791        let record = TraceRecord::new(
792            ChainId::new(0),
793            1,
794            TraceStepOutcome::accepted(),
795            0.0,
796            vec![1.0, 2.0],
797        );
798        assert_eq!(
799            trace.push(record).unwrap_err(),
800            TraceError::ObservableCountMismatch {
801                expected: 1,
802                actual: 2
803            }
804        );
805    }
806
807    #[test]
808    fn trace_error_display_messages_include_context() {
809        assert_eq!(
810            TraceError::EmptyObservableName { index: 3 }.to_string(),
811            "observable name at index 3 is empty"
812        );
813        assert_eq!(
814            TraceError::DuplicateObservableName {
815                name: "energy".to_owned()
816            }
817            .to_string(),
818            "observable name \"energy\" appears more than once"
819        );
820        assert_eq!(
821            TraceError::ObservableCountMismatch {
822                expected: 2,
823                actual: 1
824            }
825            .to_string(),
826            "trace row has 1 observable values, expected 2"
827        );
828        assert_eq!(
829            TraceError::ObservableNamesMismatch {
830                expected: vec!["energy".to_owned()],
831                actual: vec!["magnetization".to_owned()]
832            }
833            .to_string(),
834            "trace observable columns differ: expected [\"energy\"], got [\"magnetization\"]"
835        );
836    }
837
838    #[test]
839    fn trace_extend_rejects_mismatched_headers() {
840        let mut trace = Trace::new(["energy"]).expect("valid observable name");
841        let other = Trace::new(["energy", "magnetization"]).expect("valid observable names");
842
843        assert_eq!(
844            trace.extend(other).unwrap_err(),
845            TraceError::ObservableNamesMismatch {
846                expected: vec!["energy".to_owned()],
847                actual: vec!["energy".to_owned(), "magnetization".to_owned()]
848            }
849        );
850    }
851
852    #[test]
853    fn trace_extend_appends_matching_records() -> Result<(), TraceError> {
854        let mut trace = Trace::new(["energy"])?;
855        let mut other = Trace::new(["energy"])?;
856        other
857            .push(TraceRecord::new(
858                ChainId::new(1),
859                9,
860                TraceStepOutcome::accepted(),
861                -3.5,
862                vec![-7.0],
863            ))
864            .expect("row width matches");
865
866        trace.extend(other)?;
867
868        assert_eq!(trace.len(), 1);
869        assert_eq!(trace.records()[0].chain_id(), ChainId::new(1));
870        assert_eq!(trace.records()[0].step(), 9);
871        assert_eq!(trace.records()[0].observable_values(), &[-7.0]);
872        Ok(())
873    }
874
875    #[test]
876    fn trace_extend_reports_equal_width_name_mismatch() {
877        let mut trace = Trace::new(["energy"]).expect("valid observable name");
878        let other = Trace::new(["magnetization"]).expect("valid observable name");
879
880        assert_eq!(
881            trace.extend(other).unwrap_err(),
882            TraceError::ObservableNamesMismatch {
883                expected: vec!["energy".to_owned()],
884                actual: vec!["magnetization".to_owned()]
885            }
886        );
887    }
888
889    #[test]
890    fn trace_writes_csv() -> Result<(), io::Error> {
891        let mut trace = Trace::new(["energy", "quoted,name"]).expect("valid observable names");
892        trace
893            .push(TraceRecord::new(
894                ChainId::new(0),
895                1,
896                TraceStepOutcome::rejected_proposal(),
897                -2.5,
898                vec![3.0, 4.0],
899            ))
900            .expect("row width matches");
901
902        let mut csv = Vec::new();
903        trace.write_csv(&mut csv)?;
904
905        let csv = String::from_utf8(csv).expect("CSV is valid UTF-8");
906        assert_eq!(
907            csv,
908            "chain_id,step,accepted,proposed,log_prob,energy,\"quoted,name\"\n\
909             0,1,false,true,-2.5,3,4\n"
910        );
911        Ok(())
912    }
913
914    #[test]
915    fn trace_csv_escapes_quoted_headers() -> Result<(), io::Error> {
916        let mut trace = Trace::new([r#"quote"field"#]).expect("valid observable name");
917        trace
918            .push(TraceRecord::new(
919                ChainId::new(0),
920                1,
921                TraceStepOutcome::accepted(),
922                0.0,
923                vec![1.0],
924            ))
925            .expect("row width matches");
926
927        let mut csv = Vec::new();
928        trace.write_csv(&mut csv)?;
929
930        let csv = String::from_utf8(csv).expect("CSV is valid UTF-8");
931        assert_eq!(
932            csv,
933            "chain_id,step,accepted,proposed,log_prob,\"quote\"\"field\"\n\
934             0,1,true,true,0,1\n"
935        );
936        Ok(())
937    }
938
939    #[test]
940    fn trace_write_csv_propagates_row_write_errors() {
941        struct FailAfter {
942            accepted: usize,
943            limit: usize,
944        }
945
946        impl Write for FailAfter {
947            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
948                if self.accepted >= self.limit {
949                    return Err(io::Error::other("writer full"));
950                }
951                let remaining = self.limit - self.accepted;
952                let written = remaining.min(buf.len());
953                self.accepted += written;
954                Ok(written)
955            }
956
957            fn flush(&mut self) -> io::Result<()> {
958                Ok(())
959            }
960        }
961
962        let mut trace = Trace::new(["energy"]).expect("valid observable name");
963        trace
964            .push(TraceRecord::new(
965                ChainId::new(0),
966                1,
967                TraceStepOutcome::accepted(),
968                0.0,
969                vec![1.0],
970            ))
971            .expect("row width matches");
972        let mut writer = FailAfter {
973            accepted: 0,
974            limit: "chain_id,step,accepted,proposed,log_prob,energy\n".len(),
975        };
976
977        let err = trace.write_csv(&mut writer).unwrap_err();
978
979        assert_eq!(err.kind(), io::ErrorKind::Other);
980        assert_eq!(err.to_string(), "writer full");
981        writer.flush().expect("flush is infallible for test writer");
982    }
983
984    #[test]
985    fn delayed_step_outcome_converts_to_trace_outcome() {
986        assert_eq!(
987            TraceStepOutcome::from(StepOutcome::Accepted),
988            TraceStepOutcome::accepted()
989        );
990        assert_eq!(
991            TraceStepOutcome::from(StepOutcome::RejectedProposal),
992            TraceStepOutcome::rejected_proposal()
993        );
994        assert_eq!(
995            TraceStepOutcome::from(StepOutcome::NoProposal),
996            TraceStepOutcome::no_proposal()
997        );
998    }
999
1000    #[test]
1001    fn delayed_step_reference_converts_to_trace_outcome() {
1002        let step = Step {
1003            outcome: StepOutcome::NoProposal,
1004            info: None::<()>,
1005            log_prob_before: 0.0,
1006            log_prob_after: None,
1007            log_alpha: None,
1008        };
1009
1010        assert_eq!(
1011            TraceStepOutcome::from(&step),
1012            TraceStepOutcome::no_proposal()
1013        );
1014    }
1015
1016    #[test]
1017    fn proposal_acceptance_adapter_preserves_proposal_presence() {
1018        assert_eq!(
1019            TraceStepOutcome::from_proposal_acceptance(true),
1020            TraceStepOutcome::accepted()
1021        );
1022        assert_eq!(
1023            TraceStepOutcome::from_proposal_acceptance(false),
1024            TraceStepOutcome::rejected_proposal()
1025        );
1026    }
1027
1028    #[test]
1029    fn empty_chain_acceptance_rate_is_zero() -> Result<(), TraceError> {
1030        let trace = Trace::new(["energy"])?;
1031
1032        assert_relative_eq!(trace.acceptance_rate(ChainId::new(0)), 0.0);
1033        Ok(())
1034    }
1035
1036    #[test]
1037    fn chain_construction_still_reports_mcmc_errors() {
1038        struct NanTarget;
1039
1040        impl Target<i32> for NanTarget {
1041            fn log_prob(&self, _: &i32) -> f64 {
1042                f64::NAN
1043            }
1044        }
1045
1046        assert_eq!(
1047            Chain::new(0, &NanTarget).unwrap_err(),
1048            McmcError::NanInitialLogProb
1049        );
1050    }
1051}