1use crate::models::{cell::CellState, node::NodeConfig, Capability, CapabilityExt, CapabilityType};
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CapabilityQuery {
45 pub required_types: Vec<CapabilityType>,
47 pub optional_types: Vec<CapabilityType>,
49 pub min_confidence: f32,
51 pub min_capability_count: Option<usize>,
53 pub limit: Option<usize>,
55}
56
57impl CapabilityQuery {
58 pub fn builder() -> CapabilityQueryBuilder {
60 CapabilityQueryBuilder::new()
61 }
62
63 pub fn matches(&self, capabilities: &[Capability]) -> bool {
65 if let Some(min_count) = self.min_capability_count {
67 if capabilities.len() < min_count {
68 return false;
69 }
70 }
71
72 for required_type in &self.required_types {
74 let has_type = capabilities.iter().any(|cap| {
75 cap.get_capability_type() == *required_type && cap.confidence >= self.min_confidence
76 });
77
78 if !has_type {
79 return false;
80 }
81 }
82
83 true
84 }
85
86 pub fn score(&self, capabilities: &[Capability]) -> f32 {
93 if capabilities.is_empty() {
94 return 0.0;
95 }
96
97 let mut score = 0.0;
98
99 if !self.required_types.is_empty() {
101 let required_score: f32 = self
102 .required_types
103 .iter()
104 .map(|req_type| {
105 capabilities
106 .iter()
107 .filter(|cap| cap.get_capability_type() == *req_type)
108 .map(|cap| cap.confidence)
109 .max_by(|a, b| a.partial_cmp(b).unwrap())
110 .unwrap_or(0.0)
111 })
112 .sum::<f32>()
113 / self.required_types.len() as f32;
114
115 score += required_score * 0.6;
116 } else {
117 score += 0.6;
119 }
120
121 if !self.optional_types.is_empty() {
123 let optional_score: f32 = self
124 .optional_types
125 .iter()
126 .map(|opt_type| {
127 capabilities
128 .iter()
129 .filter(|cap| cap.get_capability_type() == *opt_type)
130 .map(|cap| cap.confidence)
131 .max_by(|a, b| a.partial_cmp(b).unwrap())
132 .unwrap_or(0.0)
133 })
134 .sum::<f32>()
135 / self.optional_types.len() as f32;
136
137 score += optional_score * 0.3;
138 } else {
139 score += 0.3;
140 }
141
142 let avg_confidence: f32 =
144 capabilities.iter().map(|cap| cap.confidence).sum::<f32>() / capabilities.len() as f32;
145 score += avg_confidence * 0.1;
146
147 score.clamp(0.0, 1.0)
148 }
149}
150
151#[derive(Debug, Default)]
153pub struct CapabilityQueryBuilder {
154 required_types: Vec<CapabilityType>,
155 optional_types: Vec<CapabilityType>,
156 min_confidence: f32,
157 min_capability_count: Option<usize>,
158 limit: Option<usize>,
159}
160
161impl CapabilityQueryBuilder {
162 pub fn new() -> Self {
164 Self {
165 min_confidence: 0.0,
166 ..Default::default()
167 }
168 }
169
170 pub fn require_type(mut self, capability_type: CapabilityType) -> Self {
172 self.required_types.push(capability_type);
173 self
174 }
175
176 pub fn prefer_type(mut self, capability_type: CapabilityType) -> Self {
178 self.optional_types.push(capability_type);
179 self
180 }
181
182 pub fn min_confidence(mut self, min_confidence: f32) -> Self {
184 self.min_confidence = min_confidence.clamp(0.0, 1.0);
185 self
186 }
187
188 pub fn min_capability_count(mut self, count: usize) -> Self {
190 self.min_capability_count = Some(count);
191 self
192 }
193
194 pub fn limit(mut self, limit: usize) -> Self {
196 self.limit = Some(limit);
197 self
198 }
199
200 pub fn build(self) -> CapabilityQuery {
202 CapabilityQuery {
203 required_types: self.required_types,
204 optional_types: self.optional_types,
205 min_confidence: self.min_confidence,
206 min_capability_count: self.min_capability_count,
207 limit: self.limit,
208 }
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct QueryMatch<T> {
215 pub entity: T,
217 pub score: f32,
219}
220
221pub struct CapabilityQueryEngine;
223
224impl CapabilityQueryEngine {
225 pub fn new() -> Self {
227 Self
228 }
229
230 pub fn query_platforms(
232 &self,
233 query: &CapabilityQuery,
234 nodes: &[NodeConfig],
235 ) -> Vec<QueryMatch<NodeConfig>> {
236 let mut matches: Vec<QueryMatch<NodeConfig>> = nodes
237 .iter()
238 .filter(|node| query.matches(&node.capabilities))
239 .map(|node| QueryMatch {
240 score: query.score(&node.capabilities),
241 entity: node.clone(),
242 })
243 .collect();
244
245 matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
247
248 if let Some(limit) = query.limit {
250 matches.truncate(limit);
251 }
252
253 matches
254 }
255
256 pub fn query_squads(
258 &self,
259 query: &CapabilityQuery,
260 squads: &[CellState],
261 ) -> Vec<QueryMatch<CellState>> {
262 let mut matches: Vec<QueryMatch<CellState>> = squads
263 .iter()
264 .filter(|squad| query.matches(&squad.capabilities))
265 .map(|squad| QueryMatch {
266 score: query.score(&squad.capabilities),
267 entity: squad.clone(),
268 })
269 .collect();
270
271 matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
273
274 if let Some(limit) = query.limit {
276 matches.truncate(limit);
277 }
278
279 matches
280 }
281
282 pub fn platform_capability_stats(
284 &self,
285 nodes: &[NodeConfig],
286 ) -> HashMap<CapabilityType, CapabilityStats> {
287 let mut stats: HashMap<CapabilityType, Vec<f32>> = HashMap::new();
288
289 for node in nodes {
290 for capability in &node.capabilities {
291 stats
292 .entry(capability.get_capability_type())
293 .or_default()
294 .push(capability.confidence);
295 }
296 }
297
298 stats
299 .into_iter()
300 .map(|(cap_type, confidences)| {
301 (cap_type, CapabilityStats::from_confidences(&confidences))
302 })
303 .collect()
304 }
305}
306
307impl Default for CapabilityQueryEngine {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313#[derive(Debug, Clone)]
315pub struct CapabilityStats {
316 pub count: usize,
318 pub avg_confidence: f32,
320 pub min_confidence: f32,
322 pub max_confidence: f32,
324}
325
326impl CapabilityStats {
327 pub fn from_confidences(confidences: &[f32]) -> Self {
329 let count = confidences.len();
330 let sum: f32 = confidences.iter().sum();
331 let avg_confidence = if count > 0 { sum / count as f32 } else { 0.0 };
332
333 let min_confidence = confidences
334 .iter()
335 .copied()
336 .min_by(|a, b| a.partial_cmp(b).unwrap())
337 .unwrap_or(0.0);
338
339 let max_confidence = confidences
340 .iter()
341 .copied()
342 .max_by(|a, b| a.partial_cmp(b).unwrap())
343 .unwrap_or(0.0);
344
345 Self {
346 count,
347 avg_confidence,
348 min_confidence,
349 max_confidence,
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::models::{CapabilityExt, NodeConfigExt};
358
359 fn create_test_capability(id: &str, cap_type: CapabilityType, confidence: f32) -> Capability {
360 Capability::new(
361 id.to_string(),
362 format!("{:?} capability", cap_type),
363 cap_type,
364 confidence,
365 )
366 }
367
368 fn create_test_platform(
369 id: &str,
370 platform_type: &str,
371 capabilities: Vec<Capability>,
372 ) -> NodeConfig {
373 let mut platform = NodeConfig::new(platform_type.to_string());
374 platform.id = id.to_string();
375 for cap in capabilities {
376 platform.add_capability(cap);
377 }
378 platform
379 }
380
381 #[test]
382 fn test_query_builder() {
383 let query = CapabilityQuery::builder()
384 .require_type(CapabilityType::Sensor)
385 .require_type(CapabilityType::Communication)
386 .min_confidence(0.8)
387 .limit(10)
388 .build();
389
390 assert_eq!(query.required_types.len(), 2);
391 assert_eq!(query.min_confidence, 0.8);
392 assert_eq!(query.limit, Some(10));
393 }
394
395 #[test]
396 fn test_query_matches_required_types() {
397 let query = CapabilityQuery::builder()
398 .require_type(CapabilityType::Sensor)
399 .require_type(CapabilityType::Communication)
400 .min_confidence(0.7)
401 .build();
402
403 let caps1 = vec![
405 create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
406 create_test_capability("comms1", CapabilityType::Communication, 0.8),
407 ];
408 assert!(query.matches(&caps1));
409
410 let caps2 = vec![create_test_capability(
412 "sensor1",
413 CapabilityType::Sensor,
414 0.9,
415 )];
416 assert!(!query.matches(&caps2));
417
418 let caps3 = vec![
420 create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
421 create_test_capability("comms1", CapabilityType::Communication, 0.5),
422 ];
423 assert!(!query.matches(&caps3));
424 }
425
426 #[test]
427 fn test_query_matches_min_capability_count() {
428 let query = CapabilityQuery::builder().min_capability_count(3).build();
429
430 let caps1 = vec![
431 create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
432 create_test_capability("comms1", CapabilityType::Communication, 0.8),
433 create_test_capability("compute1", CapabilityType::Compute, 0.7),
434 ];
435 assert!(query.matches(&caps1));
436
437 let caps2 = vec![
438 create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
439 create_test_capability("comms1", CapabilityType::Communication, 0.8),
440 ];
441 assert!(!query.matches(&caps2));
442 }
443
444 #[test]
445 fn test_query_scoring() {
446 let query = CapabilityQuery::builder()
447 .require_type(CapabilityType::Sensor)
448 .prefer_type(CapabilityType::Communication)
449 .build();
450
451 let caps1 = vec![
453 create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
454 create_test_capability("comms1", CapabilityType::Communication, 0.8),
455 ];
456 let score1 = query.score(&caps1);
457
458 let caps2 = vec![create_test_capability(
460 "sensor1",
461 CapabilityType::Sensor,
462 0.9,
463 )];
464 let score2 = query.score(&caps2);
465
466 assert!(score1 > score2);
468 assert!(score1 <= 1.0);
469 assert!(score2 > 0.0);
470 }
471
472 #[test]
473 fn test_query_engine_platforms() {
474 let engine = CapabilityQueryEngine::new();
475
476 let nodes = vec![
477 create_test_platform(
478 "platform1",
479 "UAV",
480 vec![
481 create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
482 create_test_capability("comms1", CapabilityType::Communication, 0.8),
483 ],
484 ),
485 create_test_platform(
486 "platform2",
487 "UAV",
488 vec![create_test_capability(
489 "sensor2",
490 CapabilityType::Sensor,
491 0.7,
492 )],
493 ),
494 create_test_platform(
495 "platform3",
496 "UAV",
497 vec![
498 create_test_capability("sensor3", CapabilityType::Sensor, 0.95),
499 create_test_capability("comms3", CapabilityType::Communication, 0.9),
500 create_test_capability("compute3", CapabilityType::Compute, 0.85),
501 ],
502 ),
503 ];
504
505 let query = CapabilityQuery::builder()
506 .require_type(CapabilityType::Sensor)
507 .prefer_type(CapabilityType::Communication)
508 .min_confidence(0.7)
509 .build();
510
511 let matches = engine.query_platforms(&query, &nodes);
512
513 assert_eq!(matches.len(), 3);
515
516 assert_eq!(matches[0].entity.id, "platform3");
518 assert!(matches[0].score > matches[1].score);
519 }
520
521 #[test]
522 fn test_query_engine_limit() {
523 let engine = CapabilityQueryEngine::new();
524
525 let nodes = vec![
526 create_test_platform(
527 "platform1",
528 "UAV",
529 vec![create_test_capability(
530 "sensor1",
531 CapabilityType::Sensor,
532 0.9,
533 )],
534 ),
535 create_test_platform(
536 "platform2",
537 "UAV",
538 vec![create_test_capability(
539 "sensor2",
540 CapabilityType::Sensor,
541 0.8,
542 )],
543 ),
544 create_test_platform(
545 "platform3",
546 "UAV",
547 vec![create_test_capability(
548 "sensor3",
549 CapabilityType::Sensor,
550 0.7,
551 )],
552 ),
553 ];
554
555 let query = CapabilityQuery::builder()
556 .require_type(CapabilityType::Sensor)
557 .limit(2)
558 .build();
559
560 let matches = engine.query_platforms(&query, &nodes);
561
562 assert_eq!(matches.len(), 2);
563 assert!(matches[0].score >= matches[1].score);
565 }
566
567 #[test]
568 fn test_capability_stats() {
569 let engine = CapabilityQueryEngine::new();
570
571 let nodes = vec![
572 create_test_platform(
573 "platform1",
574 "UAV",
575 vec![
576 create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
577 create_test_capability("comms1", CapabilityType::Communication, 0.8),
578 ],
579 ),
580 create_test_platform(
581 "platform2",
582 "UAV",
583 vec![
584 create_test_capability("sensor2", CapabilityType::Sensor, 0.7),
585 create_test_capability("compute2", CapabilityType::Compute, 0.85),
586 ],
587 ),
588 ];
589
590 let stats = engine.platform_capability_stats(&nodes);
591
592 assert_eq!(stats.len(), 3);
593 assert_eq!(stats.get(&CapabilityType::Sensor).unwrap().count, 2);
594 assert_eq!(stats.get(&CapabilityType::Communication).unwrap().count, 1);
595 assert_eq!(stats.get(&CapabilityType::Compute).unwrap().count, 1);
596
597 let sensor_stats = stats.get(&CapabilityType::Sensor).unwrap();
598 assert_eq!(sensor_stats.min_confidence, 0.7);
599 assert_eq!(sensor_stats.max_confidence, 0.9);
600 assert!((sensor_stats.avg_confidence - 0.8).abs() < 0.01);
601 }
602
603 #[test]
604 fn test_empty_query() {
605 let query = CapabilityQuery::builder().build();
606
607 let caps = vec![
608 create_test_capability("sensor1", CapabilityType::Sensor, 0.9),
609 create_test_capability("comms1", CapabilityType::Communication, 0.8),
610 ];
611
612 assert!(query.matches(&caps));
614 assert!(query.score(&caps) > 0.0);
616 }
617}