dbx_core/engine/
stream.rs1#[cfg(feature = "gpu")]
6use cudarc::driver::{CudaContext, CudaStream};
7
8use crate::error::{DbxError, DbxResult};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum StreamPriority {
13 High,
14 Normal,
15}
16
17#[cfg(feature = "gpu")]
19pub struct GpuStreamContext {
20 pub stream_id: usize,
22 pub priority: StreamPriority,
24 stream: Arc<CudaStream>,
26 device: Arc<CudaContext>,
28}
29
30#[cfg(feature = "gpu")]
31impl GpuStreamContext {
32 pub fn new(
33 stream_id: usize,
34 priority: StreamPriority,
35 device: Arc<CudaContext>,
36 ) -> DbxResult<Self> {
37 let stream = device
40 .fork_default_stream()
41 .map_err(|e| DbxError::Gpu(format!("Failed to create stream: {:?}", e)))?;
42
43 Ok(Self {
44 stream_id,
45 priority,
46 stream,
47 device,
48 })
49 }
50
51 pub fn stream(&self) -> &CudaStream {
53 &self.stream
54 }
55
56 pub fn synchronize(&self) -> DbxResult<()> {
58 self.stream
59 .synchronize()
60 .map_err(|e| DbxError::Gpu(format!("Stream sync failed: {:?}", e)))
61 }
62}
63
64#[cfg(feature = "gpu")]
66pub struct StreamManager {
67 device: Arc<CudaContext>,
69 streams: Vec<GpuStreamContext>,
71 next_id: usize,
73}
74
75#[cfg(feature = "gpu")]
76impl StreamManager {
77 pub fn new(device: Arc<CudaContext>) -> DbxResult<Self> {
79 Ok(Self {
80 device,
81 streams: Vec::new(),
82 next_id: 0,
83 })
84 }
85
86 pub fn create_stream(&mut self, priority: StreamPriority) -> DbxResult<usize> {
88 let stream_id = self.next_id;
89 self.next_id += 1;
90
91 let context = GpuStreamContext::new(stream_id, priority, self.device.clone())?;
92 self.streams.push(context);
93
94 Ok(stream_id)
95 }
96
97 pub fn get_stream(&self, stream_id: usize) -> Option<&GpuStreamContext> {
99 self.streams.iter().find(|s| s.stream_id == stream_id)
100 }
101
102 pub fn synchronize_all(&self) -> DbxResult<()> {
104 for stream in &self.streams {
105 stream.synchronize()?;
106 }
107 Ok(())
108 }
109
110 pub fn stream_count(&self) -> usize {
112 self.streams.len()
113 }
114}
115
116#[cfg(not(feature = "gpu"))]
118pub struct GpuStreamContext;
119
120#[cfg(not(feature = "gpu"))]
121pub struct StreamManager;
122
123#[cfg(not(feature = "gpu"))]
124impl StreamManager {
125 pub fn new(_device: ()) -> DbxResult<Self> {
126 Err(DbxError::NotImplemented(
127 "GPU acceleration is not enabled".to_string(),
128 ))
129 }
130}