Skip to main content

mecab_ko_dict/matrix/
mod.rs

1//! # 연접 비용 행렬 (Connection Cost Matrix)
2//!
3//! 형태소 간 연접 비용을 저장하고 조회하는 모듈입니다.
4//!
5//! ## 포맷 지원
6//!
7//! - **텍스트 포맷** (`matrix.def`): `MeCab` 표준 형식
8//! - **바이너리 포맷** (`matrix.bin`): 고정 크기 i16 배열
9//! - **압축 포맷** (`matrix.bin.zst`): Zstd 압축 바이너리
10//!
11//! ## 예제
12//!
13//! ```rust,ignore
14//! use mecab_ko_dict::matrix::ConnectionMatrix;
15//!
16//! // 텍스트 파일에서 로드
17//! let matrix = ConnectionMatrix::from_def_file("matrix.def").unwrap();
18//!
19//! // 연접 비용 조회 (left_id=0, right_id=0)
20//! let cost = matrix.get(0, 0);
21//! ```
22//!
23//! ## 행렬 구조
24//!
25//! 연접 비용 행렬은 `lsize x rsize` 크기의 2차원 배열입니다.
26//! - `lsize`: 좌문맥 ID 개수
27//! - `rsize`: 우문맥 ID 개수
28//! - 접근: `matrix[right_id + lsize * left_id]`
29
30use std::io::{self, BufRead, BufReader};
31#[cfg(feature = "zstd")]
32use std::io::{Read, Write as IoWrite};
33use std::path::Path;
34
35use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
36
37use crate::error::{DictError, Result};
38
39// SIMD 최적화 모듈
40#[cfg(feature = "simd")]
41pub mod simd;
42
43#[cfg(feature = "simd")]
44pub use simd::SimdMatrix;
45
46/// 연접 비용 행렬 헤더 크기 (바이트)
47const MATRIX_HEADER_SIZE: usize = 4;
48
49/// 기본 비용 (연결 불가능한 경우)
50pub const INVALID_CONNECTION_COST: i32 = i32::MAX;
51
52/// 연접 비용 행렬 인터페이스
53///
54/// 형태소 간 연접 비용을 조회하는 인터페이스입니다.
55/// mecab-ko-core의 `ConnectionCost` trait과 호환됩니다.
56pub trait Matrix {
57    /// 연접 비용 조회
58    ///
59    /// # Arguments
60    ///
61    /// * `right_id` - 이전 노드의 우문맥 ID (right context ID)
62    /// * `left_id` - 현재 노드의 좌문맥 ID (left context ID)
63    ///
64    /// # Returns
65    ///
66    /// 연접 비용 (i32). 연결 불가능한 경우 `INVALID_CONNECTION_COST` 반환
67    fn get(&self, right_id: u16, left_id: u16) -> i32;
68
69    /// 좌문맥 크기
70    fn left_size(&self) -> usize;
71
72    /// 우문맥 크기
73    fn right_size(&self) -> usize;
74
75    /// 전체 엔트리 수
76    fn entry_count(&self) -> usize {
77        self.left_size() * self.right_size()
78    }
79}
80
81/// 밀집 연접 비용 행렬 (Dense Matrix)
82///
83/// 모든 연접 비용을 메모리에 저장하는 구현입니다.
84/// 희소 행렬이 아닌 경우에 적합합니다.
85#[derive(Debug, Clone)]
86pub struct DenseMatrix {
87    /// 좌문맥 크기
88    lsize: usize,
89    /// 우문맥 크기
90    rsize: usize,
91    /// 비용 배열 (row-major: costs[`right_id` + lsize * `left_id`])
92    costs: Vec<i16>,
93}
94
95impl DenseMatrix {
96    /// 새로운 밀집 행렬 생성 (모든 값을 기본값으로 초기화)
97    ///
98    /// # Arguments
99    ///
100    /// * `lsize` - 좌문맥 크기
101    /// * `rsize` - 우문맥 크기
102    /// * `default_cost` - 기본 비용 값
103    #[must_use]
104    pub fn new(lsize: usize, rsize: usize, default_cost: i16) -> Self {
105        let costs = vec![default_cost; lsize * rsize];
106        Self {
107            lsize,
108            rsize,
109            costs,
110        }
111    }
112
113    /// 기존 비용 벡터로 밀집 행렬 생성
114    ///
115    /// # Arguments
116    ///
117    /// * `lsize` - 좌문맥 크기
118    /// * `rsize` - 우문맥 크기
119    /// * `costs` - 비용 배열
120    ///
121    /// # Returns
122    ///
123    /// 성공 시 `DenseMatrix`, 크기 불일치 시 에러
124    ///
125    /// # Errors
126    ///
127    /// 비용 배열의 길이가 `lsize * rsize`와 일치하지 않으면 에러를 반환합니다.
128    pub fn from_vec(lsize: usize, rsize: usize, costs: Vec<i16>) -> Result<Self> {
129        let expected_size = lsize * rsize;
130        if costs.len() != expected_size {
131            return Err(DictError::Format(format!(
132                "Matrix size mismatch: expected {} entries, got {}",
133                expected_size,
134                costs.len()
135            )));
136        }
137        Ok(Self {
138            lsize,
139            rsize,
140            costs,
141        })
142    }
143
144    /// 비용 설정
145    ///
146    /// # Arguments
147    ///
148    /// * `right_id` - 우문맥 ID
149    /// * `left_id` - 좌문맥 ID
150    /// * `cost` - 비용 값
151    pub fn set(&mut self, right_id: u16, left_id: u16, cost: i16) {
152        let index = right_id as usize + self.lsize * left_id as usize;
153        if index < self.costs.len() {
154            self.costs[index] = cost;
155        }
156    }
157
158    /// 텍스트 파일(matrix.def)에서 로드
159    ///
160    /// # Format
161    ///
162    /// ```text
163    /// <lsize> <rsize>
164    /// <right_id> <left_id> <cost>
165    /// ...
166    /// ```
167    ///
168    /// # Arguments
169    ///
170    /// * `path` - matrix.def 파일 경로
171    ///
172    /// # Errors
173    ///
174    /// 파일을 읽을 수 없거나 형식이 잘못된 경우 에러를 반환합니다.
175    pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
176        let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
177        let reader = BufReader::new(file);
178        Self::from_def_reader(reader)
179    }
180
181    /// 텍스트 리더에서 로드
182    ///
183    /// # Errors
184    ///
185    /// 리더에서 데이터를 읽을 수 없거나 형식이 잘못된 경우 에러를 반환합니다.
186    pub fn from_def_reader<R: BufRead>(mut reader: R) -> Result<Self> {
187        // 첫 줄: 크기 정보
188        let mut first_line = String::new();
189        reader.read_line(&mut first_line).map_err(DictError::Io)?;
190
191        let sizes: Vec<usize> = first_line
192            .split_whitespace()
193            .filter_map(|s| s.parse().ok())
194            .collect();
195
196        if sizes.len() != 2 {
197            return Err(DictError::Format(
198                "Invalid matrix header: expected 'lsize rsize'".to_string(),
199            ));
200        }
201
202        let lsize = sizes[0];
203        let rsize = sizes[1];
204
205        // 기본값으로 초기화 (i16::MAX는 연결 불가능을 의미)
206        let mut matrix = Self::new(lsize, rsize, i16::MAX);
207
208        // 나머지 줄: 연접 비용
209        for line in reader.lines() {
210            let line = line.map_err(DictError::Io)?;
211            let line = line.trim();
212
213            if line.is_empty() || line.starts_with('#') {
214                continue;
215            }
216
217            let parts: Vec<&str> = line.split_whitespace().collect();
218            if parts.len() != 3 {
219                continue;
220            }
221
222            let right_id: u16 = parts[0]
223                .parse()
224                .map_err(|_| DictError::Format(format!("Invalid right_id: {}", parts[0])))?;
225            let left_id: u16 = parts[1]
226                .parse()
227                .map_err(|_| DictError::Format(format!("Invalid left_id: {}", parts[1])))?;
228            let cost: i16 = parts[2]
229                .parse()
230                .map_err(|_| DictError::Format(format!("Invalid cost: {}", parts[2])))?;
231
232            matrix.set(right_id, left_id, cost);
233        }
234
235        Ok(matrix)
236    }
237
238    /// 바이너리 파일(matrix.bin)에서 로드
239    ///
240    /// # Format
241    ///
242    /// ```text
243    /// [2 bytes] lsize (little-endian u16)
244    /// [2 bytes] rsize (little-endian u16)
245    /// [lsize * rsize * 2 bytes] costs (little-endian i16 array)
246    /// ```
247    ///
248    /// # Errors
249    ///
250    /// 파일을 읽을 수 없거나 형식이 잘못된 경우 에러를 반환합니다.
251    pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
252        let data = std::fs::read(path.as_ref()).map_err(DictError::Io)?;
253        Self::from_bin_bytes(&data)
254    }
255
256    /// 바이너리 바이트에서 로드
257    ///
258    /// # Errors
259    ///
260    /// 데이터가 유효한 바이너리 형식이 아닌 경우 에러를 반환합니다.
261    pub fn from_bin_bytes(data: &[u8]) -> Result<Self> {
262        if data.len() < MATRIX_HEADER_SIZE {
263            return Err(DictError::Format(
264                "Matrix binary too short for header".to_string(),
265            ));
266        }
267
268        let mut cursor = io::Cursor::new(data);
269
270        let lsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
271        let rsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
272
273        let expected_size = lsize * rsize * 2;
274        let data_size = data.len() - MATRIX_HEADER_SIZE;
275
276        if data_size != expected_size {
277            return Err(DictError::Format(format!(
278                "Matrix data size mismatch: expected {expected_size} bytes, got {data_size}"
279            )));
280        }
281
282        let mut costs = Vec::with_capacity(lsize * rsize);
283        for _ in 0..(lsize * rsize) {
284            costs.push(cursor.read_i16::<LittleEndian>().map_err(DictError::Io)?);
285        }
286
287        Ok(Self {
288            lsize,
289            rsize,
290            costs,
291        })
292    }
293
294    /// 압축된 바이너리 파일(matrix.bin.zst)에서 로드
295    ///
296    /// # Errors
297    ///
298    /// 파일을 읽거나 압축 해제할 수 없는 경우 에러를 반환합니다.
299    #[cfg(feature = "zstd")]
300    pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
301        let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
302        let decoder = zstd::Decoder::new(file).map_err(DictError::Io)?;
303        let mut data = Vec::new();
304        BufReader::new(decoder)
305            .read_to_end(&mut data)
306            .map_err(DictError::Io)?;
307        Self::from_bin_bytes(&data)
308    }
309
310    /// 압축된 바이너리 파일에서 로드 (zstd feature 비활성화 시)
311    ///
312    /// # Errors
313    ///
314    /// zstd feature가 비활성화된 경우 항상 에러를 반환합니다.
315    #[cfg(not(feature = "zstd"))]
316    pub fn from_compressed_file<P: AsRef<Path>>(_path: P) -> Result<Self> {
317        Err(DictError::Format(
318            "zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
319                .to_string(),
320        ))
321    }
322
323    /// 바이너리 형식으로 저장
324    #[must_use]
325    pub fn to_bin_bytes(&self) -> Vec<u8> {
326        let mut buf = Vec::with_capacity(MATRIX_HEADER_SIZE + self.costs.len() * 2);
327
328        // 헤더
329        #[allow(clippy::cast_possible_truncation)]
330        buf.write_u16::<LittleEndian>(self.lsize as u16).ok();
331        #[allow(clippy::cast_possible_truncation)]
332        buf.write_u16::<LittleEndian>(self.rsize as u16).ok();
333
334        // 데이터
335        for &cost in &self.costs {
336            buf.write_i16::<LittleEndian>(cost).ok();
337        }
338
339        buf
340    }
341
342    /// 바이너리 파일로 저장
343    ///
344    /// # Errors
345    ///
346    /// 파일을 쓸 수 없는 경우 에러를 반환합니다.
347    pub fn to_bin_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
348        let data = self.to_bin_bytes();
349        std::fs::write(path.as_ref(), data).map_err(DictError::Io)
350    }
351
352    /// 압축된 바이너리 파일로 저장
353    ///
354    /// # Errors
355    ///
356    /// 파일을 쓰거나 압축할 수 없는 경우 에러를 반환합니다.
357    #[cfg(feature = "zstd")]
358    pub fn to_compressed_file<P: AsRef<Path>>(&self, path: P, level: i32) -> Result<()> {
359        let data = self.to_bin_bytes();
360        let file = std::fs::File::create(path.as_ref()).map_err(DictError::Io)?;
361        let mut encoder = zstd::Encoder::new(file, level).map_err(DictError::Io)?;
362        encoder.write_all(&data).map_err(DictError::Io)?;
363        encoder.finish().map_err(DictError::Io)?;
364        Ok(())
365    }
366
367    /// 압축된 바이너리 파일로 저장 (zstd feature 비활성화 시)
368    ///
369    /// # Errors
370    ///
371    /// zstd feature가 비활성화된 경우 항상 에러를 반환합니다.
372    #[cfg(not(feature = "zstd"))]
373    pub fn to_compressed_file<P: AsRef<Path>>(&self, _path: P, _level: i32) -> Result<()> {
374        Err(DictError::Format(
375            "zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
376                .to_string(),
377        ))
378    }
379
380    /// 원본 비용 배열 참조
381    #[must_use]
382    pub fn costs(&self) -> &[i16] {
383        &self.costs
384    }
385
386    /// 메모리 사용량 (바이트)
387    #[must_use]
388    pub fn memory_size(&self) -> usize {
389        std::mem::size_of::<Self>() + self.costs.len() * std::mem::size_of::<i16>()
390    }
391}
392
393impl Matrix for DenseMatrix {
394    #[inline(always)]
395    fn get(&self, right_id: u16, left_id: u16) -> i32 {
396        let index = right_id as usize + self.lsize * left_id as usize;
397        if index < self.costs.len() {
398            i32::from(self.costs[index])
399        } else {
400            INVALID_CONNECTION_COST
401        }
402    }
403
404    fn left_size(&self) -> usize {
405        self.lsize
406    }
407
408    fn right_size(&self) -> usize {
409        self.rsize
410    }
411}
412
413/// 메모리 맵 연접 비용 행렬 (Memory-Mapped Matrix)
414///
415/// 대용량 행렬을 메모리 맵으로 로드하여 효율적으로 접근합니다.
416/// 프로세스 간 메모리 공유가 가능합니다.
417///
418/// # Safety
419///
420/// 이 구조체는 메모리 맵을 사용하므로 내부적으로 unsafe 코드가 필요합니다.
421/// 파일이 외부에서 수정되지 않아야 합니다.
422pub struct MmapMatrix {
423    /// 좌문맥 크기
424    lsize: usize,
425    /// 우문맥 크기
426    rsize: usize,
427    /// 메모리 맵
428    mmap: memmap2::Mmap,
429}
430
431impl MmapMatrix {
432    /// 바이너리 파일에서 메모리 맵으로 로드
433    ///
434    /// # Safety
435    ///
436    /// 파일이 외부에서 수정되지 않아야 합니다.
437    /// memmap2는 파일을 메모리에 매핑하며, 이는 본질적으로 unsafe입니다.
438    ///
439    /// # Errors
440    ///
441    /// 파일을 읽거나 메모리 맵을 생성할 수 없는 경우 에러를 반환합니다.
442    #[allow(unsafe_code)]
443    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
444        let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
445
446        // SAFETY: 파일이 열려 있는 동안 수정되지 않는다고 가정
447        // memmap2::Mmap::map은 파일 내용이 변경되지 않을 때 안전합니다.
448        let mmap = unsafe { memmap2::Mmap::map(&file).map_err(DictError::Io)? };
449
450        if mmap.len() < MATRIX_HEADER_SIZE {
451            return Err(DictError::Format(
452                "Matrix file too short for header".to_string(),
453            ));
454        }
455
456        // 헤더 읽기
457        let mut cursor = io::Cursor::new(&mmap[..MATRIX_HEADER_SIZE]);
458        let lsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
459        let rsize = cursor.read_u16::<LittleEndian>().map_err(DictError::Io)? as usize;
460
461        let expected_size = MATRIX_HEADER_SIZE + lsize * rsize * 2;
462        if mmap.len() != expected_size {
463            return Err(DictError::Format(format!(
464                "Matrix file size mismatch: expected {} bytes, got {}",
465                expected_size,
466                mmap.len()
467            )));
468        }
469
470        Ok(Self { lsize, rsize, mmap })
471    }
472
473    /// 압축된 파일에서 로드 (메모리에 전체 압축 해제)
474    ///
475    /// 압축 파일은 메모리 맵이 아닌 전체 압축 해제 후 로드됩니다.
476    ///
477    /// # Errors
478    ///
479    /// 파일을 읽거나 압축 해제할 수 없는 경우 에러를 반환합니다.
480    pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<DenseMatrix> {
481        // 압축 파일은 DenseMatrix로 로드
482        DenseMatrix::from_compressed_file(path)
483    }
484
485    /// 비용 배열의 오프셋 계산
486    #[inline]
487    const fn offset(&self, right_id: u16, left_id: u16) -> usize {
488        MATRIX_HEADER_SIZE + (right_id as usize + self.lsize * left_id as usize) * 2
489    }
490}
491
492impl Matrix for MmapMatrix {
493    #[inline(always)]
494    fn get(&self, right_id: u16, left_id: u16) -> i32 {
495        let offset = self.offset(right_id, left_id);
496        if offset + 2 <= self.mmap.len() {
497            let bytes = [self.mmap[offset], self.mmap[offset + 1]];
498            i32::from(i16::from_le_bytes(bytes))
499        } else {
500            INVALID_CONNECTION_COST
501        }
502    }
503
504    fn left_size(&self) -> usize {
505        self.lsize
506    }
507
508    fn right_size(&self) -> usize {
509        self.rsize
510    }
511}
512
513/// 희소 연접 비용 행렬 (Sparse Matrix)
514///
515/// 희소 행렬을 효율적으로 저장하는 구현입니다.
516/// 대부분의 값이 기본값인 경우 메모리를 절약합니다.
517#[derive(Debug, Clone)]
518pub struct SparseMatrix {
519    /// 좌문맥 크기
520    lsize: usize,
521    /// 우문맥 크기
522    rsize: usize,
523    /// 기본 비용 (희소 엔트리에 없는 경우)
524    default_cost: i16,
525    /// 희소 엔트리 (key: `right_id` + lsize * `left_id`, value: cost)
526    entries: std::collections::HashMap<usize, i16>,
527}
528
529impl SparseMatrix {
530    /// 새로운 희소 행렬 생성
531    #[must_use]
532    pub fn new(lsize: usize, rsize: usize, default_cost: i16) -> Self {
533        Self {
534            lsize,
535            rsize,
536            default_cost,
537            entries: std::collections::HashMap::new(),
538        }
539    }
540
541    /// 비용 설정
542    pub fn set(&mut self, right_id: u16, left_id: u16, cost: i16) {
543        let index = right_id as usize + self.lsize * left_id as usize;
544        if cost == self.default_cost {
545            self.entries.remove(&index);
546        } else {
547            self.entries.insert(index, cost);
548        }
549    }
550
551    /// `DenseMatrix에서` 변환 (기본값과 다른 엔트리만 저장)
552    #[must_use]
553    pub fn from_dense(dense: &DenseMatrix, default_cost: i16) -> Self {
554        let mut sparse = Self::new(dense.lsize, dense.rsize, default_cost);
555        for (index, &cost) in dense.costs.iter().enumerate() {
556            if cost != default_cost {
557                sparse.entries.insert(index, cost);
558            }
559        }
560        sparse
561    }
562
563    /// `DenseMatrix로` 변환
564    #[must_use]
565    pub fn to_dense(&self) -> DenseMatrix {
566        let mut costs = vec![self.default_cost; self.lsize * self.rsize];
567        for (&index, &cost) in &self.entries {
568            if index < costs.len() {
569                costs[index] = cost;
570            }
571        }
572        DenseMatrix {
573            lsize: self.lsize,
574            rsize: self.rsize,
575            costs,
576        }
577    }
578
579    /// 엔트리 수
580    #[must_use]
581    pub fn entry_count_stored(&self) -> usize {
582        self.entries.len()
583    }
584
585    /// 희소도 (0.0 ~ 1.0, 1.0 = 완전 희소)
586    #[must_use]
587    pub fn sparsity(&self) -> f64 {
588        let total = self.lsize * self.rsize;
589        if total == 0 {
590            return 0.0;
591        }
592        #[allow(clippy::cast_precision_loss)]
593        let entries_len = self.entries.len() as f64;
594        #[allow(clippy::cast_precision_loss)]
595        let total_f64 = total as f64;
596        1.0 - (entries_len / total_f64)
597    }
598
599    /// 메모리 사용량 (바이트, 대략적)
600    #[must_use]
601    pub fn memory_size(&self) -> usize {
602        std::mem::size_of::<Self>()
603            + self.entries.capacity() * (std::mem::size_of::<usize>() + std::mem::size_of::<i16>())
604    }
605}
606
607impl Matrix for SparseMatrix {
608    #[inline(always)]
609    fn get(&self, right_id: u16, left_id: u16) -> i32 {
610        let index = right_id as usize + self.lsize * left_id as usize;
611        self.entries
612            .get(&index)
613            .map_or_else(|| i32::from(self.default_cost), |&c| i32::from(c))
614    }
615
616    fn left_size(&self) -> usize {
617        self.lsize
618    }
619
620    fn right_size(&self) -> usize {
621        self.rsize
622    }
623}
624
625/// 연접 비용 행렬 로더
626///
627/// 다양한 포맷에서 연접 비용 행렬을 로드합니다.
628pub struct MatrixLoader;
629
630impl MatrixLoader {
631    /// 자동 포맷 감지 로드
632    ///
633    /// 파일 확장자에 따라 적절한 로더를 선택합니다.
634    /// - `.def`: 텍스트 포맷
635    /// - `.bin`: 바이너리 포맷
636    /// - `.bin.zst`, `.zst`: 압축 바이너리 포맷
637    ///
638    /// # Errors
639    ///
640    /// 파일을 읽거나 파싱할 수 없는 경우 에러를 반환합니다.
641    pub fn load<P: AsRef<Path>>(path: P) -> Result<DenseMatrix> {
642        let path = path.as_ref();
643        let path_str = path.to_string_lossy();
644
645        if path_str.ends_with(".def") {
646            DenseMatrix::from_def_file(path)
647        } else if path_str.ends_with(".zst") || path_str.ends_with(".bin.zst") {
648            DenseMatrix::from_compressed_file(path)
649        } else if path_str.ends_with(".bin") {
650            DenseMatrix::from_bin_file(path)
651        } else {
652            // 기본: 바이너리 시도 후 텍스트 시도
653            DenseMatrix::from_bin_file(path).or_else(|_| DenseMatrix::from_def_file(path))
654        }
655    }
656
657    /// 메모리 맵으로 로드 (바이너리 파일만 지원)
658    ///
659    /// # Errors
660    ///
661    /// 파일을 읽거나 메모리 맵을 생성할 수 없는 경우 에러를 반환합니다.
662    pub fn load_mmap<P: AsRef<Path>>(path: P) -> Result<MmapMatrix> {
663        MmapMatrix::from_file(path)
664    }
665}
666
667/// 연접 비용 행렬을 위한 통합 타입
668///
669/// 다양한 행렬 구현을 하나의 타입으로 사용할 수 있습니다.
670pub enum ConnectionMatrix {
671    /// 밀집 행렬
672    Dense(DenseMatrix),
673    /// 희소 행렬
674    Sparse(SparseMatrix),
675    /// 메모리 맵 행렬
676    Mmap(MmapMatrix),
677}
678
679impl ConnectionMatrix {
680    /// 텍스트 파일에서 로드
681    ///
682    /// # Errors
683    ///
684    /// 파일을 읽거나 파싱할 수 없는 경우 에러를 반환합니다.
685    pub fn from_def_file<P: AsRef<Path>>(path: P) -> Result<Self> {
686        Ok(Self::Dense(DenseMatrix::from_def_file(path)?))
687    }
688
689    /// 바이너리 파일에서 로드
690    ///
691    /// # Errors
692    ///
693    /// 파일을 읽거나 파싱할 수 없는 경우 에러를 반환합니다.
694    pub fn from_bin_file<P: AsRef<Path>>(path: P) -> Result<Self> {
695        Ok(Self::Dense(DenseMatrix::from_bin_file(path)?))
696    }
697
698    /// 메모리 맵으로 로드
699    ///
700    /// # Errors
701    ///
702    /// 파일을 읽거나 메모리 맵을 생성할 수 없는 경우 에러를 반환합니다.
703    pub fn from_mmap_file<P: AsRef<Path>>(path: P) -> Result<Self> {
704        Ok(Self::Mmap(MmapMatrix::from_file(path)?))
705    }
706
707    /// 압축된 바이너리 파일에서 로드 (.zst)
708    ///
709    /// # Errors
710    ///
711    /// 파일을 읽거나 압축 해제/파싱할 수 없는 경우 에러를 반환합니다.
712    pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Self> {
713        Ok(Self::Dense(DenseMatrix::from_compressed_file(path)?))
714    }
715
716    /// 자동 포맷 감지 로드
717    ///
718    /// # Errors
719    ///
720    /// 파일을 읽거나 파싱할 수 없는 경우 에러를 반환합니다.
721    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
722        Ok(Self::Dense(MatrixLoader::load(path)?))
723    }
724}
725
726impl Matrix for ConnectionMatrix {
727    #[inline(always)]
728    fn get(&self, right_id: u16, left_id: u16) -> i32 {
729        match self {
730            Self::Dense(m) => m.get(right_id, left_id),
731            Self::Sparse(m) => m.get(right_id, left_id),
732            Self::Mmap(m) => m.get(right_id, left_id),
733        }
734    }
735
736    fn left_size(&self) -> usize {
737        match self {
738            Self::Dense(m) => m.left_size(),
739            Self::Sparse(m) => m.left_size(),
740            Self::Mmap(m) => m.left_size(),
741        }
742    }
743
744    fn right_size(&self) -> usize {
745        match self {
746            Self::Dense(m) => m.right_size(),
747            Self::Sparse(m) => m.right_size(),
748            Self::Mmap(m) => m.right_size(),
749        }
750    }
751}
752
753#[cfg(test)]
754#[allow(clippy::expect_used, clippy::unwrap_used, clippy::cast_lossless)]
755mod tests {
756    use super::*;
757
758    #[test]
759    fn test_dense_matrix_new() {
760        let matrix = DenseMatrix::new(10, 10, 0);
761        assert_eq!(matrix.left_size(), 10);
762        assert_eq!(matrix.right_size(), 10);
763        assert_eq!(matrix.entry_count(), 100);
764        assert_eq!(matrix.get(0, 0), 0);
765    }
766
767    #[test]
768    fn test_dense_matrix_set_get() {
769        let mut matrix = DenseMatrix::new(10, 10, 0);
770        matrix.set(3, 5, 100);
771        assert_eq!(matrix.get(3, 5), 100);
772        assert_eq!(matrix.get(5, 3), 0);
773    }
774
775    #[test]
776    fn test_dense_matrix_from_vec() {
777        let costs = vec![1, 2, 3, 4, 5, 6];
778        let matrix = DenseMatrix::from_vec(2, 3, costs).unwrap();
779        // costs[right_id + lsize * left_id]
780        // (0,0) = costs[0] = 1
781        // (1,0) = costs[1] = 2
782        // (0,1) = costs[2] = 3
783        // (1,1) = costs[3] = 4
784        // (0,2) = costs[4] = 5
785        // (1,2) = costs[5] = 6
786        assert_eq!(matrix.get(0, 0), 1);
787        assert_eq!(matrix.get(1, 0), 2);
788        assert_eq!(matrix.get(0, 1), 3);
789        assert_eq!(matrix.get(1, 1), 4);
790        assert_eq!(matrix.get(0, 2), 5);
791        assert_eq!(matrix.get(1, 2), 6);
792    }
793
794    #[test]
795    fn test_dense_matrix_from_vec_size_mismatch() {
796        let costs = vec![1, 2, 3];
797        let result = DenseMatrix::from_vec(2, 3, costs);
798        assert!(result.is_err());
799    }
800
801    #[test]
802    fn test_dense_matrix_boundary() {
803        let matrix = DenseMatrix::new(10, 10, 0);
804        // 경계 외 접근
805        assert_eq!(matrix.get(100, 100), INVALID_CONNECTION_COST);
806    }
807
808    #[test]
809    fn test_dense_matrix_def_reader() {
810        let data = "3 3\n0 0 100\n1 1 200\n2 2 300\n";
811        let reader = std::io::Cursor::new(data);
812        let matrix = DenseMatrix::from_def_reader(reader).unwrap();
813
814        assert_eq!(matrix.left_size(), 3);
815        assert_eq!(matrix.right_size(), 3);
816        assert_eq!(matrix.get(0, 0), 100);
817        assert_eq!(matrix.get(1, 1), 200);
818        assert_eq!(matrix.get(2, 2), 300);
819        // 설정되지 않은 값은 i16::MAX
820        assert_eq!(matrix.get(0, 1), i16::MAX as i32);
821    }
822
823    #[test]
824    fn test_dense_matrix_binary_roundtrip() {
825        let mut matrix = DenseMatrix::new(5, 5, 0);
826        matrix.set(0, 0, 100);
827        matrix.set(1, 2, -500);
828        matrix.set(4, 4, 32767);
829
830        let bytes = matrix.to_bin_bytes();
831        let loaded = DenseMatrix::from_bin_bytes(&bytes).unwrap();
832
833        assert_eq!(loaded.left_size(), 5);
834        assert_eq!(loaded.right_size(), 5);
835        assert_eq!(loaded.get(0, 0), 100);
836        assert_eq!(loaded.get(1, 2), -500);
837        assert_eq!(loaded.get(4, 4), 32767);
838    }
839
840    #[test]
841    fn test_sparse_matrix() {
842        let mut sparse = SparseMatrix::new(100, 100, 0);
843        sparse.set(10, 20, 500);
844        sparse.set(50, 50, -100);
845
846        assert_eq!(sparse.get(10, 20), 500);
847        assert_eq!(sparse.get(50, 50), -100);
848        assert_eq!(sparse.get(0, 0), 0); // 기본값
849
850        assert_eq!(sparse.entry_count_stored(), 2);
851        assert!(sparse.sparsity() > 0.99); // 거의 희소
852    }
853
854    #[test]
855    fn test_sparse_dense_conversion() {
856        let mut dense = DenseMatrix::new(10, 10, 0);
857        dense.set(3, 3, 100);
858        dense.set(5, 7, 200);
859
860        let sparse = SparseMatrix::from_dense(&dense, 0);
861        assert_eq!(sparse.entry_count_stored(), 2);
862        assert_eq!(sparse.get(3, 3), 100);
863        assert_eq!(sparse.get(5, 7), 200);
864
865        let converted = sparse.to_dense();
866        assert_eq!(converted.get(3, 3), 100);
867        assert_eq!(converted.get(5, 7), 200);
868        assert_eq!(converted.get(0, 0), 0);
869    }
870
871    #[test]
872    fn test_memory_size() {
873        let dense = DenseMatrix::new(100, 100, 0);
874        let mem_size = dense.memory_size();
875        // 최소 20000 바이트 (100*100*2)
876        assert!(mem_size >= 20000);
877
878        let sparse = SparseMatrix::new(100, 100, 0);
879        let sparse_size = sparse.memory_size();
880        // 희소 행렬은 훨씬 작음
881        assert!(sparse_size < mem_size);
882    }
883
884    #[test]
885    fn test_connection_matrix_enum() {
886        let dense = DenseMatrix::new(5, 5, 100);
887        let matrix = ConnectionMatrix::Dense(dense);
888
889        assert_eq!(matrix.left_size(), 5);
890        assert_eq!(matrix.right_size(), 5);
891        assert_eq!(matrix.get(0, 0), 100);
892    }
893
894    #[test]
895    fn test_large_matrix() {
896        // mecab-ko-dic의 실제 크기 (약 2800 x 2800)
897        let matrix = DenseMatrix::new(178, 178, 0);
898        assert_eq!(matrix.entry_count(), 178 * 178);
899        assert_eq!(
900            matrix.memory_size(),
901            std::mem::size_of::<DenseMatrix>() + 178 * 178 * 2
902        );
903    }
904
905    #[test]
906    fn test_def_with_comments_and_empty_lines() {
907        let data = "2 2\n# This is a comment\n\n0 0 10\n0 1 20\n\n1 0 30\n1 1 40\n";
908        let reader = std::io::Cursor::new(data);
909        let matrix = DenseMatrix::from_def_reader(reader).unwrap();
910
911        assert_eq!(matrix.get(0, 0), 10);
912        assert_eq!(matrix.get(0, 1), 20);
913        assert_eq!(matrix.get(1, 0), 30);
914        assert_eq!(matrix.get(1, 1), 40);
915    }
916}