Skip to main content

winrm_rs/
transfer.rs

1// File transfer via WinRM.
2//
3// Upload and download files using PowerShell base64 chunking.
4
5use std::path::Path;
6
7use base64::Engine;
8use base64::engine::general_purpose::STANDARD as B64;
9
10use crate::client::WinrmClient;
11use crate::error::WinrmError;
12
13/// File transfer chunk size (raw bytes before encoding).
14///
15/// Each chunk is base64-encoded, embedded in a PowerShell script, then the
16/// script is UTF-16LE encoded and base64-encoded again for `-EncodedCommand`.
17/// This triple encoding yields ~3.5× expansion. WinRS enforces an ~8191-char
18/// command-line limit, so the raw chunk must stay under ~2 KB.
19const CHUNK_SIZE: usize = 2000;
20
21/// Maximum allowed remote path length (Windows MAX_PATH).
22const MAX_REMOTE_PATH_LEN: usize = 260;
23
24/// Validate a remote file path for safety.
25///
26/// Rejects paths containing control characters (`\x00`-`\x1F` except `\t`)
27/// or exceeding Windows MAX_PATH (260 characters).
28fn validate_remote_path(path: &str) -> Result<(), WinrmError> {
29    if path.len() > MAX_REMOTE_PATH_LEN {
30        return Err(WinrmError::Transfer(format!(
31            "remote path exceeds {MAX_REMOTE_PATH_LEN} characters"
32        )));
33    }
34    if path.chars().any(|c| c.is_control() && c != '\t') {
35        return Err(WinrmError::Transfer(
36            "remote path contains control characters".into(),
37        ));
38    }
39    Ok(())
40}
41
42impl WinrmClient {
43    /// Upload a local file to a remote Windows host.
44    ///
45    /// The file is chunked into ~2 KB pieces, base64-encoded, and written
46    /// via PowerShell `[IO.File]::WriteAllBytes` / `[IO.File]::Open('Append')`.
47    /// The small chunk size is dictated by the WinRS command-line limit
48    /// (~8191 chars) after triple encoding (base64 → UTF-16LE → base64).
49    ///
50    /// Returns the number of bytes uploaded.
51    pub async fn upload_file(
52        &self,
53        host: &str,
54        local_path: &Path,
55        remote_path: &str,
56    ) -> Result<u64, WinrmError> {
57        validate_remote_path(remote_path)?;
58
59        let data = std::fs::read(local_path).map_err(|e| {
60            WinrmError::Transfer(format!(
61                "failed to read local file {}: {e}",
62                local_path.display()
63            ))
64        })?;
65
66        let shell = self.open_shell(host).await?;
67        let total = data.len() as u64;
68        let escaped_path = remote_path.replace('\'', "''");
69
70        for (i, chunk) in data.chunks(CHUNK_SIZE).enumerate() {
71            let b64 = B64.encode(chunk);
72
73            let script = if i == 0 {
74                format!(
75                    "$bytes = [Convert]::FromBase64String('{b64}'); \
76                     [IO.File]::WriteAllBytes('{escaped_path}', $bytes)"
77                )
78            } else {
79                format!(
80                    "$bytes = [Convert]::FromBase64String('{b64}'); \
81                     $f = [IO.File]::Open('{escaped_path}', 'Append'); \
82                     $f.Write($bytes, 0, $bytes.Length); $f.Close()"
83                )
84            };
85
86            let output = shell.run_powershell(&script).await?;
87            if output.exit_code != 0 {
88                shell.close().await.ok();
89                return Err(WinrmError::Transfer(format!(
90                    "upload chunk {i} failed: {}",
91                    String::from_utf8_lossy(&output.stderr)
92                )));
93            }
94        }
95
96        shell.close().await.ok();
97        Ok(total)
98    }
99
100    /// Download a file from a remote Windows host.
101    ///
102    /// Reads the file via PowerShell base64 encoding and decodes locally.
103    ///
104    /// Returns the number of bytes downloaded.
105    pub async fn download_file(
106        &self,
107        host: &str,
108        remote_path: &str,
109        local_path: &Path,
110    ) -> Result<u64, WinrmError> {
111        validate_remote_path(remote_path)?;
112
113        let escaped = remote_path.replace('\'', "''");
114        let script = format!("[Convert]::ToBase64String([IO.File]::ReadAllBytes('{escaped}'))");
115
116        let output = self.run_powershell(host, &script).await?;
117        if output.exit_code != 0 {
118            return Err(WinrmError::Transfer(format!(
119                "download failed: {}",
120                String::from_utf8_lossy(&output.stderr)
121            )));
122        }
123
124        let b64 = String::from_utf8_lossy(&output.stdout);
125        let data = B64
126            .decode(b64.trim_ascii())
127            .map_err(|e| WinrmError::Transfer(format!("base64 decode of downloaded file: {e}")))?;
128
129        let total = data.len() as u64;
130        std::fs::write(local_path, &data).map_err(|e| {
131            WinrmError::Transfer(format!(
132                "failed to write local file {}: {e}",
133                local_path.display()
134            ))
135        })?;
136
137        Ok(total)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn validate_remote_path_ok() {
147        assert!(validate_remote_path("C:\\Users\\admin\\file.txt").is_ok());
148    }
149
150    #[test]
151    fn validate_remote_path_too_long() {
152        let long_path = "C:\\".to_string() + &"a".repeat(260);
153        assert!(validate_remote_path(&long_path).is_err());
154    }
155
156    #[test]
157    fn validate_remote_path_control_chars() {
158        assert!(validate_remote_path("C:\\bad\x00path").is_err());
159        assert!(validate_remote_path("C:\\bad\x01path").is_err());
160    }
161
162    #[test]
163    fn validate_remote_path_tab_allowed() {
164        // Tab (\t) is explicitly allowed
165        assert!(validate_remote_path("C:\\path\twith\ttabs").is_ok());
166    }
167
168    #[test]
169    fn validate_remote_path_max_length_boundary() {
170        let exact = "a".repeat(260);
171        assert!(validate_remote_path(&exact).is_ok());
172        let over = "a".repeat(261);
173        assert!(validate_remote_path(&over).is_err());
174    }
175}