1use crate::config::ServerConfig;
4use crate::error::HttpError;
5use crate::handlers::{
6 delete_state, get_context, get_state, health, ready, translate, upsert_state, AppState,
7};
8use crate::middleware::security_headers;
9use attuned_core::HealthCheck;
10use attuned_store::StateStore;
11use axum::{
12 middleware,
13 routing::{delete, get, post},
14 Router,
15};
16use std::sync::Arc;
17use tower_http::trace::TraceLayer;
18
19#[cfg(feature = "inference")]
20use crate::handlers::infer;
21
22pub struct Server<S: StateStore + HealthCheck> {
24 state: Arc<AppState<S>>,
25 config: ServerConfig,
26}
27
28impl<S: StateStore + HealthCheck + 'static> Server<S> {
29 pub fn new(store: S, config: ServerConfig) -> Self {
31 #[cfg(feature = "inference")]
32 let state = if config.enable_inference {
33 Arc::new(AppState::with_inference(
34 store,
35 config.inference_config.clone(),
36 ))
37 } else {
38 Arc::new(AppState::new(store))
39 };
40 #[cfg(not(feature = "inference"))]
41 let state = Arc::new(AppState::new(store));
42
43 Self { state, config }
44 }
45
46 pub fn router(&self) -> Router {
48 let typed_router = Router::new()
50 .route("/v1/state", post(upsert_state::<S>))
52 .route("/v1/state/{user_id}", get(get_state::<S>))
53 .route("/v1/state/{user_id}", delete(delete_state::<S>))
54 .route("/v1/context/{user_id}", get(get_context::<S>))
56 .route("/v1/translate", post(translate::<S>))
57 .route("/health", get(health::<S>))
59 .route("/ready", get(ready::<S>));
60
61 #[cfg(feature = "inference")]
63 let typed_router = typed_router.route("/v1/infer", post(infer::<S>));
64
65 let mut router = typed_router.with_state(self.state.clone());
67
68 if self.config.security_headers {
70 router = router.layer(middleware::from_fn(security_headers));
71 }
72
73 router = router.layer(TraceLayer::new_for_http());
75
76 router
77 }
78
79 pub async fn run(self) -> Result<(), HttpError> {
81 let app = self.router();
82
83 tracing::info!(
84 addr = %self.config.bind_addr,
85 security_headers = %self.config.security_headers,
86 auth_enabled = %self.config.auth.is_enabled(),
87 rate_limit = %self.config.rate_limit.max_requests,
88 "starting HTTP server"
89 );
90
91 let listener = tokio::net::TcpListener::bind(&self.config.bind_addr)
92 .await
93 .map_err(|e| HttpError::Bind {
94 addr: self.config.bind_addr.to_string(),
95 message: e.to_string(),
96 })?;
97
98 axum::serve(listener, app)
99 .await
100 .map_err(|e| HttpError::Request(e.to_string()))?;
101
102 Ok(())
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use attuned_store::MemoryStore;
110 use axum::body::Body;
111 use axum::http::{Request, StatusCode};
112 use tower::ServiceExt;
113
114 fn test_server() -> Server<MemoryStore> {
115 let store = MemoryStore::default();
116 let config = ServerConfig::default();
117 Server::new(store, config)
118 }
119
120 #[tokio::test]
121 async fn test_health_endpoint() {
122 let server = test_server();
123 let app = server.router();
124
125 let response = app
126 .oneshot(
127 Request::builder()
128 .uri("/health")
129 .body(Body::empty())
130 .unwrap(),
131 )
132 .await
133 .unwrap();
134
135 assert_eq!(response.status(), StatusCode::OK);
136 }
137
138 #[tokio::test]
139 async fn test_ready_endpoint() {
140 let server = test_server();
141 let app = server.router();
142
143 let response = app
144 .oneshot(
145 Request::builder()
146 .uri("/ready")
147 .body(Body::empty())
148 .unwrap(),
149 )
150 .await
151 .unwrap();
152
153 assert_eq!(response.status(), StatusCode::OK);
154 }
155
156 #[tokio::test]
157 async fn test_get_nonexistent_user() {
158 let server = test_server();
159 let app = server.router();
160
161 let response = app
162 .oneshot(
163 Request::builder()
164 .uri("/v1/state/nonexistent")
165 .body(Body::empty())
166 .unwrap(),
167 )
168 .await
169 .unwrap();
170
171 assert_eq!(response.status(), StatusCode::NOT_FOUND);
172 }
173
174 #[tokio::test]
175 async fn test_upsert_and_get_state() {
176 let server = test_server();
177 let app = server.router();
178
179 let body = r#"{"user_id": "test_user", "axes": {"warmth": 0.7}}"#;
181 let response = app
182 .clone()
183 .oneshot(
184 Request::builder()
185 .method("POST")
186 .uri("/v1/state")
187 .header("content-type", "application/json")
188 .body(Body::from(body))
189 .unwrap(),
190 )
191 .await
192 .unwrap();
193
194 assert_eq!(response.status(), StatusCode::NO_CONTENT);
195
196 let response = app
198 .oneshot(
199 Request::builder()
200 .uri("/v1/state/test_user")
201 .body(Body::empty())
202 .unwrap(),
203 )
204 .await
205 .unwrap();
206
207 assert_eq!(response.status(), StatusCode::OK);
208 }
209
210 #[tokio::test]
211 async fn test_security_headers_present() {
212 let server = test_server();
213 let app = server.router();
214
215 let response = app
216 .oneshot(
217 Request::builder()
218 .uri("/health")
219 .body(Body::empty())
220 .unwrap(),
221 )
222 .await
223 .unwrap();
224
225 assert_eq!(response.status(), StatusCode::OK);
226
227 let headers = response.headers();
229 assert_eq!(headers.get("x-content-type-options").unwrap(), "nosniff");
230 assert_eq!(headers.get("x-frame-options").unwrap(), "DENY");
231 assert_eq!(headers.get("x-xss-protection").unwrap(), "1; mode=block");
232 assert!(headers.get("content-security-policy").is_some());
233 assert_eq!(headers.get("cache-control").unwrap(), "no-store, max-age=0");
234 }
235}