Skip to main content

singe_cuda/
nvtx.rs

1use std::{
2    ffi::{CStr, CString},
3    marker::PhantomData,
4};
5
6use num_enum::{IntoPrimitive, TryFromPrimitive};
7use singe_core::{impl_enum_conversion, impl_enum_display};
8use singe_cuda_sys::nvtx as sys;
9
10use crate::error::{Error, Result};
11
12// TODO: move to a core version type
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub struct Version {
15    pub major: u32,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct Color(u32);
20
21impl Color {
22    pub const fn argb(value: u32) -> Self {
23        Self(value)
24    }
25
26    pub const fn rgba(red: u8, green: u8, blue: u8, alpha: u8) -> Self {
27        Self(((alpha as u32) << 24) | ((red as u32) << 16) | ((green as u32) << 8) | blue as u32)
28    }
29
30    pub const fn as_raw(self) -> u32 {
31        self.0
32    }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub struct Category(u32);
37
38impl Category {
39    pub const fn from_raw(value: u32) -> Self {
40        Self(value)
41    }
42
43    pub const fn as_raw(self) -> u32 {
44        self.0
45    }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
49#[repr(u32)]
50#[non_exhaustive]
51pub enum ColorType {
52    Unknown = sys::nvtxColorType_t::NVTX_COLOR_UNKNOWN as _,
53    Argb = sys::nvtxColorType_t::NVTX_COLOR_ARGB as _,
54}
55
56impl_enum_conversion!(sys::nvtxColorType_t, ColorType);
57
58impl_enum_display!(ColorType, {
59    Self::Unknown => "NVTX_COLOR_UNKNOWN",
60    Self::Argb => "NVTX_COLOR_ARGB",
61});
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
64#[repr(u32)]
65#[non_exhaustive]
66pub enum MessageType {
67    Unknown = sys::nvtxMessageType_t::NVTX_MESSAGE_UNKNOWN as _,
68    Ascii = sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_ASCII as _,
69    Unicode = sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_UNICODE as _,
70    Registered = sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_REGISTERED as _,
71}
72
73impl_enum_conversion!(sys::nvtxMessageType_t, MessageType);
74
75impl_enum_display!(MessageType, {
76    Self::Unknown => "NVTX_MESSAGE_UNKNOWN",
77    Self::Ascii => "NVTX_MESSAGE_TYPE_ASCII",
78    Self::Unicode => "NVTX_MESSAGE_TYPE_UNICODE",
79    Self::Registered => "NVTX_MESSAGE_TYPE_REGISTERED",
80});
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
83#[repr(u32)]
84#[non_exhaustive]
85pub enum PayloadType {
86    Unknown = sys::nvtxPayloadType_t::NVTX_PAYLOAD_UNKNOWN as _,
87    UnsignedInt64 = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_UNSIGNED_INT64 as _,
88    Int64 = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_INT64 as _,
89    Double = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_DOUBLE as _,
90    UnsignedInt32 = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_UNSIGNED_INT32 as _,
91    Int32 = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_INT32 as _,
92    Float = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_FLOAT as _,
93}
94
95impl_enum_conversion!(sys::nvtxPayloadType_t, PayloadType);
96
97impl_enum_display!(PayloadType, {
98    Self::Unknown => "NVTX_PAYLOAD_UNKNOWN",
99    Self::UnsignedInt64 => "NVTX_PAYLOAD_TYPE_UNSIGNED_INT64",
100    Self::Int64 => "NVTX_PAYLOAD_TYPE_INT64",
101    Self::Double => "NVTX_PAYLOAD_TYPE_DOUBLE",
102    Self::UnsignedInt32 => "NVTX_PAYLOAD_TYPE_UNSIGNED_INT32",
103    Self::Int32 => "NVTX_PAYLOAD_TYPE_INT32",
104    Self::Float => "NVTX_PAYLOAD_TYPE_FLOAT",
105});
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
108#[repr(u32)]
109#[non_exhaustive]
110pub enum ResourceGenericType {
111    Unknown = sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_UNKNOWN as _,
112    GenericPointer = sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_GENERIC_POINTER as _,
113    GenericHandle = sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_GENERIC_HANDLE as _,
114    GenericThreadNative =
115        sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_GENERIC_THREAD_NATIVE as _,
116    GenericThreadPosix =
117        sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_GENERIC_THREAD_POSIX as _,
118}
119
120impl_enum_conversion!(sys::nvtxResourceGenericType_t, ResourceGenericType);
121
122impl_enum_display!(ResourceGenericType, {
123    Self::Unknown => "NVTX_RESOURCE_TYPE_UNKNOWN",
124    Self::GenericPointer => "NVTX_RESOURCE_TYPE_GENERIC_POINTER",
125    Self::GenericHandle => "NVTX_RESOURCE_TYPE_GENERIC_HANDLE",
126    Self::GenericThreadNative => "NVTX_RESOURCE_TYPE_GENERIC_THREAD_NATIVE",
127    Self::GenericThreadPosix => "NVTX_RESOURCE_TYPE_GENERIC_THREAD_POSIX",
128});
129
130#[derive(Debug, Clone, Copy, PartialEq)]
131#[non_exhaustive]
132pub enum Payload {
133    I32(i32),
134    I64(i64),
135    U32(u32),
136    U64(u64),
137    F32(f32),
138    F64(f64),
139}
140
141impl Payload {
142    fn encode_type(self) -> sys::nvtxPayloadType_t {
143        match self {
144            Self::I32(_) => PayloadType::Int32.into(),
145            Self::I64(_) => PayloadType::Int64.into(),
146            Self::U32(_) => PayloadType::UnsignedInt32.into(),
147            Self::U64(_) => PayloadType::UnsignedInt64.into(),
148            Self::F32(_) => PayloadType::Float.into(),
149            Self::F64(_) => PayloadType::Double.into(),
150        }
151    }
152
153    fn encode_value(self) -> sys::nvtxEventAttributes_v2_payload_t {
154        match self {
155            Self::I32(value) => sys::nvtxEventAttributes_v2_payload_t { iValue: value },
156            Self::I64(value) => sys::nvtxEventAttributes_v2_payload_t { llValue: value },
157            Self::U32(value) => sys::nvtxEventAttributes_v2_payload_t { uiValue: value },
158            Self::U64(value) => sys::nvtxEventAttributes_v2_payload_t { ullValue: value },
159            Self::F32(value) => sys::nvtxEventAttributes_v2_payload_t { fValue: value },
160            Self::F64(value) => sys::nvtxEventAttributes_v2_payload_t { dValue: value },
161        }
162    }
163}
164
165#[derive(Debug, Clone, Copy)]
166pub struct EventAttributes<'a> {
167    message: Option<&'a CStr>,
168    category: Option<Category>,
169    color: Option<Color>,
170    payload: Option<Payload>,
171}
172
173impl<'a> EventAttributes<'a> {
174    pub const fn new() -> Self {
175        Self {
176            message: None,
177            category: None,
178            color: None,
179            payload: None,
180        }
181    }
182
183    pub fn with_message(mut self, message: &'a CStr) -> Self {
184        self.message = Some(message);
185        self
186    }
187
188    pub fn with_category(mut self, category: Category) -> Self {
189        self.category = Some(category);
190        self
191    }
192
193    pub fn with_color(mut self, color: Color) -> Self {
194        self.color = Some(color);
195        self
196    }
197
198    pub fn with_payload(mut self, payload: Payload) -> Self {
199        self.payload = Some(payload);
200        self
201    }
202
203    pub const fn message(&self) -> Option<&'a CStr> {
204        self.message
205    }
206
207    pub const fn category(&self) -> Option<Category> {
208        self.category
209    }
210
211    pub const fn color(&self) -> Option<Color> {
212        self.color
213    }
214
215    pub const fn payload(&self) -> Option<Payload> {
216        self.payload
217    }
218
219    fn encode(self) -> sys::nvtxEventAttributes_t {
220        let mut raw = sys::nvtxEventAttributes_t {
221            version: sys::NVTX_VERSION as u16,
222            size: size_of::<sys::nvtxEventAttributes_t>() as u16,
223            ..Default::default()
224        };
225
226        if let Some(category) = self.category {
227            raw.category = category.0;
228        }
229
230        if let Some(color) = self.color {
231            raw.colorType = sys::nvtxColorType_t::from(ColorType::Argb) as i32;
232            raw.color = color.0;
233        }
234
235        if let Some(payload) = self.payload {
236            raw.payloadType = payload.encode_type() as i32;
237            raw.payload = payload.encode_value();
238        }
239
240        if let Some(message) = self.message {
241            raw.messageType = sys::nvtxMessageType_t::from(MessageType::Ascii) as i32;
242            raw.message.ascii = message.as_ptr();
243        }
244
245        raw
246    }
247}
248
249impl Default for EventAttributes<'_> {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255#[derive(Debug, Clone)]
256pub struct Event {
257    message: CString,
258    category: Option<Category>,
259    color: Option<Color>,
260    payload: Option<Payload>,
261}
262
263impl Event {
264    pub fn create(message: &str) -> Result<Self> {
265        Ok(Self {
266            message: CString::new(message)?,
267            category: None,
268            color: None,
269            payload: None,
270        })
271    }
272
273    pub fn create_from_c_string(message: CString) -> Self {
274        Self {
275            message,
276            category: None,
277            color: None,
278            payload: None,
279        }
280    }
281
282    pub fn with_category(mut self, category: Category) -> Self {
283        self.category = Some(category);
284        self
285    }
286
287    pub fn with_color(mut self, color: Color) -> Self {
288        self.color = Some(color);
289        self
290    }
291
292    pub fn with_payload(mut self, payload: Payload) -> Self {
293        self.payload = Some(payload);
294        self
295    }
296
297    pub fn mark(&self) {
298        mark_with_attributes(self.attributes());
299    }
300
301    pub fn local_range(&self) -> LocalRange {
302        LocalRange::from_attributes(self.attributes())
303    }
304
305    pub fn range(&self) -> Range {
306        Range::from_attributes(self.attributes())
307    }
308
309    pub fn domain_mark(&self, domain: &Domain) {
310        domain.mark_with_attributes(self.attributes());
311    }
312
313    pub fn domain_local_range<'a>(&self, domain: &'a Domain) -> DomainLocalRange<'a> {
314        domain.range_with_attributes(self.attributes())
315    }
316
317    pub fn domain_range<'a>(&self, domain: &'a Domain) -> DomainRange<'a> {
318        domain.start_range_with_attributes(self.attributes())
319    }
320
321    pub fn attributes(&self) -> EventAttributes<'_> {
322        let mut attributes = EventAttributes::new().with_message(&self.message);
323
324        if let Some(category) = self.category {
325            attributes = attributes.with_category(category);
326        }
327
328        if let Some(color) = self.color {
329            attributes = attributes.with_color(color);
330        }
331
332        if let Some(payload) = self.payload {
333            attributes = attributes.with_payload(payload);
334        }
335
336        attributes
337    }
338}
339
340#[derive(Debug)]
341pub struct Domain {
342    handle: sys::nvtxDomainHandle_t,
343}
344
345// NVTX domains are process-wide annotation handles. The wrapper owns the handle
346// and only passes immutable copies to NVTX entry points.
347unsafe impl Send for Domain {}
348unsafe impl Sync for Domain {}
349
350impl Domain {
351    pub fn create(name: &str) -> Result<Self> {
352        let name = CString::new(name)?;
353        Self::create_from_c_str(&name)
354    }
355
356    pub fn create_from_c_str(name: &CStr) -> Result<Self> {
357        let handle = unsafe { sys::nvtxDomainCreateA(name.as_ptr()) };
358        if handle.is_null() {
359            return Err(Error::NullHandle);
360        }
361        Ok(Self { handle })
362    }
363
364    pub fn as_raw(&self) -> sys::nvtxDomainHandle_t {
365        self.handle
366    }
367
368    pub fn mark(&self, message: &str) -> Result<()> {
369        let message = CString::new(message)?;
370        self.mark_c_str(&message);
371        Ok(())
372    }
373
374    pub fn mark_c_str(&self, message: &CStr) {
375        self.mark_with_attributes(EventAttributes::new().with_message(message));
376    }
377
378    pub fn mark_with_attributes(&self, attributes: EventAttributes<'_>) {
379        let raw = attributes.encode();
380        unsafe { sys::nvtxDomainMarkEx(self.handle, &raw) };
381    }
382
383    pub fn range<'a>(&'a self, message: &str) -> Result<DomainLocalRange<'a>> {
384        let message = CString::new(message)?;
385        Ok(self.range_c_str(&message))
386    }
387
388    pub fn range_c_str<'a>(&'a self, message: &CStr) -> DomainLocalRange<'a> {
389        self.range_with_attributes(EventAttributes::new().with_message(message))
390    }
391
392    pub fn range_with_attributes<'a>(
393        &'a self,
394        attributes: EventAttributes<'_>,
395    ) -> DomainLocalRange<'a> {
396        let raw = attributes.encode();
397        unsafe { sys::nvtxDomainRangePushEx(self.handle, &raw) };
398        DomainLocalRange {
399            domain: self,
400            _not_send: PhantomData,
401        }
402    }
403
404    pub fn start_range(&self, message: &str) -> Result<DomainRange<'_>> {
405        let message = CString::new(message)?;
406        Ok(self.start_range_c_str(&message))
407    }
408
409    pub fn start_range_c_str(&self, message: &CStr) -> DomainRange<'_> {
410        self.start_range_with_attributes(EventAttributes::new().with_message(message))
411    }
412
413    pub fn start_range_with_attributes(&self, attributes: EventAttributes<'_>) -> DomainRange<'_> {
414        let raw = attributes.encode();
415        let id = unsafe { sys::nvtxDomainRangeStartEx(self.handle, &raw) };
416        DomainRange { domain: self, id }
417    }
418
419    pub fn name_category(&self, category: Category, name: &str) -> Result<()> {
420        let name = CString::new(name)?;
421        unsafe { sys::nvtxDomainNameCategoryA(self.handle, category.0, name.as_ptr()) };
422        Ok(())
423    }
424}
425
426impl Drop for Domain {
427    fn drop(&mut self) {
428        unsafe { sys::nvtxDomainDestroy(self.handle) };
429    }
430}
431
432#[derive(Debug)]
433pub struct LocalRange {
434    _not_send: PhantomData<*mut ()>,
435}
436
437impl LocalRange {
438    pub fn create(message: &str) -> Result<Self> {
439        let message = CString::new(message)?;
440        Ok(Self::create_from_c_str(&message))
441    }
442
443    pub fn create_from_c_str(message: &CStr) -> Self {
444        unsafe { sys::nvtxRangePushA(message.as_ptr()) };
445        Self {
446            _not_send: PhantomData,
447        }
448    }
449
450    pub fn from_attributes(attributes: EventAttributes<'_>) -> Self {
451        let raw = attributes.encode();
452        unsafe { sys::nvtxRangePushEx(&raw) };
453        Self {
454            _not_send: PhantomData,
455        }
456    }
457}
458
459impl Drop for LocalRange {
460    fn drop(&mut self) {
461        unsafe { sys::nvtxRangePop() };
462    }
463}
464
465#[derive(Debug)]
466pub struct Range {
467    id: sys::nvtxRangeId_t,
468}
469
470impl Range {
471    pub fn create(message: &str) -> Result<Self> {
472        let message = CString::new(message)?;
473        Ok(Self::create_from_c_str(&message))
474    }
475
476    pub fn create_from_c_str(message: &CStr) -> Self {
477        let id = unsafe { sys::nvtxRangeStartA(message.as_ptr()) };
478        Self { id }
479    }
480
481    pub fn from_attributes(attributes: EventAttributes<'_>) -> Self {
482        let raw = attributes.encode();
483        let id = unsafe { sys::nvtxRangeStartEx(&raw) };
484        Self { id }
485    }
486}
487
488impl Drop for Range {
489    fn drop(&mut self) {
490        unsafe { sys::nvtxRangeEnd(self.id) };
491    }
492}
493
494#[derive(Debug)]
495pub struct DomainLocalRange<'a> {
496    domain: &'a Domain,
497    _not_send: PhantomData<*mut ()>,
498}
499
500impl Drop for DomainLocalRange<'_> {
501    fn drop(&mut self) {
502        unsafe { sys::nvtxDomainRangePop(self.domain.handle) };
503    }
504}
505
506#[derive(Debug)]
507pub struct DomainRange<'a> {
508    domain: &'a Domain,
509    id: sys::nvtxRangeId_t,
510}
511
512impl Drop for DomainRange<'_> {
513    fn drop(&mut self) {
514        unsafe { sys::nvtxDomainRangeEnd(self.domain.handle, self.id) };
515    }
516}
517
518pub fn version() -> Version {
519    Version {
520        major: sys::NVTX_VERSION,
521    }
522}
523
524pub fn initialize() {
525    unsafe { sys::nvtxInitialize(std::ptr::null()) };
526}
527
528pub fn mark(message: &str) -> Result<()> {
529    Event::create(message)?.mark();
530    Ok(())
531}
532
533pub fn mark_c_str(message: &CStr) {
534    unsafe { sys::nvtxMarkA(message.as_ptr()) };
535}
536
537pub fn mark_with_attributes(attributes: EventAttributes<'_>) {
538    let raw = attributes.encode();
539    unsafe { sys::nvtxMarkEx(&raw) };
540}
541
542pub fn name_category(category: Category, name: &str) -> Result<()> {
543    let name = CString::new(name)?;
544    unsafe { sys::nvtxNameCategoryA(category.0, name.as_ptr()) };
545    Ok(())
546}
547
548pub fn name_os_thread(thread_id: u32, name: &str) -> Result<()> {
549    let name = CString::new(name)?;
550    unsafe { sys::nvtxNameOsThreadA(thread_id, name.as_ptr()) };
551    Ok(())
552}
553
554pub fn scoped_range(message: &str) -> Result<LocalRange> {
555    LocalRange::create(message)
556}
557
558#[cfg(test)]
559mod tests {
560    use std::mem;
561
562    use super::*;
563
564    #[test]
565    fn encodes_event_attributes() {
566        let message = c"work";
567        let raw = EventAttributes::new()
568            .with_message(message)
569            .with_category(Category::from_raw(7))
570            .with_color(Color::rgba(1, 2, 3, 4))
571            .with_payload(Payload::I64(-42))
572            .encode();
573
574        assert_eq!(raw.version, sys::NVTX_VERSION as u16);
575        assert_eq!(
576            raw.size,
577            mem::size_of::<sys::nvtxEventAttributes_t>() as u16
578        );
579        assert_eq!(raw.category, 7);
580        assert_eq!(raw.colorType, sys::nvtxColorType_t::NVTX_COLOR_ARGB as i32);
581        assert_eq!(raw.color, 0x0401_0203);
582        assert_eq!(
583            raw.messageType,
584            sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_ASCII as i32
585        );
586        assert_eq!(unsafe { raw.message.ascii }, message.as_ptr());
587        assert_eq!(
588            raw.payloadType,
589            sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_INT64 as i32
590        );
591        assert_eq!(unsafe { raw.payload.llValue }, -42);
592    }
593
594    #[test]
595    fn owned_event_builds_attributes() {
596        let event = Event::create("owned")
597            .unwrap()
598            .with_category(Category::from_raw(3))
599            .with_color(Color::argb(0xff00_00ff))
600            .with_payload(Payload::U32(11));
601
602        let attributes = event.attributes();
603        let raw = attributes.encode();
604
605        assert_eq!(attributes.message(), Some(c"owned".as_ref()));
606        assert_eq!(attributes.category(), Some(Category::from_raw(3)));
607        assert_eq!(attributes.color(), Some(Color::argb(0xff00_00ff)));
608        assert_eq!(attributes.payload(), Some(Payload::U32(11)));
609        assert_eq!(raw.category, 3);
610        assert_eq!(raw.color, 0xff00_00ff);
611        assert_eq!(
612            raw.payloadType,
613            sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_UNSIGNED_INT32 as i32
614        );
615        assert_eq!(unsafe { raw.payload.uiValue }, 11);
616    }
617
618    #[test]
619    fn enum_wrappers_convert_and_display() {
620        assert_eq!(
621            ColorType::from(sys::nvtxColorType_t::NVTX_COLOR_ARGB),
622            ColorType::Argb
623        );
624        assert_eq!(
625            sys::nvtxMessageType_t::from(MessageType::Ascii),
626            sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_ASCII
627        );
628        assert_eq!(
629            PayloadType::UnsignedInt64.to_string(),
630            "NVTX_PAYLOAD_TYPE_UNSIGNED_INT64"
631        );
632        assert_eq!(
633            ResourceGenericType::GenericThreadPosix.to_string(),
634            "NVTX_RESOURCE_TYPE_GENERIC_THREAD_POSIX"
635        );
636    }
637}