oxicuda_driver/context_config.rs
1//! Context configuration: limits, cache config, and shared memory config.
2//!
3//! These functions allow tuning per-context resource limits (stack size,
4//! heap size, etc.) and cache / shared memory policies.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! # use oxicuda_driver::context_config;
10//! # use oxicuda_driver::ffi::CUlimit;
11//! # fn main() -> Result<(), oxicuda_driver::CudaError> {
12//! let stack = context_config::get_limit(CUlimit::StackSize)?;
13//! println!("GPU thread stack size: {stack} bytes");
14//!
15//! context_config::set_cache_config(context_config::CacheConfig::PreferL1)?;
16//! # Ok(())
17//! # }
18//! ```
19
20use crate::error::{CudaError, CudaResult};
21use crate::ffi::CUlimit;
22use crate::loader::try_driver;
23
24// ---------------------------------------------------------------------------
25// Limits
26// ---------------------------------------------------------------------------
27
28/// Returns the value of a context limit.
29///
30/// # Errors
31///
32/// Returns [`CudaError::NotSupported`] if the driver lacks `cuCtxGetLimit`,
33/// or another error on failure.
34pub fn get_limit(limit: CUlimit) -> CudaResult<usize> {
35 let api = try_driver()?;
36 let f = api.cu_ctx_get_limit.ok_or(CudaError::NotSupported)?;
37 let mut value: usize = 0;
38 crate::cuda_call!(f(&mut value, limit as u32))?;
39 Ok(value)
40}
41
42/// Sets the value of a context limit.
43///
44/// # Errors
45///
46/// Returns [`CudaError::NotSupported`] if the driver lacks `cuCtxSetLimit`,
47/// or another error on failure.
48pub fn set_limit(limit: CUlimit, value: usize) -> CudaResult<()> {
49 let api = try_driver()?;
50 let f = api.cu_ctx_set_limit.ok_or(CudaError::NotSupported)?;
51 crate::cuda_call!(f(limit as u32, value))
52}
53
54// ---------------------------------------------------------------------------
55// CacheConfig
56// ---------------------------------------------------------------------------
57
58/// Preferred cache configuration for a CUDA context or function.
59///
60/// Controls the trade-off between L1 cache and shared memory on devices
61/// that share the same on-chip memory for both.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63#[repr(u32)]
64pub enum CacheConfig {
65 /// No preference — the driver picks.
66 PreferNone = 0,
67 /// Prefer more shared memory over L1 cache.
68 PreferShared = 1,
69 /// Prefer more L1 cache over shared memory.
70 PreferL1 = 2,
71 /// Equal split between L1 and shared memory.
72 PreferEqual = 3,
73}
74
75impl CacheConfig {
76 /// Convert a raw `u32` driver value to a `CacheConfig`.
77 fn from_raw(val: u32) -> CudaResult<Self> {
78 match val {
79 0 => Ok(Self::PreferNone),
80 1 => Ok(Self::PreferShared),
81 2 => Ok(Self::PreferL1),
82 3 => Ok(Self::PreferEqual),
83 _ => Err(CudaError::InvalidValue),
84 }
85 }
86}
87
88/// Returns the current cache configuration for the active context.
89///
90/// # Errors
91///
92/// Returns [`CudaError::NotSupported`] if the driver lacks
93/// `cuCtxGetCacheConfig`, or another error on failure.
94pub fn get_cache_config() -> CudaResult<CacheConfig> {
95 let api = try_driver()?;
96 let f = api.cu_ctx_get_cache_config.ok_or(CudaError::NotSupported)?;
97 let mut raw: u32 = 0;
98 crate::cuda_call!(f(&mut raw))?;
99 CacheConfig::from_raw(raw)
100}
101
102/// Sets the cache configuration for the active context.
103///
104/// # Errors
105///
106/// Returns [`CudaError::NotSupported`] if the driver lacks
107/// `cuCtxSetCacheConfig`, or another error on failure.
108pub fn set_cache_config(config: CacheConfig) -> CudaResult<()> {
109 let api = try_driver()?;
110 let f = api.cu_ctx_set_cache_config.ok_or(CudaError::NotSupported)?;
111 crate::cuda_call!(f(config as u32))
112}
113
114// ---------------------------------------------------------------------------
115// SharedMemConfig
116// ---------------------------------------------------------------------------
117
118/// Shared memory bank configuration.
119///
120/// Controls whether shared memory uses 4-byte or 8-byte bank width.
121/// 8-byte banks can reduce bank conflicts for 64-bit accesses.
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
123#[repr(u32)]
124pub enum SharedMemConfig {
125 /// Use the device default bank size.
126 Default = 0,
127 /// 4-byte (32-bit) shared memory banks.
128 FourByte = 1,
129 /// 8-byte (64-bit) shared memory banks.
130 EightByte = 2,
131}
132
133impl SharedMemConfig {
134 /// Convert a raw `u32` driver value to a `SharedMemConfig`.
135 fn from_raw(val: u32) -> CudaResult<Self> {
136 match val {
137 0 => Ok(Self::Default),
138 1 => Ok(Self::FourByte),
139 2 => Ok(Self::EightByte),
140 _ => Err(CudaError::InvalidValue),
141 }
142 }
143}
144
145/// Returns the current shared memory configuration for the active context.
146///
147/// # Errors
148///
149/// Returns [`CudaError::NotSupported`] if the driver lacks
150/// `cuCtxGetSharedMemConfig`, or another error on failure.
151pub fn get_shared_mem_config() -> CudaResult<SharedMemConfig> {
152 let api = try_driver()?;
153 let f = api
154 .cu_ctx_get_shared_mem_config
155 .ok_or(CudaError::NotSupported)?;
156 let mut raw: u32 = 0;
157 crate::cuda_call!(f(&mut raw))?;
158 SharedMemConfig::from_raw(raw)
159}
160
161/// Sets the shared memory configuration for the active context.
162///
163/// # Errors
164///
165/// Returns [`CudaError::NotSupported`] if the driver lacks
166/// `cuCtxSetSharedMemConfig`, or another error on failure.
167pub fn set_shared_mem_config(config: SharedMemConfig) -> CudaResult<()> {
168 let api = try_driver()?;
169 let f = api
170 .cu_ctx_set_shared_mem_config
171 .ok_or(CudaError::NotSupported)?;
172 crate::cuda_call!(f(config as u32))
173}
174
175// ---------------------------------------------------------------------------
176// Tests
177// ---------------------------------------------------------------------------
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn cache_config_round_trip() {
185 assert_eq!(CacheConfig::from_raw(0).ok(), Some(CacheConfig::PreferNone));
186 assert_eq!(
187 CacheConfig::from_raw(1).ok(),
188 Some(CacheConfig::PreferShared)
189 );
190 assert_eq!(CacheConfig::from_raw(2).ok(), Some(CacheConfig::PreferL1));
191 assert_eq!(
192 CacheConfig::from_raw(3).ok(),
193 Some(CacheConfig::PreferEqual)
194 );
195 assert!(CacheConfig::from_raw(99).is_err());
196 }
197
198 #[test]
199 fn shared_mem_config_round_trip() {
200 assert_eq!(
201 SharedMemConfig::from_raw(0).ok(),
202 Some(SharedMemConfig::Default)
203 );
204 assert_eq!(
205 SharedMemConfig::from_raw(1).ok(),
206 Some(SharedMemConfig::FourByte)
207 );
208 assert_eq!(
209 SharedMemConfig::from_raw(2).ok(),
210 Some(SharedMemConfig::EightByte)
211 );
212 assert!(SharedMemConfig::from_raw(99).is_err());
213 }
214
215 #[test]
216 fn cache_config_repr_values() {
217 assert_eq!(CacheConfig::PreferNone as u32, 0);
218 assert_eq!(CacheConfig::PreferShared as u32, 1);
219 assert_eq!(CacheConfig::PreferL1 as u32, 2);
220 assert_eq!(CacheConfig::PreferEqual as u32, 3);
221 }
222
223 #[test]
224 fn get_limit_returns_error_without_gpu() {
225 let result = get_limit(CUlimit::StackSize);
226 let _ = result;
227 }
228}