1pub(crate) mod infer;
2mod utils;
3pub use crate::trace::Trace;
4pub use infer::{InferenceError, InputRelease, ResponseFuture};
5
6use std::{collections::HashMap, mem::transmute, os::raw::c_char, ptr::null, time::Duration};
7
8use crate::{
9 error::ErrorCode,
10 from_char_array,
11 memory::{Buffer, DataType, MemoryType},
12 message::Shape,
13 parameter::{Parameter, ParameterContent},
14 run_in_context,
15 sys::{
16 self, TRITONSERVER_InferenceRequestRemoveAllInputData,
17 TRITONSERVER_InferenceRequestRemoveAllInputs, TRITONSERVER_InferenceRequestRemoveInput,
18 },
19 to_cstring, Error, Server,
20};
21
22#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
24#[repr(u32)]
25pub enum Sequence {
26 Start = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_START,
27 End = sys::tritonserver_requestflag_enum_TRITONSERVER_REQUEST_FLAG_SEQUENCE_END,
28}
29
30#[async_trait::async_trait]
40pub trait Allocator: Send {
41 async fn allocate(
51 &mut self,
52 tensor_name: String,
53 requested_memory_type: MemoryType,
54 byte_size: usize,
55 data_type: DataType,
56 ) -> Result<Buffer, Error>;
57
58 fn enable_queries(&self) -> bool {
61 false
62 }
63
64 #[allow(unused_variables)]
75 async fn pre_allocation_query(
76 &mut self,
77 tensor_name: Option<String>,
78 byte_size: Option<usize>,
79 requested_memory_type: MemoryType,
80 ) -> MemoryType {
81 requested_memory_type
82 }
83}
84
85#[derive(Debug, Default, Clone, Copy)]
89pub struct DefaultAllocator;
90
91#[async_trait::async_trait]
92impl Allocator for DefaultAllocator {
93 async fn allocate(
94 &mut self,
95 _tensor_name: String,
96 requested_mem_type: MemoryType,
97 byte_size: usize,
98 data_type: DataType,
99 ) -> Result<Buffer, Error> {
100 let data_type_size = data_type.size();
101 run_in_context(0, move || {
102 Buffer::alloc_with_data_type(
103 (byte_size as f32 / data_type_size as f32).ceil() as usize,
104 requested_mem_type,
105 data_type,
106 )
107 })
108 .await?
109 }
110}
111
112pub struct Request<'a> {
117 ptr: *mut sys::TRITONSERVER_InferenceRequest,
118 model_name: String,
119 input: HashMap<String, Buffer>,
120 custom_allocator: Option<Box<dyn Allocator>>,
121 custom_trace: Option<Trace>,
122 server: &'a Server,
125}
126
127impl<'a> Request<'a> {
128 pub(crate) fn new<M: AsRef<str>>(
129 ptr: *mut sys::TRITONSERVER_InferenceRequest,
130 server: &'a Server,
131 model: M,
132 ) -> Result<Request<'a>, Error> {
133 Ok(Request {
134 ptr,
135 model_name: model.as_ref().to_string(),
136 input: HashMap::new(),
137 custom_allocator: None,
138 custom_trace: None,
139 server,
140 })
141 }
142
143 pub fn add_allocator(&mut self, custom_allocator: Box<dyn Allocator>) -> &mut Self {
146 let _ = self.custom_allocator.replace(custom_allocator);
147 self
148 }
149
150 pub fn add_default_allocator(&mut self) -> &mut Self {
153 let _ = self.custom_allocator.replace(Box::new(DefaultAllocator));
154 self
155 }
156
157 pub fn add_trace(&mut self, custom_trace: Trace) -> &mut Self {
161 let _ = self.custom_trace.replace(custom_trace);
162 self
163 }
164
165 pub fn get_id(&self) -> Result<String, Error> {
167 let mut id = null::<c_char>();
168 triton_call!(
169 sys::TRITONSERVER_InferenceRequestId(self.ptr, &mut id as *mut _),
170 from_char_array(id)
171 )
172 }
173
174 pub fn set_id<I: AsRef<str>>(&mut self, id: I) -> Result<&mut Self, Error> {
176 let id = to_cstring(id)?;
177 triton_call!(
178 sys::TRITONSERVER_InferenceRequestSetId(self.ptr, id.as_ptr()),
179 self
180 )
181 }
182
183 pub fn get_flags(&self) -> Result<Sequence, Error> {
186 let mut flag: u32 = 0;
187 triton_call!(sys::TRITONSERVER_InferenceRequestFlags(
188 self.ptr,
189 &mut flag as *mut _
190 ))?;
191 unsafe { Ok(transmute::<u32, Sequence>(flag)) }
192 }
193
194 pub fn set_flags(&mut self, flags: Sequence) -> Result<&mut Self, Error> {
197 triton_call!(
198 sys::TRITONSERVER_InferenceRequestSetFlags(self.ptr, flags as _),
199 self
200 )
201 }
202
203 pub fn get_correlation_id(&self) -> Result<u64, Error> {
209 let mut id: u64 = 0;
210 triton_call!(
211 sys::TRITONSERVER_InferenceRequestCorrelationId(self.ptr, &mut id as *mut _),
212 id
213 )
214 }
215
216 pub fn get_correlation_id_as_string(&self) -> Result<String, Error> {
222 let mut id = null::<c_char>();
223 triton_call!(
224 sys::TRITONSERVER_InferenceRequestCorrelationIdString(self.ptr, &mut id as *mut _),
225 from_char_array(id)
226 )
227 }
228
229 pub fn set_correlation_id(&mut self, id: u64) -> Result<&mut Self, Error> {
234 triton_call!(
235 sys::TRITONSERVER_InferenceRequestSetCorrelationId(self.ptr, id),
236 self
237 )
238 }
239
240 pub fn set_correlation_id_as_str<I: AsRef<str>>(&mut self, id: I) -> Result<&mut Self, Error> {
244 let id = to_cstring(id)?;
245 triton_call!(
246 sys::TRITONSERVER_InferenceRequestSetCorrelationIdString(self.ptr, id.as_ptr()),
247 self
248 )
249 }
250
251 pub fn get_priority(&self) -> Result<u32, Error> {
254 let mut priority: u32 = 0;
255 triton_call!(
256 sys::TRITONSERVER_InferenceRequestPriority(self.ptr, &mut priority as *mut _),
257 priority
258 )
259 }
260
261 pub fn set_priority(&mut self, priority: u32) -> Result<&mut Self, Error> {
264 triton_call!(
265 sys::TRITONSERVER_InferenceRequestSetPriority(self.ptr, priority),
266 self
267 )
268 }
269
270 pub fn get_timeout(&self) -> Result<Duration, Error> {
273 let mut timeout_us: u64 = 0;
274 triton_call!(
275 sys::TRITONSERVER_InferenceRequestTimeoutMicroseconds(
276 self.ptr,
277 &mut timeout_us as *mut _,
278 ),
279 Duration::from_micros(timeout_us)
280 )
281 }
282
283 pub fn set_timeout(&mut self, timeout: Duration) -> Result<&mut Self, Error> {
286 triton_call!(
287 sys::TRITONSERVER_InferenceRequestSetTimeoutMicroseconds(
288 self.ptr,
289 timeout.as_micros() as u64,
290 ),
291 self
292 )
293 }
294
295 pub fn add_input<N: AsRef<str>>(
300 &mut self,
301 input_name: N,
302 buffer: Buffer,
303 ) -> Result<&mut Self, Error> {
304 self.add_input_inner(input_name, buffer, None::<String>, None::<Vec<i64>>)
305 }
306
307 pub fn add_input_with_dims<N, D>(
313 &mut self,
314 input_name: N,
315 buffer: Buffer,
316 dims: D,
317 ) -> Result<&mut Self, Error>
318 where
319 N: AsRef<str>,
320 D: AsRef<[i64]>,
321 {
322 self.add_input_inner(input_name, buffer, None::<String>, Some(dims))
323 }
324
325 pub fn add_input_with_policy<N, P>(
331 &mut self,
332 input_name: N,
333 buffer: Buffer,
334 policy: P,
335 ) -> Result<&mut Self, Error>
336 where
337 N: AsRef<str>,
338 P: AsRef<str>,
339 {
340 self.add_input_inner(input_name, buffer, Some(policy), None::<Vec<i64>>)
341 }
342
343 pub fn add_input_with_policy_and_dims<N, P, D>(
350 &mut self,
351 input_name: N,
352 buffer: Buffer,
353 policy: P,
354 dims: D,
355 ) -> Result<&mut Self, Error>
356 where
357 N: AsRef<str>,
358 P: AsRef<str>,
359 D: AsRef<[i64]>,
360 {
361 self.add_input_inner(input_name, buffer, Some(policy), Some(dims))
362 }
363
364 fn add_input_inner<N, P, D>(
365 &mut self,
366 input_name: N,
367 buffer: Buffer,
368 policy: Option<P>,
369 dims: Option<D>,
370 ) -> Result<&mut Self, Error>
371 where
372 N: AsRef<str>,
373 P: AsRef<str>,
374 D: AsRef<[i64]>,
375 {
376 if self.input.contains_key(input_name.as_ref()) {
377 return Err(Error::new(
378 ErrorCode::Alreadyxists,
379 format!(
380 "Request already has buffer for input \"{}\"",
381 input_name.as_ref()
382 ),
383 ));
384 }
385 let model_shape = self.get_shape(input_name.as_ref())?;
386
387 let shape = if let Some(dims) = dims {
388 Shape {
389 name: input_name.as_ref().to_string(),
390 datatype: model_shape.datatype,
391 dims: dims.as_ref().to_vec(),
392 }
393 } else {
394 Shape {
395 name: input_name.as_ref().to_string(),
396 datatype: model_shape.datatype,
397 dims: model_shape.dims.clone(),
398 }
399 };
400
401 assert_buffer_shape(&shape, &buffer, input_name.as_ref())?;
402
403 self.add_input_triton(&input_name, &shape)?;
404
405 if let Some(policy) = policy {
406 self.append_input_data_with_policy(input_name, &policy, buffer)?;
407 } else {
408 self.append_input_data(input_name, buffer)?;
409 }
410
411 Ok(self)
412 }
413
414 fn get_shape<N: AsRef<str>>(&self, source: N) -> Result<&Shape, Error> {
415 let model_name = &self.model_name;
416 let model = self.server.get_model(model_name)?;
417
418 match model
419 .inputs
420 .iter()
421 .find(|shape| shape.name == source.as_ref())
422 {
423 None => Err(Error::new(
424 ErrorCode::Internal,
425 format!("Model {model_name} has no input named: {}", source.as_ref()),
426 )),
427 Some(shape) => Ok(shape),
428 }
429 }
430
431 fn add_input_triton<I: AsRef<str>>(&self, input_name: I, input: &Shape) -> Result<(), Error> {
432 let name = to_cstring(input_name)?;
433 triton_call!(sys::TRITONSERVER_InferenceRequestAddInput(
434 self.ptr,
435 name.as_ptr(),
436 input.datatype as u32,
437 input.dims.as_ptr(),
438 input.dims.len() as u64,
439 ))
440 }
441
442 fn append_input_data<I: AsRef<str>>(
443 &mut self,
444 input_name: I,
445 buffer: Buffer,
446 ) -> Result<&mut Self, Error> {
447 let name = to_cstring(&input_name)?;
448 triton_call!(sys::TRITONSERVER_InferenceRequestAppendInputData(
449 self.ptr,
450 name.as_ptr(),
451 buffer.ptr,
452 buffer.len,
453 buffer.memory_type as u32,
454 0,
455 ))?;
456
457 let _ = self.input.insert(input_name.as_ref().to_string(), buffer);
458 Ok(self)
459 }
460
461 fn append_input_data_with_policy<I: AsRef<str>, P: AsRef<str>>(
462 &mut self,
463 input_name: I,
464 policy: P,
465 buffer: Buffer,
466 ) -> Result<&mut Self, Error> {
467 let name = to_cstring(&input_name)?;
468 let policy = to_cstring(policy)?;
469 triton_call!(
470 sys::TRITONSERVER_InferenceRequestAppendInputDataWithHostPolicy(
471 self.ptr,
472 name.as_ptr(),
473 buffer.ptr,
474 buffer.len,
475 buffer.memory_type as u32,
476 0,
477 policy.as_ptr(),
478 )
479 )?;
480
481 self.input.insert(input_name.as_ref().to_string(), buffer);
482 Ok(self)
483 }
484
485 pub fn remove_input<N: AsRef<str>>(&mut self, name: N) -> Result<Buffer, Error> {
489 let buffer = self.input.remove(name.as_ref()).ok_or_else(|| {
490 Error::new(
491 ErrorCode::InvalidArg,
492 format!(
493 "Can't find input {} in a request. Appended inputs: {:?}",
494 name.as_ref(),
495 self.input.keys()
496 ),
497 )
498 })?;
499 let name = to_cstring(name)?;
500
501 triton_call!(TRITONSERVER_InferenceRequestRemoveAllInputData(
502 self.ptr,
503 name.as_ptr()
504 ))?;
505 triton_call!(
506 TRITONSERVER_InferenceRequestRemoveInput(self.ptr, name.as_ptr()),
507 buffer
508 )
509 }
510
511 pub fn remove_all_inputs(&mut self) -> Result<HashMap<String, Buffer>, Error> {
513 let mut buffers = HashMap::new();
514 std::mem::swap(&mut buffers, &mut self.input);
515
516 triton_call!(
517 TRITONSERVER_InferenceRequestRemoveAllInputs(self.ptr),
518 buffers
519 )
520 }
521
522 pub(crate) fn add_outputs(&mut self) -> Result<HashMap<String, DataType>, Error> {
523 let model = self.server.get_model(&self.model_name)?;
524 let mut datatype_hints = HashMap::new();
525
526 for output in &model.outputs {
527 self.add_output(&output.name)?;
528 datatype_hints.insert(output.name.clone(), output.datatype);
529 }
530
531 Ok(datatype_hints)
532 }
533
534 pub(crate) fn add_output<N: AsRef<str>>(&mut self, name: N) -> Result<&mut Self, Error> {
540 let output_name = to_cstring(&name)?;
541 triton_call!(
542 sys::TRITONSERVER_InferenceRequestAddRequestedOutput(self.ptr, output_name.as_ptr()),
543 self
544 )
545 }
546
547 pub fn set_parameter(&mut self, parameter: Parameter) -> Result<&mut Self, Error> {
549 let name = to_cstring(¶meter.name)?;
550 match parameter.content.clone() {
551 ParameterContent::Bool(value) => triton_call!(
552 sys::TRITONSERVER_InferenceRequestSetBoolParameter(self.ptr, name.as_ptr(), value),
553 self
554 ),
555 ParameterContent::Bytes(_) => Err(Error::new(
556 ErrorCode::Unsupported,
557 "Request::set_parameter does not support ParameterContent::Bytes",
558 )),
559 ParameterContent::Int(value) => triton_call!(
560 sys::TRITONSERVER_InferenceRequestSetIntParameter(self.ptr, name.as_ptr(), value),
561 self
562 ),
563 ParameterContent::Double(value) => {
564 triton_call!(
565 sys::TRITONSERVER_InferenceRequestSetDoubleParameter(
566 self.ptr,
567 name.as_ptr(),
568 value
569 ),
570 self
571 )
572 }
573 ParameterContent::String(value) => {
574 let value = to_cstring(value)?;
575 triton_call!(
576 sys::TRITONSERVER_InferenceRequestSetStringParameter(
577 self.ptr,
578 name.as_ptr(),
579 value.as_ptr()
580 ),
581 self
582 )
583 }
584 }
585 }
586}
587
588unsafe impl Send for Request<'_> {}
589
590impl Drop for Request<'_> {
591 fn drop(&mut self) {
592 unsafe {
593 sys::TRITONSERVER_InferenceRequestDelete(self.ptr);
594 }
595 }
596}
597
598fn assert_buffer_shape<N: AsRef<str>>(
599 shape: &Shape,
600 buffer: &Buffer,
601 source: N,
602) -> Result<(), Error> {
603 if shape.datatype != buffer.data_type {
604 return Err(Error::new(
605 ErrorCode::InvalidArg,
606 format!(
607 "input buffer datatype {:?} missmatches model shape datatype: {:?}. input name: {}",
608 buffer.data_type,
609 shape.datatype,
610 source.as_ref()
611 ),
612 ));
613 }
614 let shape_size = if shape.dims.iter().any(|n| *n < 0) || shape.dims.is_empty() {
615 0
616 } else {
617 shape.dims.iter().product::<i64>() as u32 * shape.datatype.size()
618 };
619
620 if shape_size as usize > buffer.size() {
621 Err(Error::new(
622 ErrorCode::InvalidArg,
623 format!(
624 "Buffer has size: {}, that less than shape min size: {shape_size}. input name: {}",
625 buffer.size(),
626 source.as_ref()
627 ),
628 ))
629 } else {
630 Ok(())
631 }
632}