1use std::collections::{BTreeMap, BTreeSet};
10
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13
14use selene_core::{Change, DbString, EdgeId, LabelSet, NodeId};
15
16use crate::index_provider::{
17 IndexProvider, ProviderError, ProviderTag, SubTag, VectorCandidateStateInfo,
18};
19use crate::store::RowIndex;
20use crate::{SeleneGraph, VectorCandidateSet};
21
22pub const CANDIDATE_STATE_PROVIDER_TAG: [u8; 4] = *b"CSET";
24
25pub const CANDIDATE_STATE_SUB: [u8; 4] = *b"STAT";
27
28const SNAPSHOT_VERSION: u8 = 1;
29const SUB_TAGS: &[SubTag] = &[SubTag(CANDIDATE_STATE_SUB)];
30
31#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
33pub struct CandidateStateSpec {
34 pub name: DbString,
36 pub required_label: Option<DbString>,
38 pub require_outgoing: Vec<DbString>,
40 pub require_incoming: Vec<DbString>,
42 pub exclude_outgoing: Vec<DbString>,
44 pub exclude_incoming: Vec<DbString>,
46}
47
48impl CandidateStateSpec {
49 #[must_use]
51 pub fn new(name: DbString) -> Self {
52 Self {
53 name,
54 required_label: None,
55 require_outgoing: Vec::new(),
56 require_incoming: Vec::new(),
57 exclude_outgoing: Vec::new(),
58 exclude_incoming: Vec::new(),
59 }
60 }
61
62 #[must_use]
64 pub fn require_label(mut self, label: DbString) -> Self {
65 self.required_label = Some(label);
66 self
67 }
68
69 #[must_use]
71 pub fn require_outgoing(mut self, label: DbString) -> Self {
72 insert_sorted_unique(&mut self.require_outgoing, label);
73 self
74 }
75
76 #[must_use]
78 pub fn require_incoming(mut self, label: DbString) -> Self {
79 insert_sorted_unique(&mut self.require_incoming, label);
80 self
81 }
82
83 #[must_use]
85 pub fn exclude_outgoing(mut self, label: DbString) -> Self {
86 insert_sorted_unique(&mut self.exclude_outgoing, label);
87 self
88 }
89
90 #[must_use]
92 pub fn exclude_incoming(mut self, label: DbString) -> Self {
93 insert_sorted_unique(&mut self.exclude_incoming, label);
94 self
95 }
96}
97
98pub struct MaintainedCandidateStateProvider {
100 specs: Vec<CandidateStateSpec>,
101 state: Mutex<CandidateState>,
102}
103
104impl MaintainedCandidateStateProvider {
105 pub fn new(specs: impl IntoIterator<Item = CandidateStateSpec>) -> Result<Self, ProviderError> {
111 let mut specs = specs.into_iter().collect::<Vec<_>>();
112 for spec in &mut specs {
113 canonicalize_labels(&mut spec.require_outgoing);
114 canonicalize_labels(&mut spec.require_incoming);
115 canonicalize_labels(&mut spec.exclude_outgoing);
116 canonicalize_labels(&mut spec.exclude_incoming);
117 }
118 validate_unique_specs(&specs)?;
119 Ok(Self {
120 state: Mutex::new(CandidateState::new(&specs)),
121 specs,
122 })
123 }
124
125 pub fn from_graph(
132 specs: impl IntoIterator<Item = CandidateStateSpec>,
133 graph: &SeleneGraph,
134 ) -> Result<Self, ProviderError> {
135 let provider = Self::new(specs)?;
136 provider.rebuild_from_graph(graph)?;
137 Ok(provider)
138 }
139
140 pub fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
149 let mut rebuilt = CandidateState::new(&self.specs);
150 for row in graph.live_nodes() {
151 let row = RowIndex::new(row);
152 let id = graph.node_id_for_row(row).ok_or_else(|| {
153 inconsistent(format!("live node row {} has no external id", row.get()))
154 })?;
155 let labels = graph
156 .node_labels(id)
157 .ok_or_else(|| inconsistent(format!("live node {id} has no label column entry")))?;
158 rebuilt.node_labels.insert(id, labels.clone());
159 }
160 for row in graph.live_edges() {
161 let row = RowIndex::new(row);
162 let id = graph.edge_id_for_row(row).ok_or_else(|| {
163 inconsistent(format!("live edge row {} has no external id", row.get()))
164 })?;
165 let label = graph
166 .edge_label(id)
167 .ok_or_else(|| inconsistent(format!("live edge {id} has no label")))?;
168 if !watches_label(&self.specs, label) {
169 continue;
170 }
171 let (source, target) = graph
172 .edge_endpoints(id)
173 .ok_or_else(|| inconsistent(format!("live edge {id} has no endpoints")))?;
174 rebuilt.edges.insert(
175 id,
176 TrackedEdge {
177 label: label.clone(),
178 source,
179 target,
180 },
181 );
182 }
183 rebuilt.rebuild_derived(&self.specs);
184 rebuilt.generation = graph.meta.generation;
185 *self.state.lock() = rebuilt;
186 Ok(())
187 }
188
189 #[must_use]
191 pub fn spec(&self, name: &DbString) -> Option<&CandidateStateSpec> {
192 self.specs.iter().find(|spec| &spec.name == name)
193 }
194
195 #[must_use]
197 pub fn candidate_set(&self, name: &DbString) -> Option<VectorCandidateSet> {
198 let state = self.state.lock();
199 state.members.get(name).map(|members| {
200 VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
201 })
202 }
203
204 #[must_use]
206 pub fn generation(&self) -> u64 {
207 self.state.lock().generation
208 }
209
210 pub fn candidate_set_at_generation(
217 &self,
218 name: &DbString,
219 generation: u64,
220 ) -> Result<Option<VectorCandidateSet>, ProviderError> {
221 let state = self.state.lock();
222 if state.generation != generation {
223 return Err(inconsistent(format!(
224 "candidate-state generation {} does not match graph generation {generation}",
225 state.generation
226 )));
227 }
228 Ok(state.members.get(name).map(|members| {
229 VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
230 }))
231 }
232
233 pub fn candidate_state_infos_at_generation(
240 &self,
241 generation: u64,
242 ) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
243 let state = self.state.lock();
244 if state.generation != generation {
245 return Err(inconsistent(format!(
246 "candidate-state generation {} does not match graph generation {generation}",
247 state.generation
248 )));
249 }
250 Ok(self
251 .specs
252 .iter()
253 .map(|spec| VectorCandidateStateInfo {
254 name: spec.name.clone(),
255 generation,
256 candidate_count: state.members.get(&spec.name).map_or(0, BTreeSet::len),
257 required_label: spec.required_label.clone(),
258 require_outgoing: spec.require_outgoing.clone(),
259 require_incoming: spec.require_incoming.clone(),
260 exclude_outgoing: spec.exclude_outgoing.clone(),
261 exclude_incoming: spec.exclude_incoming.clone(),
262 })
263 .collect())
264 }
265
266 #[must_use]
268 pub fn contains(&self, name: &DbString, node: NodeId) -> bool {
269 self.state
270 .lock()
271 .members
272 .get(name)
273 .is_some_and(|members| members.contains(&node))
274 }
275}
276
277impl IndexProvider for MaintainedCandidateStateProvider {
278 fn provider_tag(&self) -> ProviderTag {
279 ProviderTag(CANDIDATE_STATE_PROVIDER_TAG)
280 }
281
282 fn read_section(&self, sub_tag: SubTag, bytes: &[u8]) -> Result<(), ProviderError> {
283 ensure_state_subtag(sub_tag)?;
284 let snapshot: CandidateStateSnapshot = postcard::from_bytes(bytes).map_err(|error| {
285 invalid_payload(format!("CSET/STAT postcard decode failed: {error}"))
286 })?;
287 if snapshot.version != SNAPSHOT_VERSION {
288 return Err(invalid_payload(format!(
289 "unsupported CSET/STAT version {}",
290 snapshot.version
291 )));
292 }
293 if snapshot.specs != self.specs {
294 return Err(invalid_payload(
295 "CSET/STAT specs differ from provider configuration".to_owned(),
296 ));
297 }
298 let mut state = CandidateState::new(&self.specs);
299 state.generation = snapshot.generation;
300 for (id, labels) in snapshot.node_labels {
301 if state.node_labels.insert(id, labels).is_some() {
302 return Err(invalid_payload(format!(
303 "duplicate node id {id} in CSET/STAT"
304 )));
305 }
306 }
307 for (id, edge) in snapshot.edges {
308 if !watches_label(&self.specs, &edge.label) {
309 return Err(invalid_payload(format!(
310 "unwatched edge label {} in CSET/STAT",
311 edge.label.as_str()
312 )));
313 }
314 if !state.node_labels.contains_key(&edge.source)
315 || !state.node_labels.contains_key(&edge.target)
316 {
317 return Err(invalid_payload(format!(
318 "tracked edge {id} references missing endpoint in CSET/STAT"
319 )));
320 }
321 if state.edges.insert(id, edge).is_some() {
322 return Err(invalid_payload(format!(
323 "duplicate edge id {id} in CSET/STAT"
324 )));
325 }
326 }
327 state.rebuild_derived(&self.specs);
328 *self.state.lock() = state;
329 Ok(())
330 }
331
332 fn write_section(&self, sub_tag: SubTag) -> Result<Vec<u8>, ProviderError> {
333 ensure_state_subtag(sub_tag)?;
334 let state = self.state.lock();
335 let snapshot = CandidateStateSnapshot {
336 version: SNAPSHOT_VERSION,
337 generation: state.generation,
338 specs: self.specs.clone(),
339 node_labels: state
340 .node_labels
341 .iter()
342 .map(|(id, labels)| (*id, labels.clone()))
343 .collect(),
344 edges: state
345 .edges
346 .iter()
347 .map(|(id, edge)| (*id, edge.clone()))
348 .collect(),
349 };
350 postcard::to_stdvec(&snapshot).map_err(|error| ProviderError::SerializationFailed {
351 reason: format!("CSET/STAT postcard encode failed: {error}"),
352 })
353 }
354
355 fn on_change(&self, change: &Change) -> Result<(), ProviderError> {
356 self.state.lock().apply_change(&self.specs, change)
357 }
358
359 fn handles_change_batches(&self) -> bool {
360 true
361 }
362
363 fn on_changes(&self, changes: &[Change]) -> Result<(), ProviderError> {
364 let mut state = self.state.lock();
365 for change in changes {
366 state.apply_change(&self.specs, change)?;
367 }
368 Ok(())
369 }
370
371 fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
372 MaintainedCandidateStateProvider::rebuild_from_graph(self, graph)
373 }
374
375 fn on_commit_applied(&self, generation: u64) -> Result<(), ProviderError> {
376 self.state.lock().generation = generation;
377 Ok(())
378 }
379
380 fn vector_candidate_set(
381 &self,
382 name: &DbString,
383 generation: u64,
384 ) -> Result<Option<VectorCandidateSet>, ProviderError> {
385 self.candidate_set_at_generation(name, generation)
386 }
387
388 fn vector_candidate_state_infos(
389 &self,
390 generation: u64,
391 ) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
392 self.candidate_state_infos_at_generation(generation)
393 }
394
395 fn declared_sub_tags(&self) -> &[SubTag] {
396 SUB_TAGS
397 }
398}
399
400#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
401struct TrackedEdge {
402 label: DbString,
403 source: NodeId,
404 target: NodeId,
405}
406
407#[derive(Clone, Debug, Deserialize, Serialize)]
408struct CandidateStateSnapshot {
409 version: u8,
410 generation: u64,
411 specs: Vec<CandidateStateSpec>,
412 node_labels: Vec<(NodeId, LabelSet)>,
413 edges: Vec<(EdgeId, TrackedEdge)>,
414}
415
416#[derive(Clone, Debug)]
417struct CandidateState {
418 generation: u64,
419 node_labels: BTreeMap<NodeId, LabelSet>,
420 edges: BTreeMap<EdgeId, TrackedEdge>,
421 outgoing_counts: BTreeMap<(NodeId, DbString), usize>,
422 incoming_counts: BTreeMap<(NodeId, DbString), usize>,
423 members: BTreeMap<DbString, BTreeSet<NodeId>>,
424}
425
426impl CandidateState {
427 fn new(specs: &[CandidateStateSpec]) -> Self {
428 Self {
429 generation: 0,
430 node_labels: BTreeMap::new(),
431 edges: BTreeMap::new(),
432 outgoing_counts: BTreeMap::new(),
433 incoming_counts: BTreeMap::new(),
434 members: empty_members(specs),
435 }
436 }
437
438 fn apply_change(
439 &mut self,
440 specs: &[CandidateStateSpec],
441 change: &Change,
442 ) -> Result<(), ProviderError> {
443 match change {
444 Change::NodeCreated { id, labels, .. } => {
445 if self.node_labels.insert(*id, labels.clone()).is_some() {
446 return Err(inconsistent(format!("duplicate node create for {id}")));
447 }
448 self.recompute_node(specs, *id);
449 }
450 Change::NodeUpdated {
451 id, labels_diff, ..
452 } => {
453 let labels = self
454 .node_labels
455 .get_mut(id)
456 .ok_or_else(|| inconsistent(format!("label update for unknown node {id}")))?;
457 for label in &labels_diff.removed {
458 labels.remove(label);
459 }
460 for label in &labels_diff.added {
461 labels.insert(label.clone());
462 }
463 self.recompute_node(specs, *id);
464 }
465 Change::NodeDeleted { id } => {
466 self.node_labels.remove(id);
467 self.remove_incident_edges(specs, *id);
468 self.recompute_node(specs, *id);
469 }
470 Change::NodeLabelRemoved { id, label } => {
471 let labels = self
472 .node_labels
473 .get_mut(id)
474 .ok_or_else(|| inconsistent(format!("label removal for unknown node {id}")))?;
475 labels.remove(label);
476 self.recompute_node(specs, *id);
477 }
478 Change::EdgeCreated {
479 id,
480 label,
481 source,
482 target,
483 ..
484 } => {
485 if watches_label(specs, label) {
486 let edge = TrackedEdge {
487 label: label.clone(),
488 source: *source,
489 target: *target,
490 };
491 if self.edges.insert(*id, edge.clone()).is_some() {
492 return Err(inconsistent(format!("duplicate edge create for {id}")));
493 }
494 self.increment_edge(&edge);
495 self.recompute_node(specs, *source);
496 self.recompute_node(specs, *target);
497 }
498 }
499 Change::EdgeDeleted { id } => {
500 if let Some(edge) = self.edges.remove(id) {
501 self.decrement_edge(&edge);
502 self.recompute_node(specs, edge.source);
503 self.recompute_node(specs, edge.target);
504 }
505 }
506 Change::GraphReset {} => {
507 *self = Self::new(specs);
508 }
509 Change::NodesOfTypeTruncated { label } => {
510 let removed = self
511 .node_labels
512 .iter()
513 .filter_map(|(id, labels)| labels.contains(label).then_some(*id))
514 .collect::<BTreeSet<_>>();
515 if !removed.is_empty() {
516 self.node_labels.retain(|id, _| !removed.contains(id));
517 self.edges.retain(|_, edge| {
518 !removed.contains(&edge.source) && !removed.contains(&edge.target)
519 });
520 self.rebuild_derived(specs);
521 }
522 }
523 Change::EdgesOfTypeTruncated { label } => {
524 if watches_label(specs, label) {
525 self.edges.retain(|_, edge| edge.label != *label);
526 self.rebuild_derived(specs);
527 }
528 }
529 Change::EdgeUpdated { .. }
530 | Change::EdgePropertyRemoved { .. }
531 | Change::NodePropertyRemoved { .. }
532 | Change::SchemaChanged { .. } => {}
533 }
534 Ok(())
535 }
536
537 fn rebuild_derived(&mut self, specs: &[CandidateStateSpec]) {
538 self.outgoing_counts.clear();
539 self.incoming_counts.clear();
540 self.members = empty_members(specs);
541 for edge in self.edges.values().cloned().collect::<Vec<_>>() {
542 self.increment_edge(&edge);
543 }
544 for id in self.node_labels.keys().copied().collect::<Vec<_>>() {
545 self.recompute_node(specs, id);
546 }
547 }
548
549 fn increment_edge(&mut self, edge: &TrackedEdge) {
550 *self
551 .outgoing_counts
552 .entry((edge.source, edge.label.clone()))
553 .or_insert(0) += 1;
554 *self
555 .incoming_counts
556 .entry((edge.target, edge.label.clone()))
557 .or_insert(0) += 1;
558 }
559
560 fn decrement_edge(&mut self, edge: &TrackedEdge) {
561 decrement_count(&mut self.outgoing_counts, (edge.source, edge.label.clone()));
562 decrement_count(&mut self.incoming_counts, (edge.target, edge.label.clone()));
563 }
564
565 fn remove_incident_edges(&mut self, specs: &[CandidateStateSpec], node: NodeId) {
566 let incident = self
567 .edges
568 .iter()
569 .filter_map(|(id, edge)| {
570 (edge.source == node || edge.target == node).then_some((*id, edge.clone()))
571 })
572 .collect::<Vec<_>>();
573 for (id, edge) in incident {
574 self.edges.remove(&id);
575 self.decrement_edge(&edge);
576 if edge.source != node {
577 self.recompute_node(specs, edge.source);
578 }
579 if edge.target != node {
580 self.recompute_node(specs, edge.target);
581 }
582 }
583 }
584
585 fn recompute_node(&mut self, specs: &[CandidateStateSpec], node: NodeId) {
586 let labels = self.node_labels.get(&node).cloned();
587 for spec in specs {
588 let include = labels.as_ref().is_some_and(|labels| {
589 spec.required_label
590 .as_ref()
591 .is_none_or(|required| labels.contains(required))
592 && spec
593 .require_outgoing
594 .iter()
595 .all(|label| has_count(&self.outgoing_counts, node, label))
596 && spec
597 .require_incoming
598 .iter()
599 .all(|label| has_count(&self.incoming_counts, node, label))
600 && spec
601 .exclude_outgoing
602 .iter()
603 .all(|label| !has_count(&self.outgoing_counts, node, label))
604 && spec
605 .exclude_incoming
606 .iter()
607 .all(|label| !has_count(&self.incoming_counts, node, label))
608 });
609 let members = self.members.entry(spec.name.clone()).or_default();
610 if include {
611 members.insert(node);
612 } else {
613 members.remove(&node);
614 }
615 }
616 }
617}
618
619fn validate_unique_specs(specs: &[CandidateStateSpec]) -> Result<(), ProviderError> {
620 let mut seen = BTreeSet::new();
621 for spec in specs {
622 if !seen.insert(spec.name.clone()) {
623 return Err(inconsistent(format!(
624 "duplicate candidate-state spec name {}",
625 spec.name.as_str()
626 )));
627 }
628 }
629 Ok(())
630}
631
632fn empty_members(specs: &[CandidateStateSpec]) -> BTreeMap<DbString, BTreeSet<NodeId>> {
633 specs
634 .iter()
635 .map(|spec| (spec.name.clone(), BTreeSet::new()))
636 .collect()
637}
638
639fn watches_label(specs: &[CandidateStateSpec], label: &DbString) -> bool {
640 specs.iter().any(|spec| {
641 spec.require_outgoing.binary_search(label).is_ok()
642 || spec.require_incoming.binary_search(label).is_ok()
643 || spec.exclude_outgoing.binary_search(label).is_ok()
644 || spec.exclude_incoming.binary_search(label).is_ok()
645 })
646}
647
648fn has_count(counts: &BTreeMap<(NodeId, DbString), usize>, node: NodeId, label: &DbString) -> bool {
649 counts
650 .get(&(node, label.clone()))
651 .is_some_and(|count| *count > 0)
652}
653
654fn decrement_count(counts: &mut BTreeMap<(NodeId, DbString), usize>, key: (NodeId, DbString)) {
655 if let Some(count) = counts.get_mut(&key) {
656 *count = count.saturating_sub(1);
657 if *count == 0 {
658 counts.remove(&key);
659 }
660 }
661}
662
663fn insert_sorted_unique(labels: &mut Vec<DbString>, label: DbString) {
664 match labels.binary_search(&label) {
665 Ok(_) => {}
666 Err(index) => labels.insert(index, label),
667 }
668}
669
670fn canonicalize_labels(labels: &mut Vec<DbString>) {
671 labels.sort_unstable();
672 labels.dedup();
673}
674
675fn ensure_state_subtag(sub_tag: SubTag) -> Result<(), ProviderError> {
676 if sub_tag == SubTag(CANDIDATE_STATE_SUB) {
677 Ok(())
678 } else {
679 Err(invalid_payload(format!("unknown CSET sub-tag {sub_tag}")))
680 }
681}
682
683fn invalid_payload(reason: String) -> ProviderError {
684 ProviderError::InvalidPayload { reason }
685}
686
687fn inconsistent(reason: String) -> ProviderError {
688 ProviderError::Inconsistent { reason }
689}
690
691#[cfg(test)]
692#[path = "candidate_state/tests.rs"]
693mod tests;