1use dynamo_runtime::{
5 engine::AsyncEngineContext,
6 pipeline::{AsyncEngineContextProvider, Context},
7 protocols::annotated::AnnotationsProvider,
8};
9use futures::{Stream, StreamExt, stream};
10use std::str::FromStr;
11use std::sync::Arc;
12
13use crate::types::Annotated;
14
15use super::kserve;
16
17use crate::http::service::{
19 disconnect::{ConnectionHandle, create_connection_monitor},
20 metrics::{Endpoint, ResponseMetricCollector},
21};
22use crate::{http::service::metrics::InflightGuard, preprocessor::LLMMetricAnnotation};
23
24use crate::protocols::tensor;
25use crate::protocols::tensor::{
26 NvCreateTensorRequest, NvCreateTensorResponse, Tensor, TensorMetadata,
27};
28
29use crate::grpc::service::kserve::inference;
30use crate::grpc::service::kserve::inference::DataType;
31
32use tonic::Status;
33
34pub const ANNOTATION_REQUEST_ID: &str = "request_id";
36
37pub async fn tensor_response_stream(
46 state: Arc<kserve::State>,
47 request: NvCreateTensorRequest,
48 streaming: bool,
49) -> Result<impl Stream<Item = Annotated<NvCreateTensorResponse>>, Status> {
50 let request_id = get_or_create_request_id(request.id.as_deref());
52 let request = Context::with_id(request, request_id.clone());
53 let context = request.context();
54
55 let (mut connection_handle, stream_handle) =
58 create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
59
60 let model = &request.model;
63
64 let engine = state
66 .manager()
67 .get_tensor_engine(model)
68 .map_err(|_| Status::not_found("model not found"))?;
69
70 let inflight_guard =
71 state
72 .metrics_clone()
73 .create_inflight_guard(model, Endpoint::Tensor, streaming);
74
75 let mut response_collector = state.metrics_clone().create_response_collector(model);
76
77 let annotations = request.annotations();
79
80 let stream = engine.generate(request).await.map_err(|e| {
82 Status::internal(format!("Failed to generate tensor response stream: {}", e))
83 })?;
84
85 let ctx = stream.context();
87
88 let annotations = annotations.map_or(Vec::new(), |annotations| {
90 annotations
91 .iter()
92 .filter_map(|annotation| {
93 if annotation == ANNOTATION_REQUEST_ID {
94 Annotated::<NvCreateTensorResponse>::from_annotation(
95 ANNOTATION_REQUEST_ID,
96 &request_id,
97 )
98 .ok()
99 } else {
100 None
101 }
102 })
103 .collect::<Vec<_>>()
104 });
105
106 let stream = stream::iter(annotations).chain(stream);
108
109 let stream = stream.inspect(move |response| {
111 process_metrics_only(response, &mut response_collector);
112 });
113
114 let stream = grpc_monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
115
116 connection_handle.disarm();
119
120 Ok(stream)
121}
122
123pub fn grpc_monitor_for_disconnects<T>(
131 stream: impl Stream<Item = Annotated<T>>,
132 context: Arc<dyn AsyncEngineContext>,
133 mut inflight_guard: InflightGuard,
134 mut stream_handle: ConnectionHandle,
135) -> impl Stream<Item = Annotated<T>> {
136 stream_handle.arm();
137 async_stream::stream! {
138 tokio::pin!(stream);
139 loop {
140 tokio::select! {
141 event = stream.next() => {
142 match event {
143 Some(response) => {
144 yield response;
145 }
146 None => {
147 inflight_guard.mark_ok();
149 stream_handle.disarm();
150 break;
151 }
152 }
153 }
154 _ = context.stopped() => {
156 tracing::trace!("Context stopped; breaking stream");
157 break;
158 }
159 }
160 }
161 }
162}
163
164fn process_metrics_only<T>(
165 annotated: &Annotated<T>,
166 response_collector: &mut ResponseMetricCollector,
167) {
168 if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(annotated) {
170 response_collector.observe_current_osl(metrics.output_tokens);
171 response_collector.observe_response(metrics.input_tokens, metrics.chunk_tokens);
172 }
173}
174
175fn get_or_create_request_id(primary: Option<&str>) -> String {
177 if let Some(primary) = primary
179 && let Ok(uuid) = uuid::Uuid::parse_str(primary)
180 {
181 return uuid.to_string();
182 }
183
184 let uuid = uuid::Uuid::new_v4();
186 uuid.to_string()
187}
188
189impl TryFrom<inference::ModelInferRequest> for NvCreateTensorRequest {
190 type Error = Status;
191
192 fn try_from(request: inference::ModelInferRequest) -> Result<Self, Self::Error> {
193 if !request.raw_input_contents.is_empty()
196 && request.inputs.len() != request.raw_input_contents.len()
197 {
198 return Err(Status::invalid_argument(
199 "`raw_input_contents` must be used for all inputs",
200 ));
201 }
202
203 let mut tensor_request = NvCreateTensorRequest {
204 id: if !request.id.is_empty() {
205 Some(request.id.clone())
206 } else {
207 None
208 },
209 model: request.model_name.clone(),
210 tensors: Vec::new(),
211 nvext: None,
212 };
213
214 for (idx, input) in request.inputs.into_iter().enumerate() {
216 let mut tensor = Tensor {
217 metadata: TensorMetadata {
218 name: input.name.clone(),
219 data_type: tensor::DataType::from_str(&input.datatype)
220 .map_err(|err| Status::invalid_argument(err.to_string()))?,
221 shape: input.shape.clone(),
222 },
223 data: tensor::FlattenTensor::Bool(Vec::new()),
225 };
226 match &input.contents {
227 Some(contents) => {
229 tensor.set_data_from_tensor_contents(contents);
230 }
231 None => {
233 tensor.set_data_from_raw_contents(&request.raw_input_contents[idx])?;
234 }
235 }
236 tensor_request.tensors.push(tensor);
237 }
238 Ok(tensor_request)
239 }
240}
241
242impl tensor::Tensor {
243 fn set_data_from_tensor_contents(&mut self, contents: &inference::InferTensorContents) {
244 self.data = match self.metadata.data_type {
245 tensor::DataType::Bool => tensor::FlattenTensor::Bool(contents.bool_contents.clone()),
246 tensor::DataType::Uint8 => tensor::FlattenTensor::Uint8(
247 contents.uint_contents.iter().map(|&x| x as u8).collect(),
248 ),
249 tensor::DataType::Uint16 => tensor::FlattenTensor::Uint16(
250 contents.uint_contents.iter().map(|&x| x as u16).collect(),
251 ),
252 tensor::DataType::Uint32 => {
253 tensor::FlattenTensor::Uint32(contents.uint_contents.clone())
254 }
255 tensor::DataType::Uint64 => {
256 tensor::FlattenTensor::Uint64(contents.uint64_contents.clone())
257 }
258 tensor::DataType::Int8 => tensor::FlattenTensor::Int8(
259 contents.int_contents.iter().map(|&x| x as i8).collect(),
260 ),
261 tensor::DataType::Int16 => tensor::FlattenTensor::Int16(
262 contents.int_contents.iter().map(|&x| x as i16).collect(),
263 ),
264 tensor::DataType::Int32 => tensor::FlattenTensor::Int32(contents.int_contents.clone()),
265 tensor::DataType::Int64 => {
266 tensor::FlattenTensor::Int64(contents.int64_contents.clone())
267 }
268
269 tensor::DataType::Float32 => {
270 tensor::FlattenTensor::Float32(contents.fp32_contents.clone())
271 }
272
273 tensor::DataType::Float64 => {
274 tensor::FlattenTensor::Float64(contents.fp64_contents.clone())
275 }
276
277 tensor::DataType::Bytes => {
278 tensor::FlattenTensor::Bytes(contents.bytes_contents.clone())
279 }
280 }
281 }
282
283 #[allow(clippy::result_large_err)]
284 fn set_data_from_raw_contents(&mut self, raw_input: &[u8]) -> Result<(), Status> {
285 let element_count = self.metadata.shape.iter().try_fold(1usize, |acc, &d| {
286 if d < 0 {
287 Err(Status::invalid_argument(format!(
288 "Shape contains negative dimension: {}",
289 d
290 )))
291 } else {
292 acc.checked_mul(d as usize).ok_or_else(|| {
293 Status::invalid_argument("Overflow occurred while calculating element count")
294 })
295 }
296 })?;
297
298 let data_size = self.metadata.data_type.size();
299
300 if data_size == 0 {
303 self.data = self.raw_input_to_bytes_tensor(element_count, raw_input)?;
304 return Ok(());
305 }
306
307 if !raw_input.len().is_multiple_of(data_size) {
310 return Err(Status::invalid_argument(format!(
311 "Raw input length must be a multiple of {}",
312 data_size
313 )));
314 } else if raw_input.len() / data_size != element_count {
315 return Err(Status::invalid_argument(format!(
316 "Raw input element count for '{}' does not match expected size, expected {} elements, got {} elements",
317 self.metadata.name,
318 element_count,
319 raw_input.len() / data_size
320 )));
321 }
322 self.data = self.raw_input_to_typed_tensor(raw_input)?;
323
324 Ok(())
325 }
326
327 #[allow(clippy::result_large_err)]
328 fn raw_input_to_bytes_tensor(
329 &self,
330 element_count: usize,
331 raw_input: &[u8],
332 ) -> Result<tensor::FlattenTensor, Status> {
333 let mut bytes_contents = vec![];
336 let mut offset = 0;
337 while offset + 4 <= raw_input.len() {
338 let len =
339 u32::from_le_bytes(raw_input[offset..offset + 4].try_into().unwrap()) as usize;
340 offset += 4;
341 if offset + len > raw_input.len() {
342 return Err(Status::invalid_argument(format!(
343 "Invalid length-prefixed BYTES input for '{}', length exceeds raw input size",
344 self.metadata.name
345 )));
346 }
347 bytes_contents.push(raw_input[offset..offset + len].to_vec());
348 offset += len;
349 }
350 if offset != raw_input.len() {
351 return Err(Status::invalid_argument(format!(
352 "Invalid length-prefixed BYTES input for '{}', extra bytes at the end",
353 self.metadata.name
354 )));
355 }
356 if element_count != bytes_contents.len() {
357 return Err(Status::invalid_argument(format!(
358 "Raw input element count for '{}' does not match expected size, expected {} elements, got {} elements",
359 self.metadata.name,
360 element_count,
361 bytes_contents.len()
362 )));
363 }
364 Ok(tensor::FlattenTensor::Bytes(bytes_contents))
365 }
366
367 #[allow(clippy::result_large_err)]
368 fn raw_input_to_typed_tensor(&self, raw_input: &[u8]) -> Result<tensor::FlattenTensor, Status> {
369 match self.metadata.data_type {
376 tensor::DataType::Bool => Ok(tensor::FlattenTensor::Bool(
377 raw_input.iter().map(|&b| b != 0).collect(),
378 )),
379 tensor::DataType::Uint8 => Ok(tensor::FlattenTensor::Uint8(
380 raw_input.chunks_exact(1).map(|chunk| chunk[0]).collect(),
381 )),
382 tensor::DataType::Uint16 => Ok(tensor::FlattenTensor::Uint16(
383 raw_input
384 .chunks_exact(2)
385 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
386 .collect(),
387 )),
388 tensor::DataType::Uint32 => Ok(tensor::FlattenTensor::Uint32(
389 raw_input
390 .chunks_exact(4)
391 .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
392 .collect(),
393 )),
394 tensor::DataType::Uint64 => Ok(tensor::FlattenTensor::Uint64(
395 raw_input
396 .chunks_exact(8)
397 .map(|chunk| {
398 u64::from_le_bytes([
399 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
400 chunk[7],
401 ])
402 })
403 .collect(),
404 )),
405 tensor::DataType::Int8 => Ok(tensor::FlattenTensor::Int8(
406 raw_input
407 .chunks_exact(1)
408 .map(|chunk| chunk[0] as i8)
409 .collect(),
410 )),
411 tensor::DataType::Int16 => Ok(tensor::FlattenTensor::Int16(
412 raw_input
413 .chunks_exact(2)
414 .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
415 .collect(),
416 )),
417 tensor::DataType::Int32 => Ok(tensor::FlattenTensor::Int32(
418 raw_input
419 .chunks_exact(4)
420 .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
421 .collect(),
422 )),
423 tensor::DataType::Int64 => Ok(tensor::FlattenTensor::Int64(
424 raw_input
425 .chunks_exact(8)
426 .map(|chunk| {
427 i64::from_le_bytes([
428 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
429 chunk[7],
430 ])
431 })
432 .collect(),
433 )),
434 tensor::DataType::Float32 => Ok(tensor::FlattenTensor::Float32(
435 raw_input
436 .chunks_exact(4)
437 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
438 .collect(),
439 )),
440 tensor::DataType::Float64 => Ok(tensor::FlattenTensor::Float64(
441 raw_input
442 .chunks_exact(8)
443 .map(|chunk| {
444 f64::from_le_bytes([
445 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
446 chunk[7],
447 ])
448 })
449 .collect(),
450 )),
451 tensor::DataType::Bytes => Err(Status::internal(format!(
452 "Unexpected BYTES type in non-bytes branch for input '{}'",
453 self.metadata.name
454 ))),
455 }
456 }
457}
458
459impl TryFrom<NvCreateTensorResponse> for inference::ModelInferResponse {
460 type Error = anyhow::Error;
461
462 fn try_from(response: NvCreateTensorResponse) -> Result<Self, Self::Error> {
463 let mut infer_response = inference::ModelInferResponse {
464 model_name: response.model,
465 model_version: "1".to_string(),
466 id: response.id.unwrap_or_default(),
467 outputs: vec![],
468 parameters: ::std::collections::HashMap::<String, inference::InferParameter>::new(),
469 raw_output_contents: vec![],
470 };
471 for tensor in &response.tensors {
472 infer_response
473 .outputs
474 .push(inference::model_infer_response::InferOutputTensor {
475 name: tensor.metadata.name.clone(),
476 datatype: tensor.metadata.data_type.to_string(),
477 shape: tensor.metadata.shape.clone(),
478 contents: match &tensor.data {
479 tensor::FlattenTensor::Bool(data) => Some(inference::InferTensorContents {
480 bool_contents: data.clone(),
481 ..Default::default()
482 }),
483 tensor::FlattenTensor::Uint8(data) => {
484 Some(inference::InferTensorContents {
485 uint_contents: data.iter().map(|&x| x as u32).collect(),
486 ..Default::default()
487 })
488 }
489
490 tensor::FlattenTensor::Uint16(data) => {
491 Some(inference::InferTensorContents {
492 uint_contents: data.iter().map(|&x| x as u32).collect(),
493 ..Default::default()
494 })
495 }
496
497 tensor::FlattenTensor::Uint32(data) => {
498 Some(inference::InferTensorContents {
499 uint_contents: data.clone(),
500 ..Default::default()
501 })
502 }
503
504 tensor::FlattenTensor::Uint64(data) => {
505 Some(inference::InferTensorContents {
506 uint64_contents: data.clone(),
507 ..Default::default()
508 })
509 }
510
511 tensor::FlattenTensor::Int8(data) => Some(inference::InferTensorContents {
512 int_contents: data.iter().map(|&x| x as i32).collect(),
513 ..Default::default()
514 }),
515 tensor::FlattenTensor::Int16(data) => {
516 Some(inference::InferTensorContents {
517 int_contents: data.iter().map(|&x| x as i32).collect(),
518 ..Default::default()
519 })
520 }
521
522 tensor::FlattenTensor::Int32(data) => {
523 Some(inference::InferTensorContents {
524 int_contents: data.clone(),
525 ..Default::default()
526 })
527 }
528
529 tensor::FlattenTensor::Int64(data) => {
530 Some(inference::InferTensorContents {
531 int64_contents: data.clone(),
532 ..Default::default()
533 })
534 }
535
536 tensor::FlattenTensor::Float32(data) => {
537 Some(inference::InferTensorContents {
538 fp32_contents: data.clone(),
539 ..Default::default()
540 })
541 }
542
543 tensor::FlattenTensor::Float64(data) => {
544 Some(inference::InferTensorContents {
545 fp64_contents: data.clone(),
546 ..Default::default()
547 })
548 }
549
550 tensor::FlattenTensor::Bytes(data) => {
551 Some(inference::InferTensorContents {
552 bytes_contents: data.clone(),
553 ..Default::default()
554 })
555 }
556 },
557 ..Default::default()
558 });
559 }
560
561 Ok(infer_response)
562 }
563}
564
565impl TryFrom<NvCreateTensorResponse> for inference::ModelStreamInferResponse {
566 type Error = anyhow::Error;
567
568 fn try_from(response: NvCreateTensorResponse) -> Result<Self, Self::Error> {
569 match inference::ModelInferResponse::try_from(response) {
570 Ok(response) => Ok(inference::ModelStreamInferResponse {
571 infer_response: Some(response),
572 ..Default::default()
573 }),
574 Err(e) => Ok(inference::ModelStreamInferResponse {
575 infer_response: None,
576 error_message: format!("Failed to convert response: {}", e),
577 }),
578 }
579 }
580}
581
582impl tensor::DataType {
583 pub fn to_kserve(&self) -> i32 {
584 match *self {
585 tensor::DataType::Bool => DataType::TypeBool as i32,
586 tensor::DataType::Uint8 => DataType::TypeUint8 as i32,
587 tensor::DataType::Uint16 => DataType::TypeUint16 as i32,
588 tensor::DataType::Uint32 => DataType::TypeUint32 as i32,
589 tensor::DataType::Uint64 => DataType::TypeUint64 as i32,
590 tensor::DataType::Int8 => DataType::TypeInt8 as i32,
591 tensor::DataType::Int16 => DataType::TypeInt16 as i32,
592 tensor::DataType::Int32 => DataType::TypeInt32 as i32,
593 tensor::DataType::Int64 => DataType::TypeInt64 as i32,
594 tensor::DataType::Float32 => DataType::TypeFp32 as i32,
595 tensor::DataType::Float64 => DataType::TypeFp64 as i32,
596 tensor::DataType::Bytes => DataType::TypeString as i32,
597 }
598 }
599}
600
601impl std::fmt::Display for tensor::DataType {
602 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603 match *self {
604 tensor::DataType::Bool => write!(f, "BOOL"),
605 tensor::DataType::Uint8 => write!(f, "UINT8"),
606 tensor::DataType::Uint16 => write!(f, "UINT16"),
607 tensor::DataType::Uint32 => write!(f, "UINT32"),
608 tensor::DataType::Uint64 => write!(f, "UINT64"),
609 tensor::DataType::Int8 => write!(f, "INT8"),
610 tensor::DataType::Int16 => write!(f, "INT16"),
611 tensor::DataType::Int32 => write!(f, "INT32"),
612 tensor::DataType::Int64 => write!(f, "INT64"),
613 tensor::DataType::Float32 => write!(f, "FP32"),
614 tensor::DataType::Float64 => write!(f, "FP64"),
615 tensor::DataType::Bytes => write!(f, "BYTES"),
616 }
617 }
618}
619
620impl FromStr for tensor::DataType {
621 type Err = anyhow::Error;
622
623 fn from_str(s: &str) -> Result<Self, Self::Err> {
624 match s {
625 "BOOL" => Ok(tensor::DataType::Bool),
626 "UINT8" => Ok(tensor::DataType::Uint8),
627 "UINT16" => Ok(tensor::DataType::Uint16),
628 "UINT32" => Ok(tensor::DataType::Uint32),
629 "UINT64" => Ok(tensor::DataType::Uint64),
630 "INT8" => Ok(tensor::DataType::Int8),
631 "INT16" => Ok(tensor::DataType::Int16),
632 "INT32" => Ok(tensor::DataType::Int32),
633 "INT64" => Ok(tensor::DataType::Int64),
634 "FP32" => Ok(tensor::DataType::Float32),
635 "FP64" => Ok(tensor::DataType::Float64),
636 "BYTES" => Ok(tensor::DataType::Bytes),
637 _ => Err(anyhow::anyhow!("Invalid data type")),
638 }
639 }
640}