1use std::io::{Read as IoRead, Write as IoWrite};
10use std::net::TcpStream;
11use std::path::PathBuf;
12use std::time::{Duration, Instant};
13
14use noxu_sync::Mutex;
15
16use crate::error::{RepError, Result};
17
18const RESTORE_MAGIC: u32 = 0x4E52_5354;
22
23#[derive(Debug, Clone)]
28pub struct NetworkRestoreConfig {
29 pub source_node: String,
31 pub source_host: String,
33 pub source_port: u16,
35 pub retain_log_files: bool,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum RestoreState {
42 NotStarted,
44 InProgress,
46 Completed,
48 Failed,
50}
51
52#[derive(Debug, Clone)]
54pub struct RestoreProgress {
55 pub state: RestoreState,
57 pub bytes_transferred: u64,
59 pub files_transferred: u32,
61 pub elapsed: Duration,
63}
64
65pub struct NetworkRestore {
70 config: NetworkRestoreConfig,
72 state: Mutex<RestoreState>,
74 progress: Mutex<RestoreProgress>,
76 local_log_dir: Option<PathBuf>,
80}
81
82fn validate_restore_filename(name: &str) -> Result<()> {
99 if name.is_empty() {
100 return Err(RepError::ProtocolError("unsafe filename: empty".into()));
101 }
102 if name == "." || name == ".." {
103 return Err(RepError::ProtocolError(format!(
104 "unsafe filename: {:?}",
105 name
106 )));
107 }
108 if name.starts_with('.') {
109 return Err(RepError::ProtocolError(format!(
110 "unsafe filename: hidden dotfile {:?}",
111 name
112 )));
113 }
114 for b in name.as_bytes() {
115 match *b {
116 b'/' | b'\\' => {
117 return Err(RepError::ProtocolError(format!(
118 "unsafe filename: path separator in {:?}",
119 name
120 )));
121 }
122 0 => {
123 return Err(RepError::ProtocolError(format!(
124 "unsafe filename: null byte in {:?}",
125 name
126 )));
127 }
128 _ => {}
129 }
130 }
131 Ok(())
132}
133
134impl NetworkRestore {
135 pub fn new(config: NetworkRestoreConfig) -> Self {
137 Self {
138 config,
139 state: Mutex::new(RestoreState::NotStarted),
140 progress: Mutex::new(RestoreProgress {
141 state: RestoreState::NotStarted,
142 bytes_transferred: 0,
143 files_transferred: 0,
144 elapsed: Duration::ZERO,
145 }),
146 local_log_dir: None,
147 }
148 }
149
150 pub fn with_local_dir(mut self, dir: impl Into<PathBuf>) -> Self {
154 self.local_log_dir = Some(dir.into());
155 self
156 }
157
158 pub fn get_state(&self) -> RestoreState {
160 *self.state.lock()
161 }
162
163 pub fn get_progress(&self) -> RestoreProgress {
165 self.progress.lock().clone()
166 }
167
168 pub fn get_config(&self) -> &NetworkRestoreConfig {
170 &self.config
171 }
172
173 pub fn execute(&self) -> Result<()> {
190 {
192 let state = self.state.lock();
193 if *state != RestoreState::NotStarted {
194 return Err(RepError::NetworkRestoreError(format!(
195 "execute called in wrong state: {:?}",
196 *state
197 )));
198 }
199 }
200
201 self.start()?;
203
204 let started_at = Instant::now();
205 let addr =
206 format!("{}:{}", self.config.source_host, self.config.source_port);
207
208 let mut stream = TcpStream::connect(&addr).map_err(|e| {
210 RepError::NetworkRestoreError(format!(
211 "cannot connect to source {}: {}",
212 addr, e
213 ))
214 })?;
215
216 let _ = stream.set_read_timeout(Some(Duration::from_secs(120)));
218
219 stream.write_all(&RESTORE_MAGIC.to_le_bytes()).map_err(|e| {
221 RepError::NetworkRestoreError(format!(
222 "sending restore magic: {}",
223 e
224 ))
225 })?;
226
227 let mut count_buf = [0u8; 4];
229 stream.read_exact(&mut count_buf).map_err(|e| {
230 RepError::NetworkRestoreError(format!("reading file count: {}", e))
231 })?;
232 let file_count = u32::from_le_bytes(count_buf);
233
234 let log_dir = self.local_log_dir.clone().unwrap_or_else(|| {
235 std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
236 });
237
238 let mut total_bytes: u64 = 0;
239 let mut files_done: u32 = 0;
240
241 for _ in 0..file_count {
242 let mut name_len_buf = [0u8; 2];
244 stream.read_exact(&mut name_len_buf).map_err(|e| {
245 RepError::NetworkRestoreError(format!(
246 "reading filename length: {}",
247 e
248 ))
249 })?;
250 let name_len = u16::from_le_bytes(name_len_buf) as usize;
251
252 let mut name_buf = vec![0u8; name_len];
253 stream.read_exact(&mut name_buf).map_err(|e| {
254 RepError::NetworkRestoreError(format!(
255 "reading filename: {}",
256 e
257 ))
258 })?;
259 let filename = String::from_utf8(name_buf).map_err(|e| {
260 RepError::NetworkRestoreError(format!(
261 "non-UTF8 filename: {}",
262 e
263 ))
264 })?;
265 validate_restore_filename(&filename)?;
266
267 let mut size_buf = [0u8; 8];
269 stream.read_exact(&mut size_buf).map_err(|e| {
270 RepError::NetworkRestoreError(format!(
271 "reading file size for '{}': {}",
272 filename, e
273 ))
274 })?;
275 let file_size = u64::from_le_bytes(size_buf);
276
277 let dest_path = log_dir.join(&filename);
281 if self.config.retain_log_files && dest_path.exists() {
282 let backup = log_dir.join(format!("{}.bak", filename));
283 let _ = std::fs::rename(&dest_path, &backup);
284 }
285
286 let mut out = std::fs::File::create(&dest_path).map_err(|e| {
288 RepError::NetworkRestoreError(format!(
289 "creating '{}': {}",
290 dest_path.display(),
291 e
292 ))
293 })?;
294
295 let mut remaining = file_size;
296 let mut chunk = vec![0u8; 65536];
297 while remaining > 0 {
298 let to_read = (remaining as usize).min(chunk.len());
299 stream.read_exact(&mut chunk[..to_read]).map_err(|e| {
300 RepError::NetworkRestoreError(format!(
301 "reading data for '{}': {}",
302 filename, e
303 ))
304 })?;
305 out.write_all(&chunk[..to_read]).map_err(|e| {
306 RepError::NetworkRestoreError(format!(
307 "writing '{}': {}",
308 dest_path.display(),
309 e
310 ))
311 })?;
312 remaining -= to_read as u64;
313 total_bytes += to_read as u64;
314 }
315
316 files_done += 1;
317 self.update_progress(total_bytes, files_done);
318 self.update_elapsed(started_at.elapsed());
319
320 log::debug!(
321 "NetworkRestore: received '{}' ({} bytes)",
322 filename,
323 file_size
324 );
325 }
326
327 self.update_elapsed(started_at.elapsed());
328 self.complete()?;
329
330 log::info!(
331 "NetworkRestore from {}: {} file(s), {} bytes transferred in {:?}",
332 addr,
333 files_done,
334 total_bytes,
335 started_at.elapsed(),
336 );
337
338 Ok(())
339 }
340
341 pub fn execute_via_dispatcher(&self) -> Result<()> {
361 use crate::net::Channel;
362 use crate::net::service_dispatcher::connect_to_service;
363 use crate::network_restore_server::RESTORE_SERVICE_NAME;
364
365 {
367 let state = self.state.lock();
368 if *state != RestoreState::NotStarted {
369 return Err(RepError::NetworkRestoreError(format!(
370 "execute_via_dispatcher called in wrong state: {:?}",
371 *state
372 )));
373 }
374 }
375
376 self.start()?;
377 let started_at = Instant::now();
378
379 let addr_str =
380 format!("{}:{}", self.config.source_host, self.config.source_port);
381 let addr: std::net::SocketAddr = addr_str.parse().map_err(|e| {
382 RepError::NetworkRestoreError(format!(
383 "bad source address {}: {}",
384 addr_str, e
385 ))
386 })?;
387
388 let channel =
389 connect_to_service(addr, RESTORE_SERVICE_NAME).map_err(|e| {
390 RepError::NetworkRestoreError(format!(
391 "connect_to_service(RESTORE) at {}: {}",
392 addr, e
393 ))
394 })?;
395
396 channel.send(&RESTORE_MAGIC.to_le_bytes()).map_err(|e| {
398 RepError::NetworkRestoreError(format!(
399 "sending restore magic via dispatcher: {}",
400 e
401 ))
402 })?;
403
404 let payload = channel
406 .receive(Duration::from_secs(120))
407 .map_err(|e| {
408 RepError::NetworkRestoreError(format!(
409 "receiving restore payload: {}",
410 e
411 ))
412 })?
413 .ok_or_else(|| {
414 RepError::NetworkRestoreError(
415 "empty restore payload from dispatcher".to_string(),
416 )
417 })?;
418
419 if payload.len() < 4 {
421 return Err(RepError::NetworkRestoreError(format!(
422 "truncated restore payload: {} bytes",
423 payload.len()
424 )));
425 }
426 let mut off = 0usize;
427 let mut buf4 = [0u8; 4];
428 buf4.copy_from_slice(&payload[off..off + 4]);
429 off += 4;
430 let file_count = u32::from_le_bytes(buf4);
431
432 let log_dir = self.local_log_dir.clone().unwrap_or_else(|| {
433 std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
434 });
435 std::fs::create_dir_all(&log_dir).map_err(|e| {
436 RepError::NetworkRestoreError(format!(
437 "creating log dir {}: {}",
438 log_dir.display(),
439 e
440 ))
441 })?;
442
443 let mut total_bytes: u64 = 0;
444 let mut files_done: u32 = 0;
445 let mut buf2 = [0u8; 2];
446 let mut buf8 = [0u8; 8];
447
448 for _ in 0..file_count {
449 if off + 2 > payload.len() {
450 return Err(RepError::NetworkRestoreError(
451 "truncated restore payload at name_len".to_string(),
452 ));
453 }
454 buf2.copy_from_slice(&payload[off..off + 2]);
455 off += 2;
456 let name_len = u16::from_le_bytes(buf2) as usize;
457 if off + name_len + 8 > payload.len() {
458 return Err(RepError::NetworkRestoreError(
459 "truncated restore payload at name+size".to_string(),
460 ));
461 }
462 let name_bytes = payload[off..off + name_len].to_vec();
463 off += name_len;
464 let filename = String::from_utf8(name_bytes).map_err(|e| {
465 RepError::NetworkRestoreError(format!(
466 "non-UTF8 filename: {}",
467 e
468 ))
469 })?;
470 validate_restore_filename(&filename)?;
471
472 buf8.copy_from_slice(&payload[off..off + 8]);
473 off += 8;
474 let file_size = u64::from_le_bytes(buf8) as usize;
475 if off + file_size > payload.len() {
476 return Err(RepError::NetworkRestoreError(format!(
477 "truncated restore payload at file body for {:?} \
478 (need {} bytes, have {})",
479 filename,
480 file_size,
481 payload.len() - off,
482 )));
483 }
484
485 let dest_path = log_dir.join(&filename);
486 if self.config.retain_log_files && dest_path.exists() {
487 let backup = log_dir.join(format!("{}.bak", filename));
488 let _ = std::fs::rename(&dest_path, &backup);
489 }
490
491 std::fs::write(&dest_path, &payload[off..off + file_size])
492 .map_err(|e| {
493 RepError::NetworkRestoreError(format!(
494 "writing '{}': {}",
495 dest_path.display(),
496 e
497 ))
498 })?;
499 off += file_size;
500 total_bytes += file_size as u64;
501 files_done += 1;
502 self.update_progress(total_bytes, files_done);
503 self.update_elapsed(started_at.elapsed());
504 }
505
506 self.update_elapsed(started_at.elapsed());
507 self.complete()?;
508
509 log::info!(
510 "NetworkRestore via dispatcher from {}: {} file(s), {} bytes in {:?}",
511 addr,
512 files_done,
513 total_bytes,
514 started_at.elapsed(),
515 );
516 Ok(())
517 }
518
519 pub fn start(&self) -> Result<()> {
530 let mut state = self.state.lock();
531 match *state {
532 RestoreState::NotStarted => {
533 *state = RestoreState::InProgress;
534 let mut progress = self.progress.lock();
535 progress.state = RestoreState::InProgress;
536 Ok(())
537 }
538 RestoreState::Completed => Err(RepError::NetworkRestoreError(
539 "restore already completed".into(),
540 )),
541 RestoreState::Failed => Err(RepError::NetworkRestoreError(
542 "restore already failed; create a new instance".into(),
543 )),
544 RestoreState::InProgress => Err(RepError::NetworkRestoreError(
545 "restore already in progress".into(),
546 )),
547 }
548 }
549
550 pub fn update_progress(&self, bytes: u64, files: u32) {
556 let mut progress = self.progress.lock();
557 progress.bytes_transferred = bytes;
558 progress.files_transferred = files;
559 }
560
561 pub fn update_elapsed(&self, elapsed: Duration) {
563 let mut progress = self.progress.lock();
564 progress.elapsed = elapsed;
565 }
566
567 pub fn complete(&self) -> Result<()> {
569 let mut state = self.state.lock();
570 match *state {
571 RestoreState::InProgress => {
572 *state = RestoreState::Completed;
573 let mut progress = self.progress.lock();
574 progress.state = RestoreState::Completed;
575 Ok(())
576 }
577 other => Err(RepError::NetworkRestoreError(format!(
578 "cannot complete from state {:?}",
579 other
580 ))),
581 }
582 }
583
584 pub fn fail(&self) -> Result<()> {
586 let mut state = self.state.lock();
587 match *state {
588 RestoreState::InProgress => {
589 *state = RestoreState::Failed;
590 let mut progress = self.progress.lock();
591 progress.state = RestoreState::Failed;
592 Ok(())
593 }
594 other => Err(RepError::NetworkRestoreError(format!(
595 "cannot fail from state {:?}",
596 other
597 ))),
598 }
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 fn test_config() -> NetworkRestoreConfig {
607 NetworkRestoreConfig {
608 source_node: "node1".into(),
609 source_host: "192.168.1.10".into(),
610 source_port: 5001,
611 retain_log_files: false,
612 }
613 }
614
615 #[test]
616 fn test_initial_state() {
617 let restore = NetworkRestore::new(test_config());
618 assert_eq!(restore.get_state(), RestoreState::NotStarted);
619
620 let progress = restore.get_progress();
621 assert_eq!(progress.state, RestoreState::NotStarted);
622 assert_eq!(progress.bytes_transferred, 0);
623 assert_eq!(progress.files_transferred, 0);
624 assert_eq!(progress.elapsed, Duration::ZERO);
625 }
626
627 #[test]
628 fn test_start() {
629 let restore = NetworkRestore::new(test_config());
630 restore.start().unwrap();
631 assert_eq!(restore.get_state(), RestoreState::InProgress);
632 assert_eq!(restore.get_progress().state, RestoreState::InProgress);
633 }
634
635 #[test]
636 fn test_start_twice_fails() {
637 let restore = NetworkRestore::new(test_config());
638 restore.start().unwrap();
639 let result = restore.start();
640 assert!(result.is_err());
641 }
642
643 #[test]
644 fn test_update_progress() {
645 let restore = NetworkRestore::new(test_config());
646 restore.start().unwrap();
647
648 restore.update_progress(1024 * 1024, 3);
649 let progress = restore.get_progress();
650 assert_eq!(progress.bytes_transferred, 1024 * 1024);
651 assert_eq!(progress.files_transferred, 3);
652 }
653
654 #[test]
655 fn test_update_elapsed() {
656 let restore = NetworkRestore::new(test_config());
657 restore.start().unwrap();
658
659 let elapsed = Duration::from_secs(42);
660 restore.update_elapsed(elapsed);
661 assert_eq!(restore.get_progress().elapsed, elapsed);
662 }
663
664 #[test]
665 fn test_complete() {
666 let restore = NetworkRestore::new(test_config());
667 restore.start().unwrap();
668 restore.complete().unwrap();
669 assert_eq!(restore.get_state(), RestoreState::Completed);
670 assert_eq!(restore.get_progress().state, RestoreState::Completed);
671 }
672
673 #[test]
674 fn test_complete_from_not_started_fails() {
675 let restore = NetworkRestore::new(test_config());
676 let result = restore.complete();
677 assert!(result.is_err());
678 }
679
680 #[test]
681 fn test_fail() {
682 let restore = NetworkRestore::new(test_config());
683 restore.start().unwrap();
684 restore.fail().unwrap();
685 assert_eq!(restore.get_state(), RestoreState::Failed);
686 assert_eq!(restore.get_progress().state, RestoreState::Failed);
687 }
688
689 #[test]
690 fn test_fail_from_not_started_fails() {
691 let restore = NetworkRestore::new(test_config());
692 let result = restore.fail();
693 assert!(result.is_err());
694 }
695
696 #[test]
697 fn test_start_after_completed_fails() {
698 let restore = NetworkRestore::new(test_config());
699 restore.start().unwrap();
700 restore.complete().unwrap();
701 let result = restore.start();
702 assert!(result.is_err());
703 }
704
705 #[test]
706 fn test_start_after_failed_fails() {
707 let restore = NetworkRestore::new(test_config());
708 restore.start().unwrap();
709 restore.fail().unwrap();
710 let result = restore.start();
711 assert!(result.is_err());
712 }
713
714 #[test]
715 fn test_config_accessor() {
716 let config = test_config();
717 let restore = NetworkRestore::new(config);
718 assert_eq!(restore.get_config().source_node, "node1");
719 assert_eq!(restore.get_config().source_host, "192.168.1.10");
720 assert_eq!(restore.get_config().source_port, 5001);
721 assert!(!restore.get_config().retain_log_files);
722 }
723
724 #[test]
725 fn test_retain_log_files_config() {
726 let mut config = test_config();
727 config.retain_log_files = true;
728 let restore = NetworkRestore::new(config);
729 assert!(restore.get_config().retain_log_files);
730 }
731
732 #[test]
733 fn test_full_lifecycle() {
734 let restore = NetworkRestore::new(test_config());
735
736 assert_eq!(restore.get_state(), RestoreState::NotStarted);
737
738 restore.start().unwrap();
739 assert_eq!(restore.get_state(), RestoreState::InProgress);
740
741 restore.update_progress(512, 1);
742 restore.update_progress(2048, 2);
743 restore.update_elapsed(Duration::from_secs(5));
744
745 let progress = restore.get_progress();
746 assert_eq!(progress.bytes_transferred, 2048);
747 assert_eq!(progress.files_transferred, 2);
748 assert_eq!(progress.elapsed, Duration::from_secs(5));
749
750 restore.complete().unwrap();
751 assert_eq!(restore.get_state(), RestoreState::Completed);
752 }
753
754 #[test]
755 fn test_fail_lifecycle() {
756 let restore = NetworkRestore::new(test_config());
757 restore.start().unwrap();
758 restore.update_progress(256, 1);
759 restore.fail().unwrap();
760
761 assert_eq!(restore.get_state(), RestoreState::Failed);
762 let progress = restore.get_progress();
764 assert_eq!(progress.bytes_transferred, 256);
765 assert_eq!(progress.files_transferred, 1);
766 }
767
768 fn assert_unsafe(name: &str) {
773 let err = validate_restore_filename(name)
774 .expect_err(&format!("expected rejection for {:?}", name));
775 match err {
776 RepError::ProtocolError(msg) => assert!(
777 msg.contains("unsafe filename"),
778 "unexpected message for {:?}: {}",
779 name,
780 msg
781 ),
782 other => {
783 panic!("expected ProtocolError for {:?}, got {:?}", name, other)
784 }
785 }
786 }
787
788 #[test]
789 fn test_validate_filename_rejects_empty() {
790 assert_unsafe("");
791 }
792
793 #[test]
794 fn test_validate_filename_rejects_dot_and_dotdot() {
795 assert_unsafe(".");
796 assert_unsafe("..");
797 }
798
799 #[test]
800 fn test_validate_filename_rejects_hidden_dotfile() {
801 assert_unsafe(".bashrc");
802 assert_unsafe(".hidden");
803 }
804
805 #[test]
806 fn test_validate_filename_rejects_path_separators() {
807 assert_unsafe("../etc/passwd");
808 assert_unsafe("/etc/passwd");
809 assert_unsafe("subdir/file.ndb");
810 assert_unsafe("dir\\file.ndb");
811 assert_unsafe("..\\windows\\system32");
812 }
813
814 #[test]
815 fn test_validate_filename_rejects_null_byte() {
816 assert_unsafe("good\0name.ndb");
817 assert_unsafe("\0");
818 }
819
820 #[test]
821 fn test_validate_filename_accepts_normal_log_files() {
822 validate_restore_filename("00000000.ndb").unwrap();
823 validate_restore_filename("00000001.ndb").unwrap();
824 validate_restore_filename("ffffffff.ndb").unwrap();
825 validate_restore_filename("data.bin").unwrap();
826 validate_restore_filename("name-with-dashes_and_underscores.ndb")
827 .unwrap();
828 }
829}