use std::{
collections::HashMap,
path::{Path, PathBuf},
};
use serde::{Deserialize, Serialize};
const MAX_SUBGRAPHS: usize = 64;
const MAX_TOTAL_REQUEST_MS: u64 = 300_000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfigFile {
pub gateway: GatewayConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfig {
#[serde(default = "default_listen")]
pub listen: String,
#[serde(default)]
pub playground: bool,
#[serde(default)]
pub subgraphs: HashMap<String, SubgraphConfig>,
#[serde(default)]
pub timeouts: TimeoutConfig,
#[serde(default)]
pub circuit_breaker: CircuitBreakerConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubgraphConfig {
pub url: String,
pub schema: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeoutConfig {
#[serde(default = "default_subgraph_timeout")]
pub subgraph_request_ms: u64,
#[serde(default = "default_total_timeout")]
pub total_request_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
#[serde(default = "default_failure_threshold")]
pub failure_threshold: u32,
#[serde(default = "default_recovery_timeout")]
pub recovery_timeout_ms: u64,
}
fn default_listen() -> String {
"127.0.0.1:4000".to_string()
}
const fn default_subgraph_timeout() -> u64 {
5_000
}
const fn default_total_timeout() -> u64 {
30_000
}
const fn default_failure_threshold() -> u32 {
5
}
const fn default_recovery_timeout() -> u64 {
30_000
}
impl Default for TimeoutConfig {
fn default() -> Self {
Self {
subgraph_request_ms: default_subgraph_timeout(),
total_request_ms: default_total_timeout(),
}
}
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: default_failure_threshold(),
recovery_timeout_ms: default_recovery_timeout(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConfigError {
NoSubgraphs,
TooManySubgraphs {
count: usize,
max: usize,
},
InvalidUrl {
name: String,
url: String,
reason: String,
},
SchemaFileNotFound {
name: String,
path: PathBuf,
},
TotalTimeoutTooSmall,
TotalTimeoutTooLarge {
ms: u64,
max: u64,
},
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoSubgraphs => write!(f, "No subgraphs defined in [gateway.subgraphs]"),
Self::TooManySubgraphs { count, max } => {
write!(f, "Too many subgraphs: {count} (max {max})")
},
Self::InvalidUrl { name, url, reason } => {
write!(f, "Subgraph '{name}' has invalid URL '{url}': {reason}")
},
Self::SchemaFileNotFound { name, path } => {
write!(f, "Subgraph '{name}' schema file not found: {}", path.display())
},
Self::TotalTimeoutTooSmall => {
write!(f, "total_request_ms must be >= subgraph_request_ms")
},
Self::TotalTimeoutTooLarge { ms, max } => {
write!(f, "total_request_ms ({ms}ms) exceeds maximum ({max}ms)")
},
}
}
}
impl std::error::Error for ConfigError {}
pub fn load_config(path: &Path) -> anyhow::Result<GatewayConfig> {
let content = std::fs::read_to_string(path)?;
let file: GatewayConfigFile = toml::from_str(&content)?;
Ok(file.gateway)
}
pub fn validate_config(config: &GatewayConfig, base_dir: &Path) -> Result<(), Vec<ConfigError>> {
let mut errors = Vec::new();
if config.subgraphs.is_empty() {
errors.push(ConfigError::NoSubgraphs);
}
if config.subgraphs.len() > MAX_SUBGRAPHS {
errors.push(ConfigError::TooManySubgraphs {
count: config.subgraphs.len(),
max: MAX_SUBGRAPHS,
});
}
for (name, sg) in &config.subgraphs {
if let Err(e) = reqwest::Url::parse(&sg.url) {
errors.push(ConfigError::InvalidUrl {
name: name.clone(),
url: sg.url.clone(),
reason: e.to_string(),
});
}
if let Some(schema_path) = &sg.schema {
let resolved = if schema_path.is_relative() {
base_dir.join(schema_path)
} else {
schema_path.clone()
};
if !resolved.exists() {
errors.push(ConfigError::SchemaFileNotFound {
name: name.clone(),
path: resolved,
});
}
}
}
if config.timeouts.total_request_ms < config.timeouts.subgraph_request_ms {
errors.push(ConfigError::TotalTimeoutTooSmall);
}
if config.timeouts.total_request_ms > MAX_TOTAL_REQUEST_MS {
errors.push(ConfigError::TotalTimeoutTooLarge {
ms: config.timeouts.total_request_ms,
max: MAX_TOTAL_REQUEST_MS,
});
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}