Skip to main content

fast_cache/
cuda.rs

1//! GPU-facing configuration and transfer descriptors.
2//!
3//! The crate keeps these types in the public API so storage callers can
4//! describe KV-cache chunk transfers without depending on a CUDA runtime in the
5//! core crate. The actual GPU execution layer is intentionally outside this
6//! package.
7
8use serde::{Deserialize, Serialize};
9
10#[cfg(feature = "sharded")]
11use crate::storage::LocalEmbeddedReadSlice;
12use crate::storage::{Bytes, hash_key};
13
14/// Runtime configuration for the optional CUDA/GPU tier.
15///
16/// The fields are kept in the core config surface so operators can describe a
17/// GPU-tier budget and staging policy in one place even when running on a CPU-
18/// only build. The direct CUDA runtime remains feature-gated.
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(default)]
21pub struct CudaConfig {
22    pub enabled: bool,
23    pub device_ordinal: usize,
24    pub hot_tier_bytes: u64,
25    pub pinned_host_bytes: u64,
26    pub transfer_stream_count: usize,
27    pub layer_streaming: bool,
28    pub prefer_direct_host_dma: bool,
29    pub pinned_staging_threshold_bytes: usize,
30    pub allow_cpu_fallback: bool,
31}
32
33impl Default for CudaConfig {
34    fn default() -> Self {
35        Self {
36            enabled: false,
37            device_ordinal: 0,
38            hot_tier_bytes: 10 * 1024 * 1024 * 1024,
39            pinned_host_bytes: 512 * 1024 * 1024,
40            transfer_stream_count: 4,
41            layer_streaming: true,
42            prefer_direct_host_dma: true,
43            pinned_staging_threshold_bytes: 2 * 1024 * 1024,
44            allow_cpu_fallback: true,
45        }
46    }
47}
48
49/// Precomputed routing metadata for a chunk that should be transferred to a GPU
50/// destination in layer order.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct CudaChunkTransferDescriptor {
53    key: Bytes,
54    key_hash: u64,
55    layer_index: u32,
56    dst_offset_bytes: u64,
57    expected_len: Option<usize>,
58}
59
60impl CudaChunkTransferDescriptor {
61    pub fn new<K>(key: K, layer_index: u32, dst_offset_bytes: u64) -> Self
62    where
63        K: Into<Bytes>,
64    {
65        let key = key.into();
66        let key_hash = hash_key(&key);
67        Self {
68            key,
69            key_hash,
70            layer_index,
71            dst_offset_bytes,
72            expected_len: None,
73        }
74    }
75
76    #[inline(always)]
77    pub fn with_expected_len(mut self, expected_len: usize) -> Self {
78        self.expected_len = Some(expected_len);
79        self
80    }
81
82    #[inline(always)]
83    pub fn key(&self) -> &[u8] {
84        &self.key
85    }
86
87    #[inline(always)]
88    pub fn key_hash(&self) -> u64 {
89        self.key_hash
90    }
91
92    #[inline(always)]
93    pub fn layer_index(&self) -> u32 {
94        self.layer_index
95    }
96
97    #[inline(always)]
98    pub fn dst_offset_bytes(&self) -> u64 {
99        self.dst_offset_bytes
100    }
101
102    #[inline(always)]
103    pub fn expected_len(&self) -> Option<usize> {
104        self.expected_len
105    }
106}
107
108/// A session-scoped transfer request for streaming KV chunks in layer order to
109/// a GPU-facing consumer.
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct CudaSessionTransferRequest {
112    session_prefix: Bytes,
113    chunks: Vec<CudaChunkTransferDescriptor>,
114}
115
116impl CudaSessionTransferRequest {
117    pub fn new<S>(session_prefix: S, chunks: Vec<CudaChunkTransferDescriptor>) -> Self
118    where
119        S: Into<Bytes>,
120    {
121        Self {
122            session_prefix: session_prefix.into(),
123            chunks,
124        }
125    }
126
127    #[inline(always)]
128    pub fn session_prefix(&self) -> &[u8] {
129        &self.session_prefix
130    }
131
132    #[inline(always)]
133    pub fn chunks(&self) -> &[CudaChunkTransferDescriptor] {
134        &self.chunks
135    }
136
137    #[inline(always)]
138    pub fn item_count(&self) -> usize {
139        self.chunks.len()
140    }
141
142    #[inline(always)]
143    pub fn total_expected_bytes(&self) -> Option<usize> {
144        self.chunks
145            .iter()
146            .map(CudaChunkTransferDescriptor::expected_len)
147            .try_fold(0usize, |sum, len| len.map(|len| sum.saturating_add(len)))
148    }
149}
150
151/// Aggregate outcome for a session-scoped streaming transfer.
152#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
153pub struct CudaSessionTransferStats {
154    pub requested_chunks: usize,
155    pub hit_chunks: usize,
156    pub missed_chunks: usize,
157    pub transferred_bytes: usize,
158}
159
160impl CudaSessionTransferStats {
161    #[inline(always)]
162    pub fn all_hit(&self) -> bool {
163        self.requested_chunks == self.hit_chunks
164    }
165}
166
167#[cfg(feature = "sharded")]
168#[derive(Debug, Clone)]
169pub struct CudaChunkTransferHit<'a> {
170    descriptor: &'a CudaChunkTransferDescriptor,
171    value: LocalEmbeddedReadSlice<'a>,
172}
173
174#[cfg(feature = "sharded")]
175impl<'a> CudaChunkTransferHit<'a> {
176    pub(crate) fn new(
177        descriptor: &'a CudaChunkTransferDescriptor,
178        value: LocalEmbeddedReadSlice<'a>,
179    ) -> Self {
180        Self { descriptor, value }
181    }
182
183    #[inline(always)]
184    pub fn descriptor(&self) -> &'a CudaChunkTransferDescriptor {
185        self.descriptor
186    }
187
188    #[inline(always)]
189    pub fn value(&self) -> LocalEmbeddedReadSlice<'a> {
190        self.value.clone()
191    }
192
193    #[inline(always)]
194    pub fn as_slice(&self) -> &[u8] {
195        self.value.as_slice()
196    }
197}
198
199#[cfg(feature = "sharded")]
200#[derive(Debug, Clone)]
201pub enum CudaSessionChunkEvent<'a> {
202    Hit(CudaChunkTransferHit<'a>),
203    Miss(&'a CudaChunkTransferDescriptor),
204}
205
206#[cfg(feature = "sharded")]
207impl<'a> CudaSessionChunkEvent<'a> {
208    #[inline(always)]
209    pub fn descriptor(&self) -> &'a CudaChunkTransferDescriptor {
210        match self {
211            Self::Hit(hit) => hit.descriptor(),
212            Self::Miss(descriptor) => descriptor,
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::{CudaChunkTransferDescriptor, CudaSessionTransferRequest};
220
221    #[test]
222    fn transfer_request_precomputes_hashes_and_expected_bytes() {
223        let request = CudaSessionTransferRequest::new(
224            b"s:42".to_vec(),
225            vec![
226                CudaChunkTransferDescriptor::new(b"s:42:l:0".to_vec(), 0, 0).with_expected_len(128),
227                CudaChunkTransferDescriptor::new(b"s:42:l:1".to_vec(), 1, 128)
228                    .with_expected_len(256),
229            ],
230        );
231
232        assert_eq!(request.item_count(), 2);
233        assert_eq!(request.total_expected_bytes(), Some(384));
234        assert_ne!(request.chunks()[0].key_hash(), 0);
235        assert_eq!(request.chunks()[1].layer_index(), 1);
236        assert_eq!(request.chunks()[1].dst_offset_bytes(), 128);
237    }
238}