tritonserver_rs/
request.rs

1pub(crate) mod infer;
2mod utils;
3pub use crate::trace::Trace;
4pub use infer::{InferenceError, InputRelease, ResponseFuture};
5
6use std::{collections::HashMap, mem::transmute, os::raw::c_char, ptr::null, time::Duration};
7
8use crate::{
9    error::ErrorCode,
10    from_char_array,
11    memory::{Buffer, DataType, MemoryType},
12    message::Shape,
13    parameter::{Parameter, ParameterContent},
14    run_in_context,
15    sys::{
16        self, TRITONSERVER_InferenceRequestRemoveAllInputData,
17        TRITONSERVER_InferenceRequestRemoveAllInputs, TRITONSERVER_InferenceRequestRemoveInput,
18    },
19    to_cstring, Error, Server,
20};
21
22#[deprecated(
23    since = "0.4.1",
24    note = "Can't be used with bit operations as flags. Use `SequenceFlag` instead"
25)]
26/// Inference request sequence flag.
27#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
28#[repr(u32)]
29pub enum Sequence {
30    Start = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_START,
31    End = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_END,
32}
33
34bitflags::bitflags! {
35    /// Inference request sequence flag.
36    pub struct SequenceFlag: u32 {
37        const START = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_START;
38        const END = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_END;
39    }
40}
41
42/// Allocator, that user provides in order to allocate output buffers when they are needed for Triton. \
43/// [Allocator::allocate] will be invoked after [Request::infer_async] call once for each model's output.
44/// The name of the requested output, it's memory type,
45/// byte size and data type are passed as arguments.
46///
47/// Allocator should be able to allocate buffer for each model's output.
48///
49/// **Note**: Allocated buffers can be returned via [crate::Response::return_buffers] method. \
50/// **Note** allocate() method can be not invoked at all, for example, if model error occures before output is needed.
51#[async_trait::async_trait]
52pub trait Allocator: Send {
53    /// Allocate output buffer for output with name `tensor_name`.
54    ///
55    /// **NOTES:**:
56    /// - It's not necessary to allocate buffer on exact requested_memory_type: for example,
57    ///     it's fine to allocate buffer on Pinned when Triton requested GPU buffer.
58    ///     The only requirement is not to allocate CPU memory when GPU is requested and vice versa.
59    /// - Buffer of greater or equal size than requested `byte_size` can be allocated but not the smaller.
60    /// - Allocated buffer's datatype must match this output datatype specified in the model's config.
61    /// - Method will be invoked in asynchronous context.
62    async fn allocate(
63        &mut self,
64        tensor_name: String,
65        requested_memory_type: MemoryType,
66        byte_size: usize,
67        data_type: DataType,
68    ) -> Result<Buffer, Error>;
69
70    /// Unable or not a pre allocation queriing. For more info about queriing see [Allocator::pre_allocation_query]. \
71    /// Default is false.
72    fn enable_queries(&self) -> bool {
73        false
74    }
75
76    /// If [self.enable_queries()](Allocator::enable_queries) is true,
77    /// this function will be called to query the allocator's preferred memory type. \
78    /// As much as possible, the allocator should attempt to return the same memory_type
79    /// values that will be returned by the subsequent call to [Allocator::allocate].
80    /// But the allocator is not required to do so.
81    ///
82    /// `tensor_name` The name of the output tensor. None indicates that the tensor name has not determined. \
83    /// `byte_size` The expected size of the buffer. None indicates that the byte size has not determined.\
84    /// `requested_memory_type` input gives the memory type preferred by the Triton inference. \
85    /// Returns memory type preferred by the allocator, taken account of the caller preferred type.
86    #[allow(unused_variables)]
87    async fn pre_allocation_query(
88        &mut self,
89        tensor_name: Option<String>,
90        byte_size: Option<usize>,
91        requested_memory_type: MemoryType,
92    ) -> MemoryType {
93        requested_memory_type
94    }
95}
96
97/// Default allocator.
98///
99/// Will allocate exact `byte_size` bytes of datatype `data_type` of `requested_memory_type` for each output.
100#[derive(Debug, Default, Clone, Copy)]
101pub struct DefaultAllocator;
102
103#[async_trait::async_trait]
104impl Allocator for DefaultAllocator {
105    async fn allocate(
106        &mut self,
107        _tensor_name: String,
108        requested_mem_type: MemoryType,
109        byte_size: usize,
110        data_type: DataType,
111    ) -> Result<Buffer, Error> {
112        let data_type_size = data_type.size();
113        run_in_context(0, move || {
114            Buffer::alloc_with_data_type(
115                (byte_size as f32 / data_type_size as f32).ceil() as usize,
116                requested_mem_type,
117                data_type,
118            )
119        })
120        .await?
121    }
122}
123
124/// Inference request object.\
125/// One can get this item using [Server::create_request].
126///
127/// It's required to add input data and Allocator to this structure before the inference via one of [add_input](Request::add_input) methods  via [Request::add_allocator] or [Request::add_default_allocator] method.
128pub struct Request<'a> {
129    ptr: *mut sys::TRITONSERVER_InferenceRequest,
130    model_name: String,
131    input: HashMap<String, Buffer>,
132    custom_allocator: Option<Box<dyn Allocator>>,
133    custom_trace: Option<Trace>,
134    // Уверяемся, что Server не дропнется во время выполнения Request. \
135    // Server(Arc<Inner>)
136    server: &'a Server,
137}
138
139impl<'a> Request<'a> {
140    pub(crate) fn new<M: AsRef<str>>(
141        ptr: *mut sys::TRITONSERVER_InferenceRequest,
142        server: &'a Server,
143        model: M,
144    ) -> Result<Request<'a>, Error> {
145        Ok(Request {
146            ptr,
147            model_name: model.as_ref().to_string(),
148            input: HashMap::new(),
149            custom_allocator: None,
150            custom_trace: None,
151            server,
152        })
153    }
154
155    /// Add custom Allocator to the request. \
156    /// Check [Allocator] trait for more info.
157    pub fn add_allocator(&mut self, custom_allocator: Box<dyn Allocator>) -> &mut Self {
158        let _ = self.custom_allocator.replace(custom_allocator);
159        self
160    }
161
162    /// Add [DefaultAllocator] to the request. \
163    /// Check [Allocator] trait and [DefaultAllocator] for more info.
164    pub fn add_default_allocator(&mut self) -> &mut Self {
165        let _ = self.custom_allocator.replace(Box::new(DefaultAllocator));
166        self
167    }
168
169    /// Add custom Trace to the request. \
170    /// If this method is not invoked, no tracing will be provided. \
171    /// Check [Trace] for more info about tracing.
172    pub fn add_trace(&mut self, custom_trace: Trace) -> &mut Self {
173        let _ = self.custom_trace.replace(custom_trace);
174        self
175    }
176
177    /// Get the ID of the request.
178    pub fn get_id(&self) -> Result<String, Error> {
179        let mut id = null::<c_char>();
180        triton_call!(
181            sys::TRITONSERVER_InferenceRequestId(self.ptr, &mut id as *mut _),
182            from_char_array(id)
183        )
184    }
185
186    /// Set the ID of the request.
187    pub fn set_id<I: AsRef<str>>(&mut self, id: I) -> Result<&mut Self, Error> {
188        let id = to_cstring(id)?;
189        triton_call!(
190            sys::TRITONSERVER_InferenceRequestSetId(self.ptr, id.as_ptr()),
191            self
192        )
193    }
194
195    #[deprecated(
196        since = "0.4.1",
197        note = "Can't be used with bit operations as flags. Use `Request::get_sequence_flags instead`"
198    )]
199    #[allow(deprecated)]
200    /// Get the flag(s) associated with the request. \
201    /// Check [Sequence] for available flags.
202    pub fn get_flags(&self) -> Result<Sequence, Error> {
203        let mut flag: u32 = 0;
204        triton_call!(sys::TRITONSERVER_InferenceRequestFlags(
205            self.ptr,
206            &mut flag as *mut _
207        ))?;
208        unsafe { Ok(transmute::<u32, Sequence>(flag)) }
209    }
210
211    /// Get the flag(s) associated with the request. \
212    /// Check [SequenceFlag] for available flags.
213    pub fn get_sequence_flags(&self) -> Result<SequenceFlag, Error> {
214        let mut flag: u32 = 0;
215        triton_call!(sys::TRITONSERVER_InferenceRequestFlags(
216            self.ptr,
217            &mut flag as *mut _
218        ))?;
219        SequenceFlag::from_bits(flag).ok_or(Error::new(
220            ErrorCode::Internal,
221            format!("Triton returned unknown SequenceFlag: {flag:#b}"),
222        ))
223    }
224
225    #[deprecated(
226        since = "0.4.1",
227        note = "Can't be used with bit operations as flags. Use `Request::set_sequence_flags instead`"
228    )]
229    #[allow(deprecated)]
230    /// Set the flag(s) associated with a request. \
231    /// Check [Sequence] for available flags.
232    pub fn set_flags(&mut self, flags: Sequence) -> Result<&mut Self, Error> {
233        triton_call!(
234            sys::TRITONSERVER_InferenceRequestSetFlags(self.ptr, flags as _),
235            self
236        )
237    }
238
239    /// Set the flag(s) associated with a request. \
240    /// Check [SequenceFlag] for available flags.
241    pub fn set_sequence_flags(&mut self, flags: SequenceFlag) -> Result<&mut Self, Error> {
242        triton_call!(
243            sys::TRITONSERVER_InferenceRequestSetFlags(self.ptr, flags.bits()),
244            self
245        )
246    }
247
248    /// Get the correlation ID of the inference request. \
249    /// Default is 0, which indicates that the request has no correlation ID. \
250    /// If the correlation id associated with the inference request is a string, this function will return a failure. \
251    /// The correlation ID is used to indicate two or more inference request are related to each other. \
252    /// How this relationship is handled by the inference server is determined by the model's scheduling policy.
253    pub fn get_correlation_id(&self) -> Result<u64, Error> {
254        let mut id: u64 = 0;
255        triton_call!(
256            sys::TRITONSERVER_InferenceRequestCorrelationId(self.ptr, &mut id as *mut _),
257            id
258        )
259    }
260
261    /// Get the correlation ID of the inference request as a string. \
262    /// Default is empty "", which indicates that the request has no correlation ID. \
263    /// If the correlation id associated with the inference request is an unsigned integer, then this function will return a failure. \
264    /// The correlation ID is used to indicate two or more inference request are related to each other. \
265    /// How this relationship is handled by the inference server is determined by the model's scheduling policy.
266    pub fn get_correlation_id_as_string(&self) -> Result<String, Error> {
267        let mut id = null::<c_char>();
268        triton_call!(
269            sys::TRITONSERVER_InferenceRequestCorrelationIdString(self.ptr, &mut id as *mut _),
270            from_char_array(id)
271        )
272    }
273
274    /// Set the correlation ID of the inference request to be an unsigned integer. \
275    /// Default is 0, which indicates that the request has no correlation ID. \
276    /// The correlation ID is used to indicate two or more inference request are related to each other. \
277    /// How this relationship is handled by the inference server is determined by the model's scheduling policy.
278    pub fn set_correlation_id(&mut self, id: u64) -> Result<&mut Self, Error> {
279        triton_call!(
280            sys::TRITONSERVER_InferenceRequestSetCorrelationId(self.ptr, id),
281            self
282        )
283    }
284
285    /// Set the correlation ID of the inference request to be a string. \
286    /// The correlation ID is used to indicate two or more inference request are related to each other. \
287    /// How this relationship is handled by the inference server is determined by the model's scheduling policy.
288    pub fn set_correlation_id_as_str<I: AsRef<str>>(&mut self, id: I) -> Result<&mut Self, Error> {
289        let id = to_cstring(id)?;
290        triton_call!(
291            sys::TRITONSERVER_InferenceRequestSetCorrelationIdString(self.ptr, id.as_ptr()),
292            self
293        )
294    }
295
296    /// Get the priority of the request. \
297    /// The default is 0 indicating that the request does not specify a priority and so will use the model's default priority.
298    pub fn get_priority(&self) -> Result<u32, Error> {
299        let mut priority: u32 = 0;
300        triton_call!(
301            sys::TRITONSERVER_InferenceRequestPriority(self.ptr, &mut priority as *mut _),
302            priority
303        )
304    }
305
306    /// Set the priority of the request. \
307    /// The default is 0 indicating that the request does not specify a priority and so will use the model's default priority.
308    pub fn set_priority(&mut self, priority: u32) -> Result<&mut Self, Error> {
309        triton_call!(
310            sys::TRITONSERVER_InferenceRequestSetPriority(self.ptr, priority),
311            self
312        )
313    }
314
315    /// Get the timeout of the request. \
316    /// The default is 0 which indicates that the request has no timeout.
317    pub fn get_timeout(&self) -> Result<Duration, Error> {
318        let mut timeout_us: u64 = 0;
319        triton_call!(
320            sys::TRITONSERVER_InferenceRequestTimeoutMicroseconds(
321                self.ptr,
322                &mut timeout_us as *mut _,
323            ),
324            Duration::from_micros(timeout_us)
325        )
326    }
327
328    /// Set the timeout of the request. \
329    /// The default is 0 which indicates that the request has no timeout.
330    pub fn set_timeout(&mut self, timeout: Duration) -> Result<&mut Self, Error> {
331        triton_call!(
332            sys::TRITONSERVER_InferenceRequestSetTimeoutMicroseconds(
333                self.ptr,
334                timeout.as_micros() as u64,
335            ),
336            self
337        )
338    }
339
340    /// Add an input to the request.\
341    /// `input_name`: The name of the input. \
342    /// `buffer`: input data containing buffer. \
343    /// Note: input data will be returned after the inference. Check [ResponseFuture::get_input_release] for more info.
344    pub fn add_input<N: AsRef<str>>(
345        &mut self,
346        input_name: N,
347        buffer: Buffer,
348    ) -> Result<&mut Self, Error> {
349        self.add_input_inner(input_name, buffer, None::<String>, None::<Vec<i64>>)
350    }
351
352    /// Add an input with the specified shape to the request.\
353    /// `input_name`: The name of the input. \
354    /// `buffer`: input data containing buffer. \
355    ///  `dims`: Dimensions of the input.\
356    /// Note: input data will be returned after the inference. Check [ResponseFuture::get_input_release] for more info.
357    pub fn add_input_with_dims<N, D>(
358        &mut self,
359        input_name: N,
360        buffer: Buffer,
361        dims: D,
362    ) -> Result<&mut Self, Error>
363    where
364        N: AsRef<str>,
365        D: AsRef<[i64]>,
366    {
367        self.add_input_inner(input_name, buffer, None::<String>, Some(dims))
368    }
369
370    /// Add an input with the specified host policy to the request.\
371    /// `input_name`: The name of the input.\
372    /// `buffer`: input data containing buffer. \
373    /// `policy`: The policy name, all model instances executing with this policy will use this input buffer for execution.\
374    /// Note: input data will be returned after the inference. Check [ResponseFuture::get_input_release] for more info.
375    pub fn add_input_with_policy<N, P>(
376        &mut self,
377        input_name: N,
378        buffer: Buffer,
379        policy: P,
380    ) -> Result<&mut Self, Error>
381    where
382        N: AsRef<str>,
383        P: AsRef<str>,
384    {
385        self.add_input_inner(input_name, buffer, Some(policy), None::<Vec<i64>>)
386    }
387
388    /// Add an input with the specified host policy and shape to the request.
389    /// `input_name`: The name of the input.\
390    /// `buffer`: input data containing buffer. \
391    /// `policy`: The policy name, all model instances executing with this policy will use this input buffer for execution.\
392    /// `dims`: Dimensions of the input.\
393    /// Note: input data will be returned after the inference. Check [ResponseFuture::get_input_release] for more info.
394    pub fn add_input_with_policy_and_dims<N, P, D>(
395        &mut self,
396        input_name: N,
397        buffer: Buffer,
398        policy: P,
399        dims: D,
400    ) -> Result<&mut Self, Error>
401    where
402        N: AsRef<str>,
403        P: AsRef<str>,
404        D: AsRef<[i64]>,
405    {
406        self.add_input_inner(input_name, buffer, Some(policy), Some(dims))
407    }
408
409    fn add_input_inner<N, P, D>(
410        &mut self,
411        input_name: N,
412        buffer: Buffer,
413        policy: Option<P>,
414        dims: Option<D>,
415    ) -> Result<&mut Self, Error>
416    where
417        N: AsRef<str>,
418        P: AsRef<str>,
419        D: AsRef<[i64]>,
420    {
421        if self.input.contains_key(input_name.as_ref()) {
422            return Err(Error::new(
423                ErrorCode::Alreadyxists,
424                format!(
425                    "Request already has buffer for input \"{}\"",
426                    input_name.as_ref()
427                ),
428            ));
429        }
430        let model_shape = self.get_shape(input_name.as_ref())?;
431
432        let shape = if let Some(dims) = dims {
433            Shape {
434                name: input_name.as_ref().to_string(),
435                datatype: model_shape.datatype,
436                dims: dims.as_ref().to_vec(),
437            }
438        } else {
439            Shape {
440                name: input_name.as_ref().to_string(),
441                datatype: model_shape.datatype,
442                dims: model_shape.dims.clone(),
443            }
444        };
445
446        assert_buffer_shape(&shape, &buffer, input_name.as_ref())?;
447
448        self.add_input_triton(&input_name, &shape)?;
449
450        if let Some(policy) = policy {
451            self.append_input_data_with_policy(input_name, &policy, buffer)?;
452        } else {
453            self.append_input_data(input_name, buffer)?;
454        }
455
456        Ok(self)
457    }
458
459    fn get_shape<N: AsRef<str>>(&self, source: N) -> Result<&Shape, Error> {
460        let model_name = &self.model_name;
461        let model = self.server.get_model(model_name)?;
462
463        match model
464            .inputs
465            .iter()
466            .find(|shape| shape.name == source.as_ref())
467        {
468            None => Err(Error::new(
469                ErrorCode::Internal,
470                format!("Model {model_name} has no input named: {}", source.as_ref()),
471            )),
472            Some(shape) => Ok(shape),
473        }
474    }
475
476    fn add_input_triton<I: AsRef<str>>(&self, input_name: I, input: &Shape) -> Result<(), Error> {
477        let name = to_cstring(input_name)?;
478        triton_call!(sys::TRITONSERVER_InferenceRequestAddInput(
479            self.ptr,
480            name.as_ptr(),
481            input.datatype as u32,
482            input.dims.as_ptr(),
483            input.dims.len() as u64,
484        ))
485    }
486
487    fn append_input_data<I: AsRef<str>>(
488        &mut self,
489        input_name: I,
490        buffer: Buffer,
491    ) -> Result<&mut Self, Error> {
492        let name = to_cstring(&input_name)?;
493        triton_call!(sys::TRITONSERVER_InferenceRequestAppendInputData(
494            self.ptr,
495            name.as_ptr(),
496            buffer.ptr,
497            buffer.len,
498            buffer.memory_type as u32,
499            0,
500        ))?;
501
502        let _ = self.input.insert(input_name.as_ref().to_string(), buffer);
503        Ok(self)
504    }
505
506    fn append_input_data_with_policy<I: AsRef<str>, P: AsRef<str>>(
507        &mut self,
508        input_name: I,
509        policy: P,
510        buffer: Buffer,
511    ) -> Result<&mut Self, Error> {
512        let name = to_cstring(&input_name)?;
513        let policy = to_cstring(policy)?;
514        triton_call!(
515            sys::TRITONSERVER_InferenceRequestAppendInputDataWithHostPolicy(
516                self.ptr,
517                name.as_ptr(),
518                buffer.ptr,
519                buffer.len,
520                buffer.memory_type as u32,
521                0,
522                policy.as_ptr(),
523            )
524        )?;
525
526        self.input.insert(input_name.as_ref().to_string(), buffer);
527        Ok(self)
528    }
529
530    /// Remove an input from a request. Returns appended to the input data.
531    ///
532    /// `name` The name of the input. \
533    pub fn remove_input<N: AsRef<str>>(&mut self, name: N) -> Result<Buffer, Error> {
534        let buffer = self.input.remove(name.as_ref()).ok_or_else(|| {
535            Error::new(
536                ErrorCode::InvalidArg,
537                format!(
538                    "Can't find input {} in a request. Appended inputs: {:?}",
539                    name.as_ref(),
540                    self.input.keys()
541                ),
542            )
543        })?;
544        let name = to_cstring(name)?;
545
546        triton_call!(TRITONSERVER_InferenceRequestRemoveAllInputData(
547            self.ptr,
548            name.as_ptr()
549        ))?;
550        triton_call!(
551            TRITONSERVER_InferenceRequestRemoveInput(self.ptr, name.as_ptr()),
552            buffer
553        )
554    }
555
556    /// Remove all the inputs from a request. Returns appended to the inputs data.
557    pub fn remove_all_inputs(&mut self) -> Result<HashMap<String, Buffer>, Error> {
558        let mut buffers = HashMap::new();
559        std::mem::swap(&mut buffers, &mut self.input);
560
561        triton_call!(
562            TRITONSERVER_InferenceRequestRemoveAllInputs(self.ptr),
563            buffers
564        )
565    }
566
567    pub(crate) fn add_outputs(&mut self) -> Result<HashMap<String, DataType>, Error> {
568        let model = self.server.get_model(&self.model_name)?;
569        let mut datatype_hints = HashMap::new();
570
571        for output in &model.outputs {
572            self.add_output(&output.name)?;
573            datatype_hints.insert(output.name.clone(), output.datatype);
574        }
575
576        Ok(datatype_hints)
577    }
578
579    /// Add an output request to an inference request.\
580    /// name: The name of the output.\
581    /// buffer: output data buffer that required by triton allocator.
582    /// Embeddings will be put in this buffer.
583    /// One can obtain buffer back using Response::output() or with infer_async() Error.
584    pub(crate) fn add_output<N: AsRef<str>>(&mut self, name: N) -> Result<&mut Self, Error> {
585        let output_name = to_cstring(&name)?;
586        triton_call!(
587            sys::TRITONSERVER_InferenceRequestAddRequestedOutput(self.ptr, output_name.as_ptr()),
588            self
589        )
590    }
591
592    /// Set a parameter in the request. Does not support ParameterContent::Bytes.   
593    pub fn set_parameter(&mut self, parameter: Parameter) -> Result<&mut Self, Error> {
594        let name = to_cstring(&parameter.name)?;
595        match parameter.content.clone() {
596            ParameterContent::Bool(value) => triton_call!(
597                sys::TRITONSERVER_InferenceRequestSetBoolParameter(self.ptr, name.as_ptr(), value),
598                self
599            ),
600            ParameterContent::Bytes(_) => Err(Error::new(
601                ErrorCode::Unsupported,
602                "Request::set_parameter does not support ParameterContent::Bytes",
603            )),
604            ParameterContent::Int(value) => triton_call!(
605                sys::TRITONSERVER_InferenceRequestSetIntParameter(self.ptr, name.as_ptr(), value),
606                self
607            ),
608            ParameterContent::Double(value) => {
609                triton_call!(
610                    sys::TRITONSERVER_InferenceRequestSetDoubleParameter(
611                        self.ptr,
612                        name.as_ptr(),
613                        value
614                    ),
615                    self
616                )
617            }
618            ParameterContent::String(value) => {
619                let value = to_cstring(value)?;
620                triton_call!(
621                    sys::TRITONSERVER_InferenceRequestSetStringParameter(
622                        self.ptr,
623                        name.as_ptr(),
624                        value.as_ptr()
625                    ),
626                    self
627                )
628            }
629        }
630    }
631}
632
633unsafe impl Send for Request<'_> {}
634
635impl Drop for Request<'_> {
636    fn drop(&mut self) {
637        unsafe {
638            sys::TRITONSERVER_InferenceRequestDelete(self.ptr);
639        }
640    }
641}
642
643fn assert_buffer_shape<N: AsRef<str>>(
644    shape: &Shape,
645    buffer: &Buffer,
646    source: N,
647) -> Result<(), Error> {
648    if shape.datatype != buffer.data_type {
649        return Err(Error::new(
650            ErrorCode::InvalidArg,
651            format!(
652                "input buffer datatype {:?} missmatches model shape datatype: {:?}. input name: {}",
653                buffer.data_type,
654                shape.datatype,
655                source.as_ref()
656            ),
657        ));
658    }
659    let shape_size = if shape.dims.iter().any(|n| *n < 0) || shape.dims.is_empty() {
660        0
661    } else {
662        shape.dims.iter().product::<i64>() as u32 * shape.datatype.size()
663    };
664
665    if shape_size as usize > buffer.size() {
666        Err(Error::new(
667            ErrorCode::InvalidArg,
668            format!(
669                "Buffer has size: {}, that less than shape min size: {shape_size}. input name: {}",
670                buffer.size(),
671                source.as_ref()
672            ),
673        ))
674    } else {
675        Ok(())
676    }
677}