Skip to main content

plato_vision_jepa/
lib.rs

1//! # plato-vision-jepa
2//!
3//! Vision JEPA for the PLATO nervous system. Processes camera frames into
4//! structured room state vectors suitable for downstream nervous-system tiles.
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9// ---------------------------------------------------------------------------
10// Types
11// ---------------------------------------------------------------------------
12
13/// Structured room state produced by the vision JEPA from a single frame.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct VisionTile {
16    pub id: Uuid,
17    pub brightness: f32,
18    pub occupancy: f32,
19    pub motion_level: f32,
20    pub object_count: u32,
21    pub anomalies_detected: u32,
22    pub timestamp: u64,
23}
24
25/// Deadband filter — only pass frames whose histogram diff exceeds a threshold.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct VisionDeadband {
28    pub threshold: f64,
29    pub last_histogram: Option<Vec<u8>>,
30}
31
32impl Default for VisionDeadband {
33    fn default() -> Self {
34        Self {
35            threshold: 0.05,
36            last_histogram: None,
37        }
38    }
39}
40
41impl VisionDeadband {
42    pub fn new(threshold: f64) -> Self {
43        Self {
44            threshold,
45            last_histogram: None,
46        }
47    }
48
49    /// Returns `true` if the new histogram represents a significant change.
50    pub fn should_process(&mut self, histogram: &[u8; 256]) -> bool {
51        let significant = match self.last_histogram {
52            None => true,
53            Some(ref prev) => {
54                let prev_arr: [u8; 256] = prev.clone().try_into().unwrap_or([0u8; 256]);
55                is_significant_change(compute_frame_diff(&prev_arr, histogram), self.threshold)
56            }
57        };
58        if significant {
59            self.last_histogram = Some(histogram.to_vec());
60        }
61        significant
62    }
63}
64
65/// 16-dimensional vision state vector for a room.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct RoomVisionState {
68    pub brightness: f32,
69    pub motion_level: f32,
70    pub occupancy: f32,
71    pub anomaly_score: f32,
72    pub region_states: [f32; 4],
73    pub temporal_patterns: [f32; 4],
74    pub reserved: [f32; 4],
75}
76
77impl Default for RoomVisionState {
78    fn default() -> Self {
79        Self {
80            brightness: 0.0,
81            motion_level: 0.0,
82            occupancy: 0.0,
83            anomaly_score: 0.0,
84            region_states: [0.0; 4],
85            temporal_patterns: [0.0; 4],
86            reserved: [0.0; 4],
87        }
88    }
89}
90
91impl RoomVisionState {
92    pub fn to_vector(&self) -> [f32; 16] {
93        let mut v = [0.0f32; 16];
94        v[0] = self.brightness;
95        v[1] = self.motion_level;
96        v[2] = self.occupancy;
97        v[3] = self.anomaly_score;
98        v[4..8].copy_from_slice(&self.region_states);
99        v[8..12].copy_from_slice(&self.temporal_patterns);
100        v[12..16].copy_from_slice(&self.reserved);
101        v
102    }
103
104    pub fn from_vector(v: &[f32; 16]) -> Self {
105        let mut region = [0.0f32; 4];
106        let mut temporal = [0.0f32; 4];
107        let mut reserved = [0.0f32; 4];
108        region.copy_from_slice(&v[4..8]);
109        temporal.copy_from_slice(&v[8..12]);
110        reserved.copy_from_slice(&v[12..16]);
111        Self {
112            brightness: v[0],
113            motion_level: v[1],
114            occupancy: v[2],
115            anomaly_score: v[3],
116            region_states: region,
117            temporal_patterns: temporal,
118            reserved,
119        }
120    }
121}
122
123// ---------------------------------------------------------------------------
124// Functions
125// ---------------------------------------------------------------------------
126
127/// Compute histogram intersection distance between two frames.
128/// Returns a value in [0, 1] where 0 = identical, 1 = maximally different.
129pub fn compute_frame_diff(prev_histogram: &[u8; 256], curr_histogram: &[u8; 256]) -> f64 {
130    let mut intersection: f64 = 0.0;
131    let mut prev_total: f64 = 0.0;
132    let mut curr_total: f64 = 0.0;
133
134    for i in 0..256 {
135        let p = prev_histogram[i] as f64;
136        let c = curr_histogram[i] as f64;
137        intersection += p.min(c);
138        prev_total += p;
139        curr_total += c;
140    }
141
142    let total = prev_total.max(curr_total);
143    if total == 0.0 {
144        return 0.0;
145    }
146
147    1.0 - (intersection / total)
148}
149
150/// Determine whether a frame diff exceeds the significance threshold.
151pub fn is_significant_change(diff: f64, threshold: f64) -> bool {
152    diff > threshold
153}
154
155/// Divide an image grid into 4 quadrants and return average intensity of each
156/// normalized to [0, 1].
157pub fn extract_region_states(grid: &[Vec<u8>]) -> [f32; 4] {
158    if grid.is_empty() {
159        return [0.0; 4];
160    }
161
162    let rows = grid.len();
163    let mid_row = rows / 2;
164    let mut sums = [0.0f64; 4];
165    let mut counts = [0usize; 4];
166
167    for (r, row) in grid.iter().enumerate() {
168        let cols = row.len();
169        if cols == 0 {
170            continue;
171        }
172        let mid_col = cols / 2;
173        for (c, &val) in row.iter().enumerate() {
174            let quadrant = match (r < mid_row, c < mid_col) {
175                (true, true) => 0,
176                (true, false) => 1,
177                (false, true) => 2,
178                (false, false) => 3,
179            };
180            sums[quadrant] += val as f64;
181            counts[quadrant] += 1;
182        }
183    }
184
185    let mut result = [0.0f32; 4];
186    for i in 0..4 {
187        if counts[i] > 0 {
188            result[i] = (sums[i] / counts[i] as f64 / 255.0) as f32;
189        }
190    }
191    result
192}
193
194/// Compute average motion vector from tracked point positions.
195pub fn compute_motion_vector(
196    prev_positions: &[(f32, f32)],
197    curr_positions: &[(f32, f32)],
198) -> (f32, f32) {
199    if prev_positions.is_empty() || curr_positions.is_empty() {
200        return (0.0, 0.0);
201    }
202
203    let n = prev_positions.len().min(curr_positions.len());
204    let mut dx: f32 = 0.0;
205    let mut dy: f32 = 0.0;
206
207    for i in 0..n {
208        dx += curr_positions[i].0 - prev_positions[i].0;
209        dy += curr_positions[i].1 - prev_positions[i].1;
210    }
211
212    (dx / n as f32, dy / n as f32)
213}
214
215/// Convert a RoomVisionState into a VisionTile.
216pub fn vision_state_to_tile(state: &RoomVisionState) -> VisionTile {
217    VisionTile {
218        id: Uuid::new_v4(),
219        brightness: state.brightness,
220        occupancy: state.occupancy,
221        motion_level: state.motion_level,
222        object_count: state.occupancy.round() as u32,
223        anomalies_detected: if state.anomaly_score > 0.5 { 1 } else { 0 },
224        timestamp: 0,
225    }
226}