1use anyhow::{Context, Result};
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use crate::ui;
6
7const MARKER_START: &str = "# >>> mvmctl >>>";
8const MARKER_END: &str = "# <<< mvmctl <<<";
9
10pub fn generate_block(kv_root: &str) -> String {
15 let mut tera = tera::Tera::default();
16 tera.add_raw_template(
17 "shell_init",
18 include_str!("../resources/shell_init.sh.tera"),
19 )
20 .expect("embedded shell_init template should parse");
21 let mut ctx = tera::Context::new();
22 ctx.insert("kv_root", kv_root);
23 ctx.insert("marker_start", MARKER_START);
24 ctx.insert("marker_end", MARKER_END);
25 tera.render("shell_init", &ctx)
26 .expect("shell_init template should render")
27 .trim()
28 .to_string()
29}
30
31pub fn detect_kv_root() -> Result<PathBuf> {
36 let cwd = std::env::current_dir().context("Failed to get current directory")?;
37 let mut dir = cwd.as_path();
38
39 loop {
40 let cargo_toml = dir.join("Cargo.toml");
41 if cargo_toml.exists() {
42 let contents = fs::read_to_string(&cargo_toml).unwrap_or_default();
43 if contents.contains("name = \"mvmctl\"") {
44 return dir
45 .parent()
46 .map(Path::to_path_buf)
47 .context("mvm repo root has no parent directory");
48 }
49 }
50 dir = match dir.parent() {
51 Some(p) => p,
52 None => anyhow::bail!(
53 "Could not find mvm repo root (Cargo.toml with name = \"mvmctl\") \
54 in any parent of {}",
55 cwd.display()
56 ),
57 };
58 }
59}
60
61fn host_rc_path() -> Result<PathBuf> {
65 let home = std::env::var("HOME").context("HOME not set")?;
66 let rc_name = if cfg!(target_os = "macos") {
67 ".zshrc"
68 } else {
69 ".bashrc"
70 };
71 Ok(PathBuf::from(home).join(rc_name))
72}
73
74fn has_marker(contents: &str) -> bool {
76 contents.contains(MARKER_START)
77}
78
79pub fn ensure_shell_init() -> Result<()> {
82 let kv_root = match detect_kv_root() {
83 Ok(p) => p,
84 Err(e) => {
85 ui::warn(&format!("Skipping shell init: {e}"));
86 return Ok(());
87 }
88 };
89
90 let rc_path = host_rc_path()?;
91 let existing = if rc_path.exists() {
92 fs::read_to_string(&rc_path)
93 .with_context(|| format!("Failed to read {}", rc_path.display()))?
94 } else {
95 String::new()
96 };
97
98 if has_marker(&existing) {
99 ui::info(&format!(
100 "Shell init already configured in {}",
101 rc_path.display()
102 ));
103 return Ok(());
104 }
105
106 let block = generate_block(&kv_root.display().to_string());
107 let separator = if existing.is_empty() || existing.ends_with('\n') {
108 ""
109 } else {
110 "\n"
111 };
112 let new_contents = format!("{existing}{separator}\n{block}\n");
113
114 fs::write(&rc_path, new_contents)
115 .with_context(|| format!("Failed to write {}", rc_path.display()))?;
116
117 ui::success(&format!("Added mvmctl shell init to {}", rc_path.display()));
118 Ok(())
119}
120
121pub fn print_shell_init() -> Result<()> {
123 let kv_root = detect_kv_root()?;
124 let block = generate_block(&kv_root.display().to_string());
125 println!("{block}");
126 Ok(())
127}
128
129pub fn ensure_shell_init_in_vm() -> Result<()> {
135 use mvm_runtime::shell;
136
137 let kv_root = match detect_kv_root() {
138 Ok(p) => p,
139 Err(e) => {
140 ui::warn(&format!("Skipping VM shell init: {e}"));
141 return Ok(());
142 }
143 };
144
145 let block = generate_block(&kv_root.display().to_string());
146 let escaped_marker = MARKER_START.replace('"', r#"\""#);
147 let escaped_block = block.replace('\\', r"\\").replace('"', r#"\""#);
148
149 let script = format!(
151 r#"
152 if grep -qF '{marker}' ~/.bashrc 2>/dev/null; then
153 true
154 else
155 printf '\n{block}\n' >> ~/.bashrc
156 fi
157 "#,
158 marker = escaped_marker,
159 block = escaped_block,
160 );
161
162 shell::run_in_vm(&script).map(|_| ())?;
163 Ok(())
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_generate_block_contains_markers() {
172 let block = generate_block("/some/path");
173 assert!(block.starts_with(MARKER_START));
174 assert!(block.ends_with(MARKER_END));
175 }
176
177 #[test]
178 fn test_generate_block_contains_completions() {
179 let block = generate_block("/some/path");
180 assert!(block.contains("mvmctl completions"));
181 }
182
183 #[test]
184 fn test_generate_block_contains_aliases() {
185 let block = generate_block("/work/kv");
186 assert!(block.contains("alias mvmctl="));
187 assert!(block.contains("alias mvmd="));
188 assert!(block.contains(r#"KV_ROOT="/work/kv""#));
189 assert!(block.contains("$KV_ROOT/mvm/Cargo.toml"));
190 assert!(block.contains("$KV_ROOT/mvmd/Cargo.toml"));
191 }
192
193 #[test]
194 fn test_has_marker_positive() {
195 let contents = format!("some stuff\n{MARKER_START}\nmore\n{MARKER_END}\n");
196 assert!(has_marker(&contents));
197 }
198
199 #[test]
200 fn test_has_marker_negative() {
201 assert!(!has_marker("just some zshrc content\n"));
202 }
203
204 #[test]
205 fn test_detect_kv_root() {
206 let root = detect_kv_root();
208 if let Ok(root) = root {
209 assert!(root.exists());
211 }
212 }
214}