use std::io::{self, BufRead, BufReader};
#[cfg(feature = "zstd")]
use std::io::{Read, Write as IoWrite};
use std::path::Path;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use crate::error::{DictError, Result};
#[cfg(feature = "simd")]
pub mod simd;
#[cfg(feature = "simd")]
pub use simd::SimdMatrix;
const MATRIX_HEADER_SIZE: usize = 4;
pub const INVALID_CONNECTION_COST: i32 = i32::MAX;
pub trait Matrix {
fn get(&self, right_id: u16, left_id: u16) -> i32;
fn left_size(&self) -> usize;
fn right_size(&self) -> usize;
fn entry_count(&self) -> usize {
self.left_size() * self.right_size()
}
}
#[derive(Debug, Clone)]
pub struct DenseMatrix {
lsize: usize,
rsize: usize,
costs: Vec<i16>,
}
impl DenseMatrix {
#[must_use]
pub fn new(lsize: usize, rsize: usize, default_cost: i16) -> Self {
let costs = vec![default_cost; lsize * rsize];
Self {
lsize,
rsize,
costs,
}
}
pub fn from_vec(lsize: usize, rsize: usize, costs: Vec<i16>) -> Result<Self> {
let expected_size = lsize * rsize;
if costs.len() != expected_size {
return Err(DictError::Format(format!(
"Matrix size mismatch: expected {} entries, got {}",
expected_size,
costs.len()
)));
}
Ok(Self {
lsize,
rsize,
costs,
})
}
pub fn set(&mut self, right_id: u16, left_id: u16, cost: i16) {
let index = right_id as usize + self.lsize * left_id as usize;
if index < self.costs.len() {
self.costs[index] = cost;
}
}
pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
let reader = BufReader::new(file);
Self::from_def_reader(reader)
}
pub fn from_def_reader<R: BufRead>(mut reader: R) -> Result<Self> {
let mut first_line = String::new();
reader.read_line(&mut first_line).map_err(DictError::Io)?;
let sizes: Vec<usize> = first_line
.split_whitespace()
.filter_map(|s| s.parse().ok())
.collect();
if sizes.len() != 2 {
return Err(DictError::Format(
"Invalid matrix header: expected 'lsize rsize'".to_string(),
));
}
let lsize = sizes[0];
let rsize = sizes[1];
let mut matrix = Self::new(lsize, rsize, i16::MAX);
for line in reader.lines() {
let line = line.map_err(DictError::Io)?;
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() != 3 {
continue;
}
let right_id: u16 = parts[0]
.parse()
.map_err(|_| DictError::Format(format!("Invalid right_id: {}", parts[0])))?;
let left_id: u16 = parts[1]
.parse()
.map_err(|_| DictError::Format(format!("Invalid left_id: {}", parts[1])))?;
let cost: i16 = parts[2]
.parse()
.map_err(|_| DictError::Format(format!("Invalid cost: {}", parts[2])))?;
matrix.set(right_id, left_id, cost);
}
Ok(matrix)
}
pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let data = std::fs::read(path.as_ref()).map_err(DictError::Io)?;
Self::from_bin_bytes(&data)
}
pub fn from_bin_bytes(data: &[u8]) -> Result<Self> {
if data.len() < MATRIX_HEADER_SIZE {
return Err(DictError::Format(
"Matrix binary too short for header".to_string(),
));
}
let mut cursor = io::Cursor::new(data);
let lsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
let rsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
let expected_size = lsize * rsize * 2;
let data_size = data.len() - MATRIX_HEADER_SIZE;
if data_size != expected_size {
return Err(DictError::Format(format!(
"Matrix data size mismatch: expected {expected_size} bytes, got {data_size}"
)));
}
let mut costs = Vec::with_capacity(lsize * rsize);
for _ in 0..(lsize * rsize) {
costs.push(cursor.read_i16::<LittleEndian>().map_err(DictError::Io)?);
}
Ok(Self {
lsize,
rsize,
costs,
})
}
#[cfg(feature = "zstd")]
pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
let decoder = zstd::Decoder::new(file).map_err(DictError::Io)?;
let mut data = Vec::new();
BufReader::new(decoder)
.read_to_end(&mut data)
.map_err(DictError::Io)?;
Self::from_bin_bytes(&data)
}
#[cfg(not(feature = "zstd"))]
pub fn from_compressed_file<P: AsRef<Path>>(_path: P) -> Result<Self> {
Err(DictError::Format(
"zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
.to_string(),
))
}
#[must_use]
pub fn to_bin_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(MATRIX_HEADER_SIZE + self.costs.len() * 2);
#[allow(clippy::cast_possible_truncation)]
buf.write_u16::<LittleEndian>(self.lsize as u16).ok();
#[allow(clippy::cast_possible_truncation)]
buf.write_u16::<LittleEndian>(self.rsize as u16).ok();
for &cost in &self.costs {
buf.write_i16::<LittleEndian>(cost).ok();
}
buf
}
pub fn to_bin_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let data = self.to_bin_bytes();
std::fs::write(path.as_ref(), data).map_err(DictError::Io)
}
#[cfg(feature = "zstd")]
pub fn to_compressed_file<P: AsRef<Path>>(&self, path: P, level: i32) -> Result<()> {
let data = self.to_bin_bytes();
let file = std::fs::File::create(path.as_ref()).map_err(DictError::Io)?;
let mut encoder = zstd::Encoder::new(file, level).map_err(DictError::Io)?;
encoder.write_all(&data).map_err(DictError::Io)?;
encoder.finish().map_err(DictError::Io)?;
Ok(())
}
#[cfg(not(feature = "zstd"))]
pub fn to_compressed_file<P: AsRef<Path>>(&self, _path: P, _level: i32) -> Result<()> {
Err(DictError::Format(
"zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
.to_string(),
))
}
#[must_use]
pub fn costs(&self) -> &[i16] {
&self.costs
}
#[must_use]
pub fn memory_size(&self) -> usize {
std::mem::size_of::<Self>() + self.costs.len() * std::mem::size_of::<i16>()
}
}
impl Matrix for DenseMatrix {
#[inline(always)]
fn get(&self, right_id: u16, left_id: u16) -> i32 {
let index = right_id as usize + self.lsize * left_id as usize;
if index < self.costs.len() {
i32::from(self.costs[index])
} else {
INVALID_CONNECTION_COST
}
}
fn left_size(&self) -> usize {
self.lsize
}
fn right_size(&self) -> usize {
self.rsize
}
}
pub struct MmapMatrix {
lsize: usize,
rsize: usize,
mmap: memmap2::Mmap,
}
impl MmapMatrix {
#[allow(unsafe_code)]
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
let mmap = unsafe { memmap2::Mmap::map(&file).map_err(DictError::Io)? };
if mmap.len() < MATRIX_HEADER_SIZE {
return Err(DictError::Format(
"Matrix file too short for header".to_string(),
));
}
let mut cursor = io::Cursor::new(&mmap[..MATRIX_HEADER_SIZE]);
let lsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
let rsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
let expected_size = MATRIX_HEADER_SIZE + lsize * rsize * 2;
if mmap.len() != expected_size {
return Err(DictError::Format(format!(
"Matrix file size mismatch: expected {} bytes, got {}",
expected_size,
mmap.len()
)));
}
Ok(Self { lsize, rsize, mmap })
}
pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<DenseMatrix> {
DenseMatrix::from_compressed_file(path)
}
#[inline]
const fn offset(&self, right_id: u16, left_id: u16) -> usize {
MATRIX_HEADER_SIZE + (right_id as usize + self.lsize * left_id as usize) * 2
}
}
impl Matrix for MmapMatrix {
#[inline(always)]
fn get(&self, right_id: u16, left_id: u16) -> i32 {
let offset = self.offset(right_id, left_id);
if offset + 2 <= self.mmap.len() {
let bytes = [self.mmap[offset], self.mmap[offset + 1]];
i32::from(i16::from_le_bytes(bytes))
} else {
INVALID_CONNECTION_COST
}
}
fn left_size(&self) -> usize {
self.lsize
}
fn right_size(&self) -> usize {
self.rsize
}
}
#[derive(Debug, Clone)]
pub struct SparseMatrix {
lsize: usize,
rsize: usize,
default_cost: i16,
entries: std::collections::HashMap<usize, i16>,
}
impl SparseMatrix {
#[must_use]
pub fn new(lsize: usize, rsize: usize, default_cost: i16) -> Self {
Self {
lsize,
rsize,
default_cost,
entries: std::collections::HashMap::new(),
}
}
pub fn set(&mut self, right_id: u16, left_id: u16, cost: i16) {
let index = right_id as usize + self.lsize * left_id as usize;
if cost == self.default_cost {
self.entries.remove(&index);
} else {
self.entries.insert(index, cost);
}
}
#[must_use]
pub fn from_dense(dense: &DenseMatrix, default_cost: i16) -> Self {
let mut sparse = Self::new(dense.lsize, dense.rsize, default_cost);
for (index, &cost) in dense.costs.iter().enumerate() {
if cost != default_cost {
sparse.entries.insert(index, cost);
}
}
sparse
}
#[must_use]
pub fn to_dense(&self) -> DenseMatrix {
let mut costs = vec![self.default_cost; self.lsize * self.rsize];
for (&index, &cost) in &self.entries {
if index < costs.len() {
costs[index] = cost;
}
}
DenseMatrix {
lsize: self.lsize,
rsize: self.rsize,
costs,
}
}
#[must_use]
pub fn entry_count_stored(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn sparsity(&self) -> f64 {
let total = self.lsize * self.rsize;
if total == 0 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
let entries_len = self.entries.len() as f64;
#[allow(clippy::cast_precision_loss)]
let total_f64 = total as f64;
1.0 - (entries_len / total_f64)
}
#[must_use]
pub fn memory_size(&self) -> usize {
std::mem::size_of::<Self>()
+ self.entries.capacity() * (std::mem::size_of::<usize>() + std::mem::size_of::<i16>())
}
}
impl Matrix for SparseMatrix {
#[inline(always)]
fn get(&self, right_id: u16, left_id: u16) -> i32 {
let index = right_id as usize + self.lsize * left_id as usize;
self.entries
.get(&index)
.map_or_else(|| i32::from(self.default_cost), |&c| i32::from(c))
}
fn left_size(&self) -> usize {
self.lsize
}
fn right_size(&self) -> usize {
self.rsize
}
}
pub struct MatrixLoader;
impl MatrixLoader {
pub fn load<P: AsRef<Path>>(path: P) -> Result<DenseMatrix> {
let path = path.as_ref();
let path_str = path.to_string_lossy();
if path_str.ends_with(".def") {
DenseMatrix::from_def_file(path)
} else if path_str.ends_with(".zst") || path_str.ends_with(".bin.zst") {
DenseMatrix::from_compressed_file(path)
} else if path_str.ends_with(".bin") {
DenseMatrix::from_bin_file(path)
} else {
DenseMatrix::from_bin_file(path).or_else(|_| DenseMatrix::from_def_file(path))
}
}
pub fn load_mmap<P: AsRef<Path>>(path: P) -> Result<MmapMatrix> {
MmapMatrix::from_file(path)
}
}
pub enum ConnectionMatrix {
Dense(DenseMatrix),
Sparse(SparseMatrix),
Mmap(MmapMatrix),
}
impl ConnectionMatrix {
pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
Ok(Self::Dense(DenseMatrix::from_def_file(path)?))
}
pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
Ok(Self::Dense(DenseMatrix::from_bin_file(path)?))
}
pub fn from_mmap_file<P: AsRef<Path>>(path: P) -> Result<Self> {
Ok(Self::Mmap(MmapMatrix::from_file(path)?))
}
pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
Ok(Self::Dense(DenseMatrix::from_compressed_file(path)?))
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
Ok(Self::Dense(MatrixLoader::load(path)?))
}
}
impl Matrix for ConnectionMatrix {
#[inline(always)]
fn get(&self, right_id: u16, left_id: u16) -> i32 {
match self {
Self::Dense(m) => m.get(right_id, left_id),
Self::Sparse(m) => m.get(right_id, left_id),
Self::Mmap(m) => m.get(right_id, left_id),
}
}
fn left_size(&self) -> usize {
match self {
Self::Dense(m) => m.left_size(),
Self::Sparse(m) => m.left_size(),
Self::Mmap(m) => m.left_size(),
}
}
fn right_size(&self) -> usize {
match self {
Self::Dense(m) => m.right_size(),
Self::Sparse(m) => m.right_size(),
Self::Mmap(m) => m.right_size(),
}
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::cast_lossless)]
mod tests {
use super::*;
#[test]
fn test_dense_matrix_new() {
let matrix = DenseMatrix::new(10, 10, 0);
assert_eq!(matrix.left_size(), 10);
assert_eq!(matrix.right_size(), 10);
assert_eq!(matrix.entry_count(), 100);
assert_eq!(matrix.get(0, 0), 0);
}
#[test]
fn test_dense_matrix_set_get() {
let mut matrix = DenseMatrix::new(10, 10, 0);
matrix.set(3, 5, 100);
assert_eq!(matrix.get(3, 5), 100);
assert_eq!(matrix.get(5, 3), 0);
}
#[test]
fn test_dense_matrix_from_vec() {
let costs = vec![1, 2, 3, 4, 5, 6];
let matrix = DenseMatrix::from_vec(2, 3, costs).unwrap();
assert_eq!(matrix.get(0, 0), 1);
assert_eq!(matrix.get(1, 0), 2);
assert_eq!(matrix.get(0, 1), 3);
assert_eq!(matrix.get(1, 1), 4);
assert_eq!(matrix.get(0, 2), 5);
assert_eq!(matrix.get(1, 2), 6);
}
#[test]
fn test_dense_matrix_from_vec_size_mismatch() {
let costs = vec![1, 2, 3];
let result = DenseMatrix::from_vec(2, 3, costs);
assert!(result.is_err());
}
#[test]
fn test_dense_matrix_boundary() {
let matrix = DenseMatrix::new(10, 10, 0);
assert_eq!(matrix.get(100, 100), INVALID_CONNECTION_COST);
}
#[test]
fn test_dense_matrix_def_reader() {
let data = "3 3\n0 0 100\n1 1 200\n2 2 300\n";
let reader = std::io::Cursor::new(data);
let matrix = DenseMatrix::from_def_reader(reader).unwrap();
assert_eq!(matrix.left_size(), 3);
assert_eq!(matrix.right_size(), 3);
assert_eq!(matrix.get(0, 0), 100);
assert_eq!(matrix.get(1, 1), 200);
assert_eq!(matrix.get(2, 2), 300);
assert_eq!(matrix.get(0, 1), i16::MAX as i32);
}
#[test]
fn test_dense_matrix_binary_roundtrip() {
let mut matrix = DenseMatrix::new(5, 5, 0);
matrix.set(0, 0, 100);
matrix.set(1, 2, -500);
matrix.set(4, 4, 32767);
let bytes = matrix.to_bin_bytes();
let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
assert_eq!(loaded.left_size(), 5);
assert_eq!(loaded.right_size(), 5);
assert_eq!(loaded.get(0, 0), 100);
assert_eq!(loaded.get(1, 2), -500);
assert_eq!(loaded.get(4, 4), 32767);
}
#[test]
fn test_sparse_matrix() {
let mut sparse = SparseMatrix::new(100, 100, 0);
sparse.set(10, 20, 500);
sparse.set(50, 50, -100);
assert_eq!(sparse.get(10, 20), 500);
assert_eq!(sparse.get(50, 50), -100);
assert_eq!(sparse.get(0, 0), 0);
assert_eq!(sparse.entry_count_stored(), 2);
assert!(sparse.sparsity() > 0.99); }
#[test]
fn test_sparse_dense_conversion() {
let mut dense = DenseMatrix::new(10, 10, 0);
dense.set(3, 3, 100);
dense.set(5, 7, 200);
let sparse = SparseMatrix::from_dense(&dense, 0);
assert_eq!(sparse.entry_count_stored(), 2);
assert_eq!(sparse.get(3, 3), 100);
assert_eq!(sparse.get(5, 7), 200);
let converted = sparse.to_dense();
assert_eq!(converted.get(3, 3), 100);
assert_eq!(converted.get(5, 7), 200);
assert_eq!(converted.get(0, 0), 0);
}
#[test]
fn test_memory_size() {
let dense = DenseMatrix::new(100, 100, 0);
let mem_size = dense.memory_size();
assert!(mem_size >= 20000);
let sparse = SparseMatrix::new(100, 100, 0);
let sparse_size = sparse.memory_size();
assert!(sparse_size < mem_size);
}
#[test]
fn test_connection_matrix_enum() {
let dense = DenseMatrix::new(5, 5, 100);
let matrix = ConnectionMatrix::Dense(dense);
assert_eq!(matrix.left_size(), 5);
assert_eq!(matrix.right_size(), 5);
assert_eq!(matrix.get(0, 0), 100);
}
#[test]
fn test_large_matrix() {
let matrix = DenseMatrix::new(178, 178, 0);
assert_eq!(matrix.entry_count(), 178 * 178);
assert_eq!(
matrix.memory_size(),
std::mem::size_of::<DenseMatrix>() + 178 * 178 * 2
);
}
#[test]
fn test_def_with_comments_and_empty_lines() {
let data = "2 2\n# This is a comment\n\n0 0 10\n0 1 20\n\n1 0 30\n1 1 40\n";
let reader = std::io::Cursor::new(data);
let matrix = DenseMatrix::from_def_reader(reader).unwrap();
assert_eq!(matrix.get(0, 0), 10);
assert_eq!(matrix.get(0, 1), 20);
assert_eq!(matrix.get(1, 0), 30);
assert_eq!(matrix.get(1, 1), 40);
}
}