1use std::io::Cursor;
16
17use bstr::ByteSlice as _;
18use tokio::io::AsyncRead;
19use tokio::io::AsyncReadExt as _;
20
21use crate::config::ConfigGetError;
22use crate::settings::UserSettings;
23
24fn is_binary(bytes: &[u8]) -> bool {
25 let mut bytes = bytes.iter().peekable();
29 while let Some(byte) = bytes.next() {
30 match *byte {
31 b'\0' => return true,
32 b'\r' => {
33 if bytes.peek() != Some(&&b'\n') {
34 return true;
35 }
36 }
37 _ => {}
38 }
39 }
40 false
41}
42
43#[derive(Clone)]
44pub(crate) struct TargetEolStrategy {
45 eol_conversion_mode: EolConversionMode,
46}
47
48impl TargetEolStrategy {
49 pub(crate) fn new(eol_conversion_mode: EolConversionMode) -> Self {
50 Self {
51 eol_conversion_mode,
52 }
53 }
54
55 const PROBE_LIMIT: u64 = 8 << 10;
57
58 pub(crate) async fn convert_eol_for_snapshot<'a>(
59 &self,
60 mut contents: impl AsyncRead + Send + Unpin + 'a,
61 ) -> Result<Box<dyn AsyncRead + Send + Unpin + 'a>, std::io::Error> {
62 match self.eol_conversion_mode {
63 EolConversionMode::None => Ok(Box::new(contents)),
64 EolConversionMode::Input | EolConversionMode::InputOutput => {
65 let mut peek = vec![];
66 (&mut contents)
67 .take(Self::PROBE_LIMIT)
68 .read_to_end(&mut peek)
69 .await?;
70 let target_eol = if is_binary(&peek) {
71 TargetEol::PassThrough
72 } else {
73 TargetEol::Lf
74 };
75 let peek = Cursor::new(peek);
76 let contents = peek.chain(contents);
77 convert_eol(contents, target_eol).await
78 }
79 }
80 }
81
82 pub(crate) async fn convert_eol_for_update<'a>(
83 &self,
84 mut contents: impl AsyncRead + Send + Unpin + 'a,
85 ) -> Result<Box<dyn AsyncRead + Send + Unpin + 'a>, std::io::Error> {
86 match self.eol_conversion_mode {
87 EolConversionMode::None | EolConversionMode::Input => Ok(Box::new(contents)),
88 EolConversionMode::InputOutput => {
89 let mut peek = vec![];
90 (&mut contents)
91 .take(Self::PROBE_LIMIT)
92 .read_to_end(&mut peek)
93 .await?;
94 let target_eol = if is_binary(&peek) {
95 TargetEol::PassThrough
96 } else {
97 TargetEol::Crlf
98 };
99 let peek = Cursor::new(peek);
100 let contents = peek.chain(contents);
101 convert_eol(contents, target_eol).await
102 }
103 }
104 }
105}
106
107#[derive(Debug, PartialEq, Eq, Copy, Clone, serde::Deserialize)]
110#[serde(rename_all(deserialize = "kebab-case"))]
111pub enum EolConversionMode {
112 None,
114 Input,
117 InputOutput,
121}
122
123impl EolConversionMode {
124 pub fn try_from_settings(user_settings: &UserSettings) -> Result<Self, ConfigGetError> {
127 user_settings.get("working-copy.eol-conversion")
128 }
129}
130
131#[derive(Clone, Copy, Debug, PartialEq, Eq)]
132enum TargetEol {
133 Lf,
134 Crlf,
135 PassThrough,
136}
137
138async fn convert_eol<'a>(
139 mut input: impl AsyncRead + Send + Unpin + 'a,
140 target_eol: TargetEol,
141) -> Result<Box<dyn AsyncRead + Send + Unpin + 'a>, std::io::Error> {
142 let eol = match target_eol {
143 TargetEol::PassThrough => {
144 return Ok(Box::new(input));
145 }
146 TargetEol::Lf => b"\n".as_slice(),
147 TargetEol::Crlf => b"\r\n".as_slice(),
148 };
149
150 let mut contents = vec![];
151 input.read_to_end(&mut contents).await?;
152 let lines = contents.lines_with_terminator();
153 let mut res = Vec::<u8>::with_capacity(contents.len());
154 fn trim_last_eol(input: &[u8]) -> Option<&[u8]> {
155 input
156 .strip_suffix(b"\r\n")
157 .or_else(|| input.strip_suffix(b"\n"))
158 }
159 for line in lines {
160 if let Some(line) = trim_last_eol(line) {
161 res.extend_from_slice(line);
162 res.extend_from_slice(eol);
164 } else {
165 res.extend_from_slice(line);
168 }
169 }
170 Ok(Box::new(Cursor::new(res)))
171}
172
173#[cfg(test)]
174mod tests {
175 use std::error::Error;
176 use std::pin::Pin;
177 use std::task::Poll;
178
179 use test_case::test_case;
180
181 use super::*;
182
183 #[tokio::main(flavor = "current_thread")]
184 #[test_case(b"a\n", TargetEol::PassThrough, b"a\n"; "LF text with no EOL conversion")]
185 #[test_case(b"a\r\n", TargetEol::PassThrough, b"a\r\n"; "CRLF text with no EOL conversion")]
186 #[test_case(b"a", TargetEol::PassThrough, b"a"; "no EOL text with no EOL conversion")]
187 #[test_case(b"a\n", TargetEol::Crlf, b"a\r\n"; "LF text with CRLF EOL conversion")]
188 #[test_case(b"a\r\n", TargetEol::Crlf, b"a\r\n"; "CRLF text with CRLF EOL conversion")]
189 #[test_case(b"a", TargetEol::Crlf, b"a"; "no EOL text with CRLF conversion")]
190 #[test_case(b"", TargetEol::Crlf, b""; "empty text with CRLF EOL conversion")]
191 #[test_case(b"a\nb", TargetEol::Crlf, b"a\r\nb"; "text ends without EOL with CRLF EOL conversion")]
192 #[test_case(b"a\n", TargetEol::Lf, b"a\n"; "LF text with LF EOL conversion")]
193 #[test_case(b"a\r\n", TargetEol::Lf, b"a\n"; "CRLF text with LF EOL conversion")]
194 #[test_case(b"a", TargetEol::Lf, b"a"; "no EOL text with LF conversion")]
195 #[test_case(b"", TargetEol::Lf, b""; "empty text with LF EOL conversion")]
196 #[test_case(b"a\r\nb", TargetEol::Lf, b"a\nb"; "text ends without EOL with LF EOL conversion")]
197 async fn test_eol_conversion(input: &[u8], target_eol: TargetEol, expected_output: &[u8]) {
198 let mut input = input;
199 let mut output = vec![];
200 convert_eol(&mut input, target_eol)
201 .await
202 .expect("Failed to call convert_eol")
203 .read_to_end(&mut output)
204 .await
205 .expect("Failed to read from the result");
206 assert_eq!(output, expected_output);
207 }
208
209 struct ErrorReader(Option<std::io::Error>);
210
211 impl ErrorReader {
212 fn new(error: std::io::Error) -> Self {
213 Self(Some(error))
214 }
215 }
216
217 impl AsyncRead for ErrorReader {
218 fn poll_read(
219 mut self: Pin<&mut Self>,
220 _cx: &mut std::task::Context<'_>,
221 _buf: &mut tokio::io::ReadBuf<'_>,
222 ) -> Poll<std::io::Result<()>> {
223 if let Some(e) = self.0.take() {
224 return Poll::Ready(Err(e));
225 }
226 Poll::Ready(Ok(()))
227 }
228 }
229
230 #[tokio::main(flavor = "current_thread")]
231 #[test_case(TargetEol::PassThrough; "no EOL conversion")]
232 #[test_case(TargetEol::Lf; "LF EOL conversion")]
233 #[test_case(TargetEol::Crlf; "CRLF EOL conversion")]
234 async fn test_eol_convert_eol_read_error(target_eol: TargetEol) {
235 let message = "test error";
236 let error_reader = ErrorReader::new(std::io::Error::other(message));
237 let mut output = vec![];
238 let err = match convert_eol(error_reader, target_eol).await {
241 Ok(mut reader) => reader.read_to_end(&mut output).await,
242 Err(e) => Err(e),
243 }
244 .expect_err("should fail");
245 let has_expected_error_message = (0..)
246 .scan(Some(&err as &(dyn Error + 'static)), |err, _| {
247 let current_err = err.take()?;
248 *err = current_err.source();
249 Some(current_err)
250 })
251 .any(|e| e.to_string() == message);
252 assert!(
253 has_expected_error_message,
254 "should have expected error message: {message}"
255 );
256 }
257
258 #[tokio::main(flavor = "current_thread")]
259 #[test_case(TargetEolStrategy {
260 eol_conversion_mode: EolConversionMode::None,
261 }, b"\r\n", b"\r\n"; "none settings")]
262 #[test_case(TargetEolStrategy {
263 eol_conversion_mode: EolConversionMode::Input,
264 }, b"\r\n", b"\n"; "input settings text input")]
265 #[test_case(TargetEolStrategy {
266 eol_conversion_mode: EolConversionMode::InputOutput,
267 }, b"\r\n", b"\n"; "input output settings text input")]
268 #[test_case(TargetEolStrategy {
269 eol_conversion_mode: EolConversionMode::Input,
270 }, b"\0\r\n", b"\0\r\n"; "input settings binary input")]
271 #[test_case(TargetEolStrategy {
272 eol_conversion_mode: EolConversionMode::InputOutput,
273 }, b"\0\r\n", b"\0\r\n"; "input output settings binary input with NUL")]
274 #[test_case(TargetEolStrategy {
275 eol_conversion_mode: EolConversionMode::InputOutput,
276 }, b"\r\r\n", b"\r\r\n"; "input output settings binary input with lone CR")]
277 #[test_case(TargetEolStrategy {
278 eol_conversion_mode: EolConversionMode::Input,
279 }, &[0; 20 << 10], &[0; 20 << 10]; "input settings long binary input")]
280 async fn test_eol_strategy_convert_eol_for_snapshot(
281 strategy: TargetEolStrategy,
282 contents: &[u8],
283 expected_output: &[u8],
284 ) {
285 let mut actual_output = vec![];
286 strategy
287 .convert_eol_for_snapshot(contents)
288 .await
289 .unwrap()
290 .read_to_end(&mut actual_output)
291 .await
292 .unwrap();
293 assert_eq!(actual_output, expected_output);
294 }
295
296 #[tokio::main(flavor = "current_thread")]
297 #[test_case(TargetEolStrategy {
298 eol_conversion_mode: EolConversionMode::None,
299 }, b"\n", b"\n"; "none settings")]
300 #[test_case(TargetEolStrategy {
301 eol_conversion_mode: EolConversionMode::Input,
302 }, b"\n", b"\n"; "input settings")]
303 #[test_case(TargetEolStrategy {
304 eol_conversion_mode: EolConversionMode::InputOutput,
305 }, b"\n", b"\r\n"; "input output settings text input")]
306 #[test_case(TargetEolStrategy {
307 eol_conversion_mode: EolConversionMode::InputOutput,
308 }, b"\0\n", b"\0\n"; "input output settings binary input")]
309 #[test_case(TargetEolStrategy {
310 eol_conversion_mode: EolConversionMode::Input,
311 }, &[0; 20 << 10], &[0; 20 << 10]; "input output settings long binary input")]
312 async fn test_eol_strategy_convert_eol_for_update(
313 strategy: TargetEolStrategy,
314 contents: &[u8],
315 expected_output: &[u8],
316 ) {
317 let mut actual_output = vec![];
318 strategy
319 .convert_eol_for_update(contents)
320 .await
321 .unwrap()
322 .read_to_end(&mut actual_output)
323 .await
324 .unwrap();
325 assert_eq!(actual_output, expected_output);
326 }
327}