triton_rs/backend.rs
1use super::Error;
2
3pub trait Backend {
4 /// Initialize a backend. This function is optional, a backend is not
5 /// required to implement it. This function is called once when a
6 /// backend is loaded to allow the backend to initialize any state
7 /// associated with the backend. A backend has a single state that is
8 /// shared across all models that use the backend.
9 ///
10 /// Corresponds to TRITONBACKEND_Initialize.
11 fn initialize() -> Result<(), Error> {
12 Ok(())
13 }
14
15 /// Finalize for a backend. This function is optional, a backend is
16 /// not required to implement it. This function is called once, just
17 /// before the backend is unloaded. All state associated with the
18 /// backend should be freed and any threads created for the backend
19 /// should be exited/joined before returning from this function.
20 /// Corresponds to TRITONBACKEND_Finalize.
21 fn finalize() -> Result<(), Error> {
22 Ok(())
23 }
24
25 /// Initialize for a model instance. This function is optional, a
26 /// backend is not required to implement it. This function is called
27 /// once when a model instance is created to allow the backend to
28 /// initialize any state associated with the instance.
29 ///
30 /// Corresponds to TRITONBACKEND_ModelInstanceInitialize.
31 fn model_instance_initialize() -> Result<(), Error> {
32 Ok(())
33 }
34
35 /// Finalize for a model instance. This function is optional, a
36 /// backend is not required to implement it. This function is called
37 /// once for an instance, just before the corresponding model is
38 /// unloaded from Triton. All state associated with the instance
39 /// should be freed and any threads created for the instance should be
40 /// exited/joined before returning from this function.
41 ///
42 /// Corresponds to TRITONBACKEND_ModelInstanceFinalize.
43 fn model_instance_finalize() -> Result<(), Error> {
44 Ok(())
45 }
46
47 /// Execute a batch of one or more requests on a model instance. This
48 /// function is required. Triton will not perform multiple
49 /// simultaneous calls to this function for a given model 'instance';
50 /// however, there may be simultaneous calls for different model
51 /// instances (for the same or different models).
52 ///
53 /// Corresponds to TRITONBACKEND_ModelInstanceExecute.
54 fn model_instance_execute(
55 model: super::Model,
56 requests: &[super::Request],
57 ) -> Result<(), Error>;
58}
59
60#[macro_export]
61macro_rules! call_checked {
62 ($res:expr) => {
63 match $res {
64 Err(err) => {
65 let err = CString::new(err.to_string()).expect("CString::new failed");
66 unsafe {
67 triton_rs::sys::TRITONSERVER_ErrorNew(
68 triton_rs::sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_INTERNAL,
69 err.as_ptr(),
70 )
71 }
72 }
73 Ok(ok) => ptr::null(),
74 }
75 };
76}
77
78#[macro_export]
79macro_rules! declare_backend {
80 ($class:ident) => {
81 #[no_mangle]
82 extern "C" fn TRITONBACKEND_Initialize(
83 backend: *const triton_rs::sys::TRITONBACKEND_Backend,
84 ) -> *const triton_rs::sys::TRITONSERVER_Error {
85 triton_rs::call_checked!($class::initialize())
86 }
87
88 #[no_mangle]
89 extern "C" fn TRITONBACKEND_Finalize(
90 backend: *const triton_rs::sys::TRITONBACKEND_Backend,
91 ) -> *const triton_rs::sys::TRITONSERVER_Error {
92 triton_rs::call_checked!($class::finalize())
93 }
94
95 #[no_mangle]
96 extern "C" fn TRITONBACKEND_ModelInstanceInitialize(
97 instance: *mut triton_rs::sys::TRITONBACKEND_ModelInstance,
98 ) -> *const triton_rs::sys::TRITONSERVER_Error {
99 triton_rs::call_checked!($class::model_instance_initialize())
100 }
101
102 #[no_mangle]
103 extern "C" fn TRITONBACKEND_ModelInstanceFinalize(
104 instance: *const triton_rs::sys::TRITONBACKEND_ModelInstance,
105 ) -> *const triton_rs::sys::TRITONSERVER_Error {
106 triton_rs::call_checked!($class::model_instance_finalize())
107 }
108
109 #[no_mangle]
110 extern "C" fn TRITONBACKEND_ModelInstanceExecute(
111 instance: *mut triton_rs::sys::TRITONBACKEND_ModelInstance,
112 requests: *const *mut triton_rs::sys::TRITONBACKEND_Request,
113 request_count: u32,
114 ) -> *const triton_rs::sys::TRITONSERVER_Error {
115 let mut model: *mut triton_rs::sys::TRITONBACKEND_Model = ptr::null_mut();
116 let err =
117 unsafe { triton_rs::sys::TRITONBACKEND_ModelInstanceModel(instance, &mut model) };
118 if !err.is_null() {
119 return err;
120 }
121
122 let model = triton_rs::Model::from_ptr(model);
123
124 let requests = unsafe { slice::from_raw_parts(requests, request_count as usize) };
125 let requests = requests
126 .iter()
127 .map(|req| triton_rs::Request::from_ptr(*req))
128 .collect::<Vec<triton_rs::Request>>();
129
130 triton_rs::call_checked!($class::model_instance_execute(model, &requests))
131 }
132 };
133}