use serde::{Deserialize, Serialize};
use crate::error::{ConfigValidationError, Result, ToolkitError};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(deny_unknown_fields)]
pub struct ServerConfig {
#[serde(default)]
pub server: ServerSection,
#[serde(default)]
pub metadata: MetadataSection,
#[serde(default)]
pub database: DatabaseSection,
#[cfg(feature = "http")]
#[serde(default)]
pub backend: Option<BackendSection>,
#[serde(default)]
pub code_mode: Option<CodeModeSection>,
#[serde(default)]
pub tools: Vec<ToolDecl>,
#[serde(default)]
pub prompts: Vec<PromptDecl>,
#[serde(default)]
pub resources: Vec<ResourceDecl>,
#[serde(default)]
pub shared_policy_store: Option<SharedPolicyStoreSection>,
}
impl ServerConfig {
pub fn from_toml(toml_str: &str) -> Result<Self> {
toml::from_str(toml_str).map_err(ToolkitError::Parse)
}
pub fn from_toml_strict_validated(toml_str: &str) -> Result<Self> {
let cfg = Self::from_toml(toml_str)?;
cfg.validate()?;
Ok(cfg)
}
pub fn validate(&self) -> std::result::Result<(), ConfigValidationError> {
if self.server.name.trim().is_empty() {
return Err(ConfigValidationError::EmptyServerName);
}
if self.server.version.trim().is_empty() {
return Err(ConfigValidationError::EmptyServerVersion);
}
for (i, tool) in self.tools.iter().enumerate() {
if tool.name.trim().is_empty() {
return Err(ConfigValidationError::EmptyToolName(i));
}
if tool.declared_kind_count() > 1 {
return Err(ConfigValidationError::AmbiguousToolKind(i));
}
}
for (i, table) in self.database.tables.iter().enumerate() {
if table.name.trim().is_empty() {
return Err(ConfigValidationError::EmptyTableName(i));
}
}
#[cfg(feature = "http")]
if let Some(backend) = &self.backend {
if backend.base_url.trim().is_empty() {
return Err(ConfigValidationError::EmptyBackendBaseUrl);
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct ServerSection {
#[serde(default)]
pub id: Option<String>,
#[serde(default)]
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default, rename = "type")]
pub server_type: Option<String>,
#[serde(default)]
pub version: String,
#[serde(default)]
pub is_reference: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct MetadataSection {
#[serde(default)]
pub display_name: Option<String>,
#[serde(default)]
pub short_description: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub author: Option<String>,
#[serde(default)]
pub visibility: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct DatabaseSection {
#[serde(default, rename = "type")]
pub backend_type: Option<String>,
#[serde(default)]
pub database: Option<String>,
#[serde(default)]
pub output_location: Option<String>,
#[serde(default)]
pub workgroup: Option<String>,
#[serde(default)]
pub query_timeout_ms: Option<u64>,
#[serde(default)]
pub tables: Vec<DatabaseTableDecl>,
#[serde(default)]
pub url: Option<String>,
#[serde(default)]
pub file_path: Option<String>,
#[serde(default)]
pub pool: Option<DatabasePoolSection>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct DatabaseTableDecl {
#[serde(default)]
pub name: String,
#[serde(default)]
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct DatabasePoolSection {
#[serde(default)]
pub max_connections: Option<u32>,
#[serde(default)]
pub connection_timeout_seconds: Option<u64>,
}
#[cfg(feature = "http")]
pub use crate::http::auth::AuthConfig;
#[cfg(feature = "http")]
pub use crate::http::client::HttpConfig;
#[cfg(feature = "http")]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct BackendSection {
#[serde(default)]
pub base_url: String,
#[serde(default)]
pub auth: AuthConfig,
#[serde(default)]
pub http: HttpConfig,
}
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct CodeModeSection {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub server_id: Option<String>,
#[serde(default)]
pub allow_writes: bool,
#[serde(default)]
pub allow_deletes: bool,
#[serde(default)]
pub allow_ddl: bool,
#[serde(default)]
pub require_limit: bool,
#[serde(default)]
pub max_limit: Option<u64>,
#[serde(default)]
pub blocked_tables: Vec<String>,
#[serde(default)]
pub sensitive_columns: Vec<String>,
#[serde(default)]
pub auto_approve_levels: Vec<String>,
#[serde(default)]
pub token_ttl_seconds: Option<u64>,
#[serde(default)]
pub token_secret: Option<String>,
#[serde(default)]
pub allow_inline_token_secret_for_dev: bool,
#[serde(default)]
pub limits: Option<CodeModeLimits>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct CodeModeLimits {
#[serde(default)]
pub max_tables_per_query: Option<u32>,
#[serde(default)]
pub max_join_depth: Option<u32>,
#[serde(default)]
pub max_subquery_depth: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct SharedPolicyStoreSection {
#[serde(default)]
pub creates_shared_store: bool,
#[serde(default)]
pub export_to_ssm: bool,
#[serde(default)]
pub ssm_path: Option<String>,
#[serde(default)]
pub templates: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(deny_unknown_fields)]
pub struct ToolDecl {
#[serde(default)]
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub sql: Option<String>,
#[serde(default)]
pub path: Option<String>,
#[serde(default)]
pub method: Option<String>,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub script: Option<String>,
#[serde(default)]
pub ui_resource_uri: Option<String>,
#[serde(default)]
pub parameters: Vec<ParamDecl>,
#[serde(default)]
pub annotations: Option<AnnotationsDecl>,
}
impl ToolDecl {
#[must_use]
pub fn is_script_tool(&self) -> bool {
self.script.is_some()
}
fn declared_kind_count(&self) -> usize {
let is_sql = self.sql.is_some();
let is_single_call = self.path.is_some() || self.method.is_some();
let is_script = self.script.is_some();
usize::from(is_sql) + usize::from(is_single_call) + usize::from(is_script)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(deny_unknown_fields)]
pub struct ParamDecl {
#[serde(default)]
pub name: String,
#[serde(default, rename = "type")]
pub param_type: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub required: bool,
#[serde(default)]
pub default: Option<toml::Value>,
#[serde(default)]
pub max_length: Option<u64>,
#[serde(default)]
pub minimum: Option<f64>,
#[serde(default)]
pub maximum: Option<f64>,
#[serde(default, rename = "enum")]
pub enum_values: Option<Vec<toml::Value>>,
}
#[allow(clippy::struct_excessive_bools)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct AnnotationsDecl {
#[serde(default)]
pub read_only_hint: bool,
#[serde(default)]
pub destructive_hint: bool,
#[serde(default)]
pub idempotent_hint: bool,
#[serde(default)]
pub open_world_hint: bool,
#[serde(default)]
pub cost_hint: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct PromptDecl {
#[serde(default)]
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub include_resources: Vec<String>,
#[serde(default)]
pub arguments: Vec<PromptArgumentDecl>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct PromptArgumentDecl {
#[serde(default)]
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub required: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(deny_unknown_fields)]
pub struct ResourceDecl {
#[serde(default)]
pub uri: String,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub mime_type: Option<String>,
#[serde(default)]
pub content: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
const MINIMAL: &str = r#"
[server]
name = "demo"
version = "0.1.0"
"#;
#[test]
fn parse_minimal_config_succeeds() {
let cfg = ServerConfig::from_toml(MINIMAL).expect("minimal must parse");
assert_eq!(cfg.server.name, "demo");
assert_eq!(cfg.server.version, "0.1.0");
assert!(cfg.tools.is_empty());
assert!(cfg.code_mode.is_none());
}
#[test]
fn parse_unknown_field_fails() {
let toml = r#"
[server]
name = "demo"
version = "0.1.0"
unknown_field = "x"
"#;
let err = ServerConfig::from_toml(toml).expect_err("unknown field must fail");
assert!(matches!(err, ToolkitError::Parse(_)), "got: {err:?}");
}
#[test]
fn parse_typo_in_code_mode_key_fails() {
let toml = r#"
[server]
name = "demo"
version = "0.1.0"
[code_mode]
enabled = true
auto_aprove_levels = ["low"]
"#;
let err = ServerConfig::from_toml(toml).expect_err("typo'd code_mode key must be rejected");
assert!(matches!(err, ToolkitError::Parse(_)));
}
#[test]
fn code_mode_section_optional() {
let cfg = ServerConfig::from_toml(MINIMAL).expect("parse");
assert!(cfg.code_mode.is_none());
}
#[test]
fn validate_accepts_valid_config() {
let cfg = ServerConfig::from_toml(MINIMAL).expect("parse");
cfg.validate().expect("minimal config must validate");
}
#[test]
fn validate_rejects_empty_server_name() {
let toml = r#"
[server]
name = ""
version = "0.1.0"
"#;
let cfg = ServerConfig::from_toml(toml).expect("parse");
match cfg.validate() {
Err(ConfigValidationError::EmptyServerName) => {},
other => panic!("expected EmptyServerName, got {other:?}"),
}
}
#[test]
fn validate_rejects_empty_server_version() {
let toml = r#"
[server]
name = "demo"
version = ""
"#;
let cfg = ServerConfig::from_toml(toml).expect("parse");
match cfg.validate() {
Err(ConfigValidationError::EmptyServerVersion) => {},
other => panic!("expected EmptyServerVersion, got {other:?}"),
}
}
#[test]
fn validate_rejects_empty_tool_name() {
let toml = r#"
[server]
name = "demo"
version = "0.1.0"
[[tools]]
name = "ok"
description = "first"
[[tools]]
name = ""
description = "second-is-empty"
"#;
let cfg = ServerConfig::from_toml(toml).expect("parse");
match cfg.validate() {
Err(ConfigValidationError::EmptyToolName(1)) => {},
other => panic!("expected EmptyToolName(1), got {other:?}"),
}
}
#[test]
fn validate_rejects_empty_table_name() {
let toml = r#"
[server]
name = "demo"
version = "0.1.0"
[[database.tables]]
name = ""
description = "missing-name"
"#;
let cfg = ServerConfig::from_toml(toml).expect("parse");
match cfg.validate() {
Err(ConfigValidationError::EmptyTableName(0)) => {},
other => panic!("expected EmptyTableName(0), got {other:?}"),
}
}
#[cfg(feature = "http")]
#[test]
fn validate_rejects_empty_backend_base_url() {
let toml = r#"
[server]
name = "demo"
version = "0.1.0"
[backend]
base_url = ""
"#;
let cfg = ServerConfig::from_toml(toml).expect("parse");
match cfg.validate() {
Err(ConfigValidationError::EmptyBackendBaseUrl) => {},
other => panic!("expected EmptyBackendBaseUrl, got {other:?}"),
}
}
#[cfg(feature = "http")]
#[test]
fn validate_rejects_omitted_backend_base_url() {
let toml = r#"
[server]
name = "demo"
version = "0.1.0"
[backend]
"#;
let cfg = ServerConfig::from_toml(toml).expect("parse");
match cfg.validate() {
Err(ConfigValidationError::EmptyBackendBaseUrl) => {},
other => panic!("expected EmptyBackendBaseUrl, got {other:?}"),
}
}
#[cfg(feature = "http")]
#[test]
fn validate_accepts_non_empty_backend_base_url() {
let toml = r#"
[server]
name = "demo"
version = "0.1.0"
[backend]
base_url = "https://api.example.com"
"#;
let cfg = ServerConfig::from_toml(toml).expect("parse");
cfg.validate()
.expect("config with a non-empty backend.base_url must validate");
}
#[cfg(feature = "http")]
#[test]
fn validate_accepts_absent_backend() {
let cfg = ServerConfig::from_toml(MINIMAL).expect("parse");
assert!(cfg.backend.is_none());
cfg.validate()
.expect("a config without [backend] must validate (SQL configs unaffected)");
}
#[cfg(feature = "http")]
#[test]
fn empty_backend_base_url_error_names_the_field() {
let msg = ConfigValidationError::EmptyBackendBaseUrl.to_string();
assert!(
msg.contains("[backend].base_url"),
"error must name the field, got: {msg}"
);
}
#[test]
fn database_url_optional_field_parses() {
let toml = r#"
[server]
name = "x"
version = "0.0.1"
[database]
url = "env:DATABASE_URL"
"#;
let cfg = ServerConfig::from_toml(toml).expect("config with [database].url must parse");
assert_eq!(cfg.database.url, Some("env:DATABASE_URL".to_string()));
}
#[test]
fn from_toml_strict_validated_rolls_both_errors() {
let bad_toml = r#"
[server]
name = "demo"
version = "0.1.0"
nonsense = "x"
"#;
let err = ServerConfig::from_toml_strict_validated(bad_toml)
.expect_err("unknown field must surface");
assert!(matches!(err, ToolkitError::Parse(_)), "got: {err:?}");
let invalid_toml = r#"
[server]
name = ""
version = "0.1.0"
"#;
let err = ServerConfig::from_toml_strict_validated(invalid_toml)
.expect_err("empty name must surface");
assert!(
matches!(
err,
ToolkitError::Validation(ConfigValidationError::EmptyServerName)
),
"got: {err:?}"
);
}
#[test]
fn test_tooldecl_single_call_parses() {
let toml = r#"
[server]
name = "tube"
version = "0.1.0"
[[tools]]
name = "tube_status"
path = "/Line/Mode/tube/Status"
method = "GET"
"#;
let cfg = ServerConfig::from_toml(toml).expect("single-call tool must parse");
let tool = &cfg.tools[0];
assert_eq!(tool.path.as_deref(), Some("/Line/Mode/tube/Status"));
assert_eq!(tool.method.as_deref(), Some("GET"));
assert!(!tool.is_script_tool());
cfg.validate()
.expect("single-call tool is a valid single kind");
}
#[test]
fn test_tooldecl_script_parses() {
let toml = r#"
[server]
name = "tube"
version = "0.1.0"
[[tools]]
name = "plan_journey"
script = """
const a = await api.get('/Journey/JourneyResults/' + args.from + '/to/' + args.to);
return a;
"""
[[tools.parameters]]
name = "from"
type = "string"
required = true
[[tools.parameters]]
name = "to"
type = "string"
required = true
"#;
let cfg = ServerConfig::from_toml(toml).expect("script tool must parse");
let tool = &cfg.tools[0];
assert!(tool.script.is_some());
assert!(tool.is_script_tool());
assert_eq!(tool.parameters.len(), 2);
cfg.validate().expect("script tool is a valid single kind");
}
#[test]
fn test_tooldecl_detection() {
let script = ToolDecl {
script: Some("return 1;".to_string()),
..Default::default()
};
assert!(script.is_script_tool());
let single = ToolDecl {
path: Some("/x".to_string()),
method: Some("GET".to_string()),
..Default::default()
};
assert!(!single.is_script_tool());
let sql = ToolDecl {
sql: Some("SELECT 1".to_string()),
..Default::default()
};
assert!(!sql.is_script_tool());
}
#[test]
fn test_tooldecl_ambiguous_rejected() {
let toml = r#"
[server]
name = "tube"
version = "0.1.0"
[[tools]]
name = "confused"
path = "/x"
method = "GET"
script = "return 1;"
"#;
let cfg = ServerConfig::from_toml(toml).expect("parse (ambiguity is a validate-time rule)");
match cfg.validate() {
Err(ConfigValidationError::AmbiguousToolKind(0)) => {},
other => panic!("expected AmbiguousToolKind(0), got {other:?}"),
}
}
#[test]
fn test_tooldecl_ambiguous_sql_plus_script_rejected() {
let toml = r#"
[server]
name = "tube"
version = "0.1.0"
[[tools]]
name = "confused"
sql = "SELECT 1"
script = "return 1;"
"#;
let cfg = ServerConfig::from_toml(toml).expect("parse");
match cfg.validate() {
Err(ConfigValidationError::AmbiguousToolKind(0)) => {},
other => panic!("expected AmbiguousToolKind(0), got {other:?}"),
}
}
#[test]
fn test_tooldecl_sql_still_parses() {
let toml = r#"
[server]
name = "demo"
version = "0.1.0"
[[tools]]
name = "list_tables"
sql = "SELECT name FROM sqlite_master"
"#;
let cfg = ServerConfig::from_toml(toml).expect("sql tool must still parse");
let tool = &cfg.tools[0];
assert_eq!(tool.sql.as_deref(), Some("SELECT name FROM sqlite_master"));
assert!(tool.path.is_none());
assert!(tool.method.is_none());
assert!(tool.base_url.is_none());
assert!(tool.script.is_none());
assert!(!tool.is_script_tool());
cfg.validate().expect("sql tool validates as a single kind");
}
#[cfg(feature = "http")]
#[test]
fn test_backend_section_parses() {
let toml = r#"
[server]
name = "tube"
version = "0.1.0"
[backend]
base_url = "https://api.tfl.gov.uk"
[backend.auth]
type = "api_key"
[backend.auth.query_params]
app_key = "${TFL_APP_KEY}"
[backend.http]
timeout_seconds = 10
retries = 2
"#;
let cfg = ServerConfig::from_toml(toml).expect("[backend] config must parse");
let backend = cfg.backend.expect("backend must be Some");
assert_eq!(backend.base_url, "https://api.tfl.gov.uk");
assert_eq!(backend.http.timeout_seconds, 10);
assert_eq!(backend.http.retries, 2);
assert!(
matches!(backend.auth, AuthConfig::ApiKey { .. }),
"auth must be api_key, got {:?}",
backend.auth
);
}
#[cfg(feature = "http")]
#[test]
fn test_backend_auth_defaults_to_none() {
let toml = r#"
[server]
name = "tube"
version = "0.1.0"
[backend]
base_url = "https://api.example.com"
"#;
let cfg = ServerConfig::from_toml(toml).expect("backend w/o auth must parse");
let backend = cfg.backend.expect("backend must be Some");
assert!(matches!(backend.auth, AuthConfig::None));
assert_eq!(backend.http, HttpConfig::default());
}
#[cfg(feature = "http")]
#[test]
fn test_sql_config_unaffected() {
let toml = r#"
[server]
name = "demo"
version = "0.1.0"
[database]
type = "sqlite"
file_path = "/tmp/demo.db"
[[tools]]
name = "list_tables"
sql = "SELECT name FROM sqlite_master"
"#;
let cfg = ServerConfig::from_toml(toml).expect("SQL config must still parse");
assert!(
cfg.backend.is_none(),
"SQL config must have backend == None"
);
assert_eq!(cfg.tools.len(), 1);
}
#[cfg(feature = "http")]
#[test]
fn test_backend_unknown_field_rejected() {
let toml = r#"
[server]
name = "tube"
version = "0.1.0"
[backend]
base_url = "https://api.example.com"
[backend.http]
foo = 1
"#;
let err =
ServerConfig::from_toml(toml).expect_err("unknown [backend.http] key must be rejected");
assert!(matches!(err, ToolkitError::Parse(_)), "got: {err:?}");
}
proptest! {
#[test]
fn server_config_minimal_round_trips(
name in "[a-zA-Z0-9_-]{1,32}",
version in "[0-9]+\\.[0-9]+\\.[0-9]+",
) {
let cfg = ServerConfig {
server: ServerSection {
name: name.clone(),
version: version.clone(),
..Default::default()
},
..Default::default()
};
let s = toml::to_string(&cfg).unwrap();
let parsed = ServerConfig::from_toml(&s).unwrap();
prop_assert_eq!(parsed.server.name, name);
prop_assert_eq!(parsed.server.version, version);
}
}
}