1use std::{
2 env, fs,
3 path::{Path, PathBuf},
4};
5
6use anyhow::Context;
7
8use crate::SHELL_NAME;
9
10pub fn get_pwd() -> anyhow::Result<PathBuf> {
11 std::env::current_dir().context("could not retrieve current working directory")
12}
13
14pub fn get_prompt() -> String {
15 let curr_dir = env::current_dir()
16 .ok()
17 .and_then(|path| path.file_name().map(|s| s.to_string_lossy().into_owned()))
18 .unwrap_or_else(|| "".to_string());
19
20 let git_branch = current_git_branch().unwrap_or_default();
21 let branch_part = if git_branch.is_empty() {
22 String::new()
23 } else {
24 format!(" ({git_branch})")
25 };
26
27 format!(
28 "\x1b[32m{}\x1b[0m [{}{}] \x1b[1m>\x1b[0m ",
29 SHELL_NAME, curr_dir, branch_part
30 )
31}
32
33fn current_git_branch() -> Option<String> {
34 let cwd = env::current_dir().ok()?;
35 let git_dir = find_git_dir(&cwd)?;
36 let head_path = git_dir.join("HEAD");
37 let head_contents = fs::read_to_string(head_path).ok()?;
38
39 if let Some(rest) = head_contents.trim().strip_prefix("ref:") {
40 let ref_path = rest.trim();
41 return Path::new(ref_path)
42 .file_name()
43 .and_then(|s| Some(s.to_string_lossy().into_owned()));
44 }
45
46 let sha = head_contents.trim();
48 if sha.is_empty() {
49 None
50 } else {
51 Some(sha.chars().take(7).collect())
52 }
53}
54
55fn find_git_dir(start: &Path) -> Option<PathBuf> {
56 let mut dir = Some(start);
57
58 while let Some(current) = dir {
59 let candidate = current.join(".git");
60 if candidate.is_dir() {
61 return Some(candidate);
62 }
63 dir = current.parent();
64 }
65
66 None
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72
73 #[test]
74 fn prompt_includes_current_directory_name() {
75 let cwd = env::current_dir().unwrap();
76 let name = cwd
77 .file_name()
78 .map(|s| s.to_string_lossy().into_owned())
79 .unwrap_or_default();
80
81 let prompt = get_prompt();
82 assert!(prompt.contains(&name));
83 }
84
85 #[test]
86 fn prompt_includes_git_branch_when_in_repo() {
87 let tmp = env::temp_dir().join(format!("ase_git_test_{}", std::process::id()));
88 fs::create_dir_all(&tmp).unwrap();
89 let git_dir = tmp.join(".git");
90 fs::create_dir_all(&git_dir).unwrap();
91
92 let head_path = git_dir.join("HEAD");
93 fs::write(&head_path, "ref: refs/heads/feature/test\n").unwrap();
94
95 let old_cwd = env::current_dir().unwrap();
96 env::set_current_dir(&tmp).unwrap();
97
98 let prompt = get_prompt();
99
100 env::set_current_dir(old_cwd).unwrap();
101 fs::remove_file(&head_path).ok();
102 fs::remove_dir_all(&git_dir).ok();
103 fs::remove_dir_all(&tmp).ok();
104
105 assert!(prompt.contains("(test)"));
106 }
107}