tritonserver_rs/
context.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
use std::{
    collections::HashMap,
    ffi::{c_char, c_int},
    ptr::null_mut,
    sync::Arc,
};

use cuda_driver_sys::{
    cuCtxCreate_v2, cuCtxDestroy_v2, cuCtxGetApiVersion, cuCtxPopCurrent_v2, cuCtxPushCurrent_v2,
    cuDeviceGet, cuDeviceGetAttribute, cuDeviceGetName, cuDeviceTotalMem_v2, cuInit, CUcontext,
    CUdevice, CUdevice_attribute,
};
use parking_lot::{Once, RwLock};

use crate::{error::Error, from_char_array};

/// Initialize Cuda runtime. Should be called before any Cuda function, perfectly — on the start of the application.
pub fn init_cuda() -> Result<(), Error> {
    cuda_call!(cuInit(0))
}

lazy_static::lazy_static! {
    static ref CUDA_CONTEXTS: RwLock<HashMap<i32, Arc<Context>>> = RwLock::new(HashMap::default());
    static ref ONCE: Once = Once::new();
}

/// Get Cuda context on device.
pub fn get_context(device: i32) -> Result<Arc<Context>, Error> {
    if let Some(ctx) = CUDA_CONTEXTS.read().get(&device) {
        return Ok(ctx.clone());
    }

    ONCE.call_once(|| init_cuda().unwrap());

    let dev = CuDevice::new(device)?;
    log::info!(
        "Using: {} {:.2}Gb",
        dev.get_name().unwrap(),
        dev.get_total_mem().unwrap() as f64 / (1_000_000_000) as f64
    );

    let arc = Arc::new(Context::new(dev, 0)?);
    CUDA_CONTEXTS.write().insert(device, arc.clone());

    Ok(arc)
}

/// Handler of Cuda context that was pushed as current.
/// On Drop will pop context from current.
pub struct ContextHandler<'a> {
    _ctx: &'a Context,
}

impl Drop for ContextHandler<'_> {
    fn drop(&mut self) {
        let _ = cuda_call!(cuCtxPopCurrent_v2(std::ptr::null_mut()));
    }
}

/// Cuda Context.
pub struct Context {
    context: cuda_driver_sys::CUcontext,
}

unsafe impl Send for Context {}
unsafe impl Sync for Context {}

impl Context {
    /// Create Context on device `dev`. It is recommended to use zeroed `flags`.
    pub fn new(dev: CuDevice, flags: u32) -> Result<Context, Error> {
        let mut ctx = Context {
            context: std::ptr::null_mut(),
        };

        cuda_call!(cuCtxCreate_v2(
            &mut ctx.context as *mut CUcontext,
            flags,
            dev.device
        ))
        .map(|_| ctx)
    }

    /// Get Cuda API version.
    pub fn get_api_version(&self) -> Result<u32, Error> {
        let mut ver = 0;
        cuda_call!(cuCtxGetApiVersion(self.context, &mut ver as *mut u32)).map(|_| ver)
    }

    /// Make this context current.
    pub fn make_current(&self) -> Result<ContextHandler<'_>, Error> {
        cuda_call!(cuCtxPushCurrent_v2(self.context))?;

        Ok(ContextHandler { _ctx: self })
    }
}

impl Drop for Context {
    fn drop(&mut self) {
        if !self.context.is_null() {
            let _ = cuda_call!(cuCtxDestroy_v2(self.context));
        }
    }
}

/// Cuda representation of the device.
#[derive(Debug, Clone, Copy, Default)]
pub struct CuDevice {
    pub device: CUdevice,
}

impl CuDevice {
    /// Create new device with id `ordinal`.
    pub fn new(ordinal: c_int) -> Result<CuDevice, Error> {
        let mut d = CuDevice { device: 0 };

        cuda_call!(cuDeviceGet(&mut d.device as *mut i32, ordinal)).map(|_| d)
    }

    /// Get attributes of the device.
    pub fn get_attribute(&self, attr: CUdevice_attribute) -> Result<c_int, Error> {
        let mut pi = 0;

        cuda_call!(cuDeviceGetAttribute(&mut pi as *mut i32, attr, self.device)).map(|_| pi)
    }

    /// Get name of the device.
    pub fn get_name(&self) -> Result<String, Error> {
        let name = null_mut::<c_char>();

        cuda_call!(
            cuDeviceGetName(name, 256, self.device,),
            from_char_array(name)
        )
    }

    /// Get total mem of the device.
    pub fn get_total_mem(&self) -> Result<usize, Error> {
        let mut val = 0;

        cuda_call!(cuDeviceTotalMem_v2(
            &mut val as *mut usize as *mut _,
            self.device
        ))
        .map(|_| val)
    }
}