Skip to main content

peat_schema/validation/
tasking.rs

1//! Tasking validators (AI/ML Detection Tasks)
2//!
3//! Validates DetectionTask messages and related configuration for Peat Protocol.
4
5use super::{ValidationError, ValidationResult};
6use crate::tasking::v1::{
7    BatchingConfig, ChipoutConfig, DetectionFilter, DetectionTask, ProductDelivery, TaskControl,
8    TaskControlAction, TaskPriority, TaskState, TaskStatistics, TaskStatus, TrackReportMode,
9    TrackReportingConfig,
10};
11
12/// Validate a DetectionTask message
13///
14/// Validates:
15/// - task_id is present
16/// - name is present
17/// - filter parameters are valid
18/// - product delivery configuration is valid
19/// - schedule is valid
20pub fn validate_detection_task(task: &DetectionTask) -> ValidationResult<()> {
21    // Check required fields
22    if task.task_id.is_empty() {
23        return Err(ValidationError::MissingField("task_id".to_string()));
24    }
25
26    if task.name.is_empty() {
27        return Err(ValidationError::MissingField("name".to_string()));
28    }
29
30    // Priority must be specified
31    if task.priority == TaskPriority::Unspecified as i32 {
32        return Err(ValidationError::InvalidValue(
33            "priority must be specified".to_string(),
34        ));
35    }
36
37    // Timestamp is required
38    if task.issued_at.is_none() {
39        return Err(ValidationError::MissingField("issued_at".to_string()));
40    }
41
42    // issued_by is required
43    if task.issued_by.is_empty() {
44        return Err(ValidationError::MissingField("issued_by".to_string()));
45    }
46
47    // Validate filter if present
48    if let Some(ref filter) = task.filter {
49        validate_detection_filter(filter)?;
50    }
51
52    // Validate product delivery if present
53    if let Some(ref delivery) = task.product_delivery {
54        validate_product_delivery(delivery)?;
55    }
56
57    Ok(())
58}
59
60/// Validate DetectionFilter parameters
61pub fn validate_detection_filter(filter: &DetectionFilter) -> ValidationResult<()> {
62    // min_confidence must be in valid range
63    if filter.min_confidence < 0.0 || filter.min_confidence > 1.0 {
64        return Err(ValidationError::InvalidConfidence(filter.min_confidence));
65    }
66
67    // min_report_interval must be non-negative
68    if filter.min_report_interval_s < 0.0 {
69        return Err(ValidationError::InvalidValue(
70            "min_report_interval_s must be non-negative".to_string(),
71        ));
72    }
73
74    Ok(())
75}
76
77/// Validate ProductDelivery configuration
78pub fn validate_product_delivery(delivery: &ProductDelivery) -> ValidationResult<()> {
79    // Validate chipout config if present
80    if let Some(ref chipout) = delivery.chipout_config {
81        validate_chipout_config(chipout)?;
82    }
83
84    // Validate track reporting config if present
85    if let Some(ref track_reporting) = delivery.track_reporting {
86        validate_track_reporting_config(track_reporting)?;
87    }
88
89    // Validate batching config if present
90    if let Some(ref batching) = delivery.batching {
91        validate_batching_config(batching)?;
92    }
93
94    Ok(())
95}
96
97/// Validate ChipoutConfig
98pub fn validate_chipout_config(config: &ChipoutConfig) -> ValidationResult<()> {
99    // JPEG quality should be in range 1-100
100    if config.jpeg_quality > 100 {
101        return Err(ValidationError::InvalidValue(format!(
102            "jpeg_quality {} must be between 1 and 100",
103            config.jpeg_quality
104        )));
105    }
106
107    // Padding percent should be reasonable (0-100%)
108    if config.padding_percent < 0.0 || config.padding_percent > 1.0 {
109        return Err(ValidationError::InvalidValue(format!(
110            "padding_percent {} must be between 0.0 and 1.0",
111            config.padding_percent
112        )));
113    }
114
115    // Full frame quality should also be valid
116    if config.full_frame_jpeg_quality > 100 {
117        return Err(ValidationError::InvalidValue(format!(
118            "full_frame_jpeg_quality {} must be between 1 and 100",
119            config.full_frame_jpeg_quality
120        )));
121    }
122
123    Ok(())
124}
125
126/// Validate TrackReportingConfig
127pub fn validate_track_reporting_config(config: &TrackReportingConfig) -> ValidationResult<()> {
128    // Mode must be specified
129    if config.mode == TrackReportMode::Unspecified as i32 {
130        return Err(ValidationError::InvalidValue(
131            "track reporting mode must be specified".to_string(),
132        ));
133    }
134
135    // Position change threshold must be non-negative
136    if config.min_position_change_m < 0.0 {
137        return Err(ValidationError::InvalidValue(
138            "min_position_change_m must be non-negative".to_string(),
139        ));
140    }
141
142    // Confidence change threshold must be valid
143    if config.min_confidence_change < 0.0 || config.min_confidence_change > 1.0 {
144        return Err(ValidationError::InvalidConfidence(
145            config.min_confidence_change,
146        ));
147    }
148
149    // Max report interval must be non-negative
150    if config.max_report_interval_s < 0.0 {
151        return Err(ValidationError::InvalidValue(
152            "max_report_interval_s must be non-negative".to_string(),
153        ));
154    }
155
156    Ok(())
157}
158
159/// Validate BatchingConfig
160pub fn validate_batching_config(config: &BatchingConfig) -> ValidationResult<()> {
161    // Max batch delay must be non-negative
162    if config.max_batch_delay_s < 0.0 {
163        return Err(ValidationError::InvalidValue(
164            "max_batch_delay_s must be non-negative".to_string(),
165        ));
166    }
167
168    Ok(())
169}
170
171/// Validate TaskStatus message
172pub fn validate_task_status(status: &TaskStatus) -> ValidationResult<()> {
173    // task_id is required
174    if status.task_id.is_empty() {
175        return Err(ValidationError::MissingField("task_id".to_string()));
176    }
177
178    // platform_id is required
179    if status.platform_id.is_empty() {
180        return Err(ValidationError::MissingField("platform_id".to_string()));
181    }
182
183    // State must be specified
184    if status.state == TaskState::Unspecified as i32 {
185        return Err(ValidationError::InvalidValue(
186            "task state must be specified".to_string(),
187        ));
188    }
189
190    // Timestamp is required
191    if status.updated_at.is_none() {
192        return Err(ValidationError::MissingField("updated_at".to_string()));
193    }
194
195    // Validate statistics if present
196    if let Some(ref stats) = status.statistics {
197        validate_task_statistics(stats)?;
198    }
199
200    Ok(())
201}
202
203/// Validate TaskStatistics
204pub fn validate_task_statistics(stats: &TaskStatistics) -> ValidationResult<()> {
205    // reported_detections should not exceed total_detections
206    if stats.reported_detections > stats.total_detections {
207        return Err(ValidationError::ConstraintViolation(
208            "reported_detections cannot exceed total_detections".to_string(),
209        ));
210    }
211
212    // Average values must be non-negative
213    if stats.avg_inference_time_ms < 0.0 {
214        return Err(ValidationError::InvalidValue(
215            "avg_inference_time_ms must be non-negative".to_string(),
216        ));
217    }
218
219    if stats.avg_fps < 0.0 {
220        return Err(ValidationError::InvalidValue(
221            "avg_fps must be non-negative".to_string(),
222        ));
223    }
224
225    if stats.uptime_s < 0.0 {
226        return Err(ValidationError::InvalidValue(
227            "uptime_s must be non-negative".to_string(),
228        ));
229    }
230
231    Ok(())
232}
233
234/// Validate TaskControl message
235pub fn validate_task_control(control: &TaskControl) -> ValidationResult<()> {
236    // task_id is required
237    if control.task_id.is_empty() {
238        return Err(ValidationError::MissingField("task_id".to_string()));
239    }
240
241    // Action must be specified
242    if control.action == TaskControlAction::TaskControlUnspecified as i32 {
243        return Err(ValidationError::InvalidValue(
244            "task control action must be specified".to_string(),
245        ));
246    }
247
248    // issued_by is required
249    if control.issued_by.is_empty() {
250        return Err(ValidationError::MissingField("issued_by".to_string()));
251    }
252
253    // Timestamp is required
254    if control.issued_at.is_none() {
255        return Err(ValidationError::MissingField("issued_at".to_string()));
256    }
257
258    // UPDATE action requires updated_task
259    if control.action == TaskControlAction::TaskControlUpdate as i32 {
260        if control.updated_task.is_none() {
261            return Err(ValidationError::MissingField(
262                "updated_task (required for UPDATE action)".to_string(),
263            ));
264        }
265        // Validate the updated task
266        if let Some(ref task) = control.updated_task {
267            validate_detection_task(task)?;
268        }
269    }
270
271    Ok(())
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::common::v1::Timestamp;
278
279    fn valid_detection_task() -> DetectionTask {
280        DetectionTask {
281            task_id: "TASK-001".to_string(),
282            name: "Maritime Detection".to_string(),
283            description: "Detect boats in harbor area".to_string(),
284            target_classes: vec!["boat".to_string(), "person".to_string()],
285            filter: Some(DetectionFilter {
286                min_confidence: 0.7,
287                priority_classes: vec!["boat".to_string()],
288                ignore_classes: vec![],
289                min_bbox_area: 100,
290                max_detections_per_frame: 10,
291                min_report_interval_s: 1.0,
292            }),
293            product_delivery: None,
294            area_of_interest: None,
295            schedule: None,
296            priority: TaskPriority::Normal as i32,
297            issued_by: "C2-WebTAK".to_string(),
298            issued_at: Some(Timestamp {
299                seconds: 1702000000,
300                nanos: 0,
301            }),
302            target_platforms: vec![],
303        }
304    }
305
306    #[test]
307    fn test_valid_detection_task() {
308        let task = valid_detection_task();
309        assert!(validate_detection_task(&task).is_ok());
310    }
311
312    #[test]
313    fn test_missing_task_id() {
314        let mut task = valid_detection_task();
315        task.task_id = String::new();
316        let err = validate_detection_task(&task).unwrap_err();
317        assert!(matches!(err, ValidationError::MissingField(f) if f == "task_id"));
318    }
319
320    #[test]
321    fn test_missing_name() {
322        let mut task = valid_detection_task();
323        task.name = String::new();
324        let err = validate_detection_task(&task).unwrap_err();
325        assert!(matches!(err, ValidationError::MissingField(f) if f == "name"));
326    }
327
328    #[test]
329    fn test_unspecified_priority() {
330        let mut task = valid_detection_task();
331        task.priority = TaskPriority::Unspecified as i32;
332        let err = validate_detection_task(&task).unwrap_err();
333        assert!(matches!(err, ValidationError::InvalidValue(_)));
334    }
335
336    #[test]
337    fn test_invalid_confidence_filter() {
338        let mut task = valid_detection_task();
339        task.filter = Some(DetectionFilter {
340            min_confidence: 1.5, // Invalid
341            ..Default::default()
342        });
343        let err = validate_detection_task(&task).unwrap_err();
344        assert!(matches!(err, ValidationError::InvalidConfidence(_)));
345    }
346
347    #[test]
348    fn test_valid_task_status() {
349        let status = TaskStatus {
350            task_id: "TASK-001".to_string(),
351            platform_id: "Alpha-3".to_string(),
352            state: TaskState::Active as i32,
353            statistics: Some(TaskStatistics {
354                frames_processed: 1000,
355                total_detections: 50,
356                reported_detections: 45,
357                tracks_created: 10,
358                tracks_active: 5,
359                chipouts_generated: 20,
360                products_sent: 65,
361                avg_inference_time_ms: 25.0,
362                avg_fps: 30.0,
363                uptime_s: 3600.0,
364            }),
365            error_message: String::new(),
366            updated_at: Some(Timestamp {
367                seconds: 1702000000,
368                nanos: 0,
369            }),
370        };
371        assert!(validate_task_status(&status).is_ok());
372    }
373
374    #[test]
375    fn test_invalid_statistics() {
376        let status = TaskStatus {
377            task_id: "TASK-001".to_string(),
378            platform_id: "Alpha-3".to_string(),
379            state: TaskState::Active as i32,
380            statistics: Some(TaskStatistics {
381                frames_processed: 1000,
382                total_detections: 50,
383                reported_detections: 100, // Invalid: exceeds total
384                ..Default::default()
385            }),
386            error_message: String::new(),
387            updated_at: Some(Timestamp {
388                seconds: 1702000000,
389                nanos: 0,
390            }),
391        };
392        let err = validate_task_status(&status).unwrap_err();
393        assert!(matches!(err, ValidationError::ConstraintViolation(_)));
394    }
395}