use super::backend::CloudWatchConfig;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct ObservabilityConfig {
pub enabled: bool,
pub backend: String,
pub max_depth: u32,
pub sample_rate: f64,
pub tracing: TracingConfig,
pub fields: FieldsConfig,
pub metrics: MetricsConfig,
pub cloudwatch: CloudWatchConfig,
pub console: ConsoleConfig,
}
impl Default for ObservabilityConfig {
fn default() -> Self {
Self {
enabled: true,
backend: "console".to_string(),
max_depth: 10,
sample_rate: 1.0,
tracing: TracingConfig::default(),
fields: FieldsConfig::default(),
metrics: MetricsConfig::default(),
cloudwatch: CloudWatchConfig::default(),
console: ConsoleConfig::default(),
}
}
}
impl ObservabilityConfig {
pub fn load() -> Result<Self, ConfigError> {
let mut config = if let Ok(contents) = std::fs::read_to_string(".pmcp-config.toml") {
Self::from_toml(&contents)?
} else {
Self::default()
};
config.apply_env_overrides();
Ok(config)
}
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
let contents = std::fs::read_to_string(path.as_ref()).map_err(|e| ConfigError::Io {
path: path.as_ref().display().to_string(),
error: e.to_string(),
})?;
let mut config = Self::from_toml(&contents)?;
config.apply_env_overrides();
Ok(config)
}
pub fn from_toml(content: &str) -> Result<Self, ConfigError> {
#[derive(Deserialize)]
struct FullConfig {
#[serde(default)]
observability: ObservabilityConfig,
}
let full: FullConfig =
toml::from_str(content).map_err(|e| ConfigError::Parse(e.to_string()))?;
Ok(full.observability)
}
fn apply_env_overrides(&mut self) {
if let Ok(enabled) = std::env::var("PMCP_OBSERVABILITY_ENABLED") {
if let Ok(v) = enabled.parse() {
self.enabled = v;
}
}
if let Ok(backend) = std::env::var("PMCP_OBSERVABILITY_BACKEND") {
self.backend = backend;
}
if let Ok(max_depth) = std::env::var("PMCP_OBSERVABILITY_MAX_DEPTH") {
if let Ok(v) = max_depth.parse() {
self.max_depth = v;
}
}
if let Ok(sample_rate) = std::env::var("PMCP_OBSERVABILITY_SAMPLE_RATE") {
if let Ok(v) = sample_rate.parse() {
self.sample_rate = v;
}
}
if let Ok(v) = std::env::var("PMCP_OBSERVABILITY_CAPTURE_TOOL_NAME") {
if let Ok(b) = v.parse() {
self.fields.capture_tool_name = b;
}
}
if let Ok(v) = std::env::var("PMCP_OBSERVABILITY_CAPTURE_ARGUMENTS_HASH") {
if let Ok(b) = v.parse() {
self.fields.capture_arguments_hash = b;
}
}
if let Ok(v) = std::env::var("PMCP_OBSERVABILITY_CAPTURE_CLIENT_IP") {
if let Ok(b) = v.parse() {
self.fields.capture_client_ip = b;
}
}
if let Ok(v) = std::env::var("PMCP_OBSERVABILITY_CAPTURE_RESPONSE_SIZE") {
if let Ok(b) = v.parse() {
self.fields.capture_response_size = b;
}
}
if let Ok(namespace) = std::env::var("PMCP_CLOUDWATCH_NAMESPACE") {
self.cloudwatch.namespace = namespace;
}
if let Ok(emf) = std::env::var("PMCP_CLOUDWATCH_EMF_ENABLED") {
if let Ok(v) = emf.parse() {
self.cloudwatch.emf_enabled = v;
}
}
if let Ok(pretty) = std::env::var("PMCP_CONSOLE_PRETTY") {
if let Ok(v) = pretty.parse() {
self.console.pretty = v;
}
}
}
pub fn should_sample(&self) -> bool {
if self.sample_rate >= 1.0 {
return true;
}
if self.sample_rate <= 0.0 {
return false;
}
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
let random_value = (nanos as f64) / (u32::MAX as f64);
random_value < self.sample_rate
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn development() -> Self {
Self {
enabled: true,
backend: "console".to_string(),
console: ConsoleConfig {
pretty: true,
verbose: false,
},
..Default::default()
}
}
pub fn production() -> Self {
Self {
enabled: true,
backend: "cloudwatch".to_string(),
cloudwatch: CloudWatchConfig::default(),
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct TracingConfig {
pub enabled: bool,
pub trace_header: String,
pub trace_field: String,
}
impl Default for TracingConfig {
fn default() -> Self {
Self {
enabled: true,
trace_header: "X-Trace-ID".to_string(),
trace_field: "_trace".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[allow(clippy::struct_excessive_bools)]
pub struct FieldsConfig {
pub capture_tool_name: bool,
pub capture_resource_uri: bool,
pub capture_prompt_name: bool,
pub capture_arguments_hash: bool,
pub capture_full_arguments: bool,
pub capture_user_id: bool,
pub capture_client_type: bool,
pub capture_client_version: bool,
pub capture_client_ip: bool,
pub capture_session_id: bool,
pub capture_response_size: bool,
pub capture_error_details: bool,
}
impl Default for FieldsConfig {
fn default() -> Self {
Self {
capture_tool_name: true,
capture_resource_uri: true,
capture_prompt_name: true,
capture_arguments_hash: false,
capture_full_arguments: false,
capture_user_id: true,
capture_client_type: true,
capture_client_version: true,
capture_client_ip: false, capture_session_id: true,
capture_response_size: true,
capture_error_details: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[allow(clippy::struct_excessive_bools)]
pub struct MetricsConfig {
pub request_count: bool,
pub request_duration: bool,
pub error_rate: bool,
pub tool_usage: bool,
pub resource_usage: bool,
pub prompt_usage: bool,
pub prefix: String,
}
impl Default for MetricsConfig {
fn default() -> Self {
Self {
request_count: true,
request_duration: true,
error_rate: true,
tool_usage: true,
resource_usage: true,
prompt_usage: true,
prefix: "mcp".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct ConsoleConfig {
pub pretty: bool,
pub verbose: bool,
}
impl Default for ConsoleConfig {
fn default() -> Self {
Self {
pretty: true,
verbose: false,
}
}
}
#[derive(Debug)]
pub enum ConfigError {
Io {
path: String,
error: String,
},
Parse(String),
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io { path, error } => {
write!(f, "Failed to read config file '{path}': {error}")
},
Self::Parse(e) => write!(f, "Failed to parse config: {e}"),
}
}
}
impl std::error::Error for ConfigError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ObservabilityConfig::default();
assert!(config.enabled);
assert_eq!(config.backend, "console");
assert_eq!(config.max_depth, 10);
assert!((config.sample_rate - 1.0).abs() < f64::EPSILON);
assert!(config.fields.capture_tool_name);
assert!(!config.fields.capture_client_ip);
}
#[test]
fn test_from_toml() {
let toml = r#"
[observability]
enabled = true
backend = "cloudwatch"
max_depth = 5
sample_rate = 0.5
[observability.fields]
capture_tool_name = true
capture_client_ip = true
[observability.cloudwatch]
namespace = "MyApp/MCP"
emf_enabled = true
"#;
let config = ObservabilityConfig::from_toml(toml).unwrap();
assert!(config.enabled);
assert_eq!(config.backend, "cloudwatch");
assert_eq!(config.max_depth, 5);
assert!((config.sample_rate - 0.5).abs() < f64::EPSILON);
assert!(config.fields.capture_tool_name);
assert!(config.fields.capture_client_ip);
assert_eq!(config.cloudwatch.namespace, "MyApp/MCP");
}
#[test]
fn test_disabled_config() {
let config = ObservabilityConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_development_config() {
let config = ObservabilityConfig::development();
assert!(config.enabled);
assert_eq!(config.backend, "console");
assert!(config.console.pretty);
}
#[test]
fn test_production_config() {
let config = ObservabilityConfig::production();
assert!(config.enabled);
assert_eq!(config.backend, "cloudwatch");
}
#[test]
fn test_should_sample_always() {
let config = ObservabilityConfig {
sample_rate: 1.0,
..Default::default()
};
for _ in 0..100 {
assert!(config.should_sample());
}
}
#[test]
fn test_should_sample_never() {
let config = ObservabilityConfig {
sample_rate: 0.0,
..Default::default()
};
for _ in 0..100 {
assert!(!config.should_sample());
}
}
#[test]
fn test_tracing_config_defaults() {
let config = TracingConfig::default();
assert!(config.enabled);
assert_eq!(config.trace_header, "X-Trace-ID");
assert_eq!(config.trace_field, "_trace");
}
#[test]
fn test_metrics_config_defaults() {
let config = MetricsConfig::default();
assert!(config.request_count);
assert!(config.request_duration);
assert!(config.error_rate);
assert_eq!(config.prefix, "mcp");
}
}