#![cfg(feature = "cuda")]
use std::sync::Arc;
use cudarc::driver::CudaContext;
use ferrotorch_gpu::stream::{
StreamPool, StreamPriority, get_stream_priority_range, new_stream_with_priority,
};
use ferrotorch_gpu::transfer::{cpu_to_gpu, gpu_to_cpu};
fn ctx() -> Arc<CudaContext> {
CudaContext::new(0).expect("CUDA device 0 must be available for these tests")
}
#[test]
fn priority_range_query_succeeds_on_3090() {
let c = ctx();
let (least, greatest) = get_stream_priority_range(&c).expect("priority range");
assert!(greatest <= least);
}
#[test]
fn high_priority_stream_synchronize_does_not_crash() {
let c = ctx();
let stream = new_stream_with_priority(&c, StreamPriority::High).expect("high stream");
stream.synchronize().expect("synchronize high stream");
let stream = new_stream_with_priority(&c, StreamPriority::Normal).expect("normal stream");
stream.synchronize().expect("synchronize normal stream");
let stream = new_stream_with_priority(&c, StreamPriority::Low).expect("low stream");
stream.synchronize().expect("synchronize low stream");
}
#[test]
fn pool_serves_distinct_streams_for_high_and_low_priority() {
let c = ctx();
let high = StreamPool::get_priority_stream(&c, 0, StreamPriority::High).unwrap();
let low = StreamPool::get_priority_stream(&c, 0, StreamPriority::Low).unwrap();
assert!(!Arc::ptr_eq(&high, &low));
high.synchronize().unwrap();
low.synchronize().unwrap();
}
#[test]
fn priority_stream_supports_buffer_round_trip() {
use ferrotorch_gpu::device::GpuDevice;
use ferrotorch_gpu::stream::StreamGuard;
let dev = GpuDevice::new(0).expect("CUDA device 0");
let c = dev.context();
let stream = new_stream_with_priority(c, StreamPriority::High).expect("high stream");
let _guard = StreamGuard::new(0, Arc::clone(&stream));
let host: Vec<f32> = (0..1024).map(|i| i as f32).collect();
let gpu = cpu_to_gpu(&host, &dev).expect("upload");
let back = gpu_to_cpu(&gpu, &dev).expect("download");
assert_eq!(back, host);
stream.synchronize().expect("sync");
}