use crate::color::{Styled, themes::DefaultTheme};
use crate::config::CompiledConfig;
use crate::{MAX_PROMPT_MS, PzshError, Result};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub enum PromptSegment {
Literal(String),
User,
Host,
Cwd,
Git,
Char,
Custom(String),
}
#[derive(Debug, Clone, Default)]
pub struct GitCache {
pub branch: Option<String>,
pub dirty: bool,
valid: Arc<AtomicBool>,
}
impl GitCache {
#[must_use]
pub fn new() -> Self {
Self {
branch: None,
dirty: false,
valid: Arc::new(AtomicBool::new(false)),
}
}
#[must_use]
pub fn is_valid(&self) -> bool {
self.valid.load(Ordering::Relaxed)
}
pub fn invalidate(&self) {
self.valid.store(false, Ordering::Relaxed);
}
#[must_use]
pub fn render(&self) -> String {
self.render_colored(false)
}
#[must_use]
pub fn render_colored(&self, colors: bool) -> String {
match &self.branch {
Some(branch) => {
let dirty_marker = if self.dirty { "*" } else { "" };
let text = format!("({branch}{dirty_marker})");
if colors {
let style = if self.dirty {
DefaultTheme::git_dirty()
} else {
DefaultTheme::git_clean()
};
Styled::new(text, style).render()
} else {
text
}
}
None => String::new(),
}
}
}
#[derive(Debug)]
pub struct Prompt {
segments: Vec<PromptSegment>,
git_cache: GitCache,
user: String,
host: String,
colors_enabled: bool,
}
impl Prompt {
#[must_use]
pub fn new(config: &CompiledConfig) -> Self {
let segments = Self::parse_format(&config.prompt_format);
let user = std::env::var("USER").unwrap_or_else(|_| "user".to_string());
let host = hostname::get()
.ok()
.and_then(|h| h.into_string().ok())
.unwrap_or_else(|| "localhost".to_string());
let colors_enabled = config.colors_enabled && crate::color::supports_color();
Self {
segments,
git_cache: GitCache::new(),
user,
host,
colors_enabled,
}
}
pub fn set_colors_enabled(&mut self, enabled: bool) {
self.colors_enabled = enabled && crate::color::supports_color();
}
#[must_use]
pub const fn colors_enabled(&self) -> bool {
self.colors_enabled
}
fn parse_format(format: &str) -> Vec<PromptSegment> {
let mut segments = Vec::new();
let mut current_literal = String::new();
let mut in_brace = false;
let mut brace_content = String::new();
for ch in format.chars() {
match ch {
'{' if !in_brace => {
if !current_literal.is_empty() {
segments.push(PromptSegment::Literal(std::mem::take(&mut current_literal)));
}
in_brace = true;
}
'}' if in_brace => {
let segment = match brace_content.as_str() {
"user" => PromptSegment::User,
"host" => PromptSegment::Host,
"cwd" => PromptSegment::Cwd,
"git" => PromptSegment::Git,
"char" => PromptSegment::Char,
other => PromptSegment::Custom(other.to_string()),
};
segments.push(segment);
brace_content.clear();
in_brace = false;
}
_ if in_brace => {
brace_content.push(ch);
}
_ => {
current_literal.push(ch);
}
}
}
if !current_literal.is_empty() {
segments.push(PromptSegment::Literal(current_literal));
}
segments
}
pub fn render(&self) -> Result<String> {
let start = Instant::now();
let mut output = String::with_capacity(256);
for segment in &self.segments {
match segment {
PromptSegment::Literal(s) => output.push_str(s),
PromptSegment::User => {
if self.colors_enabled {
output.push_str(&Styled::new(&self.user, DefaultTheme::user()).render());
} else {
output.push_str(&self.user);
}
}
PromptSegment::Host => {
if self.colors_enabled {
output.push_str(&Styled::new(&self.host, DefaultTheme::host()).render());
} else {
output.push_str(&self.host);
}
}
PromptSegment::Cwd => {
let cwd = std::env::var("PWD")
.or_else(|_| std::env::current_dir().map(|p| p.display().to_string()))
.unwrap_or_else(|_| "~".to_string());
if self.colors_enabled {
output.push_str(&Styled::new(&cwd, DefaultTheme::cwd()).render());
} else {
output.push_str(&cwd);
}
}
PromptSegment::Git => {
output.push_str(&self.git_cache.render_colored(self.colors_enabled));
}
PromptSegment::Char => {
let is_root = self.user == "root";
let ch = if is_root { '#' } else { '$' };
if self.colors_enabled {
let style = if is_root {
DefaultTheme::prompt_root()
} else {
DefaultTheme::prompt_char()
};
output.push_str(&Styled::new(ch.to_string(), style).render());
} else {
output.push(ch);
}
}
PromptSegment::Custom(name) => {
output.push_str(&format!("{{{name}}}"));
}
}
}
let elapsed = start.elapsed();
if elapsed > Duration::from_millis(MAX_PROMPT_MS) {
return Err(PzshError::PromptBudgetExceeded(
MAX_PROMPT_MS,
elapsed.as_millis() as u64,
));
}
Ok(output)
}
pub fn update_git_cache(&mut self, branch: Option<String>, dirty: bool) {
self.git_cache.branch = branch;
self.git_cache.dirty = dirty;
self.git_cache.valid.store(true, Ordering::Relaxed);
}
pub fn invalidate_git_cache(&self) {
self.git_cache.invalidate();
}
#[must_use]
pub fn segment_count(&self) -> usize {
self.segments.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
fn test_config() -> CompiledConfig {
let mut config = CompiledConfig::default();
config.prompt_format = "{user}@{host} {cwd} {git} {char} ".to_string();
config
}
#[test]
fn test_prompt_render_under_2ms() {
let config = test_config();
let prompt = Prompt::new(&config);
let start = Instant::now();
let result = prompt.render();
let elapsed = start.elapsed();
assert!(result.is_ok());
assert!(
elapsed < Duration::from_millis(MAX_PROMPT_MS),
"ANDON: Prompt exceeded 2ms budget: {:?}",
elapsed
);
}
#[test]
fn test_parse_format() {
let segments = Prompt::parse_format("{user}@{host} {cwd} {char}");
assert_eq!(segments.len(), 7);
assert!(matches!(segments[0], PromptSegment::User));
assert!(matches!(segments[1], PromptSegment::Literal(ref s) if s == "@"));
assert!(matches!(segments[2], PromptSegment::Host));
assert!(matches!(segments[3], PromptSegment::Literal(ref s) if s == " "));
assert!(matches!(segments[4], PromptSegment::Cwd));
assert!(matches!(segments[5], PromptSegment::Literal(ref s) if s == " "));
assert!(matches!(segments[6], PromptSegment::Char));
}
#[test]
fn test_git_cache_render() {
let mut cache = GitCache::new();
assert_eq!(cache.render(), "");
cache.branch = Some("main".to_string());
assert_eq!(cache.render(), "(main)");
cache.dirty = true;
assert_eq!(cache.render(), "(main*)");
}
#[test]
fn test_git_cache_invalidation() {
let cache = GitCache::new();
assert!(!cache.is_valid());
cache.valid.store(true, Ordering::Relaxed);
assert!(cache.is_valid());
cache.invalidate();
assert!(!cache.is_valid());
}
#[test]
fn test_prompt_contains_expected_parts() {
let config = test_config();
let prompt = Prompt::new(&config);
let rendered = prompt.render().unwrap();
assert!(
rendered.contains(&prompt.user),
"Prompt should contain user"
);
assert!(
rendered.contains(&prompt.host),
"Prompt should contain host"
);
assert!(
rendered.contains('$') || rendered.contains('#'),
"Prompt should contain char"
);
}
#[test]
fn test_prompt_with_git_cache() {
let config = test_config();
let mut prompt = Prompt::new(&config);
prompt.update_git_cache(Some("feature-branch".to_string()), true);
let rendered = prompt.render().unwrap();
assert!(
rendered.contains("(feature-branch*)"),
"Prompt should show git status: {}",
rendered
);
}
#[test]
fn test_prompt_render_is_o1() {
let config1 = CompiledConfig {
prompt_format: "{user}".to_string(),
..Default::default()
};
let config2 = CompiledConfig {
prompt_format: "{user}@{host} {cwd} {git} {char}".to_string(),
..Default::default()
};
let prompt1 = Prompt::new(&config1);
let prompt2 = Prompt::new(&config2);
let start = Instant::now();
for _ in 0..1000 {
let _ = prompt1.render();
}
let time1 = start.elapsed();
let start = Instant::now();
for _ in 0..1000 {
let _ = prompt2.render();
}
let time2 = start.elapsed();
assert!(
time2 < time1 * 20,
"Complex prompt too slow: {:?} vs {:?}",
time2,
time1
);
}
#[test]
fn test_prompt_deterministic() {
let config = test_config();
let prompt = Prompt::new(&config);
let render1 = prompt.render().unwrap();
let render2 = prompt.render().unwrap();
assert_eq!(render1, render2, "Prompt must be deterministic");
}
#[test]
fn test_prompt_segment_count() {
let config = test_config();
let prompt = Prompt::new(&config);
assert!(prompt.segment_count() > 0);
}
#[test]
fn test_prompt_colors_enabled() {
let config = test_config();
let prompt = Prompt::new(&config);
let _ = prompt.colors_enabled();
}
#[test]
fn test_prompt_set_colors_enabled() {
let config = test_config();
let mut prompt = Prompt::new(&config);
prompt.set_colors_enabled(false);
assert!(!prompt.colors_enabled());
}
#[test]
fn test_git_cache_new() {
let cache = GitCache::new();
assert!(cache.branch.is_none());
assert!(!cache.dirty);
assert!(!cache.is_valid());
}
#[test]
fn test_git_cache_valid_flag() {
let cache = GitCache::new();
cache.valid.store(true, Ordering::Relaxed);
assert!(cache.is_valid());
cache.invalidate();
assert!(!cache.is_valid());
}
#[test]
fn test_git_cache_render_colored() {
let mut cache = GitCache::new();
assert_eq!(cache.render_colored(true), "");
assert_eq!(cache.render_colored(false), "");
cache.branch = Some("main".to_string());
let colored = cache.render_colored(true);
let plain = cache.render_colored(false);
assert!(colored.contains("main") || plain.contains("main"));
assert_eq!(plain, "(main)");
cache.dirty = true;
let colored = cache.render_colored(true);
let plain = cache.render_colored(false);
assert!(colored.contains("main") || plain.contains("main"));
assert_eq!(plain, "(main*)");
}
#[test]
fn test_prompt_update_git_cache() {
let config = test_config();
let mut prompt = Prompt::new(&config);
prompt.update_git_cache(Some("develop".to_string()), false);
let rendered = prompt.render().unwrap();
assert!(rendered.contains("(develop)"), "Should show clean branch");
prompt.update_git_cache(Some("develop".to_string()), true);
let rendered = prompt.render().unwrap();
assert!(rendered.contains("(develop*)"), "Should show dirty branch");
}
#[test]
fn test_prompt_invalidate_git_cache() {
let config = test_config();
let mut prompt = Prompt::new(&config);
prompt.update_git_cache(Some("main".to_string()), false);
assert!(prompt.git_cache.is_valid());
prompt.invalidate_git_cache();
assert!(!prompt.git_cache.is_valid());
}
#[test]
fn test_parse_format_empty() {
let segments = Prompt::parse_format("");
assert!(segments.is_empty());
}
#[test]
fn test_parse_format_literal_only() {
let segments = Prompt::parse_format("hello world");
assert_eq!(segments.len(), 1);
assert!(matches!(segments[0], PromptSegment::Literal(ref s) if s == "hello world"));
}
#[test]
fn test_parse_format_custom_segment() {
let segments = Prompt::parse_format("{custom_thing}");
assert_eq!(segments.len(), 1);
assert!(matches!(segments[0], PromptSegment::Custom(ref s) if s == "custom_thing"));
}
#[test]
fn test_prompt_custom_segment_render() {
let mut config = CompiledConfig::default();
config.prompt_format = "{custom} $ ".to_string();
let prompt = Prompt::new(&config);
let rendered = prompt.render().unwrap();
assert!(rendered.contains("{custom}"));
}
#[test]
fn test_prompt_segment_debug() {
let segments = vec![
PromptSegment::Literal("test".to_string()),
PromptSegment::User,
PromptSegment::Host,
PromptSegment::Cwd,
PromptSegment::Git,
PromptSegment::Char,
PromptSegment::Custom("x".to_string()),
];
for seg in segments {
let debug = format!("{:?}", seg);
assert!(!debug.is_empty());
}
}
#[test]
fn test_prompt_segment_clone() {
let seg = PromptSegment::Literal("test".to_string());
let cloned = seg.clone();
assert!(matches!(cloned, PromptSegment::Literal(ref s) if s == "test"));
}
#[test]
fn test_git_cache_debug() {
let cache = GitCache::new();
let debug = format!("{:?}", cache);
assert!(debug.contains("GitCache"));
}
#[test]
fn test_git_cache_clone() {
let mut cache = GitCache::new();
cache.branch = Some("main".to_string());
cache.dirty = true;
let cloned = cache.clone();
assert_eq!(cloned.branch, Some("main".to_string()));
assert!(cloned.dirty);
}
#[test]
fn test_git_cache_default() {
let cache = GitCache::default();
assert!(cache.branch.is_none());
assert!(!cache.dirty);
}
#[test]
fn test_prompt_debug() {
let config = test_config();
let prompt = Prompt::new(&config);
let debug = format!("{:?}", prompt);
assert!(debug.contains("Prompt"));
}
#[test]
fn test_prompt_root_char() {
let mut config = CompiledConfig::default();
config.prompt_format = "{char}".to_string();
config.colors_enabled = false;
let prompt = Prompt::new(&config);
let rendered = prompt.render().unwrap();
assert!(rendered.contains('$') || rendered.contains('#'));
}
#[test]
fn test_prompt_all_segments() {
let mut config = CompiledConfig::default();
config.prompt_format = "{user}@{host}:{cwd} {git} {char} ".to_string();
config.colors_enabled = false;
let mut prompt = Prompt::new(&config);
prompt.update_git_cache(Some("feature".to_string()), false);
let rendered = prompt.render().unwrap();
assert!(rendered.contains('@'));
assert!(rendered.contains(':'));
assert!(rendered.contains("(feature)"));
}
}