forge_orchestration/
networking.rs

1//! Networking module for Forge
2//!
3//! ## Table of Contents
4//! - **QuicTransport**: QUIC-based peer communication
5//! - **HttpServer**: Axum-based HTTP/REST API server
6//! - **GrpcServer**: Tonic-based gRPC server (placeholder)
7
8use crate::error::{ForgeError, Result};
9use axum::{
10    http::StatusCode,
11    response::IntoResponse,
12    routing::get,
13    Json, Router,
14};
15use serde::{Deserialize, Serialize};
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::info;
20
21/// HTTP server configuration
22#[derive(Debug, Clone)]
23pub struct HttpServerConfig {
24    /// Bind address
25    pub bind_addr: SocketAddr,
26    /// Enable CORS
27    pub cors_enabled: bool,
28    /// Request timeout in seconds
29    pub timeout_secs: u64,
30}
31
32impl Default for HttpServerConfig {
33    fn default() -> Self {
34        Self {
35            bind_addr: ([0, 0, 0, 0], 8080).into(),
36            cors_enabled: true,
37            timeout_secs: 30,
38        }
39    }
40}
41
42impl HttpServerConfig {
43    /// Create with custom bind address
44    pub fn with_addr(mut self, addr: SocketAddr) -> Self {
45        self.bind_addr = addr;
46        self
47    }
48
49    /// Parse from string address
50    pub fn with_addr_str(mut self, addr: &str) -> Result<Self> {
51        self.bind_addr = addr
52            .parse()
53            .map_err(|e| ForgeError::config(format!("Invalid address: {}", e)))?;
54        Ok(self)
55    }
56}
57
58/// Shared state for HTTP handlers
59pub struct HttpState<T> {
60    /// Application state
61    pub app: Arc<RwLock<T>>,
62}
63
64impl<T> Clone for HttpState<T> {
65    fn clone(&self) -> Self {
66        Self {
67            app: Arc::clone(&self.app),
68        }
69    }
70}
71
72/// Health check response
73#[derive(Debug, Serialize)]
74pub struct HealthResponse {
75    pub status: String,
76    pub version: String,
77    pub uptime_secs: u64,
78}
79
80/// Error response
81#[derive(Debug, Serialize)]
82pub struct ErrorResponse {
83    pub error: String,
84    pub code: u16,
85}
86
87impl IntoResponse for ErrorResponse {
88    fn into_response(self) -> axum::response::Response {
89        let status = StatusCode::from_u16(self.code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
90        (status, Json(self)).into_response()
91    }
92}
93
94/// Create the base router with health endpoints
95pub fn base_router<T: Send + Sync + 'static>() -> Router<HttpState<T>> {
96    Router::new()
97        .route("/health", get(health_handler))
98        .route("/ready", get(ready_handler))
99}
100
101async fn health_handler() -> Json<HealthResponse> {
102    Json(HealthResponse {
103        status: "healthy".to_string(),
104        version: env!("CARGO_PKG_VERSION").to_string(),
105        uptime_secs: 0, // TODO: Track actual uptime
106    })
107}
108
109async fn ready_handler() -> StatusCode {
110    StatusCode::OK
111}
112
113/// HTTP server wrapper
114pub struct HttpServer {
115    config: HttpServerConfig,
116    router: Router,
117}
118
119impl HttpServer {
120    /// Create a new HTTP server
121    pub fn new(config: HttpServerConfig) -> Self {
122        Self {
123            config,
124            router: Router::new(),
125        }
126    }
127
128    /// Set the router
129    pub fn with_router(mut self, router: Router) -> Self {
130        self.router = router;
131        self
132    }
133
134    /// Start the server
135    pub async fn serve(self) -> Result<()> {
136        let listener = tokio::net::TcpListener::bind(self.config.bind_addr)
137            .await
138            .map_err(|e| ForgeError::network(format!("Failed to bind: {}", e)))?;
139
140        info!(addr = %self.config.bind_addr, "HTTP server starting");
141
142        axum::serve(listener, self.router)
143            .await
144            .map_err(|e| ForgeError::network(format!("Server error: {}", e)))?;
145
146        Ok(())
147    }
148}
149
150/// QUIC transport configuration
151#[derive(Debug, Clone)]
152pub struct QuicConfig {
153    /// Bind address
154    pub bind_addr: SocketAddr,
155    /// Server name for TLS
156    pub server_name: String,
157    /// Max concurrent streams
158    pub max_streams: u32,
159    /// Idle timeout in seconds
160    pub idle_timeout_secs: u64,
161}
162
163impl Default for QuicConfig {
164    fn default() -> Self {
165        Self {
166            bind_addr: ([0, 0, 0, 0], 4433).into(),
167            server_name: "forge".to_string(),
168            max_streams: 100,
169            idle_timeout_secs: 30,
170        }
171    }
172}
173
174/// QUIC transport for peer communication
175pub struct QuicTransport {
176    config: QuicConfig,
177    endpoint: Option<quinn::Endpoint>,
178}
179
180impl QuicTransport {
181    /// Create a new QUIC transport
182    pub fn new(config: QuicConfig) -> Self {
183        Self {
184            config,
185            endpoint: None,
186        }
187    }
188
189    /// Generate self-signed certificate for development
190    fn generate_self_signed_cert() -> Result<(rustls::pki_types::CertificateDer<'static>, rustls::pki_types::PrivateKeyDer<'static>)> {
191        let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
192            .map_err(|e| ForgeError::network(format!("Failed to generate cert: {}", e)))?;
193
194        let cert_der = rustls::pki_types::CertificateDer::from(cert.cert.der().to_vec());
195        let key_der = rustls::pki_types::PrivateKeyDer::try_from(cert.key_pair.serialize_der())
196            .map_err(|e| ForgeError::network(format!("Failed to serialize key: {}", e)))?;
197
198        Ok((cert_der, key_der))
199    }
200
201    /// Start as server
202    pub async fn start_server(&mut self) -> Result<()> {
203        let (cert, key) = Self::generate_self_signed_cert()?;
204
205        let mut server_crypto = rustls::ServerConfig::builder()
206            .with_no_client_auth()
207            .with_single_cert(vec![cert], key)
208            .map_err(|e| ForgeError::network(format!("TLS config error: {}", e)))?;
209
210        server_crypto.alpn_protocols = vec![b"forge".to_vec()];
211
212        let server_config = quinn::ServerConfig::with_crypto(Arc::new(
213            quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)
214                .map_err(|e| ForgeError::network(format!("QUIC config error: {}", e)))?,
215        ));
216
217        let endpoint = quinn::Endpoint::server(server_config, self.config.bind_addr)
218            .map_err(|e| ForgeError::network(format!("Failed to create endpoint: {}", e)))?;
219
220        info!(addr = %self.config.bind_addr, "QUIC server started");
221        self.endpoint = Some(endpoint);
222
223        Ok(())
224    }
225
226    /// Accept incoming connections
227    pub async fn accept(&self) -> Result<quinn::Connection> {
228        let endpoint = self
229            .endpoint
230            .as_ref()
231            .ok_or_else(|| ForgeError::network("Endpoint not started"))?;
232
233        let incoming = endpoint
234            .accept()
235            .await
236            .ok_or_else(|| ForgeError::network("Endpoint closed"))?;
237
238        let conn = incoming
239            .await
240            .map_err(|e| ForgeError::network(format!("Connection error: {}", e)))?;
241
242        info!(remote = %conn.remote_address(), "QUIC connection accepted");
243        Ok(conn)
244    }
245
246    /// Connect to a peer
247    pub async fn connect(&self, addr: SocketAddr) -> Result<quinn::Connection> {
248        let endpoint = self
249            .endpoint
250            .as_ref()
251            .ok_or_else(|| ForgeError::network("Endpoint not started"))?;
252
253        let conn = endpoint
254            .connect(addr, &self.config.server_name)
255            .map_err(|e| ForgeError::network(format!("Connect error: {}", e)))?
256            .await
257            .map_err(|e| ForgeError::network(format!("Connection error: {}", e)))?;
258
259        info!(remote = %addr, "QUIC connection established");
260        Ok(conn)
261    }
262
263    /// Get local address
264    pub fn local_addr(&self) -> Option<SocketAddr> {
265        self.endpoint.as_ref().and_then(|e| e.local_addr().ok())
266    }
267
268    /// Close the transport
269    pub fn close(&self) {
270        if let Some(endpoint) = &self.endpoint {
271            endpoint.close(0u32.into(), b"shutdown");
272        }
273    }
274}
275
276/// Message types for peer communication
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub enum PeerMessage {
279    /// Heartbeat ping
280    Ping { node_id: String, timestamp: u64 },
281    /// Heartbeat pong
282    Pong { node_id: String, timestamp: u64 },
283    /// Route request to expert
284    RouteRequest { request_id: String, input: String },
285    /// Route response from expert
286    RouteResponse {
287        request_id: String,
288        expert_index: usize,
289        result: Vec<u8>,
290    },
291    /// Shard assignment notification
292    ShardAssign { shard_id: u64, node_id: String },
293    /// Shard migration request
294    ShardMigrate {
295        shard_id: u64,
296        from_node: String,
297        to_node: String,
298    },
299}
300
301impl PeerMessage {
302    /// Serialize to bytes
303    pub fn to_bytes(&self) -> Result<Vec<u8>> {
304        serde_json::to_vec(self).map_err(|e| ForgeError::network(format!("Serialize error: {}", e)))
305    }
306
307    /// Deserialize from bytes
308    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
309        serde_json::from_slice(bytes)
310            .map_err(|e| ForgeError::network(format!("Deserialize error: {}", e)))
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_http_config_default() {
320        let config = HttpServerConfig::default();
321        assert_eq!(config.bind_addr.port(), 8080);
322        assert!(config.cors_enabled);
323    }
324
325    #[test]
326    fn test_quic_config_default() {
327        let config = QuicConfig::default();
328        assert_eq!(config.bind_addr.port(), 4433);
329        assert_eq!(config.server_name, "forge");
330    }
331
332    #[test]
333    fn test_peer_message_serialization() {
334        let msg = PeerMessage::Ping {
335            node_id: "node-1".to_string(),
336            timestamp: 12345,
337        };
338
339        let bytes = msg.to_bytes().unwrap();
340        let decoded = PeerMessage::from_bytes(&bytes).unwrap();
341
342        match decoded {
343            PeerMessage::Ping { node_id, timestamp } => {
344                assert_eq!(node_id, "node-1");
345                assert_eq!(timestamp, 12345);
346            }
347            _ => panic!("Wrong message type"),
348        }
349    }
350}