#[cfg(feature = "arrow")]
mod arrow;
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::fmt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ModuleKind {
#[default]
Normal,
Session,
SessionDiff,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModuleFunc<'a> {
pub name: Cow<'a, str>,
pub description: Option<Cow<'a, str>>,
pub input: PyroSchema<'a>,
pub output: PyroSchema<'a>,
#[serde(default)]
pub kind: ModuleKind,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct InterfaceSpec<'a> {
pub capability: Cow<'a, str>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<Cow<'a, str>>,
pub classes: Vec<ClassSpec<'a>>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub structs: BTreeMap<Cow<'a, str>, PyroSchema<'a>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ClassSpec<'a> {
pub name: Cow<'a, str>,
pub description: Option<Cow<'a, str>>,
pub methods: Vec<CapabilityFunc<'a>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client: Option<PyroSchema<'a>>,
pub config: Option<PyroSchema<'a>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct CapabilityFunc<'a> {
pub name: Cow<'a, str>,
pub description: Option<Cow<'a, str>>,
pub input: PyroSchema<'a>,
pub output: PyroType<'a>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum PyroType<'a> {
Null,
PrimitiveScalar(PrimitiveDataType),
Str,
Timestamp,
PrimitiveList(PrimitiveDataType),
PrimitiveFixedList(PrimitiveDataType, usize),
List(Box<PyroType<'a>>, bool),
Group(Cow<'a, [PyroField<'a>]>),
Map {
key: Box<PyroType<'a>>,
value: Box<PyroType<'a>>,
},
}
impl<'a> fmt::Display for PyroType<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PyroType::Null => write!(f, "Null"),
PyroType::PrimitiveScalar(t) => write!(f, "{}", t),
PyroType::Str => write!(f, "Str"),
PyroType::Timestamp => write!(f, "Timestamp"),
PyroType::PrimitiveList(inner_type) => {
write!(f, "[{}]", inner_type)
}
PyroType::PrimitiveFixedList(inner_type, len) => {
write!(f, "[{}; {}]", inner_type, len)
}
PyroType::List(inner_type, _nullable) => {
write!(f, "[{}]", inner_type)
}
PyroType::Group(fields) => {
write!(f, "{{ ")?;
for (i, field) in fields.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}: {}", field.name, field.data_type)?;
}
write!(f, " }}")
}
PyroType::Map { key, value } => {
write!(f, "Map<{}, {}>", key, value)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum PrimitiveDataType {
Bool,
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
F16,
F32,
F64,
}
impl fmt::Display for PrimitiveDataType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PrimitiveDataType::Bool => write!(f, "Bool"),
PrimitiveDataType::U8 => write!(f, "U8"),
PrimitiveDataType::U16 => write!(f, "U16"),
PrimitiveDataType::U32 => write!(f, "U32"),
PrimitiveDataType::U64 => write!(f, "U64"),
PrimitiveDataType::I8 => write!(f, "I8"),
PrimitiveDataType::I16 => write!(f, "I16"),
PrimitiveDataType::I32 => write!(f, "I32"),
PrimitiveDataType::I64 => write!(f, "I64"),
PrimitiveDataType::F16 => write!(f, "F16"),
PrimitiveDataType::F32 => write!(f, "F32"),
PrimitiveDataType::F64 => write!(f, "F64"),
}
}
}
impl<'a> PyroType<'a> {
pub fn into_owned(self) -> PyroType<'static> {
match self {
PyroType::Null => PyroType::Null,
PyroType::PrimitiveScalar(p) => PyroType::PrimitiveScalar(p),
PyroType::Str => PyroType::Str,
PyroType::Timestamp => PyroType::Timestamp,
PyroType::PrimitiveList(p) => PyroType::PrimitiveList(p),
PyroType::PrimitiveFixedList(p, l) => PyroType::PrimitiveFixedList(p, l),
PyroType::List(inner, n) => PyroType::List(Box::new(inner.into_owned()), n),
PyroType::Group(fields) => {
let owned_fields: Vec<PyroField<'static>> =
fields.iter().map(|f| f.clone().into_owned()).collect();
PyroType::Group(Cow::Owned(owned_fields))
}
PyroType::Map { key, value } => PyroType::Map {
key: Box::new(key.into_owned()),
value: Box::new(value.into_owned()),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PyroField<'a> {
pub name: Cow<'a, str>,
pub documentation: Option<Cow<'a, str>>,
pub data_type: PyroType<'a>,
pub nullable: bool,
}
impl<'a> PyroField<'a> {
pub fn new(name: impl Into<Cow<'a, str>>, data_type: PyroType<'a>, nullable: bool) -> Self {
Self {
name: name.into(),
documentation: None,
data_type,
nullable,
}
}
#[inline]
pub fn name(&self) -> &str {
&self.name
}
#[inline]
pub fn data_type(&self) -> &PyroType<'a> {
&self.data_type
}
#[inline]
pub fn is_nullable(&self) -> bool {
self.nullable
}
pub fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
pub fn into_owned(self) -> PyroField<'static> {
PyroField {
name: Cow::Owned(self.name.into_owned()),
documentation: self.documentation.map(|d| Cow::Owned(d.into_owned())),
data_type: self.data_type.into_owned(),
nullable: self.nullable,
}
}
pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
self.documentation = Some(doc.into());
self
}
}
impl<'a> fmt::Display for PyroField<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}: {:?}{}",
self.name,
self.data_type,
if self.nullable { " (nullable)" } else { "" }
)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PyroSchema<'a> {
pub documentation: Option<Cow<'a, str>>,
pub fields: Cow<'a, [PyroField<'a>]>,
}
impl<'a> PyroSchema<'a> {
pub fn new(fields: Vec<PyroField<'a>>) -> Self {
Self {
documentation: None,
fields: Cow::Owned(fields),
}
}
pub fn empty() -> Self {
Self {
documentation: None,
fields: Cow::Owned(Vec::new()),
}
}
#[inline]
pub fn fields(&self) -> &[PyroField<'a>] {
&self.fields
}
#[inline]
pub fn num_fields(&self) -> usize {
self.fields.len()
}
pub fn field_with_name(&self, name: &str) -> Option<&PyroField<'a>> {
self.fields.iter().find(|f| f.name == name)
}
pub fn field(&self, index: usize) -> &PyroField<'a> {
&self.fields[index]
}
pub fn index_of(&self, name: &str) -> Option<usize> {
self.fields.iter().position(|f| f.name == name)
}
pub fn into_owned(self) -> PyroSchema<'static> {
PyroSchema {
documentation: self.documentation.map(|d| Cow::Owned(d.into_owned())),
fields: self.fields.iter().map(|f| f.clone().into_owned()).collect(),
}
}
pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
self.documentation = Some(doc.into());
self
}
}
impl<'a> fmt::Display for PyroSchema<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "PyroSchema {{")?;
for field in self.fields.iter() {
writeln!(f, " {field},")?;
}
write!(f, "}}")
}
}
impl<'a> From<Vec<PyroField<'a>>> for PyroSchema<'a> {
fn from(fields: Vec<PyroField<'a>>) -> Self {
Self::new(fields)
}
}
pub fn coerce_pyro_types<'a>(a: &PyroType<'a>, b: &PyroType<'a>) -> Option<PyroType<'a>> {
if a == b {
return Some(a.clone());
}
use PyroType::*;
match (a, b) {
(Null, other) | (other, Null) => Some(other.clone()),
(PrimitiveScalar(pa), PrimitiveScalar(pb)) => {
coerce_primitive_types(*pa, *pb).map(PrimitiveScalar)
}
(List(inner_a, null_a), List(inner_b, null_b)) => {
let merged_null = *null_a || *null_b;
coerce_pyro_types(inner_a, inner_b).map(|c| List(Box::new(c), merged_null))
}
(PrimitiveList(pa), PrimitiveList(pb)) => {
coerce_primitive_types(*pa, *pb).map(PrimitiveList)
}
(PrimitiveFixedList(pa, sa), PrimitiveFixedList(pb, sb)) => {
let coerced_elem = coerce_primitive_types(*pa, *pb)?;
if sa == sb {
Some(PrimitiveFixedList(coerced_elem, *sa))
} else {
Some(PrimitiveList(coerced_elem))
}
}
(PrimitiveFixedList(pa, _), PrimitiveList(pb))
| (PrimitiveList(pa), PrimitiveFixedList(pb, _)) => {
coerce_primitive_types(*pa, *pb).map(PrimitiveList)
}
(Group(fields_a), Group(fields_b)) => {
let mut merged_map: BTreeMap<String, PyroField> = BTreeMap::new();
for f in fields_a.iter().chain(fields_b.iter()) {
match merged_map.get(f.name()) {
None => {
merged_map.insert(
f.name().to_string(),
PyroField::new(
Cow::Owned(f.name().to_string()),
f.data_type().clone(),
true,
),
);
}
Some(existing) => {
let coerced = coerce_pyro_types(existing.data_type(), f.data_type())?;
let nullable = existing.is_nullable() || f.is_nullable();
merged_map.insert(
f.name().to_string(),
PyroField::new(Cow::Owned(f.name().to_string()), coerced, nullable),
);
}
}
}
Some(Group(Cow::Owned(merged_map.into_values().collect())))
}
(Map { key: ka, value: va }, Map { key: kb, value: vb }) => {
let coerced_key = coerce_pyro_types(ka, kb)?;
let coerced_val = coerce_pyro_types(va, vb)?;
Some(Map {
key: Box::new(coerced_key),
value: Box::new(coerced_val),
})
}
_ => None,
}
}
fn coerce_primitive_types(a: PrimitiveDataType, b: PrimitiveDataType) -> Option<PrimitiveDataType> {
if a == b {
return Some(a);
}
use PrimitiveDataType as P;
match (a, b) {
(P::I8, P::I16) | (P::I16, P::I8) => Some(P::I16),
(P::I8, P::I32) | (P::I32, P::I8) => Some(P::I32),
(P::I8, P::I64) | (P::I64, P::I8) => Some(P::I64),
(P::I16, P::I32) | (P::I32, P::I16) => Some(P::I32),
(P::I16, P::I64) | (P::I64, P::I16) => Some(P::I64),
(P::I32, P::I64) | (P::I64, P::I32) => Some(P::I64),
(P::U8, P::U16) | (P::U16, P::U8) => Some(P::U16),
(P::U8, P::U32) | (P::U32, P::U8) => Some(P::U32),
(P::U8, P::U64) | (P::U64, P::U8) => Some(P::U64),
(P::U16, P::U32) | (P::U32, P::U16) => Some(P::U32),
(P::U16, P::U64) | (P::U64, P::U16) => Some(P::U64),
(P::U32, P::U64) | (P::U64, P::U32) => Some(P::U64),
(P::F16, P::F32) | (P::F32, P::F16) => Some(P::F32),
(P::F32, P::F64) | (P::F64, P::F32) => Some(P::F64),
(P::F16, P::F64) | (P::F64, P::F16) => Some(P::F64),
(P::I8 | P::I16 | P::I32 | P::I64, P::F64) | (P::F64, P::I8 | P::I16 | P::I32 | P::I64) => {
Some(P::F64)
}
(P::U8 | P::U16 | P::U32 | P::U64, P::F64) | (P::F64, P::U8 | P::U16 | P::U32 | P::U64) => {
Some(P::F64)
}
_ => None,
}
}