dynamo_llm/grpc/service/
tensor.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use dynamo_runtime::{
5    engine::AsyncEngineContext,
6    pipeline::{AsyncEngineContextProvider, Context},
7    protocols::annotated::AnnotationsProvider,
8};
9use futures::{Stream, StreamExt, stream};
10use std::str::FromStr;
11use std::sync::Arc;
12
13use crate::types::Annotated;
14
15use super::kserve;
16
17// [gluo NOTE] These are common utilities that should be shared between frontends
18use crate::http::service::{
19    disconnect::{ConnectionHandle, create_connection_monitor},
20    metrics::{Endpoint, ResponseMetricCollector},
21};
22use crate::{http::service::metrics::InflightGuard, preprocessor::LLMMetricAnnotation};
23
24use crate::protocols::tensor;
25use crate::protocols::tensor::{
26    NvCreateTensorRequest, NvCreateTensorResponse, Tensor, TensorMetadata,
27};
28
29use crate::grpc::service::kserve::inference;
30use crate::grpc::service::kserve::inference::DataType;
31
32use tonic::Status;
33
34/// Dynamo Annotation for the request ID
35pub const ANNOTATION_REQUEST_ID: &str = "request_id";
36
37/// Tensor Request Handler
38///
39/// This method will handle the incoming request for model type tensor. The endpoint is a "source"
40/// for an [`super::OpenAICompletionsStreamingEngine`] and will return a stream of
41/// responses which will be forward to the client.
42///
43/// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For
44/// non-streaming requests, we will fold the stream into a single response as part of this handler.
45pub async fn tensor_response_stream(
46    state: Arc<kserve::State>,
47    request: NvCreateTensorRequest,
48    streaming: bool,
49) -> Result<impl Stream<Item = Annotated<NvCreateTensorResponse>>, Status> {
50    // create the context for the request
51    let request_id = get_or_create_request_id(request.id.as_deref());
52    let request = Context::with_id(request, request_id.clone());
53    let context = request.context();
54
55    // [gluo TODO] revisit metrics to properly expose it
56    // create the connection handles
57    let (mut connection_handle, stream_handle) =
58        create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
59
60    // todo - make the protocols be optional for model name
61    // todo - when optional, if none, apply a default
62    let model = &request.model;
63
64    // todo - error handling should be more robust
65    let engine = state
66        .manager()
67        .get_tensor_engine(model)
68        .map_err(|_| Status::not_found("model not found"))?;
69
70    let inflight_guard =
71        state
72            .metrics_clone()
73            .create_inflight_guard(model, Endpoint::Tensor, streaming);
74
75    let mut response_collector = state.metrics_clone().create_response_collector(model);
76
77    // prepare to process any annotations
78    let annotations = request.annotations();
79
80    // issue the generate call on the engine
81    let stream = engine.generate(request).await.map_err(|e| {
82        Status::internal(format!("Failed to generate tensor response stream: {}", e))
83    })?;
84
85    // capture the context to cancel the stream if the client disconnects
86    let ctx = stream.context();
87
88    // prepare any requested annotations
89    let annotations = annotations.map_or(Vec::new(), |annotations| {
90        annotations
91            .iter()
92            .filter_map(|annotation| {
93                if annotation == ANNOTATION_REQUEST_ID {
94                    Annotated::<NvCreateTensorResponse>::from_annotation(
95                        ANNOTATION_REQUEST_ID,
96                        &request_id,
97                    )
98                    .ok()
99                } else {
100                    None
101                }
102            })
103            .collect::<Vec<_>>()
104    });
105
106    // apply any annotations to the front of the stream
107    let stream = stream::iter(annotations).chain(stream);
108
109    // Tap on the stream to collect response metrics
110    let stream = stream.inspect(move |response| {
111        process_metrics_only(response, &mut response_collector);
112    });
113
114    let stream = grpc_monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
115
116    // if we got here, then we will return a response and the potentially long running task has completed successfully
117    // without need to be cancelled.
118    connection_handle.disarm();
119
120    Ok(stream)
121}
122
123/// This method will consume an AsyncEngineStream and monitor for disconnects or context cancellation.
124/// This is gRPC variant of `monitor_for_disconnects` as that implementation has SSE specific handling.
125/// Should decouple and reuse `monitor_for_disconnects`
126///
127/// Uses `tokio::select!` to choose between receiving responses from the source stream or detecting when
128/// the context is stopped. If the context is stopped, we break the stream. If the source stream ends
129/// naturally, we mark the request as successful and send the final `[DONE]` event.
130pub fn grpc_monitor_for_disconnects<T>(
131    stream: impl Stream<Item = Annotated<T>>,
132    context: Arc<dyn AsyncEngineContext>,
133    mut inflight_guard: InflightGuard,
134    mut stream_handle: ConnectionHandle,
135) -> impl Stream<Item = Annotated<T>> {
136    stream_handle.arm();
137    async_stream::stream! {
138        tokio::pin!(stream);
139        loop {
140            tokio::select! {
141                event = stream.next() => {
142                    match event {
143                        Some(response) => {
144                            yield response;
145                        }
146                        None => {
147                            // Stream ended normally
148                            inflight_guard.mark_ok();
149                            stream_handle.disarm();
150                            break;
151                        }
152                    }
153                }
154                // todo - test request cancellation with kserve frontend and tensor-based models
155                _ = context.stopped() => {
156                    tracing::trace!("Context stopped; breaking stream");
157                    break;
158                }
159            }
160        }
161    }
162}
163
164fn process_metrics_only<T>(
165    annotated: &Annotated<T>,
166    response_collector: &mut ResponseMetricCollector,
167) {
168    // update metrics
169    if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(annotated) {
170        response_collector.observe_current_osl(metrics.output_tokens);
171        response_collector.observe_response(metrics.input_tokens, metrics.chunk_tokens);
172    }
173}
174
175/// Get the request ID from a primary source, or lastly create a new one if not present
176fn get_or_create_request_id(primary: Option<&str>) -> String {
177    // Try to get the request ID from the primary source
178    if let Some(primary) = primary
179        && let Ok(uuid) = uuid::Uuid::parse_str(primary)
180    {
181        return uuid.to_string();
182    }
183
184    // Try to parse the request ID as a UUID, or generate a new one if missing/invalid
185    let uuid = uuid::Uuid::new_v4();
186    uuid.to_string()
187}
188
189impl TryFrom<inference::ModelInferRequest> for NvCreateTensorRequest {
190    type Error = Status;
191
192    fn try_from(request: inference::ModelInferRequest) -> Result<Self, Self::Error> {
193        // Protocol requires if `raw_input_contents` is used to hold input data,
194        // it must be used for all inputs.
195        if !request.raw_input_contents.is_empty()
196            && request.inputs.len() != request.raw_input_contents.len()
197        {
198            return Err(Status::invalid_argument(
199                "`raw_input_contents` must be used for all inputs",
200            ));
201        }
202
203        let mut tensor_request = NvCreateTensorRequest {
204            id: if !request.id.is_empty() {
205                Some(request.id.clone())
206            } else {
207                None
208            },
209            model: request.model_name.clone(),
210            tensors: Vec::new(),
211            nvext: None,
212        };
213
214        // iterate through inputs
215        for (idx, input) in request.inputs.into_iter().enumerate() {
216            let mut tensor = Tensor {
217                metadata: TensorMetadata {
218                    name: input.name.clone(),
219                    data_type: tensor::DataType::from_str(&input.datatype)
220                        .map_err(|err| Status::invalid_argument(err.to_string()))?,
221                    shape: input.shape.clone(),
222                },
223                // Placeholder, will be filled below
224                data: tensor::FlattenTensor::Bool(Vec::new()),
225            };
226            match &input.contents {
227                // If contents is provided in InferInputTensor
228                Some(contents) => {
229                    tensor.set_data_from_tensor_contents(contents);
230                }
231                // If not in InferInputTensor, contents is provided in raw_input_contents
232                None => {
233                    tensor.set_data_from_raw_contents(&request.raw_input_contents[idx])?;
234                }
235            }
236            tensor_request.tensors.push(tensor);
237        }
238        Ok(tensor_request)
239    }
240}
241
242impl tensor::Tensor {
243    fn set_data_from_tensor_contents(&mut self, contents: &inference::InferTensorContents) {
244        self.data = match self.metadata.data_type {
245            tensor::DataType::Bool => tensor::FlattenTensor::Bool(contents.bool_contents.clone()),
246            tensor::DataType::Uint8 => tensor::FlattenTensor::Uint8(
247                contents.uint_contents.iter().map(|&x| x as u8).collect(),
248            ),
249            tensor::DataType::Uint16 => tensor::FlattenTensor::Uint16(
250                contents.uint_contents.iter().map(|&x| x as u16).collect(),
251            ),
252            tensor::DataType::Uint32 => {
253                tensor::FlattenTensor::Uint32(contents.uint_contents.clone())
254            }
255            tensor::DataType::Uint64 => {
256                tensor::FlattenTensor::Uint64(contents.uint64_contents.clone())
257            }
258            tensor::DataType::Int8 => tensor::FlattenTensor::Int8(
259                contents.int_contents.iter().map(|&x| x as i8).collect(),
260            ),
261            tensor::DataType::Int16 => tensor::FlattenTensor::Int16(
262                contents.int_contents.iter().map(|&x| x as i16).collect(),
263            ),
264            tensor::DataType::Int32 => tensor::FlattenTensor::Int32(contents.int_contents.clone()),
265            tensor::DataType::Int64 => {
266                tensor::FlattenTensor::Int64(contents.int64_contents.clone())
267            }
268
269            tensor::DataType::Float32 => {
270                tensor::FlattenTensor::Float32(contents.fp32_contents.clone())
271            }
272
273            tensor::DataType::Float64 => {
274                tensor::FlattenTensor::Float64(contents.fp64_contents.clone())
275            }
276
277            tensor::DataType::Bytes => {
278                tensor::FlattenTensor::Bytes(contents.bytes_contents.clone())
279            }
280        }
281    }
282
283    #[allow(clippy::result_large_err)]
284    fn set_data_from_raw_contents(&mut self, raw_input: &[u8]) -> Result<(), Status> {
285        let element_count = self.metadata.shape.iter().try_fold(1usize, |acc, &d| {
286            if d < 0 {
287                Err(Status::invalid_argument(format!(
288                    "Shape contains negative dimension: {}",
289                    d
290                )))
291            } else {
292                acc.checked_mul(d as usize).ok_or_else(|| {
293                    Status::invalid_argument("Overflow occurred while calculating element count")
294                })
295            }
296        })?;
297
298        let data_size = self.metadata.data_type.size();
299
300        // For BYTES type, we need to parse length-prefixed strings and properly slice them
301        // into bytes of array, and early return
302        if data_size == 0 {
303            self.data = self.raw_input_to_bytes_tensor(element_count, raw_input)?;
304            return Ok(());
305        }
306
307        // Control reaches here on non-bytes types
308        // validate raw input length before conversion
309        if !raw_input.len().is_multiple_of(data_size) {
310            return Err(Status::invalid_argument(format!(
311                "Raw input length must be a multiple of {}",
312                data_size
313            )));
314        } else if raw_input.len() / data_size != element_count {
315            return Err(Status::invalid_argument(format!(
316                "Raw input element count for '{}' does not match expected size, expected {} elements, got {} elements",
317                self.metadata.name,
318                element_count,
319                raw_input.len() / data_size
320            )));
321        }
322        self.data = self.raw_input_to_typed_tensor(raw_input)?;
323
324        Ok(())
325    }
326
327    #[allow(clippy::result_large_err)]
328    fn raw_input_to_bytes_tensor(
329        &self,
330        element_count: usize,
331        raw_input: &[u8],
332    ) -> Result<tensor::FlattenTensor, Status> {
333        // element is not fixed size for bytes type, so the raw input has
334        // length-prefixed bytes for each element.
335        let mut bytes_contents = vec![];
336        let mut offset = 0;
337        while offset + 4 <= raw_input.len() {
338            let len =
339                u32::from_le_bytes(raw_input[offset..offset + 4].try_into().unwrap()) as usize;
340            offset += 4;
341            if offset + len > raw_input.len() {
342                return Err(Status::invalid_argument(format!(
343                    "Invalid length-prefixed BYTES input for '{}', length exceeds raw input size",
344                    self.metadata.name
345                )));
346            }
347            bytes_contents.push(raw_input[offset..offset + len].to_vec());
348            offset += len;
349        }
350        if offset != raw_input.len() {
351            return Err(Status::invalid_argument(format!(
352                "Invalid length-prefixed BYTES input for '{}', extra bytes at the end",
353                self.metadata.name
354            )));
355        }
356        if element_count != bytes_contents.len() {
357            return Err(Status::invalid_argument(format!(
358                "Raw input element count for '{}' does not match expected size, expected {} elements, got {} elements",
359                self.metadata.name,
360                element_count,
361                bytes_contents.len()
362            )));
363        }
364        Ok(tensor::FlattenTensor::Bytes(bytes_contents))
365    }
366
367    #[allow(clippy::result_large_err)]
368    fn raw_input_to_typed_tensor(&self, raw_input: &[u8]) -> Result<tensor::FlattenTensor, Status> {
369        // In Rust, we can not "reinterpret cast" a Vec<u8> to Vec<T> directly
370        // as Vec require the pointer to be aligned with the type T, which can not
371        // be guaranteed from Vec<u8>. We will have to reconstruct the Vec<T> element
372        // by element which results in data copy.
373        // Here we assume little endianess for all types as the KServe protocol doesn't
374        // specify the endianness while it should have.
375        match self.metadata.data_type {
376            tensor::DataType::Bool => Ok(tensor::FlattenTensor::Bool(
377                raw_input.iter().map(|&b| b != 0).collect(),
378            )),
379            tensor::DataType::Uint8 => Ok(tensor::FlattenTensor::Uint8(
380                raw_input.chunks_exact(1).map(|chunk| chunk[0]).collect(),
381            )),
382            tensor::DataType::Uint16 => Ok(tensor::FlattenTensor::Uint16(
383                raw_input
384                    .chunks_exact(2)
385                    .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
386                    .collect(),
387            )),
388            tensor::DataType::Uint32 => Ok(tensor::FlattenTensor::Uint32(
389                raw_input
390                    .chunks_exact(4)
391                    .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
392                    .collect(),
393            )),
394            tensor::DataType::Uint64 => Ok(tensor::FlattenTensor::Uint64(
395                raw_input
396                    .chunks_exact(8)
397                    .map(|chunk| {
398                        u64::from_le_bytes([
399                            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
400                            chunk[7],
401                        ])
402                    })
403                    .collect(),
404            )),
405            tensor::DataType::Int8 => Ok(tensor::FlattenTensor::Int8(
406                raw_input
407                    .chunks_exact(1)
408                    .map(|chunk| chunk[0] as i8)
409                    .collect(),
410            )),
411            tensor::DataType::Int16 => Ok(tensor::FlattenTensor::Int16(
412                raw_input
413                    .chunks_exact(2)
414                    .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
415                    .collect(),
416            )),
417            tensor::DataType::Int32 => Ok(tensor::FlattenTensor::Int32(
418                raw_input
419                    .chunks_exact(4)
420                    .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
421                    .collect(),
422            )),
423            tensor::DataType::Int64 => Ok(tensor::FlattenTensor::Int64(
424                raw_input
425                    .chunks_exact(8)
426                    .map(|chunk| {
427                        i64::from_le_bytes([
428                            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
429                            chunk[7],
430                        ])
431                    })
432                    .collect(),
433            )),
434            tensor::DataType::Float32 => Ok(tensor::FlattenTensor::Float32(
435                raw_input
436                    .chunks_exact(4)
437                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
438                    .collect(),
439            )),
440            tensor::DataType::Float64 => Ok(tensor::FlattenTensor::Float64(
441                raw_input
442                    .chunks_exact(8)
443                    .map(|chunk| {
444                        f64::from_le_bytes([
445                            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
446                            chunk[7],
447                        ])
448                    })
449                    .collect(),
450            )),
451            tensor::DataType::Bytes => Err(Status::internal(format!(
452                "Unexpected BYTES type in non-bytes branch for input '{}'",
453                self.metadata.name
454            ))),
455        }
456    }
457}
458
459impl TryFrom<NvCreateTensorResponse> for inference::ModelInferResponse {
460    type Error = anyhow::Error;
461
462    fn try_from(response: NvCreateTensorResponse) -> Result<Self, Self::Error> {
463        let mut infer_response = inference::ModelInferResponse {
464            model_name: response.model,
465            model_version: "1".to_string(),
466            id: response.id.unwrap_or_default(),
467            outputs: vec![],
468            parameters: ::std::collections::HashMap::<String, inference::InferParameter>::new(),
469            raw_output_contents: vec![],
470        };
471        for tensor in &response.tensors {
472            infer_response
473                .outputs
474                .push(inference::model_infer_response::InferOutputTensor {
475                    name: tensor.metadata.name.clone(),
476                    datatype: tensor.metadata.data_type.to_string(),
477                    shape: tensor.metadata.shape.clone(),
478                    contents: match &tensor.data {
479                        tensor::FlattenTensor::Bool(data) => Some(inference::InferTensorContents {
480                            bool_contents: data.clone(),
481                            ..Default::default()
482                        }),
483                        tensor::FlattenTensor::Uint8(data) => {
484                            Some(inference::InferTensorContents {
485                                uint_contents: data.iter().map(|&x| x as u32).collect(),
486                                ..Default::default()
487                            })
488                        }
489
490                        tensor::FlattenTensor::Uint16(data) => {
491                            Some(inference::InferTensorContents {
492                                uint_contents: data.iter().map(|&x| x as u32).collect(),
493                                ..Default::default()
494                            })
495                        }
496
497                        tensor::FlattenTensor::Uint32(data) => {
498                            Some(inference::InferTensorContents {
499                                uint_contents: data.clone(),
500                                ..Default::default()
501                            })
502                        }
503
504                        tensor::FlattenTensor::Uint64(data) => {
505                            Some(inference::InferTensorContents {
506                                uint64_contents: data.clone(),
507                                ..Default::default()
508                            })
509                        }
510
511                        tensor::FlattenTensor::Int8(data) => Some(inference::InferTensorContents {
512                            int_contents: data.iter().map(|&x| x as i32).collect(),
513                            ..Default::default()
514                        }),
515                        tensor::FlattenTensor::Int16(data) => {
516                            Some(inference::InferTensorContents {
517                                int_contents: data.iter().map(|&x| x as i32).collect(),
518                                ..Default::default()
519                            })
520                        }
521
522                        tensor::FlattenTensor::Int32(data) => {
523                            Some(inference::InferTensorContents {
524                                int_contents: data.clone(),
525                                ..Default::default()
526                            })
527                        }
528
529                        tensor::FlattenTensor::Int64(data) => {
530                            Some(inference::InferTensorContents {
531                                int64_contents: data.clone(),
532                                ..Default::default()
533                            })
534                        }
535
536                        tensor::FlattenTensor::Float32(data) => {
537                            Some(inference::InferTensorContents {
538                                fp32_contents: data.clone(),
539                                ..Default::default()
540                            })
541                        }
542
543                        tensor::FlattenTensor::Float64(data) => {
544                            Some(inference::InferTensorContents {
545                                fp64_contents: data.clone(),
546                                ..Default::default()
547                            })
548                        }
549
550                        tensor::FlattenTensor::Bytes(data) => {
551                            Some(inference::InferTensorContents {
552                                bytes_contents: data.clone(),
553                                ..Default::default()
554                            })
555                        }
556                    },
557                    ..Default::default()
558                });
559        }
560
561        Ok(infer_response)
562    }
563}
564
565impl TryFrom<NvCreateTensorResponse> for inference::ModelStreamInferResponse {
566    type Error = anyhow::Error;
567
568    fn try_from(response: NvCreateTensorResponse) -> Result<Self, Self::Error> {
569        match inference::ModelInferResponse::try_from(response) {
570            Ok(response) => Ok(inference::ModelStreamInferResponse {
571                infer_response: Some(response),
572                ..Default::default()
573            }),
574            Err(e) => Ok(inference::ModelStreamInferResponse {
575                infer_response: None,
576                error_message: format!("Failed to convert response: {}", e),
577            }),
578        }
579    }
580}
581
582impl tensor::DataType {
583    pub fn to_kserve(&self) -> i32 {
584        match *self {
585            tensor::DataType::Bool => DataType::TypeBool as i32,
586            tensor::DataType::Uint8 => DataType::TypeUint8 as i32,
587            tensor::DataType::Uint16 => DataType::TypeUint16 as i32,
588            tensor::DataType::Uint32 => DataType::TypeUint32 as i32,
589            tensor::DataType::Uint64 => DataType::TypeUint64 as i32,
590            tensor::DataType::Int8 => DataType::TypeInt8 as i32,
591            tensor::DataType::Int16 => DataType::TypeInt16 as i32,
592            tensor::DataType::Int32 => DataType::TypeInt32 as i32,
593            tensor::DataType::Int64 => DataType::TypeInt64 as i32,
594            tensor::DataType::Float32 => DataType::TypeFp32 as i32,
595            tensor::DataType::Float64 => DataType::TypeFp64 as i32,
596            tensor::DataType::Bytes => DataType::TypeString as i32,
597        }
598    }
599}
600
601impl std::fmt::Display for tensor::DataType {
602    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603        match *self {
604            tensor::DataType::Bool => write!(f, "BOOL"),
605            tensor::DataType::Uint8 => write!(f, "UINT8"),
606            tensor::DataType::Uint16 => write!(f, "UINT16"),
607            tensor::DataType::Uint32 => write!(f, "UINT32"),
608            tensor::DataType::Uint64 => write!(f, "UINT64"),
609            tensor::DataType::Int8 => write!(f, "INT8"),
610            tensor::DataType::Int16 => write!(f, "INT16"),
611            tensor::DataType::Int32 => write!(f, "INT32"),
612            tensor::DataType::Int64 => write!(f, "INT64"),
613            tensor::DataType::Float32 => write!(f, "FP32"),
614            tensor::DataType::Float64 => write!(f, "FP64"),
615            tensor::DataType::Bytes => write!(f, "BYTES"),
616        }
617    }
618}
619
620impl FromStr for tensor::DataType {
621    type Err = anyhow::Error;
622
623    fn from_str(s: &str) -> Result<Self, Self::Err> {
624        match s {
625            "BOOL" => Ok(tensor::DataType::Bool),
626            "UINT8" => Ok(tensor::DataType::Uint8),
627            "UINT16" => Ok(tensor::DataType::Uint16),
628            "UINT32" => Ok(tensor::DataType::Uint32),
629            "UINT64" => Ok(tensor::DataType::Uint64),
630            "INT8" => Ok(tensor::DataType::Int8),
631            "INT16" => Ok(tensor::DataType::Int16),
632            "INT32" => Ok(tensor::DataType::Int32),
633            "INT64" => Ok(tensor::DataType::Int64),
634            "FP32" => Ok(tensor::DataType::Float32),
635            "FP64" => Ok(tensor::DataType::Float64),
636            "BYTES" => Ok(tensor::DataType::Bytes),
637            _ => Err(anyhow::anyhow!("Invalid data type")),
638        }
639    }
640}