Skip to main content

stdiobus_backend_native/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2026-present Raman Marozau <raman@worktif.com>
3// Copyright (c) 2026-present stdiobus contributors
4
5#![cfg_attr(docsrs, feature(doc_cfg))]
6
7//! Native FFI backend for stdio_bus
8//!
9//! This backend links directly to libstdio_bus.a and provides
10//! the highest performance option for Unix systems.
11
12use async_trait::async_trait;
13use std::ffi::{CStr, CString};
14use std::os::raw::{c_char, c_int, c_void};
15use std::ptr;
16use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, AtomicUsize, Ordering};
17use std::sync::Arc;
18use stdiobus_core::{Backend, BusMessage, BusState, BusStats, ConfigSource, Error, Result};
19use stdiobus_ffi::*;
20use tokio::sync::{mpsc, Mutex};
21
22/// Thread-safe wrapper for bus pointer
23struct BusPtr(AtomicUsize);
24
25impl BusPtr {
26    fn new() -> Self {
27        Self(AtomicUsize::new(0))
28    }
29
30    fn set(&self, ptr: *mut stdio_bus_t) {
31        self.0.store(ptr as usize, Ordering::SeqCst);
32    }
33
34    fn get(&self) -> Option<*mut stdio_bus_t> {
35        let ptr = self.0.load(Ordering::SeqCst);
36        if ptr == 0 {
37            None
38        } else {
39            Some(ptr as *mut stdio_bus_t)
40        }
41    }
42
43    fn take(&self) -> Option<*mut stdio_bus_t> {
44        let ptr = self.0.swap(0, Ordering::SeqCst);
45        if ptr == 0 {
46            None
47        } else {
48            Some(ptr as *mut stdio_bus_t)
49        }
50    }
51}
52
53unsafe impl Send for BusPtr {}
54unsafe impl Sync for BusPtr {}
55
56/// Wrapper for raw callback context pointer to allow storage in async Mutex.
57/// Safety: The pointer is only accessed under Mutex guard and follows
58/// strict lifecycle rules (see CallbackContext docs).
59struct CtxPtr(*mut CallbackContext);
60unsafe impl Send for CtxPtr {}
61unsafe impl Sync for CtxPtr {}
62
63fn state_to_u8(s: BusState) -> u8 {
64    match s {
65        BusState::Created => 0,
66        BusState::Starting => 1,
67        BusState::Running => 2,
68        BusState::Stopping => 3,
69        BusState::Stopped => 4,
70    }
71}
72
73fn u8_to_state(v: u8) -> BusState {
74    match v {
75        0 => BusState::Created,
76        1 => BusState::Starting,
77        2 => BusState::Running,
78        3 => BusState::Stopping,
79        4 => BusState::Stopped,
80        _ => BusState::Created,
81    }
82}
83
84/// Context passed to C callbacks via user_data.
85///
86/// Safety: This context is shared with C callbacks via raw pointer.
87/// The `alive` flag MUST be set to `false` before `stdio_bus_stop` is called,
88/// and the context MUST NOT be freed until after `stdio_bus_destroy` completes.
89struct CallbackContext {
90    /// Set to false during shutdown to prevent callbacks from accessing Rust state
91    alive: AtomicBool,
92    message_tx: mpsc::Sender<BusMessage>,
93    stats: Arc<Stats>,
94}
95
96/// Native backend using FFI to libstdio_bus
97pub struct NativeBackend {
98    bus: Arc<BusPtr>,
99    config: InternalConfig,
100    state: Arc<AtomicU8>,
101    message_tx: mpsc::Sender<BusMessage>,
102    message_rx: Mutex<Option<mpsc::Receiver<BusMessage>>>,
103    stats: Arc<Stats>,
104    running: Arc<AtomicBool>,
105    /// Owned callback context — freed only after C library is fully destroyed
106    callback_ctx: Mutex<Option<CtxPtr>>,
107}
108
109/// Internal config representation
110enum InternalConfig {
111    Path(String),
112    Json(String),
113}
114
115struct Stats {
116    messages_in: AtomicU64,
117    messages_out: AtomicU64,
118    bytes_in: AtomicU64,
119    bytes_out: AtomicU64,
120    worker_restarts: AtomicU64,
121    routing_errors: AtomicU64,
122}
123
124impl NativeBackend {
125    /// Create from a file path (legacy)
126    pub fn new(config_path: &str) -> Result<Self> {
127        Self::create(InternalConfig::Path(config_path.to_string()))
128    }
129
130    /// Create from a ConfigSource (primary)
131    pub fn from_config_source(source: &ConfigSource) -> Result<Self> {
132        let internal = match source {
133            ConfigSource::Path(p) => InternalConfig::Path(p.clone()),
134            ConfigSource::Config(cfg) => {
135                let json = cfg.to_json().map_err(|e| Error::InvalidArgument {
136                    message: format!("Failed to serialize config: {}", e),
137                })?;
138                InternalConfig::Json(json)
139            }
140        };
141        Self::create(internal)
142    }
143
144    fn create(config: InternalConfig) -> Result<Self> {
145        let (tx, rx) = mpsc::channel(1000);
146
147        Ok(Self {
148            bus: Arc::new(BusPtr::new()),
149            config,
150            state: Arc::new(AtomicU8::new(0)),
151            message_tx: tx,
152            message_rx: Mutex::new(Some(rx)),
153            stats: Arc::new(Stats {
154                messages_in: AtomicU64::new(0),
155                messages_out: AtomicU64::new(0),
156                bytes_in: AtomicU64::new(0),
157                bytes_out: AtomicU64::new(0),
158                worker_restarts: AtomicU64::new(0),
159                routing_errors: AtomicU64::new(0),
160            }),
161            running: Arc::new(AtomicBool::new(false)),
162            callback_ctx: Mutex::new(None),
163        })
164    }
165
166    fn get_state(&self) -> BusState {
167        u8_to_state(self.state.load(Ordering::SeqCst))
168    }
169
170    fn set_state(&self, state: BusState) {
171        self.state.store(state_to_u8(state), Ordering::SeqCst);
172    }
173}
174
175impl Drop for NativeBackend {
176    fn drop(&mut self) {
177        self.running.store(false, Ordering::SeqCst);
178
179        // Phase 1: Signal callbacks to stop accessing Rust state
180        if let Ok(guard) = self.callback_ctx.try_lock() {
181            if let Some(ref wrapper) = *guard {
182                unsafe { (*wrapper.0).alive.store(false, Ordering::SeqCst) };
183            }
184        }
185
186        // Phase 2: Stop and destroy the C bus (no more callbacks after this)
187        if let Some(bus) = self.bus.take() {
188            unsafe {
189                stdio_bus_stop(bus, 1);
190                stdio_bus_destroy(bus);
191            }
192        }
193
194        // Phase 3: Now safe to free the callback context
195        if let Ok(mut guard) = self.callback_ctx.try_lock() {
196            if let Some(wrapper) = guard.take() {
197                unsafe { drop(Box::from_raw(wrapper.0)) };
198            }
199        }
200    }
201}
202
203
204#[async_trait]
205impl Backend for NativeBackend {
206    async fn start(&self) -> Result<()> {
207        let current_state = self.get_state();
208        if !current_state.can_start() {
209            return Err(Error::InvalidState {
210                expected: "CREATED or STOPPED".to_string(),
211                actual: current_state.to_string(),
212            });
213        }
214
215        self.set_state(BusState::Starting);
216
217        // Create callback context with alive flag for safe teardown
218        let ctx = Box::new(CallbackContext {
219            alive: AtomicBool::new(true),
220            message_tx: self.message_tx.clone(),
221            stats: self.stats.clone(),
222        });
223        let ctx_ptr = Box::into_raw(ctx);
224        let ctx_usize = ctx_ptr as usize;
225
226        // Store the pointer so we can free it on stop/drop
227        *self.callback_ctx.lock().await = Some(CtxPtr(ctx_ptr));
228
229        // Clone config for the blocking task
230        let config = match &self.config {
231            InternalConfig::Path(p) => InternalConfig::Path(p.clone()),
232            InternalConfig::Json(j) => InternalConfig::Json(j.clone()),
233        };
234        
235        let bus = tokio::task::spawn_blocking(move || {
236            // Prepare C strings based on config source
237            let (path_ptr, json_ptr, _path_cstr, _json_cstr) = match &config {
238                InternalConfig::Path(p) => {
239                    let cstr = CString::new(p.as_str()).map_err(|_| Error::InvalidArgument {
240                        message: "Invalid config path".to_string(),
241                    })?;
242                    let ptr = cstr.as_ptr();
243                    (ptr, ptr::null(), Some(cstr), None)
244                }
245                InternalConfig::Json(j) => {
246                    let cstr = CString::new(j.as_str()).map_err(|_| Error::InvalidArgument {
247                        message: "Invalid config JSON (contains null byte)".to_string(),
248                    })?;
249                    let ptr = cstr.as_ptr();
250                    (ptr::null(), ptr, None, Some(cstr))
251                }
252            };
253
254            let listener = stdio_bus_listener_config_t {
255                mode: stdio_bus_listen_mode_t::STDIO_BUS_LISTEN_NONE,
256                tcp_host: ptr::null(),
257                tcp_port: 0,
258                unix_path: ptr::null(),
259            };
260
261            let options = stdio_bus_options_t {
262                config_path: path_ptr,
263                config_json: json_ptr,
264                listener,
265                on_message: Some(on_message_callback),
266                on_error: Some(on_error_callback),
267                on_log: Some(on_log_callback),
268                on_worker: None,
269                on_client_connect: None,
270                on_client_disconnect: None,
271                user_data: ctx_usize as *mut c_void,
272                log_level: 1,
273            };
274
275            let bus = unsafe { stdio_bus_create(&options) };
276            if bus.is_null() {
277                return Err(Error::InternalError {
278                    message: "Failed to create bus".to_string(),
279                });
280            }
281
282            let result = unsafe { stdio_bus_start(bus) };
283            if result != STDIO_BUS_OK {
284                unsafe { stdio_bus_destroy(bus) };
285                return Err(Error::InternalError {
286                    message: format!("Failed to start bus: error code {}", result),
287                });
288            }
289
290            Ok(bus as usize)
291        })
292        .await
293        .map_err(|e| Error::InternalError {
294            message: format!("Task join error: {}", e),
295        })??;
296
297        self.bus.set(bus as *mut stdio_bus_t);
298        self.set_state(BusState::Running);
299        self.running.store(true, Ordering::SeqCst);
300
301        // Start polling task
302        let bus_ptr = self.bus.clone();
303        let running = self.running.clone();
304
305        tokio::spawn(async move {
306            while running.load(Ordering::SeqCst) {
307                if let Some(bus) = bus_ptr.get() {
308                    let bus_usize = bus as usize;
309                    let _ = tokio::task::spawn_blocking(move || {
310                        unsafe { stdio_bus_step(bus_usize as *mut stdio_bus_t, 10) };
311                    })
312                    .await;
313                }
314                tokio::time::sleep(std::time::Duration::from_millis(1)).await;
315            }
316        });
317
318        Ok(())
319    }
320
321    async fn stop(&self, timeout_secs: u32) -> Result<()> {
322        self.running.store(false, Ordering::SeqCst);
323        self.set_state(BusState::Stopping);
324
325        // Phase 1: Signal callbacks to stop accessing Rust state
326        {
327            let guard = self.callback_ctx.lock().await;
328            if let Some(ref wrapper) = *guard {
329                unsafe { (*wrapper.0).alive.store(false, Ordering::SeqCst) };
330            }
331        }
332
333        // Phase 2: Stop and destroy the C bus (no more callbacks after this)
334        if let Some(bus) = self.bus.take() {
335            let bus_usize = bus as usize;
336            let timeout = timeout_secs as c_int;
337            
338            tokio::task::spawn_blocking(move || {
339                unsafe {
340                    stdio_bus_stop(bus_usize as *mut stdio_bus_t, timeout);
341                    stdio_bus_destroy(bus_usize as *mut stdio_bus_t);
342                }
343            })
344            .await
345            .map_err(|e| Error::InternalError {
346                message: format!("Task join error: {}", e),
347            })?;
348        }
349
350        // Phase 3: Now safe to free the callback context
351        {
352            let mut guard = self.callback_ctx.lock().await;
353            if let Some(wrapper) = guard.take() {
354                unsafe { drop(Box::from_raw(wrapper.0)) };
355            }
356        }
357
358        self.set_state(BusState::Stopped);
359        Ok(())
360    }
361
362    async fn send(&self, message: &str) -> Result<()> {
363        let bus = self.bus.get().ok_or_else(|| Error::InvalidState {
364            expected: "RUNNING".to_string(),
365            actual: "not initialized".to_string(),
366        })?;
367
368        let bus_usize = bus as usize;
369        let msg = message.to_string();
370        let msg_len = msg.len();
371
372        let result = tokio::task::spawn_blocking(move || {
373            unsafe {
374                stdio_bus_ingest(
375                    bus_usize as *mut stdio_bus_t,
376                    msg.as_ptr() as *const c_char,
377                    msg_len,
378                )
379            }
380        })
381        .await
382        .map_err(|e| Error::InternalError {
383            message: format!("Task join error: {}", e),
384        })?;
385
386        if result != STDIO_BUS_OK {
387            return Err(Error::TransportError {
388                message: format!("Failed to send message: error code {}", result),
389            });
390        }
391
392        self.stats.messages_in.fetch_add(1, Ordering::Relaxed);
393        self.stats.bytes_in.fetch_add(msg_len as u64, Ordering::Relaxed);
394
395        Ok(())
396    }
397
398    fn state(&self) -> BusState {
399        self.get_state()
400    }
401
402    fn stats(&self) -> BusStats {
403        BusStats {
404            messages_in: self.stats.messages_in.load(Ordering::Relaxed),
405            messages_out: self.stats.messages_out.load(Ordering::Relaxed),
406            bytes_in: self.stats.bytes_in.load(Ordering::Relaxed),
407            bytes_out: self.stats.bytes_out.load(Ordering::Relaxed),
408            worker_restarts: self.stats.worker_restarts.load(Ordering::Relaxed),
409            routing_errors: self.stats.routing_errors.load(Ordering::Relaxed),
410            ..Default::default()
411        }
412    }
413
414    fn worker_count(&self) -> i32 {
415        self.bus
416            .get()
417            .map(|bus| unsafe { stdio_bus_worker_count(bus) })
418            .unwrap_or(-1)
419    }
420
421    fn client_count(&self) -> i32 {
422        self.bus
423            .get()
424            .map(|bus| unsafe { stdio_bus_client_count(bus) })
425            .unwrap_or(0)
426    }
427
428    fn subscribe(&self) -> Option<mpsc::Receiver<BusMessage>> {
429        self.message_rx.try_lock().ok().and_then(|mut rx| rx.take())
430    }
431
432    fn backend_type(&self) -> &'static str {
433        "native"
434    }
435}
436
437
438extern "C" fn on_message_callback(
439    _bus: *mut stdio_bus_t,
440    msg: *const c_char,
441    len: usize,
442    user_data: *mut c_void,
443) {
444    // Guard: catch any panic to prevent unwinding across FFI boundary
445    let _ = std::panic::catch_unwind(|| {
446        if user_data.is_null() {
447            return;
448        }
449
450        let ctx = unsafe { &*(user_data as *const CallbackContext) };
451
452        // Check alive flag — if shutting down, do not touch Rust state
453        if !ctx.alive.load(Ordering::SeqCst) {
454            return;
455        }
456        
457        let slice = unsafe { std::slice::from_raw_parts(msg as *const u8, len) };
458        if let Ok(json) = std::str::from_utf8(slice) {
459            ctx.stats.messages_out.fetch_add(1, Ordering::Relaxed);
460            ctx.stats.bytes_out.fetch_add(len as u64, Ordering::Relaxed);
461            
462            let message = BusMessage { json: json.to_string() };
463            if let Err(e) = ctx.message_tx.try_send(message) {
464                tracing::warn!("Message channel full: {}", e);
465            }
466        }
467    });
468}
469
470extern "C" fn on_error_callback(
471    _bus: *mut stdio_bus_t,
472    code: c_int,
473    msg: *const c_char,
474    user_data: *mut c_void,
475) {
476    let _ = std::panic::catch_unwind(|| {
477        if !user_data.is_null() {
478            let ctx = unsafe { &*(user_data as *const CallbackContext) };
479            if !ctx.alive.load(Ordering::SeqCst) {
480                return;
481            }
482        }
483        let msg = unsafe { CStr::from_ptr(msg) };
484        tracing::error!("Bus error {}: {:?}", code, msg);
485    });
486}
487
488extern "C" fn on_log_callback(
489    _bus: *mut stdio_bus_t,
490    level: c_int,
491    msg: *const c_char,
492    user_data: *mut c_void,
493) {
494    let _ = std::panic::catch_unwind(|| {
495        if !user_data.is_null() {
496            let ctx = unsafe { &*(user_data as *const CallbackContext) };
497            if !ctx.alive.load(Ordering::SeqCst) {
498                return;
499            }
500        }
501        let msg = unsafe { CStr::from_ptr(msg) };
502        match level {
503            0 => tracing::debug!("{:?}", msg),
504            1 => tracing::info!("{:?}", msg),
505            2 => tracing::warn!("{:?}", msg),
506            _ => tracing::error!("{:?}", msg),
507        }
508    });
509}
510
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_native_backend_new() {
518        let result = NativeBackend::new("./test-config.json");
519        assert!(result.is_ok());
520    }
521
522    #[test]
523    fn test_native_backend_initial_state() {
524        let backend = NativeBackend::new("./test-config.json").unwrap();
525        assert_eq!(backend.state(), BusState::Created);
526    }
527
528    #[test]
529    fn test_native_backend_stats_initial() {
530        let backend = NativeBackend::new("./test-config.json").unwrap();
531        let stats = backend.stats();
532        
533        assert_eq!(stats.messages_in, 0);
534        assert_eq!(stats.messages_out, 0);
535        assert_eq!(stats.bytes_in, 0);
536        assert_eq!(stats.bytes_out, 0);
537    }
538
539    #[test]
540    fn test_native_backend_type() {
541        let backend = NativeBackend::new("./test-config.json").unwrap();
542        assert_eq!(backend.backend_type(), "native");
543    }
544
545    #[test]
546    fn test_native_backend_worker_count_not_started() {
547        let backend = NativeBackend::new("./test-config.json").unwrap();
548        assert_eq!(backend.worker_count(), -1);
549    }
550
551    #[test]
552    fn test_native_backend_client_count_not_started() {
553        let backend = NativeBackend::new("./test-config.json").unwrap();
554        assert_eq!(backend.client_count(), 0);
555    }
556
557    #[test]
558    fn test_native_backend_subscribe() {
559        let backend = NativeBackend::new("./test-config.json").unwrap();
560        
561        // First subscribe should succeed
562        let rx = backend.subscribe();
563        assert!(rx.is_some());
564        
565        // Second subscribe should fail
566        let rx2 = backend.subscribe();
567        assert!(rx2.is_none());
568    }
569
570    #[test]
571    fn test_state_conversion() {
572        assert_eq!(u8_to_state(0), BusState::Created);
573        assert_eq!(u8_to_state(1), BusState::Starting);
574        assert_eq!(u8_to_state(2), BusState::Running);
575        assert_eq!(u8_to_state(3), BusState::Stopping);
576        assert_eq!(u8_to_state(4), BusState::Stopped);
577        assert_eq!(u8_to_state(255), BusState::Created);
578        
579        assert_eq!(state_to_u8(BusState::Created), 0);
580        assert_eq!(state_to_u8(BusState::Starting), 1);
581        assert_eq!(state_to_u8(BusState::Running), 2);
582        assert_eq!(state_to_u8(BusState::Stopping), 3);
583        assert_eq!(state_to_u8(BusState::Stopped), 4);
584    }
585
586    #[test]
587    fn test_bus_ptr_operations() {
588        let ptr = BusPtr::new();
589        assert!(ptr.get().is_none());
590        
591        let fake_ptr = 0x12345678 as *mut stdio_bus_t;
592        ptr.set(fake_ptr);
593        
594        assert!(ptr.get().is_some());
595        assert_eq!(ptr.get().unwrap() as usize, 0x12345678);
596        
597        let taken = ptr.take();
598        assert!(taken.is_some());
599        assert!(ptr.get().is_none());
600    }
601
602    #[tokio::test]
603    async fn test_native_backend_start_invalid_state() {
604        let backend = NativeBackend::new("./test-config.json").unwrap();
605        
606        backend.state.store(state_to_u8(BusState::Running), Ordering::SeqCst);
607        
608        let result = backend.start().await;
609        assert!(result.is_err());
610        
611        if let Err(Error::InvalidState { expected, actual }) = result {
612            assert!(expected.contains("CREATED"));
613            assert!(actual.contains("RUNNING"));
614        }
615    }
616
617    #[tokio::test]
618    async fn test_native_backend_send_not_started() {
619        let backend = NativeBackend::new("./test-config.json").unwrap();
620        
621        let result = backend.send(r#"{"test": true}"#).await;
622        assert!(result.is_err());
623        
624        if let Err(Error::InvalidState { .. }) = result {
625            // Expected
626        } else {
627            panic!("Expected InvalidState error");
628        }
629    }
630
631    #[tokio::test]
632    async fn test_native_backend_stop_not_started() {
633        let backend = NativeBackend::new("./test-config.json").unwrap();
634        
635        let result = backend.stop(1).await;
636        assert!(result.is_ok());
637        assert_eq!(backend.state(), BusState::Stopped);
638    }
639}