1use std::collections::{BTreeMap, HashMap, HashSet};
27use std::time::Duration;
28
29use serde::{Deserialize, Serialize};
30
31use super::state::{FoldIndex, FoldState, NodeId};
32use super::FoldKind;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case")]
40pub enum NodeState {
41 Idle,
43 Busy,
45 Reserved,
48 Faulty,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
60pub struct HardwareSummary {
61 pub gpu_vendor: Option<String>,
64 pub gpu_count: u8,
66 pub memory_gb: Option<u32>,
68 pub vram_gb: Option<u32>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
77pub struct CapabilityMembership {
78 pub class_hash: u64,
83 pub tags: Vec<String>,
88 pub hardware: Option<HardwareSummary>,
90 pub state: NodeState,
93 pub region: Option<String>,
96 pub price_quote: Option<u64>,
100 pub reflex_addr: Option<std::net::SocketAddr>,
108 pub allowed_nodes: Vec<u64>,
114 pub allowed_subnets: Vec<super::super::subnet::SubnetId>,
118 pub allowed_groups: Vec<super::super::group::GroupId>,
122 pub metadata: BTreeMap<String, String>,
129}
130
131#[derive(Debug, Clone, PartialEq, Eq)]
137pub enum CapabilityQuery {
138 InClass(u64),
140 HasAllTags(Vec<String>),
143 HasAnyTag(Vec<String>),
148 InState(NodeState),
150 InRegion(String),
152 Composite(CapabilityFilter),
155}
156
157#[derive(Debug, Clone, PartialEq, Eq, Default)]
161pub struct CapabilityFilter {
162 pub class: Option<u64>,
164 pub tags_all: Vec<String>,
166 pub tags_any: Vec<String>,
169 pub tag_groups_all: Vec<Vec<String>>,
178 pub state: Option<NodeState>,
180 pub region: Option<String>,
182 pub limit: usize,
184}
185
186pub type CapabilityMatch = ((u64, NodeId), CapabilityMembership);
188
189#[derive(Debug, Default)]
198pub struct CapabilityIndexInner {
199 by_tag: HashMap<String, HashSet<(u64, NodeId)>>,
201 by_synthetic: HashMap<String, HashSet<(u64, NodeId)>>,
211 by_region: HashMap<String, HashSet<(u64, NodeId)>>,
213 by_state: HashMap<NodeState, HashSet<(u64, NodeId)>>,
215}
216
217impl FoldIndex<CapabilityFold> for CapabilityIndexInner {
218 fn on_insert(&mut self, key: &(u64, NodeId), payload: &CapabilityMembership) {
219 for tag in &payload.tags {
220 self.by_tag.entry(tag.clone()).or_default().insert(*key);
221 }
222 for tag in derive_synthetic_index_tags(payload) {
229 self.by_synthetic.entry(tag).or_default().insert(*key);
230 }
231 if let Some(region) = &payload.region {
232 self.by_region
233 .entry(region.clone())
234 .or_default()
235 .insert(*key);
236 }
237 self.by_state.entry(payload.state).or_default().insert(*key);
238 }
239
240 fn on_remove(&mut self, key: &(u64, NodeId), payload: &CapabilityMembership) {
241 for tag in &payload.tags {
242 if let Some(set) = self.by_tag.get_mut(tag) {
243 set.remove(key);
244 if set.is_empty() {
245 self.by_tag.remove(tag);
246 }
247 }
248 }
249 for tag in derive_synthetic_index_tags(payload) {
252 if let Some(set) = self.by_synthetic.get_mut(&tag) {
253 set.remove(key);
254 if set.is_empty() {
255 self.by_synthetic.remove(&tag);
256 }
257 }
258 }
259 if let Some(region) = &payload.region {
260 if let Some(set) = self.by_region.get_mut(region) {
261 set.remove(key);
262 if set.is_empty() {
263 self.by_region.remove(region);
264 }
265 }
266 }
267 if let Some(set) = self.by_state.get_mut(&payload.state) {
268 set.remove(key);
269 if set.is_empty() {
270 self.by_state.remove(&payload.state);
271 }
272 }
273 }
274
275 fn clear(&mut self) {
276 self.by_tag.clear();
277 self.by_synthetic.clear();
278 self.by_region.clear();
279 self.by_state.clear();
280 }
281}
282
283fn derive_synthetic_index_tags(payload: &CapabilityMembership) -> Vec<String> {
304 use super::super::tag::{Tag, TaxonomyAxis};
305 let mut out = Vec::new();
306 for s in &payload.tags {
307 let Ok(Tag::AxisValue {
308 axis: TaxonomyAxis::Software,
309 key,
310 value,
311 ..
312 }) = Tag::parse(s)
313 else {
314 continue;
315 };
316 if let Some(rest) = key.strip_prefix("model.") {
317 if matches!(rest.split_once('.'), Some((_, "id"))) {
318 out.push(format!("model:{value}"));
319 }
320 } else if let Some(rest) = key.strip_prefix("tool.") {
321 if matches!(rest.split_once('.'), Some((_, "tool_id"))) {
322 out.push(format!("tool:{value}"));
323 }
324 }
325 }
326 if let Some(h) = &payload.hardware {
327 if h.gpu_count > 0 || h.gpu_vendor.is_some() {
328 out.push("gpu:present".to_string());
329 }
330 if let Some(vendor) = &h.gpu_vendor {
331 out.push(format!("gpu:vendor:{vendor}"));
332 }
333 }
334 out
335}
336
337#[derive(Debug)]
339pub struct CapabilityFold;
340
341impl FoldKind for CapabilityFold {
342 const KIND_ID: u16 = 1;
345 const CHANNEL_PREFIX: &'static str = "fold:cap:";
346 const DEFAULT_TTL: Duration = Duration::from_secs(60);
351
352 type Key = (u64, NodeId);
353 type Payload = CapabilityMembership;
354 type Query = CapabilityQuery;
355 type Result = Vec<CapabilityMatch>;
356 type Index = CapabilityIndexInner;
357
358 fn key_for(node_id: NodeId, payload: &Self::Payload) -> Self::Key {
359 (payload.class_hash, node_id)
360 }
361
362 fn build_index() -> CapabilityIndexInner {
363 CapabilityIndexInner::default()
364 }
365
366 fn query(
367 state: &FoldState<Self>,
368 index: &CapabilityIndexInner,
369 query: CapabilityQuery,
370 ) -> Vec<CapabilityMatch> {
371 match query {
372 CapabilityQuery::InClass(class) => state
373 .entries
374 .iter()
375 .filter(|((c, _), _)| *c == class)
376 .map(|(k, e)| (*k, e.payload.clone()))
377 .collect(),
378 CapabilityQuery::HasAllTags(tags) => resolve_keys_all_tags(index, &tags)
379 .into_iter()
380 .filter_map(|k| state.entries.get(&k).map(|e| (k, e.payload.clone())))
381 .collect(),
382 CapabilityQuery::HasAnyTag(tags) => {
383 let mut seen: HashSet<(u64, NodeId)> = HashSet::new();
384 for tag in &tags {
385 if let Some(keys) = index.by_tag.get(tag) {
386 seen.extend(keys.iter().copied());
387 }
388 }
389 seen.into_iter()
390 .filter_map(|k| state.entries.get(&k).map(|e| (k, e.payload.clone())))
391 .collect()
392 }
393 CapabilityQuery::InState(s) => index
394 .by_state
395 .get(&s)
396 .into_iter()
397 .flat_map(|set| set.iter().copied())
398 .filter_map(|k| state.entries.get(&k).map(|e| (k, e.payload.clone())))
399 .collect(),
400 CapabilityQuery::InRegion(r) => index
401 .by_region
402 .get(&r)
403 .into_iter()
404 .flat_map(|set| set.iter().copied())
405 .filter_map(|k| state.entries.get(&k).map(|e| (k, e.payload.clone())))
406 .collect(),
407 CapabilityQuery::Composite(filter) => composite_query(state, index, &filter),
408 }
409 }
410}
411
412fn resolve_keys_all_tags(index: &CapabilityIndexInner, tags: &[String]) -> HashSet<(u64, NodeId)> {
418 if tags.is_empty() {
419 return index
424 .by_state
425 .values()
426 .flat_map(|set| set.iter().copied())
427 .collect();
428 }
429 let mut tags_by_selectivity: Vec<&String> = tags.iter().collect();
431 tags_by_selectivity.sort_by_key(|t| index.by_tag.get(*t).map(|s| s.len()).unwrap_or(0));
432
433 let Some(first) = tags_by_selectivity.first() else {
434 return HashSet::new();
435 };
436 let Some(initial) = index.by_tag.get(*first) else {
437 return HashSet::new();
439 };
440 let mut candidates: HashSet<(u64, NodeId)> = initial.iter().copied().collect();
441 for tag in tags_by_selectivity.iter().skip(1) {
442 let Some(bucket) = index.by_tag.get(*tag) else {
443 return HashSet::new();
444 };
445 candidates.retain(|k| bucket.contains(k));
446 if candidates.is_empty() {
447 break;
448 }
449 }
450 candidates
451}
452
453pub(crate) fn resolve_candidate_keys(
466 state: &FoldState<CapabilityFold>,
467 index: &CapabilityIndexInner,
468 filter: &CapabilityFilter,
469) -> HashSet<(u64, NodeId)> {
470 let mut group_unions: Vec<HashSet<(u64, NodeId)>> = Vec::new();
475
476 let mut candidates: HashSet<(u64, NodeId)> = if !filter.tags_all.is_empty() {
480 let seed = resolve_keys_all_tags(index, &filter.tags_all);
481 if !seed.is_empty() {
485 group_unions = build_group_unions(index, &filter.tag_groups_all);
486 }
487 seed
488 } else {
489 group_unions = build_group_unions(index, &filter.tag_groups_all);
490 if !group_unions.is_empty() {
491 let smallest = group_unions
496 .iter()
497 .enumerate()
498 .min_by_key(|(_, u)| u.len())
499 .map(|(i, _)| i)
500 .unwrap_or(0);
501 group_unions.swap_remove(smallest)
502 } else if let Some(state_filter) = filter.state {
503 index
504 .by_state
505 .get(&state_filter)
506 .cloned()
507 .unwrap_or_default()
508 } else if let Some(region) = &filter.region {
509 index.by_region.get(region).cloned().unwrap_or_default()
510 } else if let Some(class) = filter.class {
511 state
512 .entries
513 .keys()
514 .filter(|(c, _)| *c == class)
515 .copied()
516 .collect()
517 } else {
518 state.entries.keys().copied().collect()
520 }
521 };
522
523 if let Some(class) = filter.class {
525 candidates.retain(|(c, _)| *c == class);
526 }
527 if let Some(state_filter) = filter.state {
528 if let Some(bucket) = index.by_state.get(&state_filter) {
529 candidates.retain(|k| bucket.contains(k));
530 } else {
531 candidates.clear();
532 }
533 }
534 if let Some(region) = &filter.region {
535 if let Some(bucket) = index.by_region.get(region) {
536 candidates.retain(|k| bucket.contains(k));
537 } else {
538 candidates.clear();
539 }
540 }
541 if !filter.tags_any.is_empty() {
542 let mut tags_any_union: HashSet<(u64, NodeId)> = HashSet::new();
546 for tag in &filter.tags_any {
547 if let Some(bucket) = index.by_tag.get(tag) {
548 tags_any_union.extend(bucket.iter().copied());
549 }
550 }
551 candidates.retain(|k| tags_any_union.contains(k));
552 }
553
554 for union in &group_unions {
558 candidates.retain(|k| union.contains(k));
559 if candidates.is_empty() {
560 break;
561 }
562 }
563
564 candidates
569}
570
571fn build_group_unions(
576 index: &CapabilityIndexInner,
577 groups: &[Vec<String>],
578) -> Vec<HashSet<(u64, NodeId)>> {
579 groups
580 .iter()
581 .filter(|g| !g.is_empty())
582 .map(|g| group_union(index, g))
583 .collect()
584}
585
586fn group_union(index: &CapabilityIndexInner, group: &[String]) -> HashSet<(u64, NodeId)> {
596 let mut union: HashSet<(u64, NodeId)> = HashSet::new();
597 for tag in group {
598 if let Some(bucket) = index.by_synthetic.get(tag) {
599 union.extend(bucket.iter().copied());
600 }
601 }
602 union
603}
604
605fn composite_query(
610 state: &FoldState<CapabilityFold>,
611 index: &CapabilityIndexInner,
612 filter: &CapabilityFilter,
613) -> Vec<CapabilityMatch> {
614 let candidates = resolve_candidate_keys(state, index, filter);
615 let mut matches: Vec<CapabilityMatch> = candidates
617 .into_iter()
618 .filter_map(|k| state.entries.get(&k).map(|e| (k, e.payload.clone())))
619 .collect();
620 if filter.limit > 0 && matches.len() > filter.limit {
621 matches.truncate(filter.limit);
622 }
623 matches
624}
625
626pub fn capability_tags_for(fold: &super::Fold<CapabilityFold>, node_id: NodeId) -> Vec<String> {
637 fold.with_state(|state| tags_union_for(state, node_id))
638}
639
640pub fn capability_tags_for_all(
647 fold: &super::Fold<CapabilityFold>,
648) -> std::collections::HashMap<NodeId, Vec<String>> {
649 fold.with_state(|state| {
650 let mut out: std::collections::HashMap<NodeId, Vec<String>> =
651 std::collections::HashMap::with_capacity(state.by_node.len());
652 for node_id in state.by_node.keys() {
653 out.insert(*node_id, tags_union_for(state, *node_id));
654 }
655 out
656 })
657}
658
659fn tags_union_for(state: &FoldState<CapabilityFold>, node_id: NodeId) -> Vec<String> {
662 let Some(keys) = state.by_node.get(&node_id) else {
663 return Vec::new();
664 };
665 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
666 for key in keys {
667 if let Some(entry) = state.entries.get(key) {
668 for tag in &entry.payload.tags {
669 seen.insert(tag.clone());
670 }
671 }
672 }
673 seen.into_iter().collect()
674}
675
676pub fn reflex_addr_for(
684 fold: &super::Fold<CapabilityFold>,
685 node_id: NodeId,
686) -> Option<std::net::SocketAddr> {
687 fold.with_state(|state| {
688 let keys = state.by_node.get(&node_id)?;
689 for key in keys {
690 if let Some(entry) = state.entries.get(key) {
691 if let Some(addr) = entry.payload.reflex_addr {
692 return Some(addr);
693 }
694 }
695 }
696 None
697 })
698}
699
700#[cfg(test)]
701mod tests {
702 use std::sync::Arc;
703 use std::time::Duration;
704
705 use super::*;
706 use crate::adapter::net::behavior::fold::{
707 ApplyOutcome, EnvelopeMeta, Fold, FoldRegistry, SignedAnnouncement,
708 };
709 use crate::adapter::net::identity::EntityKeypair;
710
711 fn sign_cap(
712 keypair: &EntityKeypair,
713 publisher: NodeId,
714 generation: u64,
715 class: u64,
716 tags: Vec<&str>,
717 state: NodeState,
718 region: Option<&str>,
719 ) -> SignedAnnouncement<CapabilityMembership> {
720 sign_cap_with_reflex(
721 keypair, publisher, generation, class, tags, state, region, None,
722 )
723 }
724
725 #[allow(clippy::too_many_arguments)]
726 fn sign_cap_with_reflex(
727 keypair: &EntityKeypair,
728 publisher: NodeId,
729 generation: u64,
730 class: u64,
731 tags: Vec<&str>,
732 state: NodeState,
733 region: Option<&str>,
734 reflex_addr: Option<std::net::SocketAddr>,
735 ) -> SignedAnnouncement<CapabilityMembership> {
736 SignedAnnouncement::sign(
737 keypair,
738 CapabilityFold::KIND_ID,
739 class,
740 publisher,
741 generation,
742 EnvelopeMeta::default(),
743 CapabilityMembership {
744 class_hash: class,
745 tags: tags.into_iter().map(String::from).collect(),
746 hardware: None,
747 state,
748 region: region.map(String::from),
749 price_quote: None,
750 reflex_addr,
751 allowed_nodes: Vec::new(),
752 allowed_subnets: Vec::new(),
753 allowed_groups: Vec::new(),
754 metadata: BTreeMap::new(),
755 },
756 )
757 .expect("sign succeeds")
758 }
759
760 fn new_fold() -> Fold<CapabilityFold> {
761 Fold::with_sweep_interval(Duration::ZERO)
762 }
763
764 #[test]
765 fn first_announcement_installs_and_populates_secondary_index() {
766 let fold = new_fold();
767 let kp = EntityKeypair::generate();
768 let outcome = fold
769 .apply(sign_cap(
770 &kp,
771 0xA,
772 1,
773 0x100,
774 vec!["hardware.gpu", "vendor.nvidia"],
775 NodeState::Idle,
776 Some("us-east"),
777 ))
778 .expect("apply");
779 assert_eq!(outcome, ApplyOutcome::Inserted);
780
781 let hits = fold.query(CapabilityQuery::InClass(0x100));
783 assert_eq!(hits.len(), 1);
784 assert_eq!(hits[0].0, (0x100, 0xA));
785
786 let hits = fold.query(CapabilityQuery::HasAllTags(vec!["hardware.gpu".into()]));
788 assert_eq!(hits.len(), 1);
789
790 let hits = fold.query(CapabilityQuery::InState(NodeState::Idle));
792 assert_eq!(hits.len(), 1);
793
794 let hits = fold.query(CapabilityQuery::InRegion("us-east".into()));
796 assert_eq!(hits.len(), 1);
797 }
798
799 #[test]
800 fn each_publisher_owns_its_own_class_entry_no_cross_override() {
801 let fold = new_fold();
805 let kp_a = EntityKeypair::generate();
806 let kp_b = EntityKeypair::generate();
807
808 fold.apply(sign_cap(
809 &kp_a,
810 0xA,
811 1,
812 0x100,
813 vec!["gpu"],
814 NodeState::Idle,
815 None,
816 ))
817 .expect("a");
818 fold.apply(sign_cap(
819 &kp_b,
820 0xB,
821 1,
822 0x100,
823 vec!["gpu"],
824 NodeState::Busy,
825 None,
826 ))
827 .expect("b");
828
829 let hits = fold.query(CapabilityQuery::InClass(0x100));
830 assert_eq!(hits.len(), 2, "both publishers' entries coexist");
831
832 let idle = fold.query(CapabilityQuery::InState(NodeState::Idle));
834 assert_eq!(idle.len(), 1);
835 assert_eq!(idle[0].0, (0x100, 0xA));
836
837 let busy = fold.query(CapabilityQuery::InState(NodeState::Busy));
838 assert_eq!(busy.len(), 1);
839 assert_eq!(busy[0].0, (0x100, 0xB));
840 }
841
842 #[test]
843 fn replace_updates_secondary_index_drops_stale_tags() {
844 let fold = new_fold();
849 let kp = EntityKeypair::generate();
850
851 fold.apply(sign_cap(
852 &kp,
853 0xA,
854 1,
855 0x100,
856 vec!["gpu"],
857 NodeState::Idle,
858 Some("us-east"),
859 ))
860 .expect("v1");
861
862 fold.apply(sign_cap(
863 &kp,
864 0xA,
865 2,
866 0x100,
867 vec!["tpu"],
868 NodeState::Busy,
869 Some("us-west"),
870 ))
871 .expect("v2");
872
873 let stale = fold.query(CapabilityQuery::HasAllTags(vec!["gpu".into()]));
875 assert!(stale.is_empty());
876 let fresh = fold.query(CapabilityQuery::HasAllTags(vec!["tpu".into()]));
878 assert_eq!(fresh.len(), 1);
879
880 let stale_state = fold.query(CapabilityQuery::InState(NodeState::Idle));
882 assert!(stale_state.is_empty());
883 let new_state = fold.query(CapabilityQuery::InState(NodeState::Busy));
885 assert_eq!(new_state.len(), 1);
886
887 assert!(fold
889 .query(CapabilityQuery::InRegion("us-east".into()))
890 .is_empty());
891 assert_eq!(
892 fold.query(CapabilityQuery::InRegion("us-west".into()))
893 .len(),
894 1
895 );
896 }
897
898 #[test]
899 fn has_all_tags_finds_only_entries_carrying_every_tag() {
900 let fold = new_fold();
901 let kp = EntityKeypair::generate();
902 fold.apply(sign_cap(
903 &kp,
904 0x1,
905 1,
906 0x100,
907 vec!["a", "b", "c"],
908 NodeState::Idle,
909 None,
910 ))
911 .unwrap();
912 fold.apply(sign_cap(
913 &kp,
914 0x2,
915 1,
916 0x100,
917 vec!["a", "b"],
918 NodeState::Idle,
919 None,
920 ))
921 .unwrap();
922 fold.apply(sign_cap(
923 &kp,
924 0x3,
925 1,
926 0x100,
927 vec!["a"],
928 NodeState::Idle,
929 None,
930 ))
931 .unwrap();
932
933 let hits: std::collections::HashSet<_> = fold
935 .query(CapabilityQuery::HasAllTags(vec![
936 "a".into(),
937 "b".into(),
938 "c".into(),
939 ]))
940 .into_iter()
941 .map(|((_, n), _)| n)
942 .collect();
943 assert_eq!(hits, [0x1].into_iter().collect());
944
945 let hits: std::collections::HashSet<_> = fold
947 .query(CapabilityQuery::HasAllTags(vec!["a".into(), "b".into()]))
948 .into_iter()
949 .map(|((_, n), _)| n)
950 .collect();
951 assert_eq!(hits, [0x1, 0x2].into_iter().collect());
952
953 let hits: std::collections::HashSet<_> = fold
955 .query(CapabilityQuery::HasAllTags(vec!["a".into()]))
956 .into_iter()
957 .map(|((_, n), _)| n)
958 .collect();
959 assert_eq!(hits, [0x1, 0x2, 0x3].into_iter().collect());
960 }
961
962 #[test]
963 fn has_any_tag_returns_union_across_buckets() {
964 let fold = new_fold();
965 let kp = EntityKeypair::generate();
966 fold.apply(sign_cap(
967 &kp,
968 0x1,
969 1,
970 0x100,
971 vec!["x"],
972 NodeState::Idle,
973 None,
974 ))
975 .unwrap();
976 fold.apply(sign_cap(
977 &kp,
978 0x2,
979 1,
980 0x100,
981 vec!["y"],
982 NodeState::Idle,
983 None,
984 ))
985 .unwrap();
986 fold.apply(sign_cap(
987 &kp,
988 0x3,
989 1,
990 0x100,
991 vec!["z"],
992 NodeState::Idle,
993 None,
994 ))
995 .unwrap();
996
997 let hits: std::collections::HashSet<_> = fold
998 .query(CapabilityQuery::HasAnyTag(vec!["x".into(), "y".into()]))
999 .into_iter()
1000 .map(|((_, n), _)| n)
1001 .collect();
1002 assert_eq!(hits, [0x1, 0x2].into_iter().collect());
1003 }
1004
1005 #[test]
1006 fn composite_query_intersects_every_populated_filter_axis() {
1007 let fold = new_fold();
1008 let kp = EntityKeypair::generate();
1009
1010 fold.apply(sign_cap(
1014 &kp,
1015 0xA,
1016 1,
1017 0x100,
1018 vec!["gpu"],
1019 NodeState::Idle,
1020 Some("us-east"),
1021 ))
1022 .unwrap();
1023 fold.apply(sign_cap(
1024 &kp,
1025 0xB,
1026 1,
1027 0x100,
1028 vec!["gpu"],
1029 NodeState::Busy,
1030 Some("us-east"),
1031 ))
1032 .unwrap();
1033 fold.apply(sign_cap(
1034 &kp,
1035 0xC,
1036 1,
1037 0x100,
1038 vec!["gpu"],
1039 NodeState::Idle,
1040 Some("us-west"),
1041 ))
1042 .unwrap();
1043
1044 let filter = CapabilityFilter {
1045 class: Some(0x100),
1046 tags_all: vec!["gpu".into()],
1047 state: Some(NodeState::Idle),
1048 region: Some("us-east".into()),
1049 ..CapabilityFilter::default()
1050 };
1051 let hits: Vec<_> = fold
1052 .query(CapabilityQuery::Composite(filter))
1053 .into_iter()
1054 .map(|((_, n), _)| n)
1055 .collect();
1056 assert_eq!(hits, vec![0xA]);
1057 }
1058
1059 #[test]
1060 fn composite_query_honours_limit() {
1061 let fold = new_fold();
1062 let kp = EntityKeypair::generate();
1063 for i in 0..10 {
1064 fold.apply(sign_cap(
1065 &kp,
1066 i,
1067 1,
1068 0x100,
1069 vec!["gpu"],
1070 NodeState::Idle,
1071 None,
1072 ))
1073 .unwrap();
1074 }
1075 let filter = CapabilityFilter {
1076 class: Some(0x100),
1077 limit: 3,
1078 ..CapabilityFilter::default()
1079 };
1080 let hits = fold.query(CapabilityQuery::Composite(filter));
1081 assert_eq!(hits.len(), 3);
1082 }
1083
1084 #[test]
1085 fn composite_query_with_tags_any_filters_correctly() {
1086 let fold = new_fold();
1087 let kp = EntityKeypair::generate();
1088 fold.apply(sign_cap(
1089 &kp,
1090 0xA,
1091 1,
1092 0x100,
1093 vec!["common", "fast"],
1094 NodeState::Idle,
1095 None,
1096 ))
1097 .unwrap();
1098 fold.apply(sign_cap(
1099 &kp,
1100 0xB,
1101 1,
1102 0x100,
1103 vec!["common", "slow"],
1104 NodeState::Idle,
1105 None,
1106 ))
1107 .unwrap();
1108 fold.apply(sign_cap(
1109 &kp,
1110 0xC,
1111 1,
1112 0x100,
1113 vec!["common"],
1114 NodeState::Idle,
1115 None,
1116 ))
1117 .unwrap();
1118
1119 let filter = CapabilityFilter {
1123 tags_all: vec!["common".into()],
1124 tags_any: vec!["fast".into(), "slow".into()],
1125 ..CapabilityFilter::default()
1126 };
1127 let hits: std::collections::HashSet<_> = fold
1128 .query(CapabilityQuery::Composite(filter))
1129 .into_iter()
1130 .map(|((_, n), _)| n)
1131 .collect();
1132 assert_eq!(hits, [0xA, 0xB].into_iter().collect());
1133 }
1134
1135 #[test]
1136 fn evict_node_drops_every_class_entry_and_cleans_indexes() {
1137 let fold = new_fold();
1138 let kp = EntityKeypair::generate();
1139 fold.apply(sign_cap(
1142 &kp,
1143 0xA,
1144 1,
1145 0x100,
1146 vec!["gpu"],
1147 NodeState::Idle,
1148 Some("r1"),
1149 ))
1150 .unwrap();
1151 fold.apply(sign_cap(
1152 &kp,
1153 0xA,
1154 1,
1155 0x200,
1156 vec!["tpu"],
1157 NodeState::Busy,
1158 Some("r2"),
1159 ))
1160 .unwrap();
1161 fold.apply(sign_cap(
1162 &kp,
1163 0xB,
1164 1,
1165 0x100,
1166 vec!["gpu"],
1167 NodeState::Idle,
1168 Some("r1"),
1169 ))
1170 .unwrap();
1171 assert_eq!(fold.stats().entries, 3);
1172
1173 fold.evict_node(0xA, "test");
1174 assert_eq!(fold.stats().entries, 1);
1175 assert_eq!(fold.stats().evictions, 2);
1176
1177 let gpu_hits: std::collections::HashSet<_> = fold
1181 .query(CapabilityQuery::HasAllTags(vec!["gpu".into()]))
1182 .into_iter()
1183 .map(|((_, n), _)| n)
1184 .collect();
1185 assert_eq!(gpu_hits, [0xB].into_iter().collect());
1186 let tpu_hits = fold.query(CapabilityQuery::HasAllTags(vec!["tpu".into()]));
1187 assert!(tpu_hits.is_empty());
1188 }
1189
1190 #[test]
1191 fn reflex_addr_for_returns_first_advertised_addr_across_publisher_classes() {
1192 use std::net::SocketAddr;
1193 let fold = new_fold();
1194 let kp = EntityKeypair::generate();
1195 let addr: SocketAddr = "203.0.113.4:7000".parse().unwrap();
1196
1197 fold.apply(sign_cap_with_reflex(
1201 &kp,
1202 0xAA,
1203 1,
1204 0x100,
1205 vec![],
1206 NodeState::Idle,
1207 None,
1208 None,
1209 ))
1210 .expect("class 0x100");
1211 fold.apply(sign_cap_with_reflex(
1212 &kp,
1213 0xAA,
1214 1,
1215 0x101,
1216 vec![],
1217 NodeState::Idle,
1218 None,
1219 Some(addr),
1220 ))
1221 .expect("class 0x101");
1222
1223 assert_eq!(super::reflex_addr_for(&fold, 0xAA), Some(addr));
1224 assert_eq!(super::reflex_addr_for(&fold, 0xBB), None);
1226 }
1227
1228 #[test]
1229 fn reflex_addr_for_returns_none_when_publisher_advertises_no_addr() {
1230 let fold = new_fold();
1231 let kp = EntityKeypair::generate();
1232 fold.apply(sign_cap(&kp, 0xAA, 1, 0x100, vec![], NodeState::Idle, None))
1233 .expect("class 0x100");
1234 assert_eq!(super::reflex_addr_for(&fold, 0xAA), None);
1235 }
1236
1237 #[test]
1238 fn capability_tags_for_all_matches_per_node_walk() {
1239 let fold = new_fold();
1245 let kp_a = EntityKeypair::generate();
1246 let kp_b = EntityKeypair::generate();
1247 fold.apply(sign_cap(
1248 &kp_a,
1249 0xA,
1250 1,
1251 0x100,
1252 vec!["gpu", "vendor.nvidia"],
1253 NodeState::Idle,
1254 None,
1255 ))
1256 .expect("a-100");
1257 fold.apply(sign_cap(
1259 &kp_a,
1260 0xA,
1261 1,
1262 0x200,
1263 vec!["gpu", "model:llama"],
1264 NodeState::Idle,
1265 None,
1266 ))
1267 .expect("a-200");
1268 fold.apply(sign_cap(
1269 &kp_b,
1270 0xB,
1271 1,
1272 0x100,
1273 vec!["cpu-only"],
1274 NodeState::Idle,
1275 None,
1276 ))
1277 .expect("b-100");
1278
1279 let batched = super::capability_tags_for_all(&fold);
1280 assert_eq!(batched.len(), 2);
1281
1282 let mut tags_a = batched.get(&0xA).cloned().unwrap_or_default();
1283 tags_a.sort();
1284 assert_eq!(
1285 tags_a,
1286 vec![
1287 "gpu".to_string(),
1288 "model:llama".to_string(),
1289 "vendor.nvidia".to_string()
1290 ],
1291 "publisher A unions tags across both class entries"
1292 );
1293
1294 let mut tags_b = batched.get(&0xB).cloned().unwrap_or_default();
1295 tags_b.sort();
1296 assert_eq!(tags_b, vec!["cpu-only".to_string()]);
1297
1298 for (node_id, batched_tags) in &batched {
1301 let mut single = super::capability_tags_for(&fold, *node_id);
1302 single.sort();
1303 let mut batched_sorted = batched_tags.clone();
1304 batched_sorted.sort();
1305 assert_eq!(single, batched_sorted, "mismatch for node 0x{:x}", node_id);
1306 }
1307 }
1308
1309 #[test]
1310 fn capability_tags_for_all_returns_empty_for_empty_fold() {
1311 let fold = new_fold();
1312 let batched = super::capability_tags_for_all(&fold);
1313 assert!(batched.is_empty());
1314 }
1315
1316 #[test]
1317 fn runtime_ttl_sweeps_stale_capability_entries() {
1318 let fold = new_fold();
1319 let kp = EntityKeypair::generate();
1320 let ann = SignedAnnouncement::sign(
1321 &kp,
1322 CapabilityFold::KIND_ID,
1323 0x100,
1324 0xA,
1325 1,
1326 EnvelopeMeta {
1327 ttl_secs: Some(0),
1328 ..Default::default()
1329 },
1330 CapabilityMembership {
1331 class_hash: 0x100,
1332 tags: vec!["gpu".into()],
1333 hardware: None,
1334 state: NodeState::Idle,
1335 region: None,
1336 price_quote: None,
1337 reflex_addr: None,
1338 allowed_nodes: Vec::new(),
1339 allowed_subnets: Vec::new(),
1340 allowed_groups: Vec::new(),
1341 metadata: BTreeMap::new(),
1342 },
1343 )
1344 .unwrap();
1345 fold.apply(ann).unwrap();
1346 assert_eq!(fold.stats().entries, 1);
1347
1348 std::thread::sleep(Duration::from_millis(10));
1349 let n = fold.sweep_expired_now();
1350 assert_eq!(n, 1);
1351 assert_eq!(fold.stats().entries, 0);
1352 assert_eq!(fold.stats().expiries, 1);
1353
1354 assert!(fold
1356 .query(CapabilityQuery::HasAllTags(vec!["gpu".into()]))
1357 .is_empty());
1358 }
1359
1360 #[test]
1361 fn capability_fold_plugs_into_registry_and_dispatches_signed_envelopes() {
1362 let registry = FoldRegistry::new();
1363 let fold: Arc<Fold<CapabilityFold>> = Arc::new(new_fold());
1364 registry.register(fold.clone());
1365
1366 let kp = EntityKeypair::generate();
1367 let ann = sign_cap(
1368 &kp,
1369 0xA,
1370 1,
1371 0x100,
1372 vec!["gpu"],
1373 NodeState::Idle,
1374 Some("us-east"),
1375 );
1376 let bytes = ann.encode().expect("encode");
1377 let outcome = registry.dispatch(&bytes, kp.entity_id()).expect("dispatch");
1378 assert_eq!(outcome, ApplyOutcome::Inserted);
1379 assert_eq!(fold.stats().entries, 1);
1380 }
1381}