Skip to main content

oxicuda_driver/
context_config.rs

1//! Context configuration: limits, cache config, and shared memory config.
2//!
3//! These functions allow tuning per-context resource limits (stack size,
4//! heap size, etc.) and cache / shared memory policies.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! # use oxicuda_driver::context_config;
10//! # use oxicuda_driver::ffi::CUlimit;
11//! # fn main() -> Result<(), oxicuda_driver::CudaError> {
12//! let stack = context_config::get_limit(CUlimit::StackSize)?;
13//! println!("GPU thread stack size: {stack} bytes");
14//!
15//! context_config::set_cache_config(context_config::CacheConfig::PreferL1)?;
16//! # Ok(())
17//! # }
18//! ```
19
20use crate::error::{CudaError, CudaResult};
21use crate::ffi::CUlimit;
22use crate::loader::try_driver;
23
24// ---------------------------------------------------------------------------
25// Limits
26// ---------------------------------------------------------------------------
27
28/// Returns the value of a context limit.
29///
30/// # Errors
31///
32/// Returns [`CudaError::NotSupported`] if the driver lacks `cuCtxGetLimit`,
33/// or another error on failure.
34pub fn get_limit(limit: CUlimit) -> CudaResult<usize> {
35    let api = try_driver()?;
36    let f = api.cu_ctx_get_limit.ok_or(CudaError::NotSupported)?;
37    let mut value: usize = 0;
38    crate::cuda_call!(f(&mut value, limit as u32))?;
39    Ok(value)
40}
41
42/// Sets the value of a context limit.
43///
44/// # Errors
45///
46/// Returns [`CudaError::NotSupported`] if the driver lacks `cuCtxSetLimit`,
47/// or another error on failure.
48pub fn set_limit(limit: CUlimit, value: usize) -> CudaResult<()> {
49    let api = try_driver()?;
50    let f = api.cu_ctx_set_limit.ok_or(CudaError::NotSupported)?;
51    crate::cuda_call!(f(limit as u32, value))
52}
53
54// ---------------------------------------------------------------------------
55// CacheConfig
56// ---------------------------------------------------------------------------
57
58/// Preferred cache configuration for a CUDA context or function.
59///
60/// Controls the trade-off between L1 cache and shared memory on devices
61/// that share the same on-chip memory for both.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63#[repr(u32)]
64pub enum CacheConfig {
65    /// No preference — the driver picks.
66    PreferNone = 0,
67    /// Prefer more shared memory over L1 cache.
68    PreferShared = 1,
69    /// Prefer more L1 cache over shared memory.
70    PreferL1 = 2,
71    /// Equal split between L1 and shared memory.
72    PreferEqual = 3,
73}
74
75impl CacheConfig {
76    /// Convert a raw `u32` driver value to a `CacheConfig`.
77    fn from_raw(val: u32) -> CudaResult<Self> {
78        match val {
79            0 => Ok(Self::PreferNone),
80            1 => Ok(Self::PreferShared),
81            2 => Ok(Self::PreferL1),
82            3 => Ok(Self::PreferEqual),
83            _ => Err(CudaError::InvalidValue),
84        }
85    }
86}
87
88/// Returns the current cache configuration for the active context.
89///
90/// # Errors
91///
92/// Returns [`CudaError::NotSupported`] if the driver lacks
93/// `cuCtxGetCacheConfig`, or another error on failure.
94pub fn get_cache_config() -> CudaResult<CacheConfig> {
95    let api = try_driver()?;
96    let f = api.cu_ctx_get_cache_config.ok_or(CudaError::NotSupported)?;
97    let mut raw: u32 = 0;
98    crate::cuda_call!(f(&mut raw))?;
99    CacheConfig::from_raw(raw)
100}
101
102/// Sets the cache configuration for the active context.
103///
104/// # Errors
105///
106/// Returns [`CudaError::NotSupported`] if the driver lacks
107/// `cuCtxSetCacheConfig`, or another error on failure.
108pub fn set_cache_config(config: CacheConfig) -> CudaResult<()> {
109    let api = try_driver()?;
110    let f = api.cu_ctx_set_cache_config.ok_or(CudaError::NotSupported)?;
111    crate::cuda_call!(f(config as u32))
112}
113
114// ---------------------------------------------------------------------------
115// SharedMemConfig
116// ---------------------------------------------------------------------------
117
118/// Shared memory bank configuration.
119///
120/// Controls whether shared memory uses 4-byte or 8-byte bank width.
121/// 8-byte banks can reduce bank conflicts for 64-bit accesses.
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
123#[repr(u32)]
124pub enum SharedMemConfig {
125    /// Use the device default bank size.
126    Default = 0,
127    /// 4-byte (32-bit) shared memory banks.
128    FourByte = 1,
129    /// 8-byte (64-bit) shared memory banks.
130    EightByte = 2,
131}
132
133impl SharedMemConfig {
134    /// Convert a raw `u32` driver value to a `SharedMemConfig`.
135    fn from_raw(val: u32) -> CudaResult<Self> {
136        match val {
137            0 => Ok(Self::Default),
138            1 => Ok(Self::FourByte),
139            2 => Ok(Self::EightByte),
140            _ => Err(CudaError::InvalidValue),
141        }
142    }
143}
144
145/// Returns the current shared memory configuration for the active context.
146///
147/// # Errors
148///
149/// Returns [`CudaError::NotSupported`] if the driver lacks
150/// `cuCtxGetSharedMemConfig`, or another error on failure.
151pub fn get_shared_mem_config() -> CudaResult<SharedMemConfig> {
152    let api = try_driver()?;
153    let f = api
154        .cu_ctx_get_shared_mem_config
155        .ok_or(CudaError::NotSupported)?;
156    let mut raw: u32 = 0;
157    crate::cuda_call!(f(&mut raw))?;
158    SharedMemConfig::from_raw(raw)
159}
160
161/// Sets the shared memory configuration for the active context.
162///
163/// # Errors
164///
165/// Returns [`CudaError::NotSupported`] if the driver lacks
166/// `cuCtxSetSharedMemConfig`, or another error on failure.
167pub fn set_shared_mem_config(config: SharedMemConfig) -> CudaResult<()> {
168    let api = try_driver()?;
169    let f = api
170        .cu_ctx_set_shared_mem_config
171        .ok_or(CudaError::NotSupported)?;
172    crate::cuda_call!(f(config as u32))
173}
174
175// ---------------------------------------------------------------------------
176// Tests
177// ---------------------------------------------------------------------------
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn cache_config_round_trip() {
185        assert_eq!(CacheConfig::from_raw(0).ok(), Some(CacheConfig::PreferNone));
186        assert_eq!(
187            CacheConfig::from_raw(1).ok(),
188            Some(CacheConfig::PreferShared)
189        );
190        assert_eq!(CacheConfig::from_raw(2).ok(), Some(CacheConfig::PreferL1));
191        assert_eq!(
192            CacheConfig::from_raw(3).ok(),
193            Some(CacheConfig::PreferEqual)
194        );
195        assert!(CacheConfig::from_raw(99).is_err());
196    }
197
198    #[test]
199    fn shared_mem_config_round_trip() {
200        assert_eq!(
201            SharedMemConfig::from_raw(0).ok(),
202            Some(SharedMemConfig::Default)
203        );
204        assert_eq!(
205            SharedMemConfig::from_raw(1).ok(),
206            Some(SharedMemConfig::FourByte)
207        );
208        assert_eq!(
209            SharedMemConfig::from_raw(2).ok(),
210            Some(SharedMemConfig::EightByte)
211        );
212        assert!(SharedMemConfig::from_raw(99).is_err());
213    }
214
215    #[test]
216    fn cache_config_repr_values() {
217        assert_eq!(CacheConfig::PreferNone as u32, 0);
218        assert_eq!(CacheConfig::PreferShared as u32, 1);
219        assert_eq!(CacheConfig::PreferL1 as u32, 2);
220        assert_eq!(CacheConfig::PreferEqual as u32, 3);
221    }
222
223    #[test]
224    fn get_limit_returns_error_without_gpu() {
225        let result = get_limit(CUlimit::StackSize);
226        let _ = result;
227    }
228}