1use std::ffi::OsString;
29use std::io::Write;
30use std::path::{Component, Path, PathBuf};
31
32use clap::Parser;
33use mkit_core::hash::Hash;
34use mkit_core::index::{self, EntryStatus, Index, IndexEntry};
35use mkit_core::object::Object;
36use mkit_core::ops::restore::{RestoreOptions, SparsePattern, restore_tree_to_worktree};
37use mkit_core::store::ObjectStore;
38use mkit_core::worktree;
39
40use crate::clap_shim;
41use crate::exit;
42
43#[derive(Debug, Parser)]
44#[command(
45 name = "mkit restore",
46 about = "Restore worktree files (discard local changes) or unstage them."
47)]
48struct RestoreOpts {
49 #[arg(short = 'S', long)]
53 staged: bool,
54
55 #[arg(short = 'W', long)]
59 worktree: bool,
60
61 #[arg(long, value_name = "REV")]
66 source: Option<String>,
67
68 #[arg(short = 'f', long)]
71 force: bool,
72
73 #[arg(required = true)]
76 paths: Vec<String>,
77}
78
79#[must_use]
80pub fn run(args: &[String]) -> u8 {
81 let opts = match clap_shim::parse::<RestoreOpts>("mkit restore", args) {
82 Ok(o) => o,
83 Err(code) => return code,
84 };
85 let cwd = match std::env::current_dir() {
86 Ok(p) => p,
87 Err(e) => return emit_err(&format!("cwd: {e}"), exit::NOINPUT),
88 };
89 let mkit_dir = cwd.join(mkit_core::MKIT_DIR);
90 let store = match ObjectStore::open(&cwd) {
91 Ok(s) => s,
92 Err(e) => return emit_err(&format!("not a mkit repo: {e}"), exit::GENERAL_ERROR),
93 };
94 let _lock = match super::acquire_worktree_lock(&cwd) {
95 Ok(l) => l,
96 Err(code) => return code,
97 };
98
99 let do_staged = opts.staged;
103 let do_worktree = opts.worktree || !opts.staged;
104
105 let mut idx = match super::read_or_seed_index_from_head(&cwd, &store) {
106 Ok(i) => i,
107 Err(e) => return emit_err(&e, exit::GENERAL_ERROR),
108 };
109
110 let head_tree = match super::current_head_tree(&cwd, &store) {
113 Ok(t) => t,
114 Err(e) => return emit_err(&e, exit::GENERAL_ERROR),
115 };
116
117 let source_tree: Option<Hash> = match &opts.source {
120 Some(spec) => match resolve_source_tree(&store, &mkit_dir, spec) {
121 Ok(t) => Some(t),
122 Err((msg, code)) => return emit_err(&msg, code),
123 },
124 None => None,
125 };
126
127 let restore_index: Option<Index> = match resolve_restore_index(&store, source_tree, head_tree) {
132 Ok(i) => i,
133 Err(e) => return emit_err(&e, exit::GENERAL_ERROR),
134 };
135
136 let mut rels: Vec<String> = Vec::with_capacity(opts.paths.len());
138 for raw in &opts.paths {
139 match index_path_for_arg(&cwd, Path::new(raw)) {
140 Ok(p) => rels.push(p),
141 Err(e) => return emit_err(&e, exit::DATAERR),
142 }
143 }
144
145 if do_staged && let Err(code) = restore_staged(&cwd, &mut idx, restore_index.as_ref(), &rels) {
146 return code;
147 }
148
149 if do_worktree
150 && let Err(code) = restore_worktree(
151 &cwd,
152 &store,
153 &idx,
154 restore_index.as_ref(),
155 &rels,
156 source_tree.is_some(),
157 opts.force,
158 )
159 {
160 return code;
161 }
162
163 exit::OK
164}
165
166fn resolve_source_tree(
168 store: &ObjectStore,
169 mkit_dir: &Path,
170 spec: &str,
171) -> Result<Hash, (String, u8)> {
172 let commit = super::revspec::resolve_revision(store, mkit_dir, spec)
173 .map_err(|e| (format!("bad --source '{spec}': {e}"), exit::GENERAL_ERROR))?;
174 match store.read_object(&commit) {
175 Ok(Object::Commit(c)) => Ok(c.tree_hash),
176 Ok(Object::Remix(r)) => Ok(r.tree_hash),
177 Ok(Object::Tree(_)) => Ok(commit),
178 Ok(_) => Err((
179 format!("--source '{spec}' does not resolve to a commit or tree"),
180 exit::GENERAL_ERROR,
181 )),
182 Err(e) => Err((format!("read --source object: {e}"), exit::GENERAL_ERROR)),
183 }
184}
185
186fn resolve_restore_index(
191 store: &ObjectStore,
192 source_tree: Option<Hash>,
193 head_tree: Option<Hash>,
194) -> Result<Option<Index>, String> {
195 let tree = source_tree.or(head_tree);
196 match tree {
197 Some(t) => index::from_tree(store, t)
198 .map(Some)
199 .map_err(|e| format!("read source tree: {e}")),
200 None => Ok(None),
201 }
202}
203
204fn restore_staged(
209 cwd: &Path,
210 idx: &mut Index,
211 restore_index: Option<&Index>,
212 rels: &[String],
213) -> Result<(), u8> {
214 let mut matched_any = false;
215 for rel in rels {
216 let in_index = entry_matches(idx, rel);
217 let in_source = restore_index
218 .map(|src| entry_matches(src, rel))
219 .unwrap_or_default();
220 if in_index.is_empty() && in_source.is_empty() {
221 return Err(emit_err(
222 &format!("pathspec '{rel}' did not match any tracked or staged files"),
223 exit::GENERAL_ERROR,
224 ));
225 }
226 matched_any = true;
227
228 let mut affected: Vec<String> = in_index
230 .iter()
231 .chain(in_source.iter())
232 .map(|e| e.path.clone())
233 .collect();
234 affected.sort_unstable();
235 affected.dedup();
236
237 for path in affected {
238 let source_entry =
239 restore_index.and_then(|src| src.entries.iter().find(|e| e.path == path).cloned());
240 apply_index_restore(idx, &path, source_entry);
241 }
242 }
243
244 if !matched_any {
245 return Ok(());
246 }
247 index::write_index(cwd, idx)
248 .map_err(|e| emit_err(&format!("write index: {e}"), exit::CANTCREAT))
249}
250
251fn apply_index_restore(idx: &mut Index, path: &str, source: Option<IndexEntry>) {
253 match source {
254 Some(src) => {
255 if let Some(pos) = idx.entries.iter().position(|e| e.path == path) {
256 idx.entries[pos] = src;
257 } else {
258 idx.entries.push(src);
259 }
260 }
261 None => {
262 idx.entries.retain(|e| e.path != path);
265 }
266 }
267}
268
269fn restore_worktree(
274 cwd: &Path,
275 store: &ObjectStore,
276 idx: &Index,
277 restore_index: Option<&Index>,
278 rels: &[String],
279 explicit_source: bool,
280 force: bool,
281) -> Result<(), u8> {
282 let source = if explicit_source {
285 restore_index.unwrap_or(idx)
286 } else {
287 idx
288 };
289
290 let mut to_write: Vec<IndexEntry> = Vec::new();
292 for rel in rels {
293 let matches = entry_matches(source, rel);
294 if matches.is_empty() {
295 return Err(emit_err(
296 &format!("pathspec '{rel}' did not match any tracked files"),
297 exit::GENERAL_ERROR,
298 ));
299 }
300 to_write.extend(matches);
301 }
302 to_write.sort_by(|a, b| a.path.cmp(&b.path));
303 to_write.dedup_by(|a, b| a.path == b.path);
304
305 if !force {
310 for entry in &to_write {
311 if let Some(reason) = dirty_reason(cwd, store, idx, &entry.path) {
312 return Err(emit_err(&reason, exit::GENERAL_ERROR));
313 }
314 }
315 }
316
317 let source_tree = match worktree::build_tree_from_index(store, source) {
323 Ok(t) => t,
324 Err(e) => {
325 return Err(emit_err(
326 &format!("build source tree: {e}"),
327 exit::GENERAL_ERROR,
328 ));
329 }
330 };
331 let patterns: Vec<SparsePattern> = to_write
332 .iter()
333 .map(|e| SparsePattern {
334 pattern: e.path.clone(),
335 negated: false,
336 dir_only: false,
337 })
338 .collect();
339 let restore_opts = RestoreOptions {
340 clean: false,
341 sparse_patterns: Some(patterns),
342 };
343 if let Err(e) = restore_tree_to_worktree(store, &source_tree, cwd, &restore_opts) {
344 return Err(emit_err(&format!("restore worktree: {e}"), exit::CANTCREAT));
345 }
346 Ok(())
347}
348
349fn entry_matches(idx: &Index, rel: &str) -> Vec<IndexEntry> {
351 idx.entries
352 .iter()
353 .filter(|e| {
354 e.status != EntryStatus::Removed && super::index_path_matches_or_descends(&e.path, rel)
355 })
356 .cloned()
357 .collect()
358}
359
360fn dirty_reason(root: &Path, _store: &ObjectStore, idx: &Index, path: &str) -> Option<String> {
364 let staged = idx
365 .entries
366 .iter()
367 .find(|e| e.path == path && e.status != EntryStatus::Removed)?;
368 let abs = root.join(path);
369 let meta = abs.symlink_metadata().ok()?;
370 let work_hash = if meta.file_type().is_symlink() {
371 let target = std::fs::read_link(&abs).ok()?;
372 let target_str = target.to_str()?;
373 symlink_blob_hash(target_str)?
374 } else if meta.file_type().is_file() {
375 worktree::read_regular_file_bounded(&abs)
376 .ok()
377 .and_then(|(_, data)| worktree::hash_file_object(&data).ok())?
378 } else {
379 return None;
381 };
382 if work_hash == staged.object_hash {
383 None
384 } else {
385 Some(format!(
386 "'{path}' has unstaged changes; use --force to discard them"
387 ))
388 }
389}
390
391fn symlink_blob_hash(target: &str) -> Option<Hash> {
393 let prologue = mkit_core::serialize::blob_prologue(target.len()).ok()?;
396 let mut hasher = mkit_core::hash::Hasher::new();
397 hasher.update(&prologue).update(target.as_bytes());
398 Some(hasher.finalize())
399}
400
401fn index_path_for_arg(root: &Path, arg: &Path) -> Result<String, String> {
405 let rel = if arg.is_absolute() {
406 absolute_arg_to_repo_relative(root, arg)?
407 } else {
408 arg.to_path_buf()
409 };
410
411 let mut parts: Vec<String> = Vec::new();
412 for component in rel.as_path().components() {
413 match component {
414 Component::Normal(part) => {
415 let part = part
416 .to_str()
417 .ok_or_else(|| "path is not valid UTF-8".to_string())?;
418 parts.push(part.to_string());
419 }
420 Component::CurDir => {}
421 Component::ParentDir => {
422 if parts.pop().is_none() {
423 return Err(format!("invalid path: {}", arg.display()));
424 }
425 }
426 Component::Prefix(_) | Component::RootDir => {
427 return Err(format!("invalid path: {}", arg.display()));
428 }
429 }
430 }
431
432 let path = parts.join("/");
433 if !index::validate_index_path(&path) {
434 return Err(format!("invalid path: {path}"));
435 }
436 Ok(path)
437}
438
439fn absolute_arg_to_repo_relative(root: &Path, arg: &Path) -> Result<PathBuf, String> {
440 let root = root.canonicalize().map_err(|e| format!("repo root: {e}"))?;
441
442 if let Ok(rel) = arg.strip_prefix(&root) {
443 return Ok(rel.to_path_buf());
444 }
445
446 let mut suffix: Vec<OsString> = vec![
447 arg.file_name()
448 .ok_or_else(|| format!("invalid path: {}", arg.display()))?
449 .to_os_string(),
450 ];
451 let mut ancestor = arg
452 .parent()
453 .ok_or_else(|| format!("invalid path: {}", arg.display()))?;
454 while ancestor.symlink_metadata().is_err() {
455 let name = ancestor
456 .file_name()
457 .ok_or_else(|| format!("path is outside repository: {}", arg.display()))?;
458 suffix.push(name.to_os_string());
459 ancestor = ancestor
460 .parent()
461 .ok_or_else(|| format!("path is outside repository: {}", arg.display()))?;
462 }
463
464 let mut normalized = ancestor
465 .canonicalize()
466 .map_err(|e| format!("path {}: {e}", ancestor.display()))?;
467 for component in suffix.iter().rev() {
468 normalized.push(component);
469 }
470
471 normalized
472 .strip_prefix(&root)
473 .map(Path::to_path_buf)
474 .map_err(|_| format!("path is outside repository: {}", arg.display()))
475}
476
477fn emit_err(msg: &str, code: u8) -> u8 {
478 let mut stderr = std::io::stderr().lock();
479 let _ = writeln!(stderr, "error: {msg}");
480 code
481}