1use std::io::{self, Read, Write};
4use std::process::{Command, Stdio};
5
6use git_lfs_pointer::{Extension, Oid, Pointer};
7use git_lfs_store::{Store, StoreError};
8use sha2::{Digest, Sha256};
9use tempfile::NamedTempFile;
10
11use crate::detect_pointer;
12
13const COPY_BUFFER: usize = 64 * 1024;
14
15#[derive(Debug)]
17pub enum CleanOutcome {
18 Passthrough(Pointer),
22 Stored(Pointer),
26}
27
28impl CleanOutcome {
29 pub fn pointer(&self) -> &Pointer {
32 match self {
33 Self::Passthrough(p) | Self::Stored(p) => p,
34 }
35 }
36
37 pub fn was_passthrough(&self) -> bool {
39 matches!(self, Self::Passthrough(_))
40 }
41}
42
43#[derive(Debug, Clone)]
49pub struct CleanExtension {
50 pub name: String,
51 pub priority: u8,
52 pub command: String,
53}
54
55#[derive(Debug, thiserror::Error)]
56pub enum CleanError {
57 #[error(transparent)]
58 Io(#[from] io::Error),
59 #[error(transparent)]
60 Store(#[from] StoreError),
61 #[error("extension {name:?} has no clean command configured")]
62 ExtensionMissingCommand { name: String },
63 #[error("failed to spawn extension {name:?}: {source}")]
64 ExtensionSpawnFailed {
65 name: String,
66 #[source]
67 source: io::Error,
68 },
69 #[error("extension {name:?} exited with status {status:?}")]
70 ExtensionFailed { name: String, status: Option<i32> },
71}
72
73pub fn clean<R: Read, W: Write>(
90 store: &Store,
91 input: &mut R,
92 output: &mut W,
93 path: &str,
94 extensions: &[CleanExtension],
95) -> Result<CleanOutcome, CleanError> {
96 let (head, maybe_pointer) = detect_pointer(input)?;
97
98 if let Some(pointer) = maybe_pointer {
99 output.write_all(&head)?;
100 return Ok(CleanOutcome::Passthrough(pointer));
101 }
102
103 if extensions.is_empty() {
104 let mut combined = head.as_slice().chain(input);
105 let (oid, size) = store.insert(&mut combined)?;
106 let pointer = Pointer::new(oid, size);
107 output.write_all(pointer.encode().as_bytes())?;
108 return Ok(CleanOutcome::Stored(pointer));
109 }
110
111 for ext in extensions {
112 if ext.command.trim().is_empty() {
113 return Err(CleanError::ExtensionMissingCommand {
114 name: ext.name.clone(),
115 });
116 }
117 }
118
119 let tmp_dir = store.tmp_dir();
120 std::fs::create_dir_all(&tmp_dir)?;
121
122 let mut combined = head.as_slice().chain(input);
124 let mut current_tmp = NamedTempFile::new_in(&tmp_dir)?;
125 let orig_oid = hash_and_write(&mut combined, current_tmp.as_file_mut())?;
126 let mut input_oids: Vec<Oid> = Vec::with_capacity(extensions.len());
127 input_oids.push(orig_oid);
128
129 for (i, ext) in extensions.iter().enumerate() {
133 let cmd_str = ext.command.replace("%f", path);
134 let mut parts = cmd_str.split_whitespace();
135 let prog = parts
136 .next()
137 .ok_or_else(|| CleanError::ExtensionMissingCommand {
138 name: ext.name.clone(),
139 })?;
140 let args: Vec<&str> = parts.collect();
141
142 let stdin_file = std::fs::File::open(current_tmp.path())?;
143 let mut child = Command::new(prog)
144 .args(&args)
145 .stdin(stdin_file)
146 .stdout(Stdio::piped())
147 .stderr(Stdio::inherit())
148 .spawn()
149 .map_err(|e| CleanError::ExtensionSpawnFailed {
150 name: ext.name.clone(),
151 source: e,
152 })?;
153 let mut stdout = child.stdout.take().expect("piped stdout");
154
155 let is_last = i + 1 == extensions.len();
156 if is_last {
157 let (oid, size) = store.insert(&mut stdout)?;
158 let status = child.wait()?;
159 if !status.success() {
160 return Err(CleanError::ExtensionFailed {
161 name: ext.name.clone(),
162 status: status.code(),
163 });
164 }
165
166 let pointer_extensions = build_pointer_extensions(extensions, &input_oids);
167 let pointer = Pointer {
168 oid,
169 size,
170 extensions: pointer_extensions,
171 canonical: true,
172 };
173 output.write_all(pointer.encode().as_bytes())?;
174 return Ok(CleanOutcome::Stored(pointer));
175 }
176
177 let mut next_tmp = NamedTempFile::new_in(&tmp_dir)?;
178 let next_oid = hash_and_write(&mut stdout, next_tmp.as_file_mut())?;
179 let status = child.wait()?;
180 if !status.success() {
181 return Err(CleanError::ExtensionFailed {
182 name: ext.name.clone(),
183 status: status.code(),
184 });
185 }
186
187 current_tmp = next_tmp;
188 input_oids.push(next_oid);
189 }
190
191 unreachable!("clean loop exited without storing")
194}
195
196fn build_pointer_extensions(extensions: &[CleanExtension], input_oids: &[Oid]) -> Vec<Extension> {
197 extensions
198 .iter()
199 .enumerate()
200 .map(|(i, ext)| Extension {
201 name: ext.name.clone(),
202 priority: ext.priority,
203 oid: input_oids[i],
204 })
205 .collect()
206}
207
208fn hash_and_write<R: Read>(src: &mut R, dst: &mut std::fs::File) -> io::Result<Oid> {
209 let mut hasher = Sha256::new();
210 let mut buf = vec![0u8; COPY_BUFFER];
211 loop {
212 let n = src.read(&mut buf)?;
213 if n == 0 {
214 break;
215 }
216 hasher.update(&buf[..n]);
217 dst.write_all(&buf[..n])?;
218 }
219 dst.flush()?;
220 let bytes: [u8; 32] = hasher.finalize().into();
221 Ok(Oid::from_bytes(bytes))
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use git_lfs_pointer::VERSION_LATEST;
228 use tempfile::TempDir;
229
230 fn fixture() -> (TempDir, Store) {
231 let tmp = TempDir::new().unwrap();
232 let store = Store::new(tmp.path().join("lfs"));
233 (tmp, store)
234 }
235
236 fn run(store: &Store, input: &[u8]) -> (CleanOutcome, Vec<u8>) {
237 let mut out = Vec::new();
238 let outcome = clean(store, &mut { input }, &mut out, "", &[]).unwrap();
239 (outcome, out)
240 }
241
242 #[test]
245 fn small_content_is_hashed_and_stored() {
246 let (_t, store) = fixture();
247 let (outcome, out) = run(&store, b"hello world!");
248 let p = match outcome {
249 CleanOutcome::Stored(p) => p,
250 o => panic!("expected Stored, got {o:?}"),
251 };
252 assert_eq!(p.size, 12);
253 assert!(store.contains(p.oid));
254 assert_eq!(out, p.encode().as_bytes());
255 }
256
257 #[test]
258 fn known_sha256_for_abc() {
259 let (_t, store) = fixture();
260 let (outcome, _) = run(&store, b"abc");
261 let expected: Oid = "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
262 .parse()
263 .unwrap();
264 assert_eq!(outcome.pointer().oid, expected);
265 }
266
267 #[test]
268 fn pseudo_pointer_with_extra_text_is_hashed() {
269 let input = b"version https://git-lfs.github.com/spec/v1\n\
270 oid sha256:7cd8be1d2cd0dd22cd9d229bb6b5785009a05e8b39d405615d882caac56562b5\n\
271 size 1024\n\
272 \n\
273 This is my test pointer.\n";
274 let (_t, store) = fixture();
275 let (outcome, out) = run(&store, input);
276 let p = match outcome {
277 CleanOutcome::Stored(p) => p,
278 o => panic!("expected Stored, got {o:?}"),
279 };
280 assert_eq!(p.size, input.len() as u64);
281 assert!(store.contains(p.oid));
282 assert_eq!(out, p.encode().as_bytes());
283 }
284
285 #[test]
286 fn oversized_pointer_shaped_input_is_hashed() {
287 let mut input = Vec::from(
288 &b"version https://git-lfs.github.com/spec/v1\n\
289 oid sha256:cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc\n\
290 size 5\n"[..],
291 );
292 input.extend(std::iter::repeat_n(b'x', 2000));
293 let (_t, store) = fixture();
294 let (outcome, _) = run(&store, &input);
295 let p = match outcome {
296 CleanOutcome::Stored(p) => p,
297 o => panic!("expected Stored, got {o:?}"),
298 };
299 assert_eq!(p.size, input.len() as u64);
300 assert!(store.contains(p.oid));
301 }
302
303 #[test]
304 fn streaming_megabyte_input_works() {
305 let (_t, store) = fixture();
306 let content: Vec<u8> = (0..1_048_576u32).map(|i| (i ^ (i >> 5)) as u8).collect();
307 let (outcome, _) = run(&store, &content);
308 assert_eq!(outcome.pointer().size, content.len() as u64);
309 assert!(store.contains(outcome.pointer().oid));
310 }
311
312 #[test]
315 fn canonical_pointer_passes_through_verbatim() {
316 let (_t, store) = fixture();
317 let oid_hex = "4d7a214614ab2935c943f9e0ff69d22eadbb8f32b1258daaa5e2ca24d17e2393";
318 let pointer_text = format!("version {VERSION_LATEST}\noid sha256:{oid_hex}\nsize 12345\n");
319 let (outcome, out) = run(&store, pointer_text.as_bytes());
320 match &outcome {
321 CleanOutcome::Passthrough(p) => assert!(p.canonical),
322 o => panic!("expected Passthrough, got {o:?}"),
323 }
324 assert_eq!(
325 out,
326 pointer_text.as_bytes(),
327 "output must be input verbatim"
328 );
329 assert!(!store.root().join("objects").exists());
330 }
331
332 #[test]
333 fn non_canonical_pointer_passes_through_verbatim() {
334 let (_t, store) = fixture();
338 let oid_hex = "4d7a214614ab2935c943f9e0ff69d22eadbb8f32b1258daaa5e2ca24d17e2393";
339 let crlf = format!("version {VERSION_LATEST}\r\noid sha256:{oid_hex}\r\nsize 12345\r\n");
340 let (outcome, out) = run(&store, crlf.as_bytes());
341 match &outcome {
342 CleanOutcome::Passthrough(p) => assert!(!p.canonical),
343 o => panic!("expected Passthrough, got {o:?}"),
344 }
345 assert_eq!(out, crlf.as_bytes());
346 }
347
348 #[test]
349 fn empty_input_is_passthrough_empty_pointer() {
350 let (_t, store) = fixture();
351 let (outcome, out) = run(&store, b"");
352 match &outcome {
353 CleanOutcome::Passthrough(p) => {
354 assert_eq!(p, &Pointer::empty());
355 }
356 o => panic!("expected Passthrough, got {o:?}"),
357 }
358 assert!(out.is_empty(), "empty pointer encodes to empty bytes");
359 }
360
361 #[test]
362 fn passthrough_is_idempotent() {
363 let (_t, store) = fixture();
364 let (_, first) = run(&store, b"some content here");
365 let (outcome2, second) = run(&store, &first);
366 assert!(matches!(outcome2, CleanOutcome::Passthrough(_)));
367 assert_eq!(first, second);
368 }
369
370 #[test]
377 fn single_extension_records_input_oid() {
378 let (_t, store) = fixture();
379 let exts = vec![CleanExtension {
380 name: "upper".into(),
381 priority: 0,
382 command: "tr a-z A-Z".into(),
383 }];
384
385 let mut out = Vec::new();
386 let outcome = clean(&store, &mut &b"abc"[..], &mut out, "foo.txt", &exts).unwrap();
387
388 let pointer = match outcome {
389 CleanOutcome::Stored(p) => p,
390 o => panic!("expected Stored, got {o:?}"),
391 };
392
393 let abc_oid: Oid = "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
395 .parse()
396 .unwrap();
397 let upper_oid: Oid = "b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78"
399 .parse()
400 .unwrap();
401
402 assert_eq!(pointer.extensions.len(), 1);
403 assert_eq!(pointer.extensions[0].name, "upper");
404 assert_eq!(pointer.extensions[0].priority, 0);
405 assert_eq!(pointer.extensions[0].oid, abc_oid);
406 assert_eq!(pointer.oid, upper_oid);
407 assert_eq!(pointer.size, 3);
408 assert!(store.contains(upper_oid));
409 let mut f = store.open(upper_oid).unwrap();
411 let mut bytes = Vec::new();
412 std::io::Read::read_to_end(&mut f, &mut bytes).unwrap();
413 assert_eq!(bytes, b"ABC");
414 }
415
416 #[test]
417 fn extensions_skipped_for_passthrough_pointer() {
418 let (_t, store) = fixture();
421 let oid_hex = "4d7a214614ab2935c943f9e0ff69d22eadbb8f32b1258daaa5e2ca24d17e2393";
422 let pointer_text = format!("version {VERSION_LATEST}\noid sha256:{oid_hex}\nsize 12345\n");
423 let exts = vec![CleanExtension {
424 name: "fail".into(),
425 priority: 0,
426 command: "false".into(),
428 }];
429 let mut out = Vec::new();
430 let outcome = clean(&store, &mut pointer_text.as_bytes(), &mut out, "x", &exts).unwrap();
431 assert!(matches!(outcome, CleanOutcome::Passthrough(_)));
432 assert_eq!(out, pointer_text.as_bytes());
433 }
434
435 #[test]
436 fn extension_failure_is_propagated() {
437 let (_t, store) = fixture();
438 let exts = vec![CleanExtension {
439 name: "fail".into(),
440 priority: 0,
441 command: "false".into(),
442 }];
443 let mut out = Vec::new();
444 let err = clean(&store, &mut &b"hello"[..], &mut out, "x", &exts).unwrap_err();
445 assert!(
446 matches!(err, CleanError::ExtensionFailed { .. }),
447 "got {err:?}"
448 );
449 }
450}