1#[cfg(target_os = "linux")]
36use std::collections::HashSet;
37use std::env;
38#[cfg(target_os = "linux")]
39use std::ffi::{CStr, CString, c_char, c_void};
40#[cfg(target_os = "linux")]
41use std::path::Path;
42use std::path::PathBuf;
43use std::sync::atomic::{AtomicU64, Ordering};
44use std::sync::{Arc, OnceLock};
45#[cfg(target_os = "linux")]
46use std::time::UNIX_EPOCH;
47
48use async_trait::async_trait;
49use cudarc::driver::{DevicePtr, DeviceSlice};
50use tokio::sync::Mutex;
51use tracing::{debug, info, warn};
52
53use ort::session::Session;
54use ort::sys as ort_sys;
55use ort::value::Value as OrtValue;
56
57use rave_core::backend::{ModelMetadata, UpscaleBackend};
58use rave_core::context::GpuContext;
59use rave_core::error::{EngineError, Result};
60use rave_core::types::{GpuTexture, PixelFormat};
61
62use ort::execution_providers::{CUDAExecutionProvider, TensorRTExecutionProvider};
63
64unsafe fn create_tensor_from_device_memory(
66 ptr: *mut std::ffi::c_void,
67 bytes: usize,
68 shape: &[i64],
69 elem_type: ort::tensor::TensorElementType,
70) -> Result<OrtValue> {
71 let api = ort::api();
72
73 let mut mem_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut();
75 let name = std::ffi::CString::new("Cuda").unwrap();
76 let status = unsafe {
77 (api.CreateMemoryInfo)(
78 name.as_ptr(),
79 ort_sys::OrtAllocatorType::OrtArenaAllocator,
80 0, ort_sys::OrtMemType::OrtMemTypeDefault,
82 &mut mem_info_ptr,
83 )
84 };
85 if !status.0.is_null() {
86 unsafe { (api.ReleaseStatus)(status.0) };
87 return Err(EngineError::Inference(ort::Error::new(
88 "Failed to create MemoryInfo",
89 )));
90 }
91
92 let mut ort_value_ptr: *mut ort_sys::OrtValue = std::ptr::null_mut();
94 let status = unsafe {
99 (api.CreateTensorWithDataAsOrtValue)(
100 mem_info_ptr,
101 ptr,
102 bytes as _,
103 shape.as_ptr(),
104 shape.len() as _,
105 elem_type.into(),
106 &mut ort_value_ptr,
107 )
108 };
109
110 unsafe { (api.ReleaseMemoryInfo)(mem_info_ptr) };
114
115 if !status.0.is_null() {
116 unsafe { (api.ReleaseStatus)(status.0) };
117 return Err(EngineError::Inference(ort::Error::new(
118 "Failed to create Tensor",
119 )));
120 }
121
122 Ok(unsafe {
124 ort::value::Value::<ort::value::DynValueTypeMarker>::from_ptr(
125 std::ptr::NonNull::new(ort_value_ptr).unwrap(),
126 None,
127 )
128 })
129}
130
131#[derive(Clone, Debug, Default)]
135pub enum PrecisionPolicy {
136 Fp32,
138 #[default]
140 Fp16,
141 Int8 { calibration_table: PathBuf },
144}
145
146#[derive(Clone, Debug)]
150pub struct BatchConfig {
151 pub max_batch: usize,
154 pub latency_deadline_us: u64,
157}
158
159impl Default for BatchConfig {
160 fn default() -> Self {
161 Self {
162 max_batch: 1,
163 latency_deadline_us: 8_000, }
165 }
166}
167
168pub fn validate_batch_config(cfg: &BatchConfig) -> Result<()> {
172 if cfg.max_batch > 1 {
173 return Err(EngineError::InvariantViolation(
174 "micro-batching is not implemented; max_batch must be 1 (set max_batch=1)".into(),
175 ));
176 }
177 Ok(())
178}
179#[derive(Debug)]
183pub struct InferenceMetrics {
184 pub frames_inferred: AtomicU64,
186 pub total_inference_us: AtomicU64,
188 pub peak_inference_us: AtomicU64,
190}
191
192impl InferenceMetrics {
193 pub const fn new() -> Self {
195 Self {
196 frames_inferred: AtomicU64::new(0),
197 total_inference_us: AtomicU64::new(0),
198 peak_inference_us: AtomicU64::new(0),
199 }
200 }
201
202 pub fn record(&self, elapsed_us: u64) {
204 self.frames_inferred.fetch_add(1, Ordering::Relaxed);
205 self.total_inference_us
206 .fetch_add(elapsed_us, Ordering::Relaxed);
207 self.peak_inference_us
208 .fetch_max(elapsed_us, Ordering::Relaxed);
209 }
210
211 pub fn snapshot(&self) -> InferenceMetricsSnapshot {
213 let frames = self.frames_inferred.load(Ordering::Relaxed);
214 let total = self.total_inference_us.load(Ordering::Relaxed);
215 let peak = self.peak_inference_us.load(Ordering::Relaxed);
216 InferenceMetricsSnapshot {
217 frames_inferred: frames,
218 avg_inference_us: if frames > 0 { total / frames } else { 0 },
219 peak_inference_us: peak,
220 }
221 }
222}
223
224impl Default for InferenceMetrics {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230#[derive(Clone, Debug)]
232pub struct InferenceMetricsSnapshot {
233 pub frames_inferred: u64,
235 pub avg_inference_us: u64,
237 pub peak_inference_us: u64,
239}
240
241#[derive(Debug, Clone, Copy)]
245pub struct RingMetricsSnapshot {
246 pub reuse: u64,
248 pub contention: u64,
250 pub first_use: u64,
252}
253
254#[derive(Debug)]
256pub struct RingMetrics {
257 pub slot_reuse_count: AtomicU64,
259 pub slot_contention_events: AtomicU64,
261 pub slot_first_use_count: AtomicU64,
263}
264
265impl RingMetrics {
266 pub const fn new() -> Self {
268 Self {
269 slot_reuse_count: AtomicU64::new(0),
270 slot_contention_events: AtomicU64::new(0),
271 slot_first_use_count: AtomicU64::new(0),
272 }
273 }
274
275 pub fn snapshot(&self) -> RingMetricsSnapshot {
277 RingMetricsSnapshot {
278 reuse: self.slot_reuse_count.load(Ordering::Relaxed),
279 contention: self.slot_contention_events.load(Ordering::Relaxed),
280 first_use: self.slot_first_use_count.load(Ordering::Relaxed),
281 }
282 }
283}
284
285impl Default for RingMetrics {
286 fn default() -> Self {
287 Self::new()
288 }
289}
290
291pub struct OutputRing {
295 slots: Vec<Arc<cudarc::driver::CudaSlice<u8>>>,
296 cursor: usize,
297 pub slot_bytes: usize,
299 pub alloc_dims: (u32, u32),
301 used: Vec<bool>,
303 pub metrics: RingMetrics,
305}
306
307impl OutputRing {
308 pub fn new(
313 ctx: &GpuContext,
314 in_w: u32,
315 in_h: u32,
316 scale: u32,
317 count: usize,
318 min_slots: usize,
319 ) -> Result<Self> {
320 if count < min_slots {
321 return Err(EngineError::DimensionMismatch(format!(
322 "OutputRing: ring_size ({count}) < required minimum ({min_slots}). \
323 Ring must be ≥ downstream_channel_capacity + 2."
324 )));
325 }
326 if count < 2 {
327 return Err(EngineError::DimensionMismatch(
328 "OutputRing: ring_size must be ≥ 2 for double-buffering".into(),
329 ));
330 }
331
332 let out_w = (in_w * scale) as usize;
333 let out_h = (in_h * scale) as usize;
334 let slot_bytes = 3 * out_w * out_h * std::mem::size_of::<f32>();
335
336 let slots = (0..count)
337 .map(|_| ctx.alloc(slot_bytes).map(Arc::new))
338 .collect::<Result<Vec<_>>>()?;
339
340 debug!(count, slot_bytes, out_w, out_h, "Output ring allocated");
341
342 Ok(Self {
343 slots,
344 cursor: 0,
345 slot_bytes,
346 alloc_dims: (in_w, in_h),
347 used: vec![false; count],
348 metrics: RingMetrics::new(),
349 })
350 }
351
352 pub fn acquire(&mut self) -> Result<Arc<cudarc::driver::CudaSlice<u8>>> {
359 let slot = &self.slots[self.cursor];
360 let sc = Arc::strong_count(slot);
361
362 if sc != 1 {
363 self.metrics
364 .slot_contention_events
365 .fetch_add(1, Ordering::Relaxed);
366 return Err(EngineError::BufferTooSmall {
367 need: self.slot_bytes,
368 have: 0,
369 });
370 }
371
372 debug_assert_eq!(
374 Arc::strong_count(slot),
375 1,
376 "OutputRing: slot {} strong_count must be 1 before reuse, got {}",
377 self.cursor,
378 sc
379 );
380
381 if self.used[self.cursor] {
382 self.metrics
383 .slot_reuse_count
384 .fetch_add(1, Ordering::Relaxed);
385 } else {
386 self.used[self.cursor] = true;
387 self.metrics
388 .slot_first_use_count
389 .fetch_add(1, Ordering::Relaxed);
390 }
391
392 let cloned = Arc::clone(slot);
393 self.cursor = (self.cursor + 1) % self.slots.len();
394 Ok(cloned)
395 }
396
397 pub fn needs_realloc(&self, in_w: u32, in_h: u32) -> bool {
399 self.alloc_dims != (in_w, in_h)
400 }
401
402 pub fn reallocate(&mut self, ctx: &GpuContext, in_w: u32, in_h: u32, scale: u32) -> Result<()> {
404 for (i, slot) in self.slots.iter().enumerate() {
405 let sc = Arc::strong_count(slot);
406 if sc != 1 {
407 return Err(EngineError::DimensionMismatch(format!(
408 "Cannot reallocate ring: slot {} still in use (strong_count={})",
409 i, sc,
410 )));
411 }
412 }
413
414 for _slot in &self.slots {
416 ctx.vram_dec(self.slot_bytes);
417 }
418
419 let count = self.slots.len();
420 let out_w = (in_w * scale) as usize;
421 let out_h = (in_h * scale) as usize;
422 let slot_bytes = 3 * out_w * out_h * std::mem::size_of::<f32>();
423
424 self.slots = (0..count)
425 .map(|_| ctx.alloc(slot_bytes).map(Arc::new))
426 .collect::<Result<Vec<_>>>()?;
427 self.cursor = 0;
428 self.slot_bytes = slot_bytes;
429 self.alloc_dims = (in_w, in_h);
430 self.used = vec![false; count];
431
432 debug!(count, slot_bytes, out_w, out_h, "Output ring reallocated");
433 Ok(())
434 }
435
436 pub fn len(&self) -> usize {
438 self.slots.len()
439 }
440
441 pub fn is_empty(&self) -> bool {
443 self.slots.is_empty()
444 }
445}
446
447struct InferenceState {
450 session: Session,
451 ring: Option<OutputRing>,
452}
453
454fn ort_element_type(format: PixelFormat) -> ort::tensor::TensorElementType {
456 match format {
457 PixelFormat::RgbPlanarF16 => ort::tensor::TensorElementType::Float16,
458 _ => ort::tensor::TensorElementType::Float32,
459 }
460}
461
462pub struct TensorRtBackend {
470 model_path: PathBuf,
471 ctx: Arc<GpuContext>,
472 device_id: i32,
473 ring_size: usize,
474 min_ring_slots: usize,
475 meta: OnceLock<ModelMetadata>,
476 selected_provider: OnceLock<String>,
477 state: Mutex<Option<InferenceState>>,
478 pub inference_metrics: InferenceMetrics,
480 pub precision_policy: PrecisionPolicy,
482 pub batch_config: BatchConfig,
484}
485
486#[derive(Copy, Clone, Debug, PartialEq, Eq)]
487enum OrtEpMode {
488 Auto,
489 TensorRtOnly,
490 CudaOnly,
491}
492
493#[cfg(target_os = "linux")]
494unsafe extern "C" {
495 fn dlopen(filename: *const c_char, flags: i32) -> *mut c_void;
496 fn dlerror() -> *const c_char;
497}
498
499#[cfg(target_os = "linux")]
500const RTLD_NOW: i32 = 2;
501#[cfg(all(target_os = "linux", test))]
502const RTLD_LOCAL: i32 = 0;
503#[cfg(target_os = "linux")]
504const RTLD_GLOBAL: i32 = 0x100;
505
506#[derive(Clone, Copy)]
507enum OrtProviderKind {
508 Cuda,
509 TensorRt,
510}
511
512#[cfg(target_os = "linux")]
513type ProviderCandidate = (PathBuf, &'static str);
514
515impl OrtProviderKind {
516 #[cfg(target_os = "linux")]
517 fn soname(self) -> &'static str {
518 match self {
519 OrtProviderKind::Cuda => "libonnxruntime_providers_cuda.so",
520 OrtProviderKind::TensorRt => "libonnxruntime_providers_tensorrt.so",
521 }
522 }
523
524 #[cfg(target_os = "linux")]
525 fn label(self) -> &'static str {
526 match self {
527 OrtProviderKind::Cuda => "providers_cuda",
528 OrtProviderKind::TensorRt => "providers_tensorrt",
529 }
530 }
531}
532
533impl TensorRtBackend {
534 pub fn new(
542 model_path: PathBuf,
543 ctx: Arc<GpuContext>,
544 device_id: i32,
545 ring_size: usize,
546 downstream_capacity: usize,
547 ) -> Self {
548 Self::with_precision(
549 model_path,
550 ctx,
551 device_id,
552 ring_size,
553 downstream_capacity,
554 PrecisionPolicy::default(),
555 BatchConfig::default(),
556 )
557 }
558
559 pub fn with_precision(
561 model_path: PathBuf,
562 ctx: Arc<GpuContext>,
563 device_id: i32,
564 ring_size: usize,
565 downstream_capacity: usize,
566 precision_policy: PrecisionPolicy,
567 batch_config: BatchConfig,
568 ) -> Self {
569 let min_ring_slots = downstream_capacity + 2;
570 assert!(
571 ring_size >= min_ring_slots,
572 "ring_size ({ring_size}) must be ≥ downstream_capacity + 2 ({min_ring_slots})"
573 );
574 Self {
575 model_path,
576 ctx,
577 device_id,
578 ring_size,
579 min_ring_slots,
580 meta: OnceLock::new(),
581 selected_provider: OnceLock::new(),
582 state: Mutex::new(None),
583 inference_metrics: InferenceMetrics::new(),
584 precision_policy,
585 batch_config,
586 }
587 }
588
589 fn mem_info(&self) -> Result<ort::memory::MemoryInfo> {
592 ort::memory::MemoryInfo::new(
593 ort::memory::AllocationDevice::CUDA,
594 0,
595 ort::memory::AllocatorType::Device,
596 ort::memory::MemoryType::Default,
597 )
598 .map_err(|e| EngineError::ModelMetadata(format!("MemoryInfo: {e}")))
599 }
600
601 pub async fn ring_metrics(&self) -> Option<RingMetricsSnapshot> {
603 let guard = self.state.lock().await;
604 guard
605 .as_ref()
606 .and_then(|s| s.ring.as_ref())
607 .map(|r| r.metrics.snapshot())
608 }
609
610 pub fn selected_provider(&self) -> Option<&str> {
612 self.selected_provider.get().map(String::as_str)
613 }
614
615 fn infer_scale_from_name(name: &str) -> Option<u32> {
616 let lower = name.to_ascii_lowercase();
617 let bytes = lower.as_bytes();
618 for i in 0..bytes.len() {
619 if bytes[i] != b'x' {
620 continue;
621 }
622 if i > 0 && bytes[i - 1].is_ascii_digit() {
623 let d = (bytes[i - 1] - b'0') as u32;
624 if d >= 2 {
625 return Some(d);
626 }
627 }
628 if i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
629 let d = (bytes[i + 1] - b'0') as u32;
630 if d >= 2 {
631 return Some(d);
632 }
633 }
634 }
635 None
636 }
637
638 fn extract_metadata(session: &Session) -> Result<ModelMetadata> {
639 let inputs = session.inputs();
640 let outputs = session.outputs();
641
642 if inputs.is_empty() || outputs.is_empty() {
643 return Err(EngineError::ModelMetadata(
644 "Model must have at least one input and one output tensor".into(),
645 ));
646 }
647
648 let input_info = &inputs[0];
655 let output_info = &outputs[0];
656 let input_name = input_info.name().to_string();
657 let output_name = output_info.name().to_string();
658
659 let input_dims = match input_info.dtype() {
662 ort::value::ValueType::Tensor { shape, .. } => shape.clone(),
663 other => {
664 return Err(EngineError::ModelMetadata(format!(
665 "Expected tensor input, got {:?}",
666 other
667 )));
668 }
669 };
670
671 let output_dims = match output_info.dtype() {
672 ort::value::ValueType::Tensor { shape, .. } => shape.clone(),
673 other => {
674 return Err(EngineError::ModelMetadata(format!(
675 "Expected tensor output, got {:?}",
676 other
677 )));
678 }
679 };
680
681 if input_dims.len() != 4 || output_dims.len() != 4 {
682 return Err(EngineError::ModelMetadata(format!(
683 "Expected 4D tensors (NCHW), got input={}D output={}D",
684 input_dims.len(),
685 output_dims.len()
686 )));
687 }
688
689 let input_channels = input_dims[1] as u32;
690
691 let ih = input_dims[2];
692 let iw = input_dims[3];
693 let oh = output_dims[2];
694 let ow = output_dims[3];
695
696 let name = session
697 .metadata()
698 .map(|m| m.name().unwrap_or("unknown".to_string()))
699 .unwrap_or_else(|_| "unknown".to_string());
700
701 let scale = if ih > 0 && oh > 0 && iw > 0 && ow > 0 {
702 (oh / ih) as u32
703 } else if let Some(inferred) = Self::infer_scale_from_name(&name) {
704 warn!(
705 model_name = %name,
706 inferred_scale = inferred,
707 "Dynamic spatial axes — inferring upscale scale from model name"
708 );
709 inferred
710 } else {
711 warn!(
712 model_name = %name,
713 "Dynamic spatial axes — unable to infer scale from metadata; defaulting to scale=2"
714 );
715 2
716 };
717
718 let min_input_hw = (
719 if ih > 0 { ih as u32 } else { 1 },
720 if iw > 0 { iw as u32 } else { 1 },
721 );
722
723 let max_input_hw = (
724 if ih > 0 { ih as u32 } else { u32::MAX },
725 if iw > 0 { iw as u32 } else { u32::MAX },
726 );
727
728 Ok(ModelMetadata {
729 name,
730 scale,
731 input_name,
732 output_name,
733 input_channels,
734 min_input_hw,
735 max_input_hw,
736 })
737 }
738
739 fn ort_ep_mode() -> OrtEpMode {
740 match env::var("RAVE_ORT_TENSORRT")
741 .unwrap_or_else(|_| "auto".to_string())
742 .to_lowercase()
743 .as_str()
744 {
745 "0" | "off" | "false" | "cuda" | "cuda-only" => OrtEpMode::CudaOnly,
746 "1" | "on" | "true" | "trt" | "trt-only" => OrtEpMode::TensorRtOnly,
747 _ => OrtEpMode::Auto,
748 }
749 }
750
751 #[cfg(target_os = "linux")]
752 fn is_wsl2() -> bool {
753 std::fs::read_to_string("/proc/sys/kernel/osrelease")
754 .map(|s| s.to_ascii_lowercase().contains("microsoft"))
755 .unwrap_or(false)
756 }
757
758 #[cfg(not(target_os = "linux"))]
759 fn is_wsl2() -> bool {
760 false
761 }
762
763 #[cfg(target_os = "linux")]
764 fn ort_provider_cache_dirs_newest_first() -> Vec<PathBuf> {
765 let mut dirs = Vec::<(u128, PathBuf)>::new();
766 let Some(home) = env::var_os("HOME") else {
767 return Vec::new();
768 };
769 let base = PathBuf::from(home).join(".cache/ort.pyke.io/dfbin");
770 let Ok(triples) = std::fs::read_dir(base) else {
771 return Vec::new();
772 };
773
774 for triple in triples.flatten() {
775 let triple_path = triple.path();
776 if !triple_path.is_dir() {
777 continue;
778 }
779 let Ok(hashes) = std::fs::read_dir(triple_path) else {
780 continue;
781 };
782 for hash in hashes.flatten() {
783 let path = hash.path();
784 if !path.is_dir() {
785 continue;
786 }
787 let modified = std::fs::metadata(&path)
788 .and_then(|m| m.modified())
789 .ok()
790 .and_then(|t| t.duration_since(UNIX_EPOCH).ok())
791 .map(|d| d.as_nanos())
792 .unwrap_or(0);
793 dirs.push((modified, path));
794 }
795 }
796
797 dirs.sort_by(|a, b| b.cmp(a));
798 dirs.into_iter().map(|(_, path)| path).collect()
799 }
800
801 #[cfg(target_os = "linux")]
802 fn provider_dir_candidates(
803 ort_dylib_path: Option<PathBuf>,
804 ort_lib_location: Option<PathBuf>,
805 exe_dir: Option<PathBuf>,
806 cache_dirs: Vec<PathBuf>,
807 on_wsl: bool,
808 ) -> Vec<ProviderCandidate> {
809 let mut dirs = Vec::<ProviderCandidate>::new();
810 if let Some(dir) = ort_dylib_path {
811 dirs.push((dir, "ORT_DYLIB_PATH"));
812 }
813 if let Some(dir) = ort_lib_location {
814 dirs.push((dir, "ORT_LIB_LOCATION"));
815 }
816 if on_wsl {
817 for dir in cache_dirs {
818 dirs.push((dir, "ort_cache_newest"));
819 }
820 if let Some(dir) = exe_dir {
821 dirs.push((dir.clone(), "exe_dir"));
822 dirs.push((dir.join("deps"), "exe_dir/deps"));
823 }
824 } else {
825 if let Some(dir) = exe_dir {
826 dirs.push((dir.clone(), "exe_dir"));
827 dirs.push((dir.join("deps"), "exe_dir/deps"));
828 }
829 for dir in cache_dirs {
830 dirs.push((dir, "ort_cache_newest"));
831 }
832 }
833
834 let mut uniq = HashSet::<PathBuf>::new();
835 dirs.retain(|(p, _)| uniq.insert(p.clone()));
836 dirs
837 }
838
839 #[cfg(target_os = "linux")]
840 fn ort_provider_search_inputs() -> (
841 Option<PathBuf>,
842 Option<PathBuf>,
843 Option<PathBuf>,
844 Vec<PathBuf>,
845 bool,
846 ) {
847 let ort_dylib_path = env::var_os("ORT_DYLIB_PATH").map(PathBuf::from);
848 let ort_lib_location = env::var_os("ORT_LIB_LOCATION").map(PathBuf::from);
849 let exe_dir = env::current_exe()
850 .ok()
851 .and_then(|exe| exe.parent().map(|dir| dir.to_path_buf()));
852 let cache_dirs = Self::ort_provider_cache_dirs_newest_first();
853 let on_wsl = Self::is_wsl2();
854 (
855 ort_dylib_path,
856 ort_lib_location,
857 exe_dir,
858 cache_dirs,
859 on_wsl,
860 )
861 }
862
863 #[cfg(all(target_os = "linux", test))]
864 fn ort_provider_search_dirs() -> Vec<ProviderCandidate> {
865 let (ort_dylib_path, ort_lib_location, exe_dir, cache_dirs, on_wsl) =
866 Self::ort_provider_search_inputs();
867 Self::provider_dir_candidates(
868 ort_dylib_path,
869 ort_lib_location,
870 exe_dir,
871 cache_dirs,
872 on_wsl,
873 )
874 }
875
876 #[cfg(target_os = "linux")]
877 fn render_provider_candidates(candidates: &[ProviderCandidate]) -> String {
878 if candidates.is_empty() {
879 return "none".into();
880 }
881 candidates
882 .iter()
883 .enumerate()
884 .map(|(idx, (dir, source))| format!("{}:{source}:{}", idx + 1, dir.display()))
885 .collect::<Vec<_>>()
886 .join(", ")
887 }
888
889 #[cfg(all(target_os = "linux", test))]
890 fn ort_provider_candidates(lib_name: &str) -> Vec<PathBuf> {
891 Self::ort_provider_search_dirs()
892 .into_iter()
893 .map(|(dir, _)| dir.join(lib_name))
894 .collect()
895 }
896
897 #[cfg(target_os = "linux")]
898 fn configure_ort_loader_path(dir: &Path) {
899 unsafe { env::set_var("ORT_DYLIB_PATH", dir) };
901 unsafe { env::set_var("ORT_LIB_LOCATION", dir) };
903 }
904
905 #[cfg(target_os = "linux")]
906 fn resolve_provider_dir(
907 kind: OrtProviderKind,
908 candidates: &[ProviderCandidate],
909 ort_dylib_path: Option<&Path>,
910 ort_lib_location: Option<&Path>,
911 command: &str,
912 mode: &str,
913 ep_mode: &str,
914 ) -> Result<(PathBuf, &'static str)> {
915 let provider = kind.soname();
916 for (dir, source) in candidates {
917 let shared = dir.join("libonnxruntime_providers_shared.so");
918 let provider_path = dir.join(provider);
919 if shared.is_file() && provider_path.is_file() {
920 info!(
921 path = %shared.display(),
922 source = *source,
923 "ORT providers_shared resolved"
924 );
925 info!(
926 provider = kind.label(),
927 path = %provider_path.display(),
928 source = *source,
929 "ORT provider resolved"
930 );
931 return Ok((dir.clone(), *source));
932 }
933 }
934
935 let render_path = |v: Option<&Path>| -> String {
936 v.map(|p| p.display().to_string())
937 .unwrap_or_else(|| "<unset>".to_string())
938 };
939 let checked = Self::render_provider_candidates(candidates);
940 Err(EngineError::ModelMetadata(format!(
941 "ORT provider directory resolution failed: command={command} mode={mode} ep_mode={ep_mode} provider={} \
942reason=no candidate directory contained required pair \
943(libonnxruntime_providers_shared.so + {}). \
944overrides={{ORT_DYLIB_PATH={}, ORT_LIB_LOCATION={}}}. \
945candidates_checked=[{}]. \
946Set ORT_DYLIB_PATH or ORT_LIB_LOCATION to a valid provider directory and re-run scripts/test_ort_provider_load.sh.",
947 kind.label(),
948 provider,
949 render_path(ort_dylib_path),
950 render_path(ort_lib_location),
951 checked
952 )))
953 }
954
955 #[cfg(target_os = "linux")]
956 fn resolve_ort_provider_dir(kind: OrtProviderKind) -> Result<(PathBuf, &'static str)> {
957 let (ort_dylib_path, ort_lib_location, exe_dir, cache_dirs, on_wsl) =
958 Self::ort_provider_search_inputs();
959 let candidates = Self::provider_dir_candidates(
960 ort_dylib_path.clone(),
961 ort_lib_location.clone(),
962 exe_dir,
963 cache_dirs,
964 on_wsl,
965 );
966 let (dir, source) = Self::resolve_provider_dir(
967 kind,
968 &candidates,
969 ort_dylib_path.as_deref(),
970 ort_lib_location.as_deref(),
971 "backend_initialize",
972 "provider_preload",
973 kind.label(),
974 )?;
975 Self::configure_ort_loader_path(&dir);
976 Ok((dir, source))
977 }
978
979 #[cfg(target_os = "linux")]
980 fn dlopen_path(path: &Path, flags: i32) -> Result<()> {
981 let cpath = CString::new(path.to_string_lossy().as_bytes())
982 .map_err(|_| EngineError::ModelMetadata("Invalid dlopen path".into()))?;
983 let handle = unsafe { dlopen(cpath.as_ptr(), flags) };
985 if handle.is_null() {
986 let err = unsafe {
988 let p = dlerror();
989 if p.is_null() {
990 "unknown dlopen error".to_string()
991 } else {
992 CStr::from_ptr(p).to_string_lossy().to_string()
993 }
994 };
995 return Err(EngineError::ModelMetadata(format!(
996 "dlopen failed for {}: {err}",
997 path.display()
998 )));
999 }
1000 Ok(())
1001 }
1002
1003 #[cfg(target_os = "linux")]
1004 fn preload_cuda_runtime_libs() {
1005 let mut roots = Vec::<PathBuf>::new();
1007 if let Some(dir) = env::var_os("RAVE_CUDA_RUNTIME_DIR") {
1008 roots.push(PathBuf::from(dir));
1009 }
1010 roots.push(PathBuf::from("/usr/local/cuda-12/targets/x86_64-linux/lib"));
1011 roots.push(PathBuf::from("/usr/local/cuda/lib64"));
1012
1013 let libs = [
1014 "libcudart.so.12",
1015 "libcublasLt.so.12",
1016 "libcublas.so.12",
1017 "libnvrtc.so.12",
1018 "libcurand.so.10",
1019 "libcufft.so.11",
1020 ];
1021
1022 for root in roots {
1023 if !root.is_dir() {
1024 continue;
1025 }
1026 let mut loaded = 0usize;
1027 for lib in libs {
1028 let p = root.join(lib);
1029 if !p.is_file() {
1030 continue;
1031 }
1032 if Self::dlopen_path(&p, RTLD_NOW | RTLD_GLOBAL).is_ok() {
1033 loaded += 1;
1034 }
1035 }
1036 if loaded > 0 {
1037 info!(root = %root.display(), loaded, "Preloaded CUDA runtime libraries for ORT");
1038 break;
1039 }
1040 }
1041 }
1042
1043 fn preload_ort_provider_pair(kind: OrtProviderKind) -> Result<()> {
1047 #[cfg(target_os = "linux")]
1048 {
1049 let (dir, source) = Self::resolve_ort_provider_dir(kind)?;
1050 let shared = dir.join("libonnxruntime_providers_shared.so");
1051 let provider = dir.join(kind.soname());
1052
1053 Self::dlopen_path(&shared, RTLD_NOW | RTLD_GLOBAL).map_err(|e| {
1054 EngineError::ModelMetadata(format!(
1055 "Failed loading providers_shared from {} ({e}). \
1056Ensure ORT_DYLIB_PATH/ORT_LIB_LOCATION points to a valid ORT cache dir.",
1057 shared.display()
1058 ))
1059 })?;
1060
1061 info!(
1062 path = %provider.display(),
1063 provider = kind.label(),
1064 "Skipping explicit provider dlopen; relying on startup LD_LIBRARY_PATH + ORT registration"
1065 );
1066
1067 info!(
1068 source,
1069 dir = %dir.display(),
1070 provider = kind.label(),
1071 path = %provider.display(),
1072 "ORT provider pair prepared (providers_shared preloaded, provider path configured)"
1073 );
1074 Ok(())
1075 }
1076
1077 #[cfg(not(target_os = "linux"))]
1078 {
1079 let _ = kind;
1080 Ok(())
1081 }
1082 }
1083
1084 #[cfg(all(target_os = "linux", test))]
1085 fn preload_ort_provider_bridge() -> Result<()> {
1086 for path in Self::ort_provider_candidates("libonnxruntime_providers_shared.so") {
1087 if path.is_file() {
1088 return Self::dlopen_path(&path, RTLD_NOW | RTLD_GLOBAL);
1089 }
1090 }
1091 Err(EngineError::ModelMetadata(
1092 "Could not preload libonnxruntime_providers_shared.so from known search paths".into(),
1093 ))
1094 }
1095
1096 #[cfg(all(not(target_os = "linux"), test))]
1097 #[allow(dead_code)]
1098 fn preload_ort_provider_bridge() -> Result<()> {
1099 Ok(())
1100 }
1101
1102 #[cfg(not(target_os = "linux"))]
1103 fn preload_cuda_runtime_libs() {}
1104
1105 fn build_trt_session(&self) -> Result<Session> {
1106 Self::preload_cuda_runtime_libs();
1107 Self::preload_ort_provider_pair(OrtProviderKind::TensorRt)?;
1108
1109 let mut trt_ep = TensorRTExecutionProvider::default()
1110 .with_device_id(self.device_id)
1111 .with_engine_cache(true)
1112 .with_engine_cache_path(
1113 self.model_path
1114 .parent()
1115 .unwrap_or(&self.model_path)
1116 .join("trt_cache")
1117 .to_string_lossy()
1118 .to_string(),
1119 );
1120
1121 match &self.precision_policy {
1122 PrecisionPolicy::Fp32 => {
1123 info!("TRT precision: FP32 (no mixed precision)");
1124 }
1125 PrecisionPolicy::Fp16 => {
1126 trt_ep = trt_ep.with_fp16(true);
1127 info!("TRT precision: FP16 mixed precision");
1128 }
1129 PrecisionPolicy::Int8 { calibration_table } => {
1130 trt_ep = trt_ep.with_fp16(true).with_int8(true);
1131 info!(
1132 table = %calibration_table.display(),
1133 "TRT precision: INT8 with calibration table"
1134 );
1135 }
1136 }
1137
1138 Session::builder()?
1139 .with_execution_providers([trt_ep.build().error_on_failure()])?
1140 .with_intra_threads(1)?
1141 .commit_from_file(&self.model_path)
1142 .map_err(Into::into)
1143 }
1144
1145 fn build_cuda_session(&self) -> Result<Session> {
1146 Self::preload_cuda_runtime_libs();
1147 Self::preload_ort_provider_pair(OrtProviderKind::Cuda)?;
1148 let cuda_ep = CUDAExecutionProvider::default().with_device_id(self.device_id);
1149 Session::builder()?
1150 .with_execution_providers([cuda_ep.build().error_on_failure()])?
1151 .with_intra_threads(1)?
1152 .commit_from_file(&self.model_path)
1153 .map_err(Into::into)
1154 }
1155
1156 fn pointer_identity_mismatch(
1157 input_ptr: u64,
1158 texture_ptr: u64,
1159 output_ptr: u64,
1160 ring_slot_ptr: u64,
1161 ) -> Option<String> {
1162 if input_ptr != texture_ptr {
1163 return Some(format!(
1164 "POINTER MISMATCH: IO-bound input (0x{input_ptr:016x}) != GpuTexture (0x{texture_ptr:016x})"
1165 ));
1166 }
1167 if output_ptr != ring_slot_ptr {
1168 return Some(format!(
1169 "POINTER MISMATCH: IO-bound output (0x{output_ptr:016x}) != ring slot (0x{ring_slot_ptr:016x})"
1170 ));
1171 }
1172 None
1173 }
1174
1175 fn verify_pointer_identity(
1181 input_ptr: u64,
1182 output_ptr: u64,
1183 input_texture: &GpuTexture,
1184 ring_slot_ptr: u64,
1185 ) -> Result<()> {
1186 let texture_ptr = input_texture.device_ptr();
1187 debug!(
1188 input_ptr = format!("0x{:016x}", input_ptr),
1189 texture_ptr = format!("0x{:016x}", texture_ptr),
1190 output_ptr = format!("0x{:016x}", output_ptr),
1191 ring_slot_ptr = format!("0x{:016x}", ring_slot_ptr),
1192 "IO-binding pointer identity audit"
1193 );
1194
1195 if let Some(message) =
1196 Self::pointer_identity_mismatch(input_ptr, texture_ptr, output_ptr, ring_slot_ptr)
1197 {
1198 debug_assert!(false, "{message}");
1199 #[cfg(feature = "audit-no-host-copies")]
1200 if rave_core::host_copy_audit::is_strict_mode() {
1201 return Err(EngineError::InvariantViolation(message));
1202 }
1203 rave_core::host_copy_violation!("inference", "{message}");
1204 }
1205 Ok(())
1206 }
1207
1208 fn run_io_bound(
1209 session: &mut Session,
1210 meta: &ModelMetadata,
1211 input: &GpuTexture,
1212 output_ptr: u64,
1213 output_bytes: usize,
1214 _ctx: &GpuContext,
1215 _mem_info: &ort::memory::MemoryInfo, ) -> Result<()> {
1217 let in_w = input.width as i64;
1218 let in_h = input.height as i64;
1219 let out_w = in_w * meta.scale as i64;
1220 let out_h = in_h * meta.scale as i64;
1221
1222 let input_shape: Vec<i64> = vec![1, meta.input_channels as i64, in_h, in_w];
1223 let output_shape: Vec<i64> = vec![1, meta.input_channels as i64, out_h, out_w];
1224
1225 let elem_type = ort_element_type(input.format);
1227 let elem_bytes = input.format.element_bytes();
1228 let input_bytes = input.data.len();
1229
1230 let expected = (output_shape.iter().product::<i64>() as usize) * elem_bytes;
1231 if output_bytes < expected {
1232 return Err(EngineError::BufferTooSmall {
1233 need: expected,
1234 have: output_bytes,
1235 });
1236 }
1237
1238 let input_ptr = input.device_ptr();
1239
1240 Self::verify_pointer_identity(input_ptr, output_ptr, input, output_ptr)?;
1242
1243 let mut binding = session.create_binding()?;
1244
1245 unsafe {
1250 let input_tensor = create_tensor_from_device_memory(
1252 input_ptr as *mut _,
1253 input_bytes,
1254 &input_shape,
1255 elem_type,
1256 )?;
1257
1258 binding.bind_input(&meta.input_name, &input_tensor)?;
1259
1260 let output_tensor = create_tensor_from_device_memory(
1262 output_ptr as *mut _,
1263 output_bytes,
1264 &output_shape,
1265 elem_type,
1266 )?;
1267
1268 binding.bind_output(&meta.output_name, output_tensor)?;
1269 }
1270
1271 session.run_binding(&binding)?;
1273
1274 Ok(())
1275 }
1276}
1277
1278#[cfg(test)]
1279mod pointer_audit_tests {
1280 use super::TensorRtBackend;
1281
1282 #[test]
1283 fn pointer_identity_audit_accepts_matching_addresses() {
1284 let mismatch = TensorRtBackend::pointer_identity_mismatch(0x1000, 0x1000, 0x2000, 0x2000);
1285 assert!(mismatch.is_none());
1286 }
1287
1288 #[test]
1289 fn pointer_identity_audit_reports_mismatch() {
1290 let mismatch = TensorRtBackend::pointer_identity_mismatch(0x1000, 0x1111, 0x2000, 0x2000)
1291 .expect("mismatch should be reported");
1292 assert!(mismatch.contains("IO-bound input"));
1293 }
1294}
1295
1296#[cfg(all(test, target_os = "linux"))]
1297mod provider_resolution_tests {
1298 use super::{OrtProviderKind, TensorRtBackend};
1299 use std::env;
1300 use std::fs::{self, File};
1301 use std::path::{Path, PathBuf};
1302 use std::process;
1303 use std::sync::atomic::{AtomicU64, Ordering};
1304
1305 static NEXT_TMP_ID: AtomicU64 = AtomicU64::new(0);
1306
1307 struct TempDir {
1308 path: PathBuf,
1309 }
1310
1311 impl TempDir {
1312 fn new(prefix: &str) -> Self {
1313 let unique = format!(
1314 "rave_provider_resolution_{prefix}_{}_{}",
1315 process::id(),
1316 NEXT_TMP_ID.fetch_add(1, Ordering::Relaxed)
1317 );
1318 let path = env::temp_dir().join(unique);
1319 fs::create_dir_all(&path).expect("create temp dir");
1320 Self { path }
1321 }
1322
1323 fn join(&self, suffix: &str) -> PathBuf {
1324 self.path.join(suffix)
1325 }
1326 }
1327
1328 impl Drop for TempDir {
1329 fn drop(&mut self) {
1330 let _ = fs::remove_dir_all(&self.path);
1331 }
1332 }
1333
1334 fn write_provider_pair(dir: &Path, kind: OrtProviderKind) {
1335 fs::create_dir_all(dir).expect("create provider dir");
1336 File::create(dir.join("libonnxruntime_providers_shared.so")).expect("create shared lib");
1337 File::create(dir.join(kind.soname())).expect("create provider lib");
1338 }
1339
1340 #[test]
1341 fn provider_dir_env_override_wins() {
1342 let tmp = TempDir::new("env_override");
1343 let override_dir = tmp.join("override");
1344 let exe_dir = tmp.join("exe");
1345 let cache_dir = tmp.join("cache");
1346
1347 write_provider_pair(&override_dir, OrtProviderKind::TensorRt);
1348 write_provider_pair(&exe_dir, OrtProviderKind::TensorRt);
1349 write_provider_pair(&cache_dir, OrtProviderKind::TensorRt);
1350
1351 let candidates = TensorRtBackend::provider_dir_candidates(
1352 Some(override_dir.clone()),
1353 None,
1354 Some(exe_dir),
1355 vec![cache_dir],
1356 false,
1357 );
1358 let (resolved, source) = TensorRtBackend::resolve_provider_dir(
1359 OrtProviderKind::TensorRt,
1360 &candidates,
1361 Some(&override_dir),
1362 None,
1363 "upscale",
1364 "direct",
1365 "tensorrt",
1366 )
1367 .expect("provider directory should resolve");
1368
1369 assert_eq!(resolved, override_dir);
1370 assert_eq!(source, "ORT_DYLIB_PATH");
1371 }
1372
1373 #[test]
1374 fn provider_dir_fallback_order_uses_first_existing_candidate() {
1375 let tmp = TempDir::new("fallback_order");
1376 let exe_dir = tmp.join("exe");
1377 let exe_deps_dir = exe_dir.join("deps");
1378 let cache_dir = tmp.join("cache");
1379
1380 write_provider_pair(&exe_deps_dir, OrtProviderKind::Cuda);
1381 write_provider_pair(&cache_dir, OrtProviderKind::Cuda);
1382
1383 let candidates = TensorRtBackend::provider_dir_candidates(
1384 None,
1385 None,
1386 Some(exe_dir.clone()),
1387 vec![cache_dir],
1388 false,
1389 );
1390 let (resolved, source) = TensorRtBackend::resolve_provider_dir(
1391 OrtProviderKind::Cuda,
1392 &candidates,
1393 None,
1394 None,
1395 "benchmark",
1396 "provider_preload",
1397 "cuda",
1398 )
1399 .expect("provider directory should resolve");
1400
1401 assert_eq!(resolved, exe_deps_dir);
1402 assert_eq!(source, "exe_dir/deps");
1403 }
1404
1405 #[test]
1406 fn provider_dir_failure_reports_checked_candidates_in_order() {
1407 let tmp = TempDir::new("failure_diagnostics");
1408 let override_dir = tmp.join("missing_override");
1409 let lib_dir = tmp.join("missing_lib");
1410
1411 let candidates = TensorRtBackend::provider_dir_candidates(
1412 Some(override_dir.clone()),
1413 Some(lib_dir.clone()),
1414 None,
1415 vec![],
1416 false,
1417 );
1418 let err = TensorRtBackend::resolve_provider_dir(
1419 OrtProviderKind::TensorRt,
1420 &candidates,
1421 Some(&override_dir),
1422 Some(&lib_dir),
1423 "validate",
1424 "provider_preload",
1425 "tensorrt",
1426 )
1427 .expect_err("resolution should fail");
1428 let msg = err.to_string();
1429
1430 assert!(msg.contains("command=validate"));
1431 assert!(msg.contains("mode=provider_preload"));
1432 assert!(msg.contains("ep_mode=tensorrt"));
1433 assert!(msg.contains("candidates_checked=[1:ORT_DYLIB_PATH:"));
1434 let idx_override = msg
1435 .find("1:ORT_DYLIB_PATH:")
1436 .expect("missing ORT_DYLIB_PATH candidate");
1437 let idx_lib = msg
1438 .find("2:ORT_LIB_LOCATION:")
1439 .expect("missing ORT_LIB_LOCATION candidate");
1440 assert!(
1441 idx_override < idx_lib,
1442 "candidate order should be deterministic and preserve precedence"
1443 );
1444 }
1445
1446 #[test]
1447 fn provider_candidate_rendering_is_deterministic() {
1448 let candidates = vec![
1449 (PathBuf::from("/tmp/a"), "ORT_DYLIB_PATH"),
1450 (PathBuf::from("/tmp/b"), "ORT_LIB_LOCATION"),
1451 (PathBuf::from("/tmp/c"), "exe_dir"),
1452 ];
1453
1454 let rendered_a = TensorRtBackend::render_provider_candidates(&candidates);
1455 let rendered_b = TensorRtBackend::render_provider_candidates(&candidates);
1456
1457 assert_eq!(rendered_a, rendered_b);
1458 assert_eq!(
1459 rendered_a,
1460 "1:ORT_DYLIB_PATH:/tmp/a, 2:ORT_LIB_LOCATION:/tmp/b, 3:exe_dir:/tmp/c"
1461 );
1462 }
1463}
1464
1465#[cfg(all(test, target_os = "linux"))]
1466mod tests {
1467 use super::{RTLD_LOCAL, RTLD_NOW, TensorRtBackend};
1468 use rave_core::backend::UpscaleBackend;
1469 use std::env;
1470 use std::path::PathBuf;
1471
1472 #[test]
1473 #[ignore = "requires ORT TensorRT provider libs on host"]
1474 fn providers_load_with_bridge_preloaded() {
1475 TensorRtBackend::preload_ort_provider_bridge().expect("failed to preload providers_shared");
1477 let trt_candidates =
1478 TensorRtBackend::ort_provider_candidates("libonnxruntime_providers_tensorrt.so");
1479 assert!(
1480 !trt_candidates.is_empty(),
1481 "no libonnxruntime_providers_tensorrt.so candidates found"
1482 );
1483 let path = trt_candidates[0].clone();
1484 TensorRtBackend::dlopen_path(&path, RTLD_NOW | RTLD_LOCAL)
1485 .expect("providers_tensorrt should dlopen after providers_shared preload");
1486 }
1487
1488 #[test]
1489 #[ignore = "requires model + full ORT/TensorRT runtime"]
1490 fn ort_registers_tensorrt_ep_smoke() {
1491 let model = env::var("RAVE_TEST_ONNX_MODEL").expect("set RAVE_TEST_ONNX_MODEL");
1492 let backend = TensorRtBackend::new(
1493 PathBuf::from(model),
1494 rave_core::context::GpuContext::new(0).expect("cuda ctx"),
1495 0,
1496 6,
1497 4,
1498 );
1499 let rt = tokio::runtime::Builder::new_current_thread()
1500 .enable_all()
1501 .build()
1502 .expect("tokio runtime");
1503 rt.block_on(async move { backend.initialize().await })
1504 .expect("TensorRT EP registration should succeed");
1505 }
1506}
1507
1508#[async_trait]
1509impl UpscaleBackend for TensorRtBackend {
1510 async fn initialize(&self) -> Result<()> {
1511 let mut guard = self.state.lock().await;
1512 if guard.is_some() {
1513 return Err(EngineError::ModelMetadata("Already initialized".into()));
1514 }
1515
1516 validate_batch_config(&self.batch_config)?;
1517
1518 let ep_mode = Self::ort_ep_mode();
1519 let on_wsl = Self::is_wsl2();
1520 info!(
1521 path = %self.model_path.display(),
1522 ?ep_mode,
1523 on_wsl,
1524 "Loading ONNX model with ORT execution provider policy"
1525 );
1526
1527 let (session, active_provider) = match ep_mode {
1528 OrtEpMode::CudaOnly => {
1529 info!("RAVE_ORT_TENSORRT disables TensorRT EP; using CUDAExecutionProvider");
1530 (self.build_cuda_session()?, "CUDAExecutionProvider")
1531 }
1532 OrtEpMode::TensorRtOnly => (self.build_trt_session()?, "TensorrtExecutionProvider"),
1533 OrtEpMode::Auto => match self.build_trt_session() {
1534 Ok(session) => (session, "TensorrtExecutionProvider"),
1535 Err(e) => {
1536 warn!(
1537 error = %e,
1538 on_wsl,
1539 "TensorRT EP registration failed; falling back to CUDAExecutionProvider"
1540 );
1541 (self.build_cuda_session()?, "CUDAExecutionProvider")
1542 }
1543 },
1544 };
1545 info!(
1546 provider = active_provider,
1547 "ORT execution provider selected"
1548 );
1549
1550 let metadata = Self::extract_metadata(&session)?;
1551 info!(
1552 name = %metadata.name,
1553 scale = metadata.scale,
1554 input = %metadata.input_name,
1555 output = %metadata.output_name,
1556 ring_size = self.ring_size,
1557 min_ring_slots = self.min_ring_slots,
1558 provider = active_provider,
1559 precision = ?self.precision_policy,
1560 max_batch = self.batch_config.max_batch,
1561 "Model loaded"
1562 );
1563
1564 let _ = self.meta.set(metadata);
1565 let _ = self.selected_provider.set(active_provider.to_string());
1566
1567 *guard = Some(InferenceState {
1568 session,
1569 ring: None,
1570 });
1571
1572 Ok(())
1573 }
1574
1575 async fn process(&self, input: GpuTexture) -> Result<GpuTexture> {
1576 match input.format {
1578 PixelFormat::RgbPlanarF32 | PixelFormat::RgbPlanarF16 => {}
1579 other => {
1580 return Err(EngineError::FormatMismatch {
1581 expected: PixelFormat::RgbPlanarF32,
1582 actual: other,
1583 });
1584 }
1585 }
1586
1587 let meta = self.meta.get().ok_or(EngineError::NotInitialized)?;
1588 let mut guard = self.state.lock().await;
1589 let state = guard.as_mut().ok_or(EngineError::NotInitialized)?;
1590
1591 match &mut state.ring {
1593 Some(ring) if ring.needs_realloc(input.width, input.height) => {
1594 debug!(
1595 old_w = ring.alloc_dims.0,
1596 old_h = ring.alloc_dims.1,
1597 new_w = input.width,
1598 new_h = input.height,
1599 "Reallocating output ring"
1600 );
1601 ring.reallocate(&self.ctx, input.width, input.height, meta.scale)?;
1602 }
1603 None => {
1604 debug!(
1605 w = input.width,
1606 h = input.height,
1607 slots = self.ring_size,
1608 "Lazily creating output ring"
1609 );
1610 state.ring = Some(OutputRing::new(
1611 &self.ctx,
1612 input.width,
1613 input.height,
1614 meta.scale,
1615 self.ring_size,
1616 self.min_ring_slots,
1617 )?);
1618 }
1619 Some(_) => {}
1620 }
1621
1622 let ring = state.ring.as_mut().unwrap();
1623
1624 #[cfg(feature = "debug-alloc")]
1626 {
1627 rave_core::debug_alloc::reset();
1628 rave_core::debug_alloc::enable();
1629 }
1630
1631 let output_arc = ring.acquire()?;
1632 let output_ptr = *(*output_arc).device_ptr();
1633 let output_bytes = ring.slot_bytes;
1634
1635 let t_start = std::time::Instant::now();
1637
1638 let mem_info = self.mem_info()?;
1639 Self::run_io_bound(
1640 &mut state.session,
1641 meta,
1642 &input,
1643 output_ptr,
1644 output_bytes,
1645 &self.ctx,
1646 &mem_info,
1647 )?;
1648
1649 let elapsed_us = t_start.elapsed().as_micros() as u64;
1650 self.inference_metrics.record(elapsed_us);
1651
1652 #[cfg(feature = "debug-alloc")]
1653 {
1654 rave_core::debug_alloc::disable();
1655 let host_allocs = rave_core::debug_alloc::count();
1656 debug_assert_eq!(
1657 host_allocs, 0,
1658 "VIOLATION: {host_allocs} host allocations during inference"
1659 );
1660 }
1661
1662 let out_w = input.width * meta.scale;
1663 let out_h = input.height * meta.scale;
1664 let elem_bytes = input.format.element_bytes();
1665
1666 Ok(GpuTexture {
1667 data: output_arc,
1668 width: out_w,
1669 height: out_h,
1670 pitch: (out_w as usize) * elem_bytes,
1671 format: input.format, })
1673 }
1674
1675 async fn shutdown(&self) -> Result<()> {
1676 let mut guard = self.state.lock().await;
1677 if let Some(state) = guard.take() {
1678 info!("Shutting down TensorRT backend");
1679 self.ctx.sync_all()?;
1680
1681 if let Some(ring) = &state.ring {
1683 let snap = ring.metrics.snapshot();
1684 info!(
1685 reuse = snap.reuse,
1686 contention = snap.contention,
1687 first_use = snap.first_use,
1688 "Final ring metrics"
1689 );
1690 }
1691
1692 let snap = self.inference_metrics.snapshot();
1694 info!(
1695 frames = snap.frames_inferred,
1696 avg_us = snap.avg_inference_us,
1697 peak_us = snap.peak_inference_us,
1698 precision = ?self.precision_policy,
1699 "Final inference metrics"
1700 );
1701
1702 let (current, peak) = self.ctx.vram_usage();
1704 info!(
1705 current_mb = current / (1024 * 1024),
1706 peak_mb = peak / (1024 * 1024),
1707 "Final VRAM usage"
1708 );
1709
1710 drop(state.ring);
1711 drop(state.session);
1712 debug!("TensorRT backend shutdown complete");
1713 }
1714 Ok(())
1715 }
1716
1717 fn metadata(&self) -> Result<&ModelMetadata> {
1718 self.meta.get().ok_or(EngineError::NotInitialized)
1719 }
1720}
1721
1722#[cfg(test)]
1723mod batch_config_tests {
1724 use super::{BatchConfig, validate_batch_config};
1725
1726 #[test]
1727 fn batch_config_validator_accepts_single_frame() {
1728 let cfg = BatchConfig {
1729 max_batch: 1,
1730 latency_deadline_us: 8_000,
1731 };
1732 validate_batch_config(&cfg).expect("max_batch=1 should be accepted");
1733 }
1734
1735 #[test]
1736 fn batch_config_validator_rejects_micro_batching() {
1737 let cfg = BatchConfig {
1738 max_batch: 2,
1739 latency_deadline_us: 8_000,
1740 };
1741 let err = validate_batch_config(&cfg).expect_err("max_batch>1 must be rejected");
1742 let msg = err.to_string();
1743 assert!(msg.contains("max_batch"));
1744 assert!(msg.contains("not implemented"));
1745 }
1746}
1747
1748impl Drop for TensorRtBackend {
1749 fn drop(&mut self) {
1750 if let Ok(mut guard) = self.state.try_lock()
1751 && let Some(state) = guard.take()
1752 {
1753 let _ = self.ctx.sync_all();
1754 drop(state);
1755 }
1756 }
1757}