1use std::io::{Read as IoRead, Write as IoWrite};
10use std::net::TcpStream;
11use std::path::PathBuf;
12use std::time::{Duration, Instant};
13
14use noxu_sync::Mutex;
15
16use crate::error::{RepError, Result};
17
18const RESTORE_MAGIC: u32 = 0x4E52_5354;
22
23#[derive(Debug, Clone)]
28pub struct NetworkRestoreConfig {
29 pub source_node: String,
31 pub source_host: String,
33 pub source_port: u16,
35 pub retain_log_files: bool,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum RestoreState {
42 NotStarted,
44 InProgress,
46 Completed,
48 Failed,
50}
51
52#[derive(Debug, Clone)]
54pub struct RestoreProgress {
55 pub state: RestoreState,
57 pub bytes_transferred: u64,
59 pub files_transferred: u32,
61 pub elapsed: Duration,
63}
64
65pub struct NetworkRestore {
70 config: NetworkRestoreConfig,
72 state: Mutex<RestoreState>,
74 progress: Mutex<RestoreProgress>,
76 local_log_dir: Option<PathBuf>,
80}
81
82fn validate_restore_filename(name: &str) -> Result<()> {
99 if name.is_empty() {
100 return Err(RepError::ProtocolError("unsafe filename: empty".into()));
101 }
102 if name == "." || name == ".." {
103 return Err(RepError::ProtocolError(format!(
104 "unsafe filename: {:?}",
105 name
106 )));
107 }
108 if name.starts_with('.') {
109 return Err(RepError::ProtocolError(format!(
110 "unsafe filename: hidden dotfile {:?}",
111 name
112 )));
113 }
114 for b in name.as_bytes() {
115 match *b {
116 b'/' | b'\\' => {
117 return Err(RepError::ProtocolError(format!(
118 "unsafe filename: path separator in {:?}",
119 name
120 )));
121 }
122 0 => {
123 return Err(RepError::ProtocolError(format!(
124 "unsafe filename: null byte in {:?}",
125 name
126 )));
127 }
128 _ => {}
129 }
130 }
131 Ok(())
132}
133
134impl NetworkRestore {
135 pub fn new(config: NetworkRestoreConfig) -> Self {
137 Self {
138 config,
139 state: Mutex::new(RestoreState::NotStarted),
140 progress: Mutex::new(RestoreProgress {
141 state: RestoreState::NotStarted,
142 bytes_transferred: 0,
143 files_transferred: 0,
144 elapsed: Duration::ZERO,
145 }),
146 local_log_dir: None,
147 }
148 }
149
150 pub fn with_local_dir(mut self, dir: impl Into<PathBuf>) -> Self {
154 self.local_log_dir = Some(dir.into());
155 self
156 }
157
158 pub fn get_state(&self) -> RestoreState {
160 *self.state.lock()
161 }
162
163 pub fn get_progress(&self) -> RestoreProgress {
165 self.progress.lock().clone()
166 }
167
168 pub fn get_config(&self) -> &NetworkRestoreConfig {
170 &self.config
171 }
172
173 pub fn execute(&self) -> Result<()> {
190 {
192 let state = self.state.lock();
193 if *state != RestoreState::NotStarted {
194 return Err(RepError::NetworkRestoreError(format!(
195 "execute called in wrong state: {:?}",
196 *state
197 )));
198 }
199 }
200
201 self.start()?;
203
204 let started_at = Instant::now();
205 let addr =
206 format!("{}:{}", self.config.source_host, self.config.source_port);
207
208 let mut stream = TcpStream::connect(&addr).map_err(|e| {
210 RepError::NetworkRestoreError(format!(
211 "cannot connect to source {}: {}",
212 addr, e
213 ))
214 })?;
215
216 let _ = stream.set_read_timeout(Some(Duration::from_secs(120)));
218
219 stream.write_all(&RESTORE_MAGIC.to_le_bytes()).map_err(|e| {
221 RepError::NetworkRestoreError(format!(
222 "sending restore magic: {}",
223 e
224 ))
225 })?;
226
227 let mut count_buf = [0u8; 4];
229 stream.read_exact(&mut count_buf).map_err(|e| {
230 RepError::NetworkRestoreError(format!("reading file count: {}", e))
231 })?;
232 let file_count = u32::from_le_bytes(count_buf);
233
234 let log_dir = self.local_log_dir.clone().unwrap_or_else(|| {
235 std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
236 });
237
238 let mut total_bytes: u64 = 0;
239 let mut files_done: u32 = 0;
240
241 for _ in 0..file_count {
242 let mut name_len_buf = [0u8; 2];
244 stream.read_exact(&mut name_len_buf).map_err(|e| {
245 RepError::NetworkRestoreError(format!(
246 "reading filename length: {}",
247 e
248 ))
249 })?;
250 let name_len = u16::from_le_bytes(name_len_buf) as usize;
251
252 let mut name_buf = vec![0u8; name_len];
253 stream.read_exact(&mut name_buf).map_err(|e| {
254 RepError::NetworkRestoreError(format!(
255 "reading filename: {}",
256 e
257 ))
258 })?;
259 let filename = String::from_utf8(name_buf).map_err(|e| {
260 RepError::NetworkRestoreError(format!(
261 "non-UTF8 filename: {}",
262 e
263 ))
264 })?;
265 validate_restore_filename(&filename)?;
266
267 let mut size_buf = [0u8; 8];
269 stream.read_exact(&mut size_buf).map_err(|e| {
270 RepError::NetworkRestoreError(format!(
271 "reading file size for '{}': {}",
272 filename, e
273 ))
274 })?;
275 let file_size = u64::from_le_bytes(size_buf);
276
277 let dest_path = log_dir.join(&filename);
281 if self.config.retain_log_files && dest_path.exists() {
282 let backup = log_dir.join(format!("{}.bak", filename));
283 let _ = std::fs::rename(&dest_path, &backup);
284 }
285
286 let mut out = std::fs::File::create(&dest_path).map_err(|e| {
288 RepError::NetworkRestoreError(format!(
289 "creating '{}': {}",
290 dest_path.display(),
291 e
292 ))
293 })?;
294
295 let mut remaining = file_size;
296 let mut chunk = vec![0u8; 65536];
297 let mut digest = crc32fast::Hasher::new();
301 while remaining > 0 {
302 let to_read = (remaining as usize).min(chunk.len());
303 stream.read_exact(&mut chunk[..to_read]).map_err(|e| {
304 RepError::NetworkRestoreError(format!(
305 "reading data for '{}': {}",
306 filename, e
307 ))
308 })?;
309 digest.update(&chunk[..to_read]);
310 out.write_all(&chunk[..to_read]).map_err(|e| {
311 RepError::NetworkRestoreError(format!(
312 "writing '{}': {}",
313 dest_path.display(),
314 e
315 ))
316 })?;
317 remaining -= to_read as u64;
318 total_bytes += to_read as u64;
319 }
320 let mut crc_buf = [0u8; 4];
322 stream.read_exact(&mut crc_buf).map_err(|e| {
323 RepError::NetworkRestoreError(format!(
324 "reading digest for '{}': {}",
325 filename, e
326 ))
327 })?;
328 let want = u32::from_le_bytes(crc_buf);
329 let got = digest.finalize();
330 if want != got {
331 let _ = std::fs::remove_file(&dest_path);
332 return Err(RepError::NetworkRestoreError(format!(
333 "digest mismatch for '{}': expected {:#010x}, got {:#010x} (file corrupted or truncated in transit)",
334 filename, want, got
335 )));
336 }
337
338 files_done += 1;
339 self.update_progress(total_bytes, files_done);
340 self.update_elapsed(started_at.elapsed());
341
342 log::debug!(
343 "NetworkRestore: received '{}' ({} bytes)",
344 filename,
345 file_size
346 );
347 }
348
349 self.update_elapsed(started_at.elapsed());
350 self.complete()?;
351
352 log::info!(
353 "NetworkRestore from {}: {} file(s), {} bytes transferred in {:?}",
354 addr,
355 files_done,
356 total_bytes,
357 started_at.elapsed(),
358 );
359
360 Ok(())
361 }
362
363 pub fn execute_via_dispatcher(&self) -> Result<()> {
383 use crate::net::Channel;
384 use crate::net::service_dispatcher::connect_to_service;
385 use crate::network_restore_server::RESTORE_SERVICE_NAME;
386
387 {
389 let state = self.state.lock();
390 if *state != RestoreState::NotStarted {
391 return Err(RepError::NetworkRestoreError(format!(
392 "execute_via_dispatcher called in wrong state: {:?}",
393 *state
394 )));
395 }
396 }
397
398 self.start()?;
399 let started_at = Instant::now();
400
401 let addr_str =
402 format!("{}:{}", self.config.source_host, self.config.source_port);
403 let addr: std::net::SocketAddr = addr_str.parse().map_err(|e| {
404 RepError::NetworkRestoreError(format!(
405 "bad source address {}: {}",
406 addr_str, e
407 ))
408 })?;
409
410 let channel =
411 connect_to_service(addr, RESTORE_SERVICE_NAME).map_err(|e| {
412 RepError::NetworkRestoreError(format!(
413 "connect_to_service(RESTORE) at {}: {}",
414 addr, e
415 ))
416 })?;
417
418 channel.send(&RESTORE_MAGIC.to_le_bytes()).map_err(|e| {
420 RepError::NetworkRestoreError(format!(
421 "sending restore magic via dispatcher: {}",
422 e
423 ))
424 })?;
425
426 let payload = channel
428 .receive(Duration::from_secs(120))
429 .map_err(|e| {
430 RepError::NetworkRestoreError(format!(
431 "receiving restore payload: {}",
432 e
433 ))
434 })?
435 .ok_or_else(|| {
436 RepError::NetworkRestoreError(
437 "empty restore payload from dispatcher".to_string(),
438 )
439 })?;
440
441 if payload.len() < 4 {
443 return Err(RepError::NetworkRestoreError(format!(
444 "truncated restore payload: {} bytes",
445 payload.len()
446 )));
447 }
448 let mut off = 0usize;
449 let mut buf4 = [0u8; 4];
450 buf4.copy_from_slice(&payload[off..off + 4]);
451 off += 4;
452 let file_count = u32::from_le_bytes(buf4);
453
454 let log_dir = self.local_log_dir.clone().unwrap_or_else(|| {
455 std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
456 });
457 std::fs::create_dir_all(&log_dir).map_err(|e| {
458 RepError::NetworkRestoreError(format!(
459 "creating log dir {}: {}",
460 log_dir.display(),
461 e
462 ))
463 })?;
464
465 let mut total_bytes: u64 = 0;
466 let mut files_done: u32 = 0;
467 let mut buf2 = [0u8; 2];
468 let mut buf8 = [0u8; 8];
469
470 for _ in 0..file_count {
471 if off + 2 > payload.len() {
472 return Err(RepError::NetworkRestoreError(
473 "truncated restore payload at name_len".to_string(),
474 ));
475 }
476 buf2.copy_from_slice(&payload[off..off + 2]);
477 off += 2;
478 let name_len = u16::from_le_bytes(buf2) as usize;
479 if off + name_len + 8 > payload.len() {
480 return Err(RepError::NetworkRestoreError(
481 "truncated restore payload at name+size".to_string(),
482 ));
483 }
484 let name_bytes = payload[off..off + name_len].to_vec();
485 off += name_len;
486 let filename = String::from_utf8(name_bytes).map_err(|e| {
487 RepError::NetworkRestoreError(format!(
488 "non-UTF8 filename: {}",
489 e
490 ))
491 })?;
492 validate_restore_filename(&filename)?;
493
494 buf8.copy_from_slice(&payload[off..off + 8]);
495 off += 8;
496 let file_size = u64::from_le_bytes(buf8) as usize;
497 if off + file_size + 4 > payload.len() {
499 return Err(RepError::NetworkRestoreError(format!(
500 "truncated restore payload at file body for {:?} \
501 (need {} bytes + 4 digest, have {})",
502 filename,
503 file_size,
504 payload.len() - off,
505 )));
506 }
507 let body = &payload[off..off + file_size];
509 let want = u32::from_le_bytes(
510 payload[off + file_size..off + file_size + 4]
511 .try_into()
512 .expect("4-byte CRC slice"),
513 );
514 let got = crc32fast::hash(body);
515 if want != got {
516 return Err(RepError::NetworkRestoreError(format!(
517 "digest mismatch for '{}': expected {:#010x}, got {:#010x} \
518 (file corrupted or truncated in transit)",
519 filename, want, got
520 )));
521 }
522
523 let dest_path = log_dir.join(&filename);
524 if self.config.retain_log_files && dest_path.exists() {
525 let backup = log_dir.join(format!("{}.bak", filename));
526 let _ = std::fs::rename(&dest_path, &backup);
527 }
528
529 std::fs::write(&dest_path, body).map_err(|e| {
530 RepError::NetworkRestoreError(format!(
531 "writing '{}': {}",
532 dest_path.display(),
533 e
534 ))
535 })?;
536 off += file_size + 4;
537 total_bytes += file_size as u64;
538 files_done += 1;
539 self.update_progress(total_bytes, files_done);
540 self.update_elapsed(started_at.elapsed());
541 }
542
543 self.update_elapsed(started_at.elapsed());
544 self.complete()?;
545
546 log::info!(
547 "NetworkRestore via dispatcher from {}: {} file(s), {} bytes in {:?}",
548 addr,
549 files_done,
550 total_bytes,
551 started_at.elapsed(),
552 );
553 Ok(())
554 }
555
556 pub fn start(&self) -> Result<()> {
567 let mut state = self.state.lock();
568 match *state {
569 RestoreState::NotStarted => {
570 *state = RestoreState::InProgress;
571 let mut progress = self.progress.lock();
572 progress.state = RestoreState::InProgress;
573 Ok(())
574 }
575 RestoreState::Completed => Err(RepError::NetworkRestoreError(
576 "restore already completed".into(),
577 )),
578 RestoreState::Failed => Err(RepError::NetworkRestoreError(
579 "restore already failed; create a new instance".into(),
580 )),
581 RestoreState::InProgress => Err(RepError::NetworkRestoreError(
582 "restore already in progress".into(),
583 )),
584 }
585 }
586
587 pub fn update_progress(&self, bytes: u64, files: u32) {
593 let mut progress = self.progress.lock();
594 progress.bytes_transferred = bytes;
595 progress.files_transferred = files;
596 }
597
598 pub fn update_elapsed(&self, elapsed: Duration) {
600 let mut progress = self.progress.lock();
601 progress.elapsed = elapsed;
602 }
603
604 pub fn complete(&self) -> Result<()> {
606 let mut state = self.state.lock();
607 match *state {
608 RestoreState::InProgress => {
609 *state = RestoreState::Completed;
610 let mut progress = self.progress.lock();
611 progress.state = RestoreState::Completed;
612 Ok(())
613 }
614 other => Err(RepError::NetworkRestoreError(format!(
615 "cannot complete from state {:?}",
616 other
617 ))),
618 }
619 }
620
621 pub fn fail(&self) -> Result<()> {
623 let mut state = self.state.lock();
624 match *state {
625 RestoreState::InProgress => {
626 *state = RestoreState::Failed;
627 let mut progress = self.progress.lock();
628 progress.state = RestoreState::Failed;
629 Ok(())
630 }
631 other => Err(RepError::NetworkRestoreError(format!(
632 "cannot fail from state {:?}",
633 other
634 ))),
635 }
636 }
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642
643 fn test_config() -> NetworkRestoreConfig {
644 NetworkRestoreConfig {
645 source_node: "node1".into(),
646 source_host: "192.168.1.10".into(),
647 source_port: 5001,
648 retain_log_files: false,
649 }
650 }
651
652 #[test]
653 fn test_initial_state() {
654 let restore = NetworkRestore::new(test_config());
655 assert_eq!(restore.get_state(), RestoreState::NotStarted);
656
657 let progress = restore.get_progress();
658 assert_eq!(progress.state, RestoreState::NotStarted);
659 assert_eq!(progress.bytes_transferred, 0);
660 assert_eq!(progress.files_transferred, 0);
661 assert_eq!(progress.elapsed, Duration::ZERO);
662 }
663
664 #[test]
665 fn test_start() {
666 let restore = NetworkRestore::new(test_config());
667 restore.start().unwrap();
668 assert_eq!(restore.get_state(), RestoreState::InProgress);
669 assert_eq!(restore.get_progress().state, RestoreState::InProgress);
670 }
671
672 #[test]
673 fn test_start_twice_fails() {
674 let restore = NetworkRestore::new(test_config());
675 restore.start().unwrap();
676 let result = restore.start();
677 assert!(result.is_err());
678 }
679
680 #[test]
681 fn test_update_progress() {
682 let restore = NetworkRestore::new(test_config());
683 restore.start().unwrap();
684
685 restore.update_progress(1024 * 1024, 3);
686 let progress = restore.get_progress();
687 assert_eq!(progress.bytes_transferred, 1024 * 1024);
688 assert_eq!(progress.files_transferred, 3);
689 }
690
691 #[test]
692 fn test_update_elapsed() {
693 let restore = NetworkRestore::new(test_config());
694 restore.start().unwrap();
695
696 let elapsed = Duration::from_secs(42);
697 restore.update_elapsed(elapsed);
698 assert_eq!(restore.get_progress().elapsed, elapsed);
699 }
700
701 #[test]
702 fn test_complete() {
703 let restore = NetworkRestore::new(test_config());
704 restore.start().unwrap();
705 restore.complete().unwrap();
706 assert_eq!(restore.get_state(), RestoreState::Completed);
707 assert_eq!(restore.get_progress().state, RestoreState::Completed);
708 }
709
710 #[test]
711 fn test_complete_from_not_started_fails() {
712 let restore = NetworkRestore::new(test_config());
713 let result = restore.complete();
714 assert!(result.is_err());
715 }
716
717 #[test]
718 fn test_fail() {
719 let restore = NetworkRestore::new(test_config());
720 restore.start().unwrap();
721 restore.fail().unwrap();
722 assert_eq!(restore.get_state(), RestoreState::Failed);
723 assert_eq!(restore.get_progress().state, RestoreState::Failed);
724 }
725
726 #[test]
727 fn test_fail_from_not_started_fails() {
728 let restore = NetworkRestore::new(test_config());
729 let result = restore.fail();
730 assert!(result.is_err());
731 }
732
733 #[test]
734 fn test_start_after_completed_fails() {
735 let restore = NetworkRestore::new(test_config());
736 restore.start().unwrap();
737 restore.complete().unwrap();
738 let result = restore.start();
739 assert!(result.is_err());
740 }
741
742 #[test]
743 fn test_start_after_failed_fails() {
744 let restore = NetworkRestore::new(test_config());
745 restore.start().unwrap();
746 restore.fail().unwrap();
747 let result = restore.start();
748 assert!(result.is_err());
749 }
750
751 #[test]
752 fn test_config_accessor() {
753 let config = test_config();
754 let restore = NetworkRestore::new(config);
755 assert_eq!(restore.get_config().source_node, "node1");
756 assert_eq!(restore.get_config().source_host, "192.168.1.10");
757 assert_eq!(restore.get_config().source_port, 5001);
758 assert!(!restore.get_config().retain_log_files);
759 }
760
761 #[test]
762 fn test_retain_log_files_config() {
763 let mut config = test_config();
764 config.retain_log_files = true;
765 let restore = NetworkRestore::new(config);
766 assert!(restore.get_config().retain_log_files);
767 }
768
769 #[test]
770 fn test_full_lifecycle() {
771 let restore = NetworkRestore::new(test_config());
772
773 assert_eq!(restore.get_state(), RestoreState::NotStarted);
774
775 restore.start().unwrap();
776 assert_eq!(restore.get_state(), RestoreState::InProgress);
777
778 restore.update_progress(512, 1);
779 restore.update_progress(2048, 2);
780 restore.update_elapsed(Duration::from_secs(5));
781
782 let progress = restore.get_progress();
783 assert_eq!(progress.bytes_transferred, 2048);
784 assert_eq!(progress.files_transferred, 2);
785 assert_eq!(progress.elapsed, Duration::from_secs(5));
786
787 restore.complete().unwrap();
788 assert_eq!(restore.get_state(), RestoreState::Completed);
789 }
790
791 #[test]
792 fn test_fail_lifecycle() {
793 let restore = NetworkRestore::new(test_config());
794 restore.start().unwrap();
795 restore.update_progress(256, 1);
796 restore.fail().unwrap();
797
798 assert_eq!(restore.get_state(), RestoreState::Failed);
799 let progress = restore.get_progress();
801 assert_eq!(progress.bytes_transferred, 256);
802 assert_eq!(progress.files_transferred, 1);
803 }
804
805 fn assert_unsafe(name: &str) {
810 let err = validate_restore_filename(name)
811 .expect_err(&format!("expected rejection for {:?}", name));
812 match err {
813 RepError::ProtocolError(msg) => assert!(
814 msg.contains("unsafe filename"),
815 "unexpected message for {:?}: {}",
816 name,
817 msg
818 ),
819 other => {
820 panic!("expected ProtocolError for {:?}, got {:?}", name, other)
821 }
822 }
823 }
824
825 #[test]
826 fn test_validate_filename_rejects_empty() {
827 assert_unsafe("");
828 }
829
830 #[test]
831 fn test_validate_filename_rejects_dot_and_dotdot() {
832 assert_unsafe(".");
833 assert_unsafe("..");
834 }
835
836 #[test]
837 fn test_validate_filename_rejects_hidden_dotfile() {
838 assert_unsafe(".bashrc");
839 assert_unsafe(".hidden");
840 }
841
842 #[test]
843 fn test_validate_filename_rejects_path_separators() {
844 assert_unsafe("../etc/passwd");
845 assert_unsafe("/etc/passwd");
846 assert_unsafe("subdir/file.ndb");
847 assert_unsafe("dir\\file.ndb");
848 assert_unsafe("..\\windows\\system32");
849 }
850
851 #[test]
852 fn test_validate_filename_rejects_null_byte() {
853 assert_unsafe("good\0name.ndb");
854 assert_unsafe("\0");
855 }
856
857 #[test]
858 fn test_validate_filename_accepts_normal_log_files() {
859 validate_restore_filename("00000000.ndb").unwrap();
860 validate_restore_filename("00000001.ndb").unwrap();
861 validate_restore_filename("ffffffff.ndb").unwrap();
862 validate_restore_filename("data.bin").unwrap();
863 validate_restore_filename("name-with-dashes_and_underscores.ndb")
864 .unwrap();
865 }
866}