1use crate::error::{ForgeError, Result};
9use axum::{
10 http::StatusCode,
11 response::IntoResponse,
12 routing::get,
13 Json, Router,
14};
15use serde::{Deserialize, Serialize};
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::info;
20
21#[cfg(feature = "quic")]
22use std::sync::Arc as QuicArc;
23
24#[derive(Debug, Clone)]
26pub struct HttpServerConfig {
27 pub bind_addr: SocketAddr,
29 pub cors_enabled: bool,
31 pub timeout_secs: u64,
33}
34
35impl Default for HttpServerConfig {
36 fn default() -> Self {
37 Self {
38 bind_addr: ([0, 0, 0, 0], 8080).into(),
39 cors_enabled: true,
40 timeout_secs: 30,
41 }
42 }
43}
44
45impl HttpServerConfig {
46 pub fn with_addr(mut self, addr: SocketAddr) -> Self {
48 self.bind_addr = addr;
49 self
50 }
51
52 pub fn with_addr_str(mut self, addr: &str) -> Result<Self> {
54 self.bind_addr = addr
55 .parse()
56 .map_err(|e| ForgeError::config(format!("Invalid address: {}", e)))?;
57 Ok(self)
58 }
59}
60
61pub struct HttpState<T> {
63 pub app: Arc<RwLock<T>>,
65}
66
67impl<T> Clone for HttpState<T> {
68 fn clone(&self) -> Self {
69 Self {
70 app: Arc::clone(&self.app),
71 }
72 }
73}
74
75#[derive(Debug, Serialize)]
77pub struct HealthResponse {
78 pub status: String,
79 pub version: String,
80 pub uptime_secs: u64,
81}
82
83#[derive(Debug, Serialize)]
85pub struct ErrorResponse {
86 pub error: String,
87 pub code: u16,
88}
89
90impl IntoResponse for ErrorResponse {
91 fn into_response(self) -> axum::response::Response {
92 let status = StatusCode::from_u16(self.code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
93 (status, Json(self)).into_response()
94 }
95}
96
97pub fn base_router<T: Send + Sync + 'static>() -> Router<HttpState<T>> {
99 Router::new()
100 .route("/health", get(health_handler))
101 .route("/ready", get(ready_handler))
102}
103
104async fn health_handler() -> Json<HealthResponse> {
105 Json(HealthResponse {
106 status: "healthy".to_string(),
107 version: env!("CARGO_PKG_VERSION").to_string(),
108 uptime_secs: 0, })
110}
111
112async fn ready_handler() -> StatusCode {
113 StatusCode::OK
114}
115
116pub struct HttpServer {
118 config: HttpServerConfig,
119 router: Router,
120}
121
122impl HttpServer {
123 pub fn new(config: HttpServerConfig) -> Self {
125 Self {
126 config,
127 router: Router::new(),
128 }
129 }
130
131 pub fn with_router(mut self, router: Router) -> Self {
133 self.router = router;
134 self
135 }
136
137 pub async fn serve(self) -> Result<()> {
139 let listener = tokio::net::TcpListener::bind(self.config.bind_addr)
140 .await
141 .map_err(|e| ForgeError::network(format!("Failed to bind: {}", e)))?;
142
143 info!(addr = %self.config.bind_addr, "HTTP server starting");
144
145 axum::serve(listener, self.router)
146 .await
147 .map_err(|e| ForgeError::network(format!("Server error: {}", e)))?;
148
149 Ok(())
150 }
151}
152
153#[cfg(feature = "quic")]
155#[derive(Debug, Clone)]
156pub struct QuicConfig {
157 pub bind_addr: SocketAddr,
159 pub server_name: String,
161 pub max_streams: u32,
163 pub idle_timeout_secs: u64,
165}
166
167#[cfg(feature = "quic")]
168impl Default for QuicConfig {
169 fn default() -> Self {
170 Self {
171 bind_addr: ([0, 0, 0, 0], 4433).into(),
172 server_name: "forge".to_string(),
173 max_streams: 100,
174 idle_timeout_secs: 30,
175 }
176 }
177}
178
179#[cfg(feature = "quic")]
181pub struct QuicTransport {
182 config: QuicConfig,
183 endpoint: Option<quinn::Endpoint>,
184}
185
186#[cfg(feature = "quic")]
187impl QuicTransport {
188 pub fn new(config: QuicConfig) -> Self {
190 Self {
191 config,
192 endpoint: None,
193 }
194 }
195
196 fn generate_self_signed_cert() -> Result<(rustls::pki_types::CertificateDer<'static>, rustls::pki_types::PrivateKeyDer<'static>)> {
198 let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
199 .map_err(|e| ForgeError::network(format!("Failed to generate cert: {}", e)))?;
200
201 let cert_der = rustls::pki_types::CertificateDer::from(cert.cert.der().to_vec());
202 let key_der = rustls::pki_types::PrivateKeyDer::try_from(cert.key_pair.serialize_der())
203 .map_err(|e| ForgeError::network(format!("Failed to serialize key: {}", e)))?;
204
205 Ok((cert_der, key_der))
206 }
207
208 pub async fn start_server(&mut self) -> Result<()> {
210 let (cert, key) = Self::generate_self_signed_cert()?;
211
212 let mut server_crypto = rustls::ServerConfig::builder()
213 .with_no_client_auth()
214 .with_single_cert(vec![cert], key)
215 .map_err(|e| ForgeError::network(format!("TLS config error: {}", e)))?;
216
217 server_crypto.alpn_protocols = vec![b"forge".to_vec()];
218
219 let server_config = quinn::ServerConfig::with_crypto(Arc::new(
220 quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)
221 .map_err(|e| ForgeError::network(format!("QUIC config error: {}", e)))?,
222 ));
223
224 let endpoint = quinn::Endpoint::server(server_config, self.config.bind_addr)
225 .map_err(|e| ForgeError::network(format!("Failed to create endpoint: {}", e)))?;
226
227 info!(addr = %self.config.bind_addr, "QUIC server started");
228 self.endpoint = Some(endpoint);
229
230 Ok(())
231 }
232
233 pub async fn accept(&self) -> Result<quinn::Connection> {
235 let endpoint = self
236 .endpoint
237 .as_ref()
238 .ok_or_else(|| ForgeError::network("Endpoint not started"))?;
239
240 let incoming = endpoint
241 .accept()
242 .await
243 .ok_or_else(|| ForgeError::network("Endpoint closed"))?;
244
245 let conn = incoming
246 .await
247 .map_err(|e| ForgeError::network(format!("Connection error: {}", e)))?;
248
249 info!(remote = %conn.remote_address(), "QUIC connection accepted");
250 Ok(conn)
251 }
252
253 pub async fn connect(&self, addr: SocketAddr) -> Result<quinn::Connection> {
255 let endpoint = self
256 .endpoint
257 .as_ref()
258 .ok_or_else(|| ForgeError::network("Endpoint not started"))?;
259
260 let conn = endpoint
261 .connect(addr, &self.config.server_name)
262 .map_err(|e| ForgeError::network(format!("Connect error: {}", e)))?
263 .await
264 .map_err(|e| ForgeError::network(format!("Connection error: {}", e)))?;
265
266 info!(remote = %addr, "QUIC connection established");
267 Ok(conn)
268 }
269
270 pub fn local_addr(&self) -> Option<SocketAddr> {
272 self.endpoint.as_ref().and_then(|e| e.local_addr().ok())
273 }
274
275 pub fn close(&self) {
277 if let Some(endpoint) = &self.endpoint {
278 endpoint.close(0u32.into(), b"shutdown");
279 }
280 }
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum PeerMessage {
286 Ping { node_id: String, timestamp: u64 },
288 Pong { node_id: String, timestamp: u64 },
290 RouteRequest { request_id: String, input: String },
292 RouteResponse {
294 request_id: String,
295 expert_index: usize,
296 result: Vec<u8>,
297 },
298 ShardAssign { shard_id: u64, node_id: String },
300 ShardMigrate {
302 shard_id: u64,
303 from_node: String,
304 to_node: String,
305 },
306}
307
308impl PeerMessage {
309 pub fn to_bytes(&self) -> Result<Vec<u8>> {
311 serde_json::to_vec(self).map_err(|e| ForgeError::network(format!("Serialize error: {}", e)))
312 }
313
314 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
316 serde_json::from_slice(bytes)
317 .map_err(|e| ForgeError::network(format!("Deserialize error: {}", e)))
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_http_config_default() {
327 let config = HttpServerConfig::default();
328 assert_eq!(config.bind_addr.port(), 8080);
329 assert!(config.cors_enabled);
330 }
331
332 #[test]
333 #[cfg(feature = "quic")]
334 fn test_quic_config_default() {
335 let config = QuicConfig::default();
336 assert_eq!(config.bind_addr.port(), 4433);
337 assert_eq!(config.server_name, "forge");
338 }
339
340 #[test]
341 fn test_peer_message_serialization() {
342 let msg = PeerMessage::Ping {
343 node_id: "node-1".to_string(),
344 timestamp: 12345,
345 };
346
347 let bytes = msg.to_bytes().unwrap();
348 let decoded = PeerMessage::from_bytes(&bytes).unwrap();
349
350 match decoded {
351 PeerMessage::Ping { node_id, timestamp } => {
352 assert_eq!(node_id, "node-1");
353 assert_eq!(timestamp, 12345);
354 }
355 _ => panic!("Wrong message type"),
356 }
357 }
358}