1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::Arc;
7use std::thread;
8use std::time::{Duration, Instant};
9
10use async_nng::AsyncContext;
11use nng::options::Options;
12use nng::{Message, Protocol, Socket};
13
14use ormdb_proto::framing::encode_frame;
15use ormdb_proto::{Request, Response};
16
17use crate::config::ServerConfig;
18use crate::error::Error;
19use crate::handler::RequestHandler;
20
21#[derive(Debug)]
23pub struct TransportMetrics {
24 pub requests_total: AtomicU64,
26 pub requests_success: AtomicU64,
28 pub requests_failed: AtomicU64,
30 pub bytes_received: AtomicU64,
32 pub bytes_sent: AtomicU64,
34 pub started_at: Instant,
36}
37
38impl TransportMetrics {
39 fn new() -> Self {
41 Self {
42 requests_total: AtomicU64::new(0),
43 requests_success: AtomicU64::new(0),
44 requests_failed: AtomicU64::new(0),
45 bytes_received: AtomicU64::new(0),
46 bytes_sent: AtomicU64::new(0),
47 started_at: Instant::now(),
48 }
49 }
50
51 fn record_success(&self, received_bytes: usize, sent_bytes: usize) {
53 self.requests_total.fetch_add(1, Ordering::Relaxed);
54 self.requests_success.fetch_add(1, Ordering::Relaxed);
55 self.bytes_received.fetch_add(received_bytes as u64, Ordering::Relaxed);
56 self.bytes_sent.fetch_add(sent_bytes as u64, Ordering::Relaxed);
57 }
58
59 fn record_failure(&self, received_bytes: usize, sent_bytes: usize) {
61 self.requests_total.fetch_add(1, Ordering::Relaxed);
62 self.requests_failed.fetch_add(1, Ordering::Relaxed);
63 self.bytes_received.fetch_add(received_bytes as u64, Ordering::Relaxed);
64 self.bytes_sent.fetch_add(sent_bytes as u64, Ordering::Relaxed);
65 }
66
67 pub fn uptime(&self) -> Duration {
69 self.started_at.elapsed()
70 }
71
72 pub fn total_requests(&self) -> u64 {
74 self.requests_total.load(Ordering::Relaxed)
75 }
76
77 pub fn successful_requests(&self) -> u64 {
79 self.requests_success.load(Ordering::Relaxed)
80 }
81
82 pub fn failed_requests(&self) -> u64 {
84 self.requests_failed.load(Ordering::Relaxed)
85 }
86
87 pub fn total_bytes_received(&self) -> u64 {
89 self.bytes_received.load(Ordering::Relaxed)
90 }
91
92 pub fn total_bytes_sent(&self) -> u64 {
94 self.bytes_sent.load(Ordering::Relaxed)
95 }
96}
97
98impl Default for TransportMetrics {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104pub struct Transport {
106 socket: Socket,
107 handler: Arc<RequestHandler>,
108 max_message_size: usize,
109 metrics: Arc<TransportMetrics>,
110 request_timeout: Duration,
111 worker_count: usize,
112}
113
114impl Transport {
115 pub fn new(config: &ServerConfig, handler: Arc<RequestHandler>) -> Result<Self, Error> {
117 let socket = Socket::new(Protocol::Rep0)
119 .map_err(|e| Error::Transport(format!("failed to create socket: {}", e)))?;
120
121 socket
123 .set_opt::<nng::options::RecvMaxSize>(config.max_message_size)
124 .map_err(|e| Error::Transport(format!("failed to set max message size: {}", e)))?;
125
126 if let Some(tcp_addr) = &config.tcp_address {
128 socket
129 .listen(tcp_addr)
130 .map_err(|e| Error::Transport(format!("failed to listen on {}: {}", tcp_addr, e)))?;
131
132 tracing::info!(address = %tcp_addr, "listening on TCP");
133 }
134
135 if let Some(ipc_addr) = &config.ipc_address {
137 socket
138 .listen(ipc_addr)
139 .map_err(|e| Error::Transport(format!("failed to listen on {}: {}", ipc_addr, e)))?;
140
141 tracing::info!(address = %ipc_addr, "listening on IPC");
142 }
143
144 Ok(Self {
145 socket,
146 handler,
147 max_message_size: config.max_message_size,
148 metrics: Arc::new(TransportMetrics::new()),
149 request_timeout: config.request_timeout,
150 worker_count: config.transport_workers.max(1),
151 })
152 }
153
154 pub fn metrics(&self) -> &TransportMetrics {
156 &self.metrics
157 }
158
159 pub async fn run(&self) -> Result<(), Error> {
161 let stop_flag = Arc::new(AtomicBool::new(false));
162 let _handles = self.spawn_worker_threads(stop_flag)?;
163
164 tracing::info!("transport ready, accepting requests");
165 std::future::pending::<()>().await;
166 Ok(())
167 }
168
169 pub async fn run_until_shutdown(
171 &self,
172 mut shutdown: tokio::sync::broadcast::Receiver<()>,
173 ) -> Result<(), Error> {
174 let stop_flag = Arc::new(AtomicBool::new(false));
175 let handles = self.spawn_worker_threads(stop_flag.clone())?;
176
177 tracing::info!("transport ready, accepting requests");
178
179 let _ = shutdown.recv().await;
180 tracing::info!(
181 total_requests = self.metrics.total_requests(),
182 successful = self.metrics.successful_requests(),
183 failed = self.metrics.failed_requests(),
184 bytes_received = self.metrics.total_bytes_received(),
185 bytes_sent = self.metrics.total_bytes_sent(),
186 uptime_secs = self.metrics.uptime().as_secs(),
187 "shutdown signal received, stopping transport"
188 );
189
190 stop_flag.store(true, Ordering::SeqCst);
191 let _ = tokio::task::spawn_blocking(move || {
192 for handle in handles {
193 let _ = handle.join();
194 }
195 })
196 .await;
197
198 Ok(())
199 }
200
201 fn process_message(&self, data: &[u8]) -> Vec<u8> {
203 self.worker().process_message_with_status(data).0
204 }
205
206 fn process_message_with_status(&self, data: &[u8]) -> (Vec<u8>, bool) {
208 self.worker().process_message_with_status(data)
209 }
210
211 fn worker(&self) -> TransportWorker {
212 TransportWorker::new(self.handler.clone(), self.max_message_size)
213 }
214
215 fn spawn_worker_threads(
216 &self,
217 stop_flag: Arc<AtomicBool>,
218 ) -> Result<Vec<thread::JoinHandle<()>>, Error> {
219 let mut handles = Vec::with_capacity(self.worker_count);
220 for worker_id in 0..self.worker_count {
221 let socket = self.socket.clone();
222 let worker = self.worker();
223 let metrics = self.metrics.clone();
224 let request_timeout = self.request_timeout;
225 let stop_flag = stop_flag.clone();
226
227 let handle = thread::Builder::new()
228 .name(format!("ormdb-transport-{}", worker_id))
229 .spawn(move || {
230 let runtime = tokio::runtime::Builder::new_current_thread()
231 .enable_all()
232 .build()
233 .expect("failed to build transport worker runtime");
234
235 runtime.block_on(async move {
236 let mut ctx = match AsyncContext::try_from(&socket) {
237 Ok(ctx) => ctx,
238 Err(e) => {
239 tracing::error!(error = %e, worker_id, "failed to create async context");
240 return;
241 }
242 };
243
244 loop {
245 if stop_flag.load(Ordering::SeqCst) {
246 tracing::info!(worker_id, "transport worker stopping");
247 return;
248 }
249
250 match ctx.receive(Some(Duration::from_secs(1))).await {
251 Ok(msg) => {
252 let received_bytes = msg.len();
253 let start = Instant::now();
254 let (response_bytes, is_success) =
255 worker.process_message_with_status(msg.as_slice());
256 let elapsed = start.elapsed();
257 let sent_bytes = response_bytes.len();
258
259 let response_msg = Message::from(response_bytes.as_slice());
260
261 if let Err((_, e)) = ctx.send(response_msg, None).await {
262 tracing::error!(error = %e, worker_id, "failed to send response");
263 metrics.record_failure(received_bytes, 0);
264 } else if is_success {
265 metrics.record_success(received_bytes, sent_bytes);
266 } else {
267 metrics.record_failure(received_bytes, sent_bytes);
268 }
269
270 if elapsed > request_timeout {
271 tracing::warn!(
272 worker_id,
273 duration_ms = elapsed.as_millis() as u64,
274 timeout_ms = request_timeout.as_millis() as u64,
275 "request exceeded timeout"
276 );
277 }
278 }
279 Err(nng::Error::TimedOut) => {
280 continue;
281 }
282 Err(e) => {
283 tracing::error!(error = %e, worker_id, "receive error");
284 }
285 }
286 }
287 });
288 })
289 .map_err(|e| Error::Transport(format!("failed to spawn transport worker: {}", e)))?;
290
291 handles.push(handle);
292 }
293
294 Ok(handles)
295 }
296}
297
298struct TransportWorker {
299 handler: Arc<RequestHandler>,
300 max_message_size: usize,
301}
302
303impl TransportWorker {
304 fn new(handler: Arc<RequestHandler>, max_message_size: usize) -> Self {
305 Self {
306 handler,
307 max_message_size,
308 }
309 }
310
311 fn process_message_with_status(&self, data: &[u8]) -> (Vec<u8>, bool) {
313 let (response, is_success) = match self.decode_and_handle(data) {
315 Ok(response) => {
316 let is_ok = response.status.is_ok();
317 (response, is_ok)
318 }
319 Err(e) => {
320 tracing::error!(error = %e, "request processing error");
321 let response = Response::error(0, ormdb_proto::error_codes::INTERNAL, e.to_string());
323 (response, false)
324 }
325 };
326
327 let bytes = match self.encode_response(&response) {
329 Ok(bytes) => bytes,
330 Err(e) => {
331 tracing::error!(error = %e, "failed to encode response");
332 self.encode_minimal_error(&e.to_string())
334 }
335 };
336
337 (bytes, is_success)
338 }
339
340 fn decode_and_handle(&self, data: &[u8]) -> Result<Response, Error> {
342 if data.len() > self.max_message_size {
344 return Err(Error::Protocol(ormdb_proto::Error::InvalidMessage(format!(
345 "message too large: {} bytes (max: {})",
346 data.len(),
347 self.max_message_size
348 ))));
349 }
350
351 let payload = ormdb_proto::framing::extract_payload(data)?;
353
354 let mut aligned: rkyv::util::AlignedVec<16> = rkyv::util::AlignedVec::new();
356 aligned.extend_from_slice(payload);
357
358 let request: Request =
360 rkyv::from_bytes::<Request, rkyv::rancor::Error>(&aligned).map_err(|e| {
361 Error::Protocol(ormdb_proto::Error::InvalidMessage(format!(
362 "failed to deserialize request: {}",
363 e
364 )))
365 })?;
366
367 Ok(self.handler.handle(&request))
369 }
370
371 fn encode_response(&self, response: &Response) -> Result<Vec<u8>, Error> {
373 let payload = rkyv::to_bytes::<rkyv::rancor::Error>(response).map_err(|e| {
374 Error::Protocol(ormdb_proto::Error::Serialization(format!(
375 "failed to serialize response: {}",
376 e
377 )))
378 })?;
379
380 encode_frame(&payload).map_err(|e| Error::Protocol(e))
381 }
382
383 fn encode_minimal_error(&self, message: &str) -> Vec<u8> {
385 let response = Response::error(0, ormdb_proto::error_codes::INTERNAL, message);
386
387 match rkyv::to_bytes::<rkyv::rancor::Error>(&response) {
389 Ok(payload) => match encode_frame(&payload) {
390 Ok(framed) => framed,
391 Err(_) => Vec::new(),
392 },
393 Err(_) => Vec::new(),
394 }
395 }
396}
397
398
399pub fn create_transport(
401 config: &ServerConfig,
402 handler: Arc<RequestHandler>,
403) -> Result<Transport, Error> {
404 if !config.has_transport() {
405 return Err(Error::Config(
406 "no transport configured (need TCP or IPC address)".to_string(),
407 ));
408 }
409
410 Transport::new(config, handler)
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use crate::database::Database;
417 use ormdb_core::catalog::{EntityDef, FieldDef, FieldType, ScalarType, SchemaBundle};
418 use ormdb_proto::framing::MAX_MESSAGE_SIZE;
419
420 fn setup_test_components() -> (tempfile::TempDir, Arc<RequestHandler>) {
421 let dir = tempfile::tempdir().unwrap();
422 let db = Database::open(dir.path()).unwrap();
423
424 let schema = SchemaBundle::new(1).with_entity(
426 EntityDef::new("User", "id")
427 .with_field(FieldDef::new("id", FieldType::Scalar(ScalarType::Uuid)))
428 .with_field(FieldDef::new("name", FieldType::Scalar(ScalarType::String))),
429 );
430 db.catalog().apply_schema(schema).unwrap();
431
432 let handler = Arc::new(RequestHandler::new(Arc::new(db)));
433 (dir, handler)
434 }
435
436 #[test]
437 fn test_transport_creation() {
438 let (dir, handler) = setup_test_components();
439
440 let ipc_path = format!("ipc://{}", dir.path().join("ormdb.sock").display());
441 let config = ServerConfig::new(dir.path())
442 .without_tcp()
443 .with_ipc_address(ipc_path)
444 .with_max_message_size(MAX_MESSAGE_SIZE);
445
446 let transport = Transport::new(&config, handler);
447 match transport {
448 Ok(_) => {}
449 Err(Error::Transport(msg)) if msg.contains("Permission denied") => {
450 return;
451 }
452 Err(err) => panic!("transport creation failed: {err}"),
453 }
454 }
455
456 #[test]
457 fn test_transport_requires_address() {
458 let (_dir, handler) = setup_test_components();
459
460 let config = ServerConfig::new("/tmp/test").without_tcp();
461
462 let result = create_transport(&config, handler);
463 assert!(result.is_err());
464 }
465
466 #[test]
467 fn test_process_ping_message() {
468 let (_dir, handler) = setup_test_components();
469 let worker = TransportWorker::new(handler, MAX_MESSAGE_SIZE);
470
471 let request = Request::ping(42);
473 let payload = rkyv::to_bytes::<rkyv::rancor::Error>(&request).unwrap();
474 let framed = encode_frame(&payload).unwrap();
475
476 let (response_bytes, is_success) = worker.process_message_with_status(&framed);
478 assert!(is_success);
479
480 let response_payload = ormdb_proto::framing::extract_payload(&response_bytes).unwrap();
482 let mut aligned: rkyv::util::AlignedVec<16> = rkyv::util::AlignedVec::new();
483 aligned.extend_from_slice(response_payload);
484 let response: Response =
485 rkyv::from_bytes::<Response, rkyv::rancor::Error>(&aligned).unwrap();
486
487 assert_eq!(response.id, 42);
488 assert!(response.status.is_ok());
489 assert!(matches!(
490 response.payload,
491 ormdb_proto::ResponsePayload::Pong
492 ));
493 }
494
495 #[test]
496 fn test_process_invalid_message() {
497 let (_dir, handler) = setup_test_components();
498 let worker = TransportWorker::new(handler, MAX_MESSAGE_SIZE);
499
500 let (response_bytes, is_success) = worker.process_message_with_status(b"invalid data");
502
503 assert!(!response_bytes.is_empty());
505 assert!(!is_success);
506 }
507
508 #[test]
509 fn test_process_messages_concurrently() {
510 let (_dir, handler) = setup_test_components();
511
512 let mut handles = Vec::new();
513 for i in 0..8 {
514 let handler = handler.clone();
515 handles.push(std::thread::spawn(move || {
516 let worker = TransportWorker::new(handler, MAX_MESSAGE_SIZE);
517 let request_id = 100 + i as u64;
518 let request = Request::ping(request_id);
519 let payload = rkyv::to_bytes::<rkyv::rancor::Error>(&request).unwrap();
520 let framed = encode_frame(&payload).unwrap();
521
522 let (response_bytes, is_success) = worker.process_message_with_status(&framed);
523 assert!(is_success);
524
525 let response_payload = ormdb_proto::framing::extract_payload(&response_bytes).unwrap();
526 let mut aligned: rkyv::util::AlignedVec<16> = rkyv::util::AlignedVec::new();
527 aligned.extend_from_slice(response_payload);
528 let response: Response =
529 rkyv::from_bytes::<Response, rkyv::rancor::Error>(&aligned).unwrap();
530
531 assert_eq!(response.id, request_id);
532 assert!(response.status.is_ok());
533 }));
534 }
535
536 for handle in handles {
537 handle.join().unwrap();
538 }
539 }
540}