memvid_cli/commands/
capsule.rs

1use std::ffi::OsStr;
2use std::fs;
3use std::io::{self, BufRead, Write};
4use std::path::{Path, PathBuf};
5
6use anyhow::{bail, Context, Result};
7use clap::{ArgAction, Parser};
8
9#[cfg(unix)]
10use libc::{isatty, tcgetattr, tcsetattr, termios, ECHO, TCSANOW};
11
12#[derive(Parser, Debug)]
13#[command(about = "Encrypt a memory file, creating an encrypted capsule (.mv2e)")]
14pub struct LockArgs {
15    /// Path to the .mv2 file to encrypt
16    #[arg(value_name = "FILE")]
17    pub file: PathBuf,
18
19    /// Interactive password prompt (default, safest)
20    #[arg(long, action = ArgAction::SetTrue, conflicts_with = "password_stdin")]
21    pub password: bool,
22
23    /// Read password from stdin (for CI/scripts)
24    #[arg(long = "password-stdin", action = ArgAction::SetTrue, conflicts_with = "password")]
25    pub password_stdin: bool,
26
27    /// Output file path (default: <FILE>.mv2e)
28    #[arg(short, long, value_name = "PATH")]
29    pub out: Option<PathBuf>,
30
31    /// Overwrite output file if it exists
32    #[arg(long, action = ArgAction::SetTrue)]
33    pub force: bool,
34
35    /// Keep the original .mv2 file after encryption (default: delete for security)
36    #[arg(long = "keep-original", action = ArgAction::SetTrue)]
37    pub keep_original: bool,
38
39    /// Output result as JSON
40    #[arg(long, action = ArgAction::SetTrue)]
41    pub json: bool,
42}
43
44#[derive(Parser, Debug)]
45#[command(about = "Decrypt an encrypted capsule, recreating the original .mv2 file")]
46pub struct UnlockArgs {
47    /// Path to the .mv2e file to decrypt
48    #[arg(value_name = "FILE")]
49    pub file: PathBuf,
50
51    /// Interactive password prompt (default, safest)
52    #[arg(long, action = ArgAction::SetTrue, conflicts_with = "password_stdin")]
53    pub password: bool,
54
55    /// Read password from stdin (for CI/scripts)
56    #[arg(long = "password-stdin", action = ArgAction::SetTrue, conflicts_with = "password")]
57    pub password_stdin: bool,
58
59    /// Output file path (default: <FILE> without .mv2e)
60    #[arg(short, long, value_name = "PATH")]
61    pub out: Option<PathBuf>,
62
63    /// Overwrite output file if it exists
64    #[arg(long, action = ArgAction::SetTrue)]
65    pub force: bool,
66
67    /// Output result as JSON
68    #[arg(long, action = ArgAction::SetTrue)]
69    pub json: bool,
70}
71
72pub fn handle_lock(args: LockArgs) -> Result<()> {
73    ensure_extension(&args.file, "mv2")?;
74
75    let output_path = args
76        .out
77        .unwrap_or_else(|| args.file.with_extension("mv2e"));
78    ensure_distinct_paths(&args.file, &output_path)?;
79    ensure_output_path(&output_path, args.force)?;
80
81    let mut password_bytes = read_password_bytes(PasswordMode::from_args(
82        args.password,
83        args.password_stdin,
84    ), true)?;
85
86    let input_len = fs::metadata(&args.file)
87        .with_context(|| format!("failed to stat {}", args.file.display()))?
88        .len();
89
90    let result_path = memvid_core::encryption::lock_file(
91        &args.file,
92        Some(output_path.as_path()),
93        &password_bytes,
94    )
95    .map_err(anyhow::Error::from)?;
96    password_bytes.fill(0);
97
98    // Delete the original .mv2 file for security (unless --keep-original)
99    if !args.keep_original {
100        fs::remove_file(&args.file)
101            .with_context(|| format!("failed to remove original file {}", args.file.display()))?;
102
103        if !args.json {
104            println!("Deleted: {} (use --keep-original to preserve)", args.file.display());
105        }
106    }
107
108    print_capsule_result(&args.file, &result_path, input_len, args.json, !args.keep_original)?;
109    Ok(())
110}
111
112pub fn handle_unlock(args: UnlockArgs) -> Result<()> {
113    ensure_extension(&args.file, "mv2e")?;
114
115    let output_path = args.out.unwrap_or_else(|| args.file.with_extension("mv2"));
116    ensure_distinct_paths(&args.file, &output_path)?;
117    ensure_output_path(&output_path, args.force)?;
118
119    let mut password_bytes = read_password_bytes(PasswordMode::from_args(
120        args.password,
121        args.password_stdin,
122    ), false)?;
123
124    let input_len = fs::metadata(&args.file)
125        .with_context(|| format!("failed to stat {}", args.file.display()))?
126        .len();
127
128    let result_path = memvid_core::encryption::unlock_file(
129        &args.file,
130        Some(output_path.as_path()),
131        &password_bytes,
132    )
133    .map_err(anyhow::Error::from)?;
134    password_bytes.fill(0);
135
136    print_capsule_result(&args.file, &result_path, input_len, args.json, false)?;
137    Ok(())
138}
139
140fn print_capsule_result(input: &Path, output: &Path, size: u64, json: bool, deleted_original: bool) -> Result<()> {
141    if json {
142        let payload = serde_json::json!({
143            "input": input.display().to_string(),
144            "output": output.display().to_string(),
145            "size": size,
146            "original_deleted": deleted_original,
147        });
148        println!("{}", payload);
149        return Ok(());
150    }
151
152    println!("Wrote: {}", output.display());
153    Ok(())
154}
155
156fn ensure_extension(path: &Path, expected: &str) -> Result<()> {
157    let ext = path.extension().and_then(OsStr::to_str);
158    if ext != Some(expected) {
159        bail!(
160            "Expected .{} file, got: {}",
161            expected,
162            ext.unwrap_or("<none>")
163        );
164    }
165    Ok(())
166}
167
168fn ensure_distinct_paths(input: &Path, output: &Path) -> Result<()> {
169    if input == output {
170        bail!("Refusing to overwrite input file: {}", input.display());
171    }
172    Ok(())
173}
174
175fn ensure_output_path(output: &Path, force: bool) -> Result<()> {
176    if !output.exists() {
177        return Ok(());
178    }
179    if !force {
180        bail!(
181            "Output file already exists: {}\nUse --force to overwrite",
182            output.display()
183        );
184    }
185    fs::remove_file(output).with_context(|| format!("failed to remove {}", output.display()))?;
186    Ok(())
187}
188
189#[derive(Debug, Clone, Copy)]
190enum PasswordMode {
191    Prompt,
192    Stdin,
193}
194
195impl PasswordMode {
196    fn from_args(_password: bool, password_stdin: bool) -> Self {
197        if password_stdin {
198            return PasswordMode::Stdin;
199        }
200        PasswordMode::Prompt
201    }
202}
203
204fn read_password_bytes(mode: PasswordMode, confirm: bool) -> Result<Vec<u8>> {
205    let password = match mode {
206        PasswordMode::Stdin => read_password_from_stdin()?,
207        PasswordMode::Prompt => read_password_from_prompt(confirm)?,
208    };
209
210    if password.trim().is_empty() {
211        bail!("Password cannot be empty");
212    }
213
214    Ok(password.into_bytes())
215}
216
217fn read_password_from_stdin() -> Result<String> {
218    let stdin = io::stdin();
219    let mut reader = stdin.lock();
220    let mut line = String::new();
221    let bytes = reader
222        .read_line(&mut line)
223        .context("failed to read password from stdin")?;
224    if bytes == 0 {
225        bail!("Password cannot be empty");
226    }
227    Ok(line.trim_end_matches(&['\n', '\r'][..]).to_string())
228}
229
230fn read_password_from_prompt(confirm: bool) -> Result<String> {
231    let password = read_password_hidden("Password: ")?;
232    if !confirm {
233        return Ok(password);
234    }
235    let confirm_pw = read_password_hidden("Confirm:  ")?;
236    if password != confirm_pw {
237        bail!("Passwords do not match");
238    }
239    Ok(password)
240}
241
242fn read_password_hidden(prompt: &str) -> Result<String> {
243    let is_tty = stdin_is_tty();
244    let mut stderr = io::stderr();
245    stderr.write_all(prompt.as_bytes())?;
246    stderr.flush()?;
247
248    let guard = if is_tty { disable_stdin_echo()? } else { None };
249
250    let stdin = io::stdin();
251    let mut reader = stdin.lock();
252    let mut line = String::new();
253    reader.read_line(&mut line)?;
254
255    if guard.is_some() {
256        let _ = stderr.write_all(b"\n");
257        let _ = stderr.flush();
258    }
259
260    Ok(line.trim_end_matches(&['\n', '\r'][..]).to_string())
261}
262
263fn stdin_is_tty() -> bool {
264    #[cfg(unix)]
265    unsafe {
266        isatty(libc::STDIN_FILENO) == 1
267    }
268
269    #[cfg(not(unix))]
270    {
271        false
272    }
273}
274
275#[cfg(unix)]
276fn disable_stdin_echo() -> Result<Option<EchoGuard>> {
277    let fd = libc::STDIN_FILENO;
278    unsafe {
279        let mut current: termios = std::mem::zeroed();
280        if tcgetattr(fd, &mut current) != 0 {
281            return Ok(None);
282        }
283        let mut updated = current;
284        updated.c_lflag &= !ECHO;
285        if tcsetattr(fd, TCSANOW, &updated) != 0 {
286            return Ok(None);
287        }
288        Ok(Some(EchoGuard { fd, previous: current }))
289    }
290}
291
292#[cfg(not(unix))]
293fn disable_stdin_echo() -> Result<Option<EchoGuard>> {
294    Ok(None)
295}
296
297#[cfg(unix)]
298struct EchoGuard {
299    fd: i32,
300    previous: termios,
301}
302
303#[cfg(unix)]
304impl Drop for EchoGuard {
305    fn drop(&mut self) {
306        unsafe {
307            let _ = tcsetattr(self.fd, TCSANOW, &self.previous);
308        }
309    }
310}
311
312#[cfg(not(unix))]
313struct EchoGuard;
314
315#[cfg(not(unix))]
316impl Drop for EchoGuard {
317    fn drop(&mut self) {}
318}