Skip to main content

bitrouter_runtime/
server.rs

1use std::sync::Arc;
2
3use bitrouter_api::router::{anthropic, google, openai, routes};
4use bitrouter_config::BitrouterConfig;
5use bitrouter_core::routers::{model_router::LanguageModelRouter, routing_table::RoutingTable};
6use sea_orm::DatabaseConnection;
7use warp::Filter;
8
9use crate::auth::{self, AuthContext, Unauthorized};
10use crate::error::Result;
11use crate::keys;
12
13/// A stub model router that rejects all requests with a descriptive error.
14///
15/// Used when the server starts without a real provider-backed router. Health
16/// checks and other non-model endpoints still work; only model API requests
17/// will return an error.
18pub struct StubModelRouter;
19
20impl LanguageModelRouter for StubModelRouter {
21    async fn route_model(
22        &self,
23        _target: bitrouter_core::routers::routing_table::RoutingTarget,
24    ) -> bitrouter_core::errors::Result<
25        Box<bitrouter_core::models::language::language_model::DynLanguageModel<'static>>,
26    > {
27        Err(bitrouter_core::errors::BitrouterError::unsupported(
28            "runtime",
29            "model routing",
30            Some("no model router configured — configure providers to enable API endpoints".into()),
31        ))
32    }
33}
34
35pub struct ServerPlan<T, R> {
36    config: BitrouterConfig,
37    table: Arc<T>,
38    router: Arc<R>,
39    db: Option<Arc<DatabaseConnection>>,
40}
41
42impl<T, R> ServerPlan<T, R>
43where
44    T: RoutingTable + Send + Sync + 'static,
45    R: LanguageModelRouter + Send + Sync + 'static,
46{
47    pub fn new(config: BitrouterConfig, table: Arc<T>, router: Arc<R>) -> Self {
48        Self {
49            config,
50            table,
51            router,
52            db: None,
53        }
54    }
55
56    /// Set the database connection for virtual key lookups and key management.
57    pub fn with_db(mut self, db: DatabaseConnection) -> Self {
58        self.db = Some(Arc::new(db));
59        self
60    }
61
62    pub async fn serve(self) -> Result<()> {
63        let addr = self.config.server.listen;
64
65        // Build auth context.
66        let auth_ctx = Arc::new(AuthContext::new(
67            self.config.master_key.as_deref(),
68            self.db.as_ref().map(|db| db.as_ref().clone()),
69        ));
70
71        let health = warp::path("health")
72            .and(warp::get())
73            .map(|| warp::reply::json(&serde_json::json!({ "status": "ok" })));
74
75        // Route listing — no auth required.
76        let route_list = routes::routes_filter(self.table.clone());
77
78        // Model API routes — gated by protocol-appropriate auth.
79        let chat = auth_gate(auth::openai_auth(auth_ctx.clone())).and(
80            openai::chat::filters::chat_completions_filter(self.table.clone(), self.router.clone()),
81        );
82        let messages = auth_gate(auth::anthropic_auth(auth_ctx.clone())).and(
83            anthropic::messages::filters::messages_filter(self.table.clone(), self.router.clone()),
84        );
85        let responses = auth_gate(auth::openai_auth(auth_ctx.clone())).and(
86            openai::responses::filters::responses_filter(self.table.clone(), self.router.clone()),
87        );
88        let generate_content = auth_gate(auth::openai_auth(auth_ctx.clone())).and(
89            google::generate_content::filters::generate_content_filter(
90                self.table.clone(),
91                self.router.clone(),
92            ),
93        );
94
95        // Key management routes — always mounted (returns 404 if no DB, since
96        // the filter will not match without the DB anyway).
97        let key_mgmt = keys::key_routes(auth_ctx.clone(), self.db.clone());
98
99        let routes = health
100            .or(route_list)
101            .or(chat)
102            .or(messages)
103            .or(responses)
104            .or(generate_content)
105            .or(key_mgmt)
106            .recover(handle_auth_rejection)
107            .with(warp::trace::request());
108
109        let server = warp::serve(routes)
110            .bind(addr)
111            .await
112            .graceful(shutdown_signal());
113
114        if auth_ctx.is_open() {
115            tracing::info!(%addr, "server listening (auth disabled — no master_key configured)");
116        } else {
117            tracing::info!(%addr, "server listening (auth enabled)");
118        }
119        server.run().await;
120        tracing::info!("server stopped");
121
122        Ok(())
123    }
124}
125
126/// Convert an auth filter into a gate that rejects unauthorized requests
127/// but does not add anything to the extract tuple. This lets us compose
128/// `auth_gate(auth).and(existing_filter)` without changing the existing
129/// filter's handler signature.
130fn auth_gate(
131    auth: impl Filter<Extract = (bitrouter_accounts::identity::Identity,), Error = warp::Rejection>
132    + Clone,
133) -> impl Filter<Extract = (), Error = warp::Rejection> + Clone {
134    auth.map(|_| ()).untuple_one()
135}
136
137/// Rejection handler that turns [`Unauthorized`] into a JSON 401 response.
138async fn handle_auth_rejection(
139    rejection: warp::Rejection,
140) -> std::result::Result<impl warp::Reply, warp::Rejection> {
141    if let Some(e) = rejection.find::<Unauthorized>() {
142        let json = warp::reply::json(&serde_json::json!({
143            "error": {
144                "message": e.to_string(),
145                "type": "authentication_error",
146            }
147        }));
148        return Ok(warp::reply::with_status(
149            json,
150            warp::http::StatusCode::UNAUTHORIZED,
151        ));
152    }
153    Err(rejection)
154}
155
156async fn shutdown_signal() {
157    let ctrl_c = tokio::signal::ctrl_c();
158
159    #[cfg(unix)]
160    {
161        let mut term =
162            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()).unwrap();
163        tokio::select! {
164            _ = ctrl_c => {}
165            _ = term.recv() => {}
166        }
167    }
168
169    #[cfg(not(unix))]
170    {
171        ctrl_c.await.ok();
172    }
173
174    tracing::info!("shutdown signal received");
175}