Skip to main content

selene_graph/
candidate_state.rs

1//! Maintained graph-derived candidate sets.
2//!
3//! This module owns small, policy-neutral maintained node sets for graph/vector
4//! retrieval. A set can require node labels, require incoming/outgoing edge
5//! evidence, and exclude nodes that have disqualifying incoming/outgoing edges.
6//! That is enough to model active/current/unresolved memory subsets without
7//! hard-coding those application labels into the engine.
8
9use std::collections::{BTreeMap, BTreeSet};
10
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13
14use selene_core::{Change, DbString, EdgeId, LabelSet, NodeId};
15
16use crate::index_provider::{
17    IndexProvider, ProviderError, ProviderTag, SubTag, VectorCandidateStateInfo,
18};
19use crate::store::RowIndex;
20use crate::{SeleneGraph, VectorCandidateSet};
21
22/// Provider tag for maintained graph candidate-state sections.
23pub const CANDIDATE_STATE_PROVIDER_TAG: [u8; 4] = *b"CSET";
24
25/// Provider-owned snapshot section for maintained candidate-state data.
26pub const CANDIDATE_STATE_SUB: [u8; 4] = *b"STAT";
27
28const SNAPSHOT_VERSION: u8 = 1;
29const SUB_TAGS: &[SubTag] = &[SubTag(CANDIDATE_STATE_SUB)];
30
31/// Declarative rule for one maintained candidate set.
32#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
33pub struct CandidateStateSpec {
34    /// Stable set name used by callers to retrieve candidates.
35    pub name: DbString,
36    /// Optional node label required for membership.
37    pub required_label: Option<DbString>,
38    /// Outgoing edge labels required on the source node.
39    pub require_outgoing: Vec<DbString>,
40    /// Incoming edge labels required on the target node.
41    pub require_incoming: Vec<DbString>,
42    /// Outgoing edge labels that disqualify the source node.
43    pub exclude_outgoing: Vec<DbString>,
44    /// Incoming edge labels that disqualify the target node.
45    pub exclude_incoming: Vec<DbString>,
46}
47
48impl CandidateStateSpec {
49    /// Construct an unconstrained named candidate set.
50    #[must_use]
51    pub fn new(name: DbString) -> Self {
52        Self {
53            name,
54            required_label: None,
55            require_outgoing: Vec::new(),
56            require_incoming: Vec::new(),
57            exclude_outgoing: Vec::new(),
58            exclude_incoming: Vec::new(),
59        }
60    }
61
62    /// Require `label` for candidate membership.
63    #[must_use]
64    pub fn require_label(mut self, label: DbString) -> Self {
65        self.required_label = Some(label);
66        self
67    }
68
69    /// Require an outgoing edge carrying `label`.
70    #[must_use]
71    pub fn require_outgoing(mut self, label: DbString) -> Self {
72        insert_sorted_unique(&mut self.require_outgoing, label);
73        self
74    }
75
76    /// Require an incoming edge carrying `label`.
77    #[must_use]
78    pub fn require_incoming(mut self, label: DbString) -> Self {
79        insert_sorted_unique(&mut self.require_incoming, label);
80        self
81    }
82
83    /// Exclude nodes with an outgoing edge carrying `label`.
84    #[must_use]
85    pub fn exclude_outgoing(mut self, label: DbString) -> Self {
86        insert_sorted_unique(&mut self.exclude_outgoing, label);
87        self
88    }
89
90    /// Exclude nodes with an incoming edge carrying `label`.
91    #[must_use]
92    pub fn exclude_incoming(mut self, label: DbString) -> Self {
93        insert_sorted_unique(&mut self.exclude_incoming, label);
94        self
95    }
96}
97
98/// First-party provider maintaining named graph-derived candidate sets.
99pub struct MaintainedCandidateStateProvider {
100    specs: Vec<CandidateStateSpec>,
101    state: Mutex<CandidateState>,
102}
103
104impl MaintainedCandidateStateProvider {
105    /// Construct an empty provider for `specs`.
106    ///
107    /// # Errors
108    ///
109    /// Returns [`ProviderError`] when two specs use the same name.
110    pub fn new(specs: impl IntoIterator<Item = CandidateStateSpec>) -> Result<Self, ProviderError> {
111        let mut specs = specs.into_iter().collect::<Vec<_>>();
112        for spec in &mut specs {
113            canonicalize_labels(&mut spec.require_outgoing);
114            canonicalize_labels(&mut spec.require_incoming);
115            canonicalize_labels(&mut spec.exclude_outgoing);
116            canonicalize_labels(&mut spec.exclude_incoming);
117        }
118        validate_unique_specs(&specs)?;
119        Ok(Self {
120            state: Mutex::new(CandidateState::new(&specs)),
121            specs,
122        })
123    }
124
125    /// Construct a provider and initialize it from a graph snapshot.
126    ///
127    /// # Errors
128    ///
129    /// Returns [`ProviderError`] when specs are invalid or the graph snapshot is
130    /// internally inconsistent.
131    pub fn from_graph(
132        specs: impl IntoIterator<Item = CandidateStateSpec>,
133        graph: &SeleneGraph,
134    ) -> Result<Self, ProviderError> {
135        let provider = Self::new(specs)?;
136        provider.rebuild_from_graph(graph)?;
137        Ok(provider)
138    }
139
140    /// Rebuild all maintained state from `graph`.
141    ///
142    /// This is the safe attachment path when a provider is registered against an
143    /// already-populated graph instead of observing mutations from graph birth.
144    ///
145    /// # Errors
146    ///
147    /// Returns [`ProviderError`] if live row-to-id mappings are inconsistent.
148    pub fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
149        let mut rebuilt = CandidateState::new(&self.specs);
150        for row in graph.live_nodes() {
151            let row = RowIndex::new(row);
152            let id = graph.node_id_for_row(row).ok_or_else(|| {
153                inconsistent(format!("live node row {} has no external id", row.get()))
154            })?;
155            let labels = graph
156                .node_labels(id)
157                .ok_or_else(|| inconsistent(format!("live node {id} has no label column entry")))?;
158            rebuilt.node_labels.insert(id, labels.clone());
159        }
160        for row in graph.live_edges() {
161            let row = RowIndex::new(row);
162            let id = graph.edge_id_for_row(row).ok_or_else(|| {
163                inconsistent(format!("live edge row {} has no external id", row.get()))
164            })?;
165            let label = graph
166                .edge_label(id)
167                .ok_or_else(|| inconsistent(format!("live edge {id} has no label")))?;
168            if !watches_label(&self.specs, label) {
169                continue;
170            }
171            let (source, target) = graph
172                .edge_endpoints(id)
173                .ok_or_else(|| inconsistent(format!("live edge {id} has no endpoints")))?;
174            rebuilt.edges.insert(
175                id,
176                TrackedEdge {
177                    label: label.clone(),
178                    source,
179                    target,
180                },
181            );
182        }
183        rebuilt.rebuild_derived(&self.specs);
184        rebuilt.generation = graph.meta.generation;
185        *self.state.lock() = rebuilt;
186        Ok(())
187    }
188
189    /// Return the configured spec named `name`.
190    #[must_use]
191    pub fn spec(&self, name: &DbString) -> Option<&CandidateStateSpec> {
192        self.specs.iter().find(|spec| &spec.name == name)
193    }
194
195    /// Return the current candidate set for `name`.
196    #[must_use]
197    pub fn candidate_set(&self, name: &DbString) -> Option<VectorCandidateSet> {
198        let state = self.state.lock();
199        state.members.get(name).map(|members| {
200            VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
201        })
202    }
203
204    /// Return the provider generation watermark.
205    #[must_use]
206    pub fn generation(&self) -> u64 {
207        self.state.lock().generation
208    }
209
210    /// Return the current candidate set for `name` if it matches `generation`.
211    ///
212    /// # Errors
213    ///
214    /// Returns [`ProviderError`] when this provider has not applied every
215    /// mutation through `generation`.
216    pub fn candidate_set_at_generation(
217        &self,
218        name: &DbString,
219        generation: u64,
220    ) -> Result<Option<VectorCandidateSet>, ProviderError> {
221        let state = self.state.lock();
222        if state.generation != generation {
223            return Err(inconsistent(format!(
224                "candidate-state generation {} does not match graph generation {generation}",
225                state.generation
226            )));
227        }
228        Ok(state.members.get(name).map(|members| {
229            VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
230        }))
231    }
232
233    /// Return generation-checked metadata for every configured candidate set.
234    ///
235    /// # Errors
236    ///
237    /// Returns [`ProviderError`] when this provider has not applied every
238    /// mutation through `generation`.
239    pub fn candidate_state_infos_at_generation(
240        &self,
241        generation: u64,
242    ) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
243        let state = self.state.lock();
244        if state.generation != generation {
245            return Err(inconsistent(format!(
246                "candidate-state generation {} does not match graph generation {generation}",
247                state.generation
248            )));
249        }
250        Ok(self
251            .specs
252            .iter()
253            .map(|spec| VectorCandidateStateInfo {
254                name: spec.name.clone(),
255                generation,
256                candidate_count: state.members.get(&spec.name).map_or(0, BTreeSet::len),
257                required_label: spec.required_label.clone(),
258                require_outgoing: spec.require_outgoing.clone(),
259                require_incoming: spec.require_incoming.clone(),
260                exclude_outgoing: spec.exclude_outgoing.clone(),
261                exclude_incoming: spec.exclude_incoming.clone(),
262            })
263            .collect())
264    }
265
266    /// Return true when `node` is currently a member of the named set.
267    #[must_use]
268    pub fn contains(&self, name: &DbString, node: NodeId) -> bool {
269        self.state
270            .lock()
271            .members
272            .get(name)
273            .is_some_and(|members| members.contains(&node))
274    }
275}
276
277impl IndexProvider for MaintainedCandidateStateProvider {
278    fn provider_tag(&self) -> ProviderTag {
279        ProviderTag(CANDIDATE_STATE_PROVIDER_TAG)
280    }
281
282    fn read_section(&self, sub_tag: SubTag, bytes: &[u8]) -> Result<(), ProviderError> {
283        ensure_state_subtag(sub_tag)?;
284        let snapshot: CandidateStateSnapshot = postcard::from_bytes(bytes).map_err(|error| {
285            invalid_payload(format!("CSET/STAT postcard decode failed: {error}"))
286        })?;
287        if snapshot.version != SNAPSHOT_VERSION {
288            return Err(invalid_payload(format!(
289                "unsupported CSET/STAT version {}",
290                snapshot.version
291            )));
292        }
293        if snapshot.specs != self.specs {
294            return Err(invalid_payload(
295                "CSET/STAT specs differ from provider configuration".to_owned(),
296            ));
297        }
298        let mut state = CandidateState::new(&self.specs);
299        state.generation = snapshot.generation;
300        for (id, labels) in snapshot.node_labels {
301            if state.node_labels.insert(id, labels).is_some() {
302                return Err(invalid_payload(format!(
303                    "duplicate node id {id} in CSET/STAT"
304                )));
305            }
306        }
307        for (id, edge) in snapshot.edges {
308            if !watches_label(&self.specs, &edge.label) {
309                return Err(invalid_payload(format!(
310                    "unwatched edge label {} in CSET/STAT",
311                    edge.label.as_str()
312                )));
313            }
314            if !state.node_labels.contains_key(&edge.source)
315                || !state.node_labels.contains_key(&edge.target)
316            {
317                return Err(invalid_payload(format!(
318                    "tracked edge {id} references missing endpoint in CSET/STAT"
319                )));
320            }
321            if state.edges.insert(id, edge).is_some() {
322                return Err(invalid_payload(format!(
323                    "duplicate edge id {id} in CSET/STAT"
324                )));
325            }
326        }
327        state.rebuild_derived(&self.specs);
328        *self.state.lock() = state;
329        Ok(())
330    }
331
332    fn write_section(&self, sub_tag: SubTag) -> Result<Vec<u8>, ProviderError> {
333        ensure_state_subtag(sub_tag)?;
334        let state = self.state.lock();
335        let snapshot = CandidateStateSnapshot {
336            version: SNAPSHOT_VERSION,
337            generation: state.generation,
338            specs: self.specs.clone(),
339            node_labels: state
340                .node_labels
341                .iter()
342                .map(|(id, labels)| (*id, labels.clone()))
343                .collect(),
344            edges: state
345                .edges
346                .iter()
347                .map(|(id, edge)| (*id, edge.clone()))
348                .collect(),
349        };
350        postcard::to_stdvec(&snapshot).map_err(|error| ProviderError::SerializationFailed {
351            reason: format!("CSET/STAT postcard encode failed: {error}"),
352        })
353    }
354
355    fn on_change(&self, change: &Change) -> Result<(), ProviderError> {
356        self.state.lock().apply_change(&self.specs, change)
357    }
358
359    fn handles_change_batches(&self) -> bool {
360        true
361    }
362
363    fn on_changes(&self, changes: &[Change]) -> Result<(), ProviderError> {
364        let mut state = self.state.lock();
365        for change in changes {
366            state.apply_change(&self.specs, change)?;
367        }
368        Ok(())
369    }
370
371    fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
372        MaintainedCandidateStateProvider::rebuild_from_graph(self, graph)
373    }
374
375    fn on_commit_applied(&self, generation: u64) -> Result<(), ProviderError> {
376        self.state.lock().generation = generation;
377        Ok(())
378    }
379
380    fn vector_candidate_set(
381        &self,
382        name: &DbString,
383        generation: u64,
384    ) -> Result<Option<VectorCandidateSet>, ProviderError> {
385        self.candidate_set_at_generation(name, generation)
386    }
387
388    fn vector_candidate_state_infos(
389        &self,
390        generation: u64,
391    ) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
392        self.candidate_state_infos_at_generation(generation)
393    }
394
395    fn declared_sub_tags(&self) -> &[SubTag] {
396        SUB_TAGS
397    }
398}
399
400#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
401struct TrackedEdge {
402    label: DbString,
403    source: NodeId,
404    target: NodeId,
405}
406
407#[derive(Clone, Debug, Deserialize, Serialize)]
408struct CandidateStateSnapshot {
409    version: u8,
410    generation: u64,
411    specs: Vec<CandidateStateSpec>,
412    node_labels: Vec<(NodeId, LabelSet)>,
413    edges: Vec<(EdgeId, TrackedEdge)>,
414}
415
416#[derive(Clone, Debug)]
417struct CandidateState {
418    generation: u64,
419    node_labels: BTreeMap<NodeId, LabelSet>,
420    edges: BTreeMap<EdgeId, TrackedEdge>,
421    outgoing_counts: BTreeMap<(NodeId, DbString), usize>,
422    incoming_counts: BTreeMap<(NodeId, DbString), usize>,
423    members: BTreeMap<DbString, BTreeSet<NodeId>>,
424}
425
426impl CandidateState {
427    fn new(specs: &[CandidateStateSpec]) -> Self {
428        Self {
429            generation: 0,
430            node_labels: BTreeMap::new(),
431            edges: BTreeMap::new(),
432            outgoing_counts: BTreeMap::new(),
433            incoming_counts: BTreeMap::new(),
434            members: empty_members(specs),
435        }
436    }
437
438    fn apply_change(
439        &mut self,
440        specs: &[CandidateStateSpec],
441        change: &Change,
442    ) -> Result<(), ProviderError> {
443        match change {
444            Change::NodeCreated { id, labels, .. } => {
445                if self.node_labels.insert(*id, labels.clone()).is_some() {
446                    return Err(inconsistent(format!("duplicate node create for {id}")));
447                }
448                self.recompute_node(specs, *id);
449            }
450            Change::NodeUpdated {
451                id, labels_diff, ..
452            } => {
453                let labels = self
454                    .node_labels
455                    .get_mut(id)
456                    .ok_or_else(|| inconsistent(format!("label update for unknown node {id}")))?;
457                for label in &labels_diff.removed {
458                    labels.remove(label);
459                }
460                for label in &labels_diff.added {
461                    labels.insert(label.clone());
462                }
463                self.recompute_node(specs, *id);
464            }
465            Change::NodeDeleted { id } => {
466                self.node_labels.remove(id);
467                self.remove_incident_edges(specs, *id);
468                self.recompute_node(specs, *id);
469            }
470            Change::NodeLabelRemoved { id, label } => {
471                let labels = self
472                    .node_labels
473                    .get_mut(id)
474                    .ok_or_else(|| inconsistent(format!("label removal for unknown node {id}")))?;
475                labels.remove(label);
476                self.recompute_node(specs, *id);
477            }
478            Change::EdgeCreated {
479                id,
480                label,
481                source,
482                target,
483                ..
484            } => {
485                if watches_label(specs, label) {
486                    let edge = TrackedEdge {
487                        label: label.clone(),
488                        source: *source,
489                        target: *target,
490                    };
491                    if self.edges.insert(*id, edge.clone()).is_some() {
492                        return Err(inconsistent(format!("duplicate edge create for {id}")));
493                    }
494                    self.increment_edge(&edge);
495                    self.recompute_node(specs, *source);
496                    self.recompute_node(specs, *target);
497                }
498            }
499            Change::EdgeDeleted { id } => {
500                if let Some(edge) = self.edges.remove(id) {
501                    self.decrement_edge(&edge);
502                    self.recompute_node(specs, edge.source);
503                    self.recompute_node(specs, edge.target);
504                }
505            }
506            Change::GraphReset {} => {
507                *self = Self::new(specs);
508            }
509            Change::NodesOfTypeTruncated { label } => {
510                let removed = self
511                    .node_labels
512                    .iter()
513                    .filter_map(|(id, labels)| labels.contains(label).then_some(*id))
514                    .collect::<BTreeSet<_>>();
515                if !removed.is_empty() {
516                    self.node_labels.retain(|id, _| !removed.contains(id));
517                    self.edges.retain(|_, edge| {
518                        !removed.contains(&edge.source) && !removed.contains(&edge.target)
519                    });
520                    self.rebuild_derived(specs);
521                }
522            }
523            Change::EdgesOfTypeTruncated { label } => {
524                if watches_label(specs, label) {
525                    self.edges.retain(|_, edge| edge.label != *label);
526                    self.rebuild_derived(specs);
527                }
528            }
529            Change::EdgeUpdated { .. }
530            | Change::EdgePropertyRemoved { .. }
531            | Change::NodePropertyRemoved { .. }
532            | Change::SchemaChanged { .. } => {}
533        }
534        Ok(())
535    }
536
537    fn rebuild_derived(&mut self, specs: &[CandidateStateSpec]) {
538        self.outgoing_counts.clear();
539        self.incoming_counts.clear();
540        self.members = empty_members(specs);
541        for edge in self.edges.values().cloned().collect::<Vec<_>>() {
542            self.increment_edge(&edge);
543        }
544        for id in self.node_labels.keys().copied().collect::<Vec<_>>() {
545            self.recompute_node(specs, id);
546        }
547    }
548
549    fn increment_edge(&mut self, edge: &TrackedEdge) {
550        *self
551            .outgoing_counts
552            .entry((edge.source, edge.label.clone()))
553            .or_insert(0) += 1;
554        *self
555            .incoming_counts
556            .entry((edge.target, edge.label.clone()))
557            .or_insert(0) += 1;
558    }
559
560    fn decrement_edge(&mut self, edge: &TrackedEdge) {
561        decrement_count(&mut self.outgoing_counts, (edge.source, edge.label.clone()));
562        decrement_count(&mut self.incoming_counts, (edge.target, edge.label.clone()));
563    }
564
565    fn remove_incident_edges(&mut self, specs: &[CandidateStateSpec], node: NodeId) {
566        let incident = self
567            .edges
568            .iter()
569            .filter_map(|(id, edge)| {
570                (edge.source == node || edge.target == node).then_some((*id, edge.clone()))
571            })
572            .collect::<Vec<_>>();
573        for (id, edge) in incident {
574            self.edges.remove(&id);
575            self.decrement_edge(&edge);
576            if edge.source != node {
577                self.recompute_node(specs, edge.source);
578            }
579            if edge.target != node {
580                self.recompute_node(specs, edge.target);
581            }
582        }
583    }
584
585    fn recompute_node(&mut self, specs: &[CandidateStateSpec], node: NodeId) {
586        let labels = self.node_labels.get(&node).cloned();
587        for spec in specs {
588            let include = labels.as_ref().is_some_and(|labels| {
589                spec.required_label
590                    .as_ref()
591                    .is_none_or(|required| labels.contains(required))
592                    && spec
593                        .require_outgoing
594                        .iter()
595                        .all(|label| has_count(&self.outgoing_counts, node, label))
596                    && spec
597                        .require_incoming
598                        .iter()
599                        .all(|label| has_count(&self.incoming_counts, node, label))
600                    && spec
601                        .exclude_outgoing
602                        .iter()
603                        .all(|label| !has_count(&self.outgoing_counts, node, label))
604                    && spec
605                        .exclude_incoming
606                        .iter()
607                        .all(|label| !has_count(&self.incoming_counts, node, label))
608            });
609            let members = self.members.entry(spec.name.clone()).or_default();
610            if include {
611                members.insert(node);
612            } else {
613                members.remove(&node);
614            }
615        }
616    }
617}
618
619fn validate_unique_specs(specs: &[CandidateStateSpec]) -> Result<(), ProviderError> {
620    let mut seen = BTreeSet::new();
621    for spec in specs {
622        if !seen.insert(spec.name.clone()) {
623            return Err(inconsistent(format!(
624                "duplicate candidate-state spec name {}",
625                spec.name.as_str()
626            )));
627        }
628    }
629    Ok(())
630}
631
632fn empty_members(specs: &[CandidateStateSpec]) -> BTreeMap<DbString, BTreeSet<NodeId>> {
633    specs
634        .iter()
635        .map(|spec| (spec.name.clone(), BTreeSet::new()))
636        .collect()
637}
638
639fn watches_label(specs: &[CandidateStateSpec], label: &DbString) -> bool {
640    specs.iter().any(|spec| {
641        spec.require_outgoing.binary_search(label).is_ok()
642            || spec.require_incoming.binary_search(label).is_ok()
643            || spec.exclude_outgoing.binary_search(label).is_ok()
644            || spec.exclude_incoming.binary_search(label).is_ok()
645    })
646}
647
648fn has_count(counts: &BTreeMap<(NodeId, DbString), usize>, node: NodeId, label: &DbString) -> bool {
649    counts
650        .get(&(node, label.clone()))
651        .is_some_and(|count| *count > 0)
652}
653
654fn decrement_count(counts: &mut BTreeMap<(NodeId, DbString), usize>, key: (NodeId, DbString)) {
655    if let Some(count) = counts.get_mut(&key) {
656        *count = count.saturating_sub(1);
657        if *count == 0 {
658            counts.remove(&key);
659        }
660    }
661}
662
663fn insert_sorted_unique(labels: &mut Vec<DbString>, label: DbString) {
664    match labels.binary_search(&label) {
665        Ok(_) => {}
666        Err(index) => labels.insert(index, label),
667    }
668}
669
670fn canonicalize_labels(labels: &mut Vec<DbString>) {
671    labels.sort_unstable();
672    labels.dedup();
673}
674
675fn ensure_state_subtag(sub_tag: SubTag) -> Result<(), ProviderError> {
676    if sub_tag == SubTag(CANDIDATE_STATE_SUB) {
677        Ok(())
678    } else {
679        Err(invalid_payload(format!("unknown CSET sub-tag {sub_tag}")))
680    }
681}
682
683fn invalid_payload(reason: String) -> ProviderError {
684    ProviderError::InvalidPayload { reason }
685}
686
687fn inconsistent(reason: String) -> ProviderError {
688    ProviderError::Inconsistent { reason }
689}
690
691#[cfg(test)]
692#[path = "candidate_state/tests.rs"]
693mod tests;