use axum::routing::{get, post};
use axum::{Router, middleware};
use std::pin::Pin;
use std::sync::Arc;
use systemprompt_database::DbPool;
use systemprompt_models::modules::ApiPaths;
use systemprompt_models::{AgentConfig, AiProvider};
use tokio::sync::RwLock;
use tower_http::cors::CorsLayer;
use tower_http::services::ServeDir;
use super::auth::{AgentOAuthConfig, AgentOAuthState, agent_oauth_middleware_wrapper};
use super::handlers::{AgentHandlerState, handle_agent_card, handle_agent_request};
use crate::state::AgentState;
pub struct Server {
db_pool: DbPool,
config: Arc<RwLock<AgentConfig>>,
oauth_state: Arc<AgentOAuthState>,
agent_state: Arc<AgentState>,
ai_service: Arc<dyn AiProvider>,
port: u16,
}
impl std::fmt::Debug for Server {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Server")
.field("db_pool", &"<DbPool>")
.field("config", &"Arc<RwLock<AgentConfig>>")
.field("oauth_state", &"Arc<AgentOAuthState>")
.field("agent_state", &"Arc<AgentState>")
.field("ai_service", &"<Arc<dyn AiProvider>>")
.field("port", &self.port)
.finish()
}
}
impl Server {
pub async fn new(
db_pool: DbPool,
agent_state: Arc<AgentState>,
ai_service: Arc<dyn AiProvider>,
agent_name: Option<String>,
port: u16,
) -> anyhow::Result<Self> {
use crate::services::registry::AgentRegistry;
let mut config = if let Some(name) = agent_name {
let registry = AgentRegistry::new()?;
registry.get_agent(&name).await?
} else {
return Err(anyhow::anyhow!("Agent name is required"));
};
config.extract_oauth_scopes_from_card();
let oauth_config = AgentOAuthConfig::default();
let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?.to_string();
let global_config = systemprompt_models::Config::get()?;
let mut oauth_state = AgentOAuthState::new(
Arc::clone(&db_pool),
oauth_config,
jwt_secret,
global_config.jwt_issuer.clone(),
global_config.jwt_audiences.clone(),
);
oauth_state = oauth_state.with_jwt_provider(Arc::clone(agent_state.jwt_provider()));
if let Some(user_provider) = agent_state.user_provider().cloned() {
oauth_state = oauth_state.with_user_provider(user_provider);
}
Ok(Self {
db_pool,
config: Arc::new(RwLock::new(config)),
oauth_state: Arc::new(oauth_state),
agent_state,
ai_service,
port,
})
}
pub async fn reload_config(&self) -> anyhow::Result<()> {
use crate::services::registry::AgentRegistry;
let agent_name = {
let config = self.config.read().await;
config.name.clone()
};
let registry = AgentRegistry::new()?;
let mut new_config = registry.get_agent(&agent_name).await?;
new_config.extract_oauth_scopes_from_card();
*self.config.write().await = new_config;
tracing::info!(agent_name = %agent_name, "Configuration reloaded");
Ok(())
}
pub fn create_router(&self) -> Router {
let state = Arc::new(AgentHandlerState {
db_pool: Arc::clone(&self.db_pool),
config: Arc::clone(&self.config),
oauth_state: Arc::clone(&self.oauth_state),
agent_state: Arc::clone(&self.agent_state),
ai_service: Arc::clone(&self.ai_service),
});
let post_router = Router::new()
.route("/", post(handle_agent_request))
.with_state(Arc::clone(&state))
.layer(middleware::from_fn_with_state(
Arc::clone(&state),
agent_oauth_middleware_wrapper,
));
let get_router = Router::new()
.route(ApiPaths::WELLKNOWN_AGENT_CARD, get(handle_agent_card))
.route(ApiPaths::A2A_CARD, get(handle_agent_card))
.with_state(state);
let api_router = Router::new().merge(post_router).merge(get_router);
let web_dist_path = std::path::Path::new("web/dist");
let router = if web_dist_path.exists() {
api_router.fallback_service(ServeDir::new(web_dist_path))
} else {
api_router
};
router.layer(CorsLayer::permissive())
}
pub async fn run(self) -> anyhow::Result<()> {
Self::log_server_configuration();
self.start_server(None).await
}
pub async fn run_with_shutdown(
self,
shutdown_signal: impl Future<Output = ()> + Send + 'static,
) -> anyhow::Result<()> {
Self::log_server_configuration();
self.start_server(Some(Box::pin(shutdown_signal))).await
}
const fn log_server_configuration() {}
async fn start_server(
self,
shutdown_signal: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
) -> anyhow::Result<()> {
let app = self.create_router();
let addr = format!("0.0.0.0:{}", self.port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
match shutdown_signal {
Some(signal) => axum::serve(listener, app)
.with_graceful_shutdown(signal)
.await
.map_err(|e| anyhow::anyhow!("Server error: {}", e)),
None => axum::serve(listener, app)
.await
.map_err(|e| anyhow::anyhow!("Server error: {}", e)),
}
}
}