use std::io::{Read as IoRead, Write as IoWrite};
use std::net::TcpStream;
use std::path::PathBuf;
use std::time::{Duration, Instant};
use noxu_sync::Mutex;
use crate::error::{RepError, Result};
const RESTORE_MAGIC: u32 = 0x4E52_5354;
#[derive(Debug, Clone)]
pub struct NetworkRestoreConfig {
pub source_node: String,
pub source_host: String,
pub source_port: u16,
pub retain_log_files: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RestoreState {
NotStarted,
InProgress,
Completed,
Failed,
}
#[derive(Debug, Clone)]
pub struct RestoreProgress {
pub state: RestoreState,
pub bytes_transferred: u64,
pub files_transferred: u32,
pub elapsed: Duration,
}
pub struct NetworkRestore {
config: NetworkRestoreConfig,
state: Mutex<RestoreState>,
progress: Mutex<RestoreProgress>,
local_log_dir: Option<PathBuf>,
}
fn validate_restore_filename(name: &str) -> Result<()> {
if name.is_empty() {
return Err(RepError::ProtocolError("unsafe filename: empty".into()));
}
if name == "." || name == ".." {
return Err(RepError::ProtocolError(format!(
"unsafe filename: {:?}",
name
)));
}
if name.starts_with('.') {
return Err(RepError::ProtocolError(format!(
"unsafe filename: hidden dotfile {:?}",
name
)));
}
for b in name.as_bytes() {
match *b {
b'/' | b'\\' => {
return Err(RepError::ProtocolError(format!(
"unsafe filename: path separator in {:?}",
name
)));
}
0 => {
return Err(RepError::ProtocolError(format!(
"unsafe filename: null byte in {:?}",
name
)));
}
_ => {}
}
}
Ok(())
}
impl NetworkRestore {
pub fn new(config: NetworkRestoreConfig) -> Self {
Self {
config,
state: Mutex::new(RestoreState::NotStarted),
progress: Mutex::new(RestoreProgress {
state: RestoreState::NotStarted,
bytes_transferred: 0,
files_transferred: 0,
elapsed: Duration::ZERO,
}),
local_log_dir: None,
}
}
pub fn with_local_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.local_log_dir = Some(dir.into());
self
}
pub fn get_state(&self) -> RestoreState {
*self.state.lock()
}
pub fn get_progress(&self) -> RestoreProgress {
self.progress.lock().clone()
}
pub fn get_config(&self) -> &NetworkRestoreConfig {
&self.config
}
pub fn execute(&self) -> Result<()> {
{
let state = self.state.lock();
if *state != RestoreState::NotStarted {
return Err(RepError::NetworkRestoreError(format!(
"execute called in wrong state: {:?}",
*state
)));
}
}
self.start()?;
let started_at = Instant::now();
let addr =
format!("{}:{}", self.config.source_host, self.config.source_port);
let mut stream = TcpStream::connect(&addr).map_err(|e| {
RepError::NetworkRestoreError(format!(
"cannot connect to source {}: {}",
addr, e
))
})?;
let _ = stream.set_read_timeout(Some(Duration::from_secs(120)));
stream.write_all(&RESTORE_MAGIC.to_le_bytes()).map_err(|e| {
RepError::NetworkRestoreError(format!(
"sending restore magic: {}",
e
))
})?;
let mut count_buf = [0u8; 4];
stream.read_exact(&mut count_buf).map_err(|e| {
RepError::NetworkRestoreError(format!("reading file count: {}", e))
})?;
let file_count = u32::from_le_bytes(count_buf);
let log_dir = self.local_log_dir.clone().unwrap_or_else(|| {
std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
});
let mut total_bytes: u64 = 0;
let mut files_done: u32 = 0;
for _ in 0..file_count {
let mut name_len_buf = [0u8; 2];
stream.read_exact(&mut name_len_buf).map_err(|e| {
RepError::NetworkRestoreError(format!(
"reading filename length: {}",
e
))
})?;
let name_len = u16::from_le_bytes(name_len_buf) as usize;
let mut name_buf = vec![0u8; name_len];
stream.read_exact(&mut name_buf).map_err(|e| {
RepError::NetworkRestoreError(format!(
"reading filename: {}",
e
))
})?;
let filename = String::from_utf8(name_buf).map_err(|e| {
RepError::NetworkRestoreError(format!(
"non-UTF8 filename: {}",
e
))
})?;
validate_restore_filename(&filename)?;
let mut size_buf = [0u8; 8];
stream.read_exact(&mut size_buf).map_err(|e| {
RepError::NetworkRestoreError(format!(
"reading file size for '{}': {}",
filename, e
))
})?;
let file_size = u64::from_le_bytes(size_buf);
let dest_path = log_dir.join(&filename);
if self.config.retain_log_files && dest_path.exists() {
let backup = log_dir.join(format!("{}.bak", filename));
let _ = std::fs::rename(&dest_path, &backup);
}
let mut out = std::fs::File::create(&dest_path).map_err(|e| {
RepError::NetworkRestoreError(format!(
"creating '{}': {}",
dest_path.display(),
e
))
})?;
let mut remaining = file_size;
let mut chunk = vec![0u8; 65536];
let mut digest = crc32fast::Hasher::new();
while remaining > 0 {
let to_read = (remaining as usize).min(chunk.len());
stream.read_exact(&mut chunk[..to_read]).map_err(|e| {
RepError::NetworkRestoreError(format!(
"reading data for '{}': {}",
filename, e
))
})?;
digest.update(&chunk[..to_read]);
out.write_all(&chunk[..to_read]).map_err(|e| {
RepError::NetworkRestoreError(format!(
"writing '{}': {}",
dest_path.display(),
e
))
})?;
remaining -= to_read as u64;
total_bytes += to_read as u64;
}
let mut crc_buf = [0u8; 4];
stream.read_exact(&mut crc_buf).map_err(|e| {
RepError::NetworkRestoreError(format!(
"reading digest for '{}': {}",
filename, e
))
})?;
let want = u32::from_le_bytes(crc_buf);
let got = digest.finalize();
if want != got {
let _ = std::fs::remove_file(&dest_path);
return Err(RepError::NetworkRestoreError(format!(
"digest mismatch for '{}': expected {:#010x}, got {:#010x} (file corrupted or truncated in transit)",
filename, want, got
)));
}
files_done += 1;
self.update_progress(total_bytes, files_done);
self.update_elapsed(started_at.elapsed());
log::debug!(
"NetworkRestore: received '{}' ({} bytes)",
filename,
file_size
);
}
self.update_elapsed(started_at.elapsed());
self.complete()?;
log::info!(
"NetworkRestore from {}: {} file(s), {} bytes transferred in {:?}",
addr,
files_done,
total_bytes,
started_at.elapsed(),
);
Ok(())
}
pub fn execute_via_dispatcher(&self) -> Result<()> {
use crate::net::Channel;
use crate::net::service_dispatcher::connect_to_service;
use crate::network_restore_server::RESTORE_SERVICE_NAME;
{
let state = self.state.lock();
if *state != RestoreState::NotStarted {
return Err(RepError::NetworkRestoreError(format!(
"execute_via_dispatcher called in wrong state: {:?}",
*state
)));
}
}
self.start()?;
let started_at = Instant::now();
let addr_str =
format!("{}:{}", self.config.source_host, self.config.source_port);
let addr: std::net::SocketAddr = addr_str.parse().map_err(|e| {
RepError::NetworkRestoreError(format!(
"bad source address {}: {}",
addr_str, e
))
})?;
let channel =
connect_to_service(addr, RESTORE_SERVICE_NAME).map_err(|e| {
RepError::NetworkRestoreError(format!(
"connect_to_service(RESTORE) at {}: {}",
addr, e
))
})?;
channel.send(&RESTORE_MAGIC.to_le_bytes()).map_err(|e| {
RepError::NetworkRestoreError(format!(
"sending restore magic via dispatcher: {}",
e
))
})?;
let payload = channel
.receive(Duration::from_secs(120))
.map_err(|e| {
RepError::NetworkRestoreError(format!(
"receiving restore payload: {}",
e
))
})?
.ok_or_else(|| {
RepError::NetworkRestoreError(
"empty restore payload from dispatcher".to_string(),
)
})?;
if payload.len() < 4 {
return Err(RepError::NetworkRestoreError(format!(
"truncated restore payload: {} bytes",
payload.len()
)));
}
let mut off = 0usize;
let mut buf4 = [0u8; 4];
buf4.copy_from_slice(&payload[off..off + 4]);
off += 4;
let file_count = u32::from_le_bytes(buf4);
let log_dir = self.local_log_dir.clone().unwrap_or_else(|| {
std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
});
std::fs::create_dir_all(&log_dir).map_err(|e| {
RepError::NetworkRestoreError(format!(
"creating log dir {}: {}",
log_dir.display(),
e
))
})?;
let mut total_bytes: u64 = 0;
let mut files_done: u32 = 0;
let mut buf2 = [0u8; 2];
let mut buf8 = [0u8; 8];
for _ in 0..file_count {
if off + 2 > payload.len() {
return Err(RepError::NetworkRestoreError(
"truncated restore payload at name_len".to_string(),
));
}
buf2.copy_from_slice(&payload[off..off + 2]);
off += 2;
let name_len = u16::from_le_bytes(buf2) as usize;
if off + name_len + 8 > payload.len() {
return Err(RepError::NetworkRestoreError(
"truncated restore payload at name+size".to_string(),
));
}
let name_bytes = payload[off..off + name_len].to_vec();
off += name_len;
let filename = String::from_utf8(name_bytes).map_err(|e| {
RepError::NetworkRestoreError(format!(
"non-UTF8 filename: {}",
e
))
})?;
validate_restore_filename(&filename)?;
buf8.copy_from_slice(&payload[off..off + 8]);
off += 8;
let file_size = u64::from_le_bytes(buf8) as usize;
if off + file_size + 4 > payload.len() {
return Err(RepError::NetworkRestoreError(format!(
"truncated restore payload at file body for {:?} \
(need {} bytes + 4 digest, have {})",
filename,
file_size,
payload.len() - off,
)));
}
let body = &payload[off..off + file_size];
let want = u32::from_le_bytes(
payload[off + file_size..off + file_size + 4]
.try_into()
.expect("4-byte CRC slice"),
);
let got = crc32fast::hash(body);
if want != got {
return Err(RepError::NetworkRestoreError(format!(
"digest mismatch for '{}': expected {:#010x}, got {:#010x} \
(file corrupted or truncated in transit)",
filename, want, got
)));
}
let dest_path = log_dir.join(&filename);
if self.config.retain_log_files && dest_path.exists() {
let backup = log_dir.join(format!("{}.bak", filename));
let _ = std::fs::rename(&dest_path, &backup);
}
std::fs::write(&dest_path, body).map_err(|e| {
RepError::NetworkRestoreError(format!(
"writing '{}': {}",
dest_path.display(),
e
))
})?;
off += file_size + 4;
total_bytes += file_size as u64;
files_done += 1;
self.update_progress(total_bytes, files_done);
self.update_elapsed(started_at.elapsed());
}
self.update_elapsed(started_at.elapsed());
self.complete()?;
log::info!(
"NetworkRestore via dispatcher from {}: {} file(s), {} bytes in {:?}",
addr,
files_done,
total_bytes,
started_at.elapsed(),
);
Ok(())
}
pub fn start(&self) -> Result<()> {
let mut state = self.state.lock();
match *state {
RestoreState::NotStarted => {
*state = RestoreState::InProgress;
let mut progress = self.progress.lock();
progress.state = RestoreState::InProgress;
Ok(())
}
RestoreState::Completed => Err(RepError::NetworkRestoreError(
"restore already completed".into(),
)),
RestoreState::Failed => Err(RepError::NetworkRestoreError(
"restore already failed; create a new instance".into(),
)),
RestoreState::InProgress => Err(RepError::NetworkRestoreError(
"restore already in progress".into(),
)),
}
}
pub fn update_progress(&self, bytes: u64, files: u32) {
let mut progress = self.progress.lock();
progress.bytes_transferred = bytes;
progress.files_transferred = files;
}
pub fn update_elapsed(&self, elapsed: Duration) {
let mut progress = self.progress.lock();
progress.elapsed = elapsed;
}
pub fn complete(&self) -> Result<()> {
let mut state = self.state.lock();
match *state {
RestoreState::InProgress => {
*state = RestoreState::Completed;
let mut progress = self.progress.lock();
progress.state = RestoreState::Completed;
Ok(())
}
other => Err(RepError::NetworkRestoreError(format!(
"cannot complete from state {:?}",
other
))),
}
}
pub fn fail(&self) -> Result<()> {
let mut state = self.state.lock();
match *state {
RestoreState::InProgress => {
*state = RestoreState::Failed;
let mut progress = self.progress.lock();
progress.state = RestoreState::Failed;
Ok(())
}
other => Err(RepError::NetworkRestoreError(format!(
"cannot fail from state {:?}",
other
))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> NetworkRestoreConfig {
NetworkRestoreConfig {
source_node: "node1".into(),
source_host: "192.168.1.10".into(),
source_port: 5001,
retain_log_files: false,
}
}
#[test]
fn test_initial_state() {
let restore = NetworkRestore::new(test_config());
assert_eq!(restore.get_state(), RestoreState::NotStarted);
let progress = restore.get_progress();
assert_eq!(progress.state, RestoreState::NotStarted);
assert_eq!(progress.bytes_transferred, 0);
assert_eq!(progress.files_transferred, 0);
assert_eq!(progress.elapsed, Duration::ZERO);
}
#[test]
fn test_start() {
let restore = NetworkRestore::new(test_config());
restore.start().unwrap();
assert_eq!(restore.get_state(), RestoreState::InProgress);
assert_eq!(restore.get_progress().state, RestoreState::InProgress);
}
#[test]
fn test_start_twice_fails() {
let restore = NetworkRestore::new(test_config());
restore.start().unwrap();
let result = restore.start();
assert!(result.is_err());
}
#[test]
fn test_update_progress() {
let restore = NetworkRestore::new(test_config());
restore.start().unwrap();
restore.update_progress(1024 * 1024, 3);
let progress = restore.get_progress();
assert_eq!(progress.bytes_transferred, 1024 * 1024);
assert_eq!(progress.files_transferred, 3);
}
#[test]
fn test_update_elapsed() {
let restore = NetworkRestore::new(test_config());
restore.start().unwrap();
let elapsed = Duration::from_secs(42);
restore.update_elapsed(elapsed);
assert_eq!(restore.get_progress().elapsed, elapsed);
}
#[test]
fn test_complete() {
let restore = NetworkRestore::new(test_config());
restore.start().unwrap();
restore.complete().unwrap();
assert_eq!(restore.get_state(), RestoreState::Completed);
assert_eq!(restore.get_progress().state, RestoreState::Completed);
}
#[test]
fn test_complete_from_not_started_fails() {
let restore = NetworkRestore::new(test_config());
let result = restore.complete();
assert!(result.is_err());
}
#[test]
fn test_fail() {
let restore = NetworkRestore::new(test_config());
restore.start().unwrap();
restore.fail().unwrap();
assert_eq!(restore.get_state(), RestoreState::Failed);
assert_eq!(restore.get_progress().state, RestoreState::Failed);
}
#[test]
fn test_fail_from_not_started_fails() {
let restore = NetworkRestore::new(test_config());
let result = restore.fail();
assert!(result.is_err());
}
#[test]
fn test_start_after_completed_fails() {
let restore = NetworkRestore::new(test_config());
restore.start().unwrap();
restore.complete().unwrap();
let result = restore.start();
assert!(result.is_err());
}
#[test]
fn test_start_after_failed_fails() {
let restore = NetworkRestore::new(test_config());
restore.start().unwrap();
restore.fail().unwrap();
let result = restore.start();
assert!(result.is_err());
}
#[test]
fn test_config_accessor() {
let config = test_config();
let restore = NetworkRestore::new(config);
assert_eq!(restore.get_config().source_node, "node1");
assert_eq!(restore.get_config().source_host, "192.168.1.10");
assert_eq!(restore.get_config().source_port, 5001);
assert!(!restore.get_config().retain_log_files);
}
#[test]
fn test_retain_log_files_config() {
let mut config = test_config();
config.retain_log_files = true;
let restore = NetworkRestore::new(config);
assert!(restore.get_config().retain_log_files);
}
#[test]
fn test_full_lifecycle() {
let restore = NetworkRestore::new(test_config());
assert_eq!(restore.get_state(), RestoreState::NotStarted);
restore.start().unwrap();
assert_eq!(restore.get_state(), RestoreState::InProgress);
restore.update_progress(512, 1);
restore.update_progress(2048, 2);
restore.update_elapsed(Duration::from_secs(5));
let progress = restore.get_progress();
assert_eq!(progress.bytes_transferred, 2048);
assert_eq!(progress.files_transferred, 2);
assert_eq!(progress.elapsed, Duration::from_secs(5));
restore.complete().unwrap();
assert_eq!(restore.get_state(), RestoreState::Completed);
}
#[test]
fn test_fail_lifecycle() {
let restore = NetworkRestore::new(test_config());
restore.start().unwrap();
restore.update_progress(256, 1);
restore.fail().unwrap();
assert_eq!(restore.get_state(), RestoreState::Failed);
let progress = restore.get_progress();
assert_eq!(progress.bytes_transferred, 256);
assert_eq!(progress.files_transferred, 1);
}
fn assert_unsafe(name: &str) {
let err = validate_restore_filename(name)
.expect_err(&format!("expected rejection for {:?}", name));
match err {
RepError::ProtocolError(msg) => assert!(
msg.contains("unsafe filename"),
"unexpected message for {:?}: {}",
name,
msg
),
other => {
panic!("expected ProtocolError for {:?}, got {:?}", name, other)
}
}
}
#[test]
fn test_validate_filename_rejects_empty() {
assert_unsafe("");
}
#[test]
fn test_validate_filename_rejects_dot_and_dotdot() {
assert_unsafe(".");
assert_unsafe("..");
}
#[test]
fn test_validate_filename_rejects_hidden_dotfile() {
assert_unsafe(".bashrc");
assert_unsafe(".hidden");
}
#[test]
fn test_validate_filename_rejects_path_separators() {
assert_unsafe("../etc/passwd");
assert_unsafe("/etc/passwd");
assert_unsafe("subdir/file.ndb");
assert_unsafe("dir\\file.ndb");
assert_unsafe("..\\windows\\system32");
}
#[test]
fn test_validate_filename_rejects_null_byte() {
assert_unsafe("good\0name.ndb");
assert_unsafe("\0");
}
#[test]
fn test_validate_filename_accepts_normal_log_files() {
validate_restore_filename("00000000.ndb").unwrap();
validate_restore_filename("00000001.ndb").unwrap();
validate_restore_filename("ffffffff.ndb").unwrap();
validate_restore_filename("data.bin").unwrap();
validate_restore_filename("name-with-dashes_and_underscores.ndb")
.unwrap();
}
}