tritonserver_rs/
request.rs

1pub(crate) mod infer;
2mod utils;
3pub use crate::trace::Trace;
4pub use infer::{InferenceError, InputRelease, ResponseFuture};
5
6use std::{collections::HashMap, mem::transmute, os::raw::c_char, ptr::null, time::Duration};
7
8use crate::{
9    error::ErrorCode,
10    from_char_array,
11    memory::{Buffer, DataType, MemoryType},
12    message::Shape,
13    parameter::{Parameter, ParameterContent},
14    run_in_context,
15    sys::{
16        self, TRITONSERVER_InferenceRequestRemoveAllInputData,
17        TRITONSERVER_InferenceRequestRemoveAllInputs, TRITONSERVER_InferenceRequestRemoveInput,
18    },
19    to_cstring, Error, Server,
20};
21
22/// Inference request sequence flag.
23#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
24#[repr(u32)]
25pub enum Sequence {
26    Start = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_START,
27    End = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_END,
28}
29
30/// Allocator, that user provides in order to allocate output buffers when they are needed for Triton. \
31/// [Allocator::allocate] will be invoked after [Request::infer_async] call once for each model's output.
32/// The name of the requested output, it's memory type,
33/// byte size and data type are passed as arguments.
34///
35/// Allocator should be able to allocate buffer for each model's output.
36///
37/// **Note**: Allocated buffers can be returned via [crate::Response::return_buffers] method. \
38/// **Note** allocate() method can be not invoked at all, for example, if model error occures before output is needed.
39#[async_trait::async_trait]
40pub trait Allocator: Send {
41    /// Allocate output buffer for output with name `tensor_name`.
42    ///
43    /// **NOTES:**:
44    /// - It's not necessary to allocate buffer on exact requested_memory_type: for example,
45    ///     it's fine to allocate buffer on Pinned when Triton requested GPU buffer.
46    ///     The only requirement is not to allocate CPU memory when GPU is requested and vice versa.
47    /// - Buffer of greater or equal size than requested `byte_size` can be allocated but not the smaller.
48    /// - Allocated buffer's datatype must match this output datatype specified in the model's config.
49    /// - Method will be invoked in asynchronous context.
50    async fn allocate(
51        &mut self,
52        tensor_name: String,
53        requested_memory_type: MemoryType,
54        byte_size: usize,
55        data_type: DataType,
56    ) -> Result<Buffer, Error>;
57
58    /// Unable or not a pre allocation queriing. For more info about queriing see [Allocator::pre_allocation_query]. \
59    /// Default is false.
60    fn enable_queries(&self) -> bool {
61        false
62    }
63
64    /// If [self.enable_queries()](Allocator::enable_queries) is true,
65    /// this function will be called to query the allocator's preferred memory type. \
66    /// As much as possible, the allocator should attempt to return the same memory_type
67    /// values that will be returned by the subsequent call to [Allocator::allocate].
68    /// But the allocator is not required to do so.
69    ///
70    /// `tensor_name` The name of the output tensor. None indicates that the tensor name has not determined. \
71    /// `byte_size` The expected size of the buffer. None indicates that the byte size has not determined.\
72    /// `requested_memory_type` input gives the memory type preferred by the Triton inference. \
73    /// Returns memory type preferred by the allocator, taken account of the caller preferred type.
74    #[allow(unused_variables)]
75    async fn pre_allocation_query(
76        &mut self,
77        tensor_name: Option<String>,
78        byte_size: Option<usize>,
79        requested_memory_type: MemoryType,
80    ) -> MemoryType {
81        requested_memory_type
82    }
83}
84
85/// Default allocator.
86///
87/// Will allocate exact `byte_size` bytes of datatype `data_type` of `requested_memory_type` for each output.
88#[derive(Debug, Default, Clone, Copy)]
89pub struct DefaultAllocator;
90
91#[async_trait::async_trait]
92impl Allocator for DefaultAllocator {
93    async fn allocate(
94        &mut self,
95        _tensor_name: String,
96        requested_mem_type: MemoryType,
97        byte_size: usize,
98        data_type: DataType,
99    ) -> Result<Buffer, Error> {
100        let data_type_size = data_type.size();
101        run_in_context(0, move || {
102            Buffer::alloc_with_data_type(
103                (byte_size as f32 / data_type_size as f32).ceil() as usize,
104                requested_mem_type,
105                data_type,
106            )
107        })
108        .await?
109    }
110}
111
112/// Inference request object.\
113/// One can get this item using [Server::create_request].
114///
115/// It's required to add input data and Allocator to this structure before the inference via one of [add_input](Request::add_input) methods  via [Request::add_allocator] or [Request::add_default_allocator] method.
116pub struct Request<'a> {
117    ptr: *mut sys::TRITONSERVER_InferenceRequest,
118    model_name: String,
119    input: HashMap<String, Buffer>,
120    custom_allocator: Option<Box<dyn Allocator>>,
121    custom_trace: Option<Trace>,
122    // Уверяемся, что Server не дропнется во время выполнения Request. \
123    // Server(Arc<Inner>)
124    server: &'a Server,
125}
126
127impl<'a> Request<'a> {
128    pub(crate) fn new<M: AsRef<str>>(
129        ptr: *mut sys::TRITONSERVER_InferenceRequest,
130        server: &'a Server,
131        model: M,
132    ) -> Result<Request<'a>, Error> {
133        Ok(Request {
134            ptr,
135            model_name: model.as_ref().to_string(),
136            input: HashMap::new(),
137            custom_allocator: None,
138            custom_trace: None,
139            server,
140        })
141    }
142
143    /// Add custom Allocator to the request. \
144    /// Check [Allocator] trait for more info.
145    pub fn add_allocator(&mut self, custom_allocator: Box<dyn Allocator>) -> &mut Self {
146        let _ = self.custom_allocator.replace(custom_allocator);
147        self
148    }
149
150    /// Add [DefaultAllocator] to the request. \
151    /// Check [Allocator] trait and [DefaultAllocator] for more info.
152    pub fn add_default_allocator(&mut self) -> &mut Self {
153        let _ = self.custom_allocator.replace(Box::new(DefaultAllocator));
154        self
155    }
156
157    /// Add custom Trace to the request. \
158    /// If this method is not invoked, no tracing will be provided. \
159    /// Check [Trace] for more info about tracing.
160    pub fn add_trace(&mut self, custom_trace: Trace) -> &mut Self {
161        let _ = self.custom_trace.replace(custom_trace);
162        self
163    }
164
165    /// Get the ID of the request.
166    pub fn get_id(&self) -> Result<String, Error> {
167        let mut id = null::<c_char>();
168        triton_call!(
169            sys::TRITONSERVER_InferenceRequestId(self.ptr, &mut id as *mut _),
170            from_char_array(id)
171        )
172    }
173
174    /// Set the ID of the request.
175    pub fn set_id<I: AsRef<str>>(&mut self, id: I) -> Result<&mut Self, Error> {
176        let id = to_cstring(id)?;
177        triton_call!(
178            sys::TRITONSERVER_InferenceRequestSetId(self.ptr, id.as_ptr()),
179            self
180        )
181    }
182
183    /// Get the flag(s) associated with the request. \
184    /// Check [Sequence] for available flags.
185    pub fn get_flags(&self) -> Result<Sequence, Error> {
186        let mut flag: u32 = 0;
187        triton_call!(sys::TRITONSERVER_InferenceRequestFlags(
188            self.ptr,
189            &mut flag as *mut _
190        ))?;
191        unsafe { Ok(transmute::<u32, Sequence>(flag)) }
192    }
193
194    /// Set the flag(s) associated with a request. \
195    /// Check [Sequence] for available flags.
196    pub fn set_flags(&mut self, flags: Sequence) -> Result<&mut Self, Error> {
197        triton_call!(
198            sys::TRITONSERVER_InferenceRequestSetFlags(self.ptr, flags as _),
199            self
200        )
201    }
202
203    /// Get the correlation ID of the inference request. \
204    /// Default is 0, which indicates that the request has no correlation ID. \
205    /// If the correlation id associated with the inference request is a string, this function will return a failure. \
206    /// The correlation ID is used to indicate two or more inference request are related to each other. \
207    /// How this relationship is handled by the inference server is determined by the model's scheduling policy.
208    pub fn get_correlation_id(&self) -> Result<u64, Error> {
209        let mut id: u64 = 0;
210        triton_call!(
211            sys::TRITONSERVER_InferenceRequestCorrelationId(self.ptr, &mut id as *mut _),
212            id
213        )
214    }
215
216    /// Get the correlation ID of the inference request as a string. \
217    /// Default is empty "", which indicates that the request has no correlation ID. \
218    /// If the correlation id associated with the inference request is an unsigned integer, then this function will return a failure. \
219    /// The correlation ID is used to indicate two or more inference request are related to each other. \
220    /// How this relationship is handled by the inference server is determined by the model's scheduling policy.
221    pub fn get_correlation_id_as_string(&self) -> Result<String, Error> {
222        let mut id = null::<c_char>();
223        triton_call!(
224            sys::TRITONSERVER_InferenceRequestCorrelationIdString(self.ptr, &mut id as *mut _),
225            from_char_array(id)
226        )
227    }
228
229    /// Set the correlation ID of the inference request to be an unsigned integer. \
230    /// Default is 0, which indicates that the request has no correlation ID. \
231    /// The correlation ID is used to indicate two or more inference request are related to each other. \
232    /// How this relationship is handled by the inference server is determined by the model's scheduling policy.
233    pub fn set_correlation_id(&mut self, id: u64) -> Result<&mut Self, Error> {
234        triton_call!(
235            sys::TRITONSERVER_InferenceRequestSetCorrelationId(self.ptr, id),
236            self
237        )
238    }
239
240    /// Set the correlation ID of the inference request to be a string. \
241    /// The correlation ID is used to indicate two or more inference request are related to each other. \
242    /// How this relationship is handled by the inference server is determined by the model's scheduling policy.
243    pub fn set_correlation_id_as_str<I: AsRef<str>>(&mut self, id: I) -> Result<&mut Self, Error> {
244        let id = to_cstring(id)?;
245        triton_call!(
246            sys::TRITONSERVER_InferenceRequestSetCorrelationIdString(self.ptr, id.as_ptr()),
247            self
248        )
249    }
250
251    /// Get the priority of the request. \
252    /// The default is 0 indicating that the request does not specify a priority and so will use the model's default priority.
253    pub fn get_priority(&self) -> Result<u32, Error> {
254        let mut priority: u32 = 0;
255        triton_call!(
256            sys::TRITONSERVER_InferenceRequestPriority(self.ptr, &mut priority as *mut _),
257            priority
258        )
259    }
260
261    /// Set the priority of the request. \
262    /// The default is 0 indicating that the request does not specify a priority and so will use the model's default priority.
263    pub fn set_priority(&mut self, priority: u32) -> Result<&mut Self, Error> {
264        triton_call!(
265            sys::TRITONSERVER_InferenceRequestSetPriority(self.ptr, priority),
266            self
267        )
268    }
269
270    /// Get the timeout of the request. \
271    /// The default is 0 which indicates that the request has no timeout.
272    pub fn get_timeout(&self) -> Result<Duration, Error> {
273        let mut timeout_us: u64 = 0;
274        triton_call!(
275            sys::TRITONSERVER_InferenceRequestTimeoutMicroseconds(
276                self.ptr,
277                &mut timeout_us as *mut _,
278            ),
279            Duration::from_micros(timeout_us)
280        )
281    }
282
283    /// Set the timeout of the request. \
284    /// The default is 0 which indicates that the request has no timeout.
285    pub fn set_timeout(&mut self, timeout: Duration) -> Result<&mut Self, Error> {
286        triton_call!(
287            sys::TRITONSERVER_InferenceRequestSetTimeoutMicroseconds(
288                self.ptr,
289                timeout.as_micros() as u64,
290            ),
291            self
292        )
293    }
294
295    /// Add an input to the request.\
296    /// `input_name`: The name of the input. \
297    /// `buffer`: input data containing buffer. \
298    /// Note: input data will be returned after the inference. Check [ResponseFuture::get_input_release] for more info.
299    pub fn add_input<N: AsRef<str>>(
300        &mut self,
301        input_name: N,
302        buffer: Buffer,
303    ) -> Result<&mut Self, Error> {
304        self.add_input_inner(input_name, buffer, None::<String>, None::<Vec<i64>>)
305    }
306
307    /// Add an input with the specified shape to the request.\
308    /// `input_name`: The name of the input. \
309    /// `buffer`: input data containing buffer. \
310    ///  `dims`: Dimensions of the input.\
311    /// Note: input data will be returned after the inference. Check [ResponseFuture::get_input_release] for more info.
312    pub fn add_input_with_dims<N, D>(
313        &mut self,
314        input_name: N,
315        buffer: Buffer,
316        dims: D,
317    ) -> Result<&mut Self, Error>
318    where
319        N: AsRef<str>,
320        D: AsRef<[i64]>,
321    {
322        self.add_input_inner(input_name, buffer, None::<String>, Some(dims))
323    }
324
325    /// Add an input with the specified host policy to the request.\
326    /// `input_name`: The name of the input.\
327    /// `buffer`: input data containing buffer. \
328    /// `policy`: The policy name, all model instances executing with this policy will use this input buffer for execution.\
329    /// Note: input data will be returned after the inference. Check [ResponseFuture::get_input_release] for more info.
330    pub fn add_input_with_policy<N, P>(
331        &mut self,
332        input_name: N,
333        buffer: Buffer,
334        policy: P,
335    ) -> Result<&mut Self, Error>
336    where
337        N: AsRef<str>,
338        P: AsRef<str>,
339    {
340        self.add_input_inner(input_name, buffer, Some(policy), None::<Vec<i64>>)
341    }
342
343    /// Add an input with the specified host policy and shape to the request.
344    /// `input_name`: The name of the input.\
345    /// `buffer`: input data containing buffer. \
346    /// `policy`: The policy name, all model instances executing with this policy will use this input buffer for execution.\
347    /// `dims`: Dimensions of the input.\
348    /// Note: input data will be returned after the inference. Check [ResponseFuture::get_input_release] for more info.
349    pub fn add_input_with_policy_and_dims<N, P, D>(
350        &mut self,
351        input_name: N,
352        buffer: Buffer,
353        policy: P,
354        dims: D,
355    ) -> Result<&mut Self, Error>
356    where
357        N: AsRef<str>,
358        P: AsRef<str>,
359        D: AsRef<[i64]>,
360    {
361        self.add_input_inner(input_name, buffer, Some(policy), Some(dims))
362    }
363
364    fn add_input_inner<N, P, D>(
365        &mut self,
366        input_name: N,
367        buffer: Buffer,
368        policy: Option<P>,
369        dims: Option<D>,
370    ) -> Result<&mut Self, Error>
371    where
372        N: AsRef<str>,
373        P: AsRef<str>,
374        D: AsRef<[i64]>,
375    {
376        if self.input.contains_key(input_name.as_ref()) {
377            return Err(Error::new(
378                ErrorCode::Alreadyxists,
379                format!(
380                    "Request already has buffer for input \"{}\"",
381                    input_name.as_ref()
382                ),
383            ));
384        }
385        let model_shape = self.get_shape(input_name.as_ref())?;
386
387        let shape = if let Some(dims) = dims {
388            Shape {
389                name: input_name.as_ref().to_string(),
390                datatype: model_shape.datatype,
391                dims: dims.as_ref().to_vec(),
392            }
393        } else {
394            Shape {
395                name: input_name.as_ref().to_string(),
396                datatype: model_shape.datatype,
397                dims: model_shape.dims.clone(),
398            }
399        };
400
401        assert_buffer_shape(&shape, &buffer, input_name.as_ref())?;
402
403        self.add_input_triton(&input_name, &shape)?;
404
405        if let Some(policy) = policy {
406            self.append_input_data_with_policy(input_name, &policy, buffer)?;
407        } else {
408            self.append_input_data(input_name, buffer)?;
409        }
410
411        Ok(self)
412    }
413
414    fn get_shape<N: AsRef<str>>(&self, source: N) -> Result<&Shape, Error> {
415        let model_name = &self.model_name;
416        let model = self.server.get_model(model_name)?;
417
418        match model
419            .inputs
420            .iter()
421            .find(|shape| shape.name == source.as_ref())
422        {
423            None => Err(Error::new(
424                ErrorCode::Internal,
425                format!("Model {model_name} has no input named: {}", source.as_ref()),
426            )),
427            Some(shape) => Ok(shape),
428        }
429    }
430
431    fn add_input_triton<I: AsRef<str>>(&self, input_name: I, input: &Shape) -> Result<(), Error> {
432        let name = to_cstring(input_name)?;
433        triton_call!(sys::TRITONSERVER_InferenceRequestAddInput(
434            self.ptr,
435            name.as_ptr(),
436            input.datatype as u32,
437            input.dims.as_ptr(),
438            input.dims.len() as u64,
439        ))
440    }
441
442    fn append_input_data<I: AsRef<str>>(
443        &mut self,
444        input_name: I,
445        buffer: Buffer,
446    ) -> Result<&mut Self, Error> {
447        let name = to_cstring(&input_name)?;
448        triton_call!(sys::TRITONSERVER_InferenceRequestAppendInputData(
449            self.ptr,
450            name.as_ptr(),
451            buffer.ptr,
452            buffer.len,
453            buffer.memory_type as u32,
454            0,
455        ))?;
456
457        let _ = self.input.insert(input_name.as_ref().to_string(), buffer);
458        Ok(self)
459    }
460
461    fn append_input_data_with_policy<I: AsRef<str>, P: AsRef<str>>(
462        &mut self,
463        input_name: I,
464        policy: P,
465        buffer: Buffer,
466    ) -> Result<&mut Self, Error> {
467        let name = to_cstring(&input_name)?;
468        let policy = to_cstring(policy)?;
469        triton_call!(
470            sys::TRITONSERVER_InferenceRequestAppendInputDataWithHostPolicy(
471                self.ptr,
472                name.as_ptr(),
473                buffer.ptr,
474                buffer.len,
475                buffer.memory_type as u32,
476                0,
477                policy.as_ptr(),
478            )
479        )?;
480
481        self.input.insert(input_name.as_ref().to_string(), buffer);
482        Ok(self)
483    }
484
485    /// Remove an input from a request. Returns appended to the input data.
486    ///
487    /// `name` The name of the input. \
488    pub fn remove_input<N: AsRef<str>>(&mut self, name: N) -> Result<Buffer, Error> {
489        let buffer = self.input.remove(name.as_ref()).ok_or_else(|| {
490            Error::new(
491                ErrorCode::InvalidArg,
492                format!(
493                    "Can't find input {} in a request. Appended inputs: {:?}",
494                    name.as_ref(),
495                    self.input.keys()
496                ),
497            )
498        })?;
499        let name = to_cstring(name)?;
500
501        triton_call!(TRITONSERVER_InferenceRequestRemoveAllInputData(
502            self.ptr,
503            name.as_ptr()
504        ))?;
505        triton_call!(
506            TRITONSERVER_InferenceRequestRemoveInput(self.ptr, name.as_ptr()),
507            buffer
508        )
509    }
510
511    /// Remove all the inputs from a request. Returns appended to the inputs data.
512    pub fn remove_all_inputs(&mut self) -> Result<HashMap<String, Buffer>, Error> {
513        let mut buffers = HashMap::new();
514        std::mem::swap(&mut buffers, &mut self.input);
515
516        triton_call!(
517            TRITONSERVER_InferenceRequestRemoveAllInputs(self.ptr),
518            buffers
519        )
520    }
521
522    pub(crate) fn add_outputs(&mut self) -> Result<HashMap<String, DataType>, Error> {
523        let model = self.server.get_model(&self.model_name)?;
524        let mut datatype_hints = HashMap::new();
525
526        for output in &model.outputs {
527            self.add_output(&output.name)?;
528            datatype_hints.insert(output.name.clone(), output.datatype);
529        }
530
531        Ok(datatype_hints)
532    }
533
534    /// Add an output request to an inference request.\
535    /// name: The name of the output.\
536    /// buffer: output data buffer that required by triton allocator.
537    /// Embeddings will be put in this buffer.
538    /// One can obtain buffer back using Response::output() or with infer_async() Error.
539    pub(crate) fn add_output<N: AsRef<str>>(&mut self, name: N) -> Result<&mut Self, Error> {
540        let output_name = to_cstring(&name)?;
541        triton_call!(
542            sys::TRITONSERVER_InferenceRequestAddRequestedOutput(self.ptr, output_name.as_ptr()),
543            self
544        )
545    }
546
547    /// Set a parameter in the request. Does not support ParameterContent::Bytes.   
548    pub fn set_parameter(&mut self, parameter: Parameter) -> Result<&mut Self, Error> {
549        let name = to_cstring(&parameter.name)?;
550        match parameter.content.clone() {
551            ParameterContent::Bool(value) => triton_call!(
552                sys::TRITONSERVER_InferenceRequestSetBoolParameter(self.ptr, name.as_ptr(), value),
553                self
554            ),
555            ParameterContent::Bytes(_) => Err(Error::new(
556                ErrorCode::Unsupported,
557                "Request::set_parameter does not support ParameterContent::Bytes",
558            )),
559            ParameterContent::Int(value) => triton_call!(
560                sys::TRITONSERVER_InferenceRequestSetIntParameter(self.ptr, name.as_ptr(), value),
561                self
562            ),
563            ParameterContent::Double(value) => {
564                triton_call!(
565                    sys::TRITONSERVER_InferenceRequestSetDoubleParameter(
566                        self.ptr,
567                        name.as_ptr(),
568                        value
569                    ),
570                    self
571                )
572            }
573            ParameterContent::String(value) => {
574                let value = to_cstring(value)?;
575                triton_call!(
576                    sys::TRITONSERVER_InferenceRequestSetStringParameter(
577                        self.ptr,
578                        name.as_ptr(),
579                        value.as_ptr()
580                    ),
581                    self
582                )
583            }
584        }
585    }
586}
587
588unsafe impl Send for Request<'_> {}
589
590impl Drop for Request<'_> {
591    fn drop(&mut self) {
592        unsafe {
593            sys::TRITONSERVER_InferenceRequestDelete(self.ptr);
594        }
595    }
596}
597
598fn assert_buffer_shape<N: AsRef<str>>(
599    shape: &Shape,
600    buffer: &Buffer,
601    source: N,
602) -> Result<(), Error> {
603    if shape.datatype != buffer.data_type {
604        return Err(Error::new(
605            ErrorCode::InvalidArg,
606            format!(
607                "input buffer datatype {:?} missmatches model shape datatype: {:?}. input name: {}",
608                buffer.data_type,
609                shape.datatype,
610                source.as_ref()
611            ),
612        ));
613    }
614    let shape_size = if shape.dims.iter().any(|n| *n < 0) || shape.dims.is_empty() {
615        0
616    } else {
617        shape.dims.iter().product::<i64>() as u32 * shape.datatype.size()
618    };
619
620    if shape_size as usize > buffer.size() {
621        Err(Error::new(
622            ErrorCode::InvalidArg,
623            format!(
624                "Buffer has size: {}, that less than shape min size: {shape_size}. input name: {}",
625                buffer.size(),
626                source.as_ref()
627            ),
628        ))
629    } else {
630        Ok(())
631    }
632}