Skip to main content

ormdb_server/
transport.rs

1//! Server transport layer using async-nng.
2//!
3//! Provides TCP and IPC transport for the ORMDB server using NNG's REP socket.
4
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::Arc;
7use std::thread;
8use std::time::{Duration, Instant};
9
10use async_nng::AsyncContext;
11use nng::options::Options;
12use nng::{Message, Protocol, Socket};
13
14use ormdb_proto::framing::encode_frame;
15use ormdb_proto::{Request, Response};
16
17use crate::config::ServerConfig;
18use crate::error::Error;
19use crate::handler::RequestHandler;
20
21/// Transport metrics for monitoring.
22#[derive(Debug)]
23pub struct TransportMetrics {
24    /// Total number of requests received.
25    pub requests_total: AtomicU64,
26    /// Number of successful requests.
27    pub requests_success: AtomicU64,
28    /// Number of failed requests.
29    pub requests_failed: AtomicU64,
30    /// Number of bytes received.
31    pub bytes_received: AtomicU64,
32    /// Number of bytes sent.
33    pub bytes_sent: AtomicU64,
34    /// Server start time.
35    pub started_at: Instant,
36}
37
38impl TransportMetrics {
39    /// Create new metrics.
40    fn new() -> Self {
41        Self {
42            requests_total: AtomicU64::new(0),
43            requests_success: AtomicU64::new(0),
44            requests_failed: AtomicU64::new(0),
45            bytes_received: AtomicU64::new(0),
46            bytes_sent: AtomicU64::new(0),
47            started_at: Instant::now(),
48        }
49    }
50
51    /// Record a successful request.
52    fn record_success(&self, received_bytes: usize, sent_bytes: usize) {
53        self.requests_total.fetch_add(1, Ordering::Relaxed);
54        self.requests_success.fetch_add(1, Ordering::Relaxed);
55        self.bytes_received.fetch_add(received_bytes as u64, Ordering::Relaxed);
56        self.bytes_sent.fetch_add(sent_bytes as u64, Ordering::Relaxed);
57    }
58
59    /// Record a failed request.
60    fn record_failure(&self, received_bytes: usize, sent_bytes: usize) {
61        self.requests_total.fetch_add(1, Ordering::Relaxed);
62        self.requests_failed.fetch_add(1, Ordering::Relaxed);
63        self.bytes_received.fetch_add(received_bytes as u64, Ordering::Relaxed);
64        self.bytes_sent.fetch_add(sent_bytes as u64, Ordering::Relaxed);
65    }
66
67    /// Get the uptime duration.
68    pub fn uptime(&self) -> Duration {
69        self.started_at.elapsed()
70    }
71
72    /// Get total requests count.
73    pub fn total_requests(&self) -> u64 {
74        self.requests_total.load(Ordering::Relaxed)
75    }
76
77    /// Get successful requests count.
78    pub fn successful_requests(&self) -> u64 {
79        self.requests_success.load(Ordering::Relaxed)
80    }
81
82    /// Get failed requests count.
83    pub fn failed_requests(&self) -> u64 {
84        self.requests_failed.load(Ordering::Relaxed)
85    }
86
87    /// Get total bytes received.
88    pub fn total_bytes_received(&self) -> u64 {
89        self.bytes_received.load(Ordering::Relaxed)
90    }
91
92    /// Get total bytes sent.
93    pub fn total_bytes_sent(&self) -> u64 {
94        self.bytes_sent.load(Ordering::Relaxed)
95    }
96}
97
98impl Default for TransportMetrics {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104/// Server transport that handles incoming connections.
105pub struct Transport {
106    socket: Socket,
107    handler: Arc<RequestHandler>,
108    max_message_size: usize,
109    metrics: Arc<TransportMetrics>,
110    request_timeout: Duration,
111    worker_count: usize,
112}
113
114impl Transport {
115    /// Create a new transport with the given configuration and request handler.
116    pub fn new(config: &ServerConfig, handler: Arc<RequestHandler>) -> Result<Self, Error> {
117        // Create REP socket
118        let socket = Socket::new(Protocol::Rep0)
119            .map_err(|e| Error::Transport(format!("failed to create socket: {}", e)))?;
120
121        // Set socket options
122        socket
123            .set_opt::<nng::options::RecvMaxSize>(config.max_message_size)
124            .map_err(|e| Error::Transport(format!("failed to set max message size: {}", e)))?;
125
126        // Bind to TCP address if configured
127        if let Some(tcp_addr) = &config.tcp_address {
128            socket
129                .listen(tcp_addr)
130                .map_err(|e| Error::Transport(format!("failed to listen on {}: {}", tcp_addr, e)))?;
131
132            tracing::info!(address = %tcp_addr, "listening on TCP");
133        }
134
135        // Bind to IPC address if configured
136        if let Some(ipc_addr) = &config.ipc_address {
137            socket
138                .listen(ipc_addr)
139                .map_err(|e| Error::Transport(format!("failed to listen on {}: {}", ipc_addr, e)))?;
140
141            tracing::info!(address = %ipc_addr, "listening on IPC");
142        }
143
144        Ok(Self {
145            socket,
146            handler,
147            max_message_size: config.max_message_size,
148            metrics: Arc::new(TransportMetrics::new()),
149            request_timeout: config.request_timeout,
150            worker_count: config.transport_workers.max(1),
151        })
152    }
153
154    /// Get a reference to the transport metrics.
155    pub fn metrics(&self) -> &TransportMetrics {
156        &self.metrics
157    }
158
159    /// Run the transport loop, processing incoming requests.
160    pub async fn run(&self) -> Result<(), Error> {
161        let stop_flag = Arc::new(AtomicBool::new(false));
162        let _handles = self.spawn_worker_threads(stop_flag)?;
163
164        tracing::info!("transport ready, accepting requests");
165        std::future::pending::<()>().await;
166        Ok(())
167    }
168
169    /// Run the transport with graceful shutdown support.
170    pub async fn run_until_shutdown(
171        &self,
172        mut shutdown: tokio::sync::broadcast::Receiver<()>,
173    ) -> Result<(), Error> {
174        let stop_flag = Arc::new(AtomicBool::new(false));
175        let handles = self.spawn_worker_threads(stop_flag.clone())?;
176
177        tracing::info!("transport ready, accepting requests");
178
179        let _ = shutdown.recv().await;
180        tracing::info!(
181            total_requests = self.metrics.total_requests(),
182            successful = self.metrics.successful_requests(),
183            failed = self.metrics.failed_requests(),
184            bytes_received = self.metrics.total_bytes_received(),
185            bytes_sent = self.metrics.total_bytes_sent(),
186            uptime_secs = self.metrics.uptime().as_secs(),
187            "shutdown signal received, stopping transport"
188        );
189
190        stop_flag.store(true, Ordering::SeqCst);
191        let _ = tokio::task::spawn_blocking(move || {
192            for handle in handles {
193                let _ = handle.join();
194            }
195        })
196        .await;
197
198        Ok(())
199    }
200
201    /// Process a raw message and return the response bytes.
202    fn process_message(&self, data: &[u8]) -> Vec<u8> {
203        self.worker().process_message_with_status(data).0
204    }
205
206    /// Process a raw message and return (response bytes, is_success).
207    fn process_message_with_status(&self, data: &[u8]) -> (Vec<u8>, bool) {
208        self.worker().process_message_with_status(data)
209    }
210
211    fn worker(&self) -> TransportWorker {
212        TransportWorker::new(self.handler.clone(), self.max_message_size)
213    }
214
215    fn spawn_worker_threads(
216        &self,
217        stop_flag: Arc<AtomicBool>,
218    ) -> Result<Vec<thread::JoinHandle<()>>, Error> {
219        let mut handles = Vec::with_capacity(self.worker_count);
220        for worker_id in 0..self.worker_count {
221            let socket = self.socket.clone();
222            let worker = self.worker();
223            let metrics = self.metrics.clone();
224            let request_timeout = self.request_timeout;
225            let stop_flag = stop_flag.clone();
226
227            let handle = thread::Builder::new()
228                .name(format!("ormdb-transport-{}", worker_id))
229                .spawn(move || {
230                    let runtime = tokio::runtime::Builder::new_current_thread()
231                        .enable_all()
232                        .build()
233                        .expect("failed to build transport worker runtime");
234
235                    runtime.block_on(async move {
236                        let mut ctx = match AsyncContext::try_from(&socket) {
237                            Ok(ctx) => ctx,
238                            Err(e) => {
239                                tracing::error!(error = %e, worker_id, "failed to create async context");
240                                return;
241                            }
242                        };
243
244                        loop {
245                            if stop_flag.load(Ordering::SeqCst) {
246                                tracing::info!(worker_id, "transport worker stopping");
247                                return;
248                            }
249
250                            match ctx.receive(Some(Duration::from_secs(1))).await {
251                                Ok(msg) => {
252                                    let received_bytes = msg.len();
253                                    let start = Instant::now();
254                                    let (response_bytes, is_success) =
255                                        worker.process_message_with_status(msg.as_slice());
256                                    let elapsed = start.elapsed();
257                                    let sent_bytes = response_bytes.len();
258
259                                    let response_msg = Message::from(response_bytes.as_slice());
260
261                                    if let Err((_, e)) = ctx.send(response_msg, None).await {
262                                        tracing::error!(error = %e, worker_id, "failed to send response");
263                                        metrics.record_failure(received_bytes, 0);
264                                    } else if is_success {
265                                        metrics.record_success(received_bytes, sent_bytes);
266                                    } else {
267                                        metrics.record_failure(received_bytes, sent_bytes);
268                                    }
269
270                                    if elapsed > request_timeout {
271                                        tracing::warn!(
272                                            worker_id,
273                                            duration_ms = elapsed.as_millis() as u64,
274                                            timeout_ms = request_timeout.as_millis() as u64,
275                                            "request exceeded timeout"
276                                        );
277                                    }
278                                }
279                                Err(nng::Error::TimedOut) => {
280                                    continue;
281                                }
282                                Err(e) => {
283                                    tracing::error!(error = %e, worker_id, "receive error");
284                                }
285                            }
286                        }
287                    });
288                })
289                .map_err(|e| Error::Transport(format!("failed to spawn transport worker: {}", e)))?;
290
291            handles.push(handle);
292        }
293
294        Ok(handles)
295    }
296}
297
298struct TransportWorker {
299    handler: Arc<RequestHandler>,
300    max_message_size: usize,
301}
302
303impl TransportWorker {
304    fn new(handler: Arc<RequestHandler>, max_message_size: usize) -> Self {
305        Self {
306            handler,
307            max_message_size,
308        }
309    }
310
311    /// Process a raw message and return (response bytes, is_success).
312    fn process_message_with_status(&self, data: &[u8]) -> (Vec<u8>, bool) {
313        // Decode and process the request
314        let (response, is_success) = match self.decode_and_handle(data) {
315            Ok(response) => {
316                let is_ok = response.status.is_ok();
317                (response, is_ok)
318            }
319            Err(e) => {
320                tracing::error!(error = %e, "request processing error");
321                // Return error response with request ID 0 (unknown)
322                let response = Response::error(0, ormdb_proto::error_codes::INTERNAL, e.to_string());
323                (response, false)
324            }
325        };
326
327        // Serialize response
328        let bytes = match self.encode_response(&response) {
329            Ok(bytes) => bytes,
330            Err(e) => {
331                tracing::error!(error = %e, "failed to encode response");
332                // Try to send a minimal error response
333                self.encode_minimal_error(&e.to_string())
334            }
335        };
336
337        (bytes, is_success)
338    }
339
340    /// Decode a request and dispatch to handler.
341    fn decode_and_handle(&self, data: &[u8]) -> Result<Response, Error> {
342        // Check message size
343        if data.len() > self.max_message_size {
344            return Err(Error::Protocol(ormdb_proto::Error::InvalidMessage(format!(
345                "message too large: {} bytes (max: {})",
346                data.len(),
347                self.max_message_size
348            ))));
349        }
350
351        // Extract payload from framed message
352        let payload = ormdb_proto::framing::extract_payload(data)?;
353
354        // Copy to aligned buffer for rkyv (required for zero-copy access)
355        let mut aligned: rkyv::util::AlignedVec<16> = rkyv::util::AlignedVec::new();
356        aligned.extend_from_slice(payload);
357
358        // Deserialize request using rkyv
359        let request: Request =
360            rkyv::from_bytes::<Request, rkyv::rancor::Error>(&aligned).map_err(|e| {
361                Error::Protocol(ormdb_proto::Error::InvalidMessage(format!(
362                    "failed to deserialize request: {}",
363                    e
364                )))
365            })?;
366
367        // Handle the request
368        Ok(self.handler.handle(&request))
369    }
370
371    /// Encode a response to framed bytes.
372    fn encode_response(&self, response: &Response) -> Result<Vec<u8>, Error> {
373        let payload = rkyv::to_bytes::<rkyv::rancor::Error>(response).map_err(|e| {
374            Error::Protocol(ormdb_proto::Error::Serialization(format!(
375                "failed to serialize response: {}",
376                e
377            )))
378        })?;
379
380        encode_frame(&payload).map_err(|e| Error::Protocol(e))
381    }
382
383    /// Create a minimal error response when normal encoding fails.
384    fn encode_minimal_error(&self, message: &str) -> Vec<u8> {
385        let response = Response::error(0, ormdb_proto::error_codes::INTERNAL, message);
386
387        // Try to encode, fall back to empty on failure
388        match rkyv::to_bytes::<rkyv::rancor::Error>(&response) {
389            Ok(payload) => match encode_frame(&payload) {
390                Ok(framed) => framed,
391                Err(_) => Vec::new(),
392            },
393            Err(_) => Vec::new(),
394        }
395    }
396}
397
398
399/// Create a transport that listens on the configured addresses.
400pub fn create_transport(
401    config: &ServerConfig,
402    handler: Arc<RequestHandler>,
403) -> Result<Transport, Error> {
404    if !config.has_transport() {
405        return Err(Error::Config(
406            "no transport configured (need TCP or IPC address)".to_string(),
407        ));
408    }
409
410    Transport::new(config, handler)
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use crate::database::Database;
417    use ormdb_core::catalog::{EntityDef, FieldDef, FieldType, ScalarType, SchemaBundle};
418    use ormdb_proto::framing::MAX_MESSAGE_SIZE;
419
420    fn setup_test_components() -> (tempfile::TempDir, Arc<RequestHandler>) {
421        let dir = tempfile::tempdir().unwrap();
422        let db = Database::open(dir.path()).unwrap();
423
424        // Create schema
425        let schema = SchemaBundle::new(1).with_entity(
426            EntityDef::new("User", "id")
427                .with_field(FieldDef::new("id", FieldType::Scalar(ScalarType::Uuid)))
428                .with_field(FieldDef::new("name", FieldType::Scalar(ScalarType::String))),
429        );
430        db.catalog().apply_schema(schema).unwrap();
431
432        let handler = Arc::new(RequestHandler::new(Arc::new(db)));
433        (dir, handler)
434    }
435
436    #[test]
437    fn test_transport_creation() {
438        let (dir, handler) = setup_test_components();
439
440        let ipc_path = format!("ipc://{}", dir.path().join("ormdb.sock").display());
441        let config = ServerConfig::new(dir.path())
442            .without_tcp()
443            .with_ipc_address(ipc_path)
444            .with_max_message_size(MAX_MESSAGE_SIZE);
445
446        let transport = Transport::new(&config, handler);
447        match transport {
448            Ok(_) => {}
449            Err(Error::Transport(msg)) if msg.contains("Permission denied") => {
450                return;
451            }
452            Err(err) => panic!("transport creation failed: {err}"),
453        }
454    }
455
456    #[test]
457    fn test_transport_requires_address() {
458        let (_dir, handler) = setup_test_components();
459
460        let config = ServerConfig::new("/tmp/test").without_tcp();
461
462        let result = create_transport(&config, handler);
463        assert!(result.is_err());
464    }
465
466    #[test]
467    fn test_process_ping_message() {
468        let (_dir, handler) = setup_test_components();
469        let worker = TransportWorker::new(handler, MAX_MESSAGE_SIZE);
470
471        // Create a ping request
472        let request = Request::ping(42);
473        let payload = rkyv::to_bytes::<rkyv::rancor::Error>(&request).unwrap();
474        let framed = encode_frame(&payload).unwrap();
475
476        // Process it
477        let (response_bytes, is_success) = worker.process_message_with_status(&framed);
478        assert!(is_success);
479
480        // Decode response - copy to aligned buffer for rkyv
481        let response_payload = ormdb_proto::framing::extract_payload(&response_bytes).unwrap();
482        let mut aligned: rkyv::util::AlignedVec<16> = rkyv::util::AlignedVec::new();
483        aligned.extend_from_slice(response_payload);
484        let response: Response =
485            rkyv::from_bytes::<Response, rkyv::rancor::Error>(&aligned).unwrap();
486
487        assert_eq!(response.id, 42);
488        assert!(response.status.is_ok());
489        assert!(matches!(
490            response.payload,
491            ormdb_proto::ResponsePayload::Pong
492        ));
493    }
494
495    #[test]
496    fn test_process_invalid_message() {
497        let (_dir, handler) = setup_test_components();
498        let worker = TransportWorker::new(handler, MAX_MESSAGE_SIZE);
499
500        // Send garbage data
501        let (response_bytes, is_success) = worker.process_message_with_status(b"invalid data");
502
503        // Should return an error response
504        assert!(!response_bytes.is_empty());
505        assert!(!is_success);
506    }
507
508    #[test]
509    fn test_process_messages_concurrently() {
510        let (_dir, handler) = setup_test_components();
511
512        let mut handles = Vec::new();
513        for i in 0..8 {
514            let handler = handler.clone();
515            handles.push(std::thread::spawn(move || {
516                let worker = TransportWorker::new(handler, MAX_MESSAGE_SIZE);
517                let request_id = 100 + i as u64;
518                let request = Request::ping(request_id);
519                let payload = rkyv::to_bytes::<rkyv::rancor::Error>(&request).unwrap();
520                let framed = encode_frame(&payload).unwrap();
521
522                let (response_bytes, is_success) = worker.process_message_with_status(&framed);
523                assert!(is_success);
524
525                let response_payload = ormdb_proto::framing::extract_payload(&response_bytes).unwrap();
526                let mut aligned: rkyv::util::AlignedVec<16> = rkyv::util::AlignedVec::new();
527                aligned.extend_from_slice(response_payload);
528                let response: Response =
529                    rkyv::from_bytes::<Response, rkyv::rancor::Error>(&aligned).unwrap();
530
531                assert_eq!(response.id, request_id);
532                assert!(response.status.is_ok());
533            }));
534        }
535
536        for handle in handles {
537            handle.join().unwrap();
538        }
539    }
540}