forge_orchestration/
networking.rs

1//! Networking module for Forge
2//!
3//! ## Table of Contents
4//! - **QuicTransport**: QUIC-based peer communication (requires `quic` feature)
5//! - **HttpServer**: Axum-based HTTP/REST API server
6//! - **GrpcServer**: Tonic-based gRPC server (placeholder)
7
8use crate::error::{ForgeError, Result};
9use axum::{
10    http::StatusCode,
11    response::IntoResponse,
12    routing::get,
13    Json, Router,
14};
15use serde::{Deserialize, Serialize};
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::info;
20
21#[cfg(feature = "quic")]
22use std::sync::Arc as QuicArc;
23
24/// HTTP server configuration
25#[derive(Debug, Clone)]
26pub struct HttpServerConfig {
27    /// Bind address
28    pub bind_addr: SocketAddr,
29    /// Enable CORS
30    pub cors_enabled: bool,
31    /// Request timeout in seconds
32    pub timeout_secs: u64,
33}
34
35impl Default for HttpServerConfig {
36    fn default() -> Self {
37        Self {
38            bind_addr: ([0, 0, 0, 0], 8080).into(),
39            cors_enabled: true,
40            timeout_secs: 30,
41        }
42    }
43}
44
45impl HttpServerConfig {
46    /// Create with custom bind address
47    pub fn with_addr(mut self, addr: SocketAddr) -> Self {
48        self.bind_addr = addr;
49        self
50    }
51
52    /// Parse from string address
53    pub fn with_addr_str(mut self, addr: &str) -> Result<Self> {
54        self.bind_addr = addr
55            .parse()
56            .map_err(|e| ForgeError::config(format!("Invalid address: {}", e)))?;
57        Ok(self)
58    }
59}
60
61/// Shared state for HTTP handlers
62pub struct HttpState<T> {
63    /// Application state
64    pub app: Arc<RwLock<T>>,
65}
66
67impl<T> Clone for HttpState<T> {
68    fn clone(&self) -> Self {
69        Self {
70            app: Arc::clone(&self.app),
71        }
72    }
73}
74
75/// Health check response
76#[derive(Debug, Serialize)]
77pub struct HealthResponse {
78    pub status: String,
79    pub version: String,
80    pub uptime_secs: u64,
81}
82
83/// Error response
84#[derive(Debug, Serialize)]
85pub struct ErrorResponse {
86    pub error: String,
87    pub code: u16,
88}
89
90impl IntoResponse for ErrorResponse {
91    fn into_response(self) -> axum::response::Response {
92        let status = StatusCode::from_u16(self.code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
93        (status, Json(self)).into_response()
94    }
95}
96
97/// Create the base router with health endpoints
98pub fn base_router<T: Send + Sync + 'static>() -> Router<HttpState<T>> {
99    Router::new()
100        .route("/health", get(health_handler))
101        .route("/ready", get(ready_handler))
102}
103
104async fn health_handler() -> Json<HealthResponse> {
105    Json(HealthResponse {
106        status: "healthy".to_string(),
107        version: env!("CARGO_PKG_VERSION").to_string(),
108        uptime_secs: 0, // TODO: Track actual uptime
109    })
110}
111
112async fn ready_handler() -> StatusCode {
113    StatusCode::OK
114}
115
116/// HTTP server wrapper
117pub struct HttpServer {
118    config: HttpServerConfig,
119    router: Router,
120}
121
122impl HttpServer {
123    /// Create a new HTTP server
124    pub fn new(config: HttpServerConfig) -> Self {
125        Self {
126            config,
127            router: Router::new(),
128        }
129    }
130
131    /// Set the router
132    pub fn with_router(mut self, router: Router) -> Self {
133        self.router = router;
134        self
135    }
136
137    /// Start the server
138    pub async fn serve(self) -> Result<()> {
139        let listener = tokio::net::TcpListener::bind(self.config.bind_addr)
140            .await
141            .map_err(|e| ForgeError::network(format!("Failed to bind: {}", e)))?;
142
143        info!(addr = %self.config.bind_addr, "HTTP server starting");
144
145        axum::serve(listener, self.router)
146            .await
147            .map_err(|e| ForgeError::network(format!("Server error: {}", e)))?;
148
149        Ok(())
150    }
151}
152
153/// QUIC transport configuration
154#[cfg(feature = "quic")]
155#[derive(Debug, Clone)]
156pub struct QuicConfig {
157    /// Bind address
158    pub bind_addr: SocketAddr,
159    /// Server name for TLS
160    pub server_name: String,
161    /// Max concurrent streams
162    pub max_streams: u32,
163    /// Idle timeout in seconds
164    pub idle_timeout_secs: u64,
165}
166
167#[cfg(feature = "quic")]
168impl Default for QuicConfig {
169    fn default() -> Self {
170        Self {
171            bind_addr: ([0, 0, 0, 0], 4433).into(),
172            server_name: "forge".to_string(),
173            max_streams: 100,
174            idle_timeout_secs: 30,
175        }
176    }
177}
178
179/// QUIC transport for peer communication
180#[cfg(feature = "quic")]
181pub struct QuicTransport {
182    config: QuicConfig,
183    endpoint: Option<quinn::Endpoint>,
184}
185
186#[cfg(feature = "quic")]
187impl QuicTransport {
188    /// Create a new QUIC transport
189    pub fn new(config: QuicConfig) -> Self {
190        Self {
191            config,
192            endpoint: None,
193        }
194    }
195
196    /// Generate self-signed certificate for development
197    fn generate_self_signed_cert() -> Result<(rustls::pki_types::CertificateDer<'static>, rustls::pki_types::PrivateKeyDer<'static>)> {
198        let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
199            .map_err(|e| ForgeError::network(format!("Failed to generate cert: {}", e)))?;
200
201        let cert_der = rustls::pki_types::CertificateDer::from(cert.cert.der().to_vec());
202        let key_der = rustls::pki_types::PrivateKeyDer::try_from(cert.key_pair.serialize_der())
203            .map_err(|e| ForgeError::network(format!("Failed to serialize key: {}", e)))?;
204
205        Ok((cert_der, key_der))
206    }
207
208    /// Start as server
209    pub async fn start_server(&mut self) -> Result<()> {
210        let (cert, key) = Self::generate_self_signed_cert()?;
211
212        let mut server_crypto = rustls::ServerConfig::builder()
213            .with_no_client_auth()
214            .with_single_cert(vec![cert], key)
215            .map_err(|e| ForgeError::network(format!("TLS config error: {}", e)))?;
216
217        server_crypto.alpn_protocols = vec![b"forge".to_vec()];
218
219        let server_config = quinn::ServerConfig::with_crypto(Arc::new(
220            quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)
221                .map_err(|e| ForgeError::network(format!("QUIC config error: {}", e)))?,
222        ));
223
224        let endpoint = quinn::Endpoint::server(server_config, self.config.bind_addr)
225            .map_err(|e| ForgeError::network(format!("Failed to create endpoint: {}", e)))?;
226
227        info!(addr = %self.config.bind_addr, "QUIC server started");
228        self.endpoint = Some(endpoint);
229
230        Ok(())
231    }
232
233    /// Accept incoming connections
234    pub async fn accept(&self) -> Result<quinn::Connection> {
235        let endpoint = self
236            .endpoint
237            .as_ref()
238            .ok_or_else(|| ForgeError::network("Endpoint not started"))?;
239
240        let incoming = endpoint
241            .accept()
242            .await
243            .ok_or_else(|| ForgeError::network("Endpoint closed"))?;
244
245        let conn = incoming
246            .await
247            .map_err(|e| ForgeError::network(format!("Connection error: {}", e)))?;
248
249        info!(remote = %conn.remote_address(), "QUIC connection accepted");
250        Ok(conn)
251    }
252
253    /// Connect to a peer
254    pub async fn connect(&self, addr: SocketAddr) -> Result<quinn::Connection> {
255        let endpoint = self
256            .endpoint
257            .as_ref()
258            .ok_or_else(|| ForgeError::network("Endpoint not started"))?;
259
260        let conn = endpoint
261            .connect(addr, &self.config.server_name)
262            .map_err(|e| ForgeError::network(format!("Connect error: {}", e)))?
263            .await
264            .map_err(|e| ForgeError::network(format!("Connection error: {}", e)))?;
265
266        info!(remote = %addr, "QUIC connection established");
267        Ok(conn)
268    }
269
270    /// Get local address
271    pub fn local_addr(&self) -> Option<SocketAddr> {
272        self.endpoint.as_ref().and_then(|e| e.local_addr().ok())
273    }
274
275    /// Close the transport
276    pub fn close(&self) {
277        if let Some(endpoint) = &self.endpoint {
278            endpoint.close(0u32.into(), b"shutdown");
279        }
280    }
281}
282
283/// Message types for peer communication
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum PeerMessage {
286    /// Heartbeat ping
287    Ping { node_id: String, timestamp: u64 },
288    /// Heartbeat pong
289    Pong { node_id: String, timestamp: u64 },
290    /// Route request to expert
291    RouteRequest { request_id: String, input: String },
292    /// Route response from expert
293    RouteResponse {
294        request_id: String,
295        expert_index: usize,
296        result: Vec<u8>,
297    },
298    /// Shard assignment notification
299    ShardAssign { shard_id: u64, node_id: String },
300    /// Shard migration request
301    ShardMigrate {
302        shard_id: u64,
303        from_node: String,
304        to_node: String,
305    },
306}
307
308impl PeerMessage {
309    /// Serialize to bytes
310    pub fn to_bytes(&self) -> Result<Vec<u8>> {
311        serde_json::to_vec(self).map_err(|e| ForgeError::network(format!("Serialize error: {}", e)))
312    }
313
314    /// Deserialize from bytes
315    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
316        serde_json::from_slice(bytes)
317            .map_err(|e| ForgeError::network(format!("Deserialize error: {}", e)))
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_http_config_default() {
327        let config = HttpServerConfig::default();
328        assert_eq!(config.bind_addr.port(), 8080);
329        assert!(config.cors_enabled);
330    }
331
332    #[test]
333    #[cfg(feature = "quic")]
334    fn test_quic_config_default() {
335        let config = QuicConfig::default();
336        assert_eq!(config.bind_addr.port(), 4433);
337        assert_eq!(config.server_name, "forge");
338    }
339
340    #[test]
341    fn test_peer_message_serialization() {
342        let msg = PeerMessage::Ping {
343            node_id: "node-1".to_string(),
344            timestamp: 12345,
345        };
346
347        let bytes = msg.to_bytes().unwrap();
348        let decoded = PeerMessage::from_bytes(&bytes).unwrap();
349
350        match decoded {
351            PeerMessage::Ping { node_id, timestamp } => {
352                assert_eq!(node_id, "node-1");
353                assert_eq!(timestamp, 12345);
354            }
355            _ => panic!("Wrong message type"),
356        }
357    }
358}