use oxgraph_layout_util::{OffsetIntegrityIssue, check_offset_section, check_value_range};
use crate::{
error::{BcsrError, BcsrRoleSide, BcsrSection},
internal::view::BcsrSections,
word::{BcsrIndex, BcsrWord},
};
fn map_offsets_issue<W: BcsrWord>(
section: BcsrSection,
offsets: &[W],
issue: OffsetIntegrityIssue,
) -> BcsrError {
match issue {
OffsetIntegrityIssue::Length { expected, actual } => BcsrError::OffsetLength {
section,
expected,
actual,
},
OffsetIntegrityIssue::FirstNonZero { actual } => BcsrError::FirstOffset { section, actual },
OffsetIntegrityIssue::NonMonotonic {
index,
previous,
actual,
} => BcsrError::NonMonotonicOffset {
section,
index,
previous,
actual,
},
OffsetIntegrityIssue::FinalMismatch {
final_offset,
value_len,
} => BcsrError::FinalOffset {
section,
final_offset,
value_len,
},
OffsetIntegrityIssue::UsizeOverflow { index } => {
let value = offsets
.get(index)
.copied()
.and_then(|w| w.get().to_usize())
.unwrap_or(usize::MAX);
BcsrError::UsizeOverflow { value }
}
_ => BcsrError::UsizeOverflow { value: usize::MAX },
}
}
fn map_vertex_value_issue<W: BcsrWord>(
section: BcsrSection,
values: &[W],
issue: OffsetIntegrityIssue,
) -> BcsrError {
match issue {
OffsetIntegrityIssue::ValueOutOfRange {
index,
value,
bound,
} => BcsrError::VertexOutOfRange {
section,
index,
vertex: value,
vertex_count: bound,
},
OffsetIntegrityIssue::UsizeOverflow { index } => {
let value = values
.get(index)
.copied()
.and_then(|w| w.get().to_usize())
.unwrap_or(usize::MAX);
BcsrError::UsizeOverflow { value }
}
_ => BcsrError::UsizeOverflow { value: usize::MAX },
}
}
fn map_hyperedge_value_issue<W: BcsrWord>(
section: BcsrSection,
values: &[W],
issue: OffsetIntegrityIssue,
) -> BcsrError {
match issue {
OffsetIntegrityIssue::ValueOutOfRange {
index,
value,
bound,
} => BcsrError::HyperedgeOutOfRange {
section,
index,
hyperedge: value,
hyperedge_count: bound,
},
OffsetIntegrityIssue::UsizeOverflow { index } => {
let value = values
.get(index)
.copied()
.and_then(|w| w.get().to_usize())
.unwrap_or(usize::MAX);
BcsrError::UsizeOverflow { value }
}
_ => BcsrError::UsizeOverflow { value: usize::MAX },
}
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum BcsrValidation {
Layout,
Strict,
}
#[derive(Clone, Copy, Debug)]
pub(in crate::internal) struct DerivedCounts {
pub(in crate::internal) vertex_count: usize,
pub(in crate::internal) hyperedge_count: usize,
pub(in crate::internal) p_outgoing: usize,
pub(in crate::internal) p_incoming: usize,
pub(in crate::internal) total_incidences: usize,
}
pub(in crate::internal) fn validate_sections<OffsetWord, VertexWord, RelationWord>(
sections: &BcsrSections<'_, OffsetWord, VertexWord, RelationWord>,
level: BcsrValidation,
) -> Result<DerivedCounts, BcsrError>
where
OffsetWord: BcsrWord,
VertexWord: BcsrWord,
RelationWord: BcsrWord,
{
let counts = derive_counts(sections)?;
validate_all_offsets(sections, counts)?;
validate_total_lengths(sections)?;
validate_value_ranges(sections, counts)?;
validate_within_range_sorted(sections)?;
if matches!(level, BcsrValidation::Strict) {
validate_cross_direction(sections)?;
}
Ok(counts)
}
fn derive_counts<OffsetWord, VertexWord, RelationWord>(
sections: &BcsrSections<'_, OffsetWord, VertexWord, RelationWord>,
) -> Result<DerivedCounts, BcsrError>
where
OffsetWord: BcsrWord,
VertexWord: BcsrWord,
RelationWord: BcsrWord,
{
let head_len = sections.head_offsets.len();
let tail_len = sections.tail_offsets.len();
if head_len != tail_len {
return Err(BcsrError::HyperedgeOffsetLengthMismatch {
head_offsets_len: head_len,
tail_offsets_len: tail_len,
});
}
let outgoing_len = sections.vertex_outgoing_offsets.len();
let incoming_len = sections.vertex_incoming_offsets.len();
if outgoing_len != incoming_len {
return Err(BcsrError::VertexOffsetLengthMismatch {
outgoing_offsets_len: outgoing_len,
incoming_offsets_len: incoming_len,
});
}
let hyperedge_count = derive_count_from_offsets(head_len, BcsrSection::HeadOffsets)?;
let vertex_count = derive_count_from_offsets(outgoing_len, BcsrSection::VertexOutgoingOffsets)?;
let p_outgoing = sections.vertex_outgoing_hyperedges.len();
let p_incoming = sections.vertex_incoming_hyperedges.len();
let total_incidences =
p_outgoing
.checked_add(p_incoming)
.ok_or(BcsrError::TotalIncidenceCountOverflow {
p_head: p_outgoing,
p_tail: p_incoming,
})?;
Ok(DerivedCounts {
vertex_count,
hyperedge_count,
p_outgoing,
p_incoming,
total_incidences,
})
}
const fn derive_count_from_offsets(
offsets_len: usize,
section: BcsrSection,
) -> Result<usize, BcsrError> {
if offsets_len == 0 {
return Err(BcsrError::OffsetLength {
section,
expected: 1,
actual: 0,
});
}
Ok(offsets_len - 1)
}
fn validate_all_offsets<OffsetWord, VertexWord, RelationWord>(
sections: &BcsrSections<'_, OffsetWord, VertexWord, RelationWord>,
counts: DerivedCounts,
) -> Result<(), BcsrError>
where
OffsetWord: BcsrWord,
VertexWord: BcsrWord,
RelationWord: BcsrWord,
{
validate_one_offsets(
sections.head_offsets,
BcsrSection::HeadOffsets,
counts.hyperedge_count,
sections.head_participants.len(),
)?;
validate_one_offsets(
sections.tail_offsets,
BcsrSection::TailOffsets,
counts.hyperedge_count,
sections.tail_participants.len(),
)?;
validate_one_offsets(
sections.vertex_outgoing_offsets,
BcsrSection::VertexOutgoingOffsets,
counts.vertex_count,
sections.vertex_outgoing_hyperedges.len(),
)?;
validate_one_offsets(
sections.vertex_incoming_offsets,
BcsrSection::VertexIncomingOffsets,
counts.vertex_count,
sections.vertex_incoming_hyperedges.len(),
)
}
fn validate_one_offsets<Word: BcsrWord>(
offsets: &[Word],
section: BcsrSection,
count: usize,
value_len: usize,
) -> Result<(), BcsrError> {
if count.checked_add(1).is_none() {
return Err(BcsrError::OffsetLengthOverflow { count });
}
check_offset_section(offsets, count, value_len)
.map_err(|issue| map_offsets_issue(section, offsets, issue))
}
const fn validate_total_lengths<OffsetWord, VertexWord, RelationWord>(
sections: &BcsrSections<'_, OffsetWord, VertexWord, RelationWord>,
) -> Result<(), BcsrError>
where
OffsetWord: BcsrWord,
VertexWord: BcsrWord,
RelationWord: BcsrWord,
{
if sections.head_participants.len() != sections.vertex_outgoing_hyperedges.len() {
return Err(BcsrError::OutgoingTotalMismatch {
head_participants_len: sections.head_participants.len(),
outgoing_hyperedges_len: sections.vertex_outgoing_hyperedges.len(),
});
}
if sections.tail_participants.len() != sections.vertex_incoming_hyperedges.len() {
return Err(BcsrError::IncomingTotalMismatch {
tail_participants_len: sections.tail_participants.len(),
incoming_hyperedges_len: sections.vertex_incoming_hyperedges.len(),
});
}
Ok(())
}
fn validate_value_ranges<OffsetWord, VertexWord, RelationWord>(
sections: &BcsrSections<'_, OffsetWord, VertexWord, RelationWord>,
counts: DerivedCounts,
) -> Result<(), BcsrError>
where
OffsetWord: BcsrWord,
VertexWord: BcsrWord,
RelationWord: BcsrWord,
{
check_vertex_values(
sections.head_participants,
BcsrSection::HeadParticipants,
counts.vertex_count,
)?;
check_vertex_values(
sections.tail_participants,
BcsrSection::TailParticipants,
counts.vertex_count,
)?;
check_hyperedge_values(
sections.vertex_outgoing_hyperedges,
BcsrSection::VertexOutgoingHyperedges,
counts.hyperedge_count,
)?;
check_hyperedge_values(
sections.vertex_incoming_hyperedges,
BcsrSection::VertexIncomingHyperedges,
counts.hyperedge_count,
)
}
fn check_vertex_values<Word: BcsrWord>(
values: &[Word],
section: BcsrSection,
vertex_count: usize,
) -> Result<(), BcsrError> {
check_value_range(values, vertex_count)
.map_err(|issue| map_vertex_value_issue(section, values, issue))
}
fn check_hyperedge_values<Word: BcsrWord>(
values: &[Word],
section: BcsrSection,
hyperedge_count: usize,
) -> Result<(), BcsrError> {
check_value_range(values, hyperedge_count)
.map_err(|issue| map_hyperedge_value_issue(section, values, issue))
}
fn validate_within_range_sorted<OffsetWord, VertexWord, RelationWord>(
sections: &BcsrSections<'_, OffsetWord, VertexWord, RelationWord>,
) -> Result<(), BcsrError>
where
OffsetWord: BcsrWord,
VertexWord: BcsrWord,
RelationWord: BcsrWord,
{
check_strictly_ascending_buckets(
sections.head_offsets,
sections.head_participants,
BcsrSection::HeadParticipants,
)?;
check_strictly_ascending_buckets(
sections.tail_offsets,
sections.tail_participants,
BcsrSection::TailParticipants,
)?;
check_strictly_ascending_buckets(
sections.vertex_outgoing_offsets,
sections.vertex_outgoing_hyperedges,
BcsrSection::VertexOutgoingHyperedges,
)?;
check_strictly_ascending_buckets(
sections.vertex_incoming_offsets,
sections.vertex_incoming_hyperedges,
BcsrSection::VertexIncomingHyperedges,
)
}
fn check_strictly_ascending_buckets<OffsetWord, Word>(
offsets: &[OffsetWord],
values: &[Word],
section: BcsrSection,
) -> Result<(), BcsrError>
where
OffsetWord: BcsrWord,
Word: BcsrWord,
{
if offsets.len() < 2 {
return Ok(());
}
for window in offsets.windows(2) {
let start = index_to_usize(window[0].get())?;
let end = index_to_usize(window[1].get())?;
check_strictly_ascending_range(values, start, end, section)?;
}
Ok(())
}
fn check_strictly_ascending_range<Word: BcsrWord>(
values: &[Word],
start: usize,
end: usize,
section: BcsrSection,
) -> Result<(), BcsrError> {
if end <= start + 1 {
return Ok(());
}
let mut previous = index_to_usize(values[start].get())?;
for relative in 1..(end - start) {
let index = start + relative;
let actual = index_to_usize(values[index].get())?;
if actual <= previous {
return Err(BcsrError::NotStrictlyAscending {
section,
index,
previous,
actual,
});
}
previous = actual;
}
Ok(())
}
fn validate_cross_direction<OffsetWord, VertexWord, RelationWord>(
sections: &BcsrSections<'_, OffsetWord, VertexWord, RelationWord>,
) -> Result<(), BcsrError>
where
OffsetWord: BcsrWord,
VertexWord: BcsrWord,
RelationWord: BcsrWord,
{
cross_direction_walk(
sections.head_offsets,
sections.head_participants,
sections.vertex_outgoing_offsets,
sections.vertex_outgoing_hyperedges,
BcsrRoleSide::Outgoing,
)?;
cross_direction_walk(
sections.tail_offsets,
sections.tail_participants,
sections.vertex_incoming_offsets,
sections.vertex_incoming_hyperedges,
BcsrRoleSide::Incoming,
)
}
fn cross_direction_walk<OffsetWord, VertexWord, RelationWord>(
edge_offsets: &[OffsetWord],
edge_values: &[VertexWord],
vertex_offsets: &[OffsetWord],
vertex_values: &[RelationWord],
side: BcsrRoleSide,
) -> Result<(), BcsrError>
where
OffsetWord: BcsrWord,
VertexWord: BcsrWord,
RelationWord: BcsrWord,
{
if edge_offsets.len() < 2 {
return Ok(());
}
for hyperedge_index in 0..(edge_offsets.len() - 1) {
let start = index_to_usize(edge_offsets[hyperedge_index].get())?;
let end = index_to_usize(edge_offsets[hyperedge_index + 1].get())?;
let hyperedge = hyperedge_index;
cross_direction_check_bucket(CrossDirectionBucket {
edge_values,
start,
end,
vertex_offsets,
vertex_values,
hyperedge,
side,
})?;
}
Ok(())
}
#[derive(Clone, Copy)]
struct CrossDirectionBucket<'a, OffsetWord, VertexWord, RelationWord> {
edge_values: &'a [VertexWord],
start: usize,
end: usize,
vertex_offsets: &'a [OffsetWord],
vertex_values: &'a [RelationWord],
hyperedge: usize,
side: BcsrRoleSide,
}
fn cross_direction_check_bucket<OffsetWord, VertexWord, RelationWord>(
args: CrossDirectionBucket<'_, OffsetWord, VertexWord, RelationWord>,
) -> Result<(), BcsrError>
where
OffsetWord: BcsrWord,
VertexWord: BcsrWord,
RelationWord: BcsrWord,
{
for word in args.edge_values.iter().take(args.end).skip(args.start) {
let vertex = index_to_usize(word.get())?;
let v_index = vertex;
let bucket_start = index_to_usize(args.vertex_offsets[v_index].get())?;
let bucket_end = index_to_usize(args.vertex_offsets[v_index + 1].get())?;
if !bucket_contains(args.vertex_values, bucket_start, bucket_end, args.hyperedge) {
return Err(BcsrError::CrossDirectionMismatch {
side: args.side,
hyperedge: args.hyperedge,
vertex,
});
}
}
Ok(())
}
fn bucket_contains<Word: BcsrWord>(
values: &[Word],
start: usize,
end: usize,
needle: usize,
) -> bool {
let bucket = &values[start..end];
bucket
.binary_search_by(|word| index_to_usize_validated(word.get()).cmp(&needle))
.is_ok()
}
pub(in crate::internal) fn index_to_usize<Index: BcsrIndex>(
value: Index,
) -> Result<usize, BcsrError> {
value
.to_usize()
.ok_or(BcsrError::UsizeOverflow { value: usize::MAX })
}
pub(in crate::internal) fn index_to_usize_validated<Index: BcsrIndex>(value: Index) -> usize {
value
.to_usize()
.unwrap_or_else(|| unreachable!("validated bipartite-CSR index must fit usize"))
}
pub(in crate::internal) fn usize_to_index_validated<Index: BcsrIndex>(value: usize) -> Index {
Index::from_usize(value).unwrap_or_else(|| unreachable!("validated BCSR slot must fit index"))
}