1pub(crate) mod infer;
2mod utils;
3pub use crate::trace::Trace;
4pub use infer::{InferenceError, InputRelease, ResponseFuture};
5
6use std::{collections::HashMap, mem::transmute, os::raw::c_char, ptr::null, time::Duration};
7
8use crate::{
9 error::ErrorCode,
10 from_char_array,
11 memory::{Buffer, DataType, MemoryType},
12 message::Shape,
13 parameter::{Parameter, ParameterContent},
14 run_in_context,
15 sys::{
16 self, TRITONSERVER_InferenceRequestRemoveAllInputData,
17 TRITONSERVER_InferenceRequestRemoveAllInputs, TRITONSERVER_InferenceRequestRemoveInput,
18 },
19 to_cstring, Error, Server,
20};
21
22#[deprecated(
23 since = "0.4.1",
24 note = "Can't be used with bit operations as flags. Use `SequenceFlag` instead"
25)]
26#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
28#[repr(u32)]
29pub enum Sequence {
30 Start = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_START,
31 End = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_END,
32}
33
34bitflags::bitflags! {
35 pub struct SequenceFlag: u32 {
37 const START = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_START;
38 const END = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_END;
39 }
40}
41
42#[async_trait::async_trait]
52pub trait Allocator: Send {
53 async fn allocate(
63 &mut self,
64 tensor_name: String,
65 requested_memory_type: MemoryType,
66 byte_size: usize,
67 data_type: DataType,
68 ) -> Result<Buffer, Error>;
69
70 fn enable_queries(&self) -> bool {
73 false
74 }
75
76 #[allow(unused_variables)]
87 async fn pre_allocation_query(
88 &mut self,
89 tensor_name: Option<String>,
90 byte_size: Option<usize>,
91 requested_memory_type: MemoryType,
92 ) -> MemoryType {
93 requested_memory_type
94 }
95}
96
97#[derive(Debug, Default, Clone, Copy)]
101pub struct DefaultAllocator;
102
103#[async_trait::async_trait]
104impl Allocator for DefaultAllocator {
105 async fn allocate(
106 &mut self,
107 _tensor_name: String,
108 requested_mem_type: MemoryType,
109 byte_size: usize,
110 data_type: DataType,
111 ) -> Result<Buffer, Error> {
112 let data_type_size = data_type.size();
113 run_in_context(0, move || {
114 Buffer::alloc_with_data_type(
115 (byte_size as f32 / data_type_size as f32).ceil() as usize,
116 requested_mem_type,
117 data_type,
118 )
119 })
120 .await?
121 }
122}
123
124pub struct Request<'a> {
129 ptr: *mut sys::TRITONSERVER_InferenceRequest,
130 model_name: String,
131 input: HashMap<String, Buffer>,
132 custom_allocator: Option<Box<dyn Allocator>>,
133 custom_trace: Option<Trace>,
134 server: &'a Server,
137}
138
139impl<'a> Request<'a> {
140 pub(crate) fn new<M: AsRef<str>>(
141 ptr: *mut sys::TRITONSERVER_InferenceRequest,
142 server: &'a Server,
143 model: M,
144 ) -> Result<Request<'a>, Error> {
145 Ok(Request {
146 ptr,
147 model_name: model.as_ref().to_string(),
148 input: HashMap::new(),
149 custom_allocator: None,
150 custom_trace: None,
151 server,
152 })
153 }
154
155 pub fn add_allocator(&mut self, custom_allocator: Box<dyn Allocator>) -> &mut Self {
158 let _ = self.custom_allocator.replace(custom_allocator);
159 self
160 }
161
162 pub fn add_default_allocator(&mut self) -> &mut Self {
165 let _ = self.custom_allocator.replace(Box::new(DefaultAllocator));
166 self
167 }
168
169 pub fn add_trace(&mut self, custom_trace: Trace) -> &mut Self {
173 let _ = self.custom_trace.replace(custom_trace);
174 self
175 }
176
177 pub fn get_id(&self) -> Result<String, Error> {
179 let mut id = null::<c_char>();
180 triton_call!(
181 sys::TRITONSERVER_InferenceRequestId(self.ptr, &mut id as *mut _),
182 from_char_array(id)
183 )
184 }
185
186 pub fn set_id<I: AsRef<str>>(&mut self, id: I) -> Result<&mut Self, Error> {
188 let id = to_cstring(id)?;
189 triton_call!(
190 sys::TRITONSERVER_InferenceRequestSetId(self.ptr, id.as_ptr()),
191 self
192 )
193 }
194
195 #[deprecated(
196 since = "0.4.1",
197 note = "Can't be used with bit operations as flags. Use `Request::get_sequence_flags instead`"
198 )]
199 #[allow(deprecated)]
200 pub fn get_flags(&self) -> Result<Sequence, Error> {
203 let mut flag: u32 = 0;
204 triton_call!(sys::TRITONSERVER_InferenceRequestFlags(
205 self.ptr,
206 &mut flag as *mut _
207 ))?;
208 unsafe { Ok(transmute::<u32, Sequence>(flag)) }
209 }
210
211 pub fn get_sequence_flags(&self) -> Result<SequenceFlag, Error> {
214 let mut flag: u32 = 0;
215 triton_call!(sys::TRITONSERVER_InferenceRequestFlags(
216 self.ptr,
217 &mut flag as *mut _
218 ))?;
219 SequenceFlag::from_bits(flag).ok_or(Error::new(
220 ErrorCode::Internal,
221 format!("Triton returned unknown SequenceFlag: {flag:#b}"),
222 ))
223 }
224
225 #[deprecated(
226 since = "0.4.1",
227 note = "Can't be used with bit operations as flags. Use `Request::set_sequence_flags instead`"
228 )]
229 #[allow(deprecated)]
230 pub fn set_flags(&mut self, flags: Sequence) -> Result<&mut Self, Error> {
233 triton_call!(
234 sys::TRITONSERVER_InferenceRequestSetFlags(self.ptr, flags as _),
235 self
236 )
237 }
238
239 pub fn set_sequence_flags(&mut self, flags: SequenceFlag) -> Result<&mut Self, Error> {
242 triton_call!(
243 sys::TRITONSERVER_InferenceRequestSetFlags(self.ptr, flags.bits()),
244 self
245 )
246 }
247
248 pub fn get_correlation_id(&self) -> Result<u64, Error> {
254 let mut id: u64 = 0;
255 triton_call!(
256 sys::TRITONSERVER_InferenceRequestCorrelationId(self.ptr, &mut id as *mut _),
257 id
258 )
259 }
260
261 pub fn get_correlation_id_as_string(&self) -> Result<String, Error> {
267 let mut id = null::<c_char>();
268 triton_call!(
269 sys::TRITONSERVER_InferenceRequestCorrelationIdString(self.ptr, &mut id as *mut _),
270 from_char_array(id)
271 )
272 }
273
274 pub fn set_correlation_id(&mut self, id: u64) -> Result<&mut Self, Error> {
279 triton_call!(
280 sys::TRITONSERVER_InferenceRequestSetCorrelationId(self.ptr, id),
281 self
282 )
283 }
284
285 pub fn set_correlation_id_as_str<I: AsRef<str>>(&mut self, id: I) -> Result<&mut Self, Error> {
289 let id = to_cstring(id)?;
290 triton_call!(
291 sys::TRITONSERVER_InferenceRequestSetCorrelationIdString(self.ptr, id.as_ptr()),
292 self
293 )
294 }
295
296 pub fn get_priority(&self) -> Result<u32, Error> {
299 let mut priority: u32 = 0;
300 triton_call!(
301 sys::TRITONSERVER_InferenceRequestPriority(self.ptr, &mut priority as *mut _),
302 priority
303 )
304 }
305
306 pub fn set_priority(&mut self, priority: u32) -> Result<&mut Self, Error> {
309 triton_call!(
310 sys::TRITONSERVER_InferenceRequestSetPriority(self.ptr, priority),
311 self
312 )
313 }
314
315 pub fn get_timeout(&self) -> Result<Duration, Error> {
318 let mut timeout_us: u64 = 0;
319 triton_call!(
320 sys::TRITONSERVER_InferenceRequestTimeoutMicroseconds(
321 self.ptr,
322 &mut timeout_us as *mut _,
323 ),
324 Duration::from_micros(timeout_us)
325 )
326 }
327
328 pub fn set_timeout(&mut self, timeout: Duration) -> Result<&mut Self, Error> {
331 triton_call!(
332 sys::TRITONSERVER_InferenceRequestSetTimeoutMicroseconds(
333 self.ptr,
334 timeout.as_micros() as u64,
335 ),
336 self
337 )
338 }
339
340 pub fn add_input<N: AsRef<str>>(
345 &mut self,
346 input_name: N,
347 buffer: Buffer,
348 ) -> Result<&mut Self, Error> {
349 self.add_input_inner(input_name, buffer, None::<String>, None::<Vec<i64>>)
350 }
351
352 pub fn add_input_with_dims<N, D>(
358 &mut self,
359 input_name: N,
360 buffer: Buffer,
361 dims: D,
362 ) -> Result<&mut Self, Error>
363 where
364 N: AsRef<str>,
365 D: AsRef<[i64]>,
366 {
367 self.add_input_inner(input_name, buffer, None::<String>, Some(dims))
368 }
369
370 pub fn add_input_with_policy<N, P>(
376 &mut self,
377 input_name: N,
378 buffer: Buffer,
379 policy: P,
380 ) -> Result<&mut Self, Error>
381 where
382 N: AsRef<str>,
383 P: AsRef<str>,
384 {
385 self.add_input_inner(input_name, buffer, Some(policy), None::<Vec<i64>>)
386 }
387
388 pub fn add_input_with_policy_and_dims<N, P, D>(
395 &mut self,
396 input_name: N,
397 buffer: Buffer,
398 policy: P,
399 dims: D,
400 ) -> Result<&mut Self, Error>
401 where
402 N: AsRef<str>,
403 P: AsRef<str>,
404 D: AsRef<[i64]>,
405 {
406 self.add_input_inner(input_name, buffer, Some(policy), Some(dims))
407 }
408
409 fn add_input_inner<N, P, D>(
410 &mut self,
411 input_name: N,
412 buffer: Buffer,
413 policy: Option<P>,
414 dims: Option<D>,
415 ) -> Result<&mut Self, Error>
416 where
417 N: AsRef<str>,
418 P: AsRef<str>,
419 D: AsRef<[i64]>,
420 {
421 if self.input.contains_key(input_name.as_ref()) {
422 return Err(Error::new(
423 ErrorCode::Alreadyxists,
424 format!(
425 "Request already has buffer for input \"{}\"",
426 input_name.as_ref()
427 ),
428 ));
429 }
430 let model_shape = self.get_shape(input_name.as_ref())?;
431
432 let shape = if let Some(dims) = dims {
433 Shape {
434 name: input_name.as_ref().to_string(),
435 datatype: model_shape.datatype,
436 dims: dims.as_ref().to_vec(),
437 }
438 } else {
439 Shape {
440 name: input_name.as_ref().to_string(),
441 datatype: model_shape.datatype,
442 dims: model_shape.dims.clone(),
443 }
444 };
445
446 assert_buffer_shape(&shape, &buffer, input_name.as_ref())?;
447
448 self.add_input_triton(&input_name, &shape)?;
449
450 if let Some(policy) = policy {
451 self.append_input_data_with_policy(input_name, &policy, buffer)?;
452 } else {
453 self.append_input_data(input_name, buffer)?;
454 }
455
456 Ok(self)
457 }
458
459 fn get_shape<N: AsRef<str>>(&self, source: N) -> Result<&Shape, Error> {
460 let model_name = &self.model_name;
461 let model = self.server.get_model(model_name)?;
462
463 match model
464 .inputs
465 .iter()
466 .find(|shape| shape.name == source.as_ref())
467 {
468 None => Err(Error::new(
469 ErrorCode::Internal,
470 format!("Model {model_name} has no input named: {}", source.as_ref()),
471 )),
472 Some(shape) => Ok(shape),
473 }
474 }
475
476 fn add_input_triton<I: AsRef<str>>(&self, input_name: I, input: &Shape) -> Result<(), Error> {
477 let name = to_cstring(input_name)?;
478 triton_call!(sys::TRITONSERVER_InferenceRequestAddInput(
479 self.ptr,
480 name.as_ptr(),
481 input.datatype as u32,
482 input.dims.as_ptr(),
483 input.dims.len() as u64,
484 ))
485 }
486
487 fn append_input_data<I: AsRef<str>>(
488 &mut self,
489 input_name: I,
490 buffer: Buffer,
491 ) -> Result<&mut Self, Error> {
492 let name = to_cstring(&input_name)?;
493 triton_call!(sys::TRITONSERVER_InferenceRequestAppendInputData(
494 self.ptr,
495 name.as_ptr(),
496 buffer.ptr,
497 buffer.len,
498 buffer.memory_type as u32,
499 0,
500 ))?;
501
502 let _ = self.input.insert(input_name.as_ref().to_string(), buffer);
503 Ok(self)
504 }
505
506 fn append_input_data_with_policy<I: AsRef<str>, P: AsRef<str>>(
507 &mut self,
508 input_name: I,
509 policy: P,
510 buffer: Buffer,
511 ) -> Result<&mut Self, Error> {
512 let name = to_cstring(&input_name)?;
513 let policy = to_cstring(policy)?;
514 triton_call!(
515 sys::TRITONSERVER_InferenceRequestAppendInputDataWithHostPolicy(
516 self.ptr,
517 name.as_ptr(),
518 buffer.ptr,
519 buffer.len,
520 buffer.memory_type as u32,
521 0,
522 policy.as_ptr(),
523 )
524 )?;
525
526 self.input.insert(input_name.as_ref().to_string(), buffer);
527 Ok(self)
528 }
529
530 pub fn remove_input<N: AsRef<str>>(&mut self, name: N) -> Result<Buffer, Error> {
534 let buffer = self.input.remove(name.as_ref()).ok_or_else(|| {
535 Error::new(
536 ErrorCode::InvalidArg,
537 format!(
538 "Can't find input {} in a request. Appended inputs: {:?}",
539 name.as_ref(),
540 self.input.keys()
541 ),
542 )
543 })?;
544 let name = to_cstring(name)?;
545
546 triton_call!(TRITONSERVER_InferenceRequestRemoveAllInputData(
547 self.ptr,
548 name.as_ptr()
549 ))?;
550 triton_call!(
551 TRITONSERVER_InferenceRequestRemoveInput(self.ptr, name.as_ptr()),
552 buffer
553 )
554 }
555
556 pub fn remove_all_inputs(&mut self) -> Result<HashMap<String, Buffer>, Error> {
558 let mut buffers = HashMap::new();
559 std::mem::swap(&mut buffers, &mut self.input);
560
561 triton_call!(
562 TRITONSERVER_InferenceRequestRemoveAllInputs(self.ptr),
563 buffers
564 )
565 }
566
567 pub(crate) fn add_outputs(&mut self) -> Result<HashMap<String, DataType>, Error> {
568 let model = self.server.get_model(&self.model_name)?;
569 let mut datatype_hints = HashMap::new();
570
571 for output in &model.outputs {
572 self.add_output(&output.name)?;
573 datatype_hints.insert(output.name.clone(), output.datatype);
574 }
575
576 Ok(datatype_hints)
577 }
578
579 pub(crate) fn add_output<N: AsRef<str>>(&mut self, name: N) -> Result<&mut Self, Error> {
585 let output_name = to_cstring(&name)?;
586 triton_call!(
587 sys::TRITONSERVER_InferenceRequestAddRequestedOutput(self.ptr, output_name.as_ptr()),
588 self
589 )
590 }
591
592 pub fn set_parameter(&mut self, parameter: Parameter) -> Result<&mut Self, Error> {
594 let name = to_cstring(¶meter.name)?;
595 match parameter.content.clone() {
596 ParameterContent::Bool(value) => triton_call!(
597 sys::TRITONSERVER_InferenceRequestSetBoolParameter(self.ptr, name.as_ptr(), value),
598 self
599 ),
600 ParameterContent::Bytes(_) => Err(Error::new(
601 ErrorCode::Unsupported,
602 "Request::set_parameter does not support ParameterContent::Bytes",
603 )),
604 ParameterContent::Int(value) => triton_call!(
605 sys::TRITONSERVER_InferenceRequestSetIntParameter(self.ptr, name.as_ptr(), value),
606 self
607 ),
608 ParameterContent::Double(value) => {
609 triton_call!(
610 sys::TRITONSERVER_InferenceRequestSetDoubleParameter(
611 self.ptr,
612 name.as_ptr(),
613 value
614 ),
615 self
616 )
617 }
618 ParameterContent::String(value) => {
619 let value = to_cstring(value)?;
620 triton_call!(
621 sys::TRITONSERVER_InferenceRequestSetStringParameter(
622 self.ptr,
623 name.as_ptr(),
624 value.as_ptr()
625 ),
626 self
627 )
628 }
629 }
630 }
631}
632
633unsafe impl Send for Request<'_> {}
634
635impl Drop for Request<'_> {
636 fn drop(&mut self) {
637 unsafe {
638 sys::TRITONSERVER_InferenceRequestDelete(self.ptr);
639 }
640 }
641}
642
643fn assert_buffer_shape<N: AsRef<str>>(
644 shape: &Shape,
645 buffer: &Buffer,
646 source: N,
647) -> Result<(), Error> {
648 if shape.datatype != buffer.data_type {
649 return Err(Error::new(
650 ErrorCode::InvalidArg,
651 format!(
652 "input buffer datatype {:?} missmatches model shape datatype: {:?}. input name: {}",
653 buffer.data_type,
654 shape.datatype,
655 source.as_ref()
656 ),
657 ));
658 }
659 let shape_size = if shape.dims.iter().any(|n| *n < 0) || shape.dims.is_empty() {
660 0
661 } else {
662 shape.dims.iter().product::<i64>() as u32 * shape.datatype.size()
663 };
664
665 if shape_size as usize > buffer.size() {
666 Err(Error::new(
667 ErrorCode::InvalidArg,
668 format!(
669 "Buffer has size: {}, that less than shape min size: {shape_size}. input name: {}",
670 buffer.size(),
671 source.as_ref()
672 ),
673 ))
674 } else {
675 Ok(())
676 }
677}