triton_rs/
backend.rs

1use super::Error;
2
3pub trait Backend {
4    /// Initialize a backend. This function is optional, a backend is not
5    /// required to implement it. This function is called once when a
6    /// backend is loaded to allow the backend to initialize any state
7    /// associated with the backend. A backend has a single state that is
8    /// shared across all models that use the backend.
9    ///
10    /// Corresponds to TRITONBACKEND_Initialize.
11    fn initialize() -> Result<(), Error> {
12        Ok(())
13    }
14
15    /// Finalize for a backend. This function is optional, a backend is
16    /// not required to implement it. This function is called once, just
17    /// before the backend is unloaded. All state associated with the
18    /// backend should be freed and any threads created for the backend
19    /// should be exited/joined before returning from this function.
20    /// Corresponds to TRITONBACKEND_Finalize.
21    fn finalize() -> Result<(), Error> {
22        Ok(())
23    }
24
25    /// Initialize for a model instance. This function is optional, a
26    /// backend is not required to implement it. This function is called
27    /// once when a model instance is created to allow the backend to
28    /// initialize any state associated with the instance.
29    ///
30    /// Corresponds to TRITONBACKEND_ModelInstanceInitialize.
31    fn model_instance_initialize() -> Result<(), Error> {
32        Ok(())
33    }
34
35    /// Finalize for a model instance. This function is optional, a
36    /// backend is not required to implement it. This function is called
37    /// once for an instance, just before the corresponding model is
38    /// unloaded from Triton. All state associated with the instance
39    /// should be freed and any threads created for the instance should be
40    /// exited/joined before returning from this function.
41    ///
42    /// Corresponds to TRITONBACKEND_ModelInstanceFinalize.
43    fn model_instance_finalize() -> Result<(), Error> {
44        Ok(())
45    }
46
47    /// Execute a batch of one or more requests on a model instance. This
48    /// function is required. Triton will not perform multiple
49    /// simultaneous calls to this function for a given model 'instance';
50    /// however, there may be simultaneous calls for different model
51    /// instances (for the same or different models).
52    ///
53    /// Corresponds to TRITONBACKEND_ModelInstanceExecute.
54    fn model_instance_execute(
55        model: super::Model,
56        requests: &[super::Request],
57    ) -> Result<(), Error>;
58}
59
60#[macro_export]
61macro_rules! call_checked {
62    ($res:expr) => {
63        match $res {
64            Err(err) => {
65                let err = CString::new(err.to_string()).expect("CString::new failed");
66                unsafe {
67                    triton_rs::sys::TRITONSERVER_ErrorNew(
68                        triton_rs::sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_INTERNAL,
69                        err.as_ptr(),
70                    )
71                }
72            }
73            Ok(ok) => ptr::null(),
74        }
75    };
76}
77
78#[macro_export]
79macro_rules! declare_backend {
80    ($class:ident) => {
81        #[no_mangle]
82        extern "C" fn TRITONBACKEND_Initialize(
83            backend: *const triton_rs::sys::TRITONBACKEND_Backend,
84        ) -> *const triton_rs::sys::TRITONSERVER_Error {
85            triton_rs::call_checked!($class::initialize())
86        }
87
88        #[no_mangle]
89        extern "C" fn TRITONBACKEND_Finalize(
90            backend: *const triton_rs::sys::TRITONBACKEND_Backend,
91        ) -> *const triton_rs::sys::TRITONSERVER_Error {
92            triton_rs::call_checked!($class::finalize())
93        }
94
95        #[no_mangle]
96        extern "C" fn TRITONBACKEND_ModelInstanceInitialize(
97            instance: *mut triton_rs::sys::TRITONBACKEND_ModelInstance,
98        ) -> *const triton_rs::sys::TRITONSERVER_Error {
99            triton_rs::call_checked!($class::model_instance_initialize())
100        }
101
102        #[no_mangle]
103        extern "C" fn TRITONBACKEND_ModelInstanceFinalize(
104            instance: *const triton_rs::sys::TRITONBACKEND_ModelInstance,
105        ) -> *const triton_rs::sys::TRITONSERVER_Error {
106            triton_rs::call_checked!($class::model_instance_finalize())
107        }
108
109        #[no_mangle]
110        extern "C" fn TRITONBACKEND_ModelInstanceExecute(
111            instance: *mut triton_rs::sys::TRITONBACKEND_ModelInstance,
112            requests: *const *mut triton_rs::sys::TRITONBACKEND_Request,
113            request_count: u32,
114        ) -> *const triton_rs::sys::TRITONSERVER_Error {
115            let mut model: *mut triton_rs::sys::TRITONBACKEND_Model = ptr::null_mut();
116            let err =
117                unsafe { triton_rs::sys::TRITONBACKEND_ModelInstanceModel(instance, &mut model) };
118            if !err.is_null() {
119                return err;
120            }
121
122            let model = triton_rs::Model::from_ptr(model);
123
124            let requests = unsafe { slice::from_raw_parts(requests, request_count as usize) };
125            let requests = requests
126                .iter()
127                .map(|req| triton_rs::Request::from_ptr(*req))
128                .collect::<Vec<triton_rs::Request>>();
129
130            triton_rs::call_checked!($class::model_instance_execute(model, &requests))
131        }
132    };
133}