Skip to main content

ase_shell/
utils.rs

1use std::{
2  env, fs,
3  path::{Path, PathBuf},
4};
5
6use anyhow::Context;
7
8use crate::SHELL_NAME;
9
10pub fn get_pwd() -> anyhow::Result<PathBuf> {
11  std::env::current_dir().context("could not retrieve current working directory")
12}
13
14pub fn get_prompt() -> String {
15  let curr_dir = env::current_dir()
16    .ok()
17    .and_then(|path| path.file_name().map(|s| s.to_string_lossy().into_owned()))
18    .unwrap_or_else(|| "".to_string());
19
20  let git_branch = current_git_branch().unwrap_or_default();
21  let branch_part = if git_branch.is_empty() {
22    String::new()
23  } else {
24    format!(" ({git_branch})")
25  };
26
27  format!(
28    "\x1b[32m{}\x1b[0m [{}{}] \x1b[1m>\x1b[0m ",
29    SHELL_NAME, curr_dir, branch_part
30  )
31}
32
33fn current_git_branch() -> Option<String> {
34  let cwd = env::current_dir().ok()?;
35  let git_dir = find_git_dir(&cwd)?;
36  let head_path = git_dir.join("HEAD");
37  let head_contents = fs::read_to_string(head_path).ok()?;
38
39  if let Some(rest) = head_contents.trim().strip_prefix("ref:") {
40    let ref_path = rest.trim();
41    return Path::new(ref_path)
42      .file_name()
43      .and_then(|s| Some(s.to_string_lossy().into_owned()));
44  }
45
46  // Detached HEAD: show short SHA
47  let sha = head_contents.trim();
48  if sha.is_empty() {
49    None
50  } else {
51    Some(sha.chars().take(7).collect())
52  }
53}
54
55fn find_git_dir(start: &Path) -> Option<PathBuf> {
56  let mut dir = Some(start);
57
58  while let Some(current) = dir {
59    let candidate = current.join(".git");
60    if candidate.is_dir() {
61      return Some(candidate);
62    }
63    dir = current.parent();
64  }
65
66  None
67}
68
69#[cfg(test)]
70mod tests {
71  use super::*;
72
73  #[test]
74  fn prompt_includes_current_directory_name() {
75    let cwd = env::current_dir().unwrap();
76    let name = cwd
77      .file_name()
78      .map(|s| s.to_string_lossy().into_owned())
79      .unwrap_or_default();
80
81    let prompt = get_prompt();
82    assert!(prompt.contains(&name));
83  }
84
85  #[test]
86  fn prompt_includes_git_branch_when_in_repo() {
87    let tmp = env::temp_dir().join(format!("ase_git_test_{}", std::process::id()));
88    fs::create_dir_all(&tmp).unwrap();
89    let git_dir = tmp.join(".git");
90    fs::create_dir_all(&git_dir).unwrap();
91
92    let head_path = git_dir.join("HEAD");
93    fs::write(&head_path, "ref: refs/heads/feature/test\n").unwrap();
94
95    let old_cwd = env::current_dir().unwrap();
96    env::set_current_dir(&tmp).unwrap();
97
98    let prompt = get_prompt();
99
100    env::set_current_dir(old_cwd).unwrap();
101    fs::remove_file(&head_path).ok();
102    fs::remove_dir_all(&git_dir).ok();
103    fs::remove_dir_all(&tmp).ok();
104
105    assert!(prompt.contains("(test)"));
106  }
107}