1use serde::{Deserialize, Serialize};
9
10#[cfg(feature = "sharded")]
11use crate::storage::LocalEmbeddedReadSlice;
12use crate::storage::{Bytes, hash_key};
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(default)]
21pub struct CudaConfig {
22 pub enabled: bool,
23 pub device_ordinal: usize,
24 pub hot_tier_bytes: u64,
25 pub pinned_host_bytes: u64,
26 pub transfer_stream_count: usize,
27 pub layer_streaming: bool,
28 pub prefer_direct_host_dma: bool,
29 pub pinned_staging_threshold_bytes: usize,
30 pub allow_cpu_fallback: bool,
31}
32
33impl Default for CudaConfig {
34 fn default() -> Self {
35 Self {
36 enabled: false,
37 device_ordinal: 0,
38 hot_tier_bytes: 10 * 1024 * 1024 * 1024,
39 pinned_host_bytes: 512 * 1024 * 1024,
40 transfer_stream_count: 4,
41 layer_streaming: true,
42 prefer_direct_host_dma: true,
43 pinned_staging_threshold_bytes: 2 * 1024 * 1024,
44 allow_cpu_fallback: true,
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct CudaChunkTransferDescriptor {
53 key: Bytes,
54 key_hash: u64,
55 layer_index: u32,
56 dst_offset_bytes: u64,
57 expected_len: Option<usize>,
58}
59
60impl CudaChunkTransferDescriptor {
61 pub fn new<K>(key: K, layer_index: u32, dst_offset_bytes: u64) -> Self
62 where
63 K: Into<Bytes>,
64 {
65 let key = key.into();
66 let key_hash = hash_key(&key);
67 Self {
68 key,
69 key_hash,
70 layer_index,
71 dst_offset_bytes,
72 expected_len: None,
73 }
74 }
75
76 #[inline(always)]
77 pub fn with_expected_len(mut self, expected_len: usize) -> Self {
78 self.expected_len = Some(expected_len);
79 self
80 }
81
82 #[inline(always)]
83 pub fn key(&self) -> &[u8] {
84 &self.key
85 }
86
87 #[inline(always)]
88 pub fn key_hash(&self) -> u64 {
89 self.key_hash
90 }
91
92 #[inline(always)]
93 pub fn layer_index(&self) -> u32 {
94 self.layer_index
95 }
96
97 #[inline(always)]
98 pub fn dst_offset_bytes(&self) -> u64 {
99 self.dst_offset_bytes
100 }
101
102 #[inline(always)]
103 pub fn expected_len(&self) -> Option<usize> {
104 self.expected_len
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct CudaSessionTransferRequest {
112 session_prefix: Bytes,
113 chunks: Vec<CudaChunkTransferDescriptor>,
114}
115
116impl CudaSessionTransferRequest {
117 pub fn new<S>(session_prefix: S, chunks: Vec<CudaChunkTransferDescriptor>) -> Self
118 where
119 S: Into<Bytes>,
120 {
121 Self {
122 session_prefix: session_prefix.into(),
123 chunks,
124 }
125 }
126
127 #[inline(always)]
128 pub fn session_prefix(&self) -> &[u8] {
129 &self.session_prefix
130 }
131
132 #[inline(always)]
133 pub fn chunks(&self) -> &[CudaChunkTransferDescriptor] {
134 &self.chunks
135 }
136
137 #[inline(always)]
138 pub fn item_count(&self) -> usize {
139 self.chunks.len()
140 }
141
142 #[inline(always)]
143 pub fn total_expected_bytes(&self) -> Option<usize> {
144 self.chunks
145 .iter()
146 .map(CudaChunkTransferDescriptor::expected_len)
147 .try_fold(0usize, |sum, len| len.map(|len| sum.saturating_add(len)))
148 }
149}
150
151#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
153pub struct CudaSessionTransferStats {
154 pub requested_chunks: usize,
155 pub hit_chunks: usize,
156 pub missed_chunks: usize,
157 pub transferred_bytes: usize,
158}
159
160impl CudaSessionTransferStats {
161 #[inline(always)]
162 pub fn all_hit(&self) -> bool {
163 self.requested_chunks == self.hit_chunks
164 }
165}
166
167#[cfg(feature = "sharded")]
168#[derive(Debug, Clone)]
169pub struct CudaChunkTransferHit<'a> {
170 descriptor: &'a CudaChunkTransferDescriptor,
171 value: LocalEmbeddedReadSlice<'a>,
172}
173
174#[cfg(feature = "sharded")]
175impl<'a> CudaChunkTransferHit<'a> {
176 pub(crate) fn new(
177 descriptor: &'a CudaChunkTransferDescriptor,
178 value: LocalEmbeddedReadSlice<'a>,
179 ) -> Self {
180 Self { descriptor, value }
181 }
182
183 #[inline(always)]
184 pub fn descriptor(&self) -> &'a CudaChunkTransferDescriptor {
185 self.descriptor
186 }
187
188 #[inline(always)]
189 pub fn value(&self) -> LocalEmbeddedReadSlice<'a> {
190 self.value.clone()
191 }
192
193 #[inline(always)]
194 pub fn as_slice(&self) -> &[u8] {
195 self.value.as_slice()
196 }
197}
198
199#[cfg(feature = "sharded")]
200#[derive(Debug, Clone)]
201pub enum CudaSessionChunkEvent<'a> {
202 Hit(CudaChunkTransferHit<'a>),
203 Miss(&'a CudaChunkTransferDescriptor),
204}
205
206#[cfg(feature = "sharded")]
207impl<'a> CudaSessionChunkEvent<'a> {
208 #[inline(always)]
209 pub fn descriptor(&self) -> &'a CudaChunkTransferDescriptor {
210 match self {
211 Self::Hit(hit) => hit.descriptor(),
212 Self::Miss(descriptor) => descriptor,
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::{CudaChunkTransferDescriptor, CudaSessionTransferRequest};
220
221 #[test]
222 fn transfer_request_precomputes_hashes_and_expected_bytes() {
223 let request = CudaSessionTransferRequest::new(
224 b"s:42".to_vec(),
225 vec![
226 CudaChunkTransferDescriptor::new(b"s:42:l:0".to_vec(), 0, 0).with_expected_len(128),
227 CudaChunkTransferDescriptor::new(b"s:42:l:1".to_vec(), 1, 128)
228 .with_expected_len(256),
229 ],
230 );
231
232 assert_eq!(request.item_count(), 2);
233 assert_eq!(request.total_expected_bytes(), Some(384));
234 assert_ne!(request.chunks()[0].key_hash(), 0);
235 assert_eq!(request.chunks()[1].layer_index(), 1);
236 assert_eq!(request.chunks()[1].dst_offset_bytes(), 128);
237 }
238}