fugue/runtime/trace.rs
1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/runtime/trace.md"))]
2
3use crate::core::address::Address;
4use crate::error::{FugueError, FugueResult};
5use std::collections::BTreeMap;
6
7/// Type-safe storage for values from different distribution types.
8///
9/// ChoiceValue enables traces to store values from any supported distribution
10/// while maintaining type safety. Each variant corresponds to a distribution
11/// return type, preventing runtime type errors.
12///
13/// Example:
14/// ```rust
15/// # use fugue::runtime::trace::ChoiceValue;
16///
17/// // Different value types from distributions
18/// let continuous = ChoiceValue::F64(3.14159); // Normal, Uniform, etc.
19/// let discrete = ChoiceValue::U64(42); // Poisson, Binomial
20/// let categorical = ChoiceValue::Usize(2); // Categorical selection
21/// let binary = ChoiceValue::Bool(true); // Bernoulli outcome
22///
23/// // Type-safe extraction
24/// assert_eq!(continuous.as_f64(), Some(3.14159));
25/// assert_eq!(discrete.as_u64(), Some(42));
26/// assert_eq!(binary.as_bool(), Some(true));
27///
28/// // Type mismatches return None
29/// assert_eq!(continuous.as_bool(), None);
30/// ```
31#[derive(Clone, Debug, PartialEq)]
32pub enum ChoiceValue {
33 /// Floating-point value (continuous distributions).
34 F64(f64),
35 /// Signed integer value.
36 I64(i64),
37 /// Unsigned integer value (Poisson, Binomial counts).
38 U64(u64),
39 /// Array index value (Categorical choices).
40 Usize(usize),
41 /// Boolean value (Bernoulli outcomes).
42 Bool(bool),
43}
44impl ChoiceValue {
45 /// Try to extract an f64 value, returning None if the type doesn't match.
46 pub fn as_f64(&self) -> Option<f64> {
47 match self {
48 ChoiceValue::F64(v) => Some(*v),
49 _ => None,
50 }
51 }
52
53 /// Try to extract a bool value, returning None if the type doesn't match.
54 pub fn as_bool(&self) -> Option<bool> {
55 match self {
56 ChoiceValue::Bool(v) => Some(*v),
57 _ => None,
58 }
59 }
60
61 /// Try to extract a u64 value, returning None if the type doesn't match.
62 pub fn as_u64(&self) -> Option<u64> {
63 match self {
64 ChoiceValue::U64(v) => Some(*v),
65 _ => None,
66 }
67 }
68
69 /// Try to extract a usize value, returning None if the type doesn't match.
70 pub fn as_usize(&self) -> Option<usize> {
71 match self {
72 ChoiceValue::Usize(v) => Some(*v),
73 _ => None,
74 }
75 }
76
77 /// Try to extract an i64 value, returning None if the type doesn't match.
78 pub fn as_i64(&self) -> Option<i64> {
79 match self {
80 ChoiceValue::I64(v) => Some(*v),
81 _ => None,
82 }
83 }
84
85 /// Get the type name as a string for error messages.
86 pub fn type_name(&self) -> &'static str {
87 match self {
88 ChoiceValue::F64(_) => "f64",
89 ChoiceValue::Bool(_) => "bool",
90 ChoiceValue::U64(_) => "u64",
91 ChoiceValue::Usize(_) => "usize",
92 ChoiceValue::I64(_) => "i64",
93 }
94 }
95}
96
97/// A single recorded choice made during model execution.
98///
99/// Each Choice represents a random variable assignment at a specific address,
100/// complete with the value chosen and its log-probability. Choices form the
101/// building blocks of execution traces.
102///
103/// Example:
104/// ```rust
105/// # use fugue::*;
106/// # use fugue::runtime::trace::{Choice, ChoiceValue};
107///
108/// // Choices are typically created by handlers during execution
109/// let choice = Choice {
110/// addr: addr!("theta"),
111/// value: ChoiceValue::F64(1.5),
112/// logp: -0.918, // log-probability under generating distribution
113/// };
114///
115/// println!("Choice at {}: {:?} (logp: {:.3})",
116/// choice.addr, choice.value, choice.logp);
117///
118/// // Extract the value with type safety
119/// if let Some(val) = choice.value.as_f64() {
120/// println!("Theta value: {:.3}", val);
121/// }
122/// ```
123#[derive(Clone, Debug)]
124pub struct Choice {
125 /// Address where this choice was made.
126 pub addr: Address,
127 /// Value that was chosen.
128 pub value: ChoiceValue,
129 /// Log-probability of this value under the generating distribution.
130 pub logp: f64,
131}
132
133/// Complete execution trace of a probabilistic model.
134///
135/// A Trace records the complete execution history of a probabilistic model,
136/// including all choices made and accumulated log-weights from different sources.
137/// This enables replay, scoring, and inference operations.
138///
139/// Example:
140/// ```rust
141/// # use fugue::*;
142/// # use fugue::runtime::interpreters::PriorHandler;
143/// # use rand::rngs::StdRng;
144/// # use rand::SeedableRng;
145///
146/// // Execute a model and examine the trace
147/// let model = sample(addr!("theta"), Normal::new(0.0, 1.0).unwrap())
148/// .bind(|theta| observe(addr!("y"), Normal::new(theta, 0.5).unwrap(), 1.2)
149/// .map(move |_| theta));
150///
151/// let mut rng = StdRng::seed_from_u64(42);
152/// let (result, trace) = runtime::handler::run(
153/// PriorHandler { rng: &mut rng, trace: Trace::default() },
154/// model
155/// );
156///
157/// // Examine trace components
158/// println!("Sampled theta: {:.3}", result);
159/// println!("Prior log-weight: {:.3}", trace.log_prior);
160/// println!("Likelihood log-weight: {:.3}", trace.log_likelihood);
161/// println!("Total log-weight: {:.3}", trace.total_log_weight());
162///
163/// // Type-safe value access
164/// let theta_value = trace.get_f64(&addr!("theta")).unwrap();
165/// assert_eq!(theta_value, result);
166/// ```
167#[derive(Clone, Debug, Default)]
168pub struct Trace {
169 /// Map from addresses to the choices made at those sites.
170 pub choices: BTreeMap<Address, Choice>,
171 /// Accumulated log-prior probability from all sampling sites.
172 pub log_prior: f64,
173 /// Accumulated log-likelihood from all observation sites.
174 pub log_likelihood: f64,
175 /// Accumulated log-weight from all factor statements.
176 pub log_factors: f64,
177}
178
179impl Trace {
180 /// Compute the total unnormalized log-probability of this execution.
181 ///
182 /// The total log-weight combines all three components (prior, likelihood, factors)
183 /// and represents the unnormalized log-probability of this execution path.
184 ///
185 /// Example:
186 /// ```rust
187 /// # use fugue::runtime::trace::Trace;
188 ///
189 /// let trace = Trace {
190 /// log_prior: -1.5,
191 /// log_likelihood: -2.3,
192 /// log_factors: 0.8,
193 /// ..Default::default()
194 /// };
195 ///
196 /// assert_eq!(trace.total_log_weight(), -3.0);
197 /// ```
198 pub fn total_log_weight(&self) -> f64 {
199 self.log_prior + self.log_likelihood + self.log_factors
200 }
201
202 /// Type-safe accessor for f64 values in the trace.
203 pub fn get_f64(&self, addr: &Address) -> Option<f64> {
204 self.choices.get(addr)?.value.as_f64()
205 }
206
207 /// Type-safe accessor for bool values in the trace.
208 pub fn get_bool(&self, addr: &Address) -> Option<bool> {
209 self.choices.get(addr)?.value.as_bool()
210 }
211
212 /// Type-safe accessor for u64 values in the trace.
213 pub fn get_u64(&self, addr: &Address) -> Option<u64> {
214 self.choices.get(addr)?.value.as_u64()
215 }
216
217 /// Type-safe accessor for usize values in the trace.
218 pub fn get_usize(&self, addr: &Address) -> Option<usize> {
219 self.choices.get(addr)?.value.as_usize()
220 }
221
222 /// Type-safe accessor for i64 values in the trace.
223 pub fn get_i64(&self, addr: &Address) -> Option<i64> {
224 self.choices.get(addr)?.value.as_i64()
225 }
226
227 /// Type-safe accessor that returns a Result for better error handling.
228 pub fn get_f64_result(&self, addr: &Address) -> FugueResult<f64> {
229 let choice = self.choices.get(addr).ok_or_else(|| {
230 FugueError::trace_error(
231 "get_f64",
232 Some(addr.clone()),
233 "Address not found in trace",
234 crate::error::ErrorCode::TraceAddressNotFound,
235 )
236 })?;
237
238 choice
239 .value
240 .as_f64()
241 .ok_or_else(|| FugueError::type_mismatch(addr.clone(), "f64", choice.value.type_name()))
242 }
243
244 /// Type-safe accessor that returns a Result for better error handling.
245 pub fn get_bool_result(&self, addr: &Address) -> FugueResult<bool> {
246 let choice = self.choices.get(addr).ok_or_else(|| {
247 FugueError::trace_error(
248 "get_bool",
249 Some(addr.clone()),
250 "Address not found in trace",
251 crate::error::ErrorCode::TraceAddressNotFound,
252 )
253 })?;
254
255 choice.value.as_bool().ok_or_else(|| {
256 FugueError::type_mismatch(addr.clone(), "bool", choice.value.type_name())
257 })
258 }
259
260 /// Type-safe accessor that returns a Result for better error handling.
261 pub fn get_u64_result(&self, addr: &Address) -> FugueResult<u64> {
262 let choice = self.choices.get(addr).ok_or_else(|| {
263 FugueError::trace_error(
264 "get_u64",
265 Some(addr.clone()),
266 "Address not found in trace",
267 crate::error::ErrorCode::TraceAddressNotFound,
268 )
269 })?;
270
271 choice
272 .value
273 .as_u64()
274 .ok_or_else(|| FugueError::type_mismatch(addr.clone(), "u64", choice.value.type_name()))
275 }
276
277 /// Type-safe accessor that returns a Result for better error handling.
278 pub fn get_usize_result(&self, addr: &Address) -> FugueResult<usize> {
279 let choice = self.choices.get(addr).ok_or_else(|| {
280 FugueError::trace_error(
281 "get_usize",
282 Some(addr.clone()),
283 "Address not found in trace",
284 crate::error::ErrorCode::TraceAddressNotFound,
285 )
286 })?;
287
288 choice.value.as_usize().ok_or_else(|| {
289 FugueError::type_mismatch(addr.clone(), "usize", choice.value.type_name())
290 })
291 }
292
293 /// Type-safe accessor that returns a Result for better error handling.
294 pub fn get_i64_result(&self, addr: &Address) -> FugueResult<i64> {
295 let choice = self.choices.get(addr).ok_or_else(|| {
296 FugueError::trace_error(
297 "get_i64",
298 Some(addr.clone()),
299 "Address not found in trace",
300 crate::error::ErrorCode::TraceAddressNotFound,
301 )
302 })?;
303
304 choice
305 .value
306 .as_i64()
307 .ok_or_else(|| FugueError::type_mismatch(addr.clone(), "i64", choice.value.type_name()))
308 }
309
310 /// Insert a typed choice into the trace with type safety.
311 ///
312 /// This is a convenience method for manually constructing traces. Note that
313 /// this method only updates the choices map - it does not modify the
314 /// log-weight accumulators (log_prior, log_likelihood, log_factors).
315 ///
316 /// Example:
317 /// ```rust
318 /// # use fugue::*;
319 /// # use fugue::runtime::trace::{Trace, ChoiceValue};
320 ///
321 /// let mut trace = Trace::default();
322 ///
323 /// // Insert different types of choices
324 /// trace.insert_choice(addr!("mu"), ChoiceValue::F64(1.5), -0.125);
325 /// trace.insert_choice(addr!("success"), ChoiceValue::Bool(true), -0.693);
326 /// trace.insert_choice(addr!("count"), ChoiceValue::U64(10), -2.303);
327 ///
328 /// // Retrieve with type safety
329 /// assert_eq!(trace.get_f64(&addr!("mu")), Some(1.5));
330 /// assert_eq!(trace.get_bool(&addr!("success")), Some(true));
331 /// assert_eq!(trace.get_u64(&addr!("count")), Some(10));
332 ///
333 /// println!("Trace has {} choices", trace.choices.len());
334 /// ```
335 pub fn insert_choice(&mut self, addr: Address, value: ChoiceValue, logp: f64) {
336 let choice = Choice {
337 addr: addr.clone(),
338 value,
339 logp,
340 };
341 self.choices.insert(addr, choice);
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::addr;
349
350 #[test]
351 fn insert_and_getters_work() {
352 let mut t = Trace::default();
353 t.insert_choice(addr!("a"), ChoiceValue::F64(1.5), -0.5);
354 t.insert_choice(addr!("b"), ChoiceValue::Bool(true), -0.7);
355 t.insert_choice(addr!("c"), ChoiceValue::U64(3), -0.2);
356 t.insert_choice(addr!("d"), ChoiceValue::Usize(4), -0.3);
357 t.insert_choice(addr!("e"), ChoiceValue::I64(-7), -0.1);
358
359 assert_eq!(t.get_f64(&addr!("a")), Some(1.5));
360 assert_eq!(t.get_bool(&addr!("b")), Some(true));
361 assert_eq!(t.get_u64(&addr!("c")), Some(3));
362 assert_eq!(t.get_usize(&addr!("d")), Some(4));
363 assert_eq!(t.get_i64(&addr!("e")), Some(-7));
364
365 // Result-based accessors
366 assert!(t.get_f64_result(&addr!("a")).is_ok());
367 assert!(t.get_bool_result(&addr!("b")).is_ok());
368 assert!(t.get_u64_result(&addr!("c")).is_ok());
369 assert!(t.get_usize_result(&addr!("d")).is_ok());
370 assert!(t.get_i64_result(&addr!("e")).is_ok());
371
372 // Type mismatch
373 let err = t.get_f64_result(&addr!("b")).unwrap_err();
374 assert!(matches!(err, crate::error::FugueError::TypeMismatch { .. }));
375 }
376
377 #[test]
378 fn total_log_weight_accumulates() {
379 let mut t = Trace::default();
380 // insert_choice does not modify log accumulators; set them explicitly
381 t.insert_choice(addr!("x"), ChoiceValue::F64(0.0), -1.0);
382 t.log_prior = -1.0;
383 t.log_likelihood = -2.0;
384 t.log_factors = -3.0;
385 assert!((t.total_log_weight() - (-6.0)).abs() < 1e-12);
386 }
387
388 #[test]
389 fn result_accessors_return_errors_for_missing_addresses() {
390 let t = Trace::default();
391 let e = t.get_f64_result(&addr!("missing")).unwrap_err();
392 assert!(matches!(e, crate::error::FugueError::TraceError { .. }));
393 }
394}