Skip to main content

pyro_spec/
lib.rs

1//! Pyro-native type system and schema representation.
2//!
3//! This crate provides a lightweight schema representation optimized for
4//! `PyroValue`. Unlike Arrow's `DataType` (which has ~40 variants for timestamps,
5//! decimals, run-end-encoded, etc.), `PyroType` mirrors *exactly* the variants that
6//! `PyroValue` can represent, making match arms exhaustive and tiny.
7//!
8//! Core types:
9//! - [`PyroType`] — the main type enum, mirroring `PyroValue` discriminants 1:1
10//! - [`PyroField`] — a named, nullable column descriptor (equivalent to Arrow `Field`)
11//! - [`PyroSchema`] — an ordered collection of `PyroField`s (equivalent to Arrow `Schema`)
12//! - [`coerce_pyro_types`] — type coercion to find common supertypes
13//!
14//! Conversion to/from `arrow::datatypes::DataType` lives in the `arrow` module
15//! (behind the `arrow` feature flag).
16
17// =============================================================================
18// Pyro-native type system
19// =============================================================================
20//
21// A lightweight schema representation optimized for PyroValue.
22// Unlike Arrow's DataType (which has ~40 variants for timestamps, decimals,
23// run-end-encoded, etc.), PyroType mirrors *exactly* the variants that
24// PyroValue can represent, making match arms exhaustive and tiny.
25//
26// Conversion to/from `arrow::datatypes::DataType` lives in `value::arrow::schema`.
27
28#[cfg(feature = "arrow")]
29mod arrow;
30
31use std::borrow::Cow;
32use std::collections::BTreeMap;
33use std::fmt;
34
35use serde::{Deserialize, Serialize};
36
37/// The kind of module execution model
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
39#[serde(rename_all = "snake_case")]
40pub enum ModuleKind {
41    #[default]
42    Normal,
43    Session,
44    SessionDiff,
45}
46
47/// Documentation for the main function of a module
48#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub struct ModuleFunc<'a> {
50    pub name: Cow<'a, str>,
51    pub description: Option<Cow<'a, str>>,
52    pub input: PyroSchema<'a>,
53    pub output: PyroSchema<'a>,
54    #[serde(default)]
55    pub kind: ModuleKind,
56}
57
58/// The root specification object.
59#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
60pub struct InterfaceSpec<'a> {
61    pub capability: Cow<'a, str>,
62
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub description: Option<Cow<'a, str>>,
65
66    pub classes: Vec<ClassSpec<'a>>,
67
68    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
69    pub structs: BTreeMap<Cow<'a, str>, PyroSchema<'a>>,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
73pub struct ClassSpec<'a> {
74    pub name: Cow<'a, str>,
75    pub description: Option<Cow<'a, str>>,
76    pub methods: Vec<CapabilityFunc<'a>>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub client: Option<PyroSchema<'a>>,
79    pub config: Option<PyroSchema<'a>>,
80}
81
82/// Documentation for a capability function
83#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
84pub struct CapabilityFunc<'a> {
85    pub name: Cow<'a, str>,
86    pub description: Option<Cow<'a, str>>,
87    pub input: PyroSchema<'a>,
88    pub output: PyroType<'a>,
89}
90
91// =============================================================================
92// PyroType
93// =============================================================================
94
95/// A data type enum that mirrors the variants of [`PyroValue`] exactly.
96///
97/// This is intentionally much smaller than `arrow::datatypes::DataType`.
98/// Every variant here has a 1:1 correspondence with a `PyroValue` discriminant.
99#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
100pub enum PyroType<'a> {
101    /// No value / unknown type (corresponds to `PyroValue::Null`).
102    Null,
103    /// Scalar primitive (Bool, Int, Float).
104    PrimitiveScalar(PrimitiveDataType),
105    /// UTF-8 string (corresponds to `PyroValue::Str`).
106    Str,
107    /// Day + millisecond interval (corresponds to `PyroValue::Timestamp`).
108    Timestamp,
109    /// Homogeneous list of a single primitive type (corresponds to `PyroValue::PrimitiveList`).
110    PrimitiveList(PrimitiveDataType),
111    /// Fixed-size homogeneous list of a single primitive type.
112    PrimitiveFixedList(PrimitiveDataType, usize),
113    /// Heterogeneous list of arbitrary pyro values (corresponds to `PyroValue::List`).
114    ///
115    /// Fields: `(element_type, element_nullable)`.
116    List(Box<PyroType<'a>>, bool),
117    /// Named struct / row (corresponds to `PyroValue::Group`).
118    Group(Cow<'a, [PyroField<'a>]>),
119    /// Key-value map (corresponds to `PyroValue::MapInternal`).
120    Map {
121        key: Box<PyroType<'a>>,
122        value: Box<PyroType<'a>>,
123    },
124}
125
126impl<'a> fmt::Display for PyroType<'a> {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        match self {
129            // Primitives
130            PyroType::Null => write!(f, "Null"),
131            PyroType::PrimitiveScalar(t) => write!(f, "{}", t),
132            PyroType::Str => write!(f, "Str"),
133            PyroType::Timestamp => write!(f, "Timestamp"),
134
135            // Complex Types
136            PyroType::PrimitiveList(inner_type) => {
137                write!(f, "[{}]", inner_type)
138            }
139            PyroType::PrimitiveFixedList(inner_type, len) => {
140                write!(f, "[{}; {}]", inner_type, len)
141            }
142            PyroType::List(inner_type, _nullable) => {
143                write!(f, "[{}]", inner_type)
144            }
145            PyroType::Group(fields) => {
146                write!(f, "{{ ")?;
147                for (i, field) in fields.iter().enumerate() {
148                    if i > 0 {
149                        write!(f, ", ")?;
150                    }
151                    write!(f, "{}: {}", field.name, field.data_type)?;
152                }
153                write!(f, " }}")
154            }
155            PyroType::Map { key, value } => {
156                write!(f, "Map<{}, {}>", key, value)
157            }
158        }
159    }
160}
161
162/// The primitive element type inside a `PrimitiveValueList`.
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
164pub enum PrimitiveDataType {
165    Bool,
166    U8,
167    U16,
168    U32,
169    U64,
170    I8,
171    I16,
172    I32,
173    I64,
174    F16,
175    F32,
176    F64,
177}
178
179impl fmt::Display for PrimitiveDataType {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        match self {
182            PrimitiveDataType::Bool => write!(f, "Bool"),
183            PrimitiveDataType::U8 => write!(f, "U8"),
184            PrimitiveDataType::U16 => write!(f, "U16"),
185            PrimitiveDataType::U32 => write!(f, "U32"),
186            PrimitiveDataType::U64 => write!(f, "U64"),
187            PrimitiveDataType::I8 => write!(f, "I8"),
188            PrimitiveDataType::I16 => write!(f, "I16"),
189            PrimitiveDataType::I32 => write!(f, "I32"),
190            PrimitiveDataType::I64 => write!(f, "I64"),
191            PrimitiveDataType::F16 => write!(f, "F16"),
192            PrimitiveDataType::F32 => write!(f, "F32"),
193            PrimitiveDataType::F64 => write!(f, "F64"),
194        }
195    }
196}
197
198impl<'a> PyroType<'a> {
199    pub fn into_owned(self) -> PyroType<'static> {
200        match self {
201            PyroType::Null => PyroType::Null,
202            PyroType::PrimitiveScalar(p) => PyroType::PrimitiveScalar(p),
203            PyroType::Str => PyroType::Str,
204            PyroType::Timestamp => PyroType::Timestamp,
205            PyroType::PrimitiveList(p) => PyroType::PrimitiveList(p),
206            PyroType::PrimitiveFixedList(p, l) => PyroType::PrimitiveFixedList(p, l),
207            PyroType::List(inner, n) => PyroType::List(Box::new(inner.into_owned()), n),
208            PyroType::Group(fields) => {
209                let owned_fields: Vec<PyroField<'static>> =
210                    fields.iter().map(|f| f.clone().into_owned()).collect();
211                PyroType::Group(Cow::Owned(owned_fields))
212            }
213            PyroType::Map { key, value } => PyroType::Map {
214                key: Box::new(key.into_owned()),
215                value: Box::new(value.into_owned()),
216            },
217        }
218    }
219}
220
221// =============================================================================
222// PyroField
223// =============================================================================
224
225/// A named, nullable column descriptor — the Pyro equivalent of `arrow::datatypes::Field`.
226#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
227pub struct PyroField<'a> {
228    pub name: Cow<'a, str>,
229    pub documentation: Option<Cow<'a, str>>,
230    pub data_type: PyroType<'a>,
231    pub nullable: bool,
232}
233
234impl<'a> PyroField<'a> {
235    /// Create a new field.
236    /// Accepts `&'static str`, `String`, or `Cow<'a, str>`.
237    pub fn new(name: impl Into<Cow<'a, str>>, data_type: PyroType<'a>, nullable: bool) -> Self {
238        Self {
239            name: name.into(),
240            documentation: None,
241            data_type,
242            nullable,
243        }
244    }
245
246    #[inline]
247    pub fn name(&self) -> &str {
248        &self.name
249    }
250
251    #[inline]
252    pub fn data_type(&self) -> &PyroType<'a> {
253        &self.data_type
254    }
255
256    #[inline]
257    pub fn is_nullable(&self) -> bool {
258        self.nullable
259    }
260
261    pub fn with_nullable(mut self, nullable: bool) -> Self {
262        self.nullable = nullable;
263        self
264    }
265
266    /// Convert to an owned version (PyroField<'static>) by cloning data.
267    pub fn into_owned(self) -> PyroField<'static> {
268        PyroField {
269            name: Cow::Owned(self.name.into_owned()),
270            documentation: self.documentation.map(|d| Cow::Owned(d.into_owned())),
271            data_type: self.data_type.into_owned(),
272            nullable: self.nullable,
273        }
274    }
275
276    pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
277        self.documentation = Some(doc.into());
278        self
279    }
280}
281
282impl<'a> fmt::Display for PyroField<'a> {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        write!(
285            f,
286            "{}: {:?}{}",
287            self.name,
288            self.data_type,
289            if self.nullable { " (nullable)" } else { "" }
290        )
291    }
292}
293
294// =============================================================================
295// PyroSchema
296// =============================================================================
297
298/// An ordered collection of [`PyroField`]s — the Pyro equivalent of `arrow::datatypes::Schema`.
299#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
300pub struct PyroSchema<'a> {
301    pub documentation: Option<Cow<'a, str>>,
302    pub fields: Cow<'a, [PyroField<'a>]>,
303}
304
305impl<'a> PyroSchema<'a> {
306    pub fn new(fields: Vec<PyroField<'a>>) -> Self {
307        Self {
308            documentation: None,
309            fields: Cow::Owned(fields),
310        }
311    }
312
313    pub fn empty() -> Self {
314        Self {
315            documentation: None,
316            fields: Cow::Owned(Vec::new()),
317        }
318    }
319
320    #[inline]
321    pub fn fields(&self) -> &[PyroField<'a>] {
322        &self.fields
323    }
324
325    #[inline]
326    pub fn num_fields(&self) -> usize {
327        self.fields.len()
328    }
329
330    /// Look up a field by name (linear scan).
331    pub fn field_with_name(&self, name: &str) -> Option<&PyroField<'a>> {
332        self.fields.iter().find(|f| f.name == name)
333    }
334
335    /// Get a field by index.
336    pub fn field(&self, index: usize) -> &PyroField<'a> {
337        &self.fields[index]
338    }
339
340    /// Returns column index for the given name, if present.
341    pub fn index_of(&self, name: &str) -> Option<usize> {
342        self.fields.iter().position(|f| f.name == name)
343    }
344
345    /// Convert to an fully owned schema (useful for inference results).
346    pub fn into_owned(self) -> PyroSchema<'static> {
347        PyroSchema {
348            documentation: self.documentation.map(|d| Cow::Owned(d.into_owned())),
349            fields: self.fields.iter().map(|f| f.clone().into_owned()).collect(),
350        }
351    }
352
353    pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
354        self.documentation = Some(doc.into());
355        self
356    }
357}
358
359impl<'a> fmt::Display for PyroSchema<'a> {
360    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361        writeln!(f, "PyroSchema {{")?;
362        for field in self.fields.iter() {
363            writeln!(f, "  {field},")?;
364        }
365        write!(f, "}}")
366    }
367}
368
369impl<'a> From<Vec<PyroField<'a>>> for PyroSchema<'a> {
370    fn from(fields: Vec<PyroField<'a>>) -> Self {
371        Self::new(fields)
372    }
373}
374
375/// Coerce two [`PyroType`]s to a common supertype. Returns `None` if incompatible.
376pub fn coerce_pyro_types<'a>(a: &PyroType<'a>, b: &PyroType<'a>) -> Option<PyroType<'a>> {
377    if a == b {
378        return Some(a.clone());
379    }
380
381    use PyroType::*;
382
383    match (a, b) {
384        // Null widens to anything
385        (Null, other) | (other, Null) => Some(other.clone()),
386
387        // --- Primitive Scalar coercion ---
388        (PrimitiveScalar(pa), PrimitiveScalar(pb)) => {
389            coerce_primitive_types(*pa, *pb).map(PrimitiveScalar)
390        }
391
392        // --- List coercion (merge nullability) ---
393        (List(inner_a, null_a), List(inner_b, null_b)) => {
394            let merged_null = *null_a || *null_b;
395            coerce_pyro_types(inner_a, inner_b).map(|c| List(Box::new(c), merged_null))
396        }
397
398        // --- PrimitiveList coercion ---
399        (PrimitiveList(pa), PrimitiveList(pb)) => {
400            coerce_primitive_types(*pa, *pb).map(PrimitiveList)
401        }
402
403        // --- PrimitiveFixedList coercion ---
404        // Same size + coercible element type → PrimitiveFixedList
405        // Different size → promote to PrimitiveList
406        (PrimitiveFixedList(pa, sa), PrimitiveFixedList(pb, sb)) => {
407            let coerced_elem = coerce_primitive_types(*pa, *pb)?;
408            if sa == sb {
409                Some(PrimitiveFixedList(coerced_elem, *sa))
410            } else {
411                Some(PrimitiveList(coerced_elem))
412            }
413        }
414
415        // PrimitiveFixedList + PrimitiveList → PrimitiveList
416        (PrimitiveFixedList(pa, _), PrimitiveList(pb))
417        | (PrimitiveList(pa), PrimitiveFixedList(pb, _)) => {
418            coerce_primitive_types(*pa, *pb).map(PrimitiveList)
419        }
420
421        // --- Group (struct) coercion: merge fields ---
422        (Group(fields_a), Group(fields_b)) => {
423            let mut merged_map: BTreeMap<String, PyroField> = BTreeMap::new();
424
425            for f in fields_a.iter().chain(fields_b.iter()) {
426                match merged_map.get(f.name()) {
427                    None => {
428                        // Field only in one side so far — mark nullable since the other side lacks it
429                        merged_map.insert(
430                            f.name().to_string(),
431                            PyroField::new(
432                                Cow::Owned(f.name().to_string()),
433                                f.data_type().clone(),
434                                true,
435                            ),
436                        );
437                    }
438                    Some(existing) => {
439                        let coerced = coerce_pyro_types(existing.data_type(), f.data_type())?;
440                        let nullable = existing.is_nullable() || f.is_nullable();
441                        merged_map.insert(
442                            f.name().to_string(),
443                            PyroField::new(Cow::Owned(f.name().to_string()), coerced, nullable),
444                        );
445                    }
446                }
447            }
448
449            Some(Group(Cow::Owned(merged_map.into_values().collect())))
450        }
451
452        // --- Map Coercion ---
453        (Map { key: ka, value: va }, Map { key: kb, value: vb }) => {
454            let coerced_key = coerce_pyro_types(ka, kb)?;
455            let coerced_val = coerce_pyro_types(va, vb)?;
456            Some(Map {
457                key: Box::new(coerced_key),
458                value: Box::new(coerced_val),
459            })
460        }
461
462        _ => None,
463    }
464}
465
466fn coerce_primitive_types(a: PrimitiveDataType, b: PrimitiveDataType) -> Option<PrimitiveDataType> {
467    if a == b {
468        return Some(a);
469    }
470
471    use PrimitiveDataType as P;
472
473    match (a, b) {
474        (P::I8, P::I16) | (P::I16, P::I8) => Some(P::I16),
475        (P::I8, P::I32) | (P::I32, P::I8) => Some(P::I32),
476        (P::I8, P::I64) | (P::I64, P::I8) => Some(P::I64),
477        (P::I16, P::I32) | (P::I32, P::I16) => Some(P::I32),
478        (P::I16, P::I64) | (P::I64, P::I16) => Some(P::I64),
479        (P::I32, P::I64) | (P::I64, P::I32) => Some(P::I64),
480
481        (P::U8, P::U16) | (P::U16, P::U8) => Some(P::U16),
482        (P::U8, P::U32) | (P::U32, P::U8) => Some(P::U32),
483        (P::U8, P::U64) | (P::U64, P::U8) => Some(P::U64),
484        (P::U16, P::U32) | (P::U32, P::U16) => Some(P::U32),
485        (P::U16, P::U64) | (P::U64, P::U16) => Some(P::U64),
486        (P::U32, P::U64) | (P::U64, P::U32) => Some(P::U64),
487
488        (P::F16, P::F32) | (P::F32, P::F16) => Some(P::F32),
489        (P::F32, P::F64) | (P::F64, P::F32) => Some(P::F64),
490        (P::F16, P::F64) | (P::F64, P::F16) => Some(P::F64),
491
492        // --- Int to Float promotion ---
493        (P::I8 | P::I16 | P::I32 | P::I64, P::F64) | (P::F64, P::I8 | P::I16 | P::I32 | P::I64) => {
494            Some(P::F64)
495        }
496        (P::U8 | P::U16 | P::U32 | P::U64, P::F64) | (P::F64, P::U8 | P::U16 | P::U32 | P::U64) => {
497            Some(P::F64)
498        }
499
500        _ => None,
501    }
502}