systemprompt-loader 0.1.18

File loading infrastructure for systemprompt.io - separates I/O from shared models
Documentation
use anyhow::{Context, Result};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::{Path, PathBuf};

use systemprompt_models::AppPaths;
use systemprompt_models::mcp::Deployment;
use systemprompt_models::services::{
    AgentConfig, AiConfig, PartialServicesConfig, PluginConfig, SchedulerConfig, ServicesConfig,
    Settings as ServicesSettings, WebConfig,
};

use crate::ConfigWriter;

#[derive(Debug)]
pub struct EnhancedConfigLoader {
    base_path: PathBuf,
    config_path: PathBuf,
}

#[derive(serde::Deserialize)]
struct RootConfig {
    #[serde(default)]
    includes: Vec<String>,
    #[serde(flatten)]
    config: PartialServicesRootConfig,
}

#[derive(serde::Deserialize, Default)]
struct PartialServicesRootConfig {
    #[serde(default)]
    pub agents: HashMap<String, AgentConfig>,
    #[serde(default)]
    pub mcp_servers: HashMap<String, Deployment>,
    #[serde(default)]
    pub settings: ServicesSettings,
    #[serde(default)]
    pub scheduler: Option<SchedulerConfig>,
    #[serde(default)]
    pub ai: Option<AiConfig>,
    #[serde(default)]
    pub web: Option<WebConfig>,
    #[serde(default)]
    pub plugins: HashMap<String, PluginConfig>,
}

impl EnhancedConfigLoader {
    pub fn new(config_path: PathBuf) -> Self {
        let base_path = config_path
            .parent()
            .unwrap_or_else(|| Path::new("."))
            .to_path_buf();
        Self {
            base_path,
            config_path,
        }
    }

    pub fn from_env() -> Result<Self> {
        let paths = AppPaths::get().map_err(|e| anyhow::anyhow!("{}", e))?;
        let config_path = paths.system().settings().to_path_buf();
        Ok(Self::new(config_path))
    }

    pub fn load(&self) -> Result<ServicesConfig> {
        let content = fs::read_to_string(&self.config_path)
            .with_context(|| format!("Failed to read config: {}", self.config_path.display()))?;

        self.load_from_content(&content)
    }

    pub fn load_from_content(&self, content: &str) -> Result<ServicesConfig> {
        let root: RootConfig = serde_yaml::from_str(content)
            .with_context(|| format!("Failed to parse config: {}", self.config_path.display()))?;

        let mut merged = ServicesConfig {
            agents: root.config.agents,
            mcp_servers: root.config.mcp_servers,
            settings: root.config.settings,
            scheduler: root.config.scheduler,
            ai: root.config.ai.unwrap_or_else(AiConfig::default),
            web: root.config.web.unwrap_or_else(WebConfig::default),
            plugins: root.config.plugins,
        };

        for include_path in &root.includes {
            let partial = self.load_include(include_path)?;
            Self::merge_partial(&mut merged, partial)?;
        }

        self.discover_and_load_agents(&root.includes, &mut merged)?;

        self.resolve_includes(&mut merged)?;

        merged.settings.apply_env_overrides();

        merged
            .validate()
            .map_err(|e| anyhow::anyhow!("Services config validation failed: {}", e))?;

        Ok(merged)
    }

    fn discover_and_load_agents(
        &self,
        existing_includes: &[String],
        merged: &mut ServicesConfig,
    ) -> Result<()> {
        let agents_dir = self.base_path.join("../agents");

        if !agents_dir.exists() {
            return Ok(());
        }

        let included_files: HashSet<String> = existing_includes
            .iter()
            .filter_map(|inc| {
                Path::new(inc)
                    .file_name()
                    .map(|f| f.to_string_lossy().to_string())
            })
            .collect();

        let entries = fs::read_dir(&agents_dir).with_context(|| {
            format!("Failed to read agents directory: {}", agents_dir.display())
        })?;

        for entry in entries {
            let path = entry
                .with_context(|| format!("Failed to read entry in: {}", agents_dir.display()))?
                .path();

            let is_yaml = path
                .extension()
                .is_some_and(|ext| ext == "yaml" || ext == "yml");

            if !is_yaml {
                continue;
            }

            let file_name = path
                .file_name()
                .map(|f| f.to_string_lossy().to_string())
                .ok_or_else(|| anyhow::anyhow!("Invalid file path: {}", path.display()))?;

            if included_files.contains(&file_name) {
                continue;
            }

            let relative_path = format!("../agents/{}", file_name);
            let partial = self.load_include(&relative_path)?;
            Self::merge_partial(merged, partial)?;

            ConfigWriter::add_include(&relative_path, &self.config_path).with_context(|| {
                format!(
                    "Failed to add discovered agent to includes: {}",
                    relative_path
                )
            })?;
        }

        Ok(())
    }

    fn load_include(&self, path: &str) -> Result<PartialServicesConfig> {
        let full_path = self.base_path.join(path);

        if !full_path.exists() {
            anyhow::bail!(
                "Include file not found: {}\nReferenced in: {}/config.yaml\nEither create the \
                 file or remove it from the includes list.",
                full_path.display(),
                self.base_path.display()
            );
        }

        let content = fs::read_to_string(&full_path)
            .with_context(|| format!("Failed to read include: {}", full_path.display()))?;

        serde_yaml::from_str(&content)
            .with_context(|| format!("Failed to parse include: {}", full_path.display()))
    }

    fn merge_partial(target: &mut ServicesConfig, partial: PartialServicesConfig) -> Result<()> {
        for (name, agent) in partial.agents {
            if target.agents.contains_key(&name) {
                anyhow::bail!("Duplicate agent definition: {name}");
            }
            target.agents.insert(name, agent);
        }

        for (name, mcp) in partial.mcp_servers {
            if target.mcp_servers.contains_key(&name) {
                anyhow::bail!("Duplicate MCP server definition: {name}");
            }
            target.mcp_servers.insert(name, mcp);
        }

        if partial.scheduler.is_some() && target.scheduler.is_none() {
            target.scheduler = partial.scheduler;
        }

        if let Some(ai) = partial.ai {
            if target.ai.providers.is_empty() && !ai.providers.is_empty() {
                target.ai = ai;
            } else {
                for (name, provider) in ai.providers {
                    target.ai.providers.insert(name, provider);
                }
            }
        }

        if let Some(web) = partial.web {
            target.web = web;
        }

        for (name, plugin) in partial.plugins {
            if target.plugins.contains_key(&name) {
                anyhow::bail!("Duplicate plugin definition: {name}");
            }
            target.plugins.insert(name, plugin);
        }

        Ok(())
    }

    fn resolve_includes(&self, config: &mut ServicesConfig) -> Result<()> {
        for (name, agent) in &mut config.agents {
            if let Some(ref system_prompt) = agent.metadata.system_prompt {
                if let Some(include_path) = system_prompt.strip_prefix("!include ") {
                    let full_path = self.base_path.join(include_path.trim());
                    let resolved = fs::read_to_string(&full_path).with_context(|| {
                        format!(
                            "Failed to resolve system_prompt include for agent '{name}': {}",
                            full_path.display()
                        )
                    })?;
                    agent.metadata.system_prompt = Some(resolved);
                }
            }
        }

        Ok(())
    }

    pub fn validate_file(path: &Path) -> Result<()> {
        let loader = Self::new(path.to_path_buf());
        let _config = loader.load()?;
        Ok(())
    }

    pub fn get_includes(&self) -> Result<Vec<String>> {
        #[derive(serde::Deserialize)]
        struct IncludesOnly {
            #[serde(default)]
            includes: Vec<String>,
        }

        let content = fs::read_to_string(&self.config_path)?;
        let parsed: IncludesOnly = serde_yaml::from_str(&content)?;
        Ok(parsed.includes)
    }

    pub fn list_all_includes(&self) -> Result<Vec<(String, bool)>> {
        self.get_includes()?
            .into_iter()
            .map(|include| {
                let exists = self.base_path.join(&include).exists();
                Ok((include, exists))
            })
            .collect()
    }

    pub fn base_path(&self) -> &Path {
        &self.base_path
    }
}