use std::path::Path;
use crate::agent::AgentId;
use crate::consts::{CWD_ADDENDUM_AGENT0, CWD_ADDENDUM_AGENTINFINITY, CWD_ADDENDUM_CLONE_EXT};
use netsky_config::Config as RuntimeConfig;
const BASE_TEMPLATE: &str = include_str!("../prompts/base.md");
const AGENT0_STANZA: &str = include_str!("../prompts/agent0.md");
const CLONE_STANZA: &str = include_str!("../prompts/clone.md");
const AGENTINFINITY_STANZA: &str = include_str!("../prompts/agentinfinity.md");
const SEPARATOR: &str = "\n\n---\n\n";
#[derive(Debug, Clone)]
pub struct PromptContext {
pub agent: AgentId,
pub cwd: String,
}
impl PromptContext {
pub fn new(agent: AgentId, cwd: impl Into<String>) -> Self {
Self {
agent,
cwd: cwd.into(),
}
}
fn bindings(&self) -> Vec<(&'static str, String)> {
vec![
("agent_name", self.agent.name()),
("n", self.agent.env_n()),
("cwd", self.cwd.clone()),
]
}
}
#[derive(Debug)]
pub enum PromptError {
Io(std::io::Error),
Config(anyhow::Error),
UnsubstitutedPlaceholders { count: usize, preview: String },
}
impl std::fmt::Display for PromptError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "io error reading addendum: {e}"),
Self::Config(e) => write!(f, "runtime config error reading addendum: {e}"),
Self::UnsubstitutedPlaceholders { count, preview } => write!(
f,
"template render left {count} unsubstituted placeholder(s): {preview}"
),
}
}
}
impl std::error::Error for PromptError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::Config(e) => Some(e.as_ref()),
_ => None,
}
}
}
impl From<std::io::Error> for PromptError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<anyhow::Error> for PromptError {
fn from(e: anyhow::Error) -> Self {
Self::Config(e)
}
}
fn stanza_for(agent: AgentId) -> &'static str {
match agent {
AgentId::Agent0 => AGENT0_STANZA,
AgentId::Clone(_) => CLONE_STANZA,
AgentId::Agentinfinity => AGENTINFINITY_STANZA,
}
}
fn cwd_addendum_filename(agent: AgentId) -> String {
match agent {
AgentId::Agent0 => CWD_ADDENDUM_AGENT0.to_string(),
AgentId::Agentinfinity => CWD_ADDENDUM_AGENTINFINITY.to_string(),
AgentId::Clone(n) => format!("{n}{CWD_ADDENDUM_CLONE_EXT}"),
}
}
fn resolve_addendum_path(agent: AgentId, cwd: &Path) -> std::path::PathBuf {
use crate::config::Config;
let configured = Config::load_from(&cwd.join("netsky.toml"))
.ok()
.flatten()
.and_then(|cfg| cfg.addendum)
.and_then(|a| match agent {
AgentId::Agent0 => a.agent0,
AgentId::Agentinfinity => a.agentinfinity,
AgentId::Clone(_) => a.clone_default,
});
match configured {
Some(p) if p.starts_with('/') => std::path::PathBuf::from(p),
Some(p) if p.starts_with("~/") => {
if let Some(home) = dirs::home_dir() {
home.join(p.trim_start_matches("~/"))
} else {
cwd.join(p)
}
}
Some(p) => cwd.join(p),
None => cwd.join(cwd_addendum_filename(agent)),
}
}
fn read_cwd_addendum(agent: AgentId, cwd: &Path) -> Result<Option<String>, std::io::Error> {
let path = resolve_addendum_path(agent, cwd);
match std::fs::read_to_string(&path) {
Ok(s) => Ok(Some(s)),
Err(e) => match e.kind() {
std::io::ErrorKind::NotFound | std::io::ErrorKind::NotADirectory => Ok(None),
_ => Err(e),
},
}
}
fn read_runtime_addenda() -> Result<Vec<String>, PromptError> {
let cfg = RuntimeConfig::load()?;
let mut layers = Vec::new();
if let Some(base) = cfg.addendum.base.as_deref() {
let trimmed = base.trim();
if !trimmed.is_empty() {
layers.push(trimmed.to_string());
}
}
if let Some(host) = cfg.addendum.host.as_deref() {
let trimmed = host.trim();
if !trimmed.is_empty() {
layers.push(trimmed.to_string());
}
}
Ok(layers)
}
fn apply_bindings(body: &str, bindings: &[(&'static str, String)]) -> String {
let mut out = body.to_string();
for (name, value) in bindings {
for placeholder in [
format!("{{{{ {name} }}}}"),
format!("{{{{{name}}}}}"),
format!("{{{{ {name}}}}}"),
format!("{{{{{name} }}}}"),
] {
out = out.replace(&placeholder, value);
}
}
out
}
fn assert_fully_rendered(body: &str) -> Result<(), PromptError> {
let count = body.matches("{{").count();
if count == 0 {
return Ok(());
}
let preview = body
.match_indices("{{")
.take(3)
.map(|(i, _)| {
let end = body.len().min(i + 32);
body[i..end].to_string()
})
.collect::<Vec<_>>()
.join(" | ");
Err(PromptError::UnsubstitutedPlaceholders { count, preview })
}
pub fn render_prompt(ctx: PromptContext, cwd: &Path) -> Result<String, PromptError> {
let agent = ctx.agent;
let bindings = ctx.bindings();
let base = apply_bindings(BASE_TEMPLATE, &bindings);
let stanza = apply_bindings(stanza_for(agent), &bindings);
let mut out = String::with_capacity(base.len() + stanza.len() + 128);
out.push_str(base.trim_end());
out.push_str(SEPARATOR);
out.push_str(stanza.trim_end());
if let Some(addendum) = read_cwd_addendum(agent, cwd)? {
let trimmed = addendum.trim();
if !trimmed.is_empty() {
out.push_str(SEPARATOR);
out.push_str(trimmed);
}
}
for addendum in read_runtime_addenda()? {
out.push_str(SEPARATOR);
out.push_str(&addendum);
}
out.push('\n');
assert_fully_rendered(&out)?;
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use std::sync::{Mutex, MutexGuard, OnceLock};
use tempfile::TempDir;
struct PromptTestEnv {
_tmp: TempDir,
_guard: MutexGuard<'static, ()>,
prior_xdg: Option<String>,
prior_machine_type: Option<String>,
}
impl PromptTestEnv {
fn new() -> Self {
let guard = test_lock().lock().unwrap_or_else(|err| err.into_inner());
let tmp = TempDir::new().unwrap();
let prior_xdg = std::env::var("XDG_CONFIG_HOME").ok();
let prior_machine_type = std::env::var("MACHINE_TYPE").ok();
unsafe {
std::env::set_var("XDG_CONFIG_HOME", tmp.path());
std::env::remove_var("MACHINE_TYPE");
}
std::fs::create_dir_all(netsky_config::config_dir()).unwrap();
Self {
_tmp: tmp,
_guard: guard,
prior_xdg,
prior_machine_type,
}
}
}
fn test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
impl Drop for PromptTestEnv {
fn drop(&mut self) {
unsafe {
match &self.prior_xdg {
Some(value) => std::env::set_var("XDG_CONFIG_HOME", value),
None => std::env::remove_var("XDG_CONFIG_HOME"),
}
match &self.prior_machine_type {
Some(value) => std::env::set_var("MACHINE_TYPE", value),
None => std::env::remove_var("MACHINE_TYPE"),
}
}
}
}
fn ctx_for(agent: AgentId) -> PromptContext {
PromptContext::new(agent, "/tmp/netsky-test")
}
#[test]
fn renders_all_agents_without_addendum() {
let _env = PromptTestEnv::new();
let nowhere = PathBuf::from("/dev/null/does-not-exist");
for agent in [
AgentId::Agent0,
AgentId::Clone(1),
AgentId::Clone(8),
AgentId::Agentinfinity,
] {
let out = render_prompt(ctx_for(agent), &nowhere).unwrap();
assert!(!out.is_empty(), "empty prompt for {agent}");
assert!(out.contains("---"), "missing separator for {agent}");
assert!(!out.contains("{{"), "unsubstituted placeholder for {agent}");
}
}
#[test]
fn clone_prompt_substitutes_n() {
let nowhere = PathBuf::from("/dev/null/does-not-exist");
let out = render_prompt(ctx_for(AgentId::Clone(5)), &nowhere).unwrap();
assert!(out.contains("agent5"));
assert!(!out.contains("{{ n }}"));
}
#[test]
fn cwd_addendum_is_appended() {
let _env = PromptTestEnv::new();
let tmp = tempfile::tempdir().unwrap();
std::fs::write(tmp.path().join("0.md"), "USER POLICY HERE").unwrap();
let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
assert!(out.contains("USER POLICY HERE"));
}
#[test]
fn render_rejects_unsubstituted_placeholder() {
let body = "hello {{ unknown_var }} world";
let err = assert_fully_rendered(body).unwrap_err();
match err {
PromptError::UnsubstitutedPlaceholders { count, .. } => assert_eq!(count, 1),
_ => panic!("wrong error variant"),
}
}
#[test]
fn bindings_stringify_uniformly() {
let b0 = PromptContext::new(AgentId::Agent0, "/").bindings();
let b5 = PromptContext::new(AgentId::Clone(5), "/").bindings();
let binf = PromptContext::new(AgentId::Agentinfinity, "/").bindings();
assert_eq!(lookup(&b0, "n"), "0");
assert_eq!(lookup(&b5, "n"), "5");
assert_eq!(lookup(&binf, "n"), "infinity");
}
fn lookup(bindings: &[(&'static str, String)], key: &str) -> String {
bindings.iter().find(|(k, _)| *k == key).unwrap().1.clone()
}
#[test]
fn netsky_toml_addendum_overrides_default_path() {
let _env = PromptTestEnv::new();
let tmp = tempfile::tempdir().unwrap();
std::fs::write(tmp.path().join("0.md"), "OLD POLICY").unwrap();
std::fs::create_dir_all(tmp.path().join("addenda")).unwrap();
std::fs::write(tmp.path().join("addenda/0-personal.md"), "NEW POLICY").unwrap();
std::fs::write(
tmp.path().join("netsky.toml"),
"schema_version = 1\n[addendum]\nagent0 = \"addenda/0-personal.md\"\n",
)
.unwrap();
let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
assert!(
out.contains("NEW POLICY"),
"TOML override should pick up addenda/0-personal.md"
);
assert!(
!out.contains("OLD POLICY"),
"TOML override should bypass the legacy 0.md fallback"
);
}
#[test]
fn missing_netsky_toml_falls_back_to_legacy_addendum() {
let _env = PromptTestEnv::new();
let tmp = tempfile::tempdir().unwrap();
std::fs::write(tmp.path().join("0.md"), "LEGACY ADDENDUM").unwrap();
let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
assert!(out.contains("LEGACY ADDENDUM"));
}
#[test]
fn netsky_toml_without_addendum_section_falls_back() {
let _env = PromptTestEnv::new();
let tmp = tempfile::tempdir().unwrap();
std::fs::write(tmp.path().join("0.md"), "FALLBACK POLICY").unwrap();
std::fs::write(
tmp.path().join("netsky.toml"),
"schema_version = 1\n[owner]\nname = \"Alice\"\n",
)
.unwrap();
let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
assert!(
out.contains("FALLBACK POLICY"),
"no [addendum] section should fall back to default filename"
);
}
#[test]
fn netsky_toml_addendum_absolute_path_used_as_is() {
let _env = PromptTestEnv::new();
let tmp = tempfile::tempdir().unwrap();
let abs_addendum = tmp.path().join("absolute-addendum.md");
std::fs::write(&abs_addendum, "ABSOLUTE POLICY").unwrap();
std::fs::write(
tmp.path().join("netsky.toml"),
format!(
"schema_version = 1\n[addendum]\nagent0 = \"{}\"\n",
abs_addendum.display()
),
)
.unwrap();
let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
assert!(out.contains("ABSOLUTE POLICY"));
}
#[test]
fn runtime_addendum_layers_append_after_cwd_addendum() {
let _env = PromptTestEnv::new();
let tmp = tempfile::tempdir().unwrap();
std::fs::write(tmp.path().join("0.md"), "CWD POLICY").unwrap();
std::fs::write(netsky_config::owner_path(), "github_username = \"cody\"\n").unwrap();
std::fs::write(netsky_config::addendum_path(), "BASE POLICY\n").unwrap();
std::fs::write(netsky_config::active_host_path(), "work\n").unwrap();
std::fs::write(netsky_config::host_addendum_path("work"), "WORK POLICY\n").unwrap();
let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
let cwd = out.find("CWD POLICY").unwrap();
let base = out.find("BASE POLICY").unwrap();
let host = out.find("WORK POLICY").unwrap();
assert!(cwd < base);
assert!(base < host);
}
#[test]
fn machine_type_env_overrides_active_host_cache() {
let _env = PromptTestEnv::new();
let tmp = tempfile::tempdir().unwrap();
std::fs::write(netsky_config::owner_path(), "github_username = \"cody\"\n").unwrap();
std::fs::write(netsky_config::active_host_path(), "personal\n").unwrap();
std::fs::write(
netsky_config::host_addendum_path("personal"),
"PERSONAL POLICY\n",
)
.unwrap();
std::fs::write(netsky_config::host_addendum_path("work"), "WORK POLICY\n").unwrap();
unsafe {
std::env::set_var("MACHINE_TYPE", "work");
}
let out = render_prompt(ctx_for(AgentId::Agent0), tmp.path()).unwrap();
assert!(out.contains("WORK POLICY"));
assert!(!out.contains("PERSONAL POLICY"));
}
}