tritonserver_rs/request/
infer.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
use std::{collections::HashMap, ffi::c_void, mem::forget, ptr::null_mut, sync::Arc};

use log::trace;
use tokio::{
    runtime::Handle,
    sync::oneshot::{self, Receiver},
};

use crate::{
    allocator::Allocator,
    error::{Error, ErrorCode},
    memory::Buffer,
    sys, Request, Response,
};

/// Inference result error. Contains output buffers that was allocated by user provided Allocator during the inference.
#[derive(Debug)]
pub struct InferenceError {
    pub error: Error,
    pub output_buffers: HashMap<String, Buffer>,
}

impl From<Error> for InferenceError {
    fn from(error: Error) -> Self {
        Self {
            error,
            output_buffers: HashMap::new(),
        }
    }
}

impl std::fmt::Display for InferenceError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.error.fmt(f)
    }
}

impl std::error::Error for InferenceError {}

/// Future that returns the inference response.
///
/// Also the input buffers assigned to the request can be returned via [get_input_release](ResponseFuture::get_input_release).
pub struct ResponseFuture {
    pub(super) response_receiver: Receiver<Result<Response, InferenceError>>,
    pub(super) input_release: Option<InputRelease>,
}

/// Struct that returns input buffers assigned to the request. \
/// Note: input buffer can be released in any time from the start of the inference
/// to the end of it.
///
/// Input buffers will be dropped if no one will await on this struct.
pub struct InputRelease(pub(super) oneshot::Receiver<HashMap<String, Buffer>>);

/// Start inference.
impl<'a> Request<'a> {
    /// Perform inference using the metadata and inputs supplied by the Request(self). \
    /// If the function returns success,
    /// the returned struct can be used to get results (.await) of the inference and
    /// to return input buffers after the inference start [ResponseFuture::get_input_release]. \
    /// Note: output buffer will be returned with [Response] or [InferenceError]. \
    pub fn infer_async(mut self) -> Result<ResponseFuture, Error> {
        // Check on all buffers are set.
        if self.input.is_empty() {
            return Err(Error::new(
                ErrorCode::NotFound,
                "Request's output buffer is not set",
            ));
        }
        if self.custom_allocator.is_none() {
            return Err(Error::new(
                ErrorCode::NotFound,
                "Request's output buffers allocator is not set",
            ));
        }
        let custom_allocator = self.custom_allocator.take().unwrap();
        let trace = self.custom_trace.take();

        // Add outputs.
        self.add_outputs()?;
        let outputs_count = self.server.get_model(&self.model_name)?.outputs.len();

        let runtime = self.server.runtime.clone();
        let request_ptr = self.ptr;
        let server_ptr = self.server.ptr.as_mut_ptr();

        // Канал, по которому мы вернем input buffer пользователю.
        let (input_tx, input_rx) = oneshot::channel();
        // На всякий случай сохраним указатель, в случае ошибки sys::TRITONSERVER_InferenceRequestSetReleaseCallback
        // разыменуем его и правильно дропнем Request.
        let boxed_request_input_recover = Box::into_raw(Box::new((self, input_tx)));
        let drop_boxed_request = |boxed_request: *mut (Request, _)| {
            let (_restored_request, _) = unsafe { *Box::from_raw(boxed_request) };
        };

        // Здесь мы отдаем Request, он нам вернется в методе release_callback.
        // Там же будет возвращен input_buffer.
        let err = unsafe {
            sys::TRITONSERVER_InferenceRequestSetReleaseCallback(
                request_ptr,
                Some(release_callback),
                boxed_request_input_recover as *mut _,
            )
        };

        if !err.is_null() {
            drop_boxed_request(boxed_request_input_recover);

            let err = Error {
                ptr: err,
                owned: true,
            };
            return Err(err);
        }

        // Allocator отправляется в alloc -> release, там он выдает запрашиваемые тритоном буферы в alloc и шлет их обратно в release.
        // Так как Allocator используется тритоном в методе release, который вызывается после удаления Response,
        // необходимо отправить алокатор в response_wrapper -> Response, чтобы Arc не дропнулся раньше времени.
        // Имена буферов отправляется в response_wrapper, на нем будем ждать возвращенные буферы для Response.
        let allocator = Arc::new(Allocator::new(custom_allocator, runtime.clone())?);

        let allocator_ptr = Arc::as_ptr(&allocator);
        // response_tx отправляется в response_wrapper,
        // когда там сконструируется Response, он будет положен в tx.
        // response_rx отправляется юзеру внутри ResponseFuture, он на нем await-ится.
        let (response_tx, response_rx) = oneshot::channel();

        triton_call!(sys::TRITONSERVER_InferenceRequestSetResponseCallback(
            request_ptr,
            allocator.get_allocator(),
            allocator_ptr as *mut c_void,
            Some(responce_wrapper),
            Box::into_raw(Box::new(ResponseCallbackItems {
                response_tx,
                allocator,
                outputs_count,
                runtime,
            })) as *mut _,
        ))?;

        let trace_ptr = trace
            .as_ref()
            .map(|trace| trace.ptr)
            .unwrap_or_else(null_mut);

        triton_call!(sys::TRITONSERVER_ServerInferAsync(
            server_ptr,
            request_ptr,
            trace_ptr
        ))?;

        // Do not drop the trace, it drops in trace::release_fn
        let _ = trace.map(forget);

        Ok(ResponseFuture {
            response_receiver: response_rx,
            input_release: Some(InputRelease(input_rx)),
        })
    }
}

struct ResponseCallbackItems {
    response_tx: oneshot::Sender<Result<Response, InferenceError>>,
    allocator: Arc<Allocator>,
    outputs_count: usize,
    runtime: Handle,
}

/// C-code returns the ownership on Request using this method.
unsafe extern "C" fn release_callback(
    ptr: *mut sys::TRITONSERVER_InferenceRequest,
    _flags: u32,
    user_data: *mut c_void,
) {
    trace!("release_callback is called");
    assert!(!ptr.is_null());
    assert!(!user_data.is_null());

    let (mut request, input_tx) = *Box::from_raw(user_data as *mut (Request, oneshot::Sender<_>));
    // Drain the input buffers
    let mut buffers = HashMap::new();
    std::mem::swap(&mut buffers, &mut request.input);

    if input_tx.send(buffers).is_err() {
        log::debug!("InputRelease was dropped before the input buffers returned from triton. Input buffers will be dropped");
    }

    assert_eq!(request.ptr, ptr);
    trace!("release_callback is ended");
}

/// C-code calls this method when Response is ready.
unsafe extern "C" fn responce_wrapper(
    response: *mut sys::TRITONSERVER_InferenceResponse,
    _flags: u32,
    user_data: *mut c_void,
) {
    trace!("response wrapper is called");
    assert!(!response.is_null());
    assert!(!user_data.is_null());

    // Allocator присылали сюда только для того, чтобы он не дропнулся во время реквеста.
    let ResponseCallbackItems {
        response_tx,
        allocator,
        outputs_count,
        runtime,
    } = *Box::from_raw(user_data as *mut ResponseCallbackItems);

    let send_res = response_tx.send(Response::new(
        response,
        outputs_count as u32,
        allocator,
        runtime,
    ));
    if send_res.is_err() {
        log::error!("error sending the result of the inference. It will be lost (including the output buffer)")
    } else {
        trace!("response wrapper: result is sent to oneshot");
    }
}