1use std::io::{self, BufRead, BufReader};
31#[cfg(feature = "zstd")]
32use std::io::{Read, Write as IoWrite};
33use std::path::Path;
34
35use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
36
37use crate::error::{DictError, Result};
38
39#[cfg(feature = "simd")]
41pub mod simd;
42
43#[cfg(feature = "simd")]
44pub use simd::SimdMatrix;
45
46const MATRIX_HEADER_SIZE: usize = 4;
48
49pub const INVALID_CONNECTION_COST: i32 = i32::MAX;
51
52pub trait Matrix {
57 fn get(&self, right_id: u16, left_id: u16) -> i32;
68
69 fn left_size(&self) -> usize;
71
72 fn right_size(&self) -> usize;
74
75 fn entry_count(&self) -> usize {
77 self.left_size() * self.right_size()
78 }
79}
80
81#[derive(Debug, Clone)]
86pub struct DenseMatrix {
87 lsize: usize,
89 rsize: usize,
91 costs: Vec<i16>,
93}
94
95impl DenseMatrix {
96 #[must_use]
104 pub fn new(lsize: usize, rsize: usize, default_cost: i16) -> Self {
105 let costs = vec![default_cost; lsize * rsize];
106 Self {
107 lsize,
108 rsize,
109 costs,
110 }
111 }
112
113 pub fn from_vec(lsize: usize, rsize: usize, costs: Vec<i16>) -> Result<Self> {
129 let expected_size = lsize * rsize;
130 if costs.len() != expected_size {
131 return Err(DictError::Format(format!(
132 "Matrix size mismatch: expected {} entries, got {}",
133 expected_size,
134 costs.len()
135 )));
136 }
137 Ok(Self {
138 lsize,
139 rsize,
140 costs,
141 })
142 }
143
144 pub fn set(&mut self, right_id: u16, left_id: u16, cost: i16) {
152 let index = right_id as usize + self.lsize * left_id as usize;
153 if index < self.costs.len() {
154 self.costs[index] = cost;
155 }
156 }
157
158 pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
176 let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
177 let reader = BufReader::new(file);
178 Self::from_def_reader(reader)
179 }
180
181 pub fn from_def_reader<R: BufRead>(mut reader: R) -> Result<Self> {
187 let mut first_line = String::new();
189 reader.read_line(&mut first_line).map_err(DictError::Io)?;
190
191 let sizes: Vec<usize> = first_line
192 .split_whitespace()
193 .filter_map(|s| s.parse().ok())
194 .collect();
195
196 if sizes.len() != 2 {
197 return Err(DictError::Format(
198 "Invalid matrix header: expected 'lsize rsize'".to_string(),
199 ));
200 }
201
202 let lsize = sizes[0];
203 let rsize = sizes[1];
204
205 let mut matrix = Self::new(lsize, rsize, i16::MAX);
207
208 for line in reader.lines() {
210 let line = line.map_err(DictError::Io)?;
211 let line = line.trim();
212
213 if line.is_empty() || line.starts_with('#') {
214 continue;
215 }
216
217 let parts: Vec<&str> = line.split_whitespace().collect();
218 if parts.len() != 3 {
219 continue;
220 }
221
222 let right_id: u16 = parts[0]
223 .parse()
224 .map_err(|_| DictError::Format(format!("Invalid right_id: {}", parts[0])))?;
225 let left_id: u16 = parts[1]
226 .parse()
227 .map_err(|_| DictError::Format(format!("Invalid left_id: {}", parts[1])))?;
228 let cost: i16 = parts[2]
229 .parse()
230 .map_err(|_| DictError::Format(format!("Invalid cost: {}", parts[2])))?;
231
232 matrix.set(right_id, left_id, cost);
233 }
234
235 Ok(matrix)
236 }
237
238 pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
252 let data = std::fs::read(path.as_ref()).map_err(DictError::Io)?;
253 Self::from_bin_bytes(&data)
254 }
255
256 pub fn from_bin_bytes(data: &[u8]) -> Result<Self> {
262 if data.len() < MATRIX_HEADER_SIZE {
263 return Err(DictError::Format(
264 "Matrix binary too short for header".to_string(),
265 ));
266 }
267
268 let mut cursor = io::Cursor::new(data);
269
270 let lsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
271 let rsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
272
273 let expected_size = lsize * rsize * 2;
274 let data_size = data.len() - MATRIX_HEADER_SIZE;
275
276 if data_size != expected_size {
277 return Err(DictError::Format(format!(
278 "Matrix data size mismatch: expected {expected_size} bytes, got {data_size}"
279 )));
280 }
281
282 let mut costs = Vec::with_capacity(lsize * rsize);
283 for _ in 0..(lsize * rsize) {
284 costs.push(cursor.read_i16::<LittleEndian>().map_err(DictError::Io)?);
285 }
286
287 Ok(Self {
288 lsize,
289 rsize,
290 costs,
291 })
292 }
293
294 #[cfg(feature = "zstd")]
300 pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
301 let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
302 let decoder = zstd::Decoder::new(file).map_err(DictError::Io)?;
303 let mut data = Vec::new();
304 BufReader::new(decoder)
305 .read_to_end(&mut data)
306 .map_err(DictError::Io)?;
307 Self::from_bin_bytes(&data)
308 }
309
310 #[cfg(not(feature = "zstd"))]
316 pub fn from_compressed_file<P: AsRef<Path>>(_path: P) -> Result<Self> {
317 Err(DictError::Format(
318 "zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
319 .to_string(),
320 ))
321 }
322
323 #[must_use]
325 pub fn to_bin_bytes(&self) -> Vec<u8> {
326 let mut buf = Vec::with_capacity(MATRIX_HEADER_SIZE + self.costs.len() * 2);
327
328 #[allow(clippy::cast_possible_truncation)]
330 buf.write_u16::<LittleEndian>(self.lsize as u16).ok();
331 #[allow(clippy::cast_possible_truncation)]
332 buf.write_u16::<LittleEndian>(self.rsize as u16).ok();
333
334 for &cost in &self.costs {
336 buf.write_i16::<LittleEndian>(cost).ok();
337 }
338
339 buf
340 }
341
342 pub fn to_bin_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
348 let data = self.to_bin_bytes();
349 std::fs::write(path.as_ref(), data).map_err(DictError::Io)
350 }
351
352 #[cfg(feature = "zstd")]
358 pub fn to_compressed_file<P: AsRef<Path>>(&self, path: P, level: i32) -> Result<()> {
359 let data = self.to_bin_bytes();
360 let file = std::fs::File::create(path.as_ref()).map_err(DictError::Io)?;
361 let mut encoder = zstd::Encoder::new(file, level).map_err(DictError::Io)?;
362 encoder.write_all(&data).map_err(DictError::Io)?;
363 encoder.finish().map_err(DictError::Io)?;
364 Ok(())
365 }
366
367 #[cfg(not(feature = "zstd"))]
373 pub fn to_compressed_file<P: AsRef<Path>>(&self, _path: P, _level: i32) -> Result<()> {
374 Err(DictError::Format(
375 "zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
376 .to_string(),
377 ))
378 }
379
380 #[must_use]
382 pub fn costs(&self) -> &[i16] {
383 &self.costs
384 }
385
386 #[must_use]
388 pub fn memory_size(&self) -> usize {
389 std::mem::size_of::<Self>() + self.costs.len() * std::mem::size_of::<i16>()
390 }
391}
392
393impl Matrix for DenseMatrix {
394 #[inline(always)]
395 fn get(&self, right_id: u16, left_id: u16) -> i32 {
396 let index = right_id as usize + self.lsize * left_id as usize;
397 if index < self.costs.len() {
398 i32::from(self.costs[index])
399 } else {
400 INVALID_CONNECTION_COST
401 }
402 }
403
404 fn left_size(&self) -> usize {
405 self.lsize
406 }
407
408 fn right_size(&self) -> usize {
409 self.rsize
410 }
411}
412
413pub struct MmapMatrix {
423 lsize: usize,
425 rsize: usize,
427 mmap: memmap2::Mmap,
429}
430
431impl MmapMatrix {
432 #[allow(unsafe_code)]
443 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
444 let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
445
446 let mmap = unsafe { memmap2::Mmap::map(&file).map_err(DictError::Io)? };
449
450 if mmap.len() < MATRIX_HEADER_SIZE {
451 return Err(DictError::Format(
452 "Matrix file too short for header".to_string(),
453 ));
454 }
455
456 let mut cursor = io::Cursor::new(&mmap[..MATRIX_HEADER_SIZE]);
458 let lsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
459 let rsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
460
461 let expected_size = MATRIX_HEADER_SIZE + lsize * rsize * 2;
462 if mmap.len() != expected_size {
463 return Err(DictError::Format(format!(
464 "Matrix file size mismatch: expected {} bytes, got {}",
465 expected_size,
466 mmap.len()
467 )));
468 }
469
470 Ok(Self { lsize, rsize, mmap })
471 }
472
473 pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<DenseMatrix> {
481 DenseMatrix::from_compressed_file(path)
483 }
484
485 #[inline]
487 const fn offset(&self, right_id: u16, left_id: u16) -> usize {
488 MATRIX_HEADER_SIZE + (right_id as usize + self.lsize * left_id as usize) * 2
489 }
490}
491
492impl Matrix for MmapMatrix {
493 #[inline(always)]
494 fn get(&self, right_id: u16, left_id: u16) -> i32 {
495 let offset = self.offset(right_id, left_id);
496 if offset + 2 <= self.mmap.len() {
497 let bytes = [self.mmap[offset], self.mmap[offset + 1]];
498 i32::from(i16::from_le_bytes(bytes))
499 } else {
500 INVALID_CONNECTION_COST
501 }
502 }
503
504 fn left_size(&self) -> usize {
505 self.lsize
506 }
507
508 fn right_size(&self) -> usize {
509 self.rsize
510 }
511}
512
513#[derive(Debug, Clone)]
518pub struct SparseMatrix {
519 lsize: usize,
521 rsize: usize,
523 default_cost: i16,
525 entries: std::collections::HashMap<usize, i16>,
527}
528
529impl SparseMatrix {
530 #[must_use]
532 pub fn new(lsize: usize, rsize: usize, default_cost: i16) -> Self {
533 Self {
534 lsize,
535 rsize,
536 default_cost,
537 entries: std::collections::HashMap::new(),
538 }
539 }
540
541 pub fn set(&mut self, right_id: u16, left_id: u16, cost: i16) {
543 let index = right_id as usize + self.lsize * left_id as usize;
544 if cost == self.default_cost {
545 self.entries.remove(&index);
546 } else {
547 self.entries.insert(index, cost);
548 }
549 }
550
551 #[must_use]
553 pub fn from_dense(dense: &DenseMatrix, default_cost: i16) -> Self {
554 let mut sparse = Self::new(dense.lsize, dense.rsize, default_cost);
555 for (index, &cost) in dense.costs.iter().enumerate() {
556 if cost != default_cost {
557 sparse.entries.insert(index, cost);
558 }
559 }
560 sparse
561 }
562
563 #[must_use]
565 pub fn to_dense(&self) -> DenseMatrix {
566 let mut costs = vec![self.default_cost; self.lsize * self.rsize];
567 for (&index, &cost) in &self.entries {
568 if index < costs.len() {
569 costs[index] = cost;
570 }
571 }
572 DenseMatrix {
573 lsize: self.lsize,
574 rsize: self.rsize,
575 costs,
576 }
577 }
578
579 #[must_use]
581 pub fn entry_count_stored(&self) -> usize {
582 self.entries.len()
583 }
584
585 #[must_use]
587 pub fn sparsity(&self) -> f64 {
588 let total = self.lsize * self.rsize;
589 if total == 0 {
590 return 0.0;
591 }
592 #[allow(clippy::cast_precision_loss)]
593 let entries_len = self.entries.len() as f64;
594 #[allow(clippy::cast_precision_loss)]
595 let total_f64 = total as f64;
596 1.0 - (entries_len / total_f64)
597 }
598
599 #[must_use]
601 pub fn memory_size(&self) -> usize {
602 std::mem::size_of::<Self>()
603 + self.entries.capacity() * (std::mem::size_of::<usize>() + std::mem::size_of::<i16>())
604 }
605}
606
607impl Matrix for SparseMatrix {
608 #[inline(always)]
609 fn get(&self, right_id: u16, left_id: u16) -> i32 {
610 let index = right_id as usize + self.lsize * left_id as usize;
611 self.entries
612 .get(&index)
613 .map_or_else(|| i32::from(self.default_cost), |&c| i32::from(c))
614 }
615
616 fn left_size(&self) -> usize {
617 self.lsize
618 }
619
620 fn right_size(&self) -> usize {
621 self.rsize
622 }
623}
624
625pub struct MatrixLoader;
629
630impl MatrixLoader {
631 pub fn load<P: AsRef<Path>>(path: P) -> Result<DenseMatrix> {
642 let path = path.as_ref();
643 let path_str = path.to_string_lossy();
644
645 if path_str.ends_with(".def") {
646 DenseMatrix::from_def_file(path)
647 } else if path_str.ends_with(".zst") || path_str.ends_with(".bin.zst") {
648 DenseMatrix::from_compressed_file(path)
649 } else if path_str.ends_with(".bin") {
650 DenseMatrix::from_bin_file(path)
651 } else {
652 DenseMatrix::from_bin_file(path).or_else(|_| DenseMatrix::from_def_file(path))
654 }
655 }
656
657 pub fn load_mmap<P: AsRef<Path>>(path: P) -> Result<MmapMatrix> {
663 MmapMatrix::from_file(path)
664 }
665}
666
667pub enum ConnectionMatrix {
671 Dense(DenseMatrix),
673 Sparse(SparseMatrix),
675 Mmap(MmapMatrix),
677}
678
679impl ConnectionMatrix {
680 pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
686 Ok(Self::Dense(DenseMatrix::from_def_file(path)?))
687 }
688
689 pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
695 Ok(Self::Dense(DenseMatrix::from_bin_file(path)?))
696 }
697
698 pub fn from_mmap_file<P: AsRef<Path>>(path: P) -> Result<Self> {
704 Ok(Self::Mmap(MmapMatrix::from_file(path)?))
705 }
706
707 pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
713 Ok(Self::Dense(DenseMatrix::from_compressed_file(path)?))
714 }
715
716 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
722 Ok(Self::Dense(MatrixLoader::load(path)?))
723 }
724}
725
726impl Matrix for ConnectionMatrix {
727 #[inline(always)]
728 fn get(&self, right_id: u16, left_id: u16) -> i32 {
729 match self {
730 Self::Dense(m) => m.get(right_id, left_id),
731 Self::Sparse(m) => m.get(right_id, left_id),
732 Self::Mmap(m) => m.get(right_id, left_id),
733 }
734 }
735
736 fn left_size(&self) -> usize {
737 match self {
738 Self::Dense(m) => m.left_size(),
739 Self::Sparse(m) => m.left_size(),
740 Self::Mmap(m) => m.left_size(),
741 }
742 }
743
744 fn right_size(&self) -> usize {
745 match self {
746 Self::Dense(m) => m.right_size(),
747 Self::Sparse(m) => m.right_size(),
748 Self::Mmap(m) => m.right_size(),
749 }
750 }
751}
752
753#[cfg(test)]
754#[allow(clippy::expect_used, clippy::unwrap_used, clippy::cast_lossless)]
755mod tests {
756 use super::*;
757
758 #[test]
759 fn test_dense_matrix_new() {
760 let matrix = DenseMatrix::new(10, 10, 0);
761 assert_eq!(matrix.left_size(), 10);
762 assert_eq!(matrix.right_size(), 10);
763 assert_eq!(matrix.entry_count(), 100);
764 assert_eq!(matrix.get(0, 0), 0);
765 }
766
767 #[test]
768 fn test_dense_matrix_set_get() {
769 let mut matrix = DenseMatrix::new(10, 10, 0);
770 matrix.set(3, 5, 100);
771 assert_eq!(matrix.get(3, 5), 100);
772 assert_eq!(matrix.get(5, 3), 0);
773 }
774
775 #[test]
776 fn test_dense_matrix_from_vec() {
777 let costs = vec![1, 2, 3, 4, 5, 6];
778 let matrix = DenseMatrix::from_vec(2, 3, costs).unwrap();
779 assert_eq!(matrix.get(0, 0), 1);
787 assert_eq!(matrix.get(1, 0), 2);
788 assert_eq!(matrix.get(0, 1), 3);
789 assert_eq!(matrix.get(1, 1), 4);
790 assert_eq!(matrix.get(0, 2), 5);
791 assert_eq!(matrix.get(1, 2), 6);
792 }
793
794 #[test]
795 fn test_dense_matrix_from_vec_size_mismatch() {
796 let costs = vec![1, 2, 3];
797 let result = DenseMatrix::from_vec(2, 3, costs);
798 assert!(result.is_err());
799 }
800
801 #[test]
802 fn test_dense_matrix_boundary() {
803 let matrix = DenseMatrix::new(10, 10, 0);
804 assert_eq!(matrix.get(100, 100), INVALID_CONNECTION_COST);
806 }
807
808 #[test]
809 fn test_dense_matrix_def_reader() {
810 let data = "3 3\n0 0 100\n1 1 200\n2 2 300\n";
811 let reader = std::io::Cursor::new(data);
812 let matrix = DenseMatrix::from_def_reader(reader).unwrap();
813
814 assert_eq!(matrix.left_size(), 3);
815 assert_eq!(matrix.right_size(), 3);
816 assert_eq!(matrix.get(0, 0), 100);
817 assert_eq!(matrix.get(1, 1), 200);
818 assert_eq!(matrix.get(2, 2), 300);
819 assert_eq!(matrix.get(0, 1), i16::MAX as i32);
821 }
822
823 #[test]
824 fn test_dense_matrix_binary_roundtrip() {
825 let mut matrix = DenseMatrix::new(5, 5, 0);
826 matrix.set(0, 0, 100);
827 matrix.set(1, 2, -500);
828 matrix.set(4, 4, 32767);
829
830 let bytes = matrix.to_bin_bytes();
831 let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
832
833 assert_eq!(loaded.left_size(), 5);
834 assert_eq!(loaded.right_size(), 5);
835 assert_eq!(loaded.get(0, 0), 100);
836 assert_eq!(loaded.get(1, 2), -500);
837 assert_eq!(loaded.get(4, 4), 32767);
838 }
839
840 #[test]
841 fn test_sparse_matrix() {
842 let mut sparse = SparseMatrix::new(100, 100, 0);
843 sparse.set(10, 20, 500);
844 sparse.set(50, 50, -100);
845
846 assert_eq!(sparse.get(10, 20), 500);
847 assert_eq!(sparse.get(50, 50), -100);
848 assert_eq!(sparse.get(0, 0), 0); assert_eq!(sparse.entry_count_stored(), 2);
851 assert!(sparse.sparsity() > 0.99); }
853
854 #[test]
855 fn test_sparse_dense_conversion() {
856 let mut dense = DenseMatrix::new(10, 10, 0);
857 dense.set(3, 3, 100);
858 dense.set(5, 7, 200);
859
860 let sparse = SparseMatrix::from_dense(&dense, 0);
861 assert_eq!(sparse.entry_count_stored(), 2);
862 assert_eq!(sparse.get(3, 3), 100);
863 assert_eq!(sparse.get(5, 7), 200);
864
865 let converted = sparse.to_dense();
866 assert_eq!(converted.get(3, 3), 100);
867 assert_eq!(converted.get(5, 7), 200);
868 assert_eq!(converted.get(0, 0), 0);
869 }
870
871 #[test]
872 fn test_memory_size() {
873 let dense = DenseMatrix::new(100, 100, 0);
874 let mem_size = dense.memory_size();
875 assert!(mem_size >= 20000);
877
878 let sparse = SparseMatrix::new(100, 100, 0);
879 let sparse_size = sparse.memory_size();
880 assert!(sparse_size < mem_size);
882 }
883
884 #[test]
885 fn test_connection_matrix_enum() {
886 let dense = DenseMatrix::new(5, 5, 100);
887 let matrix = ConnectionMatrix::Dense(dense);
888
889 assert_eq!(matrix.left_size(), 5);
890 assert_eq!(matrix.right_size(), 5);
891 assert_eq!(matrix.get(0, 0), 100);
892 }
893
894 #[test]
895 fn test_large_matrix() {
896 let matrix = DenseMatrix::new(178, 178, 0);
898 assert_eq!(matrix.entry_count(), 178 * 178);
899 assert_eq!(
900 matrix.memory_size(),
901 std::mem::size_of::<DenseMatrix>() + 178 * 178 * 2
902 );
903 }
904
905 #[test]
906 fn test_def_with_comments_and_empty_lines() {
907 let data = "2 2\n# This is a comment\n\n0 0 10\n0 1 20\n\n1 0 30\n1 1 40\n";
908 let reader = std::io::Cursor::new(data);
909 let matrix = DenseMatrix::from_def_reader(reader).unwrap();
910
911 assert_eq!(matrix.get(0, 0), 10);
912 assert_eq!(matrix.get(0, 1), 20);
913 assert_eq!(matrix.get(1, 0), 30);
914 assert_eq!(matrix.get(1, 1), 40);
915 }
916}