1use binary_vec_io::{binary_read_to_ref, binary_read_vec, binary_write_vec};
59use std::cmp::max;
60
61#[derive(Clone)]
62pub struct MirrorSparseMatrix {
63 x: Vec<u8>,
64}
65
66pub fn get_code_version_from_file(f: &str) -> u32 {
67 assert_eq!(std::mem::size_of::<usize>(), 8); let mut ff = std::fs::File::open(&f).unwrap();
69 let mut x = vec![0_u32; 11];
70 binary_read_to_ref::<u32>(&mut ff, &mut x[0], 11).unwrap();
71 x[10]
72}
73
74pub fn read_from_file(s: &mut MirrorSparseMatrix, f: &str) {
75 let mut ff = std::fs::File::open(&f).unwrap();
76 binary_read_vec::<u8>(&mut ff, &mut s.x).unwrap();
77 if s.code_version() != 0 && s.code_version() != 1 {
78 panic!(
79 "\nMirrorSparseMatrix: code_version has to be 0 or 1, but it is {}.\n",
80 s.code_version()
81 );
82 }
83 if s.storage_version() != 0 && s.storage_version() != 1 {
84 panic!(
85 "\nMirrorSparseMatrix: storage_version has to be 0 or 1, but it is {}.\n",
86 s.storage_version()
87 );
88 }
89}
90
91pub fn write_to_file(s: &MirrorSparseMatrix, f: &str) {
92 assert!(s.code_version() > 0);
93 let mut ff =
94 std::fs::File::create(&f).unwrap_or_else(|_| panic!("Failed to create file {}.", f));
95 binary_write_vec::<u8>(&mut ff, &s.x).unwrap();
96}
97
98fn get_u8_at_pos(v: &[u8], pos: usize) -> u8 {
99 v[pos]
100}
101
102fn get_u16_at_pos(v: &[u8], pos: usize) -> u16 {
103 let mut z = [0_u8; 2];
104 z.clone_from_slice(&v[pos..(2 + pos)]);
105 u16::from_le_bytes(z)
106}
107
108fn get_u32_at_pos(v: &[u8], pos: usize) -> u32 {
109 let mut z = [0_u8; 4];
110 z.clone_from_slice(&v[pos..(4 + pos)]);
111 u32::from_le_bytes(z)
112}
113
114fn _put_u8_at_pos(v: &mut Vec<u8>, pos: usize, val: u8) {
115 v[pos] = val;
116}
117
118fn _put_u16_at_pos(v: &mut Vec<u8>, pos: usize, val: u16) {
119 let z = val.to_le_bytes();
120 v[pos..(2 + pos)].clone_from_slice(&z[..2]);
121}
122
123fn put_u32_at_pos(v: &mut Vec<u8>, pos: usize, val: u32) {
124 let z = val.to_le_bytes();
125 v[pos..(4 + pos)].clone_from_slice(&z[..4]);
126}
127
128fn push_u8(v: &mut Vec<u8>, val: u8) {
129 v.push(val);
130}
131
132fn push_u16(v: &mut Vec<u8>, val: u16) {
133 let z = val.to_le_bytes();
134 for i in 0..2 {
135 v.push(z[i]);
136 }
137}
138
139fn push_u32(v: &mut Vec<u8>, val: u32) {
140 let z = val.to_le_bytes();
141 for i in 0..4 {
142 v.push(z[i]);
143 }
144}
145
146impl MirrorSparseMatrix {
147 pub fn new() -> MirrorSparseMatrix {
148 let v = Vec::<u8>::new();
149 MirrorSparseMatrix { x: v }
150 }
151
152 pub fn initialized(&self) -> bool {
153 !self.x.is_empty()
154 }
155
156 fn header_size() -> usize {
157 32 + 4 + 4 + 4 + 4 }
163
164 fn code_version(&self) -> usize {
165 get_u32_at_pos(&self.x, 32) as usize
166 }
167
168 fn storage_version(&self) -> usize {
169 get_u32_at_pos(&self.x, 36) as usize
170 }
171
172 pub fn build_from_vec(
173 x: &[Vec<(i32, i32)>],
174 row_labels: &[String],
175 col_labels: &[String],
176 ) -> MirrorSparseMatrix {
177 let mut max_col = 0_i32;
178 for i in 0..x.len() {
179 for j in 0..x[i].len() {
180 max_col = max(max_col, x[i][j].0);
181 }
182 }
183 let mut storage_version = 0_u32;
184 if max_col >= 65536 {
185 storage_version = 1_u32;
186 }
187 let hs = MirrorSparseMatrix::header_size();
188 let mut v = Vec::<u8>::new();
189 let mut total_bytes = hs + 4 * x.len();
190 for i in 0..x.len() {
191 let (mut m1, mut m2, mut m4) = (0, 0, 0);
192 for j in 0..x[i].len() {
193 if x[i][j].1 < 256 {
194 m1 += 1;
195 } else if x[i][j].1 < 65536 {
196 m2 += 1;
197 } else {
198 m4 += 1;
199 }
200 }
201 if storage_version == 0 {
202 total_bytes += 6 + 3 * m1 + 4 * m2 + 6 * m4;
203 } else {
204 total_bytes += 12 + 5 * m1 + 6 * m2 + 8 * m4;
205 }
206 }
207 let (n, k) = (row_labels.len(), col_labels.len());
208 assert_eq!(n, x.len());
209 total_bytes += 4 * (1 + n);
210 total_bytes += 4 * (1 + k);
211 let byte_start_of_row_labels = total_bytes;
212 for i in 0..n {
213 total_bytes += row_labels[i].len();
214 }
215 let byte_start_of_col_labels = total_bytes;
216 for j in 0..k {
217 total_bytes += col_labels[j].len();
218 }
219 v.reserve(total_bytes);
220 v.append(&mut b"MirrorSparseMatrix binary file \n".to_vec());
221 assert_eq!(v.len(), 32);
222 const CURRENT_CODE_VERSION: usize = 1;
223 let code_version = CURRENT_CODE_VERSION as u32;
224 push_u32(&mut v, code_version);
225 push_u32(&mut v, storage_version);
226 push_u32(&mut v, n as u32);
227 push_u32(&mut v, k as u32);
228 assert_eq!(v.len(), hs);
229 for _ in 0..n {
230 push_u32(&mut v, 0_u32);
231 }
232
233 let mut pos = byte_start_of_row_labels;
236 for i in 0..=n {
237 push_u32(&mut v, pos as u32);
238 if i < n {
239 pos += row_labels[i].len();
240 }
241 }
242 let mut pos = byte_start_of_col_labels;
243 for j in 0..=k {
244 push_u32(&mut v, pos as u32);
245 if j < k {
246 pos += col_labels[j].len();
247 }
248 }
249
250 for i in 0..n {
253 let p = v.len() as u32;
254 put_u32_at_pos(&mut v, hs + 4 * i, p);
255 let (mut m1, mut m2, mut m4) = (0, 0, 0);
256 for j in 0..x[i].len() {
257 if x[i][j].1 < 256 {
258 m1 += 1;
259 } else if x[i][j].1 < 65536 {
260 m2 += 1;
261 } else {
262 m4 += 1;
263 }
264 }
265 if storage_version == 0 {
266 push_u16(&mut v, m1 as u16);
267 push_u16(&mut v, m2 as u16);
268 push_u16(&mut v, m4 as u16);
269 } else {
270 push_u32(&mut v, m1 as u32);
271 push_u32(&mut v, m2 as u32);
272 push_u32(&mut v, m4 as u32);
273 }
274 for j in 0..x[i].len() {
275 if x[i][j].1 < 256 {
276 if storage_version == 0 {
277 push_u16(&mut v, x[i][j].0 as u16);
278 } else {
279 push_u32(&mut v, x[i][j].0 as u32);
280 }
281 push_u8(&mut v, x[i][j].1 as u8);
282 }
283 }
284 for j in 0..x[i].len() {
285 if x[i][j].1 >= 256 && x[i][j].1 < 65536 {
286 if storage_version == 0 {
287 push_u16(&mut v, x[i][j].0 as u16);
288 } else {
289 push_u32(&mut v, x[i][j].0 as u32);
290 }
291 push_u16(&mut v, x[i][j].1 as u16);
292 }
293 }
294 for j in 0..x[i].len() {
295 if x[i][j].1 >= 65536 {
296 if storage_version == 0 {
297 push_u16(&mut v, x[i][j].0 as u16);
298 } else {
299 push_u32(&mut v, x[i][j].0 as u32);
300 }
301 push_u32(&mut v, x[i][j].1 as u32);
302 }
303 }
304 }
305
306 for i in 0..n {
309 for p in 0..row_labels[i].len() {
310 v.push(row_labels[i].as_bytes()[p]);
311 }
312 }
313 for j in 0..k {
314 for p in 0..col_labels[j].len() {
315 v.push(col_labels[j].as_bytes()[p]);
316 pos += 1;
317 }
318 }
319
320 assert_eq!(total_bytes, v.len());
323 MirrorSparseMatrix { x: v }
324 }
325
326 pub fn nrows(&self) -> usize {
327 get_u32_at_pos(&self.x, 40) as usize
328 }
329
330 pub fn ncols(&self) -> usize {
331 get_u32_at_pos(&self.x, 44) as usize
332 }
333
334 fn start_of_row(&self, row: usize) -> usize {
335 let pos = MirrorSparseMatrix::header_size() + row * 4;
336 get_u32_at_pos(&self.x, pos) as usize
337 }
338
339 pub fn row_label(&self, i: usize) -> String {
340 let row_labels_start = MirrorSparseMatrix::header_size() + self.nrows() * 4;
341 let label_start = get_u32_at_pos(&self.x, row_labels_start + i * 4);
342 let label_stop = get_u32_at_pos(&self.x, row_labels_start + (i + 1) * 4);
343 let label_bytes = &self.x[label_start as usize..label_stop as usize];
344 String::from_utf8(label_bytes.to_vec()).unwrap()
345 }
346
347 pub fn col_label(&self, j: usize) -> String {
348 let col_labels_start = MirrorSparseMatrix::header_size() + self.nrows() * 8 + 4;
349 let label_start = get_u32_at_pos(&self.x, col_labels_start + j * 4);
350 let label_stop = get_u32_at_pos(&self.x, col_labels_start + (j + 1) * 4);
351 let label_bytes = &self.x[label_start as usize..label_stop as usize];
352 String::from_utf8(label_bytes.to_vec()).unwrap()
353 }
354
355 pub fn row(&self, row: usize) -> Vec<(usize, usize)> {
356 let mut all = Vec::<(usize, usize)>::new();
357 let s = self.start_of_row(row);
358 if self.storage_version() == 0 {
359 let m1 = get_u16_at_pos(&self.x, s) as usize;
360 let m2 = get_u16_at_pos(&self.x, s + 2) as usize;
361 let m4 = get_u16_at_pos(&self.x, s + 4) as usize;
362 for i in 0..m1 {
363 let pos = s + 6 + 3 * i;
364 let col = get_u16_at_pos(&self.x, pos) as usize;
365 let entry = get_u8_at_pos(&self.x, pos + 2) as usize;
366 all.push((col, entry));
367 }
368 for i in 0..m2 {
369 let pos = s + 6 + 3 * m1 + 4 * i;
370 let col = get_u16_at_pos(&self.x, pos) as usize;
371 let entry = get_u16_at_pos(&self.x, pos + 2) as usize;
372 all.push((col, entry));
373 }
374 for i in 0..m4 {
375 let pos = s + 6 + 3 * m1 + 4 * m2 + 6 * i;
376 let col = get_u16_at_pos(&self.x, pos) as usize;
377 let entry = get_u32_at_pos(&self.x, pos + 2) as usize;
378 all.push((col, entry));
379 }
380 } else {
381 let m1 = get_u32_at_pos(&self.x, s) as usize;
382 let m2 = get_u32_at_pos(&self.x, s + 4) as usize;
383 let m4 = get_u32_at_pos(&self.x, s + 8) as usize;
384 for i in 0..m1 {
385 let pos = s + 12 + 5 * i;
386 let col = get_u32_at_pos(&self.x, pos) as usize;
387 let entry = get_u8_at_pos(&self.x, pos + 4) as usize;
388 all.push((col, entry));
389 }
390 for i in 0..m2 {
391 let pos = s + 12 + 5 * m1 + 6 * i;
392 let col = get_u32_at_pos(&self.x, pos) as usize;
393 let entry = get_u16_at_pos(&self.x, pos + 4) as usize;
394 all.push((col, entry));
395 }
396 for i in 0..m4 {
397 let pos = s + 12 + 5 * m1 + 6 * m2 + 8 * i;
398 let col = get_u32_at_pos(&self.x, pos) as usize;
399 let entry = get_u32_at_pos(&self.x, pos + 4) as usize;
400 all.push((col, entry));
401 }
402 }
403 all.sort_unstable();
404 all
405 }
406
407 pub fn sum_of_row(&self, row: usize) -> usize {
408 let s = self.start_of_row(row);
409 let mut sum = 0;
410 if self.storage_version() == 0 {
411 let m1 = get_u16_at_pos(&self.x, s) as usize;
412 let m2 = get_u16_at_pos(&self.x, s + 2) as usize;
413 let m4 = get_u16_at_pos(&self.x, s + 4) as usize;
414 for i in 0..m1 {
415 let pos = s + 6 + 3 * i + 2;
416 sum += get_u8_at_pos(&self.x, pos) as usize;
417 }
418 for i in 0..m2 {
419 let pos = s + 6 + 3 * m1 + 4 * i + 2;
420 sum += get_u16_at_pos(&self.x, pos) as usize;
421 }
422 for i in 0..m4 {
423 let pos = s + 6 + 3 * m1 + 4 * m2 + 6 * i + 2;
424 sum += get_u32_at_pos(&self.x, pos) as usize;
425 }
426 } else {
427 let m1 = get_u32_at_pos(&self.x, s) as usize;
428 let m2 = get_u32_at_pos(&self.x, s + 4) as usize;
429 let m4 = get_u32_at_pos(&self.x, s + 8) as usize;
430 for i in 0..m1 {
431 let pos = s + 12 + 5 * i + 4;
432 sum += get_u8_at_pos(&self.x, pos) as usize;
433 }
434 for i in 0..m2 {
435 let pos = s + 12 + 5 * m1 + 6 * i + 4;
436 sum += get_u16_at_pos(&self.x, pos) as usize;
437 }
438 for i in 0..m4 {
439 let pos = s + 12 + 5 * m1 + 6 * m2 + 8 * i + 4;
440 sum += get_u32_at_pos(&self.x, pos) as usize;
441 }
442 }
443 sum
444 }
445
446 #[allow(dead_code)]
447 pub fn sum_of_col(&self, col: usize) -> usize {
448 let mut sum = 0;
449 if self.storage_version() == 0 {
450 for row in 0..self.nrows() {
451 let s = self.start_of_row(row);
452 let m1 = get_u16_at_pos(&self.x, s) as usize;
453 let m2 = get_u16_at_pos(&self.x, s + 2) as usize;
454 let m4 = get_u16_at_pos(&self.x, s + 4) as usize;
455 for i in 0..m1 {
456 let pos = s + 6 + 3 * i;
457 let f = get_u16_at_pos(&self.x, pos) as usize;
458 if f == col {
459 sum += get_u8_at_pos(&self.x, pos + 2) as usize;
460 }
461 }
462 for i in 0..m2 {
463 let pos = s + 6 + 3 * m1 + 4 * i;
464 let f = get_u16_at_pos(&self.x, pos) as usize;
465 if f == col {
466 sum += get_u16_at_pos(&self.x, pos + 2) as usize;
467 }
468 }
469 for i in 0..m4 {
470 let pos = s + 6 + 3 * m1 + 4 * m2 + 6 * i;
471 let f = get_u16_at_pos(&self.x, pos) as usize;
472 if f == col {
473 sum += get_u32_at_pos(&self.x, pos + 2) as usize;
474 }
475 }
476 }
477 } else {
478 for row in 0..self.nrows() {
479 let s = self.start_of_row(row);
480 let m1 = get_u32_at_pos(&self.x, s) as usize;
481 let m2 = get_u32_at_pos(&self.x, s + 4) as usize;
482 let m4 = get_u32_at_pos(&self.x, s + 8) as usize;
483 for i in 0..m1 {
484 let pos = s + 12 + 5 * i;
485 let f = get_u32_at_pos(&self.x, pos) as usize;
486 if f == col {
487 sum += get_u8_at_pos(&self.x, pos + 4) as usize;
488 }
489 }
490 for i in 0..m2 {
491 let pos = s + 12 + 5 * m1 + 6 * i;
492 let f = get_u32_at_pos(&self.x, pos) as usize;
493 if f == col {
494 sum += get_u16_at_pos(&self.x, pos + 4) as usize;
495 }
496 }
497 for i in 0..m4 {
498 let pos = s + 12 + 5 * m1 + 6 * m2 + 8 * i;
499 let f = get_u32_at_pos(&self.x, pos) as usize;
500 if f == col {
501 sum += get_u32_at_pos(&self.x, pos + 4) as usize;
502 }
503 }
504 }
505 }
506 sum
507 }
508
509 pub fn value(&self, row: usize, col: usize) -> usize {
510 let s = self.start_of_row(row);
511 if self.storage_version() == 0 {
512 let m1 = get_u16_at_pos(&self.x, s) as usize;
513 let m2 = get_u16_at_pos(&self.x, s + 2) as usize;
514 let m4 = get_u16_at_pos(&self.x, s + 4) as usize;
515 for i in 0..m1 {
516 let pos = s + 6 + 3 * i;
517 let f = get_u16_at_pos(&self.x, pos) as usize;
518 if f == col {
519 return get_u8_at_pos(&self.x, pos + 2) as usize;
520 }
521 }
522 for i in 0..m2 {
523 let pos = s + 6 + 3 * m1 + 4 * i;
524 let f = get_u16_at_pos(&self.x, pos) as usize;
525 if f == col {
526 return get_u16_at_pos(&self.x, pos + 2) as usize;
527 }
528 }
529 for i in 0..m4 {
530 let pos = s + 6 + 3 * m1 + 4 * m2 + 6 * i;
531 let f = get_u16_at_pos(&self.x, pos) as usize;
532 if f == col {
533 return get_u32_at_pos(&self.x, pos + 2) as usize;
534 }
535 }
536 0
537 } else {
538 let m1 = get_u32_at_pos(&self.x, s) as usize;
539 let m2 = get_u32_at_pos(&self.x, s + 4) as usize;
540 let m4 = get_u32_at_pos(&self.x, s + 8) as usize;
541 for i in 0..m1 {
542 let pos = s + 12 + 5 * i;
543 let f = get_u32_at_pos(&self.x, pos) as usize;
544 if f == col {
545 return get_u8_at_pos(&self.x, pos + 4) as usize;
546 }
547 }
548 for i in 0..m2 {
549 let pos = s + 12 + 5 * m1 + 6 * i;
550 let f = get_u32_at_pos(&self.x, pos) as usize;
551 if f == col {
552 return get_u16_at_pos(&self.x, pos + 4) as usize;
553 }
554 }
555 for i in 0..m4 {
556 let pos = s + 12 + 5 * m1 + 6 * m2 + 8 * i;
557 let f = get_u32_at_pos(&self.x, pos) as usize;
558 if f == col {
559 return get_u32_at_pos(&self.x, pos + 4) as usize;
560 }
561 }
562 0
563 }
564 }
565}
566
567impl Default for MirrorSparseMatrix {
568 fn default() -> Self {
569 Self::new()
570 }
571}
572
573#[cfg(test)]
574mod tests {
575
576 use super::*;
579 use io_utils::printme;
580 use pretty_trace::PrettyTrace;
581
582 #[test]
583 fn test_mirror_sparse_matrix() {
584 PrettyTrace::new().on();
585 for storage_version in 0..2 {
590 printme!(storage_version);
591 let mut x = Vec::<Vec<(i32, i32)>>::new();
592 let (n, k) = (10, 100);
593 for i in 0..n {
594 let mut y = Vec::<(i32, i32)>::new();
595 for j in 0..k {
596 let col: usize;
597 if storage_version == 0 {
598 col = i + j;
599 } else {
600 col = 10000 * i + j;
601 }
602 y.push((col as i32, (i * i * j) as i32));
603 }
604 x.push(y);
605 }
606 let test_row = 9;
607 let mut row_sum = 0;
608 for j in 0..x[test_row].len() {
609 row_sum += x[test_row][j].1 as usize;
610 }
611 let (mut row_labels, mut col_labels) = (Vec::<String>::new(), Vec::<String>::new());
612 for i in 0..n {
613 row_labels.push(format!("row {}", i));
614 }
615 for j in 0..k {
616 col_labels.push(format!("col {}", j));
617 }
618 let y = MirrorSparseMatrix::build_from_vec(&x, &row_labels, &col_labels);
619 let row_sum2 = y.sum_of_row(test_row);
620 assert_eq!(row_sum, row_sum2);
621 let test_col;
622 if storage_version == 0 {
623 test_col = 15;
624 } else {
625 test_col = 90001;
626 }
627 let mut col_sum = 0;
628 for i in 0..x.len() {
629 for j in 0..x[i].len() {
630 assert_eq!(x[i][j].1 as usize, y.value(i, x[i][j].0 as usize));
631 if x[i][j].0 as usize == test_col {
632 col_sum += x[i][j].1 as usize;
633 }
634 }
635 }
636 let col_sum2 = y.sum_of_col(test_col);
637 printme!(col_sum, col_sum2);
638 assert_eq!(col_sum, col_sum2);
639 assert_eq!(y.storage_version(), storage_version);
640 assert_eq!(y.row_label(5), row_labels[5]);
641 assert_eq!(y.col_label(7), col_labels[7]);
642 }
643 }
644}