1use crate::error::{ForgeError, Result};
9use axum::{
10 http::StatusCode,
11 response::IntoResponse,
12 routing::get,
13 Json, Router,
14};
15use serde::{Deserialize, Serialize};
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::info;
20
21#[derive(Debug, Clone)]
23pub struct HttpServerConfig {
24 pub bind_addr: SocketAddr,
26 pub cors_enabled: bool,
28 pub timeout_secs: u64,
30}
31
32impl Default for HttpServerConfig {
33 fn default() -> Self {
34 Self {
35 bind_addr: ([0, 0, 0, 0], 8080).into(),
36 cors_enabled: true,
37 timeout_secs: 30,
38 }
39 }
40}
41
42impl HttpServerConfig {
43 pub fn with_addr(mut self, addr: SocketAddr) -> Self {
45 self.bind_addr = addr;
46 self
47 }
48
49 pub fn with_addr_str(mut self, addr: &str) -> Result<Self> {
51 self.bind_addr = addr
52 .parse()
53 .map_err(|e| ForgeError::config(format!("Invalid address: {}", e)))?;
54 Ok(self)
55 }
56}
57
58pub struct HttpState<T> {
60 pub app: Arc<RwLock<T>>,
62}
63
64impl<T> Clone for HttpState<T> {
65 fn clone(&self) -> Self {
66 Self {
67 app: Arc::clone(&self.app),
68 }
69 }
70}
71
72#[derive(Debug, Serialize)]
74pub struct HealthResponse {
75 pub status: String,
76 pub version: String,
77 pub uptime_secs: u64,
78}
79
80#[derive(Debug, Serialize)]
82pub struct ErrorResponse {
83 pub error: String,
84 pub code: u16,
85}
86
87impl IntoResponse for ErrorResponse {
88 fn into_response(self) -> axum::response::Response {
89 let status = StatusCode::from_u16(self.code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
90 (status, Json(self)).into_response()
91 }
92}
93
94pub fn base_router<T: Send + Sync + 'static>() -> Router<HttpState<T>> {
96 Router::new()
97 .route("/health", get(health_handler))
98 .route("/ready", get(ready_handler))
99}
100
101async fn health_handler() -> Json<HealthResponse> {
102 Json(HealthResponse {
103 status: "healthy".to_string(),
104 version: env!("CARGO_PKG_VERSION").to_string(),
105 uptime_secs: 0, })
107}
108
109async fn ready_handler() -> StatusCode {
110 StatusCode::OK
111}
112
113pub struct HttpServer {
115 config: HttpServerConfig,
116 router: Router,
117}
118
119impl HttpServer {
120 pub fn new(config: HttpServerConfig) -> Self {
122 Self {
123 config,
124 router: Router::new(),
125 }
126 }
127
128 pub fn with_router(mut self, router: Router) -> Self {
130 self.router = router;
131 self
132 }
133
134 pub async fn serve(self) -> Result<()> {
136 let listener = tokio::net::TcpListener::bind(self.config.bind_addr)
137 .await
138 .map_err(|e| ForgeError::network(format!("Failed to bind: {}", e)))?;
139
140 info!(addr = %self.config.bind_addr, "HTTP server starting");
141
142 axum::serve(listener, self.router)
143 .await
144 .map_err(|e| ForgeError::network(format!("Server error: {}", e)))?;
145
146 Ok(())
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct QuicConfig {
153 pub bind_addr: SocketAddr,
155 pub server_name: String,
157 pub max_streams: u32,
159 pub idle_timeout_secs: u64,
161}
162
163impl Default for QuicConfig {
164 fn default() -> Self {
165 Self {
166 bind_addr: ([0, 0, 0, 0], 4433).into(),
167 server_name: "forge".to_string(),
168 max_streams: 100,
169 idle_timeout_secs: 30,
170 }
171 }
172}
173
174pub struct QuicTransport {
176 config: QuicConfig,
177 endpoint: Option<quinn::Endpoint>,
178}
179
180impl QuicTransport {
181 pub fn new(config: QuicConfig) -> Self {
183 Self {
184 config,
185 endpoint: None,
186 }
187 }
188
189 fn generate_self_signed_cert() -> Result<(rustls::pki_types::CertificateDer<'static>, rustls::pki_types::PrivateKeyDer<'static>)> {
191 let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
192 .map_err(|e| ForgeError::network(format!("Failed to generate cert: {}", e)))?;
193
194 let cert_der = rustls::pki_types::CertificateDer::from(cert.cert.der().to_vec());
195 let key_der = rustls::pki_types::PrivateKeyDer::try_from(cert.key_pair.serialize_der())
196 .map_err(|e| ForgeError::network(format!("Failed to serialize key: {}", e)))?;
197
198 Ok((cert_der, key_der))
199 }
200
201 pub async fn start_server(&mut self) -> Result<()> {
203 let (cert, key) = Self::generate_self_signed_cert()?;
204
205 let mut server_crypto = rustls::ServerConfig::builder()
206 .with_no_client_auth()
207 .with_single_cert(vec![cert], key)
208 .map_err(|e| ForgeError::network(format!("TLS config error: {}", e)))?;
209
210 server_crypto.alpn_protocols = vec![b"forge".to_vec()];
211
212 let server_config = quinn::ServerConfig::with_crypto(Arc::new(
213 quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)
214 .map_err(|e| ForgeError::network(format!("QUIC config error: {}", e)))?,
215 ));
216
217 let endpoint = quinn::Endpoint::server(server_config, self.config.bind_addr)
218 .map_err(|e| ForgeError::network(format!("Failed to create endpoint: {}", e)))?;
219
220 info!(addr = %self.config.bind_addr, "QUIC server started");
221 self.endpoint = Some(endpoint);
222
223 Ok(())
224 }
225
226 pub async fn accept(&self) -> Result<quinn::Connection> {
228 let endpoint = self
229 .endpoint
230 .as_ref()
231 .ok_or_else(|| ForgeError::network("Endpoint not started"))?;
232
233 let incoming = endpoint
234 .accept()
235 .await
236 .ok_or_else(|| ForgeError::network("Endpoint closed"))?;
237
238 let conn = incoming
239 .await
240 .map_err(|e| ForgeError::network(format!("Connection error: {}", e)))?;
241
242 info!(remote = %conn.remote_address(), "QUIC connection accepted");
243 Ok(conn)
244 }
245
246 pub async fn connect(&self, addr: SocketAddr) -> Result<quinn::Connection> {
248 let endpoint = self
249 .endpoint
250 .as_ref()
251 .ok_or_else(|| ForgeError::network("Endpoint not started"))?;
252
253 let conn = endpoint
254 .connect(addr, &self.config.server_name)
255 .map_err(|e| ForgeError::network(format!("Connect error: {}", e)))?
256 .await
257 .map_err(|e| ForgeError::network(format!("Connection error: {}", e)))?;
258
259 info!(remote = %addr, "QUIC connection established");
260 Ok(conn)
261 }
262
263 pub fn local_addr(&self) -> Option<SocketAddr> {
265 self.endpoint.as_ref().and_then(|e| e.local_addr().ok())
266 }
267
268 pub fn close(&self) {
270 if let Some(endpoint) = &self.endpoint {
271 endpoint.close(0u32.into(), b"shutdown");
272 }
273 }
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
278pub enum PeerMessage {
279 Ping { node_id: String, timestamp: u64 },
281 Pong { node_id: String, timestamp: u64 },
283 RouteRequest { request_id: String, input: String },
285 RouteResponse {
287 request_id: String,
288 expert_index: usize,
289 result: Vec<u8>,
290 },
291 ShardAssign { shard_id: u64, node_id: String },
293 ShardMigrate {
295 shard_id: u64,
296 from_node: String,
297 to_node: String,
298 },
299}
300
301impl PeerMessage {
302 pub fn to_bytes(&self) -> Result<Vec<u8>> {
304 serde_json::to_vec(self).map_err(|e| ForgeError::network(format!("Serialize error: {}", e)))
305 }
306
307 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
309 serde_json::from_slice(bytes)
310 .map_err(|e| ForgeError::network(format!("Deserialize error: {}", e)))
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_http_config_default() {
320 let config = HttpServerConfig::default();
321 assert_eq!(config.bind_addr.port(), 8080);
322 assert!(config.cors_enabled);
323 }
324
325 #[test]
326 fn test_quic_config_default() {
327 let config = QuicConfig::default();
328 assert_eq!(config.bind_addr.port(), 4433);
329 assert_eq!(config.server_name, "forge");
330 }
331
332 #[test]
333 fn test_peer_message_serialization() {
334 let msg = PeerMessage::Ping {
335 node_id: "node-1".to_string(),
336 timestamp: 12345,
337 };
338
339 let bytes = msg.to_bytes().unwrap();
340 let decoded = PeerMessage::from_bytes(&bytes).unwrap();
341
342 match decoded {
343 PeerMessage::Ping { node_id, timestamp } => {
344 assert_eq!(node_id, "node-1");
345 assert_eq!(timestamp, 12345);
346 }
347 _ => panic!("Wrong message type"),
348 }
349 }
350}