use crate::protocols::Annotated;
use anyhow::Result;
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use validator::Validate;
pub use super::openai::nvext::{NvExt, NvExtProvider};
#[derive(Debug, Serialize, Clone, Eq, PartialEq, Deserialize)]
pub enum DataType {
Bool,
Uint8,
Uint16,
Uint32,
Uint64,
Int8,
Int16,
Int32,
Int64,
Float32,
Float64,
Bytes,
}
impl DataType {
pub fn size(&self) -> usize {
match self {
DataType::Bool => size_of::<bool>(),
DataType::Uint8 => size_of::<u8>(),
DataType::Uint16 => size_of::<u16>(),
DataType::Uint32 => size_of::<u32>(),
DataType::Uint64 => size_of::<u64>(),
DataType::Int8 => size_of::<i8>(),
DataType::Int16 => size_of::<i16>(),
DataType::Int32 => size_of::<i32>(),
DataType::Int64 => size_of::<i64>(),
DataType::Float32 => size_of::<f32>(),
DataType::Float64 => size_of::<f64>(),
DataType::Bytes => 0, }
}
}
#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
#[serde(tag = "data_type", content = "values")]
pub enum FlattenTensor {
Bool(Vec<bool>),
Uint8(Vec<u8>),
Uint16(Vec<u16>),
Uint32(Vec<u32>),
Uint64(Vec<u64>),
Int8(Vec<i8>),
Int16(Vec<i16>),
Int32(Vec<i32>),
Int64(Vec<i64>),
Float32(Vec<f32>),
Float64(Vec<f64>),
Bytes(Vec<Vec<u8>>),
}
#[allow(clippy::len_without_is_empty)]
impl FlattenTensor {
pub fn len(&self) -> usize {
match self {
Self::Bool(v) => v.len(),
Self::Uint8(v) => v.len(),
Self::Uint16(v) => v.len(),
Self::Uint32(v) => v.len(),
Self::Uint64(v) => v.len(),
Self::Int8(v) => v.len(),
Self::Int16(v) => v.len(),
Self::Int32(v) => v.len(),
Self::Int64(v) => v.len(),
Self::Float32(v) => v.len(),
Self::Float64(v) => v.len(),
Self::Bytes(v) => v.len(),
}
}
pub fn data_type(&self) -> DataType {
match self {
Self::Bool(_) => DataType::Bool,
Self::Uint8(_) => DataType::Uint8,
Self::Uint16(_) => DataType::Uint16,
Self::Uint32(_) => DataType::Uint32,
Self::Uint64(_) => DataType::Uint64,
Self::Int8(_) => DataType::Int8,
Self::Int16(_) => DataType::Int16,
Self::Int32(_) => DataType::Int32,
Self::Int64(_) => DataType::Int64,
Self::Float32(_) => DataType::Float32,
Self::Float64(_) => DataType::Float64,
Self::Bytes(_) => DataType::Bytes,
}
}
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct TensorMetadata {
pub name: String,
pub data_type: DataType,
pub shape: Vec<i64>,
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
pub parameters: Parameters,
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq, Default)]
#[serde(deny_unknown_fields)]
pub struct TensorModelConfig {
pub name: String,
pub inputs: Vec<TensorMetadata>,
pub outputs: Vec<TensorMetadata>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub triton_model_config: Option<Vec<u8>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct Tensor {
pub metadata: TensorMetadata,
pub data: FlattenTensor,
}
impl validator::Validate for Tensor {
fn validate(&self) -> Result<(), validator::ValidationErrors> {
use validator::{ValidationError, ValidationErrors};
let mut errs = ValidationErrors::new();
if self.metadata.data_type != self.data.data_type() {
let mut e = ValidationError::new("dtype_mismatch");
e.message = Some("metadata.data_type does not match data variant".into());
errs.add("data_type", e);
}
let mut product: usize = 1;
for &d in &self.metadata.shape {
if d < 0 {
let mut e = ValidationError::new("negative_dim");
e.message = Some("only -1 is allowed as a wildcard dimension".into());
errs.add("shape", e);
break;
}
product = product.saturating_mul(d as usize);
}
let expect_count = self.data.len();
if product != expect_count {
let mut e = ValidationError::new("element_count_mismatch");
e.message = Some(
format!(
"shape implies {} elements but data has {}",
product, expect_count
)
.into(),
);
errs.add("shape", e);
}
if errs.is_empty() { Ok(()) } else { Err(errs) }
}
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct NvCreateTensorRequest {
pub id: Option<String>,
pub model: String,
#[validate(nested)]
pub tensors: Vec<Tensor>,
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
pub parameters: Parameters,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct NvCreateTensorResponse {
pub id: Option<String>,
pub model: String,
#[validate(nested)]
pub tensors: Vec<Tensor>,
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
pub parameters: Parameters,
}
impl NvExtProvider for NvCreateTensorRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn raw_prompt(&self) -> Option<String> {
None
}
}
impl AnnotationsProvider for NvCreateTensorRequest {
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
pub struct DeltaAggregator {
response: Option<NvCreateTensorResponse>,
error: Option<String>,
}
impl NvCreateTensorResponse {
pub async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateTensorResponse>>,
) -> Result<NvCreateTensorResponse> {
let aggregator = stream
.fold(
DeltaAggregator {
response: None,
error: None,
},
|mut aggregator, delta| async move {
let delta = match delta.ok() {
Ok(delta) => delta,
Err(error) => {
if aggregator.error.is_none() {
aggregator.error = Some(error);
}
return aggregator;
}
};
match delta.data {
Some(resp) => {
if aggregator.response.is_none() {
aggregator.response = Some(resp);
} else if aggregator.error.is_none() {
aggregator.error =
Some("Multiple responses in non-streaming mode".to_string());
}
}
None => {
}
}
aggregator
},
)
.await;
if let Some(error) = aggregator.error {
Err(anyhow::anyhow!(error))
} else if let Some(response) = aggregator.response {
Ok(response)
} else {
Err(anyhow::anyhow!("No response received"))
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum ParameterValue {
Bool(bool),
Int64(i64),
String(String),
Double(f64),
Uint64(u64),
}
pub type Parameters = HashMap<String, ParameterValue>;