dynamo_llm/protocols/
tensor.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::protocols::Annotated;
5use anyhow::Result;
6use dynamo_runtime::protocols::annotated::AnnotationsProvider;
7use futures::{Stream, StreamExt};
8use serde::{Deserialize, Serialize};
9use validator::Validate;
10
11// [gluo TODO] whether it makes sense to have aggregator for tensor..
12// we could if considering aggregation to be stacking the tensors by adding
13// one more dimension. i.e. stream of [2, 2] tensors to be aggregated to
14// [-1, 2, 2]. Will decide it later and currently do not allow aggregation.
15// mod aggregator;
16
17// pub use aggregator::DeltaAggregator;
18
19// [gluo TODO] nvext is LLM specific, we really only use the annotation field
20pub use super::openai::nvext::{NvExt, NvExtProvider};
21
22#[derive(Debug, Serialize, Clone, Eq, PartialEq, Deserialize)]
23pub enum DataType {
24    Bool,
25    Uint8,
26    Uint16,
27    Uint32,
28    Uint64,
29    Int8,
30    Int16,
31    Int32,
32    Int64,
33    Float32,
34    Float64,
35    Bytes,
36}
37
38impl DataType {
39    pub fn size(&self) -> usize {
40        match self {
41            DataType::Bool => size_of::<bool>(),
42            DataType::Uint8 => size_of::<u8>(),
43            DataType::Uint16 => size_of::<u16>(),
44            DataType::Uint32 => size_of::<u32>(),
45            DataType::Uint64 => size_of::<u64>(),
46            DataType::Int8 => size_of::<i8>(),
47            DataType::Int16 => size_of::<i16>(),
48            DataType::Int32 => size_of::<i32>(),
49            DataType::Int64 => size_of::<i64>(),
50            DataType::Float32 => size_of::<f32>(),
51            DataType::Float64 => size_of::<f64>(),
52            DataType::Bytes => 0, // variable length, return 0 as indicator
53        }
54    }
55}
56
57#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
58// Self-describing encoding removes ambiguity between signed/unsigned and width variants.
59#[serde(tag = "data_type", content = "values")]
60pub enum FlattenTensor {
61    Bool(Vec<bool>),
62    // [gluo NOTE] f16, and bf16 is not stably supported
63    Uint8(Vec<u8>),
64    Uint16(Vec<u16>),
65    Uint32(Vec<u32>),
66    Uint64(Vec<u64>),
67    Int8(Vec<i8>),
68    Int16(Vec<i16>),
69    Int32(Vec<i32>),
70    Int64(Vec<i64>),
71    Float32(Vec<f32>),
72    Float64(Vec<f64>),
73    // Typically use to store string data, but really it can store
74    // arbitrary data such as serialized handles for custom worker behavior.
75    Bytes(Vec<Vec<u8>>),
76}
77
78#[allow(clippy::len_without_is_empty)]
79impl FlattenTensor {
80    pub fn len(&self) -> usize {
81        match self {
82            Self::Bool(v) => v.len(),
83            Self::Uint8(v) => v.len(),
84            Self::Uint16(v) => v.len(),
85            Self::Uint32(v) => v.len(),
86            Self::Uint64(v) => v.len(),
87            Self::Int8(v) => v.len(),
88            Self::Int16(v) => v.len(),
89            Self::Int32(v) => v.len(),
90            Self::Int64(v) => v.len(),
91            Self::Float32(v) => v.len(),
92            Self::Float64(v) => v.len(),
93            Self::Bytes(v) => v.len(),
94        }
95    }
96
97    pub fn data_type(&self) -> DataType {
98        match self {
99            Self::Bool(_) => DataType::Bool,
100            Self::Uint8(_) => DataType::Uint8,
101            Self::Uint16(_) => DataType::Uint16,
102            Self::Uint32(_) => DataType::Uint32,
103            Self::Uint64(_) => DataType::Uint64,
104            Self::Int8(_) => DataType::Int8,
105            Self::Int16(_) => DataType::Int16,
106            Self::Int32(_) => DataType::Int32,
107            Self::Int64(_) => DataType::Int64,
108            Self::Float32(_) => DataType::Float32,
109            Self::Float64(_) => DataType::Float64,
110            Self::Bytes(_) => DataType::Bytes,
111        }
112    }
113}
114
115#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)]
116pub struct TensorMetadata {
117    pub name: String,
118    pub data_type: DataType,
119    pub shape: Vec<i64>,
120}
121
122#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)]
123pub struct TensorModelConfig {
124    pub name: String,
125    pub inputs: Vec<TensorMetadata>,
126    pub outputs: Vec<TensorMetadata>,
127}
128
129#[derive(Serialize, Deserialize, Debug, Clone)]
130pub struct Tensor {
131    pub metadata: TensorMetadata,
132    pub data: FlattenTensor,
133}
134
135impl validator::Validate for Tensor {
136    fn validate(&self) -> Result<(), validator::ValidationErrors> {
137        use validator::{ValidationError, ValidationErrors};
138        let mut errs = ValidationErrors::new();
139
140        // dtype must match
141        if self.metadata.data_type != self.data.data_type() {
142            let mut e = ValidationError::new("dtype_mismatch");
143            e.message = Some("metadata.data_type does not match data variant".into());
144            errs.add("data_type", e);
145        }
146
147        let mut product: usize = 1;
148        for &d in &self.metadata.shape {
149            if d < 0 {
150                let mut e = ValidationError::new("negative_dim");
151                e.message = Some("only -1 is allowed as a wildcard dimension".into());
152                errs.add("shape", e);
153                break;
154            }
155            product = product.saturating_mul(d as usize);
156        }
157        // bytes payloads may be variable-length per item; enforce outer count only
158        let expect_count = self.data.len();
159        if product != expect_count {
160            let mut e = ValidationError::new("element_count_mismatch");
161            e.message = Some(
162                format!(
163                    "shape implies {} elements but data has {}",
164                    product, expect_count
165                )
166                .into(),
167            );
168            errs.add("shape", e);
169        }
170
171        if errs.is_empty() { Ok(()) } else { Err(errs) }
172    }
173}
174
175#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
176pub struct NvCreateTensorRequest {
177    /// ID of the request
178    pub id: Option<String>,
179
180    /// ID of the model to use.
181    pub model: String,
182
183    /// Input tensors.
184    pub tensors: Vec<Tensor>,
185
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub nvext: Option<NvExt>,
188}
189
190/// A response structure for unary chat completion responses, embedding OpenAI's
191/// `CreateChatCompletionResponse`.
192#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
193pub struct NvCreateTensorResponse {
194    /// ID of the corresponding request.
195    pub id: Option<String>,
196
197    /// ID of the model.
198    pub model: String,
199
200    /// Output tensors.
201    pub tensors: Vec<Tensor>,
202}
203
204/// Implements `NvExtProvider` for `NvCreateTensorRequest`,
205/// providing access to NVIDIA-specific extensions.
206impl NvExtProvider for NvCreateTensorRequest {
207    fn nvext(&self) -> Option<&NvExt> {
208        self.nvext.as_ref()
209    }
210
211    fn raw_prompt(&self) -> Option<String> {
212        // Not really apply here.
213        None
214    }
215}
216
217/// Implements `AnnotationsProvider` for `NvCreateTensorRequest`,
218/// enabling retrieval and management of request annotations.
219impl AnnotationsProvider for NvCreateTensorRequest {
220    /// Retrieves the list of annotations from `NvExt`, if present.
221    fn annotations(&self) -> Option<Vec<String>> {
222        self.nvext
223            .as_ref()
224            .and_then(|nvext| nvext.annotations.clone())
225    }
226
227    /// Checks whether a specific annotation exists in the request.
228    ///
229    /// # Arguments
230    /// * `annotation` - A string slice representing the annotation to check.
231    ///
232    /// # Returns
233    /// `true` if the annotation exists, `false` otherwise.
234    fn has_annotation(&self, annotation: &str) -> bool {
235        self.nvext
236            .as_ref()
237            .and_then(|nvext| nvext.annotations.as_ref())
238            .map(|annotations| annotations.contains(&annotation.to_string()))
239            .unwrap_or(false)
240    }
241}
242
243pub struct DeltaAggregator {
244    response: Option<NvCreateTensorResponse>,
245    error: Option<String>,
246}
247
248impl NvCreateTensorResponse {
249    pub async fn from_annotated_stream(
250        stream: impl Stream<Item = Annotated<NvCreateTensorResponse>>,
251    ) -> Result<NvCreateTensorResponse> {
252        let aggregator = stream
253            .fold(
254                DeltaAggregator {
255                    response: None,
256                    error: None,
257                },
258                |mut aggregator, delta| async move {
259                    let delta = match delta.ok() {
260                        Ok(delta) => delta,
261                        Err(error) => {
262                            if aggregator.error.is_none() {
263                                aggregator.error = Some(error);
264                            }
265                            return aggregator;
266                        }
267                    };
268                    match delta.data {
269                        Some(resp) => {
270                            if aggregator.response.is_none() {
271                                aggregator.response = Some(resp);
272                            } else if aggregator.error.is_none() {
273                                aggregator.error =
274                                    Some("Multiple responses in non-streaming mode".to_string());
275                            }
276                        }
277                        None => {
278                            // Ignore metadata-only deltas in non-streaming mode.
279                        }
280                    }
281                    aggregator
282                },
283            )
284            .await;
285        if let Some(error) = aggregator.error {
286            Err(anyhow::anyhow!(error))
287        } else if let Some(response) = aggregator.response {
288            Ok(response)
289        } else {
290            Err(anyhow::anyhow!("No response received"))
291        }
292    }
293}