use anyhow::Result;
use base64::{Engine as _, engine::general_purpose};
use dynamo_memory::SystemStorage;
use dynamo_memory::nixl::{self, NixlAgent, NixlDescriptor, RegisteredView};
use flate2::{Compression, write::ZlibEncoder};
use ndarray::{ArrayBase, Dimension, OwnedRepr};
use serde::{Deserialize, Serialize};
use std::io::Write;
use std::sync::Arc;
use super::decoders::DecodedMediaMetadata;
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum DataType {
UINT8,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct MediaTensorInfo {
pub(crate) shape: Vec<usize>,
pub(crate) dtype: DataType,
pub(crate) metadata: Option<DecodedMediaMetadata>,
}
#[derive(Debug)]
pub struct DecodedMediaData {
pub(crate) data: SystemStorage,
pub(crate) tensor_info: MediaTensorInfo,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct RdmaMediaDataDescriptor {
pub(crate) nixl_metadata: String,
pub(crate) nixl_descriptor: NixlDescriptor,
#[serde(flatten)]
pub(crate) tensor_info: MediaTensorInfo,
#[serde(skip, default)]
#[allow(dead_code)]
pub(crate) source_storage: Option<Arc<nixl::NixlRegistered<SystemStorage>>>,
}
impl DecodedMediaData {
pub fn into_rdma_descriptor(self, nixl_agent: &NixlAgent) -> Result<RdmaMediaDataDescriptor> {
let source_storage = self.data;
let registered = nixl::register_with_nixl(source_storage, nixl_agent, None)
.map_err(|_| anyhow::anyhow!("Failed to register storage with NIXL"))?;
let nixl_descriptor = registered.descriptor();
let nixl_metadata = get_nixl_metadata(nixl_agent, registered.storage())?;
Ok(RdmaMediaDataDescriptor {
nixl_metadata,
nixl_descriptor,
tensor_info: self.tensor_info,
source_storage: Some(Arc::new(registered)),
})
}
}
impl<D: Dimension> TryFrom<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
type Error = anyhow::Error;
fn try_from(array: ArrayBase<OwnedRepr<u8>, D>) -> Result<Self, Self::Error> {
let shape = array.shape().to_vec();
let (data_vec, _) = array.into_raw_vec_and_offset();
let mut storage = SystemStorage::new(data_vec.len())?;
unsafe {
std::ptr::copy_nonoverlapping(data_vec.as_ptr(), storage.as_mut_ptr(), data_vec.len());
}
Ok(Self {
data: storage,
tensor_info: MediaTensorInfo {
shape,
dtype: DataType::UINT8,
metadata: None,
},
})
}
}
pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result<String> {
let nixl_md = agent.raw_agent().get_local_md()?;
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::new(6));
encoder.write_all(&nixl_md)?;
let compressed = encoder.finish()?;
let b64_encoded = general_purpose::STANDARD.encode(&compressed);
Ok(format!("b64:{}", b64_encoded))
}
pub fn get_nixl_agent() -> Result<NixlAgent> {
let name = format!("media-loader-{}", uuid::Uuid::new_v4());
let nixl_agent = NixlAgent::with_backends(&name, &["UCX"])?;
Ok(nixl_agent)
}