Skip to main content

murk_obs/
spec.rs

1//! Observation specification types.
2//!
3//! An [`ObsSpec`] defines how to extract flat observation tensors from
4//! simulation state. Each [`ObsEntry`] targets one field, specifying
5//! the spatial region to observe, the transform to apply, and the
6//! output data type.
7
8use murk_core::FieldId;
9use murk_space::RegionSpec;
10use smallvec::SmallVec;
11
12/// Specification for observation extraction.
13///
14/// An `ObsSpec` is a list of entries, each describing one slice of the
15/// output tensor. Entries are gathered in order: entry 0 fills the first
16/// `N_0` elements, entry 1 fills the next `N_1`, etc.
17///
18/// # Examples
19///
20/// ```
21/// use murk_obs::{ObsSpec, ObsEntry, ObsDtype, ObsTransform, ObsRegion};
22/// use murk_core::FieldId;
23/// use murk_space::RegionSpec;
24///
25/// let spec = ObsSpec {
26///     entries: vec![
27///         ObsEntry {
28///             field_id: FieldId(0),
29///             region: ObsRegion::Fixed(RegionSpec::All),
30///             pool: None,
31///             transform: ObsTransform::Identity,
32///             dtype: ObsDtype::F32,
33///         },
34///         ObsEntry {
35///             field_id: FieldId(1),
36///             region: ObsRegion::AgentDisk { radius: 3 },
37///             pool: None,
38///             transform: ObsTransform::Normalize { min: 0.0, max: 100.0 },
39///             dtype: ObsDtype::F32,
40///         },
41///     ],
42/// };
43///
44/// assert_eq!(spec.entries.len(), 2);
45/// assert_eq!(spec.entries[0].field_id, FieldId(0));
46/// ```
47#[derive(Clone, Debug, PartialEq)]
48pub struct ObsSpec {
49    /// Ordered observation entries.
50    pub entries: Vec<ObsEntry>,
51}
52
53/// Observation region — how to select spatial cells for an entry.
54///
55/// `Fixed` regions are resolved at plan-compile time (like the existing
56/// `RegionSpec`). `AgentDisk` and `AgentRect` are resolved at execute
57/// time relative to each agent's position (foveation).
58#[derive(Clone, Debug, PartialEq)]
59pub enum ObsRegion {
60    /// Absolute region, compiled at plan-compile time.
61    Fixed(RegionSpec),
62    /// Disk centered on the agent, resolved at execute time.
63    AgentDisk {
64        /// Maximum graph distance from agent center (inclusive).
65        radius: u32,
66    },
67    /// Axis-aligned rectangle centered on the agent, resolved at execute time.
68    AgentRect {
69        /// Half-extent per dimension (the full extent is `2 * half_extent + 1`).
70        half_extent: SmallVec<[u32; 4]>,
71    },
72}
73
74impl From<RegionSpec> for ObsRegion {
75    fn from(spec: RegionSpec) -> Self {
76        ObsRegion::Fixed(spec)
77    }
78}
79
80/// Pooling kernel type for spatial downsampling.
81#[derive(Clone, Copy, Debug, PartialEq, Eq)]
82pub enum PoolKernel {
83    /// Average of valid cells in the window.
84    Mean,
85    /// Maximum of valid cells in the window.
86    Max,
87    /// Minimum of valid cells in the window.
88    Min,
89    /// Sum of valid cells in the window.
90    Sum,
91}
92
93/// Configuration for spatial pooling applied after gather.
94///
95/// # Examples
96///
97/// ```
98/// use murk_obs::{PoolConfig, PoolKernel};
99///
100/// let pool = PoolConfig {
101///     kernel: PoolKernel::Mean,
102///     kernel_size: 3,
103///     stride: 2,
104/// };
105///
106/// assert_eq!(pool.kernel, PoolKernel::Mean);
107/// assert_eq!(pool.kernel_size, 3);
108/// assert_eq!(pool.stride, 2);
109/// ```
110#[derive(Clone, Debug, PartialEq, Eq)]
111pub struct PoolConfig {
112    /// Pooling kernel type.
113    pub kernel: PoolKernel,
114    /// Window size (applies to all spatial dimensions).
115    pub kernel_size: usize,
116    /// Stride between windows (applies to all spatial dimensions).
117    pub stride: usize,
118}
119
120/// A single observation entry targeting one field over a spatial region.
121///
122/// # Examples
123///
124/// ```
125/// use murk_obs::{ObsEntry, ObsDtype, ObsTransform, ObsRegion};
126/// use murk_core::FieldId;
127/// use murk_space::RegionSpec;
128///
129/// let entry = ObsEntry {
130///     field_id: FieldId(0),
131///     region: RegionSpec::All.into(),
132///     pool: None,
133///     transform: ObsTransform::Identity,
134///     dtype: ObsDtype::F32,
135/// };
136///
137/// assert_eq!(entry.field_id, FieldId(0));
138/// assert!(entry.pool.is_none());
139/// assert!(matches!(entry.region, ObsRegion::Fixed(RegionSpec::All)));
140/// ```
141#[derive(Clone, Debug, PartialEq)]
142pub struct ObsEntry {
143    /// Which simulation field to observe.
144    pub field_id: FieldId,
145    /// Spatial region to gather from.
146    pub region: ObsRegion,
147    /// Optional spatial pooling applied after gather, before transform.
148    pub pool: Option<PoolConfig>,
149    /// Transform to apply to raw field values (element-wise, after pooling).
150    pub transform: ObsTransform,
151    /// Output data type.
152    pub dtype: ObsDtype,
153}
154
155/// Transform applied to raw field values before output.
156///
157/// v1 supports `Identity` and `Normalize`. Additional transforms
158/// are deferred to v1.5+.
159///
160/// # Examples
161///
162/// ```
163/// use murk_obs::ObsTransform;
164///
165/// let t = ObsTransform::Normalize { min: 0.0, max: 1.0 };
166/// assert!(matches!(t, ObsTransform::Normalize { min, max } if max > min));
167///
168/// let identity = ObsTransform::Identity;
169/// assert_ne!(identity, t);
170/// ```
171#[derive(Clone, Debug, PartialEq)]
172pub enum ObsTransform {
173    /// Pass values through unchanged.
174    Identity,
175    /// Linearly map `[min, max]` to `[0, 1]`.
176    ///
177    /// Values outside the range are clamped. If `min == max`,
178    /// all outputs are 0.0.
179    Normalize {
180        /// Lower bound of the input range.
181        min: f64,
182        /// Upper bound of the input range.
183        max: f64,
184    },
185}
186
187/// Output data type for observation values.
188///
189/// v1 supports only `F32`. `F16` and `U8` are deferred to v1.5+.
190#[derive(Clone, Copy, Debug, PartialEq, Eq)]
191pub enum ObsDtype {
192    /// 32-bit float.
193    F32,
194}