1use ssec_core::decrypt::{Decrypt, SsecHeaderError};
2use futures_util::{Stream, StreamExt};
3use tokio::io::AsyncWriteExt;
4use zeroize::Zeroizing;
5use indicatif::{ProgressBar, ProgressStyle};
6use std::path::PathBuf;
7use crate::cli::{DecArgs, FetchArgs};
8use crate::file::new_async_tempfile;
9use crate::password::prompt_password;
10use crate::io::IoBundle;
11use crate::{DEFINITE_BAR_STYLE, INDEFINITE_BAR_STYLE};
12
13const SPINNER_STYLE: &str = "{spinner} deriving decryption key";
14
15macro_rules! bail {
16 ($p:ident, $m:literal) => {
17 $p.suspend(|| {
18 eprintln!($m);
19 });
20 return Err(());
21 }
22}
23
24async fn dec_stream_to<E, S>(
25 stream: S,
26 password: Zeroizing<Vec<u8>>,
27 out_path: PathBuf,
28 show_progress: bool,
29 enc_len: Option<u64>
30) -> Result<(), ()>
31where
32 E: std::error::Error,
33 S: Stream<Item = Result<bytes::Bytes, E>> + Unpin + Send + 'static
34{
35 let progress = match show_progress {
36 true => ProgressBar::new_spinner(),
37 false => ProgressBar::hidden()
38 };
39 let stream = stream.map({
40 let progress = progress.clone();
41 move |b| {
42 if let Ok(b) = &b {
43 progress.inc(b.len() as u64);
44 }
45 b
46 }
47 });
48
49 let (dec, f_out) = tokio::join!(
50 async {
51 let dec = Decrypt::new(stream).await?;
52 Ok::<_, SsecHeaderError<E>>(tokio::task::spawn_blocking({
53 let progress = progress.clone();
54 move || {
55 progress.set_style(ProgressStyle::with_template(SPINNER_STYLE).unwrap());
56 progress.enable_steady_tick(std::time::Duration::from_millis(100));
57
58 dec.try_password(&password)
59 }
60 }).await.unwrap())
61 },
62 new_async_tempfile()
63 );
64
65 let mut dec = match dec {
66 Ok(Ok(dec)) => dec,
67 Ok(Err(_)) => {
68 bail!(progress, "password incorrect");
69 },
70 Err(SsecHeaderError::NotSsec) => {
71 bail!(progress, "input is not a SSEC file");
72 },
73 Err(SsecHeaderError::UnsupportedVersion(0)) => {
74 bail!(progress, "input is from an old version of SSEC, consider downgrading to `ssec-cli` version 0.3");
75 },
76 Err(SsecHeaderError::UnsupportedVersion(v)) => {
77 bail!(progress, "input is from a future version of SSEC (version {v:?}), consider updating `ssec-cli` to the latest version");
78 },
79 Err(SsecHeaderError::UnsupportedCompression(c)) => {
80 bail!(progress, "input has unimplemented compression (type {c:?}), consider updating `ssec-cli` to the latest version");
81 },
82 Err(SsecHeaderError::Stream(e)) => {
83 bail!(progress, "input stream failed: {e}");
84 }
85 };
86 let mut f_out = f_out.unwrap();
87
88 progress.disable_steady_tick();
89 match enc_len {
90 Some(enc_len) => {
91 progress.set_length(enc_len);
92 progress.set_style(ProgressStyle::with_template(DEFINITE_BAR_STYLE).unwrap());
93 },
94 None => progress.set_style(ProgressStyle::with_template(INDEFINITE_BAR_STYLE).unwrap())
95 }
96 progress.reset();
97
98 while let Some(bytes) = dec.next().await {
99 let b = match bytes {
100 Ok(b) => b,
101 Err(e) => {
102 bail!(progress, "{e}");
103 },
104 };
105
106 f_out.as_mut().write_all(&b).await.unwrap();
107 }
108
109 f_out.as_mut().shutdown().await.unwrap();
110
111 f_out.persist(out_path).await.unwrap();
112
113 Ok(())
114}
115
116pub async fn dec_file<B: IoBundle>(args: DecArgs, io: B) -> Result<(), ()> {
117 let password = prompt_password(io).await.map_err(|e| {
118 eprintln!("failed to read password interactively: {e}");
119 })?;
120
121 let f_in = tokio::fs::File::open(&args.in_file).await.map_err(|e| {
122 eprintln!("failed to open file {:?}: {e}", args.in_file);
123 })?;
124
125 let f_in_metadata = f_in.metadata().await.map_err(|e| {
126 eprintln!("failed to get metadata of input file: {e}");
127 })?;
128
129 let s = tokio_util::io::ReaderStream::new(f_in);
130
131 dec_stream_to(
132 s,
133 password,
134 args.out_file,
135 B::is_interactive() && !args.silent,
136 Some(f_in_metadata.len())
137 ).await
138}
139
140pub async fn dec_fetch<B: IoBundle>(args: FetchArgs, io: B) -> Result<(), ()> {
141 let password = prompt_password(io).await.map_err(|e| {
142 eprintln!("failed to read password interactively: {e}");
143 })?;
144
145 let client = reqwest::Client::new();
146
147 let resp = client.get(args.url.clone()).send().await.map_err(|e| {
148 eprintln!("failed to fetch remote file {:?}: {e}", args.url);
149 })?;
150 let enc_len = resp.content_length();
151 let s = resp.bytes_stream();
152
153 dec_stream_to(
154 s,
155 password,
156 args.out_file,
157 B::is_interactive() && !args.silent,
158 enc_len
159 ).await
160}