1use std::collections::HashSet;
36use std::env;
37use std::ffi::{CStr, CString, c_char, c_void};
38use std::path::{Path, PathBuf};
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::{Arc, OnceLock};
41use std::time::UNIX_EPOCH;
42
43use async_trait::async_trait;
44use cudarc::driver::{DevicePtr, DeviceSlice};
45use tokio::sync::Mutex;
46use tracing::{debug, info, warn};
47
48use ort::session::Session;
49use ort::sys as ort_sys;
50use ort::value::Value as OrtValue;
51
52use rave_core::backend::{ModelMetadata, UpscaleBackend};
53use rave_core::context::GpuContext;
54use rave_core::error::{EngineError, Result};
55use rave_core::types::{GpuTexture, PixelFormat};
56
57use ort::execution_providers::{CUDAExecutionProvider, TensorRTExecutionProvider};
58
59unsafe fn create_tensor_from_device_memory(
61 ptr: *mut std::ffi::c_void,
62 bytes: usize,
63 shape: &[i64],
64 elem_type: ort::tensor::TensorElementType,
65) -> Result<OrtValue> {
66 let api = ort::api();
67
68 let mut mem_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
70 let name = std::ffi::CString::new("Cuda").unwrap();
71 let status = unsafe {
72 (api.CreateMemoryInfo)(
73 name.as_ptr(),
74 ort_sys::OrtAllocatorType::OrtArenaAllocator,
75 0, ort_sys::OrtMemType::OrtMemTypeDefault,
77 &mut mem_info_ptr,
78 )
79 };
80 if !status.0.is_null() {
81 unsafe { (api.ReleaseStatus)(status.0) };
82 return Err(EngineError::Inference(ort::Error::new(
83 "Failed to create MemoryInfo",
84 )));
85 }
86
87 let mut ort_value_ptr: *mut ort_sys::OrtValue = std::ptr::null_mut();
89 let status = unsafe {
90 (api.CreateTensorWithDataAsOrtValue)(
91 mem_info_ptr,
92 ptr,
93 bytes as _,
94 shape.as_ptr(),
95 shape.len() as _,
96 elem_type.into(),
97 &mut ort_value_ptr,
98 )
99 };
100
101 unsafe { (api.ReleaseMemoryInfo)(mem_info_ptr) };
106
107 if !status.0.is_null() {
108 unsafe { (api.ReleaseStatus)(status.0) };
109 return Err(EngineError::Inference(ort::Error::new(
110 "Failed to create Tensor",
111 )));
112 }
113
114 Ok(unsafe {
116 ort::value::Value::<ort::value::DynValueTypeMarker>::from_ptr(
117 std::ptr::NonNull::new(ort_value_ptr).unwrap(),
118 None,
119 )
120 })
121}
122
123#[derive(Clone, Debug, Default)]
131pub enum PrecisionPolicy {
132 Fp32,
134 #[default]
136 Fp16,
137 Int8 { calibration_table: PathBuf },
140}
141
142#[derive(Clone, Debug)]
146pub struct BatchConfig {
147 pub max_batch: usize,
150 pub latency_deadline_us: u64,
153}
154
155impl Default for BatchConfig {
156 fn default() -> Self {
157 Self {
158 max_batch: 1,
159 latency_deadline_us: 8_000, }
161 }
162}
163#[derive(Debug)]
167pub struct InferenceMetrics {
168 pub frames_inferred: AtomicU64,
170 pub total_inference_us: AtomicU64,
172 pub peak_inference_us: AtomicU64,
174}
175
176impl InferenceMetrics {
177 pub const fn new() -> Self {
178 Self {
179 frames_inferred: AtomicU64::new(0),
180 total_inference_us: AtomicU64::new(0),
181 peak_inference_us: AtomicU64::new(0),
182 }
183 }
184
185 pub fn record(&self, elapsed_us: u64) {
186 self.frames_inferred.fetch_add(1, Ordering::Relaxed);
187 self.total_inference_us
188 .fetch_add(elapsed_us, Ordering::Relaxed);
189 self.peak_inference_us
190 .fetch_max(elapsed_us, Ordering::Relaxed);
191 }
192
193 pub fn snapshot(&self) -> InferenceMetricsSnapshot {
194 let frames = self.frames_inferred.load(Ordering::Relaxed);
195 let total = self.total_inference_us.load(Ordering::Relaxed);
196 let peak = self.peak_inference_us.load(Ordering::Relaxed);
197 InferenceMetricsSnapshot {
198 frames_inferred: frames,
199 avg_inference_us: if frames > 0 { total / frames } else { 0 },
200 peak_inference_us: peak,
201 }
202 }
203}
204
205impl Default for InferenceMetrics {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211#[derive(Clone, Debug)]
213pub struct InferenceMetricsSnapshot {
214 pub frames_inferred: u64,
215 pub avg_inference_us: u64,
216 pub peak_inference_us: u64,
217}
218
219#[derive(Debug)]
223pub struct RingMetrics {
224 pub slot_reuse_count: AtomicU64,
226 pub slot_contention_events: AtomicU64,
228 pub slot_first_use_count: AtomicU64,
230}
231
232impl RingMetrics {
233 pub const fn new() -> Self {
234 Self {
235 slot_reuse_count: AtomicU64::new(0),
236 slot_contention_events: AtomicU64::new(0),
237 slot_first_use_count: AtomicU64::new(0),
238 }
239 }
240
241 pub fn snapshot(&self) -> (u64, u64, u64) {
242 (
243 self.slot_reuse_count.load(Ordering::Relaxed),
244 self.slot_contention_events.load(Ordering::Relaxed),
245 self.slot_first_use_count.load(Ordering::Relaxed),
246 )
247 }
248}
249
250impl Default for RingMetrics {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256pub struct OutputRing {
260 slots: Vec<Arc<cudarc::driver::CudaSlice<u8>>>,
261 cursor: usize,
262 pub slot_bytes: usize,
263 pub alloc_dims: (u32, u32),
264 used: Vec<bool>,
266 pub metrics: RingMetrics,
267}
268
269impl OutputRing {
270 pub fn new(
275 ctx: &GpuContext,
276 in_w: u32,
277 in_h: u32,
278 scale: u32,
279 count: usize,
280 min_slots: usize,
281 ) -> Result<Self> {
282 if count < min_slots {
283 return Err(EngineError::DimensionMismatch(format!(
284 "OutputRing: ring_size ({count}) < required minimum ({min_slots}). \
285 Ring must be ≥ downstream_channel_capacity + 2."
286 )));
287 }
288 if count < 2 {
289 return Err(EngineError::DimensionMismatch(
290 "OutputRing: ring_size must be ≥ 2 for double-buffering".into(),
291 ));
292 }
293
294 let out_w = (in_w * scale) as usize;
295 let out_h = (in_h * scale) as usize;
296 let slot_bytes = 3 * out_w * out_h * std::mem::size_of::<f32>();
297
298 let slots = (0..count)
299 .map(|_| ctx.alloc(slot_bytes).map(Arc::new))
300 .collect::<Result<Vec<_>>>()?;
301
302 debug!(count, slot_bytes, out_w, out_h, "Output ring allocated");
303
304 Ok(Self {
305 slots,
306 cursor: 0,
307 slot_bytes,
308 alloc_dims: (in_w, in_h),
309 used: vec![false; count],
310 metrics: RingMetrics::new(),
311 })
312 }
313
314 pub fn acquire(&mut self) -> Result<Arc<cudarc::driver::CudaSlice<u8>>> {
321 let slot = &self.slots[self.cursor];
322 let sc = Arc::strong_count(slot);
323
324 if sc != 1 {
325 self.metrics
326 .slot_contention_events
327 .fetch_add(1, Ordering::Relaxed);
328 return Err(EngineError::BufferTooSmall {
329 need: self.slot_bytes,
330 have: 0,
331 });
332 }
333
334 debug_assert_eq!(
336 Arc::strong_count(slot),
337 1,
338 "OutputRing: slot {} strong_count must be 1 before reuse, got {}",
339 self.cursor,
340 sc
341 );
342
343 if self.used[self.cursor] {
344 self.metrics
345 .slot_reuse_count
346 .fetch_add(1, Ordering::Relaxed);
347 } else {
348 self.used[self.cursor] = true;
349 self.metrics
350 .slot_first_use_count
351 .fetch_add(1, Ordering::Relaxed);
352 }
353
354 let cloned = Arc::clone(slot);
355 self.cursor = (self.cursor + 1) % self.slots.len();
356 Ok(cloned)
357 }
358
359 pub fn needs_realloc(&self, in_w: u32, in_h: u32) -> bool {
360 self.alloc_dims != (in_w, in_h)
361 }
362
363 pub fn reallocate(&mut self, ctx: &GpuContext, in_w: u32, in_h: u32, scale: u32) -> Result<()> {
365 for (i, slot) in self.slots.iter().enumerate() {
366 let sc = Arc::strong_count(slot);
367 if sc != 1 {
368 return Err(EngineError::DimensionMismatch(format!(
369 "Cannot reallocate ring: slot {} still in use (strong_count={})",
370 i, sc,
371 )));
372 }
373 }
374
375 for _slot in &self.slots {
377 ctx.vram_dec(self.slot_bytes);
378 }
379
380 let count = self.slots.len();
381 let out_w = (in_w * scale) as usize;
382 let out_h = (in_h * scale) as usize;
383 let slot_bytes = 3 * out_w * out_h * std::mem::size_of::<f32>();
384
385 self.slots = (0..count)
386 .map(|_| ctx.alloc(slot_bytes).map(Arc::new))
387 .collect::<Result<Vec<_>>>()?;
388 self.cursor = 0;
389 self.slot_bytes = slot_bytes;
390 self.alloc_dims = (in_w, in_h);
391 self.used = vec![false; count];
392
393 debug!(count, slot_bytes, out_w, out_h, "Output ring reallocated");
394 Ok(())
395 }
396
397 pub fn len(&self) -> usize {
399 self.slots.len()
400 }
401
402 pub fn is_empty(&self) -> bool {
403 self.slots.is_empty()
404 }
405}
406
407struct InferenceState {
410 session: Session,
411 ring: Option<OutputRing>,
412}
413
414fn ort_element_type(format: PixelFormat) -> ort::tensor::TensorElementType {
416 match format {
417 PixelFormat::RgbPlanarF16 => ort::tensor::TensorElementType::Float16,
418 _ => ort::tensor::TensorElementType::Float32,
419 }
420}
421
422pub struct TensorRtBackend {
425 model_path: PathBuf,
426 ctx: Arc<GpuContext>,
427 device_id: i32,
428 ring_size: usize,
429 min_ring_slots: usize,
430 meta: OnceLock<ModelMetadata>,
431 state: Mutex<Option<InferenceState>>,
432 pub inference_metrics: InferenceMetrics,
433 pub precision_policy: PrecisionPolicy,
435 pub batch_config: BatchConfig,
437}
438
439#[derive(Copy, Clone, Debug, PartialEq, Eq)]
440enum OrtEpMode {
441 Auto,
442 TensorRtOnly,
443 CudaOnly,
444}
445
446#[cfg(target_os = "linux")]
447unsafe extern "C" {
448 fn dlopen(filename: *const c_char, flags: i32) -> *mut c_void;
449 fn dlerror() -> *const c_char;
450}
451
452#[cfg(target_os = "linux")]
453const RTLD_NOW: i32 = 2;
454#[cfg(all(target_os = "linux", test))]
455const RTLD_LOCAL: i32 = 0;
456#[cfg(target_os = "linux")]
457const RTLD_GLOBAL: i32 = 0x100;
458
459#[cfg(target_os = "linux")]
460#[derive(Clone, Copy)]
461enum OrtProviderKind {
462 Cuda,
463 TensorRt,
464}
465
466#[cfg(target_os = "linux")]
467impl OrtProviderKind {
468 fn soname(self) -> &'static str {
469 match self {
470 OrtProviderKind::Cuda => "libonnxruntime_providers_cuda.so",
471 OrtProviderKind::TensorRt => "libonnxruntime_providers_tensorrt.so",
472 }
473 }
474
475 fn label(self) -> &'static str {
476 match self {
477 OrtProviderKind::Cuda => "providers_cuda",
478 OrtProviderKind::TensorRt => "providers_tensorrt",
479 }
480 }
481}
482
483impl TensorRtBackend {
484 pub fn new(
492 model_path: PathBuf,
493 ctx: Arc<GpuContext>,
494 device_id: i32,
495 ring_size: usize,
496 downstream_capacity: usize,
497 ) -> Self {
498 Self::with_precision(
499 model_path,
500 ctx,
501 device_id,
502 ring_size,
503 downstream_capacity,
504 PrecisionPolicy::default(),
505 BatchConfig::default(),
506 )
507 }
508
509 pub fn with_precision(
511 model_path: PathBuf,
512 ctx: Arc<GpuContext>,
513 device_id: i32,
514 ring_size: usize,
515 downstream_capacity: usize,
516 precision_policy: PrecisionPolicy,
517 batch_config: BatchConfig,
518 ) -> Self {
519 let min_ring_slots = downstream_capacity + 2;
520 assert!(
521 ring_size >= min_ring_slots,
522 "ring_size ({ring_size}) must be ≥ downstream_capacity + 2 ({min_ring_slots})"
523 );
524 Self {
525 model_path,
526 ctx,
527 device_id,
528 ring_size,
529 min_ring_slots,
530 meta: OnceLock::new(),
531 state: Mutex::new(None),
532 inference_metrics: InferenceMetrics::new(),
533 precision_policy,
534 batch_config,
535 }
536 }
537
538 fn mem_info(&self) -> Result<ort::memory::MemoryInfo> {
541 ort::memory::MemoryInfo::new(
542 ort::memory::AllocationDevice::CUDA,
543 0,
544 ort::memory::AllocatorType::Device,
545 ort::memory::MemoryType::Default,
546 )
547 .map_err(|e| EngineError::ModelMetadata(format!("MemoryInfo: {e}")))
548 }
549
550 pub async fn ring_metrics(&self) -> Option<(u64, u64, u64)> {
552 let guard = self.state.lock().await;
553 guard
554 .as_ref()
555 .and_then(|s| s.ring.as_ref())
556 .map(|r| r.metrics.snapshot())
557 }
558
559 fn extract_metadata(session: &Session) -> Result<ModelMetadata> {
560 let inputs = session.inputs();
561 let outputs = session.outputs();
562
563 if inputs.is_empty() || outputs.is_empty() {
564 return Err(EngineError::ModelMetadata(
565 "Model must have at least one input and one output tensor".into(),
566 ));
567 }
568
569 let input_info = &inputs[0];
576 let output_info = &outputs[0];
577 let input_name = input_info.name().to_string();
578 let output_name = output_info.name().to_string();
579
580 let input_dims = match input_info.dtype() {
583 ort::value::ValueType::Tensor { shape, .. } => shape.clone(),
584 other => {
585 return Err(EngineError::ModelMetadata(format!(
586 "Expected tensor input, got {:?}",
587 other
588 )));
589 }
590 };
591
592 let output_dims = match output_info.dtype() {
593 ort::value::ValueType::Tensor { shape, .. } => shape.clone(),
594 other => {
595 return Err(EngineError::ModelMetadata(format!(
596 "Expected tensor output, got {:?}",
597 other
598 )));
599 }
600 };
601
602 if input_dims.len() != 4 || output_dims.len() != 4 {
603 return Err(EngineError::ModelMetadata(format!(
604 "Expected 4D tensors (NCHW), got input={}D output={}D",
605 input_dims.len(),
606 output_dims.len()
607 )));
608 }
609
610 let input_channels = input_dims[1] as u32;
611
612 let ih = input_dims[2];
613 let iw = input_dims[3];
614 let oh = output_dims[2];
615 let ow = output_dims[3];
616
617 let scale = if ih > 0 && oh > 0 && iw > 0 && ow > 0 {
618 (oh / ih) as u32
619 } else {
620 warn!("Dynamic spatial axes — defaulting to scale=4");
621 4
622 };
623
624 let min_input_hw = (
625 if ih > 0 { ih as u32 } else { 1 },
626 if iw > 0 { iw as u32 } else { 1 },
627 );
628
629 let max_input_hw = (
630 if ih > 0 { ih as u32 } else { u32::MAX },
631 if iw > 0 { iw as u32 } else { u32::MAX },
632 );
633
634 let name = session
635 .metadata()
636 .map(|m| m.name().unwrap_or("unknown".to_string()))
637 .unwrap_or_else(|_| "unknown".to_string());
638
639 Ok(ModelMetadata {
640 name,
641 scale,
642 input_name,
643 output_name,
644 input_channels,
645 min_input_hw,
646 max_input_hw,
647 })
648 }
649
650 fn ort_ep_mode() -> OrtEpMode {
651 match env::var("RAVE_ORT_TENSORRT")
652 .unwrap_or_else(|_| "auto".to_string())
653 .to_lowercase()
654 .as_str()
655 {
656 "0" | "off" | "false" | "cuda" | "cuda-only" => OrtEpMode::CudaOnly,
657 "1" | "on" | "true" | "trt" | "trt-only" => OrtEpMode::TensorRtOnly,
658 _ => OrtEpMode::Auto,
659 }
660 }
661
662 #[cfg(target_os = "linux")]
663 fn is_wsl2() -> bool {
664 std::fs::read_to_string("/proc/sys/kernel/osrelease")
665 .map(|s| s.to_ascii_lowercase().contains("microsoft"))
666 .unwrap_or(false)
667 }
668
669 #[cfg(not(target_os = "linux"))]
670 fn is_wsl2() -> bool {
671 false
672 }
673
674 #[cfg(target_os = "linux")]
675 fn ort_provider_cache_dirs_newest_first() -> Vec<PathBuf> {
676 let mut dirs = Vec::<(u128, PathBuf)>::new();
677 let Some(home) = env::var_os("HOME") else {
678 return Vec::new();
679 };
680 let base = PathBuf::from(home).join(".cache/ort.pyke.io/dfbin");
681 let Ok(triples) = std::fs::read_dir(base) else {
682 return Vec::new();
683 };
684
685 for triple in triples.flatten() {
686 let triple_path = triple.path();
687 if !triple_path.is_dir() {
688 continue;
689 }
690 let Ok(hashes) = std::fs::read_dir(triple_path) else {
691 continue;
692 };
693 for hash in hashes.flatten() {
694 let path = hash.path();
695 if !path.is_dir() {
696 continue;
697 }
698 let modified = std::fs::metadata(&path)
699 .and_then(|m| m.modified())
700 .ok()
701 .and_then(|t| t.duration_since(UNIX_EPOCH).ok())
702 .map(|d| d.as_nanos())
703 .unwrap_or(0);
704 dirs.push((modified, path));
705 }
706 }
707
708 dirs.sort_by(|a, b| b.cmp(a));
709 dirs.into_iter().map(|(_, path)| path).collect()
710 }
711
712 #[cfg(target_os = "linux")]
713 fn ort_provider_search_dirs() -> Vec<(PathBuf, &'static str)> {
714 let mut dirs = Vec::<(PathBuf, &'static str)>::new();
715 if let Some(dir) = env::var_os("ORT_DYLIB_PATH") {
716 dirs.push((PathBuf::from(dir), "ORT_DYLIB_PATH"));
717 }
718 if let Some(dir) = env::var_os("ORT_LIB_LOCATION") {
719 dirs.push((PathBuf::from(dir), "ORT_LIB_LOCATION"));
720 }
721 if Self::is_wsl2() {
722 for dir in Self::ort_provider_cache_dirs_newest_first() {
723 dirs.push((dir, "ort_cache_newest"));
724 }
725 if let Ok(exe) = env::current_exe()
726 && let Some(dir) = exe.parent()
727 {
728 dirs.push((dir.to_path_buf(), "exe_dir"));
729 dirs.push((dir.join("deps"), "exe_dir/deps"));
730 }
731 } else {
732 if let Ok(exe) = env::current_exe()
733 && let Some(dir) = exe.parent()
734 {
735 dirs.push((dir.to_path_buf(), "exe_dir"));
736 dirs.push((dir.join("deps"), "exe_dir/deps"));
737 }
738 for dir in Self::ort_provider_cache_dirs_newest_first() {
739 dirs.push((dir, "ort_cache_newest"));
740 }
741 }
742
743 let mut uniq = HashSet::<PathBuf>::new();
744 dirs.retain(|(p, _)| uniq.insert(p.clone()));
745 dirs
746 }
747
748 #[cfg(all(target_os = "linux", test))]
749 fn ort_provider_candidates(lib_name: &str) -> Vec<PathBuf> {
750 Self::ort_provider_search_dirs()
751 .into_iter()
752 .map(|(dir, _)| dir.join(lib_name))
753 .collect()
754 }
755
756 #[cfg(target_os = "linux")]
757 fn configure_ort_loader_path(dir: &Path) {
758 unsafe { env::set_var("ORT_DYLIB_PATH", dir) };
760 unsafe { env::set_var("ORT_LIB_LOCATION", dir) };
762 }
763
764 #[cfg(target_os = "linux")]
765 fn resolve_ort_provider_dir(kind: OrtProviderKind) -> Result<(PathBuf, &'static str)> {
766 let provider = kind.soname();
767 for (dir, source) in Self::ort_provider_search_dirs() {
768 let shared = dir.join("libonnxruntime_providers_shared.so");
769 let provider_path = dir.join(provider);
770 if shared.is_file() && provider_path.is_file() {
771 info!(
772 path = %shared.display(),
773 source,
774 "ORT providers_shared resolved"
775 );
776 info!(
777 provider = kind.label(),
778 path = %provider_path.display(),
779 source,
780 "ORT provider resolved"
781 );
782 Self::configure_ort_loader_path(&dir);
783 return Ok((dir, source));
784 }
785 }
786
787 Err(EngineError::ModelMetadata(format!(
788 "Could not find ORT provider pair (libonnxruntime_providers_shared.so + {}). \
789Set ORT_DYLIB_PATH or ORT_LIB_LOCATION to that directory, then re-run scripts/test_ort_provider_load.sh.",
790 provider
791 )))
792 }
793
794 #[cfg(target_os = "linux")]
795 fn dlopen_path(path: &Path, flags: i32) -> Result<()> {
796 let cpath = CString::new(path.to_string_lossy().as_bytes())
797 .map_err(|_| EngineError::ModelMetadata("Invalid dlopen path".into()))?;
798 let handle = unsafe { dlopen(cpath.as_ptr(), flags) };
800 if handle.is_null() {
801 let err = unsafe {
803 let p = dlerror();
804 if p.is_null() {
805 "unknown dlopen error".to_string()
806 } else {
807 CStr::from_ptr(p).to_string_lossy().to_string()
808 }
809 };
810 return Err(EngineError::ModelMetadata(format!(
811 "dlopen failed for {}: {err}",
812 path.display()
813 )));
814 }
815 Ok(())
816 }
817
818 #[cfg(target_os = "linux")]
819 fn preload_cuda_runtime_libs() {
820 let mut roots = Vec::<PathBuf>::new();
822 if let Some(dir) = env::var_os("RAVE_CUDA_RUNTIME_DIR") {
823 roots.push(PathBuf::from(dir));
824 }
825 roots.push(PathBuf::from("/usr/local/cuda-12/targets/x86_64-linux/lib"));
826 roots.push(PathBuf::from("/usr/local/cuda/lib64"));
827
828 let libs = [
829 "libcudart.so.12",
830 "libcublasLt.so.12",
831 "libcublas.so.12",
832 "libnvrtc.so.12",
833 "libcurand.so.10",
834 "libcufft.so.11",
835 ];
836
837 for root in roots {
838 if !root.is_dir() {
839 continue;
840 }
841 let mut loaded = 0usize;
842 for lib in libs {
843 let p = root.join(lib);
844 if !p.is_file() {
845 continue;
846 }
847 if Self::dlopen_path(&p, RTLD_NOW | RTLD_GLOBAL).is_ok() {
848 loaded += 1;
849 }
850 }
851 if loaded > 0 {
852 info!(root = %root.display(), loaded, "Preloaded CUDA runtime libraries for ORT");
853 break;
854 }
855 }
856 }
857
858 #[cfg(target_os = "linux")]
862 fn preload_ort_provider_pair(kind: OrtProviderKind) -> Result<()> {
863 let (dir, source) = Self::resolve_ort_provider_dir(kind)?;
864 let shared = dir.join("libonnxruntime_providers_shared.so");
865 let provider = dir.join(kind.soname());
866
867 Self::dlopen_path(&shared, RTLD_NOW | RTLD_GLOBAL).map_err(|e| {
868 EngineError::ModelMetadata(format!(
869 "Failed loading providers_shared from {} ({e}). \
870Ensure ORT_DYLIB_PATH/ORT_LIB_LOCATION points to a valid ORT cache dir.",
871 shared.display()
872 ))
873 })?;
874
875 info!(
876 path = %provider.display(),
877 provider = kind.label(),
878 "Skipping explicit provider dlopen; relying on startup LD_LIBRARY_PATH + ORT registration"
879 );
880
881 info!(
882 source,
883 dir = %dir.display(),
884 provider = kind.label(),
885 path = %provider.display(),
886 "ORT provider pair prepared (providers_shared preloaded, provider path configured)"
887 );
888 Ok(())
889 }
890
891 #[cfg(all(target_os = "linux", test))]
892 fn preload_ort_provider_bridge() -> Result<()> {
893 for path in Self::ort_provider_candidates("libonnxruntime_providers_shared.so") {
894 if path.is_file() {
895 return Self::dlopen_path(&path, RTLD_NOW | RTLD_GLOBAL);
896 }
897 }
898 Err(EngineError::ModelMetadata(
899 "Could not preload libonnxruntime_providers_shared.so from known search paths".into(),
900 ))
901 }
902
903 #[cfg(all(not(target_os = "linux"), test))]
904 fn preload_ort_provider_bridge() -> Result<()> {
905 Ok(())
906 }
907
908 #[cfg(not(target_os = "linux"))]
909 fn preload_cuda_runtime_libs() {}
910
911 fn build_trt_session(&self) -> Result<Session> {
912 Self::preload_cuda_runtime_libs();
913 Self::preload_ort_provider_pair(OrtProviderKind::TensorRt)?;
914
915 let mut trt_ep = TensorRTExecutionProvider::default()
916 .with_device_id(self.device_id)
917 .with_engine_cache(true)
918 .with_engine_cache_path(
919 self.model_path
920 .parent()
921 .unwrap_or(&self.model_path)
922 .join("trt_cache")
923 .to_string_lossy()
924 .to_string(),
925 );
926
927 match &self.precision_policy {
928 PrecisionPolicy::Fp32 => {
929 info!("TRT precision: FP32 (no mixed precision)");
930 }
931 PrecisionPolicy::Fp16 => {
932 trt_ep = trt_ep.with_fp16(true);
933 info!("TRT precision: FP16 mixed precision");
934 }
935 PrecisionPolicy::Int8 { calibration_table } => {
936 trt_ep = trt_ep.with_fp16(true).with_int8(true);
937 info!(
938 table = %calibration_table.display(),
939 "TRT precision: INT8 with calibration table"
940 );
941 }
942 }
943
944 Session::builder()?
945 .with_execution_providers([trt_ep.build().error_on_failure()])?
946 .with_intra_threads(1)?
947 .commit_from_file(&self.model_path)
948 .map_err(Into::into)
949 }
950
951 fn build_cuda_session(&self) -> Result<Session> {
952 Self::preload_cuda_runtime_libs();
953 Self::preload_ort_provider_pair(OrtProviderKind::Cuda)?;
954 let cuda_ep = CUDAExecutionProvider::default().with_device_id(self.device_id);
955 Session::builder()?
956 .with_execution_providers([cuda_ep.build().error_on_failure()])?
957 .with_intra_threads(1)?
958 .commit_from_file(&self.model_path)
959 .map_err(Into::into)
960 }
961
962 fn verify_pointer_identity(
968 input_ptr: u64,
969 output_ptr: u64,
970 input_texture: &GpuTexture,
971 ring_slot_ptr: u64,
972 ) {
973 let texture_ptr = input_texture.device_ptr();
974 debug!(
975 input_ptr = format!("0x{:016x}", input_ptr),
976 texture_ptr = format!("0x{:016x}", texture_ptr),
977 output_ptr = format!("0x{:016x}", output_ptr),
978 ring_slot_ptr = format!("0x{:016x}", ring_slot_ptr),
979 "IO-binding pointer identity audit"
980 );
981
982 debug_assert_eq!(
983 input_ptr, texture_ptr,
984 "POINTER MISMATCH: IO-bound input (0x{:016x}) != GpuTexture (0x{:016x})",
985 input_ptr, texture_ptr,
986 );
987 debug_assert_eq!(
988 output_ptr, ring_slot_ptr,
989 "POINTER MISMATCH: IO-bound output (0x{:016x}) != ring slot (0x{:016x})",
990 output_ptr, ring_slot_ptr,
991 );
992 }
993
994 fn run_io_bound(
995 session: &mut Session,
996 meta: &ModelMetadata,
997 input: &GpuTexture,
998 output_ptr: u64,
999 output_bytes: usize,
1000 _ctx: &GpuContext,
1001 _mem_info: &ort::memory::MemoryInfo, ) -> Result<()> {
1003 let in_w = input.width as i64;
1004 let in_h = input.height as i64;
1005 let out_w = in_w * meta.scale as i64;
1006 let out_h = in_h * meta.scale as i64;
1007
1008 let input_shape: Vec<i64> = vec![1, meta.input_channels as i64, in_h, in_w];
1009 let output_shape: Vec<i64> = vec![1, meta.input_channels as i64, out_h, out_w];
1010
1011 let elem_type = ort_element_type(input.format);
1013 let elem_bytes = input.format.element_bytes();
1014 let input_bytes = input.data.len();
1015
1016 let expected = (output_shape.iter().product::<i64>() as usize) * elem_bytes;
1017 if output_bytes < expected {
1018 return Err(EngineError::BufferTooSmall {
1019 need: expected,
1020 have: output_bytes,
1021 });
1022 }
1023
1024 let input_ptr = input.device_ptr();
1025
1026 Self::verify_pointer_identity(input_ptr, output_ptr, input, output_ptr);
1028
1029 let mut binding = session.create_binding()?;
1030
1031 unsafe {
1036 let input_tensor = create_tensor_from_device_memory(
1038 input_ptr as *mut _,
1039 input_bytes,
1040 &input_shape,
1041 elem_type,
1042 )?;
1043
1044 binding.bind_input(&meta.input_name, &input_tensor)?;
1045
1046 let output_tensor = create_tensor_from_device_memory(
1048 output_ptr as *mut _,
1049 output_bytes,
1050 &output_shape,
1051 elem_type,
1052 )?;
1053
1054 binding.bind_output(&meta.output_name, output_tensor)?;
1055 }
1056
1057 session.run_binding(&binding)?;
1059
1060 Ok(())
1061 }
1062}
1063
1064#[cfg(all(test, target_os = "linux"))]
1065mod tests {
1066 use super::{RTLD_LOCAL, RTLD_NOW, TensorRtBackend};
1067 use rave_core::backend::UpscaleBackend;
1068 use std::env;
1069 use std::path::PathBuf;
1070
1071 #[test]
1072 #[ignore = "requires ORT TensorRT provider libs on host"]
1073 fn providers_load_with_bridge_preloaded() {
1074 TensorRtBackend::preload_ort_provider_bridge().expect("failed to preload providers_shared");
1076 let trt_candidates =
1077 TensorRtBackend::ort_provider_candidates("libonnxruntime_providers_tensorrt.so");
1078 assert!(
1079 !trt_candidates.is_empty(),
1080 "no libonnxruntime_providers_tensorrt.so candidates found"
1081 );
1082 let path = trt_candidates[0].clone();
1083 TensorRtBackend::dlopen_path(&path, RTLD_NOW | RTLD_LOCAL)
1084 .expect("providers_tensorrt should dlopen after providers_shared preload");
1085 }
1086
1087 #[test]
1088 #[ignore = "requires model + full ORT/TensorRT runtime"]
1089 fn ort_registers_tensorrt_ep_smoke() {
1090 let model = env::var("RAVE_TEST_ONNX_MODEL").expect("set RAVE_TEST_ONNX_MODEL");
1091 let backend = TensorRtBackend::new(
1092 PathBuf::from(model),
1093 rave_core::context::GpuContext::new(0).expect("cuda ctx"),
1094 0,
1095 6,
1096 4,
1097 );
1098 let rt = tokio::runtime::Builder::new_current_thread()
1099 .enable_all()
1100 .build()
1101 .expect("tokio runtime");
1102 rt.block_on(async move { backend.initialize().await })
1103 .expect("TensorRT EP registration should succeed");
1104 }
1105}
1106
1107#[async_trait]
1108impl UpscaleBackend for TensorRtBackend {
1109 async fn initialize(&self) -> Result<()> {
1110 let mut guard = self.state.lock().await;
1111 if guard.is_some() {
1112 return Err(EngineError::ModelMetadata("Already initialized".into()));
1113 }
1114
1115 let ep_mode = Self::ort_ep_mode();
1116 let on_wsl = Self::is_wsl2();
1117 info!(
1118 path = %self.model_path.display(),
1119 ?ep_mode,
1120 on_wsl,
1121 "Loading ONNX model with ORT execution provider policy"
1122 );
1123
1124 let (session, active_provider) = match ep_mode {
1125 OrtEpMode::CudaOnly => {
1126 info!("RAVE_ORT_TENSORRT disables TensorRT EP; using CUDAExecutionProvider");
1127 (self.build_cuda_session()?, "CUDAExecutionProvider")
1128 }
1129 OrtEpMode::TensorRtOnly => (self.build_trt_session()?, "TensorrtExecutionProvider"),
1130 OrtEpMode::Auto => match self.build_trt_session() {
1131 Ok(session) => (session, "TensorrtExecutionProvider"),
1132 Err(e) => {
1133 warn!(
1134 error = %e,
1135 on_wsl,
1136 "TensorRT EP registration failed; falling back to CUDAExecutionProvider"
1137 );
1138 (self.build_cuda_session()?, "CUDAExecutionProvider")
1139 }
1140 },
1141 };
1142 info!(
1143 provider = active_provider,
1144 "ORT execution provider selected"
1145 );
1146
1147 let metadata = Self::extract_metadata(&session)?;
1148 info!(
1149 name = %metadata.name,
1150 scale = metadata.scale,
1151 input = %metadata.input_name,
1152 output = %metadata.output_name,
1153 ring_size = self.ring_size,
1154 min_ring_slots = self.min_ring_slots,
1155 provider = active_provider,
1156 precision = ?self.precision_policy,
1157 max_batch = self.batch_config.max_batch,
1158 "Model loaded"
1159 );
1160
1161 let _ = self.meta.set(metadata);
1162
1163 *guard = Some(InferenceState {
1164 session,
1165 ring: None,
1166 });
1167
1168 Ok(())
1169 }
1170
1171 async fn process(&self, input: GpuTexture) -> Result<GpuTexture> {
1172 match input.format {
1174 PixelFormat::RgbPlanarF32 | PixelFormat::RgbPlanarF16 => {}
1175 other => {
1176 return Err(EngineError::FormatMismatch {
1177 expected: PixelFormat::RgbPlanarF32,
1178 actual: other,
1179 });
1180 }
1181 }
1182
1183 let meta = self.meta.get().ok_or(EngineError::NotInitialized)?;
1184 let mut guard = self.state.lock().await;
1185 let state = guard.as_mut().ok_or(EngineError::NotInitialized)?;
1186
1187 match &mut state.ring {
1189 Some(ring) if ring.needs_realloc(input.width, input.height) => {
1190 debug!(
1191 old_w = ring.alloc_dims.0,
1192 old_h = ring.alloc_dims.1,
1193 new_w = input.width,
1194 new_h = input.height,
1195 "Reallocating output ring"
1196 );
1197 ring.reallocate(&self.ctx, input.width, input.height, meta.scale)?;
1198 }
1199 None => {
1200 debug!(
1201 w = input.width,
1202 h = input.height,
1203 slots = self.ring_size,
1204 "Lazily creating output ring"
1205 );
1206 state.ring = Some(OutputRing::new(
1207 &self.ctx,
1208 input.width,
1209 input.height,
1210 meta.scale,
1211 self.ring_size,
1212 self.min_ring_slots,
1213 )?);
1214 }
1215 Some(_) => {}
1216 }
1217
1218 let ring = state.ring.as_mut().unwrap();
1219
1220 #[cfg(feature = "debug-alloc")]
1222 {
1223 rave_core::debug_alloc::reset();
1224 rave_core::debug_alloc::enable();
1225 }
1226
1227 let output_arc = ring.acquire()?;
1228 let output_ptr = *(*output_arc).device_ptr();
1229 let output_bytes = ring.slot_bytes;
1230
1231 let t_start = std::time::Instant::now();
1233
1234 let mem_info = self.mem_info()?;
1235 Self::run_io_bound(
1236 &mut state.session,
1237 meta,
1238 &input,
1239 output_ptr,
1240 output_bytes,
1241 &self.ctx,
1242 &mem_info,
1243 )?;
1244
1245 let elapsed_us = t_start.elapsed().as_micros() as u64;
1246 self.inference_metrics.record(elapsed_us);
1247
1248 #[cfg(feature = "debug-alloc")]
1249 {
1250 rave_core::debug_alloc::disable();
1251 let host_allocs = rave_core::debug_alloc::count();
1252 debug_assert_eq!(
1253 host_allocs, 0,
1254 "VIOLATION: {host_allocs} host allocations during inference"
1255 );
1256 }
1257
1258 let out_w = input.width * meta.scale;
1259 let out_h = input.height * meta.scale;
1260 let elem_bytes = input.format.element_bytes();
1261
1262 Ok(GpuTexture {
1263 data: output_arc,
1264 width: out_w,
1265 height: out_h,
1266 pitch: (out_w as usize) * elem_bytes,
1267 format: input.format, })
1269 }
1270
1271 async fn shutdown(&self) -> Result<()> {
1272 let mut guard = self.state.lock().await;
1273 if let Some(state) = guard.take() {
1274 info!("Shutting down TensorRT backend");
1275 self.ctx.sync_all()?;
1276
1277 if let Some(ring) = &state.ring {
1279 let (reuse, contention, first) = ring.metrics.snapshot();
1280 info!(reuse, contention, first, "Final ring metrics");
1281 }
1282
1283 let snap = self.inference_metrics.snapshot();
1285 info!(
1286 frames = snap.frames_inferred,
1287 avg_us = snap.avg_inference_us,
1288 peak_us = snap.peak_inference_us,
1289 precision = ?self.precision_policy,
1290 "Final inference metrics"
1291 );
1292
1293 let (current, peak) = self.ctx.vram_usage();
1295 info!(
1296 current_mb = current / (1024 * 1024),
1297 peak_mb = peak / (1024 * 1024),
1298 "Final VRAM usage"
1299 );
1300
1301 drop(state.ring);
1302 drop(state.session);
1303 debug!("TensorRT backend shutdown complete");
1304 }
1305 Ok(())
1306 }
1307
1308 fn metadata(&self) -> Result<&ModelMetadata> {
1309 self.meta.get().ok_or(EngineError::NotInitialized)
1310 }
1311}
1312
1313impl Drop for TensorRtBackend {
1314 fn drop(&mut self) {
1315 if let Ok(mut guard) = self.state.try_lock()
1316 && let Some(state) = guard.take()
1317 {
1318 let _ = self.ctx.sync_all();
1319 drop(state);
1320 }
1321 }
1322}