use crate::config::CompiledConfig;
use crate::{MAX_PROMPT_MS, PzshError, Result};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub enum PromptSegment {
Literal(String),
User,
Host,
Cwd,
Git,
Char,
Custom(String),
}
#[derive(Debug, Clone, Default)]
pub struct GitCache {
pub branch: Option<String>,
pub dirty: bool,
valid: Arc<AtomicBool>,
}
impl GitCache {
#[must_use]
pub fn new() -> Self {
Self {
branch: None,
dirty: false,
valid: Arc::new(AtomicBool::new(false)),
}
}
#[must_use]
pub fn is_valid(&self) -> bool {
self.valid.load(Ordering::Relaxed)
}
pub fn invalidate(&self) {
self.valid.store(false, Ordering::Relaxed);
}
#[must_use]
pub fn render(&self) -> String {
match &self.branch {
Some(branch) => {
let dirty_marker = if self.dirty { "*" } else { "" };
format!("({branch}{dirty_marker})")
}
None => String::new(),
}
}
}
#[derive(Debug)]
pub struct Prompt {
segments: Vec<PromptSegment>,
git_cache: GitCache,
user: String,
host: String,
}
impl Prompt {
#[must_use]
pub fn new(config: &CompiledConfig) -> Self {
let segments = Self::parse_format(&config.prompt_format);
let user = std::env::var("USER").unwrap_or_else(|_| "user".to_string());
let host = hostname::get()
.ok()
.and_then(|h| h.into_string().ok())
.unwrap_or_else(|| "localhost".to_string());
Self {
segments,
git_cache: GitCache::new(),
user,
host,
}
}
fn parse_format(format: &str) -> Vec<PromptSegment> {
let mut segments = Vec::new();
let mut current_literal = String::new();
let mut in_brace = false;
let mut brace_content = String::new();
for ch in format.chars() {
match ch {
'{' if !in_brace => {
if !current_literal.is_empty() {
segments.push(PromptSegment::Literal(std::mem::take(&mut current_literal)));
}
in_brace = true;
}
'}' if in_brace => {
let segment = match brace_content.as_str() {
"user" => PromptSegment::User,
"host" => PromptSegment::Host,
"cwd" => PromptSegment::Cwd,
"git" => PromptSegment::Git,
"char" => PromptSegment::Char,
other => PromptSegment::Custom(other.to_string()),
};
segments.push(segment);
brace_content.clear();
in_brace = false;
}
_ if in_brace => {
brace_content.push(ch);
}
_ => {
current_literal.push(ch);
}
}
}
if !current_literal.is_empty() {
segments.push(PromptSegment::Literal(current_literal));
}
segments
}
pub fn render(&self) -> Result<String> {
let start = Instant::now();
let mut output = String::with_capacity(128);
for segment in &self.segments {
match segment {
PromptSegment::Literal(s) => output.push_str(s),
PromptSegment::User => output.push_str(&self.user),
PromptSegment::Host => output.push_str(&self.host),
PromptSegment::Cwd => {
let cwd = std::env::var("PWD")
.or_else(|_| std::env::current_dir().map(|p| p.display().to_string()))
.unwrap_or_else(|_| "~".to_string());
output.push_str(&cwd);
}
PromptSegment::Git => {
output.push_str(&self.git_cache.render());
}
PromptSegment::Char => {
let is_root = self.user == "root";
output.push(if is_root { '#' } else { '$' });
}
PromptSegment::Custom(name) => {
output.push_str(&format!("{{{name}}}"));
}
}
}
let elapsed = start.elapsed();
if elapsed > Duration::from_millis(MAX_PROMPT_MS) {
return Err(PzshError::PromptBudgetExceeded(
MAX_PROMPT_MS,
elapsed.as_millis() as u64,
));
}
Ok(output)
}
pub fn update_git_cache(&mut self, branch: Option<String>, dirty: bool) {
self.git_cache.branch = branch;
self.git_cache.dirty = dirty;
self.git_cache.valid.store(true, Ordering::Relaxed);
}
pub fn invalidate_git_cache(&self) {
self.git_cache.invalidate();
}
#[must_use]
pub fn segment_count(&self) -> usize {
self.segments.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
fn test_config() -> CompiledConfig {
let mut config = CompiledConfig::default();
config.prompt_format = "{user}@{host} {cwd} {git} {char} ".to_string();
config
}
#[test]
fn test_prompt_render_under_2ms() {
let config = test_config();
let prompt = Prompt::new(&config);
let start = Instant::now();
let result = prompt.render();
let elapsed = start.elapsed();
assert!(result.is_ok());
assert!(
elapsed < Duration::from_millis(MAX_PROMPT_MS),
"ANDON: Prompt exceeded 2ms budget: {:?}",
elapsed
);
}
#[test]
fn test_parse_format() {
let segments = Prompt::parse_format("{user}@{host} {cwd} {char}");
assert_eq!(segments.len(), 7);
assert!(matches!(segments[0], PromptSegment::User));
assert!(matches!(segments[1], PromptSegment::Literal(ref s) if s == "@"));
assert!(matches!(segments[2], PromptSegment::Host));
assert!(matches!(segments[3], PromptSegment::Literal(ref s) if s == " "));
assert!(matches!(segments[4], PromptSegment::Cwd));
assert!(matches!(segments[5], PromptSegment::Literal(ref s) if s == " "));
assert!(matches!(segments[6], PromptSegment::Char));
}
#[test]
fn test_git_cache_render() {
let mut cache = GitCache::new();
assert_eq!(cache.render(), "");
cache.branch = Some("main".to_string());
assert_eq!(cache.render(), "(main)");
cache.dirty = true;
assert_eq!(cache.render(), "(main*)");
}
#[test]
fn test_git_cache_invalidation() {
let cache = GitCache::new();
assert!(!cache.is_valid());
cache.valid.store(true, Ordering::Relaxed);
assert!(cache.is_valid());
cache.invalidate();
assert!(!cache.is_valid());
}
#[test]
fn test_prompt_contains_expected_parts() {
let config = test_config();
let prompt = Prompt::new(&config);
let rendered = prompt.render().unwrap();
assert!(
rendered.contains(&prompt.user),
"Prompt should contain user"
);
assert!(
rendered.contains(&prompt.host),
"Prompt should contain host"
);
assert!(
rendered.contains('$') || rendered.contains('#'),
"Prompt should contain char"
);
}
#[test]
fn test_prompt_with_git_cache() {
let config = test_config();
let mut prompt = Prompt::new(&config);
prompt.update_git_cache(Some("feature-branch".to_string()), true);
let rendered = prompt.render().unwrap();
assert!(
rendered.contains("(feature-branch*)"),
"Prompt should show git status: {}",
rendered
);
}
#[test]
fn test_prompt_render_is_o1() {
let config1 = CompiledConfig {
prompt_format: "{user}".to_string(),
..Default::default()
};
let config2 = CompiledConfig {
prompt_format: "{user}@{host} {cwd} {git} {char}".to_string(),
..Default::default()
};
let prompt1 = Prompt::new(&config1);
let prompt2 = Prompt::new(&config2);
let start = Instant::now();
for _ in 0..1000 {
let _ = prompt1.render();
}
let time1 = start.elapsed();
let start = Instant::now();
for _ in 0..1000 {
let _ = prompt2.render();
}
let time2 = start.elapsed();
assert!(
time2 < time1 * 5,
"Complex prompt too slow: {:?} vs {:?}",
time2,
time1
);
}
#[test]
fn test_prompt_deterministic() {
let config = test_config();
let prompt = Prompt::new(&config);
let render1 = prompt.render().unwrap();
let render2 = prompt.render().unwrap();
assert_eq!(render1, render2, "Prompt must be deterministic");
}
}