use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
#[derive(Debug, Clone, thiserror::Error)]
pub enum ConfigValidationError {
#[error("Unknown key '{key}' in section [{section}]. Did you mean '{suggestion:?}'?")]
UnknownKey {
section: String,
key: String,
suggestion: Option<String>,
},
#[error("Invalid type for [{section}.{key}]: expected {expected}, got {actual}")]
InvalidType {
section: String,
key: String,
expected: String,
actual: String,
},
#[error("Config validation failed:\n{}", .0.iter().map(|e| format!(" - {}", e)).collect::<Vec<_>>().join("\n"))]
Multiple(Vec<ConfigValidationError>),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PluginsConfig {
#[serde(default)]
pub enabled: PluginsEnabled,
#[serde(default = "default_true")]
pub auto_discover: bool,
#[serde(default = "default_plugin_dirs")]
pub directories: Vec<String>,
}
fn default_true() -> bool {
true
}
fn default_plugin_dirs() -> Vec<String> {
vec![
"./.praisonai/plugins/".to_string(),
"~/.praisonai/plugins/".to_string(),
]
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PluginsEnabled {
Bool(bool),
List(Vec<String>),
}
impl Default for PluginsEnabled {
fn default() -> Self {
Self::Bool(false)
}
}
impl PluginsEnabled {
pub fn is_enabled(&self) -> bool {
match self {
Self::Bool(b) => *b,
Self::List(list) => !list.is_empty(),
}
}
pub fn get_list(&self) -> Option<&[String]> {
match self {
Self::Bool(_) => None,
Self::List(list) => Some(list),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DefaultsConfig {
pub model: Option<String>,
pub base_url: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub allow_delegation: bool,
#[serde(default)]
pub allow_code_execution: bool,
#[serde(default = "default_code_execution_mode")]
pub code_execution_mode: String,
pub memory: Option<serde_json::Value>,
pub knowledge: Option<serde_json::Value>,
pub planning: Option<serde_json::Value>,
pub reflection: Option<serde_json::Value>,
pub guardrails: Option<serde_json::Value>,
pub web: Option<serde_json::Value>,
pub output: Option<serde_json::Value>,
pub execution: Option<serde_json::Value>,
pub caching: Option<serde_json::Value>,
pub autonomy: Option<serde_json::Value>,
pub skills: Option<serde_json::Value>,
pub context: Option<serde_json::Value>,
pub hooks: Option<serde_json::Value>,
pub templates: Option<serde_json::Value>,
}
fn default_code_execution_mode() -> String {
"safe".to_string()
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ManagerConfig {
pub llm: Option<String>,
pub max_iter: Option<usize>,
#[serde(default)]
pub verbose: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionConfig {
pub session_id: Option<String>,
pub user_id: Option<String>,
#[serde(default)]
pub persist: bool,
pub storage_path: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AutoRagConfig {
#[serde(default)]
pub enabled: bool,
pub chunk_size: Option<usize>,
pub chunk_overlap: Option<usize>,
pub embedding_model: Option<String>,
pub vector_store: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PraisonConfig {
#[serde(default)]
pub plugins: PluginsConfig,
#[serde(default)]
pub defaults: DefaultsConfig,
}
impl PraisonConfig {
pub fn to_dict(&self) -> serde_json::Value {
serde_json::to_value(self).unwrap_or_default()
}
}
static CONFIG_CACHE: OnceLock<PraisonConfig> = OnceLock::new();
fn find_config_file() -> Option<PathBuf> {
let cwd = std::env::current_dir().ok()?;
let local_paths = [
cwd.join(".praisonai").join("config.toml"),
cwd.join("praisonai.toml"),
];
for path in &local_paths {
if path.exists() {
return Some(path.clone());
}
}
if let Some(home) = dirs::home_dir() {
let global_path = home.join(".praisonai").join("config.toml");
if global_path.exists() {
return Some(global_path);
}
}
None
}
fn load_config_from_file() -> PraisonConfig {
let config_path = match find_config_file() {
Some(path) => path,
None => return PraisonConfig::default(),
};
let content = match std::fs::read_to_string(&config_path) {
Ok(c) => c,
Err(_) => return PraisonConfig::default(),
};
match toml::from_str(&content) {
Ok(config) => config,
Err(e) => {
tracing::warn!("Failed to parse config from {:?}: {}", config_path, e);
PraisonConfig::default()
}
}
}
pub fn get_config() -> &'static PraisonConfig {
CONFIG_CACHE.get_or_init(load_config_from_file)
}
pub fn get_config_path() -> Option<PathBuf> {
find_config_file()
}
pub fn get_plugins_config() -> &'static PluginsConfig {
&get_config().plugins
}
pub fn get_defaults_config() -> &'static DefaultsConfig {
&get_config().defaults
}
pub fn get_default<T: serde::de::DeserializeOwned>(key: &str, fallback: T) -> T {
let defaults = get_defaults_config();
let value = serde_json::to_value(defaults).unwrap_or_default();
let parts: Vec<&str> = key.split('.').collect();
let mut current = &value;
for part in parts {
match current.get(part) {
Some(v) => current = v,
None => return fallback,
}
}
serde_json::from_value(current.clone()).unwrap_or(fallback)
}
pub fn is_plugins_enabled() -> bool {
if let Ok(env_value) = std::env::var("PRAISONAI_PLUGINS") {
let lower = env_value.to_lowercase();
if matches!(lower.as_str(), "true" | "1" | "yes" | "on") {
return true;
}
if matches!(lower.as_str(), "false" | "0" | "no" | "off") {
return false;
}
return true;
}
get_plugins_config().enabled.is_enabled()
}
pub fn get_enabled_plugins() -> Option<Vec<String>> {
if let Ok(env_value) = std::env::var("PRAISONAI_PLUGINS") {
let lower = env_value.to_lowercase();
if !matches!(lower.as_str(), "true" | "1" | "yes" | "on" | "false" | "0" | "no" | "off") {
return Some(
env_value
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect(),
);
}
}
get_plugins_config().enabled.get_list().map(|l| l.to_vec())
}
pub fn apply_config_defaults<T: serde::de::DeserializeOwned + Default>(
param_name: &str,
explicit_value: Option<T>,
) -> Option<T> {
if explicit_value.is_some() {
return explicit_value;
}
let config_value: Option<serde_json::Value> = get_default(param_name, None);
match config_value {
Some(v) => {
if let Some(enabled) = v.get("enabled") {
if enabled.as_bool() == Some(false) {
return None;
}
}
serde_json::from_value(v).ok()
}
None => None,
}
}
const VALID_ROOT_KEYS: &[&str] = &["plugins", "defaults"];
const VALID_PLUGINS_KEYS: &[&str] = &["enabled", "auto_discover", "directories"];
const VALID_DEFAULTS_KEYS: &[&str] = &[
"model",
"base_url",
"api_key",
"allow_delegation",
"allow_code_execution",
"code_execution_mode",
"memory",
"knowledge",
"planning",
"reflection",
"guardrails",
"web",
"output",
"execution",
"caching",
"autonomy",
"skills",
"context",
"hooks",
"templates",
];
fn suggest_similar_key(key: &str, valid_keys: &[&str]) -> Option<String> {
let key_lower = key.to_lowercase();
for valid in valid_keys {
let valid_lower = valid.to_lowercase();
if valid_lower == key_lower {
return Some(valid.to_string());
}
if valid_lower.starts_with(&key_lower) || key_lower.starts_with(&valid_lower) {
return Some(valid.to_string());
}
if valid_lower.contains(&key_lower) || key_lower.contains(&valid_lower) {
return Some(valid.to_string());
}
}
None
}
pub fn validate_config(config: &serde_json::Value) -> Result<(), ConfigValidationError> {
let mut errors = Vec::new();
if let Some(obj) = config.as_object() {
for key in obj.keys() {
if !VALID_ROOT_KEYS.contains(&key.as_str()) {
errors.push(ConfigValidationError::UnknownKey {
section: "root".to_string(),
key: key.clone(),
suggestion: suggest_similar_key(key, VALID_ROOT_KEYS),
});
}
}
if let Some(plugins) = obj.get("plugins") {
if let Some(plugins_obj) = plugins.as_object() {
for key in plugins_obj.keys() {
if !VALID_PLUGINS_KEYS.contains(&key.as_str()) {
errors.push(ConfigValidationError::UnknownKey {
section: "plugins".to_string(),
key: key.clone(),
suggestion: suggest_similar_key(key, VALID_PLUGINS_KEYS),
});
}
}
}
}
if let Some(defaults) = obj.get("defaults") {
if let Some(defaults_obj) = defaults.as_object() {
for key in defaults_obj.keys() {
if !VALID_DEFAULTS_KEYS.contains(&key.as_str()) {
errors.push(ConfigValidationError::UnknownKey {
section: "defaults".to_string(),
key: key.clone(),
suggestion: suggest_similar_key(key, VALID_DEFAULTS_KEYS),
});
}
}
}
}
}
if errors.is_empty() {
Ok(())
} else if errors.len() == 1 {
Err(errors.remove(0))
} else {
Err(ConfigValidationError::Multiple(errors))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugins_enabled_bool() {
let enabled = PluginsEnabled::Bool(true);
assert!(enabled.is_enabled());
assert!(enabled.get_list().is_none());
let disabled = PluginsEnabled::Bool(false);
assert!(!disabled.is_enabled());
}
#[test]
fn test_plugins_enabled_list() {
let enabled = PluginsEnabled::List(vec!["plugin1".to_string(), "plugin2".to_string()]);
assert!(enabled.is_enabled());
assert_eq!(enabled.get_list().unwrap().len(), 2);
let empty = PluginsEnabled::List(vec![]);
assert!(!empty.is_enabled());
}
#[test]
fn test_praison_config_default() {
let config = PraisonConfig::default();
assert!(!config.plugins.enabled.is_enabled());
assert!(!config.plugins.auto_discover);
}
#[test]
fn test_suggest_similar_key() {
let result = suggest_similar_key("model", VALID_DEFAULTS_KEYS);
assert_eq!(result, Some("model".to_string()));
let result2 = suggest_similar_key("mod", VALID_DEFAULTS_KEYS);
assert_eq!(result2, Some("model".to_string()));
let result3 = suggest_similar_key("mem", VALID_DEFAULTS_KEYS);
assert_eq!(result3, Some("memory".to_string()));
let result4 = suggest_similar_key("xyzabc", VALID_DEFAULTS_KEYS);
assert!(result4.is_none());
}
#[test]
fn test_validate_config_valid() {
let config = serde_json::json!({
"plugins": {
"enabled": true
},
"defaults": {
"model": "gpt-4o"
}
});
assert!(validate_config(&config).is_ok());
}
#[test]
fn test_validate_config_invalid_key() {
let config = serde_json::json!({
"plugins": {
"enabeld": true }
});
let result = validate_config(&config);
assert!(result.is_err());
}
#[test]
fn test_defaults_config() {
let defaults = DefaultsConfig {
model: Some("gpt-4o".to_string()),
allow_delegation: true,
..Default::default()
};
assert_eq!(defaults.model, Some("gpt-4o".to_string()));
assert!(defaults.allow_delegation);
assert!(!defaults.allow_code_execution);
}
}