Skip to main content

haystack_server/
app.rs

1//! Server builder and startup.
2
3use std::sync::Arc;
4
5use axum::Router;
6use axum::body::Body;
7use axum::extract::{DefaultBodyLimit, State};
8use axum::http::{Method, Request, StatusCode};
9use axum::middleware::{self, Next};
10use axum::response::{IntoResponse, Response};
11use axum::routing::{get, post};
12
13use haystack_core::auth::{AuthHeader, parse_auth_header};
14use haystack_core::graph::SharedGraph;
15use haystack_core::ontology::DefNamespace;
16
17use crate::actions::ActionRegistry;
18use crate::auth::AuthManager;
19use crate::his_store::HisStore;
20use crate::ops;
21use crate::state::{AppState, SharedState};
22use crate::ws;
23use crate::ws::WatchManager;
24
25/// Builder for the Haystack HTTP server.
26pub struct HaystackServer {
27    graph: SharedGraph,
28    namespace: DefNamespace,
29    auth_manager: AuthManager,
30    actions: ActionRegistry,
31    custom_router: Option<Router<SharedState>>,
32    authenticated_router: Option<Router<SharedState>>,
33    history_provider: Option<Box<dyn crate::his_provider::HistoryProvider>>,
34    port: u16,
35    host: String,
36}
37
38impl HaystackServer {
39    /// Create a new server with the given entity graph.
40    pub fn new(graph: SharedGraph) -> Self {
41        Self {
42            graph,
43            namespace: DefNamespace::new(),
44            auth_manager: AuthManager::empty(),
45            actions: ActionRegistry::new(),
46            custom_router: None,
47            authenticated_router: None,
48            history_provider: None,
49            port: 8080,
50            host: "127.0.0.1".to_string(),
51        }
52    }
53
54    /// Set the ontology namespace for def/spec operations.
55    pub fn with_namespace(mut self, ns: DefNamespace) -> Self {
56        self.namespace = ns;
57        self
58    }
59
60    /// Set the authentication manager.
61    pub fn with_auth(mut self, auth: AuthManager) -> Self {
62        self.auth_manager = auth;
63        self
64    }
65
66    /// Set the port to listen on (default: 8080).
67    pub fn port(mut self, port: u16) -> Self {
68        self.port = port;
69        self
70    }
71
72    /// Set the host to bind to (default: "127.0.0.1").
73    pub fn host(mut self, host: &str) -> Self {
74        self.host = host.to_string();
75        self
76    }
77
78    /// Set the action registry for the `invokeAction` op.
79    pub fn with_actions(mut self, actions: ActionRegistry) -> Self {
80        self.actions = actions;
81        self
82    }
83
84    /// Merge additional routes into the server.
85    ///
86    /// **Note:** Routes added via `with_router()` are NOT protected by the built-in
87    /// auth middleware. To protect custom routes, apply your own auth layer to the
88    /// router before passing it, or use `with_authenticated_router()` instead.
89    ///
90    /// The router's routes are merged at the top level, so paths must
91    /// include any prefix (e.g. `/custom/endpoint`).
92    pub fn with_router(mut self, router: Router<SharedState>) -> Self {
93        self.custom_router = Some(router);
94        self
95    }
96
97    /// Merge additional routes that are protected by the built-in auth middleware.
98    ///
99    /// Routes added here go through the same authentication and permission
100    /// checks as the standard Haystack API endpoints.
101    pub fn with_authenticated_router(mut self, router: Router<SharedState>) -> Self {
102        self.authenticated_router = Some(router);
103        self
104    }
105
106    /// Set the history storage provider (default: in-memory [`HisStore`]).
107    pub fn with_history_provider(
108        mut self,
109        provider: Box<dyn crate::his_provider::HistoryProvider>,
110    ) -> Self {
111        self.history_provider = Some(provider);
112        self
113    }
114
115    /// Start the HTTP server. This blocks until the server is stopped.
116    pub async fn run(self) -> std::io::Result<()> {
117        let his: Box<dyn crate::his_provider::HistoryProvider> = self
118            .history_provider
119            .unwrap_or_else(|| Box::new(HisStore::new()));
120
121        let state: SharedState = Arc::new(AppState {
122            graph: self.graph,
123            namespace: parking_lot::RwLock::new(self.namespace),
124            auth: self.auth_manager,
125            watches: WatchManager::new(),
126            actions: self.actions,
127            his,
128            started_at: std::time::Instant::now(),
129        });
130
131        let mut core_router = Router::new()
132            // GET routes
133            .route("/api/about", get(ops::about::handle))
134            .route("/api/ops", get(ops::ops_handler::handle))
135            .route("/api/formats", get(ops::formats::handle))
136            .route("/api/ws", get(ws::ws_handler))
137            // POST routes
138            .route("/api/read", post(ops::read::handle))
139            .route("/api/nav", post(ops::nav::handle))
140            .route("/api/defs", post(ops::defs::handle))
141            .route("/api/libs", post(ops::defs::handle_libs))
142            .route("/api/hisRead", post(ops::his::handle_read))
143            .route("/api/hisWrite", post(ops::his::handle_write))
144            .route("/api/watchSub", post(ops::watch::handle_sub))
145            .route("/api/watchPoll", post(ops::watch::handle_poll))
146            .route("/api/watchUnsub", post(ops::watch::handle_unsub))
147            .route("/api/pointWrite", post(ops::point_write::handle))
148            .route("/api/invokeAction", post(ops::invoke::handle))
149            .route("/api/close", post(ops::about::handle_close))
150            .route("/api/import", post(ops::data::handle_import))
151            .route("/api/export", post(ops::data::handle_export))
152            .route("/api/validate", post(ops::libs::handle_validate))
153            .route("/api/specs", post(ops::libs::handle_specs))
154            .route("/api/spec", post(ops::libs::handle_spec))
155            .route("/api/loadLib", post(ops::libs::handle_load_lib))
156            .route("/api/unloadLib", post(ops::libs::handle_unload_lib))
157            .route("/api/exportLib", post(ops::libs::handle_export_lib))
158            .route("/api/changes", post(ops::changes::handle));
159
160        // Merge the authenticated custom router before applying the auth layer,
161        // so its routes are also protected by the built-in auth middleware.
162        if let Some(auth_router) = self.authenticated_router {
163            core_router = core_router.merge(auth_router);
164        }
165
166        let mut app = core_router
167            .route_layer(middleware::from_fn_with_state(
168                state.clone(),
169                auth_middleware,
170            ))
171            .layer(DefaultBodyLimit::max(2 * 1024 * 1024))
172            .with_state(state.clone());
173
174        if let Some(custom) = self.custom_router {
175            app = app.merge(custom.with_state(state));
176        }
177
178        log::info!("Starting haystack-server on {}:{}", self.host, self.port);
179
180        let listener =
181            tokio::net::TcpListener::bind(format!("{}:{}", self.host, self.port)).await?;
182        axum::serve(listener, app).await
183    }
184}
185
186/// Determine the required permission for a given request path.
187///
188/// Returns `None` if the path does not require permission checking
189/// (e.g. public endpoints handled before auth).
190fn required_permission(path: &str) -> Option<&'static str> {
191    // Write operations
192    match path {
193        "/api/pointWrite" | "/api/hisWrite" | "/api/invokeAction" | "/api/loadLib"
194        | "/api/unloadLib" | "/api/import" => return Some("write"),
195        _ => {}
196    }
197
198    // Everything else that reaches here is a read-level operation:
199    // /api/about, /api/read, /api/nav, /api/defs, /api/libs,
200    // /api/hisRead, /api/watchSub, /api/watchPoll, /api/watchUnsub,
201    // /api/close, /api/ops, /api/formats, etc.
202    Some("read")
203}
204
205/// Authentication middleware for Axum.
206///
207/// - GET /api/about: pass through (about handles auth itself for SCRAM)
208/// - GET /api/ops, GET /api/formats: pass through (public info)
209/// - All other endpoints: require BEARER token if auth is enabled,
210///   then check the user has the required permission for that route.
211async fn auth_middleware(
212    State(state): State<SharedState>,
213    mut req: Request<Body>,
214    next: Next,
215) -> Response {
216    let path = req.uri().path().to_string();
217    let method = req.method().clone();
218
219    // Allow about endpoint through (it handles auth itself for SCRAM handshake)
220    if path == "/api/about" {
221        return next.run(req).await;
222    }
223
224    // Allow ops and formats through without auth (public endpoints)
225    if (path == "/api/ops" || path == "/api/formats") && method == Method::GET {
226        return next.run(req).await;
227    }
228
229    // Check if auth is enabled
230    if !state.auth.is_enabled() {
231        return next.run(req).await;
232    }
233
234    // Extract and validate BEARER token
235    let auth_header = req
236        .headers()
237        .get("Authorization")
238        .and_then(|v| v.to_str().ok())
239        .map(|s| s.to_string());
240
241    match auth_header {
242        Some(header) => match parse_auth_header(&header) {
243            Ok(AuthHeader::Bearer { auth_token }) => {
244                match state.auth.validate_token(&auth_token) {
245                    Some(auth_user) => {
246                        // Check permission for the requested path
247                        if let Some(required) = required_permission(&path)
248                            && !AuthManager::check_permission(&auth_user, required)
249                        {
250                            return crate::error::HaystackError::forbidden(format!(
251                                "insufficient '{}' permission",
252                                required
253                            ))
254                            .into_response();
255                        }
256
257                        // Inject AuthUser into request extensions
258                        req.extensions_mut().insert(auth_user);
259                        next.run(req).await
260                    }
261                    None => crate::error::HaystackError::new(
262                        "invalid or expired auth token",
263                        StatusCode::UNAUTHORIZED,
264                    )
265                    .into_response(),
266                }
267            }
268            _ => {
269                crate::error::HaystackError::new("BEARER token required", StatusCode::UNAUTHORIZED)
270                    .into_response()
271            }
272        },
273        None => crate::error::HaystackError::new(
274            "Authorization header required",
275            StatusCode::UNAUTHORIZED,
276        )
277        .into_response(),
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn required_permission_read_ops() {
287        assert_eq!(required_permission("/api/read"), Some("read"));
288        assert_eq!(required_permission("/api/nav"), Some("read"));
289        assert_eq!(required_permission("/api/defs"), Some("read"));
290        assert_eq!(required_permission("/api/libs"), Some("read"));
291        assert_eq!(required_permission("/api/hisRead"), Some("read"));
292        assert_eq!(required_permission("/api/watchSub"), Some("read"));
293        assert_eq!(required_permission("/api/watchPoll"), Some("read"));
294        assert_eq!(required_permission("/api/watchUnsub"), Some("read"));
295        assert_eq!(required_permission("/api/close"), Some("read"));
296        assert_eq!(required_permission("/api/about"), Some("read"));
297        assert_eq!(required_permission("/api/ops"), Some("read"));
298        assert_eq!(required_permission("/api/formats"), Some("read"));
299    }
300
301    #[test]
302    fn required_permission_write_ops() {
303        assert_eq!(required_permission("/api/pointWrite"), Some("write"));
304        assert_eq!(required_permission("/api/hisWrite"), Some("write"));
305        assert_eq!(required_permission("/api/invokeAction"), Some("write"));
306        assert_eq!(required_permission("/api/import"), Some("write"));
307    }
308}