use std::fmt;
use selene_core::{Change, DbString};
use crate::{SeleneGraph, VectorCandidateSet};
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct ProviderTag(
pub [u8; 4],
);
impl fmt::Display for ProviderTag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt_tag(self.0, f)
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct SubTag(
pub [u8; 4],
);
impl fmt::Display for SubTag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt_tag(self.0, f)
}
}
pub trait IndexProvider: Send + Sync + 'static {
fn provider_tag(&self) -> ProviderTag;
fn read_section(&self, sub_tag: SubTag, bytes: &[u8]) -> Result<(), ProviderError>;
fn write_section(&self, sub_tag: SubTag) -> Result<Vec<u8>, ProviderError>;
fn on_change(&self, change: &Change) -> Result<(), ProviderError>;
fn handles_change_batches(&self) -> bool {
false
}
fn on_changes(&self, changes: &[Change]) -> Result<(), ProviderError> {
for change in changes {
self.on_change(change)?;
}
Ok(())
}
fn rebuild_from_graph(&self, _graph: &SeleneGraph) -> Result<(), ProviderError> {
Err(ProviderError::Inconsistent {
reason: format!(
"provider {} has persisted sections but does not support graph rebuild",
self.provider_tag()
),
})
}
fn on_commit_applied(&self, _generation: u64) -> Result<(), ProviderError> {
Ok(())
}
fn vector_candidate_set(
&self,
_name: &DbString,
_generation: u64,
) -> Result<Option<VectorCandidateSet>, ProviderError> {
Ok(None)
}
fn vector_candidate_state_infos(
&self,
_generation: u64,
) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
Ok(Vec::new())
}
fn declared_sub_tags(&self) -> &[SubTag];
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct VectorCandidateStateInfo {
pub name: DbString,
pub generation: u64,
pub candidate_count: usize,
pub required_label: Option<DbString>,
pub require_outgoing: Vec<DbString>,
pub require_incoming: Vec<DbString>,
pub exclude_outgoing: Vec<DbString>,
pub exclude_incoming: Vec<DbString>,
}
#[derive(Debug, thiserror::Error, miette::Diagnostic)]
#[non_exhaustive]
pub enum ProviderError {
#[error("invalid provider payload: {reason}")]
#[diagnostic(code(SLENE_G_010))]
InvalidPayload {
reason: String,
},
#[error("provider serialization failed: {reason}")]
#[diagnostic(code(SLENE_G_012))]
SerializationFailed {
reason: String,
},
#[error("provider state inconsistency: {reason}")]
#[diagnostic(code(SLENE_G_014))]
Inconsistent {
reason: String,
},
}
fn fmt_tag(bytes: [u8; 4], f: &mut fmt::Formatter<'_>) -> fmt::Result {
if bytes.iter().all(|byte| byte.is_ascii_graphic()) {
for byte in bytes {
f.write_str(char::from(byte).encode_utf8(&mut [0; 4]))?;
}
Ok(())
} else {
write!(
f,
"0x{:02X}{:02X}{:02X}{:02X}",
bytes[0], bytes[1], bytes[2], bytes[3]
)
}
}
#[cfg(test)]
mod tests {
use parking_lot::Mutex;
use rstest::rstest;
use selene_core::{LabelSet, NodeId, PropertyMap};
use super::*;
use crate::{GraphError, GraphResult};
struct RecordingProvider {
tag: ProviderTag,
changes: Mutex<Vec<Change>>,
}
impl RecordingProvider {
fn new(tag: ProviderTag) -> Self {
Self {
tag,
changes: Mutex::new(Vec::new()),
}
}
}
impl IndexProvider for RecordingProvider {
fn provider_tag(&self) -> ProviderTag {
self.tag
}
fn read_section(&self, _sub_tag: SubTag, _bytes: &[u8]) -> Result<(), ProviderError> {
Ok(())
}
fn write_section(&self, _sub_tag: SubTag) -> Result<Vec<u8>, ProviderError> {
Ok(Vec::new())
}
fn on_change(&self, change: &Change) -> Result<(), ProviderError> {
self.changes.lock().push(change.clone());
Ok(())
}
fn declared_sub_tags(&self) -> &[SubTag] {
&[]
}
}
fn assert_send_sync_static<T: Send + Sync + 'static>() {}
#[test]
fn provider_tag_equality_and_ordering() {
let demo = ProviderTag(*b"DEMO");
let meta = ProviderTag(*b"META");
assert_eq!(demo, ProviderTag(*b"DEMO"));
assert!(demo < meta);
assert_eq!(demo.to_string(), "DEMO");
}
#[test]
fn sub_tag_equality_and_ordering() {
let graph = SubTag(*b"GRPH");
let subt = SubTag(*b"SUBT");
assert_eq!(graph, SubTag(*b"GRPH"));
assert!(graph < subt);
assert_eq!(graph.to_string(), "GRPH");
}
#[rstest]
#[case(ProviderError::InvalidPayload { reason: "bad".to_owned() })]
#[case(ProviderError::SerializationFailed { reason: "io".to_owned() })]
#[case(ProviderError::Inconsistent { reason: "duplicate".to_owned() })]
fn provider_error_gqlstatus_mappings(#[case] provider_error: ProviderError) {
let graph_error = GraphError::Provider(provider_error);
assert_eq!(graph_error.gqlstatus(), "5GQL0");
}
#[test]
fn dummy_provider_with_interior_mutability() -> GraphResult<()> {
assert_send_sync_static::<RecordingProvider>();
let provider = RecordingProvider::new(ProviderTag(*b"TEST"));
provider.on_change(&Change::NodeCreated {
id: NodeId::new(1),
labels: LabelSet::new(),
properties: PropertyMap::new(),
})?;
assert_eq!(provider.changes.lock().len(), 1);
assert_eq!(provider.provider_tag(), ProviderTag(*b"TEST"));
Ok(())
}
}