use std::future::Future;
use std::net::IpAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use axum::Router;
use axum::body::Body;
use axum::http::Request;
use axum::response::Response;
use tokio::sync::broadcast;
use forge_core::CircuitBreakerClient;
use forge_core::cluster::{LeaderRole, NodeId, NodeInfo, NodeRole, NodeStatus};
use forge_core::config::{ForgeConfig, NodeRole as ConfigNodeRole};
use forge_core::error::{ForgeError, Result};
use forge_core::function::{ForgeMutation, ForgeQuery};
use forge_core::mcp::ForgeMcpTool;
use forge_runtime::migrations::{Migration, MigrationRunner, load_migrations_from_dir};
use forge_runtime::cluster::{
GracefulShutdown, HeartbeatConfig, HeartbeatLoop, LeaderConfig, LeaderElection, NodeRegistry,
ShutdownConfig,
};
use forge_runtime::cron::{CronRegistry, CronRunner, CronRunnerConfig};
use forge_runtime::daemon::{DaemonRegistry, DaemonRunner};
use forge_runtime::db::Database;
use forge_runtime::function::FunctionRegistry;
use forge_runtime::gateway::{AuthConfig, GatewayConfig as RuntimeGatewayConfig, GatewayServer};
use forge_runtime::jobs::{JobDispatcher, JobQueue, JobRegistry, Worker, WorkerConfig};
use forge_runtime::mcp::McpToolRegistry;
use forge_runtime::webhook::{WebhookRegistry, WebhookState, webhook_handler};
use forge_runtime::workflow::{
EventStore, WorkflowExecutor, WorkflowRegistry, WorkflowScheduler, WorkflowSchedulerConfig,
};
use tokio_util::sync::CancellationToken;
pub type FrontendHandler = fn(Request<Body>) -> Pin<Box<dyn Future<Output = Response> + Send>>;
pub mod prelude {
pub use chrono::{DateTime, Utc};
pub use uuid::Uuid;
pub use serde::{Deserialize, Serialize};
pub use serde_json;
pub type Timestamp = DateTime<Utc>;
pub use forge_core::auth::TokenPair;
pub use forge_core::cluster::NodeRole;
pub use forge_core::config::ForgeConfig;
pub use forge_core::cron::{CronContext, ForgeCron};
pub use forge_core::daemon::{DaemonContext, ForgeDaemon};
pub use forge_core::env::EnvAccess;
pub use forge_core::error::{ForgeError, Result};
pub use forge_core::function::{
AuthContext, DbConn, ForgeMutation, ForgeQuery, MutationContext, QueryContext,
};
pub use forge_core::job::{ForgeJob, JobContext, JobPriority};
pub use forge_core::mcp::{ForgeMcpTool, McpToolContext, McpToolResult};
pub use forge_core::realtime::Delta;
pub use forge_core::schema::{FieldDef, ModelMeta, SchemaRegistry, TableDef};
pub use forge_core::schemars::JsonSchema;
pub use forge_core::types::Upload;
pub use forge_core::webhook::{ForgeWebhook, WebhookContext, WebhookResult, WebhookSignature};
pub use forge_core::workflow::{ForgeWorkflow, WorkflowContext};
pub use axum;
pub use crate::{Forge, ForgeBuilder};
}
pub struct Forge {
config: ForgeConfig,
db: Option<Database>,
node_id: NodeId,
function_registry: FunctionRegistry,
mcp_registry: McpToolRegistry,
job_registry: JobRegistry,
cron_registry: Arc<CronRegistry>,
workflow_registry: WorkflowRegistry,
daemon_registry: Arc<DaemonRegistry>,
webhook_registry: Arc<WebhookRegistry>,
shutdown_tx: broadcast::Sender<()>,
migrations_dir: PathBuf,
extra_migrations: Vec<Migration>,
frontend_handler: Option<FrontendHandler>,
custom_routes_factory: Option<Box<dyn FnOnce(sqlx::PgPool) -> Router + Send + Sync>>,
}
impl Forge {
pub fn builder() -> ForgeBuilder {
ForgeBuilder::new()
}
pub fn node_id(&self) -> NodeId {
self.node_id
}
pub fn config(&self) -> &ForgeConfig {
&self.config
}
pub fn function_registry(&self) -> &FunctionRegistry {
&self.function_registry
}
pub fn function_registry_mut(&mut self) -> &mut FunctionRegistry {
&mut self.function_registry
}
pub fn mcp_registry_mut(&mut self) -> &mut McpToolRegistry {
&mut self.mcp_registry
}
pub fn register_mcp_tool<T: ForgeMcpTool>(&mut self) -> &mut Self {
self.mcp_registry.register::<T>();
self
}
pub fn job_registry(&self) -> &JobRegistry {
&self.job_registry
}
pub fn job_registry_mut(&mut self) -> &mut JobRegistry {
&mut self.job_registry
}
pub fn cron_registry(&self) -> Arc<CronRegistry> {
self.cron_registry.clone()
}
pub fn workflow_registry(&self) -> &WorkflowRegistry {
&self.workflow_registry
}
pub fn workflow_registry_mut(&mut self) -> &mut WorkflowRegistry {
&mut self.workflow_registry
}
pub fn daemon_registry(&self) -> Arc<DaemonRegistry> {
self.daemon_registry.clone()
}
pub fn webhook_registry(&self) -> Arc<WebhookRegistry> {
self.webhook_registry.clone()
}
async fn persist_workflow_definitions(&self, pool: &sqlx::PgPool) -> Result<()> {
for info in self.workflow_registry.definitions() {
let status = if info.is_active {
"active"
} else if info.is_deprecated {
"deprecated"
} else {
"active"
};
let existing = sqlx::query!(
r#"
SELECT workflow_signature FROM forge_workflow_definitions
WHERE workflow_name = $1 AND workflow_version = $2
"#,
info.name,
info.version,
)
.fetch_optional(pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
if let Some(row) = existing {
if row.workflow_signature != info.signature {
return Err(ForgeError::Config(format!(
"Workflow '{}' version '{}' has a different signature than previously registered. \
Persisted contract changed under the same version. \
Expected signature: {}, got: {}. \
Create a new version instead of modifying the existing one.",
info.name, info.version, row.workflow_signature, info.signature
)));
}
sqlx::query!(
"UPDATE forge_workflow_definitions SET status = $3 WHERE workflow_name = $1 AND workflow_version = $2",
info.name,
info.version,
status,
)
.execute(pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
} else {
sqlx::query!(
r#"
INSERT INTO forge_workflow_definitions (workflow_name, workflow_version, workflow_signature, status)
VALUES ($1, $2, $3, $4)
"#,
info.name,
info.version,
info.signature,
status,
)
.execute(pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
}
tracing::debug!(
workflow = info.name,
version = info.version,
signature = info.signature,
status = status,
"Workflow definition registered"
);
}
Ok(())
}
pub async fn run(mut self) -> Result<()> {
let telemetry_config = forge_runtime::TelemetryConfig::from_observability_config(
&self.config.observability,
&self.config.project.name,
&self.config.project.version,
);
let telemetry_result = forge_runtime::init_telemetry(
&telemetry_config,
&self.config.project.name,
&self.config.observability.log_level,
);
match &telemetry_result {
Ok(true) | Ok(false) => {
tracing::debug!(
endpoint = %telemetry_config.otlp_endpoint,
traces = telemetry_config.enable_traces,
metrics = telemetry_config.enable_metrics,
logs = telemetry_config.enable_logs,
sampling = telemetry_config.sampling_ratio,
"Telemetry initialized"
);
}
Err(e) => eprintln!("forge: failed to initialize telemetry: {e}"),
}
tracing::debug!("Connecting to database");
let db =
Database::from_config_with_service(&self.config.database, &self.config.project.name)
.await?;
let pool = db.primary().clone();
let jobs_pool = db.jobs_pool().clone();
let observability_pool = db.observability_pool().clone();
if let Some(handle) = db.start_health_monitor() {
let mut shutdown_rx = self.shutdown_tx.subscribe();
tokio::spawn(async move {
tokio::select! {
_ = shutdown_rx.recv() => {}
_ = handle => {}
}
});
}
self.db = Some(db);
tracing::debug!("Database connected");
let runner = MigrationRunner::new(pool.clone());
let mut user_migrations = load_migrations_from_dir(&self.migrations_dir)?;
user_migrations.extend(self.extra_migrations.clone());
runner.run(user_migrations).await?;
tracing::debug!("Migrations applied");
if !self.workflow_registry.is_empty() {
self.persist_workflow_definitions(&pool).await?;
}
let hostname = get_hostname();
let ip_address: IpAddr = std::env::var("HOST")
.unwrap_or_else(|_| "0.0.0.0".to_string())
.parse()
.unwrap_or_else(|_| "0.0.0.0".parse().expect("valid IP literal"));
if let Ok(port_str) = std::env::var("PORT")
&& let Ok(port) = port_str.parse::<u16>()
{
self.config.gateway.port = port;
}
let roles: Vec<NodeRole> = self
.config
.node
.roles
.iter()
.map(config_role_to_node_role)
.collect();
let node_info = NodeInfo::new_local(
hostname,
ip_address,
self.config.gateway.port,
self.config.gateway.grpc_port,
roles.clone(),
self.config.node.worker_capabilities.clone(),
env!("CARGO_PKG_VERSION").to_string(),
);
let node_id = node_info.id;
self.node_id = node_id;
let node_registry = Arc::new(NodeRegistry::new(pool.clone(), node_info));
if let Err(e) = node_registry.register().await {
tracing::debug!("Failed to register node (tables may not exist): {}", e);
}
if let Err(e) = node_registry.set_status(NodeStatus::Active).await {
tracing::debug!("Failed to set node status: {}", e);
}
let leader_election = if roles.contains(&NodeRole::Scheduler) {
let election = Arc::new(LeaderElection::new(
pool.clone(),
node_id,
LeaderRole::Scheduler,
LeaderConfig::default(),
));
if let Err(e) = election.try_become_leader().await {
tracing::debug!("Failed to acquire leadership: {}", e);
}
Some(election)
} else {
None
};
let shutdown = Arc::new(GracefulShutdown::new(
node_registry.clone(),
leader_election.clone(),
ShutdownConfig::default(),
));
let http_client = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let mut handles = Vec::new();
{
let heartbeat_pool = pool.clone();
let heartbeat_node_id = node_id;
let config = HeartbeatConfig::from_cluster_config(&self.config.cluster);
handles.push(tokio::spawn(async move {
let heartbeat = HeartbeatLoop::new(heartbeat_pool, heartbeat_node_id, config);
heartbeat.run().await;
}));
}
if let Some(ref election) = leader_election {
let election = election.clone();
handles.push(tokio::spawn(async move {
election.run().await;
}));
}
if roles.contains(&NodeRole::Worker) {
let job_queue = JobQueue::new(jobs_pool.clone());
let worker_config = WorkerConfig {
id: Some(node_id.as_uuid()),
capabilities: self.config.node.worker_capabilities.clone(),
max_concurrent: self.config.worker.max_concurrent_jobs,
poll_interval: Duration::from_millis(self.config.worker.poll_interval_ms),
..Default::default()
};
let mut worker = Worker::new(
worker_config,
job_queue,
self.job_registry.clone(),
jobs_pool.clone(),
);
handles.push(tokio::spawn(async move {
if let Err(e) = worker.run().await {
tracing::error!("Worker error: {}", e);
}
}));
tracing::debug!("Job worker started");
}
if roles.contains(&NodeRole::Scheduler) {
let cron_registry = self.cron_registry.clone();
let cron_pool = jobs_pool.clone();
let cron_http = http_client.clone();
let cron_leader_election = leader_election.clone();
let cron_config = CronRunnerConfig {
poll_interval: Duration::from_secs(1),
node_id: node_id.as_uuid(),
is_leader: cron_leader_election.is_none(),
leader_election: cron_leader_election,
run_stale_threshold: Duration::from_secs(15 * 60),
};
let cron_runner = CronRunner::new(cron_registry, cron_pool, cron_http, cron_config);
handles.push(tokio::spawn(async move {
if let Err(e) = cron_runner.run().await {
tracing::error!("Cron runner error: {}", e);
}
}));
tracing::debug!("Cron scheduler started");
}
let workflow_shutdown_token = CancellationToken::new();
if roles.contains(&NodeRole::Scheduler) {
let scheduler_executor = Arc::new(WorkflowExecutor::new(
Arc::new(self.workflow_registry.clone()),
jobs_pool.clone(),
http_client.clone(),
));
let event_store = Arc::new(EventStore::new(jobs_pool.clone()));
let scheduler = WorkflowScheduler::new(
jobs_pool.clone(),
scheduler_executor,
event_store,
WorkflowSchedulerConfig::default(),
);
let shutdown_token = workflow_shutdown_token.clone();
handles.push(tokio::spawn(async move {
scheduler.run(shutdown_token).await;
}));
tracing::debug!("Workflow scheduler started");
}
let job_queue_for_dispatch = JobQueue::new(jobs_pool.clone());
let job_dispatcher = Arc::new(JobDispatcher::new(
job_queue_for_dispatch,
self.job_registry.clone(),
));
let workflow_executor = Arc::new(WorkflowExecutor::new(
Arc::new(self.workflow_registry.clone()),
jobs_pool.clone(),
http_client.clone(),
));
if roles.contains(&NodeRole::Scheduler) && !self.daemon_registry.is_empty() {
let daemon_registry = self.daemon_registry.clone();
let daemon_pool = jobs_pool.clone();
let daemon_http = http_client.clone();
let daemon_shutdown_rx = self.shutdown_tx.subscribe();
let daemon_runner = DaemonRunner::new(
daemon_registry,
daemon_pool,
daemon_http,
node_id.as_uuid(),
daemon_shutdown_rx,
)
.with_job_dispatch(job_dispatcher.clone())
.with_workflow_dispatch(workflow_executor.clone());
handles.push(tokio::spawn(async move {
if let Err(e) = daemon_runner.run().await {
tracing::error!("Daemon runner error: {}", e);
}
}));
tracing::debug!("Daemon runner started");
}
let mut reactor_handle = None;
if roles.contains(&NodeRole::Gateway) {
let gateway_config = RuntimeGatewayConfig {
port: self.config.gateway.port,
max_connections: self.config.gateway.max_connections,
sse_max_sessions: self.config.gateway.sse_max_sessions,
request_timeout_secs: self.config.gateway.request_timeout_secs,
cors_enabled: self.config.gateway.cors_enabled
|| !self.config.gateway.cors_origins.is_empty(),
cors_origins: self.config.gateway.cors_origins.clone(),
auth: AuthConfig::from_forge_config(&self.config.auth)
.map_err(|e| ForgeError::Config(e.to_string()))?,
mcp: self.config.mcp.clone(),
quiet_routes: self.config.gateway.quiet_routes.clone(),
max_body_size_bytes: self.config.gateway.max_body_size_bytes()?,
max_file_size_bytes: self.config.gateway.max_file_size_bytes()?,
token_ttl: forge_core::AuthTokenTtl {
access_token_secs: self.config.auth.access_token_ttl_secs(),
refresh_token_days: self.config.auth.refresh_token_ttl_days(),
},
project_name: self.config.project.name.clone(),
};
let db_ref = self
.db
.clone()
.ok_or_else(|| ForgeError::Internal("Database not initialized".into()))?;
let mut gateway = GatewayServer::new(
gateway_config,
self.function_registry.clone(),
db_ref.clone(),
)
.with_job_dispatcher(job_dispatcher.clone())
.with_workflow_dispatcher(workflow_executor.clone())
.with_mcp_registry(self.mcp_registry.clone());
if self.config.signals.enabled {
let signals_pool = std::sync::Arc::new(db_ref.analytics_pool().clone());
let collector = forge_runtime::signals::SignalsCollector::spawn(
signals_pool.clone(),
self.config.signals.batch_size,
std::time::Duration::from_millis(self.config.signals.flush_interval_ms),
);
let geoip = match &self.config.signals.geoip_db_path {
Some(path) => {
let resolver = forge_runtime::signals::geoip::GeoIpResolver::from_mmdb(
std::path::Path::new(path),
)?;
tracing::info!(path, "GeoIP: MaxMind MMDB loaded (city-level)");
resolver
}
None => forge_runtime::signals::geoip::GeoIpResolver::new(),
};
gateway = gateway
.with_signals_collector(collector)
.with_signals_anonymize_ip(self.config.signals.anonymize_ip)
.with_signals_geoip(geoip);
forge_runtime::signals::session::spawn_session_reaper(
signals_pool,
self.config.signals.session_timeout_mins,
);
tracing::info!("Signals enabled (analytics + diagnostics)");
}
if let Some(factory) = self.custom_routes_factory.take() {
gateway = gateway.with_custom_routes(factory(pool.clone()));
tracing::debug!("Custom routes merged into gateway middleware stack");
}
let reactor = gateway.reactor();
if let Err(e) = reactor.start().await {
tracing::error!("Failed to start reactor: {}", e);
} else {
tracing::debug!("Reactor started");
reactor_handle = Some(reactor);
}
let api_router = gateway.router();
let mut router = Router::new().nest("/_api", api_router);
if !self.webhook_registry.is_empty() {
use axum::routing::post;
use tower_http::cors::{Any, CorsLayer};
let webhook_state = Arc::new(
WebhookState::new(self.webhook_registry.clone(), pool.clone())
.with_job_dispatcher(job_dispatcher.clone()),
);
let webhook_cors = if self.config.gateway.cors_enabled
|| !self.config.gateway.cors_origins.is_empty()
{
if self.config.gateway.cors_origins.iter().any(|o| o == "*") {
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
} else {
use axum::http::Method;
let origins: Vec<_> = self
.config
.gateway
.cors_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
CorsLayer::new()
.allow_origin(origins)
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::PATCH,
Method::OPTIONS,
])
.allow_headers([
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
axum::http::header::ACCEPT,
axum::http::HeaderName::from_static("x-webhook-signature"),
axum::http::HeaderName::from_static("x-idempotency-key"),
])
.allow_credentials(true)
}
} else {
CorsLayer::new()
};
let webhook_router = Router::new()
.route("/{*path}", post(webhook_handler).with_state(webhook_state))
.layer(axum::extract::DefaultBodyLimit::max(1024 * 1024))
.layer(
tower::ServiceBuilder::new()
.layer(axum::error_handling::HandleErrorLayer::new(
|err: tower::BoxError| async move {
if err.is::<tower::timeout::error::Elapsed>() {
return (
axum::http::StatusCode::REQUEST_TIMEOUT,
"Request timed out",
);
}
(
axum::http::StatusCode::SERVICE_UNAVAILABLE,
"Server overloaded",
)
},
))
.layer(tower::limit::ConcurrencyLimitLayer::new(
self.config.gateway.max_connections,
))
.layer(tower::timeout::TimeoutLayer::new(Duration::from_secs(
self.config.gateway.request_timeout_secs,
))),
)
.layer(webhook_cors);
router = router.nest("/_api/webhooks", webhook_router);
tracing::debug!(
webhooks = ?self.webhook_registry.paths().collect::<Vec<_>>(),
"Webhook routes registered"
);
}
if self.config.mcp.enabled {
use axum::routing::get;
if let Some((oauth_api_router, oauth_state)) = gateway.oauth_router() {
router = router.nest("/_api", oauth_api_router);
router = router
.route(
"/.well-known/oauth-authorization-server",
get(forge_runtime::gateway::oauth::well_known_oauth_metadata)
.with_state(oauth_state.clone()),
)
.route(
"/.well-known/oauth-protected-resource",
get(forge_runtime::gateway::oauth::well_known_resource_metadata)
.with_state(oauth_state),
);
tracing::info!("OAuth 2.1 endpoints enabled for MCP");
} else {
async fn oauth_not_supported() -> impl axum::response::IntoResponse {
(
axum::http::StatusCode::NOT_FOUND,
axum::Json(serde_json::json!({
"error": "oauth_not_supported",
"error_description": "This server does not support OAuth. Connect without authentication."
})),
)
}
router = router
.route(
"/.well-known/oauth-authorization-server",
get(oauth_not_supported),
)
.route(
"/.well-known/oauth-protected-resource",
get(oauth_not_supported),
);
}
}
if let Some(handler) = self.frontend_handler {
use axum::routing::get;
router = router.fallback(get(handler));
tracing::debug!("Frontend handler enabled");
}
let addr = gateway.addr();
handles.push(tokio::spawn(async move {
tracing::debug!(addr = %addr, "Gateway server binding");
let listener = tokio::net::TcpListener::bind(addr)
.await
.expect("Failed to bind");
if let Err(e) = axum::serve(listener, router).await {
tracing::error!("Gateway server error: {}", e);
}
}));
}
tracing::info!(
queries = self.function_registry.queries().count(),
mutations = self.function_registry.mutations().count(),
jobs = self.job_registry.len(),
crons = self.cron_registry.len(),
workflows = self.workflow_registry.len(),
daemons = self.daemon_registry.len(),
webhooks = self.webhook_registry.len(),
mcp_tools = self.mcp_registry.len(),
"Functions registered"
);
{
let metrics_pool = observability_pool;
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(15)).await;
forge_runtime::observability::record_pool_metrics(&metrics_pool);
}
});
}
let role_names: Vec<&str> = roles.iter().map(|r| r.as_str()).collect();
let capabilities = &self.config.node.worker_capabilities;
tracing::info!(
node_id = %node_id,
project = %self.config.project.name,
version = env!("CARGO_PKG_VERSION"),
roles = ?role_names,
worker_capabilities = ?capabilities,
port = self.config.gateway.port,
db_pool_size = self.config.database.pool_size,
cluster_discovery = ?self.config.cluster.discovery,
observability = self.config.observability.enabled,
mcp = self.config.mcp.enabled,
"Forge started"
);
let mut shutdown_rx = self.shutdown_tx.subscribe();
tokio::select! {
_ = tokio::signal::ctrl_c() => {
tracing::debug!("Received ctrl-c");
}
_ = shutdown_rx.recv() => {
tracing::debug!("Received shutdown notification");
}
}
tracing::debug!("Graceful shutdown starting");
workflow_shutdown_token.cancel();
if let Err(e) = shutdown.shutdown().await {
tracing::warn!(error = %e, "Shutdown error");
}
if let Some(ref election) = leader_election {
election.stop();
}
if let Some(ref reactor) = reactor_handle {
reactor.stop();
}
if let Some(ref db) = self.db {
db.close().await;
}
forge_runtime::shutdown_telemetry();
tracing::info!("Forge stopped");
Ok(())
}
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(());
}
}
pub struct ForgeBuilder {
config: Option<ForgeConfig>,
function_registry: FunctionRegistry,
mcp_registry: McpToolRegistry,
job_registry: JobRegistry,
cron_registry: CronRegistry,
workflow_registry: WorkflowRegistry,
daemon_registry: DaemonRegistry,
webhook_registry: WebhookRegistry,
migrations_dir: PathBuf,
extra_migrations: Vec<Migration>,
frontend_handler: Option<FrontendHandler>,
custom_routes_factory: Option<Box<dyn FnOnce(sqlx::PgPool) -> Router + Send + Sync>>,
}
impl ForgeBuilder {
pub fn new() -> Self {
Self {
config: None,
function_registry: FunctionRegistry::new(),
mcp_registry: McpToolRegistry::new(),
job_registry: JobRegistry::new(),
cron_registry: CronRegistry::new(),
workflow_registry: WorkflowRegistry::new(),
daemon_registry: DaemonRegistry::new(),
webhook_registry: WebhookRegistry::new(),
migrations_dir: PathBuf::from("migrations"),
extra_migrations: Vec::new(),
frontend_handler: None,
custom_routes_factory: None,
}
}
pub fn migrations_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.migrations_dir = path.into();
self
}
pub fn migration(mut self, name: impl Into<String>, sql: impl Into<String>) -> Self {
self.extra_migrations.push(Migration::new(name, sql));
self
}
pub fn frontend_handler(mut self, handler: FrontendHandler) -> Self {
self.frontend_handler = Some(handler);
self
}
pub fn custom_routes<F>(mut self, f: F) -> Self
where
F: FnOnce(sqlx::PgPool) -> Router + Send + Sync + 'static,
{
self.custom_routes_factory = Some(Box::new(f));
self
}
pub fn auto_register(mut self) -> Self {
crate::auto_register::auto_register_all(
&mut self.function_registry,
&mut self.job_registry,
&mut self.cron_registry,
&mut self.workflow_registry,
&mut self.daemon_registry,
&mut self.webhook_registry,
&mut self.mcp_registry,
);
self
}
pub fn config(mut self, config: ForgeConfig) -> Self {
self.config = Some(config);
self
}
pub fn function_registry_mut(&mut self) -> &mut FunctionRegistry {
&mut self.function_registry
}
pub fn job_registry_mut(&mut self) -> &mut JobRegistry {
&mut self.job_registry
}
pub fn mcp_registry_mut(&mut self) -> &mut McpToolRegistry {
&mut self.mcp_registry
}
pub fn register_mcp_tool<T: ForgeMcpTool>(mut self) -> Self {
self.mcp_registry.register::<T>();
self
}
pub fn cron_registry_mut(&mut self) -> &mut CronRegistry {
&mut self.cron_registry
}
pub fn workflow_registry_mut(&mut self) -> &mut WorkflowRegistry {
&mut self.workflow_registry
}
pub fn daemon_registry_mut(&mut self) -> &mut DaemonRegistry {
&mut self.daemon_registry
}
pub fn webhook_registry_mut(&mut self) -> &mut WebhookRegistry {
&mut self.webhook_registry
}
pub fn register_query<Q: ForgeQuery>(mut self) -> Self
where
Q::Args: serde::de::DeserializeOwned + Send + 'static,
Q::Output: serde::Serialize + Send + 'static,
{
self.function_registry.register_query::<Q>();
self
}
pub fn register_mutation<M: ForgeMutation>(mut self) -> Self
where
M::Args: serde::de::DeserializeOwned + Send + 'static,
M::Output: serde::Serialize + Send + 'static,
{
self.function_registry.register_mutation::<M>();
self
}
pub fn register_job<J: forge_core::ForgeJob>(mut self) -> Self
where
J::Args: serde::de::DeserializeOwned + Send + 'static,
J::Output: serde::Serialize + Send + 'static,
{
self.job_registry.register::<J>();
self
}
pub fn register_cron<C: forge_core::ForgeCron>(mut self) -> Self {
self.cron_registry.register::<C>();
self
}
pub fn register_workflow<W: forge_core::ForgeWorkflow>(mut self) -> Self
where
W::Input: serde::de::DeserializeOwned,
W::Output: serde::Serialize,
{
self.workflow_registry.register::<W>();
self
}
pub fn register_daemon<D: forge_core::ForgeDaemon>(mut self) -> Self {
self.daemon_registry.register::<D>();
self
}
pub fn register_webhook<W: forge_core::ForgeWebhook>(mut self) -> Self {
self.webhook_registry.register::<W>();
self
}
pub fn build(self) -> Result<Forge> {
let config = self
.config
.ok_or_else(|| ForgeError::Config("Configuration is required".to_string()))?;
let (shutdown_tx, _) = broadcast::channel(1);
Ok(Forge {
config,
db: None,
node_id: NodeId::new(),
function_registry: self.function_registry,
mcp_registry: self.mcp_registry,
job_registry: self.job_registry,
cron_registry: Arc::new(self.cron_registry),
workflow_registry: self.workflow_registry,
daemon_registry: Arc::new(self.daemon_registry),
webhook_registry: Arc::new(self.webhook_registry),
shutdown_tx,
migrations_dir: self.migrations_dir,
extra_migrations: self.extra_migrations,
frontend_handler: self.frontend_handler,
custom_routes_factory: self.custom_routes_factory,
})
}
}
impl Default for ForgeBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(unix)]
fn get_hostname() -> String {
nix::unistd::gethostname()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string())
}
#[cfg(not(unix))]
fn get_hostname() -> String {
std::env::var("COMPUTERNAME")
.or_else(|_| std::env::var("HOSTNAME"))
.unwrap_or_else(|_| "unknown".to_string())
}
fn config_role_to_node_role(role: &ConfigNodeRole) -> NodeRole {
match role {
ConfigNodeRole::Gateway => NodeRole::Gateway,
ConfigNodeRole::Function => NodeRole::Function,
ConfigNodeRole::Worker => NodeRole::Worker,
ConfigNodeRole::Scheduler => NodeRole::Scheduler,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
use std::future::Future;
use std::pin::Pin;
use forge_core::mcp::{McpToolAnnotations, McpToolInfo};
struct TestMcpTool;
impl ForgeMcpTool for TestMcpTool {
type Args = serde_json::Value;
type Output = serde_json::Value;
fn info() -> McpToolInfo {
McpToolInfo {
name: "test.mcp.tool",
title: None,
description: None,
required_role: None,
is_public: false,
timeout: None,
rate_limit_requests: None,
rate_limit_per_secs: None,
rate_limit_key: None,
annotations: McpToolAnnotations::default(),
icons: &[],
}
}
fn execute(
_ctx: &forge_core::McpToolContext,
_args: Self::Args,
) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
Box::pin(async { Ok(serde_json::json!({ "ok": true })) })
}
}
#[test]
fn test_forge_builder_new() {
let builder = ForgeBuilder::new();
assert!(builder.config.is_none());
}
#[test]
fn test_forge_builder_requires_config() {
let builder = ForgeBuilder::new();
let result = builder.build();
assert!(result.is_err());
}
#[test]
fn test_forge_builder_with_config() {
let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
let result = ForgeBuilder::new().config(config).build();
assert!(result.is_ok());
}
#[test]
fn test_forge_builder_register_mcp_tool() {
let builder = ForgeBuilder::new().register_mcp_tool::<TestMcpTool>();
assert_eq!(builder.mcp_registry.len(), 1);
}
#[test]
fn test_config_role_conversion() {
assert_eq!(
config_role_to_node_role(&ConfigNodeRole::Gateway),
NodeRole::Gateway
);
assert_eq!(
config_role_to_node_role(&ConfigNodeRole::Worker),
NodeRole::Worker
);
assert_eq!(
config_role_to_node_role(&ConfigNodeRole::Scheduler),
NodeRole::Scheduler
);
assert_eq!(
config_role_to_node_role(&ConfigNodeRole::Function),
NodeRole::Function
);
}
}