attuned_http/
server.rs

1//! HTTP server implementation.
2
3use crate::config::ServerConfig;
4use crate::error::HttpError;
5use crate::handlers::{
6    delete_state, get_context, get_state, health, ready, translate, upsert_state, AppState,
7};
8use crate::middleware::security_headers;
9use attuned_core::HealthCheck;
10use attuned_store::StateStore;
11use axum::{
12    middleware,
13    routing::{delete, get, post},
14    Router,
15};
16use std::sync::Arc;
17use tower_http::trace::TraceLayer;
18
19#[cfg(feature = "inference")]
20use crate::handlers::infer;
21
22/// HTTP server for the Attuned API.
23pub struct Server<S: StateStore + HealthCheck> {
24    state: Arc<AppState<S>>,
25    config: ServerConfig,
26}
27
28impl<S: StateStore + HealthCheck + 'static> Server<S> {
29    /// Create a new server with the given store and configuration.
30    pub fn new(store: S, config: ServerConfig) -> Self {
31        #[cfg(feature = "inference")]
32        let state = if config.enable_inference {
33            Arc::new(AppState::with_inference(
34                store,
35                config.inference_config.clone(),
36            ))
37        } else {
38            Arc::new(AppState::new(store))
39        };
40        #[cfg(not(feature = "inference"))]
41        let state = Arc::new(AppState::new(store));
42
43        Self { state, config }
44    }
45
46    /// Build the router with all routes.
47    pub fn router(&self) -> Router {
48        // Build routes with typed state
49        let typed_router = Router::new()
50            // State management
51            .route("/v1/state", post(upsert_state::<S>))
52            .route("/v1/state/{user_id}", get(get_state::<S>))
53            .route("/v1/state/{user_id}", delete(delete_state::<S>))
54            // Context/translation
55            .route("/v1/context/{user_id}", get(get_context::<S>))
56            .route("/v1/translate", post(translate::<S>))
57            // Operations
58            .route("/health", get(health::<S>))
59            .route("/ready", get(ready::<S>));
60
61        // Add inference endpoint if feature enabled
62        #[cfg(feature = "inference")]
63        let typed_router = typed_router.route("/v1/infer", post(infer::<S>));
64
65        // Apply state and convert to Router<()>
66        let mut router = typed_router.with_state(self.state.clone());
67
68        // Add security headers middleware (outermost layer, runs last on request, first on response)
69        if self.config.security_headers {
70            router = router.layer(middleware::from_fn(security_headers));
71        }
72
73        // Add tracing
74        router = router.layer(TraceLayer::new_for_http());
75
76        router
77    }
78
79    /// Run the server until shutdown.
80    pub async fn run(self) -> Result<(), HttpError> {
81        let app = self.router();
82
83        tracing::info!(
84            addr = %self.config.bind_addr,
85            security_headers = %self.config.security_headers,
86            auth_enabled = %self.config.auth.is_enabled(),
87            rate_limit = %self.config.rate_limit.max_requests,
88            "starting HTTP server"
89        );
90
91        let listener = tokio::net::TcpListener::bind(&self.config.bind_addr)
92            .await
93            .map_err(|e| HttpError::Bind {
94                addr: self.config.bind_addr.to_string(),
95                message: e.to_string(),
96            })?;
97
98        axum::serve(listener, app)
99            .await
100            .map_err(|e| HttpError::Request(e.to_string()))?;
101
102        Ok(())
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use attuned_store::MemoryStore;
110    use axum::body::Body;
111    use axum::http::{Request, StatusCode};
112    use tower::ServiceExt;
113
114    fn test_server() -> Server<MemoryStore> {
115        let store = MemoryStore::default();
116        let config = ServerConfig::default();
117        Server::new(store, config)
118    }
119
120    #[tokio::test]
121    async fn test_health_endpoint() {
122        let server = test_server();
123        let app = server.router();
124
125        let response = app
126            .oneshot(
127                Request::builder()
128                    .uri("/health")
129                    .body(Body::empty())
130                    .unwrap(),
131            )
132            .await
133            .unwrap();
134
135        assert_eq!(response.status(), StatusCode::OK);
136    }
137
138    #[tokio::test]
139    async fn test_ready_endpoint() {
140        let server = test_server();
141        let app = server.router();
142
143        let response = app
144            .oneshot(
145                Request::builder()
146                    .uri("/ready")
147                    .body(Body::empty())
148                    .unwrap(),
149            )
150            .await
151            .unwrap();
152
153        assert_eq!(response.status(), StatusCode::OK);
154    }
155
156    #[tokio::test]
157    async fn test_get_nonexistent_user() {
158        let server = test_server();
159        let app = server.router();
160
161        let response = app
162            .oneshot(
163                Request::builder()
164                    .uri("/v1/state/nonexistent")
165                    .body(Body::empty())
166                    .unwrap(),
167            )
168            .await
169            .unwrap();
170
171        assert_eq!(response.status(), StatusCode::NOT_FOUND);
172    }
173
174    #[tokio::test]
175    async fn test_upsert_and_get_state() {
176        let server = test_server();
177        let app = server.router();
178
179        // Upsert state
180        let body = r#"{"user_id": "test_user", "axes": {"warmth": 0.7}}"#;
181        let response = app
182            .clone()
183            .oneshot(
184                Request::builder()
185                    .method("POST")
186                    .uri("/v1/state")
187                    .header("content-type", "application/json")
188                    .body(Body::from(body))
189                    .unwrap(),
190            )
191            .await
192            .unwrap();
193
194        assert_eq!(response.status(), StatusCode::NO_CONTENT);
195
196        // Get state
197        let response = app
198            .oneshot(
199                Request::builder()
200                    .uri("/v1/state/test_user")
201                    .body(Body::empty())
202                    .unwrap(),
203            )
204            .await
205            .unwrap();
206
207        assert_eq!(response.status(), StatusCode::OK);
208    }
209
210    #[tokio::test]
211    async fn test_security_headers_present() {
212        let server = test_server();
213        let app = server.router();
214
215        let response = app
216            .oneshot(
217                Request::builder()
218                    .uri("/health")
219                    .body(Body::empty())
220                    .unwrap(),
221            )
222            .await
223            .unwrap();
224
225        assert_eq!(response.status(), StatusCode::OK);
226
227        // Verify security headers
228        let headers = response.headers();
229        assert_eq!(headers.get("x-content-type-options").unwrap(), "nosniff");
230        assert_eq!(headers.get("x-frame-options").unwrap(), "DENY");
231        assert_eq!(headers.get("x-xss-protection").unwrap(), "1; mode=block");
232        assert!(headers.get("content-security-policy").is_some());
233        assert_eq!(headers.get("cache-control").unwrap(), "no-store, max-age=0");
234    }
235}