1use std::path::Path;
15
16use crate::agent::AgentId;
17use crate::consts::{CWD_ADDENDUM_AGENT0, CWD_ADDENDUM_AGENTINFINITY, CWD_ADDENDUM_CLONE_EXT};
18use netsky_config::Config as RuntimeConfig;
19
20const BASE_TEMPLATE: &str = include_str!("../prompts/base.md");
24const AGENT0_STANZA: &str = include_str!("../prompts/agent0.md");
25const CLONE_STANZA: &str = include_str!("../prompts/clone.md");
26const AGENTINFINITY_STANZA: &str = include_str!("../prompts/agentinfinity.md");
27
28const SEPARATOR: &str = "\n\n---\n\n";
29
30#[derive(Debug, Clone)]
32pub struct PromptContext {
33 pub agent: AgentId,
34 pub cwd: String,
35}
36
37impl PromptContext {
38 pub fn new(agent: AgentId, cwd: impl Into<String>) -> Self {
39 Self {
40 agent,
41 cwd: cwd.into(),
42 }
43 }
44
45 fn bindings(&self) -> Vec<(&'static str, String)> {
49 vec![
50 ("agent_name", self.agent.name()),
51 ("n", self.agent.env_n()),
52 ("cwd", self.cwd.clone()),
53 ]
54 }
55}
56
57#[derive(Debug)]
58pub enum PromptError {
59 Io(std::io::Error),
60 Config(anyhow::Error),
61 UnsubstitutedPlaceholders { count: usize, preview: String },
62}
63
64impl std::fmt::Display for PromptError {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 Self::Io(e) => write!(f, "io error reading addendum: {e}"),
68 Self::Config(e) => write!(f, "runtime config error reading addendum: {e}"),
69 Self::UnsubstitutedPlaceholders { count, preview } => write!(
70 f,
71 "template render left {count} unsubstituted placeholder(s): {preview}"
72 ),
73 }
74 }
75}
76
77impl std::error::Error for PromptError {
78 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
79 match self {
80 Self::Io(e) => Some(e),
81 Self::Config(e) => Some(e.as_ref()),
82 _ => None,
83 }
84 }
85}
86
87impl From<std::io::Error> for PromptError {
88 fn from(e: std::io::Error) -> Self {
89 Self::Io(e)
90 }
91}
92
93impl From<anyhow::Error> for PromptError {
94 fn from(e: anyhow::Error) -> Self {
95 Self::Config(e)
96 }
97}
98
99fn stanza_for(agent: AgentId) -> &'static str {
100 match agent {
101 AgentId::Agent0 => AGENT0_STANZA,
102 AgentId::Clone(_) => CLONE_STANZA,
103 AgentId::Agentinfinity => AGENTINFINITY_STANZA,
104 }
105}
106
107fn cwd_addendum_filename(agent: AgentId) -> String {
110 match agent {
111 AgentId::Agent0 => CWD_ADDENDUM_AGENT0.to_string(),
112 AgentId::Agentinfinity => CWD_ADDENDUM_AGENTINFINITY.to_string(),
113 AgentId::Clone(n) => format!("{n}{CWD_ADDENDUM_CLONE_EXT}"),
114 }
115}
116
117fn resolve_addendum_path(agent: AgentId, cwd: &Path) -> std::path::PathBuf {
129 use crate::config::Config;
130
131 let configured = Config::load_from(&cwd.join("netsky.toml"))
132 .ok()
133 .flatten()
134 .and_then(|cfg| cfg.addendum)
135 .and_then(|a| match agent {
136 AgentId::Agent0 => a.agent0,
137 AgentId::Agentinfinity => a.agentinfinity,
138 AgentId::Clone(_) => a.clone_default,
139 });
140
141 match configured {
142 Some(p) if p.starts_with('/') => std::path::PathBuf::from(p),
143 Some(p) if p.starts_with("~/") => {
144 if let Some(home) = dirs::home_dir() {
145 home.join(p.trim_start_matches("~/"))
146 } else {
147 cwd.join(p)
148 }
149 }
150 Some(p) => cwd.join(p),
151 None => cwd.join(cwd_addendum_filename(agent)),
152 }
153}
154
155fn read_cwd_addendum(agent: AgentId, cwd: &Path) -> Result<Option<String>, std::io::Error> {
160 let path = resolve_addendum_path(agent, cwd);
161 match std::fs::read_to_string(&path) {
162 Ok(s) => Ok(Some(s)),
163 Err(e) => match e.kind() {
164 std::io::ErrorKind::NotFound | std::io::ErrorKind::NotADirectory => Ok(None),
167 _ => Err(e),
168 },
169 }
170}
171
172fn read_runtime_addenda() -> Result<Vec<String>, PromptError> {
173 let cfg = RuntimeConfig::load()?;
174 let mut layers = Vec::new();
175
176 if let Some(base) = cfg.addendum.base.as_deref() {
177 let trimmed = base.trim();
178 if !trimmed.is_empty() {
179 layers.push(trimmed.to_string());
180 }
181 }
182
183 if let Some(host) = cfg.addendum.host.as_deref() {
184 let trimmed = host.trim();
185 if !trimmed.is_empty() {
186 layers.push(trimmed.to_string());
187 }
188 }
189
190 Ok(layers)
191}
192
193fn apply_bindings(body: &str, bindings: &[(&'static str, String)]) -> String {
197 let mut out = body.to_string();
198 for (name, value) in bindings {
199 for placeholder in [
204 format!("{{{{ {name} }}}}"),
205 format!("{{{{{name}}}}}"),
206 format!("{{{{ {name}}}}}"),
207 format!("{{{{{name} }}}}"),
208 ] {
209 out = out.replace(&placeholder, value);
210 }
211 }
212 out
213}
214
215fn assert_fully_rendered(body: &str) -> Result<(), PromptError> {
218 let count = body.matches("{{").count();
219 if count == 0 {
220 return Ok(());
221 }
222 let preview = body
223 .match_indices("{{")
224 .take(3)
225 .map(|(i, _)| {
226 let end = body.len().min(i + 32);
227 body[i..end].to_string()
228 })
229 .collect::<Vec<_>>()
230 .join(" | ");
231 Err(PromptError::UnsubstitutedPlaceholders { count, preview })
232}
233
234pub fn render_prompt(ctx: PromptContext, cwd: &Path) -> Result<String, PromptError> {
237 let agent = ctx.agent;
238 let bindings = ctx.bindings();
239
240 let base = apply_bindings(BASE_TEMPLATE, &bindings);
241 let stanza = apply_bindings(stanza_for(agent), &bindings);
242
243 let mut out = String::with_capacity(base.len() + stanza.len() + 128);
244 out.push_str(base.trim_end());
245 out.push_str(SEPARATOR);
246 out.push_str(stanza.trim_end());
247
248 if let Some(addendum) = read_cwd_addendum(agent, cwd)? {
249 let trimmed = addendum.trim();
250 if !trimmed.is_empty() {
251 out.push_str(SEPARATOR);
252 out.push_str(trimmed);
253 }
254 }
255 for addendum in read_runtime_addenda()? {
256 out.push_str(SEPARATOR);
257 out.push_str(&addendum);
258 }
259 out.push('\n');
260
261 assert_fully_rendered(&out)?;
262 Ok(out)
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use std::path::PathBuf;
269 use std::sync::{Mutex, MutexGuard, OnceLock};
270 use tempfile::TempDir;
271
272 struct PromptTestEnv {
273 _tmp: TempDir,
274 _guard: MutexGuard<'static, ()>,
275 prior_xdg: Option<String>,
276 prior_machine_type: Option<String>,
277 }
278
279 impl PromptTestEnv {
280 fn new() -> Self {
281 let guard = test_lock().lock().unwrap_or_else(|err| err.into_inner());
282 let tmp = TempDir::new().unwrap();
283 let prior_xdg = std::env::var("XDG_CONFIG_HOME").ok();
284 let prior_machine_type = std::env::var("MACHINE_TYPE").ok();
285 unsafe {
286 std::env::set_var("XDG_CONFIG_HOME", tmp.path());
287 std::env::remove_var("MACHINE_TYPE");
288 }
289 std::fs::create_dir_all(netsky_config::config_dir()).unwrap();
290 Self {
291 _tmp: tmp,
292 _guard: guard,
293 prior_xdg,
294 prior_machine_type,
295 }
296 }
297 }
298
299 fn test_lock() -> &'static Mutex<()> {
300 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
301 LOCK.get_or_init(|| Mutex::new(()))
302 }
303
304 impl Drop for PromptTestEnv {
305 fn drop(&mut self) {
306 unsafe {
307 match &self.prior_xdg {
308 Some(value) => std::env::set_var("XDG_CONFIG_HOME", value),
309 None => std::env::remove_var("XDG_CONFIG_HOME"),
310 }
311 match &self.prior_machine_type {
312 Some(value) => std::env::set_var("MACHINE_TYPE", value),
313 None => std::env::remove_var("MACHINE_TYPE"),
314 }
315 }
316 }
317 }
318
319 fn ctx_for(agent: AgentId) -> PromptContext {
320 PromptContext::new(agent, "/tmp/netsky-test")
321 }
322
323 #[test]
324 fn renders_all_agents_without_addendum() {
325 let _env = PromptTestEnv::new();
326 let nowhere = PathBuf::from("/dev/null/does-not-exist");
327 for agent in [
328 AgentId::Agent0,
329 AgentId::Clone(1),
330 AgentId::Clone(8),
331 AgentId::Agentinfinity,
332 ] {
333 let out = render_prompt(ctx_for(agent), &nowhere).unwrap();
334 assert!(!out.is_empty(), "empty prompt for {agent}");
335 assert!(out.contains("---"), "missing separator for {agent}");
336 assert!(!out.contains("{{"), "unsubstituted placeholder for {agent}");
337 }
338 }
339
340 #[test]
341 fn clone_prompt_substitutes_n() {
342 let nowhere = PathBuf::from("/dev/null/does-not-exist");
343 let out = render_prompt(ctx_for(AgentId::Clone(5)), &nowhere).unwrap();
344 assert!(out.contains("agent5"));
345 assert!(!out.contains("{{ n }}"));
346 }
347
348 #[test]
349 fn cwd_addendum_is_appended() {
350 let _env = PromptTestEnv::new();
351 let tmp = tempfile::tempdir().unwrap();
352 std::fs::write(tmp.path().join("0.md"), "USER POLICY HERE").unwrap();
353 let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
354 assert!(out.contains("USER POLICY HERE"));
355 }
356
357 #[test]
358 fn render_rejects_unsubstituted_placeholder() {
359 let body = "hello {{ unknown_var }} world";
360 let err = assert_fully_rendered(body).unwrap_err();
361 match err {
362 PromptError::UnsubstitutedPlaceholders { count, .. } => assert_eq!(count, 1),
363 _ => panic!("wrong error variant"),
364 }
365 }
366
367 #[test]
368 fn bindings_stringify_uniformly() {
369 let b0 = PromptContext::new(AgentId::Agent0, "/").bindings();
371 let b5 = PromptContext::new(AgentId::Clone(5), "/").bindings();
372 let binf = PromptContext::new(AgentId::Agentinfinity, "/").bindings();
373 assert_eq!(lookup(&b0, "n"), "0");
374 assert_eq!(lookup(&b5, "n"), "5");
375 assert_eq!(lookup(&binf, "n"), "infinity");
376 }
377
378 fn lookup(bindings: &[(&'static str, String)], key: &str) -> String {
379 bindings.iter().find(|(k, _)| *k == key).unwrap().1.clone()
380 }
381
382 #[test]
383 fn netsky_toml_addendum_overrides_default_path() {
384 let _env = PromptTestEnv::new();
385 let tmp = tempfile::tempdir().unwrap();
389 std::fs::write(tmp.path().join("0.md"), "OLD POLICY").unwrap();
390 std::fs::create_dir_all(tmp.path().join("addenda")).unwrap();
391 std::fs::write(tmp.path().join("addenda/0-personal.md"), "NEW POLICY").unwrap();
392 std::fs::write(
393 tmp.path().join("netsky.toml"),
394 "schema_version = 1\n[addendum]\nagent0 = \"addenda/0-personal.md\"\n",
395 )
396 .unwrap();
397
398 let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
399 assert!(
400 out.contains("NEW POLICY"),
401 "TOML override should pick up addenda/0-personal.md"
402 );
403 assert!(
404 !out.contains("OLD POLICY"),
405 "TOML override should bypass the legacy 0.md fallback"
406 );
407 }
408
409 #[test]
410 fn missing_netsky_toml_falls_back_to_legacy_addendum() {
411 let _env = PromptTestEnv::new();
412 let tmp = tempfile::tempdir().unwrap();
414 std::fs::write(tmp.path().join("0.md"), "LEGACY ADDENDUM").unwrap();
415 let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
416 assert!(out.contains("LEGACY ADDENDUM"));
417 }
418
419 #[test]
420 fn netsky_toml_without_addendum_section_falls_back() {
421 let _env = PromptTestEnv::new();
422 let tmp = tempfile::tempdir().unwrap();
424 std::fs::write(tmp.path().join("0.md"), "FALLBACK POLICY").unwrap();
425 std::fs::write(
426 tmp.path().join("netsky.toml"),
427 "schema_version = 1\n[owner]\nname = \"Alice\"\n",
428 )
429 .unwrap();
430 let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
431 assert!(
432 out.contains("FALLBACK POLICY"),
433 "no [addendum] section should fall back to default filename"
434 );
435 }
436
437 #[test]
438 fn netsky_toml_addendum_absolute_path_used_as_is() {
439 let _env = PromptTestEnv::new();
440 let tmp = tempfile::tempdir().unwrap();
441 let abs_addendum = tmp.path().join("absolute-addendum.md");
442 std::fs::write(&abs_addendum, "ABSOLUTE POLICY").unwrap();
443 std::fs::write(
444 tmp.path().join("netsky.toml"),
445 format!(
446 "schema_version = 1\n[addendum]\nagent0 = \"{}\"\n",
447 abs_addendum.display()
448 ),
449 )
450 .unwrap();
451 let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
452 assert!(out.contains("ABSOLUTE POLICY"));
453 }
454
455 #[test]
456 fn runtime_addendum_layers_append_after_cwd_addendum() {
457 let _env = PromptTestEnv::new();
458 let tmp = tempfile::tempdir().unwrap();
459 std::fs::write(tmp.path().join("0.md"), "CWD POLICY").unwrap();
460 std::fs::write(netsky_config::owner_path(), "github_username = \"cody\"\n").unwrap();
461 std::fs::write(netsky_config::addendum_path(), "BASE POLICY\n").unwrap();
462 std::fs::write(netsky_config::active_host_path(), "work\n").unwrap();
463 std::fs::write(netsky_config::host_addendum_path("work"), "WORK POLICY\n").unwrap();
464
465 let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
466 let cwd = out.find("CWD POLICY").unwrap();
467 let base = out.find("BASE POLICY").unwrap();
468 let host = out.find("WORK POLICY").unwrap();
469 assert!(cwd < base);
470 assert!(base < host);
471 }
472
473 #[test]
474 fn machine_type_env_overrides_active_host_cache() {
475 let _env = PromptTestEnv::new();
476 let tmp = tempfile::tempdir().unwrap();
477 std::fs::write(netsky_config::owner_path(), "github_username = \"cody\"\n").unwrap();
478 std::fs::write(netsky_config::active_host_path(), "personal\n").unwrap();
479 std::fs::write(
480 netsky_config::host_addendum_path("personal"),
481 "PERSONAL POLICY\n",
482 )
483 .unwrap();
484 std::fs::write(netsky_config::host_addendum_path("work"), "WORK POLICY\n").unwrap();
485 unsafe {
486 std::env::set_var("MACHINE_TYPE", "work");
487 }
488
489 let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
490 assert!(out.contains("WORK POLICY"));
491 assert!(!out.contains("PERSONAL POLICY"));
492 }
493}