fugue/runtime/
trace.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/runtime/trace.md"))]
2
3use crate::core::address::Address;
4use crate::error::{FugueError, FugueResult};
5use std::collections::BTreeMap;
6
7/// Type-safe storage for values from different distribution types.
8///
9/// ChoiceValue enables traces to store values from any supported distribution
10/// while maintaining type safety. Each variant corresponds to a distribution
11/// return type, preventing runtime type errors.
12///
13/// Example:
14/// ```rust
15/// # use fugue::runtime::trace::ChoiceValue;
16///
17/// // Different value types from distributions
18/// let continuous = ChoiceValue::F64(3.14159);  // Normal, Uniform, etc.
19/// let discrete = ChoiceValue::U64(42);         // Poisson, Binomial
20/// let categorical = ChoiceValue::Usize(2);     // Categorical selection
21/// let binary = ChoiceValue::Bool(true);        // Bernoulli outcome
22///
23/// // Type-safe extraction
24/// assert_eq!(continuous.as_f64(), Some(3.14159));
25/// assert_eq!(discrete.as_u64(), Some(42));
26/// assert_eq!(binary.as_bool(), Some(true));
27///
28/// // Type mismatches return None
29/// assert_eq!(continuous.as_bool(), None);
30/// ```
31#[derive(Clone, Debug, PartialEq)]
32pub enum ChoiceValue {
33    /// Floating-point value (continuous distributions).
34    F64(f64),
35    /// Signed integer value.
36    I64(i64),
37    /// Unsigned integer value (Poisson, Binomial counts).
38    U64(u64),
39    /// Array index value (Categorical choices).
40    Usize(usize),
41    /// Boolean value (Bernoulli outcomes).
42    Bool(bool),
43}
44impl ChoiceValue {
45    /// Try to extract an f64 value, returning None if the type doesn't match.
46    pub fn as_f64(&self) -> Option<f64> {
47        match self {
48            ChoiceValue::F64(v) => Some(*v),
49            _ => None,
50        }
51    }
52
53    /// Try to extract a bool value, returning None if the type doesn't match.
54    pub fn as_bool(&self) -> Option<bool> {
55        match self {
56            ChoiceValue::Bool(v) => Some(*v),
57            _ => None,
58        }
59    }
60
61    /// Try to extract a u64 value, returning None if the type doesn't match.
62    pub fn as_u64(&self) -> Option<u64> {
63        match self {
64            ChoiceValue::U64(v) => Some(*v),
65            _ => None,
66        }
67    }
68
69    /// Try to extract a usize value, returning None if the type doesn't match.
70    pub fn as_usize(&self) -> Option<usize> {
71        match self {
72            ChoiceValue::Usize(v) => Some(*v),
73            _ => None,
74        }
75    }
76
77    /// Try to extract an i64 value, returning None if the type doesn't match.
78    pub fn as_i64(&self) -> Option<i64> {
79        match self {
80            ChoiceValue::I64(v) => Some(*v),
81            _ => None,
82        }
83    }
84
85    /// Get the type name as a string for error messages.
86    pub fn type_name(&self) -> &'static str {
87        match self {
88            ChoiceValue::F64(_) => "f64",
89            ChoiceValue::Bool(_) => "bool",
90            ChoiceValue::U64(_) => "u64",
91            ChoiceValue::Usize(_) => "usize",
92            ChoiceValue::I64(_) => "i64",
93        }
94    }
95}
96
97/// A single recorded choice made during model execution.
98///
99/// Each Choice represents a random variable assignment at a specific address,
100/// complete with the value chosen and its log-probability. Choices form the
101/// building blocks of execution traces.
102///
103/// Example:
104/// ```rust
105/// # use fugue::*;
106/// # use fugue::runtime::trace::{Choice, ChoiceValue};
107///
108/// // Choices are typically created by handlers during execution
109/// let choice = Choice {
110///     addr: addr!("theta"),
111///     value: ChoiceValue::F64(1.5),
112///     logp: -0.918, // log-probability under generating distribution
113/// };
114///
115/// println!("Choice at {}: {:?} (logp: {:.3})",
116///          choice.addr, choice.value, choice.logp);
117///
118/// // Extract the value with type safety
119/// if let Some(val) = choice.value.as_f64() {
120///     println!("Theta value: {:.3}", val);
121/// }
122/// ```
123#[derive(Clone, Debug)]
124pub struct Choice {
125    /// Address where this choice was made.
126    pub addr: Address,
127    /// Value that was chosen.
128    pub value: ChoiceValue,
129    /// Log-probability of this value under the generating distribution.
130    pub logp: f64,
131}
132
133/// Complete execution trace of a probabilistic model.
134///
135/// A Trace records the complete execution history of a probabilistic model,
136/// including all choices made and accumulated log-weights from different sources.
137/// This enables replay, scoring, and inference operations.
138///
139/// Example:
140/// ```rust
141/// # use fugue::*;
142/// # use fugue::runtime::interpreters::PriorHandler;
143/// # use rand::rngs::StdRng;
144/// # use rand::SeedableRng;
145///
146/// // Execute a model and examine the trace
147/// let model = sample(addr!("theta"), Normal::new(0.0, 1.0).unwrap())
148///     .bind(|theta| observe(addr!("y"), Normal::new(theta, 0.5).unwrap(), 1.2)
149///         .map(move |_| theta));
150///
151/// let mut rng = StdRng::seed_from_u64(42);
152/// let (result, trace) = runtime::handler::run(
153///     PriorHandler { rng: &mut rng, trace: Trace::default() },
154///     model
155/// );
156///
157/// // Examine trace components
158/// println!("Sampled theta: {:.3}", result);
159/// println!("Prior log-weight: {:.3}", trace.log_prior);
160/// println!("Likelihood log-weight: {:.3}", trace.log_likelihood);
161/// println!("Total log-weight: {:.3}", trace.total_log_weight());
162///
163/// // Type-safe value access
164/// let theta_value = trace.get_f64(&addr!("theta")).unwrap();
165/// assert_eq!(theta_value, result);
166/// ```
167#[derive(Clone, Debug, Default)]
168pub struct Trace {
169    /// Map from addresses to the choices made at those sites.
170    pub choices: BTreeMap<Address, Choice>,
171    /// Accumulated log-prior probability from all sampling sites.
172    pub log_prior: f64,
173    /// Accumulated log-likelihood from all observation sites.
174    pub log_likelihood: f64,
175    /// Accumulated log-weight from all factor statements.
176    pub log_factors: f64,
177}
178
179impl Trace {
180    /// Compute the total unnormalized log-probability of this execution.
181    ///
182    /// The total log-weight combines all three components (prior, likelihood, factors)
183    /// and represents the unnormalized log-probability of this execution path.
184    ///
185    /// Example:
186    /// ```rust
187    /// # use fugue::runtime::trace::Trace;
188    ///
189    /// let trace = Trace {
190    ///     log_prior: -1.5,
191    ///     log_likelihood: -2.3,
192    ///     log_factors: 0.8,
193    ///     ..Default::default()
194    /// };
195    ///
196    /// assert_eq!(trace.total_log_weight(), -3.0);
197    /// ```
198    pub fn total_log_weight(&self) -> f64 {
199        self.log_prior + self.log_likelihood + self.log_factors
200    }
201
202    /// Type-safe accessor for f64 values in the trace.
203    pub fn get_f64(&self, addr: &Address) -> Option<f64> {
204        self.choices.get(addr)?.value.as_f64()
205    }
206
207    /// Type-safe accessor for bool values in the trace.
208    pub fn get_bool(&self, addr: &Address) -> Option<bool> {
209        self.choices.get(addr)?.value.as_bool()
210    }
211
212    /// Type-safe accessor for u64 values in the trace.
213    pub fn get_u64(&self, addr: &Address) -> Option<u64> {
214        self.choices.get(addr)?.value.as_u64()
215    }
216
217    /// Type-safe accessor for usize values in the trace.
218    pub fn get_usize(&self, addr: &Address) -> Option<usize> {
219        self.choices.get(addr)?.value.as_usize()
220    }
221
222    /// Type-safe accessor for i64 values in the trace.
223    pub fn get_i64(&self, addr: &Address) -> Option<i64> {
224        self.choices.get(addr)?.value.as_i64()
225    }
226
227    /// Type-safe accessor that returns a Result for better error handling.
228    pub fn get_f64_result(&self, addr: &Address) -> FugueResult<f64> {
229        let choice = self.choices.get(addr).ok_or_else(|| {
230            FugueError::trace_error(
231                "get_f64",
232                Some(addr.clone()),
233                "Address not found in trace",
234                crate::error::ErrorCode::TraceAddressNotFound,
235            )
236        })?;
237
238        choice
239            .value
240            .as_f64()
241            .ok_or_else(|| FugueError::type_mismatch(addr.clone(), "f64", choice.value.type_name()))
242    }
243
244    /// Type-safe accessor that returns a Result for better error handling.
245    pub fn get_bool_result(&self, addr: &Address) -> FugueResult<bool> {
246        let choice = self.choices.get(addr).ok_or_else(|| {
247            FugueError::trace_error(
248                "get_bool",
249                Some(addr.clone()),
250                "Address not found in trace",
251                crate::error::ErrorCode::TraceAddressNotFound,
252            )
253        })?;
254
255        choice.value.as_bool().ok_or_else(|| {
256            FugueError::type_mismatch(addr.clone(), "bool", choice.value.type_name())
257        })
258    }
259
260    /// Type-safe accessor that returns a Result for better error handling.
261    pub fn get_u64_result(&self, addr: &Address) -> FugueResult<u64> {
262        let choice = self.choices.get(addr).ok_or_else(|| {
263            FugueError::trace_error(
264                "get_u64",
265                Some(addr.clone()),
266                "Address not found in trace",
267                crate::error::ErrorCode::TraceAddressNotFound,
268            )
269        })?;
270
271        choice
272            .value
273            .as_u64()
274            .ok_or_else(|| FugueError::type_mismatch(addr.clone(), "u64", choice.value.type_name()))
275    }
276
277    /// Type-safe accessor that returns a Result for better error handling.
278    pub fn get_usize_result(&self, addr: &Address) -> FugueResult<usize> {
279        let choice = self.choices.get(addr).ok_or_else(|| {
280            FugueError::trace_error(
281                "get_usize",
282                Some(addr.clone()),
283                "Address not found in trace",
284                crate::error::ErrorCode::TraceAddressNotFound,
285            )
286        })?;
287
288        choice.value.as_usize().ok_or_else(|| {
289            FugueError::type_mismatch(addr.clone(), "usize", choice.value.type_name())
290        })
291    }
292
293    /// Type-safe accessor that returns a Result for better error handling.
294    pub fn get_i64_result(&self, addr: &Address) -> FugueResult<i64> {
295        let choice = self.choices.get(addr).ok_or_else(|| {
296            FugueError::trace_error(
297                "get_i64",
298                Some(addr.clone()),
299                "Address not found in trace",
300                crate::error::ErrorCode::TraceAddressNotFound,
301            )
302        })?;
303
304        choice
305            .value
306            .as_i64()
307            .ok_or_else(|| FugueError::type_mismatch(addr.clone(), "i64", choice.value.type_name()))
308    }
309
310    /// Insert a typed choice into the trace with type safety.
311    ///
312    /// This is a convenience method for manually constructing traces. Note that
313    /// this method only updates the choices map - it does not modify the
314    /// log-weight accumulators (log_prior, log_likelihood, log_factors).
315    ///
316    /// Example:
317    /// ```rust
318    /// # use fugue::*;
319    /// # use fugue::runtime::trace::{Trace, ChoiceValue};
320    ///
321    /// let mut trace = Trace::default();
322    ///
323    /// // Insert different types of choices
324    /// trace.insert_choice(addr!("mu"), ChoiceValue::F64(1.5), -0.125);
325    /// trace.insert_choice(addr!("success"), ChoiceValue::Bool(true), -0.693);
326    /// trace.insert_choice(addr!("count"), ChoiceValue::U64(10), -2.303);
327    ///
328    /// // Retrieve with type safety
329    /// assert_eq!(trace.get_f64(&addr!("mu")), Some(1.5));
330    /// assert_eq!(trace.get_bool(&addr!("success")), Some(true));
331    /// assert_eq!(trace.get_u64(&addr!("count")), Some(10));
332    ///
333    /// println!("Trace has {} choices", trace.choices.len());
334    /// ```
335    pub fn insert_choice(&mut self, addr: Address, value: ChoiceValue, logp: f64) {
336        let choice = Choice {
337            addr: addr.clone(),
338            value,
339            logp,
340        };
341        self.choices.insert(addr, choice);
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::addr;
349
350    #[test]
351    fn insert_and_getters_work() {
352        let mut t = Trace::default();
353        t.insert_choice(addr!("a"), ChoiceValue::F64(1.5), -0.5);
354        t.insert_choice(addr!("b"), ChoiceValue::Bool(true), -0.7);
355        t.insert_choice(addr!("c"), ChoiceValue::U64(3), -0.2);
356        t.insert_choice(addr!("d"), ChoiceValue::Usize(4), -0.3);
357        t.insert_choice(addr!("e"), ChoiceValue::I64(-7), -0.1);
358
359        assert_eq!(t.get_f64(&addr!("a")), Some(1.5));
360        assert_eq!(t.get_bool(&addr!("b")), Some(true));
361        assert_eq!(t.get_u64(&addr!("c")), Some(3));
362        assert_eq!(t.get_usize(&addr!("d")), Some(4));
363        assert_eq!(t.get_i64(&addr!("e")), Some(-7));
364
365        // Result-based accessors
366        assert!(t.get_f64_result(&addr!("a")).is_ok());
367        assert!(t.get_bool_result(&addr!("b")).is_ok());
368        assert!(t.get_u64_result(&addr!("c")).is_ok());
369        assert!(t.get_usize_result(&addr!("d")).is_ok());
370        assert!(t.get_i64_result(&addr!("e")).is_ok());
371
372        // Type mismatch
373        let err = t.get_f64_result(&addr!("b")).unwrap_err();
374        assert!(matches!(err, crate::error::FugueError::TypeMismatch { .. }));
375    }
376
377    #[test]
378    fn total_log_weight_accumulates() {
379        let mut t = Trace::default();
380        // insert_choice does not modify log accumulators; set them explicitly
381        t.insert_choice(addr!("x"), ChoiceValue::F64(0.0), -1.0);
382        t.log_prior = -1.0;
383        t.log_likelihood = -2.0;
384        t.log_factors = -3.0;
385        assert!((t.total_log_weight() - (-6.0)).abs() < 1e-12);
386    }
387
388    #[test]
389    fn result_accessors_return_errors_for_missing_addresses() {
390        let t = Trace::default();
391        let e = t.get_f64_result(&addr!("missing")).unwrap_err();
392        assert!(matches!(e, crate::error::FugueError::TraceError { .. }));
393    }
394}