use std::collections::BTreeSet;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use selene_core::{Change, DbString, NodeId};
#[cfg(test)]
use selene_core::{EdgeId, LabelSet};
use crate::index_provider::{
IndexProvider, ProviderError, ProviderTag, SubTag, VectorCandidateStateInfo,
};
use crate::store::RowIndex;
use crate::{SeleneGraph, VectorCandidateSet};
#[path = "candidate_state/state.rs"]
mod state;
use state::{
CandidateState, CandidateStateSnapshot, TrackedEdge, canonicalize_labels, ensure_state_subtag,
inconsistent, insert_sorted_unique, invalid_payload, validate_unique_specs, watches_label,
};
pub const CANDIDATE_STATE_PROVIDER_TAG: [u8; 4] = *b"CSET";
pub const CANDIDATE_STATE_SUB: [u8; 4] = *b"STAT";
const SNAPSHOT_VERSION: u8 = 1;
const SUB_TAGS: &[SubTag] = &[SubTag(CANDIDATE_STATE_SUB)];
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct CandidateStateSpec {
pub name: DbString,
pub required_label: Option<DbString>,
pub require_outgoing: Vec<DbString>,
pub require_incoming: Vec<DbString>,
pub exclude_outgoing: Vec<DbString>,
pub exclude_incoming: Vec<DbString>,
}
impl CandidateStateSpec {
#[must_use]
pub fn new(name: DbString) -> Self {
Self {
name,
required_label: None,
require_outgoing: Vec::new(),
require_incoming: Vec::new(),
exclude_outgoing: Vec::new(),
exclude_incoming: Vec::new(),
}
}
#[must_use]
pub fn require_label(mut self, label: DbString) -> Self {
self.required_label = Some(label);
self
}
#[must_use]
pub fn require_outgoing(mut self, label: DbString) -> Self {
insert_sorted_unique(&mut self.require_outgoing, label);
self
}
#[must_use]
pub fn require_incoming(mut self, label: DbString) -> Self {
insert_sorted_unique(&mut self.require_incoming, label);
self
}
#[must_use]
pub fn exclude_outgoing(mut self, label: DbString) -> Self {
insert_sorted_unique(&mut self.exclude_outgoing, label);
self
}
#[must_use]
pub fn exclude_incoming(mut self, label: DbString) -> Self {
insert_sorted_unique(&mut self.exclude_incoming, label);
self
}
}
pub struct MaintainedCandidateStateProvider {
specs: Vec<CandidateStateSpec>,
state: Mutex<CandidateState>,
}
impl MaintainedCandidateStateProvider {
pub fn new(specs: impl IntoIterator<Item = CandidateStateSpec>) -> Result<Self, ProviderError> {
let mut specs = specs.into_iter().collect::<Vec<_>>();
for spec in &mut specs {
canonicalize_labels(&mut spec.require_outgoing);
canonicalize_labels(&mut spec.require_incoming);
canonicalize_labels(&mut spec.exclude_outgoing);
canonicalize_labels(&mut spec.exclude_incoming);
}
validate_unique_specs(&specs)?;
Ok(Self {
state: Mutex::new(CandidateState::new(&specs)),
specs,
})
}
pub fn from_graph(
specs: impl IntoIterator<Item = CandidateStateSpec>,
graph: &SeleneGraph,
) -> Result<Self, ProviderError> {
let provider = Self::new(specs)?;
provider.rebuild_from_graph(graph)?;
Ok(provider)
}
pub fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
let mut rebuilt = CandidateState::new(&self.specs);
for row in graph.live_nodes() {
let row = RowIndex::new(row);
let id = graph.node_id_for_row(row).ok_or_else(|| {
inconsistent(format!("live node row {} has no external id", row.get()))
})?;
let labels = graph
.node_labels(id)
.ok_or_else(|| inconsistent(format!("live node {id} has no label column entry")))?;
rebuilt.node_labels.insert(id, labels.clone());
}
for row in graph.live_edges() {
let row = RowIndex::new(row);
let id = graph.edge_id_for_row(row).ok_or_else(|| {
inconsistent(format!("live edge row {} has no external id", row.get()))
})?;
let label = graph
.edge_label(id)
.ok_or_else(|| inconsistent(format!("live edge {id} has no label")))?;
if !watches_label(&self.specs, label) {
continue;
}
let (source, target) = graph
.edge_endpoints(id)
.ok_or_else(|| inconsistent(format!("live edge {id} has no endpoints")))?;
rebuilt.edges.insert(
id,
TrackedEdge {
label: label.clone(),
source,
target,
},
);
}
rebuilt.rebuild_derived(&self.specs);
rebuilt.generation = graph.meta.generation;
*self.state.lock() = rebuilt;
Ok(())
}
#[must_use]
pub fn spec(&self, name: &DbString) -> Option<&CandidateStateSpec> {
self.specs.iter().find(|spec| &spec.name == name)
}
#[must_use]
pub fn candidate_set(&self, name: &DbString) -> Option<VectorCandidateSet> {
let state = self.state.lock();
state.members.get(name).map(|members| {
VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
})
}
#[must_use]
pub fn generation(&self) -> u64 {
self.state.lock().generation
}
pub fn candidate_set_at_generation(
&self,
name: &DbString,
generation: u64,
) -> Result<Option<VectorCandidateSet>, ProviderError> {
let state = self.state.lock();
if state.generation != generation {
return Err(inconsistent(format!(
"candidate-state generation {} does not match graph generation {generation}",
state.generation
)));
}
Ok(state.members.get(name).map(|members| {
VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
}))
}
pub fn candidate_state_infos_at_generation(
&self,
generation: u64,
) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
let state = self.state.lock();
if state.generation != generation {
return Err(inconsistent(format!(
"candidate-state generation {} does not match graph generation {generation}",
state.generation
)));
}
Ok(self
.specs
.iter()
.map(|spec| VectorCandidateStateInfo {
name: spec.name.clone(),
generation,
candidate_count: state.members.get(&spec.name).map_or(0, BTreeSet::len),
required_label: spec.required_label.clone(),
require_outgoing: spec.require_outgoing.clone(),
require_incoming: spec.require_incoming.clone(),
exclude_outgoing: spec.exclude_outgoing.clone(),
exclude_incoming: spec.exclude_incoming.clone(),
})
.collect())
}
#[must_use]
pub fn contains(&self, name: &DbString, node: NodeId) -> bool {
self.state
.lock()
.members
.get(name)
.is_some_and(|members| members.contains(&node))
}
}
impl IndexProvider for MaintainedCandidateStateProvider {
fn provider_tag(&self) -> ProviderTag {
ProviderTag(CANDIDATE_STATE_PROVIDER_TAG)
}
fn read_section(&self, sub_tag: SubTag, bytes: &[u8]) -> Result<(), ProviderError> {
ensure_state_subtag(sub_tag)?;
let snapshot: CandidateStateSnapshot = postcard::from_bytes(bytes).map_err(|error| {
invalid_payload(format!("CSET/STAT postcard decode failed: {error}"))
})?;
if snapshot.version != SNAPSHOT_VERSION {
return Err(invalid_payload(format!(
"unsupported CSET/STAT version {}",
snapshot.version
)));
}
if snapshot.specs != self.specs {
return Err(invalid_payload(
"CSET/STAT specs differ from provider configuration".to_owned(),
));
}
let mut state = CandidateState::new(&self.specs);
state.generation = snapshot.generation;
for (id, labels) in snapshot.node_labels {
if state.node_labels.insert(id, labels).is_some() {
return Err(invalid_payload(format!(
"duplicate node id {id} in CSET/STAT"
)));
}
}
for (id, edge) in snapshot.edges {
if !watches_label(&self.specs, &edge.label) {
return Err(invalid_payload(format!(
"unwatched edge label {} in CSET/STAT",
edge.label.as_str()
)));
}
if !state.node_labels.contains_key(&edge.source)
|| !state.node_labels.contains_key(&edge.target)
{
return Err(invalid_payload(format!(
"tracked edge {id} references missing endpoint in CSET/STAT"
)));
}
if state.edges.insert(id, edge).is_some() {
return Err(invalid_payload(format!(
"duplicate edge id {id} in CSET/STAT"
)));
}
}
state.rebuild_derived(&self.specs);
*self.state.lock() = state;
Ok(())
}
fn write_section(&self, sub_tag: SubTag) -> Result<Vec<u8>, ProviderError> {
ensure_state_subtag(sub_tag)?;
let state = self.state.lock();
let snapshot = CandidateStateSnapshot {
version: SNAPSHOT_VERSION,
generation: state.generation,
specs: self.specs.clone(),
node_labels: state
.node_labels
.iter()
.map(|(id, labels)| (*id, labels.clone()))
.collect(),
edges: state
.edges
.iter()
.map(|(id, edge)| (*id, edge.clone()))
.collect(),
};
postcard::to_stdvec(&snapshot).map_err(|error| ProviderError::SerializationFailed {
reason: format!("CSET/STAT postcard encode failed: {error}"),
})
}
fn on_change(&self, change: &Change) -> Result<(), ProviderError> {
self.state.lock().apply_change(&self.specs, change)
}
fn handles_change_batches(&self) -> bool {
true
}
fn on_changes(&self, changes: &[Change]) -> Result<(), ProviderError> {
let mut state = self.state.lock();
for change in changes {
state.apply_change(&self.specs, change)?;
}
Ok(())
}
fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
MaintainedCandidateStateProvider::rebuild_from_graph(self, graph)
}
fn on_commit_applied(&self, generation: u64) -> Result<(), ProviderError> {
self.state.lock().generation = generation;
Ok(())
}
fn vector_candidate_set(
&self,
name: &DbString,
generation: u64,
) -> Result<Option<VectorCandidateSet>, ProviderError> {
self.candidate_set_at_generation(name, generation)
}
fn vector_candidate_state_infos(
&self,
generation: u64,
) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
self.candidate_state_infos_at_generation(generation)
}
fn declared_sub_tags(&self) -> &[SubTag] {
SUB_TAGS
}
}
#[cfg(test)]
#[path = "candidate_state/tests.rs"]
mod tests;