1use std::{
4 fs::File,
5 io::Write,
6 path::{Path, PathBuf},
7};
8
9use git2::{IndexAddOption, Repository, RepositoryOpenFlags};
10use scopetime::scope_time;
11
12use super::{repository::repo, CommitId, RepoPath, ShowUntrackedFilesConfig};
13use crate::{
14 error::{Error, Result},
15 sync::config::untracked_files_config_repo,
16};
17
18#[derive(PartialEq, Eq, Debug, Clone)]
20pub struct Head {
21 pub name: String,
23 pub id: CommitId,
25}
26
27pub fn repo_open_error(repo_path: &RepoPath) -> Option<String> {
29 Repository::open_ext(
30 repo_path.gitpath(),
31 RepositoryOpenFlags::empty(),
32 Vec::<&Path>::new(),
33 )
34 .map_or_else(|e| Some(e.to_string()), |_| None)
35}
36
37pub(crate) fn work_dir(repo: &Repository) -> Result<&Path> {
39 repo.workdir().ok_or(Error::NoWorkDir)
40}
41
42pub fn repo_dir(repo_path: &RepoPath) -> Result<PathBuf> {
44 let repo = repo(repo_path)?;
45 Ok(repo.path().to_owned())
46}
47
48pub fn repo_work_dir(repo_path: &RepoPath) -> Result<String> {
50 let repo = repo(repo_path)?;
51 work_dir(&repo)?.to_str().map_or_else(
52 || Err(Error::Generic("invalid workdir".to_string())),
53 |workdir| Ok(workdir.to_string()),
54 )
55}
56
57pub fn get_head(repo_path: &RepoPath) -> Result<CommitId> {
59 let repo = repo(repo_path)?;
60 get_head_repo(&repo)
61}
62
63pub fn get_head_tuple(repo_path: &RepoPath) -> Result<Head> {
65 let repo = repo(repo_path)?;
66 let id = get_head_repo(&repo)?;
67 let name = get_head_refname(&repo)?;
68
69 Ok(Head { name, id })
70}
71
72pub fn get_head_refname(repo: &Repository) -> Result<String> {
74 let head = repo.head()?;
75 let ref_name = bytes2string(head.name_bytes())?;
76
77 Ok(ref_name)
78}
79
80pub fn get_head_repo(repo: &Repository) -> Result<CommitId> {
82 scope_time!("get_head_repo");
83
84 let head = repo.head()?.target();
85
86 head.map_or(Err(Error::NoHead), |head_id| Ok(head_id.into()))
87}
88
89pub fn stage_add_file(repo_path: &RepoPath, path: &Path) -> Result<()> {
92 scope_time!("stage_add_file");
93
94 let repo = repo(repo_path)?;
95
96 let mut index = repo.index()?;
97
98 index.add_path(path)?;
99 index.write()?;
100
101 Ok(())
102}
103
104pub fn stage_add_all(
107 repo_path: &RepoPath,
108 pattern: &str,
109 stage_untracked: Option<ShowUntrackedFilesConfig>,
110) -> Result<()> {
111 scope_time!("stage_add_all");
112
113 let repo = repo(repo_path)?;
114
115 let mut index = repo.index()?;
116
117 let stage_untracked = if let Some(config) = stage_untracked {
118 config
119 } else {
120 untracked_files_config_repo(&repo)?
121 };
122
123 if stage_untracked.include_untracked() {
124 index.add_all(vec![pattern], IndexAddOption::DEFAULT, None)?;
125 } else {
126 index.update_all(vec![pattern], None)?;
127 }
128
129 index.write()?;
130
131 Ok(())
132}
133
134pub fn undo_last_commit(repo_path: &RepoPath) -> Result<()> {
136 let repo = repo(repo_path)?;
137 let previous_commit = repo.revparse_single("HEAD~")?;
138
139 Repository::reset(&repo, &previous_commit, git2::ResetType::Soft, None)?;
140
141 Ok(())
142}
143
144pub fn stage_addremoved(repo_path: &RepoPath, path: &Path) -> Result<()> {
146 scope_time!("stage_addremoved");
147
148 let repo = repo(repo_path)?;
149
150 let mut index = repo.index()?;
151
152 index.remove_path(path)?;
153 index.write()?;
154
155 Ok(())
156}
157
158pub(crate) fn bytes2string(bytes: &[u8]) -> Result<String> {
159 Ok(String::from_utf8(bytes.to_vec())?)
160}
161
162pub(crate) fn repo_write_file(repo: &Repository, file: &str, content: &str) -> Result<()> {
164 let dir = work_dir(repo)?.join(file);
165 let file_path = dir
166 .to_str()
167 .ok_or_else(|| Error::Generic(String::from("invalid file path")))?;
168 let mut file = File::create(file_path)?;
169 file.write_all(content.as_bytes())?;
170 Ok(())
171}
172
173pub fn read_file(path: &Path) -> Result<String> {
175 use std::io::Read;
176
177 let mut file = File::open(path)?;
178 let mut buffer = Vec::new();
179 file.read_to_end(&mut buffer)?;
180
181 Ok(String::from_utf8(buffer)?)
182}
183
184#[cfg(test)]
185pub(crate) fn repo_read_file(repo: &Repository, file: &str) -> Result<String> {
186 use std::io::Read;
187
188 let dir = work_dir(repo)?.join(file);
189 let file_path = dir
190 .to_str()
191 .ok_or_else(|| Error::Generic(String::from("invalid file path")))?;
192
193 let mut file = File::open(file_path)?;
194 let mut buffer = Vec::new();
195 file.read_to_end(&mut buffer)?;
196
197 Ok(String::from_utf8(buffer)?)
198}
199
200#[cfg(test)]
201mod tests {
202 use std::{
203 fs::{self, remove_file, File},
204 io::Write,
205 path::Path,
206 };
207
208 use super::*;
209 use crate::sync::{
210 commit,
211 diff::get_diff,
212 status::{get_status, StatusType},
213 tests::{debug_cmd_print, get_statuses, repo_init, repo_init_empty, write_commit_file},
214 };
215
216 #[test]
217 fn test_stage_add_smoke() {
218 let file_path = Path::new("foo");
219 let (_td, repo) = repo_init_empty().unwrap();
220 let root = repo.path().parent().unwrap();
221 let repo_path = root.as_os_str().to_str().unwrap();
222
223 assert_eq!(stage_add_file(&repo_path.into(), file_path).is_ok(), false);
224 }
225
226 #[test]
227 fn test_staging_one_file() {
228 let file_path = Path::new("file1.txt");
229 let (_td, repo) = repo_init().unwrap();
230 let root = repo.path().parent().unwrap();
231 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
232
233 File::create(root.join(file_path))
234 .unwrap()
235 .write_all(b"test file1 content")
236 .unwrap();
237
238 File::create(root.join(Path::new("file2.txt")))
239 .unwrap()
240 .write_all(b"test file2 content")
241 .unwrap();
242
243 assert_eq!(get_statuses(repo_path), (2, 0));
244
245 stage_add_file(repo_path, file_path).unwrap();
246
247 assert_eq!(get_statuses(repo_path), (1, 1));
248 }
249
250 #[test]
251 fn test_staging_folder() -> Result<()> {
252 let (_td, repo) = repo_init().unwrap();
253 let root = repo.path().parent().unwrap();
254 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
255
256 let status_count =
257 |s: StatusType| -> usize { get_status(repo_path, s, None).unwrap().len() };
258
259 fs::create_dir_all(root.join("a/d"))?;
260 File::create(root.join(Path::new("a/d/f1.txt")))?.write_all(b"foo")?;
261 File::create(root.join(Path::new("a/d/f2.txt")))?.write_all(b"foo")?;
262 File::create(root.join(Path::new("a/f3.txt")))?.write_all(b"foo")?;
263
264 assert_eq!(status_count(StatusType::WorkingDir), 3);
265
266 stage_add_all(repo_path, "a/d", None).unwrap();
267
268 assert_eq!(status_count(StatusType::WorkingDir), 1);
269 assert_eq!(status_count(StatusType::Stage), 2);
270
271 Ok(())
272 }
273
274 #[test]
275 fn test_undo_commit_empty_repo() {
276 let (_td, repo) = repo_init().unwrap();
277 let root = repo.path().parent().unwrap();
278 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
279
280 assert!(undo_last_commit(repo_path).is_err());
282 }
283
284 #[test]
285 fn test_undo_commit() {
286 let (_td, repo) = repo_init().unwrap();
287 let root = repo.path().parent().unwrap();
288 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
289
290 let c1 = write_commit_file(&repo, "test.txt", "content1", "c1");
292 let _c2 = write_commit_file(&repo, "test.txt", "content2", "c2");
293 assert!(undo_last_commit(repo_path).is_ok());
294
295 assert_eq!(c1, get_head_repo(&repo).unwrap());
297
298 assert_eq!(get_statuses(repo_path), (0, 1));
300
301 let diff = get_diff(repo_path, "test.txt", true, None).unwrap();
303 assert_eq!(&*diff.hunks[0].lines[0].content, "@@ -1 +1 @@");
304 }
305
306 #[test]
307 fn test_not_staging_untracked_folder() -> Result<()> {
308 let (_td, repo) = repo_init().unwrap();
309 let root = repo.path().parent().unwrap();
310 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
311
312 fs::create_dir_all(root.join("a/d"))?;
313 File::create(root.join(Path::new("a/d/f1.txt")))?.write_all(b"foo")?;
314 File::create(root.join(Path::new("a/d/f2.txt")))?.write_all(b"foo")?;
315 File::create(root.join(Path::new("f3.txt")))?.write_all(b"foo")?;
316
317 assert_eq!(get_statuses(repo_path), (3, 0));
318
319 repo.config()?.set_str("status.showUntrackedFiles", "no")?;
320
321 assert_eq!(get_statuses(repo_path), (0, 0));
322
323 stage_add_all(repo_path, "*", None).unwrap();
324
325 assert_eq!(get_statuses(repo_path), (0, 0));
326
327 Ok(())
328 }
329
330 #[test]
331 fn test_staging_deleted_file() {
332 let file_path = Path::new("file1.txt");
333 let (_td, repo) = repo_init().unwrap();
334 let root = repo.path().parent().unwrap();
335 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
336
337 let status_count =
338 |s: StatusType| -> usize { get_status(repo_path, s, None).unwrap().len() };
339
340 let full_path = &root.join(file_path);
341
342 File::create(full_path)
343 .unwrap()
344 .write_all(b"test file1 content")
345 .unwrap();
346
347 stage_add_file(repo_path, file_path).unwrap();
348
349 commit(repo_path, "commit msg").unwrap();
350
351 assert_eq!(remove_file(full_path).is_ok(), true);
353
354 assert_eq!(status_count(StatusType::WorkingDir), 1);
356
357 stage_addremoved(repo_path, file_path).unwrap();
358
359 assert_eq!(status_count(StatusType::WorkingDir), 0);
360 assert_eq!(status_count(StatusType::Stage), 1);
361 }
362
363 #[test]
365 fn test_staging_sub_git_folder() -> Result<()> {
366 let (_td, repo) = repo_init().unwrap();
367 let root = repo.path().parent().unwrap();
368 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
369
370 let status_count =
371 |s: StatusType| -> usize { get_status(repo_path, s, None).unwrap().len() };
372
373 let sub = &root.join("sub");
374
375 fs::create_dir_all(sub)?;
376
377 debug_cmd_print(&sub.to_str().unwrap().into(), "git init subgit");
378
379 File::create(sub.join("subgit/foo.txt"))
380 .unwrap()
381 .write_all(b"content")
382 .unwrap();
383
384 assert_eq!(status_count(StatusType::WorkingDir), 1);
385
386 assert!(stage_add_all(repo_path, "sub", None).is_err());
388
389 Ok(())
390 }
391
392 #[test]
393 fn test_head_empty() -> Result<()> {
394 let (_td, repo) = repo_init_empty()?;
395 let root = repo.path().parent().unwrap();
396 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
397
398 assert_eq!(get_head(repo_path).is_ok(), false);
399
400 Ok(())
401 }
402
403 #[test]
404 fn test_head() -> Result<()> {
405 let (_td, repo) = repo_init()?;
406 let root = repo.path().parent().unwrap();
407 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
408
409 assert_eq!(get_head(repo_path).is_ok(), true);
410
411 Ok(())
412 }
413}