1use crate::protocols::Annotated;
5use anyhow::Result;
6use dynamo_runtime::protocols::annotated::AnnotationsProvider;
7use futures::{Stream, StreamExt};
8use serde::{Deserialize, Serialize};
9use validator::Validate;
10
11pub use super::openai::nvext::{NvExt, NvExtProvider};
21
22#[derive(Debug, Serialize, Clone, Eq, PartialEq, Deserialize)]
23pub enum DataType {
24 Bool,
25 Uint8,
26 Uint16,
27 Uint32,
28 Uint64,
29 Int8,
30 Int16,
31 Int32,
32 Int64,
33 Float32,
34 Float64,
35 Bytes,
36}
37
38impl DataType {
39 pub fn size(&self) -> usize {
40 match self {
41 DataType::Bool => size_of::<bool>(),
42 DataType::Uint8 => size_of::<u8>(),
43 DataType::Uint16 => size_of::<u16>(),
44 DataType::Uint32 => size_of::<u32>(),
45 DataType::Uint64 => size_of::<u64>(),
46 DataType::Int8 => size_of::<i8>(),
47 DataType::Int16 => size_of::<i16>(),
48 DataType::Int32 => size_of::<i32>(),
49 DataType::Int64 => size_of::<i64>(),
50 DataType::Float32 => size_of::<f32>(),
51 DataType::Float64 => size_of::<f64>(),
52 DataType::Bytes => 0, }
54 }
55}
56
57#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
58#[serde(tag = "data_type", content = "values")]
60pub enum FlattenTensor {
61 Bool(Vec<bool>),
62 Uint8(Vec<u8>),
64 Uint16(Vec<u16>),
65 Uint32(Vec<u32>),
66 Uint64(Vec<u64>),
67 Int8(Vec<i8>),
68 Int16(Vec<i16>),
69 Int32(Vec<i32>),
70 Int64(Vec<i64>),
71 Float32(Vec<f32>),
72 Float64(Vec<f64>),
73 Bytes(Vec<Vec<u8>>),
76}
77
78#[allow(clippy::len_without_is_empty)]
79impl FlattenTensor {
80 pub fn len(&self) -> usize {
81 match self {
82 Self::Bool(v) => v.len(),
83 Self::Uint8(v) => v.len(),
84 Self::Uint16(v) => v.len(),
85 Self::Uint32(v) => v.len(),
86 Self::Uint64(v) => v.len(),
87 Self::Int8(v) => v.len(),
88 Self::Int16(v) => v.len(),
89 Self::Int32(v) => v.len(),
90 Self::Int64(v) => v.len(),
91 Self::Float32(v) => v.len(),
92 Self::Float64(v) => v.len(),
93 Self::Bytes(v) => v.len(),
94 }
95 }
96
97 pub fn data_type(&self) -> DataType {
98 match self {
99 Self::Bool(_) => DataType::Bool,
100 Self::Uint8(_) => DataType::Uint8,
101 Self::Uint16(_) => DataType::Uint16,
102 Self::Uint32(_) => DataType::Uint32,
103 Self::Uint64(_) => DataType::Uint64,
104 Self::Int8(_) => DataType::Int8,
105 Self::Int16(_) => DataType::Int16,
106 Self::Int32(_) => DataType::Int32,
107 Self::Int64(_) => DataType::Int64,
108 Self::Float32(_) => DataType::Float32,
109 Self::Float64(_) => DataType::Float64,
110 Self::Bytes(_) => DataType::Bytes,
111 }
112 }
113}
114
115#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)]
116pub struct TensorMetadata {
117 pub name: String,
118 pub data_type: DataType,
119 pub shape: Vec<i64>,
120}
121
122#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)]
123pub struct TensorModelConfig {
124 pub name: String,
125 pub inputs: Vec<TensorMetadata>,
126 pub outputs: Vec<TensorMetadata>,
127}
128
129#[derive(Serialize, Deserialize, Debug, Clone)]
130pub struct Tensor {
131 pub metadata: TensorMetadata,
132 pub data: FlattenTensor,
133}
134
135impl validator::Validate for Tensor {
136 fn validate(&self) -> Result<(), validator::ValidationErrors> {
137 use validator::{ValidationError, ValidationErrors};
138 let mut errs = ValidationErrors::new();
139
140 if self.metadata.data_type != self.data.data_type() {
142 let mut e = ValidationError::new("dtype_mismatch");
143 e.message = Some("metadata.data_type does not match data variant".into());
144 errs.add("data_type", e);
145 }
146
147 let mut product: usize = 1;
148 for &d in &self.metadata.shape {
149 if d < 0 {
150 let mut e = ValidationError::new("negative_dim");
151 e.message = Some("only -1 is allowed as a wildcard dimension".into());
152 errs.add("shape", e);
153 break;
154 }
155 product = product.saturating_mul(d as usize);
156 }
157 let expect_count = self.data.len();
159 if product != expect_count {
160 let mut e = ValidationError::new("element_count_mismatch");
161 e.message = Some(
162 format!(
163 "shape implies {} elements but data has {}",
164 product, expect_count
165 )
166 .into(),
167 );
168 errs.add("shape", e);
169 }
170
171 if errs.is_empty() { Ok(()) } else { Err(errs) }
172 }
173}
174
175#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
176pub struct NvCreateTensorRequest {
177 pub id: Option<String>,
179
180 pub model: String,
182
183 pub tensors: Vec<Tensor>,
185
186 #[serde(skip_serializing_if = "Option::is_none")]
187 pub nvext: Option<NvExt>,
188}
189
190#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
193pub struct NvCreateTensorResponse {
194 pub id: Option<String>,
196
197 pub model: String,
199
200 pub tensors: Vec<Tensor>,
202}
203
204impl NvExtProvider for NvCreateTensorRequest {
207 fn nvext(&self) -> Option<&NvExt> {
208 self.nvext.as_ref()
209 }
210
211 fn raw_prompt(&self) -> Option<String> {
212 None
214 }
215}
216
217impl AnnotationsProvider for NvCreateTensorRequest {
220 fn annotations(&self) -> Option<Vec<String>> {
222 self.nvext
223 .as_ref()
224 .and_then(|nvext| nvext.annotations.clone())
225 }
226
227 fn has_annotation(&self, annotation: &str) -> bool {
235 self.nvext
236 .as_ref()
237 .and_then(|nvext| nvext.annotations.as_ref())
238 .map(|annotations| annotations.contains(&annotation.to_string()))
239 .unwrap_or(false)
240 }
241}
242
243pub struct DeltaAggregator {
244 response: Option<NvCreateTensorResponse>,
245 error: Option<String>,
246}
247
248impl NvCreateTensorResponse {
249 pub async fn from_annotated_stream(
250 stream: impl Stream<Item = Annotated<NvCreateTensorResponse>>,
251 ) -> Result<NvCreateTensorResponse> {
252 let aggregator = stream
253 .fold(
254 DeltaAggregator {
255 response: None,
256 error: None,
257 },
258 |mut aggregator, delta| async move {
259 let delta = match delta.ok() {
260 Ok(delta) => delta,
261 Err(error) => {
262 if aggregator.error.is_none() {
263 aggregator.error = Some(error);
264 }
265 return aggregator;
266 }
267 };
268 match delta.data {
269 Some(resp) => {
270 if aggregator.response.is_none() {
271 aggregator.response = Some(resp);
272 } else if aggregator.error.is_none() {
273 aggregator.error =
274 Some("Multiple responses in non-streaming mode".to_string());
275 }
276 }
277 None => {
278 }
280 }
281 aggregator
282 },
283 )
284 .await;
285 if let Some(error) = aggregator.error {
286 Err(anyhow::anyhow!(error))
287 } else if let Some(response) = aggregator.response {
288 Ok(response)
289 } else {
290 Err(anyhow::anyhow!("No response received"))
291 }
292 }
293}