Skip to main content

rave_tensorrt/
tensorrt.rs

1//! TensorRT inference backend — ORT + TensorRTExecutionProvider + IO Binding.
2//!
3//! # Zero-copy contract
4//!
5//! Input and output tensors are bound to CUDA device pointers via ORT's
6//! IO Binding API.  At no point does frame data touch host memory.
7//!
8//! # Execution provider policy
9//!
10//! **Only TensorRT EP is permitted.**  CPU EP is explicitly disabled.
11//! After session creation, the provider list is validated.  If ORT falls
12//! back to CPU for any graph node, `initialize()` returns an error.
13//!
14//! # CUDA stream ordering
15//!
16//! ORT creates its own internal CUDA stream for TensorRT EP execution.
17//! We cannot inject `GpuContext::inference_stream` because `cudarc::CudaStream`
18//! does not expose its raw `CUstream` handle (`pub(crate)` field).
19//!
20//! Correctness is maintained because `session.run_with_binding()` is
21//! **synchronous** — ORT blocks the calling thread until all GPU kernels
22//! on its internal stream complete.  Therefore:
23//!
24//! 1. Output buffer is fully written when `run_with_binding()` returns.
25//! 2. CUDA global memory coherency guarantees visibility to any subsequent
26//!    reader on any stream after this synchronization point.
27//! 3. No additional inter-stream event is needed.
28//!
29//! # Output ring serialization
30//!
31//! `OutputRing` owns N pre-allocated device buffers.  `acquire()` checks
32//! `Arc::strong_count == 1` before returning a slot, guaranteeing no
33//! concurrent reader.  Ring size must be ≥ `downstream_channel_capacity + 2`.
34
35#[cfg(target_os = "linux")]
36use std::collections::HashSet;
37use std::env;
38#[cfg(target_os = "linux")]
39use std::ffi::{CStr, CString, c_char, c_void};
40#[cfg(target_os = "linux")]
41use std::path::Path;
42use std::path::PathBuf;
43use std::sync::atomic::{AtomicU64, Ordering};
44use std::sync::{Arc, OnceLock};
45#[cfg(target_os = "linux")]
46use std::time::UNIX_EPOCH;
47
48use async_trait::async_trait;
49use cudarc::driver::{DevicePtr, DeviceSlice};
50use tokio::sync::Mutex;
51use tracing::{debug, info, warn};
52
53use ort::session::Session;
54use ort::sys as ort_sys;
55use ort::value::Value as OrtValue;
56
57use rave_core::backend::{ModelMetadata, UpscaleBackend};
58use rave_core::context::GpuContext;
59use rave_core::error::{EngineError, Result};
60use rave_core::types::{GpuTexture, PixelFormat};
61
62use ort::execution_providers::{CUDAExecutionProvider, TensorRTExecutionProvider};
63
64// Helper to create tensor from device memory using C API (Zero-Copy)
65unsafe fn create_tensor_from_device_memory(
66    ptr: *mut std::ffi::c_void,
67    bytes: usize,
68    shape: &[i64],
69    elem_type: ort::tensor::TensorElementType,
70) -> Result<OrtValue> {
71    let api = ort::api();
72
73    // Create MemoryInfo for CUDA
74    let mut mem_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
75    let name = std::ffi::CString::new("Cuda").unwrap();
76    let status = unsafe {
77        (api.CreateMemoryInfo)(
78            name.as_ptr(),
79            ort_sys::OrtAllocatorType::OrtArenaAllocator,
80            0, // device id
81            ort_sys::OrtMemType::OrtMemTypeDefault,
82            &mut mem_info_ptr,
83        )
84    };
85    if !status.0.is_null() {
86        unsafe { (api.ReleaseStatus)(status.0) };
87        return Err(EngineError::Inference(ort::Error::new(
88            "Failed to create MemoryInfo",
89        )));
90    }
91
92    // Create Tensor
93    let mut ort_value_ptr: *mut ort_sys::OrtValue = std::ptr::null_mut();
94    // SAFETY: mem_info_ptr is a valid OrtMemoryInfo created above; ptr is a
95    // caller-supplied device pointer of at least `bytes` bytes; shape is a
96    // valid non-empty slice of i64 dimension values; ort_value_ptr is an
97    // initialized null pointer used as the output parameter.
98    let status = unsafe {
99        (api.CreateTensorWithDataAsOrtValue)(
100            mem_info_ptr,
101            ptr,
102            bytes as _,
103            shape.as_ptr(),
104            shape.len() as _,
105            elem_type.into(),
106            &mut ort_value_ptr,
107        )
108    };
109
110    // ORT API contract: CreateTensorWithDataAsOrtValue does NOT take ownership
111    // of mem_info — it copies what it needs internally. We release our handle
112    // immediately after tensor creation, before any error check, as required.
113    unsafe { (api.ReleaseMemoryInfo)(mem_info_ptr) };
114
115    if !status.0.is_null() {
116        unsafe { (api.ReleaseStatus)(status.0) };
117        return Err(EngineError::Inference(ort::Error::new(
118            "Failed to create Tensor",
119        )));
120    }
121
122    // Wrap in OrtValue
123    Ok(unsafe {
124        ort::value::Value::<ort::value::DynValueTypeMarker>::from_ptr(
125            std::ptr::NonNull::new(ort_value_ptr).unwrap(),
126            None,
127        )
128    })
129}
130
131// ─── Precision policy ───────────────────────────────────────────────────────
132
133/// TensorRT precision policy — controls EP optimization flags.
134#[derive(Clone, Debug, Default)]
135pub enum PrecisionPolicy {
136    /// FP32 only — maximum accuracy, baseline performance.
137    Fp32,
138    /// FP16 mixed precision — 2× throughput on Tensor Cores.
139    #[default]
140    Fp16,
141    /// INT8 quantized with calibration table — 4× throughput.
142    /// Requires a pre-generated calibration table path.
143    Int8 { calibration_table: PathBuf },
144}
145
146// ─── Batch config ──────────────────────────────────────────────────────────
147
148/// Batch inference configuration.
149#[derive(Clone, Debug)]
150pub struct BatchConfig {
151    /// Maximum batch size for pipelined inference.
152    /// Must be ≤ model’s max dynamic batch axis.
153    pub max_batch: usize,
154    /// Collect at most this many frames before dispatching a batch,
155    /// even if `max_batch` is not reached (latency bound).
156    pub latency_deadline_us: u64,
157}
158
159impl Default for BatchConfig {
160    fn default() -> Self {
161        Self {
162            max_batch: 1,
163            latency_deadline_us: 8_000, // 8ms — half a 60fps frame
164        }
165    }
166}
167
168/// Validate a [`BatchConfig`], returning an error if `max_batch > 1`.
169///
170/// Micro-batching is not yet implemented; this enforces the `max_batch = 1` requirement.
171pub fn validate_batch_config(cfg: &BatchConfig) -> Result<()> {
172    if cfg.max_batch > 1 {
173        return Err(EngineError::InvariantViolation(
174            "micro-batching is not implemented; max_batch must be 1 (set max_batch=1)".into(),
175        ));
176    }
177    Ok(())
178}
179// ─── Inference metrics ───────────────────────────────────────────────────────
180
181/// Atomic counters for inference stage observability.
182#[derive(Debug)]
183pub struct InferenceMetrics {
184    /// Total frames inferred.
185    pub frames_inferred: AtomicU64,
186    /// Cumulative inference time in microseconds (for avg latency).
187    pub total_inference_us: AtomicU64,
188    /// Peak single-frame inference time in microseconds.
189    pub peak_inference_us: AtomicU64,
190}
191
192impl InferenceMetrics {
193    /// Create a new zeroed [`InferenceMetrics`].
194    pub const fn new() -> Self {
195        Self {
196            frames_inferred: AtomicU64::new(0),
197            total_inference_us: AtomicU64::new(0),
198            peak_inference_us: AtomicU64::new(0),
199        }
200    }
201
202    /// Record a single frame's inference latency in microseconds.
203    pub fn record(&self, elapsed_us: u64) {
204        self.frames_inferred.fetch_add(1, Ordering::Relaxed);
205        self.total_inference_us
206            .fetch_add(elapsed_us, Ordering::Relaxed);
207        self.peak_inference_us
208            .fetch_max(elapsed_us, Ordering::Relaxed);
209    }
210
211    /// Return a point-in-time snapshot of all inference counters.
212    pub fn snapshot(&self) -> InferenceMetricsSnapshot {
213        let frames = self.frames_inferred.load(Ordering::Relaxed);
214        let total = self.total_inference_us.load(Ordering::Relaxed);
215        let peak = self.peak_inference_us.load(Ordering::Relaxed);
216        InferenceMetricsSnapshot {
217            frames_inferred: frames,
218            avg_inference_us: if frames > 0 { total / frames } else { 0 },
219            peak_inference_us: peak,
220        }
221    }
222}
223
224impl Default for InferenceMetrics {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230/// Snapshot of inference metrics for reporting.
231#[derive(Clone, Debug)]
232pub struct InferenceMetricsSnapshot {
233    /// Total frames inferred.
234    pub frames_inferred: u64,
235    /// Average inference latency in microseconds.
236    pub avg_inference_us: u64,
237    /// Peak single-frame inference latency in microseconds.
238    pub peak_inference_us: u64,
239}
240
241// ─── Ring metrics ────────────────────────────────────────────────────────────
242
243/// A point-in-time snapshot of [`RingMetrics`] counters.
244#[derive(Debug, Clone, Copy)]
245pub struct RingMetricsSnapshot {
246    /// Number of times a slot was acquired and was already free (strong_count == 1).
247    pub reuse: u64,
248    /// Number of times `acquire()` found a slot still held downstream (strong_count > 1).
249    pub contention: u64,
250    /// Number of times a slot was used for the very first time (not a reuse).
251    pub first_use: u64,
252}
253
254/// Atomic counters for output ring buffer activity.
255#[derive(Debug)]
256pub struct RingMetrics {
257    /// Successful slot reuses (slot was free, strong_count == 1).
258    pub slot_reuse_count: AtomicU64,
259    /// Times `acquire()` found a slot still held downstream (strong_count > 1).
260    pub slot_contention_events: AtomicU64,
261    /// Times a slot was acquired but it was the first use (not a reuse).
262    pub slot_first_use_count: AtomicU64,
263}
264
265impl RingMetrics {
266    /// Create a new zeroed [`RingMetrics`].
267    pub const fn new() -> Self {
268        Self {
269            slot_reuse_count: AtomicU64::new(0),
270            slot_contention_events: AtomicU64::new(0),
271            slot_first_use_count: AtomicU64::new(0),
272        }
273    }
274
275    /// Returns a point-in-time snapshot of all ring counters.
276    pub fn snapshot(&self) -> RingMetricsSnapshot {
277        RingMetricsSnapshot {
278            reuse: self.slot_reuse_count.load(Ordering::Relaxed),
279            contention: self.slot_contention_events.load(Ordering::Relaxed),
280            first_use: self.slot_first_use_count.load(Ordering::Relaxed),
281        }
282    }
283}
284
285impl Default for RingMetrics {
286    fn default() -> Self {
287        Self::new()
288    }
289}
290
291// ─── Output ring buffer ─────────────────────────────────────────────────────
292
293/// Fixed-size ring of pre-allocated device buffers for inference output.
294pub struct OutputRing {
295    slots: Vec<Arc<cudarc::driver::CudaSlice<u8>>>,
296    cursor: usize,
297    /// Size of each slot in bytes (`3 × out_w × out_h × sizeof(f32)`).
298    pub slot_bytes: usize,
299    /// Input dimensions `(width, height)` used to allocate the current slots.
300    pub alloc_dims: (u32, u32),
301    /// Whether each slot has been used at least once (for first-use tracking).
302    used: Vec<bool>,
303    /// Atomic counters tracking reuse, contention, and first-use events.
304    pub metrics: RingMetrics,
305}
306
307impl OutputRing {
308    /// Allocate `count` output buffers.
309    ///
310    /// `min_slots` is the enforced minimum (`downstream_capacity + 2`).
311    /// Returns error if `count < min_slots`.
312    pub fn new(
313        ctx: &GpuContext,
314        in_w: u32,
315        in_h: u32,
316        scale: u32,
317        count: usize,
318        min_slots: usize,
319    ) -> Result<Self> {
320        if count < min_slots {
321            return Err(EngineError::DimensionMismatch(format!(
322                "OutputRing: ring_size ({count}) < required minimum ({min_slots}). \
323                 Ring must be ≥ downstream_channel_capacity + 2."
324            )));
325        }
326        if count < 2 {
327            return Err(EngineError::DimensionMismatch(
328                "OutputRing: ring_size must be ≥ 2 for double-buffering".into(),
329            ));
330        }
331
332        let out_w = (in_w * scale) as usize;
333        let out_h = (in_h * scale) as usize;
334        let slot_bytes = 3 * out_w * out_h * std::mem::size_of::<f32>();
335
336        let slots = (0..count)
337            .map(|_| ctx.alloc(slot_bytes).map(Arc::new))
338            .collect::<Result<Vec<_>>>()?;
339
340        debug!(count, slot_bytes, out_w, out_h, "Output ring allocated");
341
342        Ok(Self {
343            slots,
344            cursor: 0,
345            slot_bytes,
346            alloc_dims: (in_w, in_h),
347            used: vec![false; count],
348            metrics: RingMetrics::new(),
349        })
350    }
351
352    /// Acquire the next ring slot for writing.
353    ///
354    /// # Serialization invariant
355    ///
356    /// Asserts `Arc::strong_count == 1` before returning.  If downstream
357    /// still holds a reference, returns error and increments contention counter.
358    pub fn acquire(&mut self) -> Result<Arc<cudarc::driver::CudaSlice<u8>>> {
359        let slot = &self.slots[self.cursor];
360        let sc = Arc::strong_count(slot);
361
362        if sc != 1 {
363            self.metrics
364                .slot_contention_events
365                .fetch_add(1, Ordering::Relaxed);
366            return Err(EngineError::BufferTooSmall {
367                need: self.slot_bytes,
368                have: 0,
369            });
370        }
371
372        // Debug assertion — belt-and-suspenders check.
373        debug_assert_eq!(
374            Arc::strong_count(slot),
375            1,
376            "OutputRing: slot {} strong_count must be 1 before reuse, got {}",
377            self.cursor,
378            sc
379        );
380
381        if self.used[self.cursor] {
382            self.metrics
383                .slot_reuse_count
384                .fetch_add(1, Ordering::Relaxed);
385        } else {
386            self.used[self.cursor] = true;
387            self.metrics
388                .slot_first_use_count
389                .fetch_add(1, Ordering::Relaxed);
390        }
391
392        let cloned = Arc::clone(slot);
393        self.cursor = (self.cursor + 1) % self.slots.len();
394        Ok(cloned)
395    }
396
397    /// Return `true` if the current slot dimensions differ from `(in_w, in_h)`.
398    pub fn needs_realloc(&self, in_w: u32, in_h: u32) -> bool {
399        self.alloc_dims != (in_w, in_h)
400    }
401
402    /// Reallocate all slots.  All must have `strong_count == 1`.
403    pub fn reallocate(&mut self, ctx: &GpuContext, in_w: u32, in_h: u32, scale: u32) -> Result<()> {
404        for (i, slot) in self.slots.iter().enumerate() {
405            let sc = Arc::strong_count(slot);
406            if sc != 1 {
407                return Err(EngineError::DimensionMismatch(format!(
408                    "Cannot reallocate ring: slot {} still in use (strong_count={})",
409                    i, sc,
410                )));
411            }
412        }
413
414        // Free old slots — decrement VRAM accounting.
415        for _slot in &self.slots {
416            ctx.vram_dec(self.slot_bytes);
417        }
418
419        let count = self.slots.len();
420        let out_w = (in_w * scale) as usize;
421        let out_h = (in_h * scale) as usize;
422        let slot_bytes = 3 * out_w * out_h * std::mem::size_of::<f32>();
423
424        self.slots = (0..count)
425            .map(|_| ctx.alloc(slot_bytes).map(Arc::new))
426            .collect::<Result<Vec<_>>>()?;
427        self.cursor = 0;
428        self.slot_bytes = slot_bytes;
429        self.alloc_dims = (in_w, in_h);
430        self.used = vec![false; count];
431
432        debug!(count, slot_bytes, out_w, out_h, "Output ring reallocated");
433        Ok(())
434    }
435
436    /// Total number of slots.
437    pub fn len(&self) -> usize {
438        self.slots.len()
439    }
440
441    /// Return `true` if the ring has no slots (should not happen in normal operation).
442    pub fn is_empty(&self) -> bool {
443        self.slots.is_empty()
444    }
445}
446
447// ─── Inference state ─────────────────────────────────────────────────────────
448
449struct InferenceState {
450    session: Session,
451    ring: Option<OutputRing>,
452}
453
454/// Resolve ORT tensor element type from our PixelFormat.
455fn ort_element_type(format: PixelFormat) -> ort::tensor::TensorElementType {
456    match format {
457        PixelFormat::RgbPlanarF16 => ort::tensor::TensorElementType::Float16,
458        _ => ort::tensor::TensorElementType::Float32,
459    }
460}
461
462// ─── Backend ─────────────────────────────────────────────────────────────────
463
464/// TensorRT/CUDA ORT inference backend.
465///
466/// Implements [`UpscaleBackend`](rave_core::backend::UpscaleBackend) using
467/// ONNX Runtime with a TensorRT or CUDA execution provider.  Output buffers
468/// are managed via a fixed-size [`OutputRing`] to avoid per-frame allocation.
469pub struct TensorRtBackend {
470    model_path: PathBuf,
471    ctx: Arc<GpuContext>,
472    device_id: i32,
473    ring_size: usize,
474    min_ring_slots: usize,
475    meta: OnceLock<ModelMetadata>,
476    selected_provider: OnceLock<String>,
477    state: Mutex<Option<InferenceState>>,
478    /// Atomic inference latency and frame count metrics.
479    pub inference_metrics: InferenceMetrics,
480    /// Precision policy used when building the TensorRT EP session.
481    pub precision_policy: PrecisionPolicy,
482    /// Batch configuration (must have `max_batch = 1` — batching not yet implemented).
483    pub batch_config: BatchConfig,
484}
485
486#[derive(Copy, Clone, Debug, PartialEq, Eq)]
487enum OrtEpMode {
488    Auto,
489    TensorRtOnly,
490    CudaOnly,
491}
492
493#[cfg(target_os = "linux")]
494unsafe extern "C" {
495    fn dlopen(filename: *const c_char, flags: i32) -> *mut c_void;
496    fn dlerror() -> *const c_char;
497}
498
499#[cfg(target_os = "linux")]
500const RTLD_NOW: i32 = 2;
501#[cfg(all(target_os = "linux", test))]
502const RTLD_LOCAL: i32 = 0;
503#[cfg(target_os = "linux")]
504const RTLD_GLOBAL: i32 = 0x100;
505
506#[derive(Clone, Copy)]
507enum OrtProviderKind {
508    Cuda,
509    TensorRt,
510}
511
512#[cfg(target_os = "linux")]
513type ProviderCandidate = (PathBuf, &'static str);
514
515impl OrtProviderKind {
516    #[cfg(target_os = "linux")]
517    fn soname(self) -> &'static str {
518        match self {
519            OrtProviderKind::Cuda => "libonnxruntime_providers_cuda.so",
520            OrtProviderKind::TensorRt => "libonnxruntime_providers_tensorrt.so",
521        }
522    }
523
524    #[cfg(target_os = "linux")]
525    fn label(self) -> &'static str {
526        match self {
527            OrtProviderKind::Cuda => "providers_cuda",
528            OrtProviderKind::TensorRt => "providers_tensorrt",
529        }
530    }
531}
532
533impl TensorRtBackend {
534    /// Create a new backend instance.
535    ///
536    /// # Parameters
537    ///
538    /// - `ring_size`: number of output ring slots to pre-allocate.
539    /// - `downstream_capacity`: the bounded channel capacity between inference
540    ///   and the encoder.  Ring size is validated ≥ `downstream_capacity + 2`.
541    pub fn new(
542        model_path: PathBuf,
543        ctx: Arc<GpuContext>,
544        device_id: i32,
545        ring_size: usize,
546        downstream_capacity: usize,
547    ) -> Self {
548        Self::with_precision(
549            model_path,
550            ctx,
551            device_id,
552            ring_size,
553            downstream_capacity,
554            PrecisionPolicy::default(),
555            BatchConfig::default(),
556        )
557    }
558
559    /// Create with explicit precision policy and batch config.
560    pub fn with_precision(
561        model_path: PathBuf,
562        ctx: Arc<GpuContext>,
563        device_id: i32,
564        ring_size: usize,
565        downstream_capacity: usize,
566        precision_policy: PrecisionPolicy,
567        batch_config: BatchConfig,
568    ) -> Self {
569        let min_ring_slots = downstream_capacity + 2;
570        assert!(
571            ring_size >= min_ring_slots,
572            "ring_size ({ring_size}) must be ≥ downstream_capacity + 2 ({min_ring_slots})"
573        );
574        Self {
575            model_path,
576            ctx,
577            device_id,
578            ring_size,
579            min_ring_slots,
580            meta: OnceLock::new(),
581            selected_provider: OnceLock::new(),
582            state: Mutex::new(None),
583            inference_metrics: InferenceMetrics::new(),
584            precision_policy,
585            batch_config,
586        }
587    }
588
589    /// Get or create cached ORT MemoryInfo (avoids per-frame allocation).
590    /// Create a CUDA memory info structure.
591    fn mem_info(&self) -> Result<ort::memory::MemoryInfo> {
592        ort::memory::MemoryInfo::new(
593            ort::memory::AllocationDevice::CUDA,
594            0,
595            ort::memory::AllocatorType::Device,
596            ort::memory::MemoryType::Default,
597        )
598        .map_err(|e| EngineError::ModelMetadata(format!("MemoryInfo: {e}")))
599    }
600
601    /// Access ring metrics (if initialized).
602    pub async fn ring_metrics(&self) -> Option<RingMetricsSnapshot> {
603        let guard = self.state.lock().await;
604        guard
605            .as_ref()
606            .and_then(|s| s.ring.as_ref())
607            .map(|r| r.metrics.snapshot())
608    }
609
610    /// Active ORT execution provider selected during initialization.
611    pub fn selected_provider(&self) -> Option<&str> {
612        self.selected_provider.get().map(String::as_str)
613    }
614
615    fn infer_scale_from_name(name: &str) -> Option<u32> {
616        let lower = name.to_ascii_lowercase();
617        let bytes = lower.as_bytes();
618        for i in 0..bytes.len() {
619            if bytes[i] != b'x' {
620                continue;
621            }
622            if i > 0 && bytes[i - 1].is_ascii_digit() {
623                let d = (bytes[i - 1] - b'0') as u32;
624                if d >= 2 {
625                    return Some(d);
626                }
627            }
628            if i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
629                let d = (bytes[i + 1] - b'0') as u32;
630                if d >= 2 {
631                    return Some(d);
632                }
633            }
634        }
635        None
636    }
637
638    fn extract_metadata(session: &Session) -> Result<ModelMetadata> {
639        let inputs = session.inputs();
640        let outputs = session.outputs();
641
642        if inputs.is_empty() || outputs.is_empty() {
643            return Err(EngineError::ModelMetadata(
644                "Model must have at least one input and one output tensor".into(),
645            ));
646        }
647
648        // ... imports ...
649
650        // ...
651
652        // ... in extract_metadata ...
653        // ...
654        let input_info = &inputs[0];
655        let output_info = &outputs[0];
656        let input_name = input_info.name().to_string();
657        let output_name = output_info.name().to_string();
658
659        // ...
660
661        let input_dims = match input_info.dtype() {
662            ort::value::ValueType::Tensor { shape, .. } => shape.clone(),
663            other => {
664                return Err(EngineError::ModelMetadata(format!(
665                    "Expected tensor input, got {:?}",
666                    other
667                )));
668            }
669        };
670
671        let output_dims = match output_info.dtype() {
672            ort::value::ValueType::Tensor { shape, .. } => shape.clone(),
673            other => {
674                return Err(EngineError::ModelMetadata(format!(
675                    "Expected tensor output, got {:?}",
676                    other
677                )));
678            }
679        };
680
681        if input_dims.len() != 4 || output_dims.len() != 4 {
682            return Err(EngineError::ModelMetadata(format!(
683                "Expected 4D tensors (NCHW), got input={}D output={}D",
684                input_dims.len(),
685                output_dims.len()
686            )));
687        }
688
689        let input_channels = input_dims[1] as u32;
690
691        let ih = input_dims[2];
692        let iw = input_dims[3];
693        let oh = output_dims[2];
694        let ow = output_dims[3];
695
696        let name = session
697            .metadata()
698            .map(|m| m.name().unwrap_or("unknown".to_string()))
699            .unwrap_or_else(|_| "unknown".to_string());
700
701        let scale = if ih > 0 && oh > 0 && iw > 0 && ow > 0 {
702            (oh / ih) as u32
703        } else if let Some(inferred) = Self::infer_scale_from_name(&name) {
704            warn!(
705                model_name = %name,
706                inferred_scale = inferred,
707                "Dynamic spatial axes — inferring upscale scale from model name"
708            );
709            inferred
710        } else {
711            warn!(
712                model_name = %name,
713                "Dynamic spatial axes — unable to infer scale from metadata; defaulting to scale=2"
714            );
715            2
716        };
717
718        let min_input_hw = (
719            if ih > 0 { ih as u32 } else { 1 },
720            if iw > 0 { iw as u32 } else { 1 },
721        );
722
723        let max_input_hw = (
724            if ih > 0 { ih as u32 } else { u32::MAX },
725            if iw > 0 { iw as u32 } else { u32::MAX },
726        );
727
728        Ok(ModelMetadata {
729            name,
730            scale,
731            input_name,
732            output_name,
733            input_channels,
734            min_input_hw,
735            max_input_hw,
736        })
737    }
738
739    fn ort_ep_mode() -> OrtEpMode {
740        match env::var("RAVE_ORT_TENSORRT")
741            .unwrap_or_else(|_| "auto".to_string())
742            .to_lowercase()
743            .as_str()
744        {
745            "0" | "off" | "false" | "cuda" | "cuda-only" => OrtEpMode::CudaOnly,
746            "1" | "on" | "true" | "trt" | "trt-only" => OrtEpMode::TensorRtOnly,
747            _ => OrtEpMode::Auto,
748        }
749    }
750
751    #[cfg(target_os = "linux")]
752    fn is_wsl2() -> bool {
753        std::fs::read_to_string("/proc/sys/kernel/osrelease")
754            .map(|s| s.to_ascii_lowercase().contains("microsoft"))
755            .unwrap_or(false)
756    }
757
758    #[cfg(not(target_os = "linux"))]
759    fn is_wsl2() -> bool {
760        false
761    }
762
763    #[cfg(target_os = "linux")]
764    fn ort_provider_cache_dirs_newest_first() -> Vec<PathBuf> {
765        let mut dirs = Vec::<(u128, PathBuf)>::new();
766        let Some(home) = env::var_os("HOME") else {
767            return Vec::new();
768        };
769        let base = PathBuf::from(home).join(".cache/ort.pyke.io/dfbin");
770        let Ok(triples) = std::fs::read_dir(base) else {
771            return Vec::new();
772        };
773
774        for triple in triples.flatten() {
775            let triple_path = triple.path();
776            if !triple_path.is_dir() {
777                continue;
778            }
779            let Ok(hashes) = std::fs::read_dir(triple_path) else {
780                continue;
781            };
782            for hash in hashes.flatten() {
783                let path = hash.path();
784                if !path.is_dir() {
785                    continue;
786                }
787                let modified = std::fs::metadata(&path)
788                    .and_then(|m| m.modified())
789                    .ok()
790                    .and_then(|t| t.duration_since(UNIX_EPOCH).ok())
791                    .map(|d| d.as_nanos())
792                    .unwrap_or(0);
793                dirs.push((modified, path));
794            }
795        }
796
797        dirs.sort_by(|a, b| b.cmp(a));
798        dirs.into_iter().map(|(_, path)| path).collect()
799    }
800
801    #[cfg(target_os = "linux")]
802    fn provider_dir_candidates(
803        ort_dylib_path: Option<PathBuf>,
804        ort_lib_location: Option<PathBuf>,
805        exe_dir: Option<PathBuf>,
806        cache_dirs: Vec<PathBuf>,
807        on_wsl: bool,
808    ) -> Vec<ProviderCandidate> {
809        let mut dirs = Vec::<ProviderCandidate>::new();
810        if let Some(dir) = ort_dylib_path {
811            dirs.push((dir, "ORT_DYLIB_PATH"));
812        }
813        if let Some(dir) = ort_lib_location {
814            dirs.push((dir, "ORT_LIB_LOCATION"));
815        }
816        if on_wsl {
817            for dir in cache_dirs {
818                dirs.push((dir, "ort_cache_newest"));
819            }
820            if let Some(dir) = exe_dir {
821                dirs.push((dir.clone(), "exe_dir"));
822                dirs.push((dir.join("deps"), "exe_dir/deps"));
823            }
824        } else {
825            if let Some(dir) = exe_dir {
826                dirs.push((dir.clone(), "exe_dir"));
827                dirs.push((dir.join("deps"), "exe_dir/deps"));
828            }
829            for dir in cache_dirs {
830                dirs.push((dir, "ort_cache_newest"));
831            }
832        }
833
834        let mut uniq = HashSet::<PathBuf>::new();
835        dirs.retain(|(p, _)| uniq.insert(p.clone()));
836        dirs
837    }
838
839    #[cfg(target_os = "linux")]
840    fn ort_provider_search_inputs() -> (
841        Option<PathBuf>,
842        Option<PathBuf>,
843        Option<PathBuf>,
844        Vec<PathBuf>,
845        bool,
846    ) {
847        let ort_dylib_path = env::var_os("ORT_DYLIB_PATH").map(PathBuf::from);
848        let ort_lib_location = env::var_os("ORT_LIB_LOCATION").map(PathBuf::from);
849        let exe_dir = env::current_exe()
850            .ok()
851            .and_then(|exe| exe.parent().map(|dir| dir.to_path_buf()));
852        let cache_dirs = Self::ort_provider_cache_dirs_newest_first();
853        let on_wsl = Self::is_wsl2();
854        (
855            ort_dylib_path,
856            ort_lib_location,
857            exe_dir,
858            cache_dirs,
859            on_wsl,
860        )
861    }
862
863    #[cfg(all(target_os = "linux", test))]
864    fn ort_provider_search_dirs() -> Vec<ProviderCandidate> {
865        let (ort_dylib_path, ort_lib_location, exe_dir, cache_dirs, on_wsl) =
866            Self::ort_provider_search_inputs();
867        Self::provider_dir_candidates(
868            ort_dylib_path,
869            ort_lib_location,
870            exe_dir,
871            cache_dirs,
872            on_wsl,
873        )
874    }
875
876    #[cfg(target_os = "linux")]
877    fn render_provider_candidates(candidates: &[ProviderCandidate]) -> String {
878        if candidates.is_empty() {
879            return "none".into();
880        }
881        candidates
882            .iter()
883            .enumerate()
884            .map(|(idx, (dir, source))| format!("{}:{source}:{}", idx + 1, dir.display()))
885            .collect::<Vec<_>>()
886            .join(", ")
887    }
888
889    #[cfg(all(target_os = "linux", test))]
890    fn ort_provider_candidates(lib_name: &str) -> Vec<PathBuf> {
891        Self::ort_provider_search_dirs()
892            .into_iter()
893            .map(|(dir, _)| dir.join(lib_name))
894            .collect()
895    }
896
897    #[cfg(target_os = "linux")]
898    fn configure_ort_loader_path(dir: &Path) {
899        // SAFETY: Process env mutation is done at startup during backend init.
900        unsafe { env::set_var("ORT_DYLIB_PATH", dir) };
901        // SAFETY: Process env mutation is done at startup during backend init.
902        unsafe { env::set_var("ORT_LIB_LOCATION", dir) };
903    }
904
905    #[cfg(target_os = "linux")]
906    fn resolve_provider_dir(
907        kind: OrtProviderKind,
908        candidates: &[ProviderCandidate],
909        ort_dylib_path: Option<&Path>,
910        ort_lib_location: Option<&Path>,
911        command: &str,
912        mode: &str,
913        ep_mode: &str,
914    ) -> Result<(PathBuf, &'static str)> {
915        let provider = kind.soname();
916        for (dir, source) in candidates {
917            let shared = dir.join("libonnxruntime_providers_shared.so");
918            let provider_path = dir.join(provider);
919            if shared.is_file() && provider_path.is_file() {
920                info!(
921                    path = %shared.display(),
922                    source = *source,
923                    "ORT providers_shared resolved"
924                );
925                info!(
926                    provider = kind.label(),
927                    path = %provider_path.display(),
928                    source = *source,
929                    "ORT provider resolved"
930                );
931                return Ok((dir.clone(), *source));
932            }
933        }
934
935        let render_path = |v: Option<&Path>| -> String {
936            v.map(|p| p.display().to_string())
937                .unwrap_or_else(|| "<unset>".to_string())
938        };
939        let checked = Self::render_provider_candidates(candidates);
940        Err(EngineError::ModelMetadata(format!(
941            "ORT provider directory resolution failed: command={command} mode={mode} ep_mode={ep_mode} provider={} \
942reason=no candidate directory contained required pair \
943(libonnxruntime_providers_shared.so + {}). \
944overrides={{ORT_DYLIB_PATH={}, ORT_LIB_LOCATION={}}}. \
945candidates_checked=[{}]. \
946Set ORT_DYLIB_PATH or ORT_LIB_LOCATION to a valid provider directory and re-run scripts/test_ort_provider_load.sh.",
947            kind.label(),
948            provider,
949            render_path(ort_dylib_path),
950            render_path(ort_lib_location),
951            checked
952        )))
953    }
954
955    #[cfg(target_os = "linux")]
956    fn resolve_ort_provider_dir(kind: OrtProviderKind) -> Result<(PathBuf, &'static str)> {
957        let (ort_dylib_path, ort_lib_location, exe_dir, cache_dirs, on_wsl) =
958            Self::ort_provider_search_inputs();
959        let candidates = Self::provider_dir_candidates(
960            ort_dylib_path.clone(),
961            ort_lib_location.clone(),
962            exe_dir,
963            cache_dirs,
964            on_wsl,
965        );
966        let (dir, source) = Self::resolve_provider_dir(
967            kind,
968            &candidates,
969            ort_dylib_path.as_deref(),
970            ort_lib_location.as_deref(),
971            "backend_initialize",
972            "provider_preload",
973            kind.label(),
974        )?;
975        Self::configure_ort_loader_path(&dir);
976        Ok((dir, source))
977    }
978
979    #[cfg(target_os = "linux")]
980    fn dlopen_path(path: &Path, flags: i32) -> Result<()> {
981        let cpath = CString::new(path.to_string_lossy().as_bytes())
982            .map_err(|_| EngineError::ModelMetadata("Invalid dlopen path".into()))?;
983        // SAFETY: `dlopen` expects a valid, NUL-terminated C string and flags.
984        let handle = unsafe { dlopen(cpath.as_ptr(), flags) };
985        if handle.is_null() {
986            // SAFETY: `dlerror` returns a thread-local pointer or null.
987            let err = unsafe {
988                let p = dlerror();
989                if p.is_null() {
990                    "unknown dlopen error".to_string()
991                } else {
992                    CStr::from_ptr(p).to_string_lossy().to_string()
993                }
994            };
995            return Err(EngineError::ModelMetadata(format!(
996                "dlopen failed for {}: {err}",
997                path.display()
998            )));
999        }
1000        Ok(())
1001    }
1002
1003    #[cfg(target_os = "linux")]
1004    fn preload_cuda_runtime_libs() {
1005        // Prefer explicit toolkit paths to avoid reliance on unrelated app installs.
1006        let mut roots = Vec::<PathBuf>::new();
1007        if let Some(dir) = env::var_os("RAVE_CUDA_RUNTIME_DIR") {
1008            roots.push(PathBuf::from(dir));
1009        }
1010        roots.push(PathBuf::from("/usr/local/cuda-12/targets/x86_64-linux/lib"));
1011        roots.push(PathBuf::from("/usr/local/cuda/lib64"));
1012
1013        let libs = [
1014            "libcudart.so.12",
1015            "libcublasLt.so.12",
1016            "libcublas.so.12",
1017            "libnvrtc.so.12",
1018            "libcurand.so.10",
1019            "libcufft.so.11",
1020        ];
1021
1022        for root in roots {
1023            if !root.is_dir() {
1024                continue;
1025            }
1026            let mut loaded = 0usize;
1027            for lib in libs {
1028                let p = root.join(lib);
1029                if !p.is_file() {
1030                    continue;
1031                }
1032                if Self::dlopen_path(&p, RTLD_NOW | RTLD_GLOBAL).is_ok() {
1033                    loaded += 1;
1034                }
1035            }
1036            if loaded > 0 {
1037                info!(root = %root.display(), loaded, "Preloaded CUDA runtime libraries for ORT");
1038                break;
1039            }
1040        }
1041    }
1042
1043    /// When ORT is linked statically, TensorRT provider expects `Provider_GetHost`
1044    /// from `libonnxruntime_providers_shared.so` to already be present in the process.
1045    /// Preloading this bridge with `RTLD_GLOBAL` satisfies that symbol before TRT EP load.
1046    fn preload_ort_provider_pair(kind: OrtProviderKind) -> Result<()> {
1047        #[cfg(target_os = "linux")]
1048        {
1049            let (dir, source) = Self::resolve_ort_provider_dir(kind)?;
1050            let shared = dir.join("libonnxruntime_providers_shared.so");
1051            let provider = dir.join(kind.soname());
1052
1053            Self::dlopen_path(&shared, RTLD_NOW | RTLD_GLOBAL).map_err(|e| {
1054                EngineError::ModelMetadata(format!(
1055                    "Failed loading providers_shared from {} ({e}). \
1056Ensure ORT_DYLIB_PATH/ORT_LIB_LOCATION points to a valid ORT cache dir.",
1057                    shared.display()
1058                ))
1059            })?;
1060
1061            info!(
1062                path = %provider.display(),
1063                provider = kind.label(),
1064                "Skipping explicit provider dlopen; relying on startup LD_LIBRARY_PATH + ORT registration"
1065            );
1066
1067            info!(
1068                source,
1069                dir = %dir.display(),
1070                provider = kind.label(),
1071                path = %provider.display(),
1072                "ORT provider pair prepared (providers_shared preloaded, provider path configured)"
1073            );
1074            Ok(())
1075        }
1076
1077        #[cfg(not(target_os = "linux"))]
1078        {
1079            let _ = kind;
1080            Ok(())
1081        }
1082    }
1083
1084    #[cfg(all(target_os = "linux", test))]
1085    fn preload_ort_provider_bridge() -> Result<()> {
1086        for path in Self::ort_provider_candidates("libonnxruntime_providers_shared.so") {
1087            if path.is_file() {
1088                return Self::dlopen_path(&path, RTLD_NOW | RTLD_GLOBAL);
1089            }
1090        }
1091        Err(EngineError::ModelMetadata(
1092            "Could not preload libonnxruntime_providers_shared.so from known search paths".into(),
1093        ))
1094    }
1095
1096    #[cfg(all(not(target_os = "linux"), test))]
1097    #[allow(dead_code)]
1098    fn preload_ort_provider_bridge() -> Result<()> {
1099        Ok(())
1100    }
1101
1102    #[cfg(not(target_os = "linux"))]
1103    fn preload_cuda_runtime_libs() {}
1104
1105    fn build_trt_session(&self) -> Result<Session> {
1106        Self::preload_cuda_runtime_libs();
1107        Self::preload_ort_provider_pair(OrtProviderKind::TensorRt)?;
1108
1109        let mut trt_ep = TensorRTExecutionProvider::default()
1110            .with_device_id(self.device_id)
1111            .with_engine_cache(true)
1112            .with_engine_cache_path(
1113                self.model_path
1114                    .parent()
1115                    .unwrap_or(&self.model_path)
1116                    .join("trt_cache")
1117                    .to_string_lossy()
1118                    .to_string(),
1119            );
1120
1121        match &self.precision_policy {
1122            PrecisionPolicy::Fp32 => {
1123                info!("TRT precision: FP32 (no mixed precision)");
1124            }
1125            PrecisionPolicy::Fp16 => {
1126                trt_ep = trt_ep.with_fp16(true);
1127                info!("TRT precision: FP16 mixed precision");
1128            }
1129            PrecisionPolicy::Int8 { calibration_table } => {
1130                trt_ep = trt_ep.with_fp16(true).with_int8(true);
1131                info!(
1132                    table = %calibration_table.display(),
1133                    "TRT precision: INT8 with calibration table"
1134                );
1135            }
1136        }
1137
1138        Session::builder()?
1139            .with_execution_providers([trt_ep.build().error_on_failure()])?
1140            .with_intra_threads(1)?
1141            .commit_from_file(&self.model_path)
1142            .map_err(Into::into)
1143    }
1144
1145    fn build_cuda_session(&self) -> Result<Session> {
1146        Self::preload_cuda_runtime_libs();
1147        Self::preload_ort_provider_pair(OrtProviderKind::Cuda)?;
1148        let cuda_ep = CUDAExecutionProvider::default().with_device_id(self.device_id);
1149        Session::builder()?
1150            .with_execution_providers([cuda_ep.build().error_on_failure()])?
1151            .with_intra_threads(1)?
1152            .commit_from_file(&self.model_path)
1153            .map_err(Into::into)
1154    }
1155
1156    fn pointer_identity_mismatch(
1157        input_ptr: u64,
1158        texture_ptr: u64,
1159        output_ptr: u64,
1160        ring_slot_ptr: u64,
1161    ) -> Option<String> {
1162        if input_ptr != texture_ptr {
1163            return Some(format!(
1164                "POINTER MISMATCH: IO-bound input (0x{input_ptr:016x}) != GpuTexture (0x{texture_ptr:016x})"
1165            ));
1166        }
1167        if output_ptr != ring_slot_ptr {
1168            return Some(format!(
1169                "POINTER MISMATCH: IO-bound output (0x{output_ptr:016x}) != ring slot (0x{ring_slot_ptr:016x})"
1170            ));
1171        }
1172        None
1173    }
1174
1175    /// Verify that IO-bound device pointers match the source GpuTexture
1176    /// and OutputRing slot pointers exactly (pointer identity check).
1177    ///
1178    /// Called by `run_io_bound` to audit that ORT IO Binding uses our
1179    /// device pointers without any host staging or reallocation.
1180    fn verify_pointer_identity(
1181        input_ptr: u64,
1182        output_ptr: u64,
1183        input_texture: &GpuTexture,
1184        ring_slot_ptr: u64,
1185    ) -> Result<()> {
1186        let texture_ptr = input_texture.device_ptr();
1187        debug!(
1188            input_ptr = format!("0x{:016x}", input_ptr),
1189            texture_ptr = format!("0x{:016x}", texture_ptr),
1190            output_ptr = format!("0x{:016x}", output_ptr),
1191            ring_slot_ptr = format!("0x{:016x}", ring_slot_ptr),
1192            "IO-binding pointer identity audit"
1193        );
1194
1195        if let Some(message) =
1196            Self::pointer_identity_mismatch(input_ptr, texture_ptr, output_ptr, ring_slot_ptr)
1197        {
1198            debug_assert!(false, "{message}");
1199            #[cfg(feature = "audit-no-host-copies")]
1200            if rave_core::host_copy_audit::is_strict_mode() {
1201                return Err(EngineError::InvariantViolation(message));
1202            }
1203            rave_core::host_copy_violation!("inference", "{message}");
1204        }
1205        Ok(())
1206    }
1207
1208    fn run_io_bound(
1209        session: &mut Session,
1210        meta: &ModelMetadata,
1211        input: &GpuTexture,
1212        output_ptr: u64,
1213        output_bytes: usize,
1214        _ctx: &GpuContext,
1215        _mem_info: &ort::memory::MemoryInfo, // Unused argument if we recreate it, or use it if passed.
1216    ) -> Result<()> {
1217        let in_w = input.width as i64;
1218        let in_h = input.height as i64;
1219        let out_w = in_w * meta.scale as i64;
1220        let out_h = in_h * meta.scale as i64;
1221
1222        let input_shape: Vec<i64> = vec![1, meta.input_channels as i64, in_h, in_w];
1223        let output_shape: Vec<i64> = vec![1, meta.input_channels as i64, out_h, out_w];
1224
1225        // Resolve element type from pixel format (F16 or F32).
1226        let elem_type = ort_element_type(input.format);
1227        let elem_bytes = input.format.element_bytes();
1228        let input_bytes = input.data.len();
1229
1230        let expected = (output_shape.iter().product::<i64>() as usize) * elem_bytes;
1231        if output_bytes < expected {
1232            return Err(EngineError::BufferTooSmall {
1233                need: expected,
1234                have: output_bytes,
1235            });
1236        }
1237
1238        let input_ptr = input.device_ptr();
1239
1240        // Phase 7: Pointer identity audit — verify device pointers match.
1241        Self::verify_pointer_identity(input_ptr, output_ptr, input, output_ptr)?;
1242
1243        let mut binding = session.create_binding()?;
1244
1245        // Create fresh MemoryInfo for this opt (cheap, avoids Sync issues with cache)
1246        // (Removed: using FFI helper instead)
1247
1248        // SAFETY — ORT IO Binding with raw device pointers:
1249        unsafe {
1250            // Create Input Tensor from raw device memory (zero-copy)
1251            let input_tensor = create_tensor_from_device_memory(
1252                input_ptr as *mut _,
1253                input_bytes,
1254                &input_shape,
1255                elem_type,
1256            )?;
1257
1258            binding.bind_input(&meta.input_name, &input_tensor)?;
1259
1260            // Create Output Tensor from raw device memory (zero-copy)
1261            let output_tensor = create_tensor_from_device_memory(
1262                output_ptr as *mut _,
1263                output_bytes,
1264                &output_shape,
1265                elem_type,
1266            )?;
1267
1268            binding.bind_output(&meta.output_name, output_tensor)?;
1269        }
1270
1271        // Synchronous execution — ORT blocks until all TensorRT kernels complete.
1272        session.run_binding(&binding)?;
1273
1274        Ok(())
1275    }
1276}
1277
1278#[cfg(test)]
1279mod pointer_audit_tests {
1280    use super::TensorRtBackend;
1281
1282    #[test]
1283    fn pointer_identity_audit_accepts_matching_addresses() {
1284        let mismatch = TensorRtBackend::pointer_identity_mismatch(0x1000, 0x1000, 0x2000, 0x2000);
1285        assert!(mismatch.is_none());
1286    }
1287
1288    #[test]
1289    fn pointer_identity_audit_reports_mismatch() {
1290        let mismatch = TensorRtBackend::pointer_identity_mismatch(0x1000, 0x1111, 0x2000, 0x2000)
1291            .expect("mismatch should be reported");
1292        assert!(mismatch.contains("IO-bound input"));
1293    }
1294}
1295
1296#[cfg(all(test, target_os = "linux"))]
1297mod provider_resolution_tests {
1298    use super::{OrtProviderKind, TensorRtBackend};
1299    use std::env;
1300    use std::fs::{self, File};
1301    use std::path::{Path, PathBuf};
1302    use std::process;
1303    use std::sync::atomic::{AtomicU64, Ordering};
1304
1305    static NEXT_TMP_ID: AtomicU64 = AtomicU64::new(0);
1306
1307    struct TempDir {
1308        path: PathBuf,
1309    }
1310
1311    impl TempDir {
1312        fn new(prefix: &str) -> Self {
1313            let unique = format!(
1314                "rave_provider_resolution_{prefix}_{}_{}",
1315                process::id(),
1316                NEXT_TMP_ID.fetch_add(1, Ordering::Relaxed)
1317            );
1318            let path = env::temp_dir().join(unique);
1319            fs::create_dir_all(&path).expect("create temp dir");
1320            Self { path }
1321        }
1322
1323        fn join(&self, suffix: &str) -> PathBuf {
1324            self.path.join(suffix)
1325        }
1326    }
1327
1328    impl Drop for TempDir {
1329        fn drop(&mut self) {
1330            let _ = fs::remove_dir_all(&self.path);
1331        }
1332    }
1333
1334    fn write_provider_pair(dir: &Path, kind: OrtProviderKind) {
1335        fs::create_dir_all(dir).expect("create provider dir");
1336        File::create(dir.join("libonnxruntime_providers_shared.so")).expect("create shared lib");
1337        File::create(dir.join(kind.soname())).expect("create provider lib");
1338    }
1339
1340    #[test]
1341    fn provider_dir_env_override_wins() {
1342        let tmp = TempDir::new("env_override");
1343        let override_dir = tmp.join("override");
1344        let exe_dir = tmp.join("exe");
1345        let cache_dir = tmp.join("cache");
1346
1347        write_provider_pair(&override_dir, OrtProviderKind::TensorRt);
1348        write_provider_pair(&exe_dir, OrtProviderKind::TensorRt);
1349        write_provider_pair(&cache_dir, OrtProviderKind::TensorRt);
1350
1351        let candidates = TensorRtBackend::provider_dir_candidates(
1352            Some(override_dir.clone()),
1353            None,
1354            Some(exe_dir),
1355            vec![cache_dir],
1356            false,
1357        );
1358        let (resolved, source) = TensorRtBackend::resolve_provider_dir(
1359            OrtProviderKind::TensorRt,
1360            &candidates,
1361            Some(&override_dir),
1362            None,
1363            "upscale",
1364            "direct",
1365            "tensorrt",
1366        )
1367        .expect("provider directory should resolve");
1368
1369        assert_eq!(resolved, override_dir);
1370        assert_eq!(source, "ORT_DYLIB_PATH");
1371    }
1372
1373    #[test]
1374    fn provider_dir_fallback_order_uses_first_existing_candidate() {
1375        let tmp = TempDir::new("fallback_order");
1376        let exe_dir = tmp.join("exe");
1377        let exe_deps_dir = exe_dir.join("deps");
1378        let cache_dir = tmp.join("cache");
1379
1380        write_provider_pair(&exe_deps_dir, OrtProviderKind::Cuda);
1381        write_provider_pair(&cache_dir, OrtProviderKind::Cuda);
1382
1383        let candidates = TensorRtBackend::provider_dir_candidates(
1384            None,
1385            None,
1386            Some(exe_dir.clone()),
1387            vec![cache_dir],
1388            false,
1389        );
1390        let (resolved, source) = TensorRtBackend::resolve_provider_dir(
1391            OrtProviderKind::Cuda,
1392            &candidates,
1393            None,
1394            None,
1395            "benchmark",
1396            "provider_preload",
1397            "cuda",
1398        )
1399        .expect("provider directory should resolve");
1400
1401        assert_eq!(resolved, exe_deps_dir);
1402        assert_eq!(source, "exe_dir/deps");
1403    }
1404
1405    #[test]
1406    fn provider_dir_failure_reports_checked_candidates_in_order() {
1407        let tmp = TempDir::new("failure_diagnostics");
1408        let override_dir = tmp.join("missing_override");
1409        let lib_dir = tmp.join("missing_lib");
1410
1411        let candidates = TensorRtBackend::provider_dir_candidates(
1412            Some(override_dir.clone()),
1413            Some(lib_dir.clone()),
1414            None,
1415            vec![],
1416            false,
1417        );
1418        let err = TensorRtBackend::resolve_provider_dir(
1419            OrtProviderKind::TensorRt,
1420            &candidates,
1421            Some(&override_dir),
1422            Some(&lib_dir),
1423            "validate",
1424            "provider_preload",
1425            "tensorrt",
1426        )
1427        .expect_err("resolution should fail");
1428        let msg = err.to_string();
1429
1430        assert!(msg.contains("command=validate"));
1431        assert!(msg.contains("mode=provider_preload"));
1432        assert!(msg.contains("ep_mode=tensorrt"));
1433        assert!(msg.contains("candidates_checked=[1:ORT_DYLIB_PATH:"));
1434        let idx_override = msg
1435            .find("1:ORT_DYLIB_PATH:")
1436            .expect("missing ORT_DYLIB_PATH candidate");
1437        let idx_lib = msg
1438            .find("2:ORT_LIB_LOCATION:")
1439            .expect("missing ORT_LIB_LOCATION candidate");
1440        assert!(
1441            idx_override < idx_lib,
1442            "candidate order should be deterministic and preserve precedence"
1443        );
1444    }
1445
1446    #[test]
1447    fn provider_candidate_rendering_is_deterministic() {
1448        let candidates = vec![
1449            (PathBuf::from("/tmp/a"), "ORT_DYLIB_PATH"),
1450            (PathBuf::from("/tmp/b"), "ORT_LIB_LOCATION"),
1451            (PathBuf::from("/tmp/c"), "exe_dir"),
1452        ];
1453
1454        let rendered_a = TensorRtBackend::render_provider_candidates(&candidates);
1455        let rendered_b = TensorRtBackend::render_provider_candidates(&candidates);
1456
1457        assert_eq!(rendered_a, rendered_b);
1458        assert_eq!(
1459            rendered_a,
1460            "1:ORT_DYLIB_PATH:/tmp/a, 2:ORT_LIB_LOCATION:/tmp/b, 3:exe_dir:/tmp/c"
1461        );
1462    }
1463}
1464
1465#[cfg(all(test, target_os = "linux"))]
1466mod tests {
1467    use super::{RTLD_LOCAL, RTLD_NOW, TensorRtBackend};
1468    use rave_core::backend::UpscaleBackend;
1469    use std::env;
1470    use std::path::PathBuf;
1471
1472    #[test]
1473    #[ignore = "requires ORT TensorRT provider libs on host"]
1474    fn providers_load_with_bridge_preloaded() {
1475        // Ensure bridge is globally visible before loading TensorRT provider.
1476        TensorRtBackend::preload_ort_provider_bridge().expect("failed to preload providers_shared");
1477        let trt_candidates =
1478            TensorRtBackend::ort_provider_candidates("libonnxruntime_providers_tensorrt.so");
1479        assert!(
1480            !trt_candidates.is_empty(),
1481            "no libonnxruntime_providers_tensorrt.so candidates found"
1482        );
1483        let path = trt_candidates[0].clone();
1484        TensorRtBackend::dlopen_path(&path, RTLD_NOW | RTLD_LOCAL)
1485            .expect("providers_tensorrt should dlopen after providers_shared preload");
1486    }
1487
1488    #[test]
1489    #[ignore = "requires model + full ORT/TensorRT runtime"]
1490    fn ort_registers_tensorrt_ep_smoke() {
1491        let model = env::var("RAVE_TEST_ONNX_MODEL").expect("set RAVE_TEST_ONNX_MODEL");
1492        let backend = TensorRtBackend::new(
1493            PathBuf::from(model),
1494            rave_core::context::GpuContext::new(0).expect("cuda ctx"),
1495            0,
1496            6,
1497            4,
1498        );
1499        let rt = tokio::runtime::Builder::new_current_thread()
1500            .enable_all()
1501            .build()
1502            .expect("tokio runtime");
1503        rt.block_on(async move { backend.initialize().await })
1504            .expect("TensorRT EP registration should succeed");
1505    }
1506}
1507
1508#[async_trait]
1509impl UpscaleBackend for TensorRtBackend {
1510    async fn initialize(&self) -> Result<()> {
1511        let mut guard = self.state.lock().await;
1512        if guard.is_some() {
1513            return Err(EngineError::ModelMetadata("Already initialized".into()));
1514        }
1515
1516        validate_batch_config(&self.batch_config)?;
1517
1518        let ep_mode = Self::ort_ep_mode();
1519        let on_wsl = Self::is_wsl2();
1520        info!(
1521            path = %self.model_path.display(),
1522            ?ep_mode,
1523            on_wsl,
1524            "Loading ONNX model with ORT execution provider policy"
1525        );
1526
1527        let (session, active_provider) = match ep_mode {
1528            OrtEpMode::CudaOnly => {
1529                info!("RAVE_ORT_TENSORRT disables TensorRT EP; using CUDAExecutionProvider");
1530                (self.build_cuda_session()?, "CUDAExecutionProvider")
1531            }
1532            OrtEpMode::TensorRtOnly => (self.build_trt_session()?, "TensorrtExecutionProvider"),
1533            OrtEpMode::Auto => match self.build_trt_session() {
1534                Ok(session) => (session, "TensorrtExecutionProvider"),
1535                Err(e) => {
1536                    warn!(
1537                        error = %e,
1538                        on_wsl,
1539                        "TensorRT EP registration failed; falling back to CUDAExecutionProvider"
1540                    );
1541                    (self.build_cuda_session()?, "CUDAExecutionProvider")
1542                }
1543            },
1544        };
1545        info!(
1546            provider = active_provider,
1547            "ORT execution provider selected"
1548        );
1549
1550        let metadata = Self::extract_metadata(&session)?;
1551        info!(
1552            name = %metadata.name,
1553            scale = metadata.scale,
1554            input = %metadata.input_name,
1555            output = %metadata.output_name,
1556            ring_size = self.ring_size,
1557            min_ring_slots = self.min_ring_slots,
1558            provider = active_provider,
1559            precision = ?self.precision_policy,
1560            max_batch = self.batch_config.max_batch,
1561            "Model loaded"
1562        );
1563
1564        let _ = self.meta.set(metadata);
1565        let _ = self.selected_provider.set(active_provider.to_string());
1566
1567        *guard = Some(InferenceState {
1568            session,
1569            ring: None,
1570        });
1571
1572        Ok(())
1573    }
1574
1575    async fn process(&self, input: GpuTexture) -> Result<GpuTexture> {
1576        // Accept both F32 and F16 planar RGB.
1577        match input.format {
1578            PixelFormat::RgbPlanarF32 | PixelFormat::RgbPlanarF16 => {}
1579            other => {
1580                return Err(EngineError::FormatMismatch {
1581                    expected: PixelFormat::RgbPlanarF32,
1582                    actual: other,
1583                });
1584            }
1585        }
1586
1587        let meta = self.meta.get().ok_or(EngineError::NotInitialized)?;
1588        let mut guard = self.state.lock().await;
1589        let state = guard.as_mut().ok_or(EngineError::NotInitialized)?;
1590
1591        // Lazy ring init / realloc.
1592        match &mut state.ring {
1593            Some(ring) if ring.needs_realloc(input.width, input.height) => {
1594                debug!(
1595                    old_w = ring.alloc_dims.0,
1596                    old_h = ring.alloc_dims.1,
1597                    new_w = input.width,
1598                    new_h = input.height,
1599                    "Reallocating output ring"
1600                );
1601                ring.reallocate(&self.ctx, input.width, input.height, meta.scale)?;
1602            }
1603            None => {
1604                debug!(
1605                    w = input.width,
1606                    h = input.height,
1607                    slots = self.ring_size,
1608                    "Lazily creating output ring"
1609                );
1610                state.ring = Some(OutputRing::new(
1611                    &self.ctx,
1612                    input.width,
1613                    input.height,
1614                    meta.scale,
1615                    self.ring_size,
1616                    self.min_ring_slots,
1617                )?);
1618            }
1619            Some(_) => {}
1620        }
1621
1622        let ring = state.ring.as_mut().unwrap();
1623
1624        // Debug-mode host allocation tracking.
1625        #[cfg(feature = "debug-alloc")]
1626        {
1627            rave_core::debug_alloc::reset();
1628            rave_core::debug_alloc::enable();
1629        }
1630
1631        let output_arc = ring.acquire()?;
1632        let output_ptr = *(*output_arc).device_ptr();
1633        let output_bytes = ring.slot_bytes;
1634
1635        // ── Inference with latency measurement ──
1636        let t_start = std::time::Instant::now();
1637
1638        let mem_info = self.mem_info()?;
1639        Self::run_io_bound(
1640            &mut state.session,
1641            meta,
1642            &input,
1643            output_ptr,
1644            output_bytes,
1645            &self.ctx,
1646            &mem_info,
1647        )?;
1648
1649        let elapsed_us = t_start.elapsed().as_micros() as u64;
1650        self.inference_metrics.record(elapsed_us);
1651
1652        #[cfg(feature = "debug-alloc")]
1653        {
1654            rave_core::debug_alloc::disable();
1655            let host_allocs = rave_core::debug_alloc::count();
1656            debug_assert_eq!(
1657                host_allocs, 0,
1658                "VIOLATION: {host_allocs} host allocations during inference"
1659            );
1660        }
1661
1662        let out_w = input.width * meta.scale;
1663        let out_h = input.height * meta.scale;
1664        let elem_bytes = input.format.element_bytes();
1665
1666        Ok(GpuTexture {
1667            data: output_arc,
1668            width: out_w,
1669            height: out_h,
1670            pitch: (out_w as usize) * elem_bytes,
1671            format: input.format, // Preserve F32 or F16
1672        })
1673    }
1674
1675    async fn shutdown(&self) -> Result<()> {
1676        let mut guard = self.state.lock().await;
1677        if let Some(state) = guard.take() {
1678            info!("Shutting down TensorRT backend");
1679            self.ctx.sync_all()?;
1680
1681            // Report ring metrics.
1682            if let Some(ring) = &state.ring {
1683                let snap = ring.metrics.snapshot();
1684                info!(
1685                    reuse = snap.reuse,
1686                    contention = snap.contention,
1687                    first_use = snap.first_use,
1688                    "Final ring metrics"
1689                );
1690            }
1691
1692            // Report inference metrics.
1693            let snap = self.inference_metrics.snapshot();
1694            info!(
1695                frames = snap.frames_inferred,
1696                avg_us = snap.avg_inference_us,
1697                peak_us = snap.peak_inference_us,
1698                precision = ?self.precision_policy,
1699                "Final inference metrics"
1700            );
1701
1702            // Report VRAM.
1703            let (current, peak) = self.ctx.vram_usage();
1704            info!(
1705                current_mb = current / (1024 * 1024),
1706                peak_mb = peak / (1024 * 1024),
1707                "Final VRAM usage"
1708            );
1709
1710            drop(state.ring);
1711            drop(state.session);
1712            debug!("TensorRT backend shutdown complete");
1713        }
1714        Ok(())
1715    }
1716
1717    fn metadata(&self) -> Result<&ModelMetadata> {
1718        self.meta.get().ok_or(EngineError::NotInitialized)
1719    }
1720}
1721
1722#[cfg(test)]
1723mod batch_config_tests {
1724    use super::{BatchConfig, validate_batch_config};
1725
1726    #[test]
1727    fn batch_config_validator_accepts_single_frame() {
1728        let cfg = BatchConfig {
1729            max_batch: 1,
1730            latency_deadline_us: 8_000,
1731        };
1732        validate_batch_config(&cfg).expect("max_batch=1 should be accepted");
1733    }
1734
1735    #[test]
1736    fn batch_config_validator_rejects_micro_batching() {
1737        let cfg = BatchConfig {
1738            max_batch: 2,
1739            latency_deadline_us: 8_000,
1740        };
1741        let err = validate_batch_config(&cfg).expect_err("max_batch>1 must be rejected");
1742        let msg = err.to_string();
1743        assert!(msg.contains("max_batch"));
1744        assert!(msg.contains("not implemented"));
1745    }
1746}
1747
1748impl Drop for TensorRtBackend {
1749    fn drop(&mut self) {
1750        if let Ok(mut guard) = self.state.try_lock()
1751            && let Some(state) = guard.take()
1752        {
1753            let _ = self.ctx.sync_all();
1754            drop(state);
1755        }
1756    }
1757}