memvid_cli/commands/
capsule.rs

1use std::ffi::OsStr;
2use std::fs;
3use std::io::{self, BufRead, Write};
4use std::path::{Path, PathBuf};
5
6use anyhow::{bail, Context, Result};
7use clap::{ArgAction, Parser};
8
9#[cfg(unix)]
10use libc::{isatty, tcgetattr, tcsetattr, termios, ECHO, TCSANOW};
11
12#[derive(Parser, Debug)]
13#[command(about = "Encrypt a memory file, creating an encrypted capsule (.mv2e)")]
14pub struct LockArgs {
15    /// Path to the .mv2 file to encrypt
16    #[arg(value_name = "FILE")]
17    pub file: PathBuf,
18
19    /// Interactive password prompt (default, safest)
20    #[arg(long, action = ArgAction::SetTrue, conflicts_with = "password_stdin")]
21    pub password: bool,
22
23    /// Read password from stdin (for CI/scripts)
24    #[arg(long = "password-stdin", action = ArgAction::SetTrue, conflicts_with = "password")]
25    pub password_stdin: bool,
26
27    /// Output file path (default: <FILE>.mv2e)
28    #[arg(short, long, value_name = "PATH")]
29    pub out: Option<PathBuf>,
30
31    /// Overwrite output file if it exists
32    #[arg(long, action = ArgAction::SetTrue)]
33    pub force: bool,
34
35    /// Keep the original .mv2 file after encryption (default: delete for security)
36    #[arg(long = "keep-original", action = ArgAction::SetTrue)]
37    pub keep_original: bool,
38
39    /// Output result as JSON
40    #[arg(long, action = ArgAction::SetTrue)]
41    pub json: bool,
42}
43
44#[derive(Parser, Debug)]
45#[command(about = "Decrypt an encrypted capsule, recreating the original .mv2 file")]
46pub struct UnlockArgs {
47    /// Path to the .mv2e file to decrypt
48    #[arg(value_name = "FILE")]
49    pub file: PathBuf,
50
51    /// Interactive password prompt (default, safest)
52    #[arg(long, action = ArgAction::SetTrue, conflicts_with = "password_stdin")]
53    pub password: bool,
54
55    /// Read password from stdin (for CI/scripts)
56    #[arg(long = "password-stdin", action = ArgAction::SetTrue, conflicts_with = "password")]
57    pub password_stdin: bool,
58
59    /// Output file path (default: <FILE> without .mv2e)
60    #[arg(short, long, value_name = "PATH")]
61    pub out: Option<PathBuf>,
62
63    /// Overwrite output file if it exists
64    #[arg(long, action = ArgAction::SetTrue)]
65    pub force: bool,
66
67    /// Output result as JSON
68    #[arg(long, action = ArgAction::SetTrue)]
69    pub json: bool,
70}
71
72pub fn handle_lock(args: LockArgs) -> Result<()> {
73    ensure_extension(&args.file, "mv2")?;
74
75    let output_path = args.out.unwrap_or_else(|| args.file.with_extension("mv2e"));
76    ensure_distinct_paths(&args.file, &output_path)?;
77    ensure_output_path(&output_path, args.force)?;
78
79    let mut password_bytes = read_password_bytes(
80        PasswordMode::from_args(args.password, args.password_stdin),
81        true,
82    )?;
83
84    let input_len = fs::metadata(&args.file)
85        .with_context(|| format!("failed to stat {}", args.file.display()))?
86        .len();
87
88    let result_path = memvid_core::encryption::lock_file(
89        &args.file,
90        Some(output_path.as_path()),
91        &password_bytes,
92    )
93    .map_err(anyhow::Error::from)?;
94    password_bytes.fill(0);
95
96    // Delete the original .mv2 file for security (unless --keep-original)
97    if !args.keep_original {
98        fs::remove_file(&args.file)
99            .with_context(|| format!("failed to remove original file {}", args.file.display()))?;
100
101        if !args.json {
102            println!(
103                "Deleted: {} (use --keep-original to preserve)",
104                args.file.display()
105            );
106        }
107    }
108
109    print_capsule_result(
110        &args.file,
111        &result_path,
112        input_len,
113        args.json,
114        !args.keep_original,
115    )?;
116    Ok(())
117}
118
119pub fn handle_unlock(args: UnlockArgs) -> Result<()> {
120    ensure_extension(&args.file, "mv2e")?;
121
122    let output_path = args.out.unwrap_or_else(|| args.file.with_extension("mv2"));
123    ensure_distinct_paths(&args.file, &output_path)?;
124    ensure_output_path(&output_path, args.force)?;
125
126    let mut password_bytes = read_password_bytes(
127        PasswordMode::from_args(args.password, args.password_stdin),
128        false,
129    )?;
130
131    let input_len = fs::metadata(&args.file)
132        .with_context(|| format!("failed to stat {}", args.file.display()))?
133        .len();
134
135    let result_path = memvid_core::encryption::unlock_file(
136        &args.file,
137        Some(output_path.as_path()),
138        &password_bytes,
139    )
140    .map_err(anyhow::Error::from)?;
141    password_bytes.fill(0);
142
143    print_capsule_result(&args.file, &result_path, input_len, args.json, false)?;
144    Ok(())
145}
146
147fn print_capsule_result(
148    input: &Path,
149    output: &Path,
150    size: u64,
151    json: bool,
152    deleted_original: bool,
153) -> Result<()> {
154    if json {
155        let payload = serde_json::json!({
156            "input": input.display().to_string(),
157            "output": output.display().to_string(),
158            "size": size,
159            "original_deleted": deleted_original,
160        });
161        println!("{}", payload);
162        return Ok(());
163    }
164
165    println!("Wrote: {}", output.display());
166    Ok(())
167}
168
169fn ensure_extension(path: &Path, expected: &str) -> Result<()> {
170    let ext = path.extension().and_then(OsStr::to_str);
171    if ext != Some(expected) {
172        bail!(
173            "Expected .{} file, got: {}",
174            expected,
175            ext.unwrap_or("<none>")
176        );
177    }
178    Ok(())
179}
180
181fn ensure_distinct_paths(input: &Path, output: &Path) -> Result<()> {
182    if input == output {
183        bail!("Refusing to overwrite input file: {}", input.display());
184    }
185    Ok(())
186}
187
188fn ensure_output_path(output: &Path, force: bool) -> Result<()> {
189    if !output.exists() {
190        return Ok(());
191    }
192    if !force {
193        bail!(
194            "Output file already exists: {}\nUse --force to overwrite",
195            output.display()
196        );
197    }
198    fs::remove_file(output).with_context(|| format!("failed to remove {}", output.display()))?;
199    Ok(())
200}
201
202#[derive(Debug, Clone, Copy)]
203enum PasswordMode {
204    Prompt,
205    Stdin,
206}
207
208impl PasswordMode {
209    fn from_args(_password: bool, password_stdin: bool) -> Self {
210        if password_stdin {
211            return PasswordMode::Stdin;
212        }
213        PasswordMode::Prompt
214    }
215}
216
217fn read_password_bytes(mode: PasswordMode, confirm: bool) -> Result<Vec<u8>> {
218    let password = match mode {
219        PasswordMode::Stdin => read_password_from_stdin()?,
220        PasswordMode::Prompt => read_password_from_prompt(confirm)?,
221    };
222
223    if password.trim().is_empty() {
224        bail!("Password cannot be empty");
225    }
226
227    Ok(password.into_bytes())
228}
229
230fn read_password_from_stdin() -> Result<String> {
231    let stdin = io::stdin();
232    let mut reader = stdin.lock();
233    let mut line = String::new();
234    let bytes = reader
235        .read_line(&mut line)
236        .context("failed to read password from stdin")?;
237    if bytes == 0 {
238        bail!("Password cannot be empty");
239    }
240    Ok(line.trim_end_matches(&['\n', '\r'][..]).to_string())
241}
242
243fn read_password_from_prompt(confirm: bool) -> Result<String> {
244    let password = read_password_hidden("Password: ")?;
245    if !confirm {
246        return Ok(password);
247    }
248    let confirm_pw = read_password_hidden("Confirm:  ")?;
249    if password != confirm_pw {
250        bail!("Passwords do not match");
251    }
252    Ok(password)
253}
254
255fn read_password_hidden(prompt: &str) -> Result<String> {
256    let is_tty = stdin_is_tty();
257    let mut stderr = io::stderr();
258    stderr.write_all(prompt.as_bytes())?;
259    stderr.flush()?;
260
261    let guard = if is_tty { disable_stdin_echo()? } else { None };
262
263    let stdin = io::stdin();
264    let mut reader = stdin.lock();
265    let mut line = String::new();
266    reader.read_line(&mut line)?;
267
268    if guard.is_some() {
269        let _ = stderr.write_all(b"\n");
270        let _ = stderr.flush();
271    }
272
273    Ok(line.trim_end_matches(&['\n', '\r'][..]).to_string())
274}
275
276fn stdin_is_tty() -> bool {
277    #[cfg(unix)]
278    unsafe {
279        isatty(libc::STDIN_FILENO) == 1
280    }
281
282    #[cfg(not(unix))]
283    {
284        false
285    }
286}
287
288#[cfg(unix)]
289fn disable_stdin_echo() -> Result<Option<EchoGuard>> {
290    let fd = libc::STDIN_FILENO;
291    unsafe {
292        let mut current: termios = std::mem::zeroed();
293        if tcgetattr(fd, &mut current) != 0 {
294            return Ok(None);
295        }
296        let mut updated = current;
297        updated.c_lflag &= !ECHO;
298        if tcsetattr(fd, TCSANOW, &updated) != 0 {
299            return Ok(None);
300        }
301        Ok(Some(EchoGuard {
302            fd,
303            previous: current,
304        }))
305    }
306}
307
308#[cfg(not(unix))]
309fn disable_stdin_echo() -> Result<Option<EchoGuard>> {
310    Ok(None)
311}
312
313#[cfg(unix)]
314struct EchoGuard {
315    fd: i32,
316    previous: termios,
317}
318
319#[cfg(unix)]
320impl Drop for EchoGuard {
321    fn drop(&mut self) {
322        unsafe {
323            let _ = tcsetattr(self.fd, TCSANOW, &self.previous);
324        }
325    }
326}
327
328#[cfg(not(unix))]
329struct EchoGuard;
330
331#[cfg(not(unix))]
332impl Drop for EchoGuard {
333    fn drop(&mut self) {}
334}