Skip to main content

cvx_bayes/
variable.rs

1//! Discrete random variables for Bayesian network nodes.
2
3use serde::{Deserialize, Serialize};
4
5/// Unique identifier for a variable in the network.
6pub type VariableId = u32;
7
8/// A discrete random variable with named states.
9///
10/// # Example
11///
12/// ```
13/// use cvx_bayes::Variable;
14///
15/// let task = Variable::new(0, "task_type", vec![
16///     "pick_and_place".into(),
17///     "heat_then_place".into(),
18///     "clean_then_place".into(),
19/// ]);
20/// assert_eq!(task.n_states(), 3);
21/// assert_eq!(task.state_index("heat_then_place"), Some(1));
22/// ```
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Variable {
25    /// Unique ID in the network.
26    pub id: VariableId,
27    /// Human-readable name.
28    pub name: String,
29    /// Named states (e.g., ["success", "failure"]).
30    pub states: Vec<String>,
31}
32
33impl Variable {
34    /// Create a new variable with named states.
35    pub fn new(id: VariableId, name: impl Into<String>, states: Vec<String>) -> Self {
36        assert!(!states.is_empty(), "variable must have at least one state");
37        Self {
38            id,
39            name: name.into(),
40            states,
41        }
42    }
43
44    /// Create a binary variable (two states: "true", "false").
45    pub fn binary(id: VariableId, name: impl Into<String>) -> Self {
46        Self::new(id, name, vec!["true".into(), "false".into()])
47    }
48
49    /// Number of states.
50    pub fn n_states(&self) -> usize {
51        self.states.len()
52    }
53
54    /// Get state index by name.
55    pub fn state_index(&self, state_name: &str) -> Option<usize> {
56        self.states.iter().position(|s| s == state_name)
57    }
58
59    /// Get state name by index.
60    pub fn state_name(&self, index: usize) -> Option<&str> {
61        self.states.get(index).map(|s| s.as_str())
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn variable_creation() {
71        let v = Variable::new(
72            0,
73            "color",
74            vec!["red".into(), "blue".into(), "green".into()],
75        );
76        assert_eq!(v.n_states(), 3);
77        assert_eq!(v.state_index("blue"), Some(1));
78        assert_eq!(v.state_name(2), Some("green"));
79        assert_eq!(v.state_index("yellow"), None);
80    }
81
82    #[test]
83    fn binary_variable() {
84        let v = Variable::binary(0, "success");
85        assert_eq!(v.n_states(), 2);
86        assert_eq!(v.state_index("true"), Some(0));
87        assert_eq!(v.state_index("false"), Some(1));
88    }
89
90    #[test]
91    #[should_panic]
92    fn empty_states_panics() {
93        Variable::new(0, "empty", vec![]);
94    }
95}