use jacquard_common::{BosStr, CowStr, DefaultStr, FromStaticStr};
use jacquard_derive::IntoStatic;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, IntoStatic, Hash)]
pub struct ArrowTensor;
impl core::fmt::Display for ArrowTensor {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "arrowTensor")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ArrayFormat<S: BosStr = DefaultStr> {
NdarrayBytes,
SparseBytes,
StructuredBytes,
ArrowTensor,
Safetensors,
Other(S),
}
impl<S: BosStr> ArrayFormat<S> {
pub fn as_str(&self) -> &str {
match self {
Self::NdarrayBytes => "ndarrayBytes",
Self::SparseBytes => "sparseBytes",
Self::StructuredBytes => "structuredBytes",
Self::ArrowTensor => "arrowTensor",
Self::Safetensors => "safetensors",
Self::Other(s) => s.as_ref(),
}
}
pub fn from_value(s: S) -> Self {
match s.as_ref() {
"ndarrayBytes" => Self::NdarrayBytes,
"sparseBytes" => Self::SparseBytes,
"structuredBytes" => Self::StructuredBytes,
"arrowTensor" => Self::ArrowTensor,
"safetensors" => Self::Safetensors,
_ => Self::Other(s),
}
}
}
impl<S: BosStr> AsRef<str> for ArrayFormat<S> {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl<S: BosStr> core::fmt::Display for ArrayFormat<S> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl<S: BosStr> Serialize for ArrayFormat<S> {
fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
where
Ser: serde::Serializer,
{
serializer.serialize_str(self.as_str())
}
}
impl<'de, S: Deserialize<'de> + BosStr> Deserialize<'de> for ArrayFormat<S> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = S::deserialize(deserializer)?;
Ok(Self::from_value(s))
}
}
impl<S: BosStr> jacquard_common::IntoStatic for ArrayFormat<S>
where
S: BosStr + jacquard_common::IntoStatic,
S::Output: BosStr,
{
type Output = ArrayFormat<S::Output>;
fn into_static(self) -> Self::Output {
match self {
ArrayFormat::NdarrayBytes => ArrayFormat::NdarrayBytes,
ArrayFormat::SparseBytes => ArrayFormat::SparseBytes,
ArrayFormat::StructuredBytes => ArrayFormat::StructuredBytes,
ArrayFormat::ArrowTensor => ArrayFormat::ArrowTensor,
ArrayFormat::Safetensors => ArrayFormat::Safetensors,
ArrayFormat::Other(v) => ArrayFormat::Other(v.into_static()),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, IntoStatic, Hash)]
pub struct NdarrayBytes;
impl core::fmt::Display for NdarrayBytes {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "ndarrayBytes")
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, IntoStatic, Hash)]
pub struct Safetensors;
impl core::fmt::Display for Safetensors {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "safetensors")
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, IntoStatic, Hash)]
pub struct SparseBytes;
impl core::fmt::Display for SparseBytes {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "sparseBytes")
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, IntoStatic, Hash)]
pub struct StructuredBytes;
impl core::fmt::Display for StructuredBytes {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "structuredBytes")
}
}