Skip to main content

peat_protocol/discovery/
capability_query.rs

1//! Capability-based queries for platform and squad discovery
2//!
3//! Implements the capability query system for finding nodes and squads
4//! based on required capabilities during the bootstrap phase.
5//!
6//! # Architecture
7//!
8//! The capability query system allows C2 or nodes to discover other entities
9//! based on capability requirements:
10//!
11//! ## Query Types
12//!
13//! - **Type-based**: Find entities with specific capability types (Sensor, Compute, etc.)
14//! - **Confidence-based**: Filter by minimum confidence thresholds
15//! - **Combination queries**: Require multiple capabilities (AND logic)
16//! - **Ranked results**: Score and rank matches by relevance
17//!
18//! ## Use Cases
19//!
20//! - **Mission planning**: Find nodes with required sensor capabilities
21//! - **Cell formation**: Form cells with complementary capabilities
22//! - **Resource discovery**: Locate available compute/comms resources
23//! - **Redundancy**: Find backup nodes with similar capabilities
24//!
25//! ## Example
26//!
27//! ```rust,ignore
28//! // Find nodes with sensor AND communication capabilities
29//! let query = CapabilityQuery::builder()
30//!     .require_type(CapabilityType::Sensor)
31//!     .require_type(CapabilityType::Communication)
32//!     .min_confidence(0.8)
33//!     .build();
34//!
35//! let matches = engine.query_platforms(&query, &nodes)?;
36//! ```
37
38use crate::models::{cell::CellState, node::NodeConfig, Capability, CapabilityExt, CapabilityType};
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41
42/// Capability query for finding nodes or squads
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CapabilityQuery {
45    /// Required capability types (AND logic - all must be present)
46    pub required_types: Vec<CapabilityType>,
47    /// Optional capability types (OR logic - any can be present for bonus score)
48    pub optional_types: Vec<CapabilityType>,
49    /// Minimum confidence threshold (0.0 - 1.0)
50    pub min_confidence: f32,
51    /// Minimum number of total capabilities
52    pub min_capability_count: Option<usize>,
53    /// Maximum results to return
54    pub limit: Option<usize>,
55}
56
57impl CapabilityQuery {
58    /// Create a new query builder
59    pub fn builder() -> CapabilityQueryBuilder {
60        CapabilityQueryBuilder::new()
61    }
62
63    /// Check if a set of capabilities satisfies this query
64    pub fn matches(&self, capabilities: &[Capability]) -> bool {
65        // Check minimum capability count
66        if let Some(min_count) = self.min_capability_count {
67            if capabilities.len() < min_count {
68                return false;
69            }
70        }
71
72        // Check all required types are present with sufficient confidence
73        for required_type in &self.required_types {
74            let has_type = capabilities.iter().any(|cap| {
75                cap.get_capability_type() == *required_type && cap.confidence >= self.min_confidence
76            });
77
78            if !has_type {
79                return false;
80            }
81        }
82
83        true
84    }
85
86    /// Calculate a relevance score for a set of capabilities (0.0 - 1.0)
87    ///
88    /// Score components:
89    /// - Required types: 0.6 weight (normalized by count)
90    /// - Optional types: 0.3 weight (normalized by count)
91    /// - Average confidence: 0.1 weight
92    pub fn score(&self, capabilities: &[Capability]) -> f32 {
93        if capabilities.is_empty() {
94            return 0.0;
95        }
96
97        let mut score = 0.0;
98
99        // Required types score (60% weight)
100        if !self.required_types.is_empty() {
101            let required_score: f32 = self
102                .required_types
103                .iter()
104                .map(|req_type| {
105                    capabilities
106                        .iter()
107                        .filter(|cap| cap.get_capability_type() == *req_type)
108                        .map(|cap| cap.confidence)
109                        .max_by(|a, b| a.partial_cmp(b).unwrap())
110                        .unwrap_or(0.0)
111                })
112                .sum::<f32>()
113                / self.required_types.len() as f32;
114
115            score += required_score * 0.6;
116        } else {
117            // If no required types, give full weight
118            score += 0.6;
119        }
120
121        // Optional types score (30% weight)
122        if !self.optional_types.is_empty() {
123            let optional_score: f32 = self
124                .optional_types
125                .iter()
126                .map(|opt_type| {
127                    capabilities
128                        .iter()
129                        .filter(|cap| cap.get_capability_type() == *opt_type)
130                        .map(|cap| cap.confidence)
131                        .max_by(|a, b| a.partial_cmp(b).unwrap())
132                        .unwrap_or(0.0)
133                })
134                .sum::<f32>()
135                / self.optional_types.len() as f32;
136
137            score += optional_score * 0.3;
138        } else {
139            score += 0.3;
140        }
141
142        // Average confidence score (10% weight)
143        let avg_confidence: f32 =
144            capabilities.iter().map(|cap| cap.confidence).sum::<f32>() / capabilities.len() as f32;
145        score += avg_confidence * 0.1;
146
147        score.clamp(0.0, 1.0)
148    }
149}
150
151/// Builder for creating capability queries
152#[derive(Debug, Default)]
153pub struct CapabilityQueryBuilder {
154    required_types: Vec<CapabilityType>,
155    optional_types: Vec<CapabilityType>,
156    min_confidence: f32,
157    min_capability_count: Option<usize>,
158    limit: Option<usize>,
159}
160
161impl CapabilityQueryBuilder {
162    /// Create a new query builder
163    pub fn new() -> Self {
164        Self {
165            min_confidence: 0.0,
166            ..Default::default()
167        }
168    }
169
170    /// Add a required capability type
171    pub fn require_type(mut self, capability_type: CapabilityType) -> Self {
172        self.required_types.push(capability_type);
173        self
174    }
175
176    /// Add an optional capability type
177    pub fn prefer_type(mut self, capability_type: CapabilityType) -> Self {
178        self.optional_types.push(capability_type);
179        self
180    }
181
182    /// Set minimum confidence threshold
183    pub fn min_confidence(mut self, min_confidence: f32) -> Self {
184        self.min_confidence = min_confidence.clamp(0.0, 1.0);
185        self
186    }
187
188    /// Set minimum capability count
189    pub fn min_capability_count(mut self, count: usize) -> Self {
190        self.min_capability_count = Some(count);
191        self
192    }
193
194    /// Set maximum results limit
195    pub fn limit(mut self, limit: usize) -> Self {
196        self.limit = Some(limit);
197        self
198    }
199
200    /// Build the query
201    pub fn build(self) -> CapabilityQuery {
202        CapabilityQuery {
203            required_types: self.required_types,
204            optional_types: self.optional_types,
205            min_confidence: self.min_confidence,
206            min_capability_count: self.min_capability_count,
207            limit: self.limit,
208        }
209    }
210}
211
212/// Result of a capability query with score
213#[derive(Debug, Clone)]
214pub struct QueryMatch<T> {
215    /// The matched entity (platform or squad)
216    pub entity: T,
217    /// Relevance score (0.0 - 1.0)
218    pub score: f32,
219}
220
221/// Capability query engine for finding nodes and squads
222pub struct CapabilityQueryEngine;
223
224impl CapabilityQueryEngine {
225    /// Create a new query engine
226    pub fn new() -> Self {
227        Self
228    }
229
230    /// Query nodes by capabilities
231    pub fn query_platforms(
232        &self,
233        query: &CapabilityQuery,
234        nodes: &[NodeConfig],
235    ) -> Vec<QueryMatch<NodeConfig>> {
236        let mut matches: Vec<QueryMatch<NodeConfig>> = nodes
237            .iter()
238            .filter(|node| query.matches(&node.capabilities))
239            .map(|node| QueryMatch {
240                score: query.score(&node.capabilities),
241                entity: node.clone(),
242            })
243            .collect();
244
245        // Sort by score descending
246        matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
247
248        // Apply limit if specified
249        if let Some(limit) = query.limit {
250            matches.truncate(limit);
251        }
252
253        matches
254    }
255
256    /// Query cells by capabilities
257    pub fn query_squads(
258        &self,
259        query: &CapabilityQuery,
260        squads: &[CellState],
261    ) -> Vec<QueryMatch<CellState>> {
262        let mut matches: Vec<QueryMatch<CellState>> = squads
263            .iter()
264            .filter(|squad| query.matches(&squad.capabilities))
265            .map(|squad| QueryMatch {
266                score: query.score(&squad.capabilities),
267                entity: squad.clone(),
268            })
269            .collect();
270
271        // Sort by score descending
272        matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
273
274        // Apply limit if specified
275        if let Some(limit) = query.limit {
276            matches.truncate(limit);
277        }
278
279        matches
280    }
281
282    /// Get capability statistics for a set of platforms
283    pub fn platform_capability_stats(
284        &self,
285        nodes: &[NodeConfig],
286    ) -> HashMap<CapabilityType, CapabilityStats> {
287        let mut stats: HashMap<CapabilityType, Vec<f32>> = HashMap::new();
288
289        for node in nodes {
290            for capability in &node.capabilities {
291                stats
292                    .entry(capability.get_capability_type())
293                    .or_default()
294                    .push(capability.confidence);
295            }
296        }
297
298        stats
299            .into_iter()
300            .map(|(cap_type, confidences)| {
301                (cap_type, CapabilityStats::from_confidences(&confidences))
302            })
303            .collect()
304    }
305}
306
307impl Default for CapabilityQueryEngine {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313/// Statistical summary of capability distribution
314#[derive(Debug, Clone)]
315pub struct CapabilityStats {
316    /// Total count of this capability type
317    pub count: usize,
318    /// Average confidence
319    pub avg_confidence: f32,
320    /// Minimum confidence
321    pub min_confidence: f32,
322    /// Maximum confidence
323    pub max_confidence: f32,
324}
325
326impl CapabilityStats {
327    /// Calculate statistics from confidence values
328    pub fn from_confidences(confidences: &[f32]) -> Self {
329        let count = confidences.len();
330        let sum: f32 = confidences.iter().sum();
331        let avg_confidence = if count > 0 { sum / count as f32 } else { 0.0 };
332
333        let min_confidence = confidences
334            .iter()
335            .copied()
336            .min_by(|a, b| a.partial_cmp(b).unwrap())
337            .unwrap_or(0.0);
338
339        let max_confidence = confidences
340            .iter()
341            .copied()
342            .max_by(|a, b| a.partial_cmp(b).unwrap())
343            .unwrap_or(0.0);
344
345        Self {
346            count,
347            avg_confidence,
348            min_confidence,
349            max_confidence,
350        }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::models::{CapabilityExt, NodeConfigExt};
358
359    fn create_test_capability(id: &str, cap_type: CapabilityType, confidence: f32) -> Capability {
360        Capability::new(
361            id.to_string(),
362            format!("{:?} capability", cap_type),
363            cap_type,
364            confidence,
365        )
366    }
367
368    fn create_test_platform(
369        id: &str,
370        platform_type: &str,
371        capabilities: Vec<Capability>,
372    ) -> NodeConfig {
373        let mut platform = NodeConfig::new(platform_type.to_string());
374        platform.id = id.to_string();
375        for cap in capabilities {
376            platform.add_capability(cap);
377        }
378        platform
379    }
380
381    #[test]
382    fn test_query_builder() {
383        let query = CapabilityQuery::builder()
384            .require_type(CapabilityType::Sensor)
385            .require_type(CapabilityType::Communication)
386            .min_confidence(0.8)
387            .limit(10)
388            .build();
389
390        assert_eq!(query.required_types.len(), 2);
391        assert_eq!(query.min_confidence, 0.8);
392        assert_eq!(query.limit, Some(10));
393    }
394
395    #[test]
396    fn test_query_matches_required_types() {
397        let query = CapabilityQuery::builder()
398            .require_type(CapabilityType::Sensor)
399            .require_type(CapabilityType::Communication)
400            .min_confidence(0.7)
401            .build();
402
403        // Node with both required capabilities
404        let caps1 = vec![
405            create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
406            create_test_capability("comms1", CapabilityType::Communication, 0.8),
407        ];
408        assert!(query.matches(&caps1));
409
410        // Node missing one required capability
411        let caps2 = vec![create_test_capability(
412            "sensor1",
413            CapabilityType::Sensor,
414            0.9,
415        )];
416        assert!(!query.matches(&caps2));
417
418        // Node with low confidence
419        let caps3 = vec![
420            create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
421            create_test_capability("comms1", CapabilityType::Communication, 0.5),
422        ];
423        assert!(!query.matches(&caps3));
424    }
425
426    #[test]
427    fn test_query_matches_min_capability_count() {
428        let query = CapabilityQuery::builder().min_capability_count(3).build();
429
430        let caps1 = vec![
431            create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
432            create_test_capability("comms1", CapabilityType::Communication, 0.8),
433            create_test_capability("compute1", CapabilityType::Compute, 0.7),
434        ];
435        assert!(query.matches(&caps1));
436
437        let caps2 = vec![
438            create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
439            create_test_capability("comms1", CapabilityType::Communication, 0.8),
440        ];
441        assert!(!query.matches(&caps2));
442    }
443
444    #[test]
445    fn test_query_scoring() {
446        let query = CapabilityQuery::builder()
447            .require_type(CapabilityType::Sensor)
448            .prefer_type(CapabilityType::Communication)
449            .build();
450
451        // Node with both required and optional
452        let caps1 = vec![
453            create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
454            create_test_capability("comms1", CapabilityType::Communication, 0.8),
455        ];
456        let score1 = query.score(&caps1);
457
458        // Node with only required
459        let caps2 = vec![create_test_capability(
460            "sensor1",
461            CapabilityType::Sensor,
462            0.9,
463        )];
464        let score2 = query.score(&caps2);
465
466        // First platform should score higher
467        assert!(score1 > score2);
468        assert!(score1 <= 1.0);
469        assert!(score2 > 0.0);
470    }
471
472    #[test]
473    fn test_query_engine_platforms() {
474        let engine = CapabilityQueryEngine::new();
475
476        let nodes = vec![
477            create_test_platform(
478                "platform1",
479                "UAV",
480                vec![
481                    create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
482                    create_test_capability("comms1", CapabilityType::Communication, 0.8),
483                ],
484            ),
485            create_test_platform(
486                "platform2",
487                "UAV",
488                vec![create_test_capability(
489                    "sensor2",
490                    CapabilityType::Sensor,
491                    0.7,
492                )],
493            ),
494            create_test_platform(
495                "platform3",
496                "UAV",
497                vec![
498                    create_test_capability("sensor3", CapabilityType::Sensor, 0.95),
499                    create_test_capability("comms3", CapabilityType::Communication, 0.9),
500                    create_test_capability("compute3", CapabilityType::Compute, 0.85),
501                ],
502            ),
503        ];
504
505        let query = CapabilityQuery::builder()
506            .require_type(CapabilityType::Sensor)
507            .prefer_type(CapabilityType::Communication)
508            .min_confidence(0.7)
509            .build();
510
511        let matches = engine.query_platforms(&query, &nodes);
512
513        // All nodes have sensor capability
514        assert_eq!(matches.len(), 3);
515
516        // platform3 should score highest (has all capabilities with high confidence)
517        assert_eq!(matches[0].entity.id, "platform3");
518        assert!(matches[0].score > matches[1].score);
519    }
520
521    #[test]
522    fn test_query_engine_limit() {
523        let engine = CapabilityQueryEngine::new();
524
525        let nodes = vec![
526            create_test_platform(
527                "platform1",
528                "UAV",
529                vec![create_test_capability(
530                    "sensor1",
531                    CapabilityType::Sensor,
532                    0.9,
533                )],
534            ),
535            create_test_platform(
536                "platform2",
537                "UAV",
538                vec![create_test_capability(
539                    "sensor2",
540                    CapabilityType::Sensor,
541                    0.8,
542                )],
543            ),
544            create_test_platform(
545                "platform3",
546                "UAV",
547                vec![create_test_capability(
548                    "sensor3",
549                    CapabilityType::Sensor,
550                    0.7,
551                )],
552            ),
553        ];
554
555        let query = CapabilityQuery::builder()
556            .require_type(CapabilityType::Sensor)
557            .limit(2)
558            .build();
559
560        let matches = engine.query_platforms(&query, &nodes);
561
562        assert_eq!(matches.len(), 2);
563        // Should return top 2 by score
564        assert!(matches[0].score >= matches[1].score);
565    }
566
567    #[test]
568    fn test_capability_stats() {
569        let engine = CapabilityQueryEngine::new();
570
571        let nodes = vec![
572            create_test_platform(
573                "platform1",
574                "UAV",
575                vec![
576                    create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
577                    create_test_capability("comms1", CapabilityType::Communication, 0.8),
578                ],
579            ),
580            create_test_platform(
581                "platform2",
582                "UAV",
583                vec![
584                    create_test_capability("sensor2", CapabilityType::Sensor, 0.7),
585                    create_test_capability("compute2", CapabilityType::Compute, 0.85),
586                ],
587            ),
588        ];
589
590        let stats = engine.platform_capability_stats(&nodes);
591
592        assert_eq!(stats.len(), 3);
593        assert_eq!(stats.get(&CapabilityType::Sensor).unwrap().count, 2);
594        assert_eq!(stats.get(&CapabilityType::Communication).unwrap().count, 1);
595        assert_eq!(stats.get(&CapabilityType::Compute).unwrap().count, 1);
596
597        let sensor_stats = stats.get(&CapabilityType::Sensor).unwrap();
598        assert_eq!(sensor_stats.min_confidence, 0.7);
599        assert_eq!(sensor_stats.max_confidence, 0.9);
600        assert!((sensor_stats.avg_confidence - 0.8).abs() < 0.01);
601    }
602
603    #[test]
604    fn test_empty_query() {
605        let query = CapabilityQuery::builder().build();
606
607        let caps = vec![
608            create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
609            create_test_capability("comms1", CapabilityType::Communication, 0.8),
610        ];
611
612        // Empty query should match any platform
613        assert!(query.matches(&caps));
614        // Score should be non-zero
615        assert!(query.score(&caps) > 0.0);
616    }
617}