tritonserver_rs/
error.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use std::{
    error::Error as ErrorExt,
    ffi::{CStr, CString},
    fmt, io,
    mem::transmute,
};

use crate::sys;

pub(crate) const CSTR_CONVERT_ERROR_PLUG: &str = "INVALID UTF-8 STRING";

/// Triton server error codes
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum ErrorCode {
    Unknown = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_UNKNOWN,
    Internal = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_INTERNAL,
    NotFound = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_NOT_FOUND,
    InvalidArg = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_INVALID_ARG,
    Unavailable = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_UNAVAILABLE,
    Unsupported = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_UNSUPPORTED,
    Alreadyxists = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_ALREADY_EXISTS,
}

/// Triton server error.
pub struct Error {
    pub(crate) ptr: *mut sys::TRITONSERVER_Error,
    pub(crate) owned: bool,
}

/// It's protected by the owned, so until no one changes owned it's safe.
/// User can't change it anyhow: it's private + pub methods don't change it.
unsafe impl Send for Error {}
unsafe impl Sync for Error {}

impl Error {
    /// Create new custom error.
    pub fn new<S: AsRef<str>>(code: ErrorCode, message: S) -> Self {
        let message = CString::new(message.as_ref()).expect("CString::new failed");
        unsafe {
            let this = sys::TRITONSERVER_ErrorNew(code as u32, message.as_ptr());
            assert!(!this.is_null());
            this.into()
        }
    }

    /// Return ErrorCode of the error.
    pub fn code(&self) -> ErrorCode {
        unsafe { transmute(sys::TRITONSERVER_ErrorCode(self.ptr)) }
    }

    /// Return string representation of the ErrorCode.
    pub fn name(&self) -> &str {
        let ptr = unsafe { sys::TRITONSERVER_ErrorCodeString(self.ptr) };
        if ptr.is_null() {
            "NULL"
        } else {
            unsafe { CStr::from_ptr(ptr) }
                .to_str()
                .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
        }
    }

    /// Return error description.
    pub fn message(&self) -> &str {
        let ptr = unsafe { sys::TRITONSERVER_ErrorMessage(self.ptr) };
        if ptr.is_null() {
            "NULL"
        } else {
            unsafe { CStr::from_ptr(ptr) }
                .to_str()
                .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
        }
    }

    #[cfg(not(feature = "gpu"))]
    pub(crate) fn wrong_type(mem_type: crate::memory::MemoryType) -> Self {
        Self::new(
            ErrorCode::InvalidArg,
            format!("Got {mem_type:?} with gpu feature disabled"),
        )
    }
}

impl fmt::Debug for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}: {}", self.name(), self.message())
    }
}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}: {}", self.name(), self.message())
    }
}

impl From<*mut sys::TRITONSERVER_Error> for Error {
    fn from(ptr: *mut sys::TRITONSERVER_Error) -> Self {
        Error { ptr, owned: true }
    }
}

impl ErrorExt for Error {}

impl Drop for Error {
    fn drop(&mut self) {
        if self.owned && !self.ptr.is_null() {
            unsafe {
                sys::TRITONSERVER_ErrorDelete(self.ptr);
            }
        }
    }
}

impl From<Error> for io::Error {
    fn from(err: Error) -> Self {
        io::Error::new(io::ErrorKind::Other, err.to_string())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn create() {
        const ERROR_CODE: ErrorCode = ErrorCode::Unknown;
        const ERROR_DESCRIPTION: &str = "some error";

        let err = Error::new(ERROR_CODE, ERROR_DESCRIPTION);

        assert_eq!(err.code(), ERROR_CODE);
        assert_eq!(err.message(), ERROR_DESCRIPTION);
    }
}