1use std::fmt;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum LaunchError {
14 BlockSizeExceedsLimit {
16 requested: u32,
18 max: u32,
20 },
21 GridSizeExceedsLimit {
23 requested: u32,
25 max: u32,
27 },
28 SharedMemoryExceedsLimit {
30 requested: u32,
32 max: u32,
34 },
35 InvalidDimension {
37 dim: &'static str,
39 value: u32,
41 },
42}
43
44impl fmt::Display for LaunchError {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 Self::BlockSizeExceedsLimit { requested, max } => {
48 write!(
49 f,
50 "block size {requested} exceeds device maximum {max} threads per block"
51 )
52 }
53 Self::GridSizeExceedsLimit { requested, max } => {
54 write!(f, "grid dimension {requested} exceeds device maximum {max}")
55 }
56 Self::SharedMemoryExceedsLimit { requested, max } => {
57 write!(
58 f,
59 "shared memory {requested} bytes exceeds device maximum {max} bytes"
60 )
61 }
62 Self::InvalidDimension { dim, value } => {
63 write!(f, "invalid dimension {dim} = {value} (must be > 0)")
64 }
65 }
66 }
67}
68
69impl std::error::Error for LaunchError {}
70
71#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[test]
80 fn block_size_exceeds_display() {
81 let err = LaunchError::BlockSizeExceedsLimit {
82 requested: 2048,
83 max: 1024,
84 };
85 let msg = format!("{err}");
86 assert!(msg.contains("2048"));
87 assert!(msg.contains("1024"));
88 }
89
90 #[test]
91 fn grid_size_exceeds_display() {
92 let err = LaunchError::GridSizeExceedsLimit {
93 requested: 100_000,
94 max: 65535,
95 };
96 let msg = format!("{err}");
97 assert!(msg.contains("100000"));
98 assert!(msg.contains("65535"));
99 }
100
101 #[test]
102 fn shared_memory_exceeds_display() {
103 let err = LaunchError::SharedMemoryExceedsLimit {
104 requested: 65536,
105 max: 49152,
106 };
107 let msg = format!("{err}");
108 assert!(msg.contains("65536"));
109 assert!(msg.contains("49152"));
110 }
111
112 #[test]
113 fn invalid_dimension_display() {
114 let err = LaunchError::InvalidDimension {
115 dim: "block.x",
116 value: 0,
117 };
118 let msg = format!("{err}");
119 assert!(msg.contains("block.x"));
120 assert!(msg.contains("0"));
121 }
122
123 #[test]
124 fn launch_error_implements_std_error() {
125 fn assert_error<T: std::error::Error>() {}
126 assert_error::<LaunchError>();
127 }
128
129 #[test]
130 fn launch_error_eq() {
131 let a = LaunchError::BlockSizeExceedsLimit {
132 requested: 512,
133 max: 256,
134 };
135 let b = LaunchError::BlockSizeExceedsLimit {
136 requested: 512,
137 max: 256,
138 };
139 assert_eq!(a, b);
140 }
141
142 #[test]
143 fn launch_error_debug() {
144 let err = LaunchError::InvalidDimension {
145 dim: "grid.z",
146 value: 0,
147 };
148 let dbg = format!("{err:?}");
149 assert!(dbg.contains("InvalidDimension"));
150 }
151}