Skip to main content

rave_tensorrt/
tensorrt.rs

1//! TensorRT inference backend — ORT + TensorRTExecutionProvider + IO Binding.
2//!
3//! # Zero-copy contract
4//!
5//! Input and output tensors are bound to CUDA device pointers via ORT's
6//! IO Binding API.  At no point does frame data touch host memory.
7//!
8//! # Execution provider policy
9//!
10//! **Only TensorRT EP is permitted.**  CPU EP is explicitly disabled.
11//! After session creation, the provider list is validated.  If ORT falls
12//! back to CPU for any graph node, `initialize()` returns an error.
13//!
14//! # CUDA stream ordering
15//!
16//! ORT creates its own internal CUDA stream for TensorRT EP execution.
17//! We cannot inject `GpuContext::inference_stream` because `cudarc::CudaStream`
18//! does not expose its raw `CUstream` handle (`pub(crate)` field).
19//!
20//! Correctness is maintained because `session.run_with_binding()` is
21//! **synchronous** — ORT blocks the calling thread until all GPU kernels
22//! on its internal stream complete.  Therefore:
23//!
24//! 1. Output buffer is fully written when `run_with_binding()` returns.
25//! 2. CUDA global memory coherency guarantees visibility to any subsequent
26//!    reader on any stream after this synchronization point.
27//! 3. No additional inter-stream event is needed.
28//!
29//! # Output ring serialization
30//!
31//! `OutputRing` owns N pre-allocated device buffers.  `acquire()` checks
32//! `Arc::strong_count == 1` before returning a slot, guaranteeing no
33//! concurrent reader.  Ring size must be ≥ `downstream_channel_capacity + 2`.
34
35use std::collections::HashSet;
36use std::env;
37use std::ffi::{CStr, CString, c_char, c_void};
38use std::path::{Path, PathBuf};
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::{Arc, OnceLock};
41use std::time::UNIX_EPOCH;
42
43use async_trait::async_trait;
44use cudarc::driver::{DevicePtr, DeviceSlice};
45use tokio::sync::Mutex;
46use tracing::{debug, info, warn};
47
48use ort::session::Session;
49use ort::sys as ort_sys;
50use ort::value::Value as OrtValue;
51
52use rave_core::backend::{ModelMetadata, UpscaleBackend};
53use rave_core::context::GpuContext;
54use rave_core::error::{EngineError, Result};
55use rave_core::types::{GpuTexture, PixelFormat};
56
57use ort::execution_providers::{CUDAExecutionProvider, TensorRTExecutionProvider};
58
59// Helper to create tensor from device memory using C API (Zero-Copy)
60unsafe fn create_tensor_from_device_memory(
61    ptr: *mut std::ffi::c_void,
62    bytes: usize,
63    shape: &[i64],
64    elem_type: ort::tensor::TensorElementType,
65) -> Result<OrtValue> {
66    let api = ort::api();
67
68    // Create MemoryInfo for CUDA
69    let mut mem_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
70    let name = std::ffi::CString::new("Cuda").unwrap();
71    let status = unsafe {
72        (api.CreateMemoryInfo)(
73            name.as_ptr(),
74            ort_sys::OrtAllocatorType::OrtArenaAllocator,
75            0, // device id
76            ort_sys::OrtMemType::OrtMemTypeDefault,
77            &mut mem_info_ptr,
78        )
79    };
80    if !status.0.is_null() {
81        unsafe { (api.ReleaseStatus)(status.0) };
82        return Err(EngineError::Inference(ort::Error::new(
83            "Failed to create MemoryInfo",
84        )));
85    }
86
87    // Create Tensor
88    let mut ort_value_ptr: *mut ort_sys::OrtValue = std::ptr::null_mut();
89    let status = unsafe {
90        (api.CreateTensorWithDataAsOrtValue)(
91            mem_info_ptr,
92            ptr,
93            bytes as _,
94            shape.as_ptr(),
95            shape.len() as _,
96            elem_type.into(),
97            &mut ort_value_ptr,
98        )
99    };
100
101    // Release MemoryInfo (Tensor likely executes AddRef / Copy, or we hand over ownership?
102    // ORT docs say CreateTensorWithDataAsOrtValue does NOT take ownership of mem_info,
103    // but the resulting tensor keeps a reference?
104    // Actually, usually we should release our handle if we don't need it.
105    unsafe { (api.ReleaseMemoryInfo)(mem_info_ptr) };
106
107    if !status.0.is_null() {
108        unsafe { (api.ReleaseStatus)(status.0) };
109        return Err(EngineError::Inference(ort::Error::new(
110            "Failed to create Tensor",
111        )));
112    }
113
114    // Wrap in OrtValue
115    Ok(unsafe {
116        ort::value::Value::<ort::value::DynValueTypeMarker>::from_ptr(
117            std::ptr::NonNull::new(ort_value_ptr).unwrap(),
118            None,
119        )
120    })
121}
122
123// use half::f16; // If half is dependency. ort might export it?
124// I will guess ort::f16 or just try referencing it fully qualified if needed.
125// I'll stick to basic imports for now and use full path in code if unsafe.
126
127// ─── Precision policy ───────────────────────────────────────────────────────
128
129/// TensorRT precision policy — controls EP optimization flags.
130#[derive(Clone, Debug, Default)]
131pub enum PrecisionPolicy {
132    /// FP32 only — maximum accuracy, baseline performance.
133    Fp32,
134    /// FP16 mixed precision — 2× throughput on Tensor Cores.
135    #[default]
136    Fp16,
137    /// INT8 quantized with calibration table — 4× throughput.
138    /// Requires a pre-generated calibration table path.
139    Int8 { calibration_table: PathBuf },
140}
141
142// ─── Batch config ──────────────────────────────────────────────────────────
143
144/// Batch inference configuration.
145#[derive(Clone, Debug)]
146pub struct BatchConfig {
147    /// Maximum batch size for pipelined inference.
148    /// Must be ≤ model’s max dynamic batch axis.
149    pub max_batch: usize,
150    /// Collect at most this many frames before dispatching a batch,
151    /// even if `max_batch` is not reached (latency bound).
152    pub latency_deadline_us: u64,
153}
154
155impl Default for BatchConfig {
156    fn default() -> Self {
157        Self {
158            max_batch: 1,
159            latency_deadline_us: 8_000, // 8ms — half a 60fps frame
160        }
161    }
162}
163// ─── Inference metrics ───────────────────────────────────────────────────────
164
165/// Atomic counters for inference stage observability.
166#[derive(Debug)]
167pub struct InferenceMetrics {
168    /// Total frames inferred.
169    pub frames_inferred: AtomicU64,
170    /// Cumulative inference time in microseconds (for avg latency).
171    pub total_inference_us: AtomicU64,
172    /// Peak single-frame inference time in microseconds.
173    pub peak_inference_us: AtomicU64,
174}
175
176impl InferenceMetrics {
177    pub const fn new() -> Self {
178        Self {
179            frames_inferred: AtomicU64::new(0),
180            total_inference_us: AtomicU64::new(0),
181            peak_inference_us: AtomicU64::new(0),
182        }
183    }
184
185    pub fn record(&self, elapsed_us: u64) {
186        self.frames_inferred.fetch_add(1, Ordering::Relaxed);
187        self.total_inference_us
188            .fetch_add(elapsed_us, Ordering::Relaxed);
189        self.peak_inference_us
190            .fetch_max(elapsed_us, Ordering::Relaxed);
191    }
192
193    pub fn snapshot(&self) -> InferenceMetricsSnapshot {
194        let frames = self.frames_inferred.load(Ordering::Relaxed);
195        let total = self.total_inference_us.load(Ordering::Relaxed);
196        let peak = self.peak_inference_us.load(Ordering::Relaxed);
197        InferenceMetricsSnapshot {
198            frames_inferred: frames,
199            avg_inference_us: if frames > 0 { total / frames } else { 0 },
200            peak_inference_us: peak,
201        }
202    }
203}
204
205impl Default for InferenceMetrics {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211/// Snapshot of inference metrics for reporting.
212#[derive(Clone, Debug)]
213pub struct InferenceMetricsSnapshot {
214    pub frames_inferred: u64,
215    pub avg_inference_us: u64,
216    pub peak_inference_us: u64,
217}
218
219// ─── Ring metrics ────────────────────────────────────────────────────────────
220
221/// Atomic counters for output ring buffer activity.
222#[derive(Debug)]
223pub struct RingMetrics {
224    /// Successful slot reuses (slot was free, strong_count == 1).
225    pub slot_reuse_count: AtomicU64,
226    /// Times `acquire()` found a slot still held downstream (strong_count > 1).
227    pub slot_contention_events: AtomicU64,
228    /// Times a slot was acquired but it was the first use (not a reuse).
229    pub slot_first_use_count: AtomicU64,
230}
231
232impl RingMetrics {
233    pub const fn new() -> Self {
234        Self {
235            slot_reuse_count: AtomicU64::new(0),
236            slot_contention_events: AtomicU64::new(0),
237            slot_first_use_count: AtomicU64::new(0),
238        }
239    }
240
241    pub fn snapshot(&self) -> (u64, u64, u64) {
242        (
243            self.slot_reuse_count.load(Ordering::Relaxed),
244            self.slot_contention_events.load(Ordering::Relaxed),
245            self.slot_first_use_count.load(Ordering::Relaxed),
246        )
247    }
248}
249
250impl Default for RingMetrics {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256// ─── Output ring buffer ─────────────────────────────────────────────────────
257
258/// Fixed-size ring of pre-allocated device buffers for inference output.
259pub struct OutputRing {
260    slots: Vec<Arc<cudarc::driver::CudaSlice<u8>>>,
261    cursor: usize,
262    pub slot_bytes: usize,
263    pub alloc_dims: (u32, u32),
264    /// Whether each slot has been used at least once (for first-use tracking).
265    used: Vec<bool>,
266    pub metrics: RingMetrics,
267}
268
269impl OutputRing {
270    /// Allocate `count` output buffers.
271    ///
272    /// `min_slots` is the enforced minimum (`downstream_capacity + 2`).
273    /// Returns error if `count < min_slots`.
274    pub fn new(
275        ctx: &GpuContext,
276        in_w: u32,
277        in_h: u32,
278        scale: u32,
279        count: usize,
280        min_slots: usize,
281    ) -> Result<Self> {
282        if count < min_slots {
283            return Err(EngineError::DimensionMismatch(format!(
284                "OutputRing: ring_size ({count}) < required minimum ({min_slots}). \
285                 Ring must be ≥ downstream_channel_capacity + 2."
286            )));
287        }
288        if count < 2 {
289            return Err(EngineError::DimensionMismatch(
290                "OutputRing: ring_size must be ≥ 2 for double-buffering".into(),
291            ));
292        }
293
294        let out_w = (in_w * scale) as usize;
295        let out_h = (in_h * scale) as usize;
296        let slot_bytes = 3 * out_w * out_h * std::mem::size_of::<f32>();
297
298        let slots = (0..count)
299            .map(|_| ctx.alloc(slot_bytes).map(Arc::new))
300            .collect::<Result<Vec<_>>>()?;
301
302        debug!(count, slot_bytes, out_w, out_h, "Output ring allocated");
303
304        Ok(Self {
305            slots,
306            cursor: 0,
307            slot_bytes,
308            alloc_dims: (in_w, in_h),
309            used: vec![false; count],
310            metrics: RingMetrics::new(),
311        })
312    }
313
314    /// Acquire the next ring slot for writing.
315    ///
316    /// # Serialization invariant
317    ///
318    /// Asserts `Arc::strong_count == 1` before returning.  If downstream
319    /// still holds a reference, returns error and increments contention counter.
320    pub fn acquire(&mut self) -> Result<Arc<cudarc::driver::CudaSlice<u8>>> {
321        let slot = &self.slots[self.cursor];
322        let sc = Arc::strong_count(slot);
323
324        if sc != 1 {
325            self.metrics
326                .slot_contention_events
327                .fetch_add(1, Ordering::Relaxed);
328            return Err(EngineError::BufferTooSmall {
329                need: self.slot_bytes,
330                have: 0,
331            });
332        }
333
334        // Debug assertion — belt-and-suspenders check.
335        debug_assert_eq!(
336            Arc::strong_count(slot),
337            1,
338            "OutputRing: slot {} strong_count must be 1 before reuse, got {}",
339            self.cursor,
340            sc
341        );
342
343        if self.used[self.cursor] {
344            self.metrics
345                .slot_reuse_count
346                .fetch_add(1, Ordering::Relaxed);
347        } else {
348            self.used[self.cursor] = true;
349            self.metrics
350                .slot_first_use_count
351                .fetch_add(1, Ordering::Relaxed);
352        }
353
354        let cloned = Arc::clone(slot);
355        self.cursor = (self.cursor + 1) % self.slots.len();
356        Ok(cloned)
357    }
358
359    pub fn needs_realloc(&self, in_w: u32, in_h: u32) -> bool {
360        self.alloc_dims != (in_w, in_h)
361    }
362
363    /// Reallocate all slots.  All must have `strong_count == 1`.
364    pub fn reallocate(&mut self, ctx: &GpuContext, in_w: u32, in_h: u32, scale: u32) -> Result<()> {
365        for (i, slot) in self.slots.iter().enumerate() {
366            let sc = Arc::strong_count(slot);
367            if sc != 1 {
368                return Err(EngineError::DimensionMismatch(format!(
369                    "Cannot reallocate ring: slot {} still in use (strong_count={})",
370                    i, sc,
371                )));
372            }
373        }
374
375        // Free old slots — decrement VRAM accounting.
376        for _slot in &self.slots {
377            ctx.vram_dec(self.slot_bytes);
378        }
379
380        let count = self.slots.len();
381        let out_w = (in_w * scale) as usize;
382        let out_h = (in_h * scale) as usize;
383        let slot_bytes = 3 * out_w * out_h * std::mem::size_of::<f32>();
384
385        self.slots = (0..count)
386            .map(|_| ctx.alloc(slot_bytes).map(Arc::new))
387            .collect::<Result<Vec<_>>>()?;
388        self.cursor = 0;
389        self.slot_bytes = slot_bytes;
390        self.alloc_dims = (in_w, in_h);
391        self.used = vec![false; count];
392
393        debug!(count, slot_bytes, out_w, out_h, "Output ring reallocated");
394        Ok(())
395    }
396
397    /// Total number of slots.
398    pub fn len(&self) -> usize {
399        self.slots.len()
400    }
401
402    pub fn is_empty(&self) -> bool {
403        self.slots.is_empty()
404    }
405}
406
407// ─── Inference state ─────────────────────────────────────────────────────────
408
409struct InferenceState {
410    session: Session,
411    ring: Option<OutputRing>,
412}
413
414/// Resolve ORT tensor element type from our PixelFormat.
415fn ort_element_type(format: PixelFormat) -> ort::tensor::TensorElementType {
416    match format {
417        PixelFormat::RgbPlanarF16 => ort::tensor::TensorElementType::Float16,
418        _ => ort::tensor::TensorElementType::Float32,
419    }
420}
421
422// ─── Backend ─────────────────────────────────────────────────────────────────
423
424pub struct TensorRtBackend {
425    model_path: PathBuf,
426    ctx: Arc<GpuContext>,
427    device_id: i32,
428    ring_size: usize,
429    min_ring_slots: usize,
430    meta: OnceLock<ModelMetadata>,
431    state: Mutex<Option<InferenceState>>,
432    pub inference_metrics: InferenceMetrics,
433    /// Phase 8: precision policy for TRT EP.
434    pub precision_policy: PrecisionPolicy,
435    /// Phase 8: batch configuration.
436    pub batch_config: BatchConfig,
437}
438
439#[derive(Copy, Clone, Debug, PartialEq, Eq)]
440enum OrtEpMode {
441    Auto,
442    TensorRtOnly,
443    CudaOnly,
444}
445
446#[cfg(target_os = "linux")]
447unsafe extern "C" {
448    fn dlopen(filename: *const c_char, flags: i32) -> *mut c_void;
449    fn dlerror() -> *const c_char;
450}
451
452#[cfg(target_os = "linux")]
453const RTLD_NOW: i32 = 2;
454#[cfg(all(target_os = "linux", test))]
455const RTLD_LOCAL: i32 = 0;
456#[cfg(target_os = "linux")]
457const RTLD_GLOBAL: i32 = 0x100;
458
459#[cfg(target_os = "linux")]
460#[derive(Clone, Copy)]
461enum OrtProviderKind {
462    Cuda,
463    TensorRt,
464}
465
466#[cfg(target_os = "linux")]
467impl OrtProviderKind {
468    fn soname(self) -> &'static str {
469        match self {
470            OrtProviderKind::Cuda => "libonnxruntime_providers_cuda.so",
471            OrtProviderKind::TensorRt => "libonnxruntime_providers_tensorrt.so",
472        }
473    }
474
475    fn label(self) -> &'static str {
476        match self {
477            OrtProviderKind::Cuda => "providers_cuda",
478            OrtProviderKind::TensorRt => "providers_tensorrt",
479        }
480    }
481}
482
483impl TensorRtBackend {
484    /// Create a new backend instance.
485    ///
486    /// # Parameters
487    ///
488    /// - `ring_size`: number of output ring slots to pre-allocate.
489    /// - `downstream_capacity`: the bounded channel capacity between inference
490    ///   and the encoder.  Ring size is validated ≥ `downstream_capacity + 2`.
491    pub fn new(
492        model_path: PathBuf,
493        ctx: Arc<GpuContext>,
494        device_id: i32,
495        ring_size: usize,
496        downstream_capacity: usize,
497    ) -> Self {
498        Self::with_precision(
499            model_path,
500            ctx,
501            device_id,
502            ring_size,
503            downstream_capacity,
504            PrecisionPolicy::default(),
505            BatchConfig::default(),
506        )
507    }
508
509    /// Create with explicit precision policy and batch config.
510    pub fn with_precision(
511        model_path: PathBuf,
512        ctx: Arc<GpuContext>,
513        device_id: i32,
514        ring_size: usize,
515        downstream_capacity: usize,
516        precision_policy: PrecisionPolicy,
517        batch_config: BatchConfig,
518    ) -> Self {
519        let min_ring_slots = downstream_capacity + 2;
520        assert!(
521            ring_size >= min_ring_slots,
522            "ring_size ({ring_size}) must be ≥ downstream_capacity + 2 ({min_ring_slots})"
523        );
524        Self {
525            model_path,
526            ctx,
527            device_id,
528            ring_size,
529            min_ring_slots,
530            meta: OnceLock::new(),
531            state: Mutex::new(None),
532            inference_metrics: InferenceMetrics::new(),
533            precision_policy,
534            batch_config,
535        }
536    }
537
538    /// Get or create cached ORT MemoryInfo (avoids per-frame allocation).
539    /// Create a CUDA memory info structure.
540    fn mem_info(&self) -> Result<ort::memory::MemoryInfo> {
541        ort::memory::MemoryInfo::new(
542            ort::memory::AllocationDevice::CUDA,
543            0,
544            ort::memory::AllocatorType::Device,
545            ort::memory::MemoryType::Default,
546        )
547        .map_err(|e| EngineError::ModelMetadata(format!("MemoryInfo: {e}")))
548    }
549
550    /// Access ring metrics (if initialized).
551    pub async fn ring_metrics(&self) -> Option<(u64, u64, u64)> {
552        let guard = self.state.lock().await;
553        guard
554            .as_ref()
555            .and_then(|s| s.ring.as_ref())
556            .map(|r| r.metrics.snapshot())
557    }
558
559    fn extract_metadata(session: &Session) -> Result<ModelMetadata> {
560        let inputs = session.inputs();
561        let outputs = session.outputs();
562
563        if inputs.is_empty() || outputs.is_empty() {
564            return Err(EngineError::ModelMetadata(
565                "Model must have at least one input and one output tensor".into(),
566            ));
567        }
568
569        // ... imports ...
570
571        // ...
572
573        // ... in extract_metadata ...
574        // ...
575        let input_info = &inputs[0];
576        let output_info = &outputs[0];
577        let input_name = input_info.name().to_string();
578        let output_name = output_info.name().to_string();
579
580        // ...
581
582        let input_dims = match input_info.dtype() {
583            ort::value::ValueType::Tensor { shape, .. } => shape.clone(),
584            other => {
585                return Err(EngineError::ModelMetadata(format!(
586                    "Expected tensor input, got {:?}",
587                    other
588                )));
589            }
590        };
591
592        let output_dims = match output_info.dtype() {
593            ort::value::ValueType::Tensor { shape, .. } => shape.clone(),
594            other => {
595                return Err(EngineError::ModelMetadata(format!(
596                    "Expected tensor output, got {:?}",
597                    other
598                )));
599            }
600        };
601
602        if input_dims.len() != 4 || output_dims.len() != 4 {
603            return Err(EngineError::ModelMetadata(format!(
604                "Expected 4D tensors (NCHW), got input={}D output={}D",
605                input_dims.len(),
606                output_dims.len()
607            )));
608        }
609
610        let input_channels = input_dims[1] as u32;
611
612        let ih = input_dims[2];
613        let iw = input_dims[3];
614        let oh = output_dims[2];
615        let ow = output_dims[3];
616
617        let scale = if ih > 0 && oh > 0 && iw > 0 && ow > 0 {
618            (oh / ih) as u32
619        } else {
620            warn!("Dynamic spatial axes — defaulting to scale=4");
621            4
622        };
623
624        let min_input_hw = (
625            if ih > 0 { ih as u32 } else { 1 },
626            if iw > 0 { iw as u32 } else { 1 },
627        );
628
629        let max_input_hw = (
630            if ih > 0 { ih as u32 } else { u32::MAX },
631            if iw > 0 { iw as u32 } else { u32::MAX },
632        );
633
634        let name = session
635            .metadata()
636            .map(|m| m.name().unwrap_or("unknown".to_string()))
637            .unwrap_or_else(|_| "unknown".to_string());
638
639        Ok(ModelMetadata {
640            name,
641            scale,
642            input_name,
643            output_name,
644            input_channels,
645            min_input_hw,
646            max_input_hw,
647        })
648    }
649
650    fn ort_ep_mode() -> OrtEpMode {
651        match env::var("RAVE_ORT_TENSORRT")
652            .unwrap_or_else(|_| "auto".to_string())
653            .to_lowercase()
654            .as_str()
655        {
656            "0" | "off" | "false" | "cuda" | "cuda-only" => OrtEpMode::CudaOnly,
657            "1" | "on" | "true" | "trt" | "trt-only" => OrtEpMode::TensorRtOnly,
658            _ => OrtEpMode::Auto,
659        }
660    }
661
662    #[cfg(target_os = "linux")]
663    fn is_wsl2() -> bool {
664        std::fs::read_to_string("/proc/sys/kernel/osrelease")
665            .map(|s| s.to_ascii_lowercase().contains("microsoft"))
666            .unwrap_or(false)
667    }
668
669    #[cfg(not(target_os = "linux"))]
670    fn is_wsl2() -> bool {
671        false
672    }
673
674    #[cfg(target_os = "linux")]
675    fn ort_provider_cache_dirs_newest_first() -> Vec<PathBuf> {
676        let mut dirs = Vec::<(u128, PathBuf)>::new();
677        let Some(home) = env::var_os("HOME") else {
678            return Vec::new();
679        };
680        let base = PathBuf::from(home).join(".cache/ort.pyke.io/dfbin");
681        let Ok(triples) = std::fs::read_dir(base) else {
682            return Vec::new();
683        };
684
685        for triple in triples.flatten() {
686            let triple_path = triple.path();
687            if !triple_path.is_dir() {
688                continue;
689            }
690            let Ok(hashes) = std::fs::read_dir(triple_path) else {
691                continue;
692            };
693            for hash in hashes.flatten() {
694                let path = hash.path();
695                if !path.is_dir() {
696                    continue;
697                }
698                let modified = std::fs::metadata(&path)
699                    .and_then(|m| m.modified())
700                    .ok()
701                    .and_then(|t| t.duration_since(UNIX_EPOCH).ok())
702                    .map(|d| d.as_nanos())
703                    .unwrap_or(0);
704                dirs.push((modified, path));
705            }
706        }
707
708        dirs.sort_by(|a, b| b.cmp(a));
709        dirs.into_iter().map(|(_, path)| path).collect()
710    }
711
712    #[cfg(target_os = "linux")]
713    fn ort_provider_search_dirs() -> Vec<(PathBuf, &'static str)> {
714        let mut dirs = Vec::<(PathBuf, &'static str)>::new();
715        if let Some(dir) = env::var_os("ORT_DYLIB_PATH") {
716            dirs.push((PathBuf::from(dir), "ORT_DYLIB_PATH"));
717        }
718        if let Some(dir) = env::var_os("ORT_LIB_LOCATION") {
719            dirs.push((PathBuf::from(dir), "ORT_LIB_LOCATION"));
720        }
721        if Self::is_wsl2() {
722            for dir in Self::ort_provider_cache_dirs_newest_first() {
723                dirs.push((dir, "ort_cache_newest"));
724            }
725            if let Ok(exe) = env::current_exe()
726                && let Some(dir) = exe.parent()
727            {
728                dirs.push((dir.to_path_buf(), "exe_dir"));
729                dirs.push((dir.join("deps"), "exe_dir/deps"));
730            }
731        } else {
732            if let Ok(exe) = env::current_exe()
733                && let Some(dir) = exe.parent()
734            {
735                dirs.push((dir.to_path_buf(), "exe_dir"));
736                dirs.push((dir.join("deps"), "exe_dir/deps"));
737            }
738            for dir in Self::ort_provider_cache_dirs_newest_first() {
739                dirs.push((dir, "ort_cache_newest"));
740            }
741        }
742
743        let mut uniq = HashSet::<PathBuf>::new();
744        dirs.retain(|(p, _)| uniq.insert(p.clone()));
745        dirs
746    }
747
748    #[cfg(all(target_os = "linux", test))]
749    fn ort_provider_candidates(lib_name: &str) -> Vec<PathBuf> {
750        Self::ort_provider_search_dirs()
751            .into_iter()
752            .map(|(dir, _)| dir.join(lib_name))
753            .collect()
754    }
755
756    #[cfg(target_os = "linux")]
757    fn configure_ort_loader_path(dir: &Path) {
758        // SAFETY: Process env mutation is done at startup during backend init.
759        unsafe { env::set_var("ORT_DYLIB_PATH", dir) };
760        // SAFETY: Process env mutation is done at startup during backend init.
761        unsafe { env::set_var("ORT_LIB_LOCATION", dir) };
762    }
763
764    #[cfg(target_os = "linux")]
765    fn resolve_ort_provider_dir(kind: OrtProviderKind) -> Result<(PathBuf, &'static str)> {
766        let provider = kind.soname();
767        for (dir, source) in Self::ort_provider_search_dirs() {
768            let shared = dir.join("libonnxruntime_providers_shared.so");
769            let provider_path = dir.join(provider);
770            if shared.is_file() && provider_path.is_file() {
771                info!(
772                    path = %shared.display(),
773                    source,
774                    "ORT providers_shared resolved"
775                );
776                info!(
777                    provider = kind.label(),
778                    path = %provider_path.display(),
779                    source,
780                    "ORT provider resolved"
781                );
782                Self::configure_ort_loader_path(&dir);
783                return Ok((dir, source));
784            }
785        }
786
787        Err(EngineError::ModelMetadata(format!(
788            "Could not find ORT provider pair (libonnxruntime_providers_shared.so + {}). \
789Set ORT_DYLIB_PATH or ORT_LIB_LOCATION to that directory, then re-run scripts/test_ort_provider_load.sh.",
790            provider
791        )))
792    }
793
794    #[cfg(target_os = "linux")]
795    fn dlopen_path(path: &Path, flags: i32) -> Result<()> {
796        let cpath = CString::new(path.to_string_lossy().as_bytes())
797            .map_err(|_| EngineError::ModelMetadata("Invalid dlopen path".into()))?;
798        // SAFETY: `dlopen` expects a valid, NUL-terminated C string and flags.
799        let handle = unsafe { dlopen(cpath.as_ptr(), flags) };
800        if handle.is_null() {
801            // SAFETY: `dlerror` returns a thread-local pointer or null.
802            let err = unsafe {
803                let p = dlerror();
804                if p.is_null() {
805                    "unknown dlopen error".to_string()
806                } else {
807                    CStr::from_ptr(p).to_string_lossy().to_string()
808                }
809            };
810            return Err(EngineError::ModelMetadata(format!(
811                "dlopen failed for {}: {err}",
812                path.display()
813            )));
814        }
815        Ok(())
816    }
817
818    #[cfg(target_os = "linux")]
819    fn preload_cuda_runtime_libs() {
820        // Prefer explicit toolkit paths to avoid reliance on unrelated app installs.
821        let mut roots = Vec::<PathBuf>::new();
822        if let Some(dir) = env::var_os("RAVE_CUDA_RUNTIME_DIR") {
823            roots.push(PathBuf::from(dir));
824        }
825        roots.push(PathBuf::from("/usr/local/cuda-12/targets/x86_64-linux/lib"));
826        roots.push(PathBuf::from("/usr/local/cuda/lib64"));
827
828        let libs = [
829            "libcudart.so.12",
830            "libcublasLt.so.12",
831            "libcublas.so.12",
832            "libnvrtc.so.12",
833            "libcurand.so.10",
834            "libcufft.so.11",
835        ];
836
837        for root in roots {
838            if !root.is_dir() {
839                continue;
840            }
841            let mut loaded = 0usize;
842            for lib in libs {
843                let p = root.join(lib);
844                if !p.is_file() {
845                    continue;
846                }
847                if Self::dlopen_path(&p, RTLD_NOW | RTLD_GLOBAL).is_ok() {
848                    loaded += 1;
849                }
850            }
851            if loaded > 0 {
852                info!(root = %root.display(), loaded, "Preloaded CUDA runtime libraries for ORT");
853                break;
854            }
855        }
856    }
857
858    /// When ORT is linked statically, TensorRT provider expects `Provider_GetHost`
859    /// from `libonnxruntime_providers_shared.so` to already be present in the process.
860    /// Preloading this bridge with `RTLD_GLOBAL` satisfies that symbol before TRT EP load.
861    #[cfg(target_os = "linux")]
862    fn preload_ort_provider_pair(kind: OrtProviderKind) -> Result<()> {
863        let (dir, source) = Self::resolve_ort_provider_dir(kind)?;
864        let shared = dir.join("libonnxruntime_providers_shared.so");
865        let provider = dir.join(kind.soname());
866
867        Self::dlopen_path(&shared, RTLD_NOW | RTLD_GLOBAL).map_err(|e| {
868            EngineError::ModelMetadata(format!(
869                "Failed loading providers_shared from {} ({e}). \
870Ensure ORT_DYLIB_PATH/ORT_LIB_LOCATION points to a valid ORT cache dir.",
871                shared.display()
872            ))
873        })?;
874
875        info!(
876            path = %provider.display(),
877            provider = kind.label(),
878            "Skipping explicit provider dlopen; relying on startup LD_LIBRARY_PATH + ORT registration"
879        );
880
881        info!(
882            source,
883            dir = %dir.display(),
884            provider = kind.label(),
885            path = %provider.display(),
886            "ORT provider pair prepared (providers_shared preloaded, provider path configured)"
887        );
888        Ok(())
889    }
890
891    #[cfg(all(target_os = "linux", test))]
892    fn preload_ort_provider_bridge() -> Result<()> {
893        for path in Self::ort_provider_candidates("libonnxruntime_providers_shared.so") {
894            if path.is_file() {
895                return Self::dlopen_path(&path, RTLD_NOW | RTLD_GLOBAL);
896            }
897        }
898        Err(EngineError::ModelMetadata(
899            "Could not preload libonnxruntime_providers_shared.so from known search paths".into(),
900        ))
901    }
902
903    #[cfg(all(not(target_os = "linux"), test))]
904    fn preload_ort_provider_bridge() -> Result<()> {
905        Ok(())
906    }
907
908    #[cfg(not(target_os = "linux"))]
909    fn preload_cuda_runtime_libs() {}
910
911    fn build_trt_session(&self) -> Result<Session> {
912        Self::preload_cuda_runtime_libs();
913        Self::preload_ort_provider_pair(OrtProviderKind::TensorRt)?;
914
915        let mut trt_ep = TensorRTExecutionProvider::default()
916            .with_device_id(self.device_id)
917            .with_engine_cache(true)
918            .with_engine_cache_path(
919                self.model_path
920                    .parent()
921                    .unwrap_or(&self.model_path)
922                    .join("trt_cache")
923                    .to_string_lossy()
924                    .to_string(),
925            );
926
927        match &self.precision_policy {
928            PrecisionPolicy::Fp32 => {
929                info!("TRT precision: FP32 (no mixed precision)");
930            }
931            PrecisionPolicy::Fp16 => {
932                trt_ep = trt_ep.with_fp16(true);
933                info!("TRT precision: FP16 mixed precision");
934            }
935            PrecisionPolicy::Int8 { calibration_table } => {
936                trt_ep = trt_ep.with_fp16(true).with_int8(true);
937                info!(
938                    table = %calibration_table.display(),
939                    "TRT precision: INT8 with calibration table"
940                );
941            }
942        }
943
944        Session::builder()?
945            .with_execution_providers([trt_ep.build().error_on_failure()])?
946            .with_intra_threads(1)?
947            .commit_from_file(&self.model_path)
948            .map_err(Into::into)
949    }
950
951    fn build_cuda_session(&self) -> Result<Session> {
952        Self::preload_cuda_runtime_libs();
953        Self::preload_ort_provider_pair(OrtProviderKind::Cuda)?;
954        let cuda_ep = CUDAExecutionProvider::default().with_device_id(self.device_id);
955        Session::builder()?
956            .with_execution_providers([cuda_ep.build().error_on_failure()])?
957            .with_intra_threads(1)?
958            .commit_from_file(&self.model_path)
959            .map_err(Into::into)
960    }
961
962    /// Verify that IO-bound device pointers match the source GpuTexture
963    /// and OutputRing slot pointers exactly (pointer identity check).
964    ///
965    /// Called by `run_io_bound` to audit that ORT IO Binding uses our
966    /// device pointers without any host staging or reallocation.
967    fn verify_pointer_identity(
968        input_ptr: u64,
969        output_ptr: u64,
970        input_texture: &GpuTexture,
971        ring_slot_ptr: u64,
972    ) {
973        let texture_ptr = input_texture.device_ptr();
974        debug!(
975            input_ptr = format!("0x{:016x}", input_ptr),
976            texture_ptr = format!("0x{:016x}", texture_ptr),
977            output_ptr = format!("0x{:016x}", output_ptr),
978            ring_slot_ptr = format!("0x{:016x}", ring_slot_ptr),
979            "IO-binding pointer identity audit"
980        );
981
982        debug_assert_eq!(
983            input_ptr, texture_ptr,
984            "POINTER MISMATCH: IO-bound input (0x{:016x}) != GpuTexture (0x{:016x})",
985            input_ptr, texture_ptr,
986        );
987        debug_assert_eq!(
988            output_ptr, ring_slot_ptr,
989            "POINTER MISMATCH: IO-bound output (0x{:016x}) != ring slot (0x{:016x})",
990            output_ptr, ring_slot_ptr,
991        );
992    }
993
994    fn run_io_bound(
995        session: &mut Session,
996        meta: &ModelMetadata,
997        input: &GpuTexture,
998        output_ptr: u64,
999        output_bytes: usize,
1000        _ctx: &GpuContext,
1001        _mem_info: &ort::memory::MemoryInfo, // Unused argument if we recreate it, or use it if passed.
1002    ) -> Result<()> {
1003        let in_w = input.width as i64;
1004        let in_h = input.height as i64;
1005        let out_w = in_w * meta.scale as i64;
1006        let out_h = in_h * meta.scale as i64;
1007
1008        let input_shape: Vec<i64> = vec![1, meta.input_channels as i64, in_h, in_w];
1009        let output_shape: Vec<i64> = vec![1, meta.input_channels as i64, out_h, out_w];
1010
1011        // Resolve element type from pixel format (F16 or F32).
1012        let elem_type = ort_element_type(input.format);
1013        let elem_bytes = input.format.element_bytes();
1014        let input_bytes = input.data.len();
1015
1016        let expected = (output_shape.iter().product::<i64>() as usize) * elem_bytes;
1017        if output_bytes < expected {
1018            return Err(EngineError::BufferTooSmall {
1019                need: expected,
1020                have: output_bytes,
1021            });
1022        }
1023
1024        let input_ptr = input.device_ptr();
1025
1026        // Phase 7: Pointer identity audit — verify device pointers match.
1027        Self::verify_pointer_identity(input_ptr, output_ptr, input, output_ptr);
1028
1029        let mut binding = session.create_binding()?;
1030
1031        // Create fresh MemoryInfo for this opt (cheap, avoids Sync issues with cache)
1032        // (Removed: using FFI helper instead)
1033
1034        // SAFETY — ORT IO Binding with raw device pointers:
1035        unsafe {
1036            // Create Input Tensor from raw device memory (zero-copy)
1037            let input_tensor = create_tensor_from_device_memory(
1038                input_ptr as *mut _,
1039                input_bytes,
1040                &input_shape,
1041                elem_type,
1042            )?;
1043
1044            binding.bind_input(&meta.input_name, &input_tensor)?;
1045
1046            // Create Output Tensor from raw device memory (zero-copy)
1047            let output_tensor = create_tensor_from_device_memory(
1048                output_ptr as *mut _,
1049                output_bytes,
1050                &output_shape,
1051                elem_type,
1052            )?;
1053
1054            binding.bind_output(&meta.output_name, output_tensor)?;
1055        }
1056
1057        // Synchronous execution — ORT blocks until all TensorRT kernels complete.
1058        session.run_binding(&binding)?;
1059
1060        Ok(())
1061    }
1062}
1063
1064#[cfg(all(test, target_os = "linux"))]
1065mod tests {
1066    use super::{RTLD_LOCAL, RTLD_NOW, TensorRtBackend};
1067    use rave_core::backend::UpscaleBackend;
1068    use std::env;
1069    use std::path::PathBuf;
1070
1071    #[test]
1072    #[ignore = "requires ORT TensorRT provider libs on host"]
1073    fn providers_load_with_bridge_preloaded() {
1074        // Ensure bridge is globally visible before loading TensorRT provider.
1075        TensorRtBackend::preload_ort_provider_bridge().expect("failed to preload providers_shared");
1076        let trt_candidates =
1077            TensorRtBackend::ort_provider_candidates("libonnxruntime_providers_tensorrt.so");
1078        assert!(
1079            !trt_candidates.is_empty(),
1080            "no libonnxruntime_providers_tensorrt.so candidates found"
1081        );
1082        let path = trt_candidates[0].clone();
1083        TensorRtBackend::dlopen_path(&path, RTLD_NOW | RTLD_LOCAL)
1084            .expect("providers_tensorrt should dlopen after providers_shared preload");
1085    }
1086
1087    #[test]
1088    #[ignore = "requires model + full ORT/TensorRT runtime"]
1089    fn ort_registers_tensorrt_ep_smoke() {
1090        let model = env::var("RAVE_TEST_ONNX_MODEL").expect("set RAVE_TEST_ONNX_MODEL");
1091        let backend = TensorRtBackend::new(
1092            PathBuf::from(model),
1093            rave_core::context::GpuContext::new(0).expect("cuda ctx"),
1094            0,
1095            6,
1096            4,
1097        );
1098        let rt = tokio::runtime::Builder::new_current_thread()
1099            .enable_all()
1100            .build()
1101            .expect("tokio runtime");
1102        rt.block_on(async move { backend.initialize().await })
1103            .expect("TensorRT EP registration should succeed");
1104    }
1105}
1106
1107#[async_trait]
1108impl UpscaleBackend for TensorRtBackend {
1109    async fn initialize(&self) -> Result<()> {
1110        let mut guard = self.state.lock().await;
1111        if guard.is_some() {
1112            return Err(EngineError::ModelMetadata("Already initialized".into()));
1113        }
1114
1115        let ep_mode = Self::ort_ep_mode();
1116        let on_wsl = Self::is_wsl2();
1117        info!(
1118            path = %self.model_path.display(),
1119            ?ep_mode,
1120            on_wsl,
1121            "Loading ONNX model with ORT execution provider policy"
1122        );
1123
1124        let (session, active_provider) = match ep_mode {
1125            OrtEpMode::CudaOnly => {
1126                info!("RAVE_ORT_TENSORRT disables TensorRT EP; using CUDAExecutionProvider");
1127                (self.build_cuda_session()?, "CUDAExecutionProvider")
1128            }
1129            OrtEpMode::TensorRtOnly => (self.build_trt_session()?, "TensorrtExecutionProvider"),
1130            OrtEpMode::Auto => match self.build_trt_session() {
1131                Ok(session) => (session, "TensorrtExecutionProvider"),
1132                Err(e) => {
1133                    warn!(
1134                        error = %e,
1135                        on_wsl,
1136                        "TensorRT EP registration failed; falling back to CUDAExecutionProvider"
1137                    );
1138                    (self.build_cuda_session()?, "CUDAExecutionProvider")
1139                }
1140            },
1141        };
1142        info!(
1143            provider = active_provider,
1144            "ORT execution provider selected"
1145        );
1146
1147        let metadata = Self::extract_metadata(&session)?;
1148        info!(
1149            name = %metadata.name,
1150            scale = metadata.scale,
1151            input = %metadata.input_name,
1152            output = %metadata.output_name,
1153            ring_size = self.ring_size,
1154            min_ring_slots = self.min_ring_slots,
1155            provider = active_provider,
1156            precision = ?self.precision_policy,
1157            max_batch = self.batch_config.max_batch,
1158            "Model loaded"
1159        );
1160
1161        let _ = self.meta.set(metadata);
1162
1163        *guard = Some(InferenceState {
1164            session,
1165            ring: None,
1166        });
1167
1168        Ok(())
1169    }
1170
1171    async fn process(&self, input: GpuTexture) -> Result<GpuTexture> {
1172        // Accept both F32 and F16 planar RGB.
1173        match input.format {
1174            PixelFormat::RgbPlanarF32 | PixelFormat::RgbPlanarF16 => {}
1175            other => {
1176                return Err(EngineError::FormatMismatch {
1177                    expected: PixelFormat::RgbPlanarF32,
1178                    actual: other,
1179                });
1180            }
1181        }
1182
1183        let meta = self.meta.get().ok_or(EngineError::NotInitialized)?;
1184        let mut guard = self.state.lock().await;
1185        let state = guard.as_mut().ok_or(EngineError::NotInitialized)?;
1186
1187        // Lazy ring init / realloc.
1188        match &mut state.ring {
1189            Some(ring) if ring.needs_realloc(input.width, input.height) => {
1190                debug!(
1191                    old_w = ring.alloc_dims.0,
1192                    old_h = ring.alloc_dims.1,
1193                    new_w = input.width,
1194                    new_h = input.height,
1195                    "Reallocating output ring"
1196                );
1197                ring.reallocate(&self.ctx, input.width, input.height, meta.scale)?;
1198            }
1199            None => {
1200                debug!(
1201                    w = input.width,
1202                    h = input.height,
1203                    slots = self.ring_size,
1204                    "Lazily creating output ring"
1205                );
1206                state.ring = Some(OutputRing::new(
1207                    &self.ctx,
1208                    input.width,
1209                    input.height,
1210                    meta.scale,
1211                    self.ring_size,
1212                    self.min_ring_slots,
1213                )?);
1214            }
1215            Some(_) => {}
1216        }
1217
1218        let ring = state.ring.as_mut().unwrap();
1219
1220        // Debug-mode host allocation tracking.
1221        #[cfg(feature = "debug-alloc")]
1222        {
1223            rave_core::debug_alloc::reset();
1224            rave_core::debug_alloc::enable();
1225        }
1226
1227        let output_arc = ring.acquire()?;
1228        let output_ptr = *(*output_arc).device_ptr();
1229        let output_bytes = ring.slot_bytes;
1230
1231        // ── Inference with latency measurement ──
1232        let t_start = std::time::Instant::now();
1233
1234        let mem_info = self.mem_info()?;
1235        Self::run_io_bound(
1236            &mut state.session,
1237            meta,
1238            &input,
1239            output_ptr,
1240            output_bytes,
1241            &self.ctx,
1242            &mem_info,
1243        )?;
1244
1245        let elapsed_us = t_start.elapsed().as_micros() as u64;
1246        self.inference_metrics.record(elapsed_us);
1247
1248        #[cfg(feature = "debug-alloc")]
1249        {
1250            rave_core::debug_alloc::disable();
1251            let host_allocs = rave_core::debug_alloc::count();
1252            debug_assert_eq!(
1253                host_allocs, 0,
1254                "VIOLATION: {host_allocs} host allocations during inference"
1255            );
1256        }
1257
1258        let out_w = input.width * meta.scale;
1259        let out_h = input.height * meta.scale;
1260        let elem_bytes = input.format.element_bytes();
1261
1262        Ok(GpuTexture {
1263            data: output_arc,
1264            width: out_w,
1265            height: out_h,
1266            pitch: (out_w as usize) * elem_bytes,
1267            format: input.format, // Preserve F32 or F16
1268        })
1269    }
1270
1271    async fn shutdown(&self) -> Result<()> {
1272        let mut guard = self.state.lock().await;
1273        if let Some(state) = guard.take() {
1274            info!("Shutting down TensorRT backend");
1275            self.ctx.sync_all()?;
1276
1277            // Report ring metrics.
1278            if let Some(ring) = &state.ring {
1279                let (reuse, contention, first) = ring.metrics.snapshot();
1280                info!(reuse, contention, first, "Final ring metrics");
1281            }
1282
1283            // Report inference metrics.
1284            let snap = self.inference_metrics.snapshot();
1285            info!(
1286                frames = snap.frames_inferred,
1287                avg_us = snap.avg_inference_us,
1288                peak_us = snap.peak_inference_us,
1289                precision = ?self.precision_policy,
1290                "Final inference metrics"
1291            );
1292
1293            // Report VRAM.
1294            let (current, peak) = self.ctx.vram_usage();
1295            info!(
1296                current_mb = current / (1024 * 1024),
1297                peak_mb = peak / (1024 * 1024),
1298                "Final VRAM usage"
1299            );
1300
1301            drop(state.ring);
1302            drop(state.session);
1303            debug!("TensorRT backend shutdown complete");
1304        }
1305        Ok(())
1306    }
1307
1308    fn metadata(&self) -> Result<&ModelMetadata> {
1309        self.meta.get().ok_or(EngineError::NotInitialized)
1310    }
1311}
1312
1313impl Drop for TensorRtBackend {
1314    fn drop(&mut self) {
1315        if let Ok(mut guard) = self.state.try_lock()
1316            && let Some(state) = guard.take()
1317        {
1318            let _ = self.ctx.sync_all();
1319            drop(state);
1320        }
1321    }
1322}