Skip to main content

pyro_spec/
lib.rs

1//! Pyro-native type system and schema representation.
2//!
3//! This crate provides a lightweight schema representation optimized for
4//! `PyroValue`. Unlike Arrow's `DataType` (which has ~40 variants for timestamps,
5//! decimals, run-end-encoded, etc.), `PyroType` mirrors *exactly* the variants that
6//! `PyroValue` can represent, making match arms exhaustive and tiny.
7//!
8//! Core types:
9//! - [`PyroType`] — the main type enum, mirroring `PyroValue` discriminants 1:1
10//! - [`PyroField`] — a named, nullable column descriptor (equivalent to Arrow `Field`)
11//! - [`PyroSchema`] — an ordered collection of `PyroField`s (equivalent to Arrow `Schema`)
12//! - [`coerce_pyro_types`] — type coercion to find common supertypes
13//!
14//! Conversion to/from `arrow::datatypes::DataType` lives in the `arrow` module
15//! (behind the `arrow` feature flag).
16
17// =============================================================================
18// Pyro-native type system
19// =============================================================================
20//
21// A lightweight schema representation optimized for PyroValue.
22// Unlike Arrow's DataType (which has ~40 variants for timestamps, decimals,
23// run-end-encoded, etc.), PyroType mirrors *exactly* the variants that
24// PyroValue can represent, making match arms exhaustive and tiny.
25//
26// Conversion to/from `arrow::datatypes::DataType` lives in `value::arrow::schema`.
27
28#[cfg(feature = "arrow")]
29mod arrow;
30
31use std::borrow::Cow;
32use std::collections::BTreeMap;
33use std::fmt;
34
35use serde::{Deserialize, Serialize};
36
37/// Documentation for the main function of a module
38#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
39pub struct ModuleFunc<'a> {
40    pub name: Cow<'a, str>,
41    pub description: Option<Cow<'a, str>>,
42    pub input: PyroSchema<'a>,
43    pub output: PyroSchema<'a>,
44}
45
46/// The root specification object.
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub struct InterfaceSpec<'a> {
49    pub capability: Cow<'a, str>,
50
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub description: Option<Cow<'a, str>>,
53
54    pub classes: Vec<ClassSpec<'a>>,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
58pub struct ClassSpec<'a> {
59    pub name: Cow<'a, str>,
60    pub description: Option<Cow<'a, str>>,
61    pub methods: Vec<CapabilityFunc<'a>>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub client: Option<PyroSchema<'a>>,
64    pub config: Option<PyroSchema<'a>>,
65}
66
67/// Documentation for a capability function
68#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
69pub struct CapabilityFunc<'a> {
70    pub name: Cow<'a, str>,
71    pub description: Option<Cow<'a, str>>,
72    pub input: PyroSchema<'a>,
73    pub output: PyroType<'a>,
74}
75
76// =============================================================================
77// PyroType
78// =============================================================================
79
80/// A data type enum that mirrors the variants of [`PyroValue`] exactly.
81///
82/// This is intentionally much smaller than `arrow::datatypes::DataType`.
83/// Every variant here has a 1:1 correspondence with a `PyroValue` discriminant.
84#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
85pub enum PyroType<'a> {
86    /// No value / unknown type (corresponds to `PyroValue::Null`).
87    Null,
88    /// Scalar primitive (Bool, Int, Float).
89    PrimitiveScalar(PrimitiveDataType),
90    /// UTF-8 string (corresponds to `PyroValue::Str`).
91    Str,
92    /// Day + millisecond interval (corresponds to `PyroValue::Timestamp`).
93    Timestamp,
94    /// Homogeneous list of a single primitive type (corresponds to `PyroValue::PrimitiveList`).
95    PrimitiveList(PrimitiveDataType),
96    /// Fixed-size homogeneous list of a single primitive type.
97    PrimitiveFixedList(PrimitiveDataType, usize),
98    /// Heterogeneous list of arbitrary pyro values (corresponds to `PyroValue::List`).
99    ///
100    /// Fields: `(element_type, element_nullable)`.
101    List(Box<PyroType<'a>>, bool),
102    /// Named struct / row (corresponds to `PyroValue::Group`).
103    Group(Cow<'a, [PyroField<'a>]>),
104    /// Key-value map (corresponds to `PyroValue::MapInternal`).
105    Map {
106        key: Box<PyroType<'a>>,
107        value: Box<PyroType<'a>>,
108    },
109}
110
111impl<'a> fmt::Display for PyroType<'a> {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        match self {
114            // Primitives
115            PyroType::Null => write!(f, "Null"),
116            PyroType::PrimitiveScalar(t) => write!(f, "{}", t),
117            PyroType::Str => write!(f, "Str"),
118            PyroType::Timestamp => write!(f, "Timestamp"),
119
120            // Complex Types
121            PyroType::PrimitiveList(inner_type) => {
122                write!(f, "[{}]", inner_type)
123            }
124            PyroType::PrimitiveFixedList(inner_type, len) => {
125                write!(f, "[{}; {}]", inner_type, len)
126            }
127            PyroType::List(inner_type, _nullable) => {
128                write!(f, "[{}]", inner_type)
129            }
130            PyroType::Group(fields) => {
131                write!(f, "{{ ")?;
132                for (i, field) in fields.iter().enumerate() {
133                    if i > 0 {
134                        write!(f, ", ")?;
135                    }
136                    write!(f, "{}: {}", field.name, field.data_type)?;
137                }
138                write!(f, " }}")
139            }
140            PyroType::Map { key, value } => {
141                write!(f, "Map<{}, {}>", key, value)
142            }
143        }
144    }
145}
146
147/// The primitive element type inside a `PrimitiveValueList`.
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
149pub enum PrimitiveDataType {
150    Bool,
151    U8,
152    U16,
153    U32,
154    U64,
155    I8,
156    I16,
157    I32,
158    I64,
159    F16,
160    F32,
161    F64,
162}
163
164impl fmt::Display for PrimitiveDataType {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        match self {
167            PrimitiveDataType::Bool => write!(f, "Bool"),
168            PrimitiveDataType::U8 => write!(f, "U8"),
169            PrimitiveDataType::U16 => write!(f, "U16"),
170            PrimitiveDataType::U32 => write!(f, "U32"),
171            PrimitiveDataType::U64 => write!(f, "U64"),
172            PrimitiveDataType::I8 => write!(f, "I8"),
173            PrimitiveDataType::I16 => write!(f, "I16"),
174            PrimitiveDataType::I32 => write!(f, "I32"),
175            PrimitiveDataType::I64 => write!(f, "I64"),
176            PrimitiveDataType::F16 => write!(f, "F16"),
177            PrimitiveDataType::F32 => write!(f, "F32"),
178            PrimitiveDataType::F64 => write!(f, "F64"),
179        }
180    }
181}
182
183impl<'a> PyroType<'a> {
184    pub fn into_owned(self) -> PyroType<'static> {
185        match self {
186            PyroType::Null => PyroType::Null,
187            PyroType::PrimitiveScalar(p) => PyroType::PrimitiveScalar(p),
188            PyroType::Str => PyroType::Str,
189            PyroType::Timestamp => PyroType::Timestamp,
190            PyroType::PrimitiveList(p) => PyroType::PrimitiveList(p),
191            PyroType::PrimitiveFixedList(p, l) => PyroType::PrimitiveFixedList(p, l),
192            PyroType::List(inner, n) => PyroType::List(Box::new(inner.into_owned()), n),
193            PyroType::Group(fields) => {
194                let owned_fields: Vec<PyroField<'static>> =
195                    fields.iter().map(|f| f.clone().into_owned()).collect();
196                PyroType::Group(Cow::Owned(owned_fields))
197            }
198            PyroType::Map { key, value } => PyroType::Map {
199                key: Box::new(key.into_owned()),
200                value: Box::new(value.into_owned()),
201            },
202        }
203    }
204}
205
206// =============================================================================
207// PyroField
208// =============================================================================
209
210/// A named, nullable column descriptor — the Pyro equivalent of `arrow::datatypes::Field`.
211#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
212pub struct PyroField<'a> {
213    pub name: Cow<'a, str>,
214    pub documentation: Option<Cow<'a, str>>,
215    pub data_type: PyroType<'a>,
216    pub nullable: bool,
217}
218
219impl<'a> PyroField<'a> {
220    /// Create a new field.
221    /// Accepts `&'static str`, `String`, or `Cow<'a, str>`.
222    pub fn new(name: impl Into<Cow<'a, str>>, data_type: PyroType<'a>, nullable: bool) -> Self {
223        Self {
224            name: name.into(),
225            documentation: None,
226            data_type,
227            nullable,
228        }
229    }
230
231    #[inline]
232    pub fn name(&self) -> &str {
233        &self.name
234    }
235
236    #[inline]
237    pub fn data_type(&self) -> &PyroType<'a> {
238        &self.data_type
239    }
240
241    #[inline]
242    pub fn is_nullable(&self) -> bool {
243        self.nullable
244    }
245
246    pub fn with_nullable(mut self, nullable: bool) -> Self {
247        self.nullable = nullable;
248        self
249    }
250
251    /// Convert to an owned version (PyroField<'static>) by cloning data.
252    pub fn into_owned(self) -> PyroField<'static> {
253        PyroField {
254            name: Cow::Owned(self.name.into_owned()),
255            documentation: self.documentation.map(|d| Cow::Owned(d.into_owned())),
256            data_type: self.data_type.into_owned(),
257            nullable: self.nullable,
258        }
259    }
260
261    pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
262        self.documentation = Some(doc.into());
263        self
264    }
265}
266
267impl<'a> fmt::Display for PyroField<'a> {
268    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269        write!(
270            f,
271            "{}: {:?}{}",
272            self.name,
273            self.data_type,
274            if self.nullable { " (nullable)" } else { "" }
275        )
276    }
277}
278
279// =============================================================================
280// PyroSchema
281// =============================================================================
282
283/// An ordered collection of [`PyroField`]s — the Pyro equivalent of `arrow::datatypes::Schema`.
284#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
285pub struct PyroSchema<'a> {
286    pub documentation: Option<Cow<'a, str>>,
287    pub fields: Cow<'a, [PyroField<'a>]>,
288}
289
290impl<'a> PyroSchema<'a> {
291    pub fn new(fields: Vec<PyroField<'a>>) -> Self {
292        Self {
293            documentation: None,
294            fields: Cow::Owned(fields),
295        }
296    }
297
298    pub fn empty() -> Self {
299        Self {
300            documentation: None,
301            fields: Cow::Owned(Vec::new()),
302        }
303    }
304
305    #[inline]
306    pub fn fields(&self) -> &[PyroField<'a>] {
307        &self.fields
308    }
309
310    #[inline]
311    pub fn num_fields(&self) -> usize {
312        self.fields.len()
313    }
314
315    /// Look up a field by name (linear scan).
316    pub fn field_with_name(&self, name: &str) -> Option<&PyroField<'a>> {
317        self.fields.iter().find(|f| f.name == name)
318    }
319
320    /// Get a field by index.
321    pub fn field(&self, index: usize) -> &PyroField<'a> {
322        &self.fields[index]
323    }
324
325    /// Returns column index for the given name, if present.
326    pub fn index_of(&self, name: &str) -> Option<usize> {
327        self.fields.iter().position(|f| f.name == name)
328    }
329
330    /// Convert to an fully owned schema (useful for inference results).
331    pub fn into_owned(self) -> PyroSchema<'static> {
332        PyroSchema {
333            documentation: None,
334            fields: self
335                .fields
336                .into_iter()
337                .map(|f| f.clone().into_owned())
338                .collect(),
339        }
340    }
341
342    pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
343        self.documentation = Some(doc.into());
344        self
345    }
346}
347
348impl<'a> fmt::Display for PyroSchema<'a> {
349    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350        writeln!(f, "PyroSchema {{")?;
351        for field in self.fields.iter() {
352            writeln!(f, "  {field},")?;
353        }
354        write!(f, "}}")
355    }
356}
357
358impl<'a> From<Vec<PyroField<'a>>> for PyroSchema<'a> {
359    fn from(fields: Vec<PyroField<'a>>) -> Self {
360        Self::new(fields)
361    }
362}
363
364/// Coerce two [`PyroType`]s to a common supertype. Returns `None` if incompatible.
365pub fn coerce_pyro_types<'a>(a: &PyroType<'a>, b: &PyroType<'a>) -> Option<PyroType<'a>> {
366    if a == b {
367        return Some(a.clone());
368    }
369
370    use PyroType::*;
371
372    match (a, b) {
373        // Null widens to anything
374        (Null, other) | (other, Null) => Some(other.clone()),
375
376        // --- Primitive Scalar coercion ---
377        (PrimitiveScalar(pa), PrimitiveScalar(pb)) => {
378            coerce_primitive_types(*pa, *pb).map(PrimitiveScalar)
379        }
380
381        // --- List coercion (merge nullability) ---
382        (List(inner_a, null_a), List(inner_b, null_b)) => {
383            let merged_null = *null_a || *null_b;
384            coerce_pyro_types(&inner_a, &inner_b).map(|c| List(Box::new(c), merged_null))
385        }
386
387        // --- PrimitiveList coercion ---
388        (PrimitiveList(pa), PrimitiveList(pb)) => {
389            coerce_primitive_types(*pa, *pb).map(PrimitiveList)
390        }
391
392        // --- PrimitiveFixedList coercion ---
393        // Same size + coercible element type → PrimitiveFixedList
394        // Different size → promote to PrimitiveList
395        (PrimitiveFixedList(pa, sa), PrimitiveFixedList(pb, sb)) => {
396            let coerced_elem = coerce_primitive_types(*pa, *pb)?;
397            if sa == sb {
398                Some(PrimitiveFixedList(coerced_elem, *sa))
399            } else {
400                Some(PrimitiveList(coerced_elem))
401            }
402        }
403
404        // PrimitiveFixedList + PrimitiveList → PrimitiveList
405        (PrimitiveFixedList(pa, _), PrimitiveList(pb))
406        | (PrimitiveList(pa), PrimitiveFixedList(pb, _)) => {
407            coerce_primitive_types(*pa, *pb).map(PrimitiveList)
408        }
409
410        // --- Group (struct) coercion: merge fields ---
411        (Group(fields_a), Group(fields_b)) => {
412            let mut merged_map: BTreeMap<String, PyroField> = BTreeMap::new();
413
414            for f in fields_a.iter().chain(fields_b.iter()) {
415                match merged_map.get(f.name()) {
416                    None => {
417                        // Field only in one side so far — mark nullable since the other side lacks it
418                        merged_map.insert(
419                            f.name().to_string(),
420                            PyroField::new(
421                                Cow::Owned(f.name().to_string()),
422                                f.data_type().clone(),
423                                true,
424                            ),
425                        );
426                    }
427                    Some(existing) => {
428                        let coerced = coerce_pyro_types(existing.data_type(), f.data_type())?;
429                        let nullable = existing.is_nullable() || f.is_nullable();
430                        merged_map.insert(
431                            f.name().to_string(),
432                            PyroField::new(Cow::Owned(f.name().to_string()), coerced, nullable),
433                        );
434                    }
435                }
436            }
437
438            Some(Group(Cow::Owned(merged_map.into_values().collect())))
439        }
440
441        // --- Map Coercion ---
442        (Map { key: ka, value: va }, Map { key: kb, value: vb }) => {
443            let coerced_key = coerce_pyro_types(ka, kb)?;
444            let coerced_val = coerce_pyro_types(va, vb)?;
445            Some(Map {
446                key: Box::new(coerced_key),
447                value: Box::new(coerced_val),
448            })
449        }
450
451        _ => None,
452    }
453}
454
455fn coerce_primitive_types(a: PrimitiveDataType, b: PrimitiveDataType) -> Option<PrimitiveDataType> {
456    if a == b {
457        return Some(a);
458    }
459
460    use PrimitiveDataType as P;
461
462    match (a, b) {
463        (P::I8, P::I16) | (P::I16, P::I8) => Some(P::I16),
464        (P::I8, P::I32) | (P::I32, P::I8) => Some(P::I32),
465        (P::I8, P::I64) | (P::I64, P::I8) => Some(P::I64),
466        (P::I16, P::I32) | (P::I32, P::I16) => Some(P::I32),
467        (P::I16, P::I64) | (P::I64, P::I16) => Some(P::I64),
468        (P::I32, P::I64) | (P::I64, P::I32) => Some(P::I64),
469
470        (P::U8, P::U16) | (P::U16, P::U8) => Some(P::U16),
471        (P::U8, P::U32) | (P::U32, P::U8) => Some(P::U32),
472        (P::U8, P::U64) | (P::U64, P::U8) => Some(P::U64),
473        (P::U16, P::U32) | (P::U32, P::U16) => Some(P::U32),
474        (P::U16, P::U64) | (P::U64, P::U16) => Some(P::U64),
475        (P::U32, P::U64) | (P::U64, P::U32) => Some(P::U64),
476
477        (P::F16, P::F32) | (P::F32, P::F16) => Some(P::F32),
478        (P::F32, P::F64) | (P::F64, P::F32) => Some(P::F64),
479        (P::F16, P::F64) | (P::F64, P::F16) => Some(P::F64),
480
481        // --- Int to Float promotion ---
482        (P::I8 | P::I16 | P::I32 | P::I64, P::F64) | (P::F64, P::I8 | P::I16 | P::I32 | P::I64) => {
483            Some(P::F64)
484        }
485        (P::U8 | P::U16 | P::U32 | P::U64, P::F64) | (P::F64, P::U8 | P::U16 | P::U32 | P::U64) => {
486            Some(P::F64)
487        }
488
489        _ => None,
490    }
491}