murk_obs/spec.rs
1//! Observation specification types.
2//!
3//! An [`ObsSpec`] defines how to extract flat observation tensors from
4//! simulation state. Each [`ObsEntry`] targets one field, specifying
5//! the spatial region to observe, the transform to apply, and the
6//! output data type.
7
8use murk_core::FieldId;
9use murk_space::RegionSpec;
10use smallvec::SmallVec;
11
12/// Specification for observation extraction.
13///
14/// An `ObsSpec` is a list of entries, each describing one slice of the
15/// output tensor. Entries are gathered in order: entry 0 fills the first
16/// `N_0` elements, entry 1 fills the next `N_1`, etc.
17///
18/// # Examples
19///
20/// ```
21/// use murk_obs::{ObsSpec, ObsEntry, ObsDtype, ObsTransform, ObsRegion};
22/// use murk_core::FieldId;
23/// use murk_space::RegionSpec;
24///
25/// let spec = ObsSpec {
26/// entries: vec![
27/// ObsEntry {
28/// field_id: FieldId(0),
29/// region: ObsRegion::Fixed(RegionSpec::All),
30/// pool: None,
31/// transform: ObsTransform::Identity,
32/// dtype: ObsDtype::F32,
33/// },
34/// ObsEntry {
35/// field_id: FieldId(1),
36/// region: ObsRegion::AgentDisk { radius: 3 },
37/// pool: None,
38/// transform: ObsTransform::Normalize { min: 0.0, max: 100.0 },
39/// dtype: ObsDtype::F32,
40/// },
41/// ],
42/// };
43///
44/// assert_eq!(spec.entries.len(), 2);
45/// assert_eq!(spec.entries[0].field_id, FieldId(0));
46/// ```
47#[derive(Clone, Debug, PartialEq)]
48pub struct ObsSpec {
49 /// Ordered observation entries.
50 pub entries: Vec<ObsEntry>,
51}
52
53/// Observation region — how to select spatial cells for an entry.
54///
55/// `Fixed` regions are resolved at plan-compile time (like the existing
56/// `RegionSpec`). `AgentDisk` and `AgentRect` are resolved at execute
57/// time relative to each agent's position (foveation).
58#[derive(Clone, Debug, PartialEq)]
59pub enum ObsRegion {
60 /// Absolute region, compiled at plan-compile time.
61 Fixed(RegionSpec),
62 /// Disk centered on the agent, resolved at execute time.
63 AgentDisk {
64 /// Maximum graph distance from agent center (inclusive).
65 radius: u32,
66 },
67 /// Axis-aligned rectangle centered on the agent, resolved at execute time.
68 AgentRect {
69 /// Half-extent per dimension (the full extent is `2 * half_extent + 1`).
70 half_extent: SmallVec<[u32; 4]>,
71 },
72}
73
74impl From<RegionSpec> for ObsRegion {
75 fn from(spec: RegionSpec) -> Self {
76 ObsRegion::Fixed(spec)
77 }
78}
79
80/// Pooling kernel type for spatial downsampling.
81#[derive(Clone, Copy, Debug, PartialEq, Eq)]
82pub enum PoolKernel {
83 /// Average of valid cells in the window.
84 Mean,
85 /// Maximum of valid cells in the window.
86 Max,
87 /// Minimum of valid cells in the window.
88 Min,
89 /// Sum of valid cells in the window.
90 Sum,
91}
92
93/// Configuration for spatial pooling applied after gather.
94///
95/// # Examples
96///
97/// ```
98/// use murk_obs::{PoolConfig, PoolKernel};
99///
100/// let pool = PoolConfig {
101/// kernel: PoolKernel::Mean,
102/// kernel_size: 3,
103/// stride: 2,
104/// };
105///
106/// assert_eq!(pool.kernel, PoolKernel::Mean);
107/// assert_eq!(pool.kernel_size, 3);
108/// assert_eq!(pool.stride, 2);
109/// ```
110#[derive(Clone, Debug, PartialEq, Eq)]
111pub struct PoolConfig {
112 /// Pooling kernel type.
113 pub kernel: PoolKernel,
114 /// Window size (applies to all spatial dimensions).
115 pub kernel_size: usize,
116 /// Stride between windows (applies to all spatial dimensions).
117 pub stride: usize,
118}
119
120/// A single observation entry targeting one field over a spatial region.
121///
122/// # Examples
123///
124/// ```
125/// use murk_obs::{ObsEntry, ObsDtype, ObsTransform, ObsRegion};
126/// use murk_core::FieldId;
127/// use murk_space::RegionSpec;
128///
129/// let entry = ObsEntry {
130/// field_id: FieldId(0),
131/// region: RegionSpec::All.into(),
132/// pool: None,
133/// transform: ObsTransform::Identity,
134/// dtype: ObsDtype::F32,
135/// };
136///
137/// assert_eq!(entry.field_id, FieldId(0));
138/// assert!(entry.pool.is_none());
139/// assert!(matches!(entry.region, ObsRegion::Fixed(RegionSpec::All)));
140/// ```
141#[derive(Clone, Debug, PartialEq)]
142pub struct ObsEntry {
143 /// Which simulation field to observe.
144 pub field_id: FieldId,
145 /// Spatial region to gather from.
146 pub region: ObsRegion,
147 /// Optional spatial pooling applied after gather, before transform.
148 pub pool: Option<PoolConfig>,
149 /// Transform to apply to raw field values (element-wise, after pooling).
150 pub transform: ObsTransform,
151 /// Output data type.
152 pub dtype: ObsDtype,
153}
154
155/// Transform applied to raw field values before output.
156///
157/// v1 supports `Identity` and `Normalize`. Additional transforms
158/// are deferred to v1.5+.
159///
160/// # Examples
161///
162/// ```
163/// use murk_obs::ObsTransform;
164///
165/// let t = ObsTransform::Normalize { min: 0.0, max: 1.0 };
166/// assert!(matches!(t, ObsTransform::Normalize { min, max } if max > min));
167///
168/// let identity = ObsTransform::Identity;
169/// assert_ne!(identity, t);
170/// ```
171#[derive(Clone, Debug, PartialEq)]
172pub enum ObsTransform {
173 /// Pass values through unchanged.
174 Identity,
175 /// Linearly map `[min, max]` to `[0, 1]`.
176 ///
177 /// Values outside the range are clamped. If `min == max`,
178 /// all outputs are 0.0.
179 Normalize {
180 /// Lower bound of the input range.
181 min: f64,
182 /// Upper bound of the input range.
183 max: f64,
184 },
185}
186
187/// Output data type for observation values.
188///
189/// v1 supports only `F32`. `F16` and `U8` are deferred to v1.5+.
190#[derive(Clone, Copy, Debug, PartialEq, Eq)]
191pub enum ObsDtype {
192 /// 32-bit float.
193 F32,
194}