use alloc::{collections::BTreeMap, format, string::String, vec::Vec};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
mast::{AsmOpId, MastNodeId},
serde::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
utils::{CsrMatrix, CsrValidationError},
};
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct OpToAsmOpId {
inner: CsrMatrix<MastNodeId, (usize, AsmOpId)>,
}
impl Default for OpToAsmOpId {
fn default() -> Self {
Self::new()
}
}
impl OpToAsmOpId {
pub fn new() -> Self {
Self { inner: CsrMatrix::new() }
}
pub fn with_capacity(nodes_capacity: usize, operations_capacity: usize) -> Self {
Self {
inner: CsrMatrix::with_capacity(nodes_capacity, operations_capacity),
}
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn num_nodes(&self) -> usize {
self.inner.num_rows()
}
pub fn num_operations(&self) -> usize {
self.inner.num_elements()
}
pub fn add_asm_ops_for_node(
&mut self,
node_id: MastNodeId,
num_operations: usize,
asm_ops: Vec<(usize, AsmOpId)>,
) -> Result<(), AsmOpIndexError> {
let expected_node = self.num_nodes() as u32;
let node_idx = u32::from(node_id);
if node_idx < expected_node {
return Err(AsmOpIndexError::NodeIndex(node_id));
}
for _ in expected_node..node_idx {
self.inner.push_empty_row().map_err(|_| AsmOpIndexError::InternalStructure)?;
}
for window in asm_ops.windows(2) {
if window[0].0 >= window[1].0 {
return Err(AsmOpIndexError::NonIncreasingOpIndices);
}
}
if let Some((max_idx, _)) = asm_ops.last()
&& *max_idx >= num_operations
{
return Err(AsmOpIndexError::OpIndexOutOfBounds(*max_idx, num_operations));
}
self.inner.push_row(asm_ops).map_err(|_| AsmOpIndexError::InternalStructure)?;
Ok(())
}
pub fn asm_op_id_for_operation(&self, node_id: MastNodeId, op_idx: usize) -> Option<AsmOpId> {
let entries = self.inner.row(node_id)?;
match entries.binary_search_by_key(&op_idx, |(idx, _)| *idx) {
Ok(i) => Some(entries[i].1),
Err(i) if i > 0 => Some(entries[i - 1].1),
Err(_) => None,
}
}
pub fn first_asm_op_for_node(&self, node_id: MastNodeId) -> Option<AsmOpId> {
let entries = self.inner.row(node_id)?;
entries.first().map(|(_, id)| *id)
}
pub fn asm_ops_for_node(&self, node_id: MastNodeId) -> Vec<(usize, AsmOpId)> {
self.inner.row(node_id).map(|r| r.to_vec()).unwrap_or_default()
}
pub(super) fn validate_csr(&self, asm_op_count: usize) -> Result<(), String> {
self.inner
.validate_with(|(_op_idx, asm_op_id)| (u32::from(*asm_op_id) as usize) < asm_op_count)
.map_err(|e| format_validation_error(e, asm_op_count))
}
pub fn remap_nodes(&self, remapping: &BTreeMap<MastNodeId, MastNodeId>) -> Self {
if self.is_empty() {
return Self::new();
}
if remapping.is_empty() {
return self.clone();
}
let max_new_id = remapping.values().map(|id| u32::from(*id)).max().unwrap_or(0) as usize;
let num_new_nodes = max_new_id + 1;
let mut new_node_data: BTreeMap<usize, Vec<(usize, AsmOpId)>> = BTreeMap::new();
for (old_id, new_id) in remapping {
let new_idx = u32::from(*new_id) as usize;
if let Some(entries) = self.inner.row(*old_id)
&& !entries.is_empty()
{
new_node_data.insert(new_idx, entries.to_vec());
}
}
let mut new_inner = CsrMatrix::with_capacity(num_new_nodes, self.inner.num_elements());
for new_idx in 0..num_new_nodes {
if let Some(data) = new_node_data.get(&new_idx) {
new_inner.push_row(data.iter().copied()).expect("node count should fit in u32");
} else {
new_inner.push_empty_row().expect("node count should fit in u32");
}
}
Self { inner: new_inner }
}
pub(super) fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.inner.write_into(target);
}
pub(super) fn read_from<R: ByteReader>(
source: &mut R,
asm_op_count: usize,
) -> Result<Self, DeserializationError> {
let inner: CsrMatrix<MastNodeId, (usize, AsmOpId)> = Deserializable::read_from(source)?;
let result = Self { inner };
result.validate_csr(asm_op_count).map_err(|e| {
DeserializationError::InvalidValue(format!("OpToAsmOpId validation failed: {}", e))
})?;
Ok(result)
}
}
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum AsmOpIndexError {
#[error("Invalid node index {0:?}")]
NodeIndex(MastNodeId),
#[error("Operation indices must be strictly increasing")]
NonIncreasingOpIndices,
#[error("Operation index {0} exceeds node's operation count {1}")]
OpIndexOutOfBounds(usize, usize),
#[error("Internal CSR structure error")]
InternalStructure,
}
fn format_validation_error(error: CsrValidationError, asm_op_count: usize) -> String {
match error {
CsrValidationError::IndptrStartNotZero(val) => format!("indptr must start at 0, got {val}"),
CsrValidationError::IndptrNotMonotonic { index, prev, curr } => {
format!("indptr not monotonic at index {index}: {prev} > {curr}")
},
CsrValidationError::IndptrDataMismatch { indptr_end, data_len } => {
format!("indptr ends at {indptr_end}, but data.len() is {data_len}")
},
CsrValidationError::InvalidData { row, position } => format!(
"Invalid AsmOpId at row {row}, position {position}: exceeds asm_op count {asm_op_count}"
),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::serde::SliceReader;
fn test_asm_op_id(value: u32) -> AsmOpId {
AsmOpId::new(value)
}
fn test_node_id(value: u32) -> MastNodeId {
MastNodeId::new_unchecked(value)
}
#[test]
fn test_op_to_asm_op_id_empty() {
let storage = OpToAsmOpId::new();
assert!(storage.is_empty());
assert_eq!(storage.num_nodes(), 0);
assert_eq!(storage.num_operations(), 0);
}
#[test]
fn test_op_to_asm_op_id_default() {
let storage = OpToAsmOpId::default();
assert!(storage.is_empty());
}
#[test]
fn test_op_to_asm_op_id_with_capacity() {
let storage = OpToAsmOpId::with_capacity(10, 100);
assert!(storage.is_empty());
assert_eq!(storage.num_nodes(), 0);
}
#[test]
fn test_op_to_asm_op_id_single_node() {
let mut storage = OpToAsmOpId::new();
let node_id = test_node_id(0);
let asm_op_id = test_asm_op_id(0);
storage.add_asm_ops_for_node(node_id, 3, vec![(2, asm_op_id)]).unwrap();
assert!(!storage.is_empty());
assert_eq!(storage.num_nodes(), 1);
assert_eq!(storage.num_operations(), 1);
assert_eq!(storage.asm_op_id_for_operation(node_id, 0), None);
assert_eq!(storage.asm_op_id_for_operation(node_id, 1), None);
assert_eq!(storage.asm_op_id_for_operation(node_id, 2), Some(asm_op_id));
assert_eq!(storage.asm_op_id_for_operation(node_id, 3), Some(asm_op_id)); }
#[test]
fn test_op_to_asm_op_id_single_node_multiple_ops() {
let mut storage = OpToAsmOpId::new();
let node_id = test_node_id(0);
storage
.add_asm_ops_for_node(
node_id,
6,
vec![(0, test_asm_op_id(10)), (2, test_asm_op_id(20)), (5, test_asm_op_id(30))],
)
.unwrap();
assert_eq!(storage.num_operations(), 3);
assert_eq!(storage.asm_op_id_for_operation(node_id, 0), Some(test_asm_op_id(10)));
assert_eq!(storage.asm_op_id_for_operation(node_id, 1), Some(test_asm_op_id(10)));
assert_eq!(storage.asm_op_id_for_operation(node_id, 2), Some(test_asm_op_id(20)));
assert_eq!(storage.asm_op_id_for_operation(node_id, 3), Some(test_asm_op_id(20)));
assert_eq!(storage.asm_op_id_for_operation(node_id, 4), Some(test_asm_op_id(20)));
assert_eq!(storage.asm_op_id_for_operation(node_id, 5), Some(test_asm_op_id(30)));
}
#[test]
fn test_op_to_asm_op_id_empty_node() {
let mut storage = OpToAsmOpId::new();
let node_id = test_node_id(0);
storage.add_asm_ops_for_node(node_id, 0, vec![]).unwrap();
assert!(!storage.is_empty());
assert_eq!(storage.num_nodes(), 1);
assert_eq!(storage.num_operations(), 0);
assert_eq!(storage.asm_op_id_for_operation(node_id, 0), None);
}
#[test]
fn test_op_to_asm_op_id_multiple_nodes() {
let mut storage = OpToAsmOpId::new();
storage
.add_asm_ops_for_node(test_node_id(0), 2, vec![(1, test_asm_op_id(0))])
.unwrap();
storage
.add_asm_ops_for_node(
test_node_id(1),
3,
vec![(0, test_asm_op_id(1)), (2, test_asm_op_id(2))],
)
.unwrap();
assert_eq!(storage.num_nodes(), 2);
assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 0), None);
assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 1), Some(test_asm_op_id(0)));
assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 0), Some(test_asm_op_id(1)));
assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 1), Some(test_asm_op_id(1)));
assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 2), Some(test_asm_op_id(2)));
}
#[test]
fn test_op_to_asm_op_id_mixed_empty_and_populated_nodes() {
let mut storage = OpToAsmOpId::new();
storage
.add_asm_ops_for_node(test_node_id(0), 1, vec![(0, test_asm_op_id(0))])
.unwrap();
storage.add_asm_ops_for_node(test_node_id(1), 0, vec![]).unwrap();
storage
.add_asm_ops_for_node(test_node_id(2), 2, vec![(1, test_asm_op_id(1))])
.unwrap();
assert_eq!(storage.num_nodes(), 3);
assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 0), Some(test_asm_op_id(0)));
assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 0), None);
assert_eq!(storage.asm_op_id_for_operation(test_node_id(2), 0), None);
assert_eq!(storage.asm_op_id_for_operation(test_node_id(2), 1), Some(test_asm_op_id(1)));
}
#[test]
fn test_op_to_asm_op_id_gap_in_nodes() {
let mut storage = OpToAsmOpId::new();
storage
.add_asm_ops_for_node(test_node_id(0), 1, vec![(0, test_asm_op_id(0))])
.unwrap();
storage
.add_asm_ops_for_node(test_node_id(2), 1, vec![(0, test_asm_op_id(1))])
.unwrap();
assert_eq!(storage.num_nodes(), 3);
assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 0), None);
assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 0), Some(test_asm_op_id(0)));
assert_eq!(storage.asm_op_id_for_operation(test_node_id(2), 0), Some(test_asm_op_id(1)));
}
#[test]
fn test_first_asm_op_for_node() {
let mut storage = OpToAsmOpId::new();
storage
.add_asm_ops_for_node(test_node_id(0), 3, vec![(2, test_asm_op_id(42))])
.unwrap();
assert_eq!(storage.first_asm_op_for_node(test_node_id(0)), Some(test_asm_op_id(42)));
}
#[test]
fn test_first_asm_op_for_node_empty() {
let mut storage = OpToAsmOpId::new();
storage.add_asm_ops_for_node(test_node_id(0), 0, vec![]).unwrap();
assert_eq!(storage.first_asm_op_for_node(test_node_id(0)), None);
}
#[test]
fn test_first_asm_op_for_node_nonexistent() {
let storage = OpToAsmOpId::new();
assert_eq!(storage.first_asm_op_for_node(test_node_id(0)), None);
}
#[test]
fn test_first_asm_op_for_node_multiple_ops() {
let mut storage = OpToAsmOpId::new();
storage
.add_asm_ops_for_node(
test_node_id(0),
4,
vec![(1, test_asm_op_id(10)), (3, test_asm_op_id(30))],
)
.unwrap();
assert_eq!(storage.first_asm_op_for_node(test_node_id(0)), Some(test_asm_op_id(10)));
}
#[test]
fn test_op_to_asm_op_id_non_increasing_ops() {
let mut storage = OpToAsmOpId::new();
let result = storage.add_asm_ops_for_node(
test_node_id(0),
3,
vec![(2, test_asm_op_id(0)), (1, test_asm_op_id(1))],
);
assert_eq!(result, Err(AsmOpIndexError::NonIncreasingOpIndices));
}
#[test]
fn test_op_to_asm_op_id_duplicate_ops() {
let mut storage = OpToAsmOpId::new();
let result = storage.add_asm_ops_for_node(
test_node_id(0),
2,
vec![(1, test_asm_op_id(0)), (1, test_asm_op_id(1))],
);
assert_eq!(result, Err(AsmOpIndexError::NonIncreasingOpIndices));
}
#[test]
fn test_op_to_asm_op_id_node_already_added() {
let mut storage = OpToAsmOpId::new();
storage.add_asm_ops_for_node(test_node_id(0), 0, vec![]).unwrap();
storage.add_asm_ops_for_node(test_node_id(1), 0, vec![]).unwrap();
let result = storage.add_asm_ops_for_node(test_node_id(0), 0, vec![]);
assert_eq!(result, Err(AsmOpIndexError::NodeIndex(test_node_id(0))));
}
#[test]
fn test_op_to_asm_op_id_query_nonexistent_node() {
let storage = OpToAsmOpId::new();
assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 0), None);
assert_eq!(storage.asm_op_id_for_operation(test_node_id(999), 0), None);
}
#[test]
fn test_op_to_asm_op_id_query_out_of_bounds_op() {
let mut storage = OpToAsmOpId::new();
storage
.add_asm_ops_for_node(test_node_id(0), 2, vec![(1, test_asm_op_id(0))])
.unwrap();
assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 2), Some(test_asm_op_id(0)));
assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 100), Some(test_asm_op_id(0)));
}
#[test]
fn test_validate_csr_empty() {
let storage = OpToAsmOpId::new();
assert!(storage.validate_csr(0).is_ok());
}
#[test]
fn test_validate_csr_valid() {
let mut storage = OpToAsmOpId::new();
storage
.add_asm_ops_for_node(
test_node_id(0),
2,
vec![(0, test_asm_op_id(0)), (1, test_asm_op_id(1))],
)
.unwrap();
assert!(storage.validate_csr(2).is_ok());
}
#[test]
fn test_validate_csr_invalid_asm_op_id() {
let mut storage = OpToAsmOpId::new();
storage
.add_asm_ops_for_node(
test_node_id(0),
2,
vec![(0, test_asm_op_id(0)), (1, test_asm_op_id(5))],
)
.unwrap();
let result = storage.validate_csr(2);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid AsmOpId"));
}
#[test]
fn test_serialization_roundtrip_empty() {
let storage = OpToAsmOpId::new();
let mut bytes = alloc::vec::Vec::new();
storage.write_into(&mut bytes);
let mut reader = SliceReader::new(&bytes);
let deserialized = OpToAsmOpId::read_from(&mut reader, 0).unwrap();
assert_eq!(storage, deserialized);
}
#[test]
fn test_serialization_roundtrip_with_data() {
let mut storage = OpToAsmOpId::new();
storage
.add_asm_ops_for_node(
test_node_id(0),
3,
vec![(0, test_asm_op_id(0)), (2, test_asm_op_id(1))],
)
.unwrap();
storage.add_asm_ops_for_node(test_node_id(1), 0, vec![]).unwrap();
storage
.add_asm_ops_for_node(test_node_id(2), 2, vec![(1, test_asm_op_id(2))])
.unwrap();
let mut bytes = alloc::vec::Vec::new();
storage.write_into(&mut bytes);
let mut reader = SliceReader::new(&bytes);
let deserialized = OpToAsmOpId::read_from(&mut reader, 3).unwrap();
assert_eq!(storage, deserialized);
}
#[test]
fn test_clone_and_equality() {
let mut storage1 = OpToAsmOpId::new();
storage1
.add_asm_ops_for_node(test_node_id(0), 1, vec![(0, test_asm_op_id(42))])
.unwrap();
let storage2 = storage1.clone();
assert_eq!(storage1, storage2);
let mut storage3 = OpToAsmOpId::new();
storage3
.add_asm_ops_for_node(test_node_id(0), 1, vec![(0, test_asm_op_id(99))])
.unwrap();
assert_ne!(storage1, storage3);
}
#[test]
fn test_debug_impl() {
let storage = OpToAsmOpId::new();
let debug_str = alloc::format!("{:?}", storage);
assert!(debug_str.contains("OpToAsmOpId"));
}
}