1use clap::Clap;
2use context::ContextErrorKind;
3use log::debug;
4use std::{error::Error, fs::{File}, io::Read, sync::mpsc::channel};
5
6use crate::{
7 context::{self, Context, PivotByLineContext},
8 encoder::Capacity,
9 method::complex::{eluv::ELUVMethod, extended_line::ExtendedLineMethod},
10};
11
12use super::{encoder::{determine_pivot_size, validate_pivot_smaller_than_text}, progress::{new_progress_bar, spawn_progress_thread, ProgressStatus}, writer::Writer};
13
14#[derive(Clap)]
16pub struct GetCapacityCommand {
17 #[clap(short, long)]
19 cover: String,
20
21 #[clap(short, long)]
23 pivot: usize,
24
25 #[clap(long, group = "method_args")]
29 eluv: bool,
30
31 #[clap(long = "eline", group = "method_args")]
35 #[allow(dead_code)]
36 extended_line: bool,
37}
38
39impl GetCapacityCommand {
40 pub fn run(&self) -> Result<u32, Box<dyn Error>> {
41 let cover_file_input = File::open(&self.cover)?;
42
43 self.get_cover_text_capacity(cover_file_input)
44 }
45
46 pub(crate) fn get_cover_text_capacity(
47 &self,
48 mut cover_input: impl Read,
49 ) -> Result<u32, Box<dyn Error>> {
50 let mut cover_text = String::new();
51
52 cover_input.read_to_string(&mut cover_text)?;
53 let mut pivot_word_context = PivotByLineContext::new(&cover_text, self.pivot);
54 let mut text_fragment_count = 0;
55
56 let max_word_length = determine_pivot_size(cover_text.split_whitespace());
57 validate_pivot_smaller_than_text(self.pivot, &cover_text)?;
58
59 debug!("Longest word in the cover text is {}", max_word_length);
60
61 if max_word_length > self.pivot {
62 Writer::warn("This pivot might not guarantee the secret data will be encodable!");
63 }
64
65 let progress_bar = new_progress_bar(cover_text.len() as u64);
66 let (tx, rx) = channel::<ProgressStatus>();
67 progress_bar.set_message("Calculating the capacity...");
68 spawn_progress_thread(progress_bar.clone(), rx);
69
70 loop {
71 let result = pivot_word_context.load_text();
72
73 match result {
74 Ok(fragment) => {
75 tx.send(ProgressStatus::Step(fragment.len() as u64)).ok();
76 text_fragment_count += 1;
77 }
78 Err(error) => match error.kind() {
79 ContextErrorKind::CannotConstructLine => {
80 tx.send(ProgressStatus::Finished).ok();
81 progress_bar.abandon_with_message("Error occurred");
82 return Err(error.into());
83 }
84 ContextErrorKind::NoTextLeft => {
85 tx.send(ProgressStatus::Finished).ok();
86 progress_bar.finish_with_message("Capacity calculated");
87 break;
88 }
89 },
90 }
91
92 pivot_word_context.next_word();
93 }
94
95 let method = self.get_method();
96 Ok(text_fragment_count * method.bitrate() as u32)
97 }
98
99 pub(crate) fn get_method(&self) -> Box<dyn Capacity> {
100 if self.eluv {
101 Box::new(ELUVMethod::default())
102 } else {
103 Box::new(ExtendedLineMethod::default())
104 }
105 }
106}
107
108
109
110#[allow(unused_imports)]
111mod test {
112 use std::{error::Error, io::Read};
113
114 use super::GetCapacityCommand;
115
116 #[test]
117 fn returns_capacity_for_given_method() -> Result<(), Box<dyn Error>> {
118 let cover_input = "a b c ".repeat(2);
119
120 let command = GetCapacityCommand {
121 cover: "stub".into(),
122 pivot: 3,
123 eluv: false,
124 extended_line: true,
125 };
126
127 let result = command.get_cover_text_capacity(cover_input.as_bytes());
128 assert_eq!(result.ok(), Some(6 as u32));
129 Ok(())
130 }
131
132
133 #[test]
134 fn fails_when_pivot_is_too_large() -> Result<(), Box<dyn Error>> {
135 let stego_input = "aaaaa";
136
137 let command = GetCapacityCommand {
138 cover: "stub".into(),
139 pivot: 6,
140 eluv: false,
141 extended_line: true,
142 };
143
144 let result = command.get_cover_text_capacity(stego_input.as_bytes());
145 assert!(result.is_err());
146 Ok(())
147 }
148}