1#[cfg(feature = "arrow")]
29mod arrow;
30
31use std::borrow::Cow;
32use std::collections::BTreeMap;
33use std::fmt;
34
35use serde::{Deserialize, Serialize};
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
39pub struct ModuleFunc<'a> {
40 pub name: Cow<'a, str>,
41 pub description: Option<Cow<'a, str>>,
42 pub input: PyroSchema<'a>,
43 pub output: PyroSchema<'a>,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub struct InterfaceSpec<'a> {
49 pub capability: Cow<'a, str>,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub description: Option<Cow<'a, str>>,
53
54 pub classes: Vec<ClassSpec<'a>>,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
58pub struct ClassSpec<'a> {
59 pub name: Cow<'a, str>,
60 pub description: Option<Cow<'a, str>>,
61 pub methods: Vec<CapabilityFunc<'a>>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub client: Option<PyroSchema<'a>>,
64 pub config: Option<PyroSchema<'a>>,
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
69pub struct CapabilityFunc<'a> {
70 pub name: Cow<'a, str>,
71 pub description: Option<Cow<'a, str>>,
72 pub input: PyroSchema<'a>,
73 pub output: PyroType<'a>,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
85pub enum PyroType<'a> {
86 Null,
88 PrimitiveScalar(PrimitiveDataType),
90 Str,
92 Timestamp,
94 PrimitiveList(PrimitiveDataType),
96 PrimitiveFixedList(PrimitiveDataType, usize),
98 List(Box<PyroType<'a>>, bool),
102 Group(Cow<'a, [PyroField<'a>]>),
104 Map {
106 key: Box<PyroType<'a>>,
107 value: Box<PyroType<'a>>,
108 },
109}
110
111impl<'a> fmt::Display for PyroType<'a> {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 match self {
114 PyroType::Null => write!(f, "Null"),
116 PyroType::PrimitiveScalar(t) => write!(f, "{}", t),
117 PyroType::Str => write!(f, "Str"),
118 PyroType::Timestamp => write!(f, "Timestamp"),
119
120 PyroType::PrimitiveList(inner_type) => {
122 write!(f, "[{}]", inner_type)
123 }
124 PyroType::PrimitiveFixedList(inner_type, len) => {
125 write!(f, "[{}; {}]", inner_type, len)
126 }
127 PyroType::List(inner_type, _nullable) => {
128 write!(f, "[{}]", inner_type)
129 }
130 PyroType::Group(fields) => {
131 write!(f, "{{ ")?;
132 for (i, field) in fields.iter().enumerate() {
133 if i > 0 {
134 write!(f, ", ")?;
135 }
136 write!(f, "{}: {}", field.name, field.data_type)?;
137 }
138 write!(f, " }}")
139 }
140 PyroType::Map { key, value } => {
141 write!(f, "Map<{}, {}>", key, value)
142 }
143 }
144 }
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
149pub enum PrimitiveDataType {
150 Bool,
151 U8,
152 U16,
153 U32,
154 U64,
155 I8,
156 I16,
157 I32,
158 I64,
159 F16,
160 F32,
161 F64,
162}
163
164impl fmt::Display for PrimitiveDataType {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 match self {
167 PrimitiveDataType::Bool => write!(f, "Bool"),
168 PrimitiveDataType::U8 => write!(f, "U8"),
169 PrimitiveDataType::U16 => write!(f, "U16"),
170 PrimitiveDataType::U32 => write!(f, "U32"),
171 PrimitiveDataType::U64 => write!(f, "U64"),
172 PrimitiveDataType::I8 => write!(f, "I8"),
173 PrimitiveDataType::I16 => write!(f, "I16"),
174 PrimitiveDataType::I32 => write!(f, "I32"),
175 PrimitiveDataType::I64 => write!(f, "I64"),
176 PrimitiveDataType::F16 => write!(f, "F16"),
177 PrimitiveDataType::F32 => write!(f, "F32"),
178 PrimitiveDataType::F64 => write!(f, "F64"),
179 }
180 }
181}
182
183impl<'a> PyroType<'a> {
184 pub fn into_owned(self) -> PyroType<'static> {
185 match self {
186 PyroType::Null => PyroType::Null,
187 PyroType::PrimitiveScalar(p) => PyroType::PrimitiveScalar(p),
188 PyroType::Str => PyroType::Str,
189 PyroType::Timestamp => PyroType::Timestamp,
190 PyroType::PrimitiveList(p) => PyroType::PrimitiveList(p),
191 PyroType::PrimitiveFixedList(p, l) => PyroType::PrimitiveFixedList(p, l),
192 PyroType::List(inner, n) => PyroType::List(Box::new(inner.into_owned()), n),
193 PyroType::Group(fields) => {
194 let owned_fields: Vec<PyroField<'static>> =
195 fields.iter().map(|f| f.clone().into_owned()).collect();
196 PyroType::Group(Cow::Owned(owned_fields))
197 }
198 PyroType::Map { key, value } => PyroType::Map {
199 key: Box::new(key.into_owned()),
200 value: Box::new(value.into_owned()),
201 },
202 }
203 }
204}
205
206#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
212pub struct PyroField<'a> {
213 pub name: Cow<'a, str>,
214 pub documentation: Option<Cow<'a, str>>,
215 pub data_type: PyroType<'a>,
216 pub nullable: bool,
217}
218
219impl<'a> PyroField<'a> {
220 pub fn new(name: impl Into<Cow<'a, str>>, data_type: PyroType<'a>, nullable: bool) -> Self {
223 Self {
224 name: name.into(),
225 documentation: None,
226 data_type,
227 nullable,
228 }
229 }
230
231 #[inline]
232 pub fn name(&self) -> &str {
233 &self.name
234 }
235
236 #[inline]
237 pub fn data_type(&self) -> &PyroType<'a> {
238 &self.data_type
239 }
240
241 #[inline]
242 pub fn is_nullable(&self) -> bool {
243 self.nullable
244 }
245
246 pub fn with_nullable(mut self, nullable: bool) -> Self {
247 self.nullable = nullable;
248 self
249 }
250
251 pub fn into_owned(self) -> PyroField<'static> {
253 PyroField {
254 name: Cow::Owned(self.name.into_owned()),
255 documentation: self.documentation.map(|d| Cow::Owned(d.into_owned())),
256 data_type: self.data_type.into_owned(),
257 nullable: self.nullable,
258 }
259 }
260
261 pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
262 self.documentation = Some(doc.into());
263 self
264 }
265}
266
267impl<'a> fmt::Display for PyroField<'a> {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 write!(
270 f,
271 "{}: {:?}{}",
272 self.name,
273 self.data_type,
274 if self.nullable { " (nullable)" } else { "" }
275 )
276 }
277}
278
279#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
285pub struct PyroSchema<'a> {
286 pub documentation: Option<Cow<'a, str>>,
287 pub fields: Cow<'a, [PyroField<'a>]>,
288}
289
290impl<'a> PyroSchema<'a> {
291 pub fn new(fields: Vec<PyroField<'a>>) -> Self {
292 Self {
293 documentation: None,
294 fields: Cow::Owned(fields),
295 }
296 }
297
298 pub fn empty() -> Self {
299 Self {
300 documentation: None,
301 fields: Cow::Owned(Vec::new()),
302 }
303 }
304
305 #[inline]
306 pub fn fields(&self) -> &[PyroField<'a>] {
307 &self.fields
308 }
309
310 #[inline]
311 pub fn num_fields(&self) -> usize {
312 self.fields.len()
313 }
314
315 pub fn field_with_name(&self, name: &str) -> Option<&PyroField<'a>> {
317 self.fields.iter().find(|f| f.name == name)
318 }
319
320 pub fn field(&self, index: usize) -> &PyroField<'a> {
322 &self.fields[index]
323 }
324
325 pub fn index_of(&self, name: &str) -> Option<usize> {
327 self.fields.iter().position(|f| f.name == name)
328 }
329
330 pub fn into_owned(self) -> PyroSchema<'static> {
332 PyroSchema {
333 documentation: None,
334 fields: self
335 .fields
336 .into_iter()
337 .map(|f| f.clone().into_owned())
338 .collect(),
339 }
340 }
341
342 pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
343 self.documentation = Some(doc.into());
344 self
345 }
346}
347
348impl<'a> fmt::Display for PyroSchema<'a> {
349 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350 writeln!(f, "PyroSchema {{")?;
351 for field in self.fields.iter() {
352 writeln!(f, " {field},")?;
353 }
354 write!(f, "}}")
355 }
356}
357
358impl<'a> From<Vec<PyroField<'a>>> for PyroSchema<'a> {
359 fn from(fields: Vec<PyroField<'a>>) -> Self {
360 Self::new(fields)
361 }
362}
363
364pub fn coerce_pyro_types<'a>(a: &PyroType<'a>, b: &PyroType<'a>) -> Option<PyroType<'a>> {
366 if a == b {
367 return Some(a.clone());
368 }
369
370 use PyroType::*;
371
372 match (a, b) {
373 (Null, other) | (other, Null) => Some(other.clone()),
375
376 (PrimitiveScalar(pa), PrimitiveScalar(pb)) => {
378 coerce_primitive_types(*pa, *pb).map(PrimitiveScalar)
379 }
380
381 (List(inner_a, null_a), List(inner_b, null_b)) => {
383 let merged_null = *null_a || *null_b;
384 coerce_pyro_types(&inner_a, &inner_b).map(|c| List(Box::new(c), merged_null))
385 }
386
387 (PrimitiveList(pa), PrimitiveList(pb)) => {
389 coerce_primitive_types(*pa, *pb).map(PrimitiveList)
390 }
391
392 (PrimitiveFixedList(pa, sa), PrimitiveFixedList(pb, sb)) => {
396 let coerced_elem = coerce_primitive_types(*pa, *pb)?;
397 if sa == sb {
398 Some(PrimitiveFixedList(coerced_elem, *sa))
399 } else {
400 Some(PrimitiveList(coerced_elem))
401 }
402 }
403
404 (PrimitiveFixedList(pa, _), PrimitiveList(pb))
406 | (PrimitiveList(pa), PrimitiveFixedList(pb, _)) => {
407 coerce_primitive_types(*pa, *pb).map(PrimitiveList)
408 }
409
410 (Group(fields_a), Group(fields_b)) => {
412 let mut merged_map: BTreeMap<String, PyroField> = BTreeMap::new();
413
414 for f in fields_a.iter().chain(fields_b.iter()) {
415 match merged_map.get(f.name()) {
416 None => {
417 merged_map.insert(
419 f.name().to_string(),
420 PyroField::new(
421 Cow::Owned(f.name().to_string()),
422 f.data_type().clone(),
423 true,
424 ),
425 );
426 }
427 Some(existing) => {
428 let coerced = coerce_pyro_types(existing.data_type(), f.data_type())?;
429 let nullable = existing.is_nullable() || f.is_nullable();
430 merged_map.insert(
431 f.name().to_string(),
432 PyroField::new(Cow::Owned(f.name().to_string()), coerced, nullable),
433 );
434 }
435 }
436 }
437
438 Some(Group(Cow::Owned(merged_map.into_values().collect())))
439 }
440
441 (Map { key: ka, value: va }, Map { key: kb, value: vb }) => {
443 let coerced_key = coerce_pyro_types(ka, kb)?;
444 let coerced_val = coerce_pyro_types(va, vb)?;
445 Some(Map {
446 key: Box::new(coerced_key),
447 value: Box::new(coerced_val),
448 })
449 }
450
451 _ => None,
452 }
453}
454
455fn coerce_primitive_types(a: PrimitiveDataType, b: PrimitiveDataType) -> Option<PrimitiveDataType> {
456 if a == b {
457 return Some(a);
458 }
459
460 use PrimitiveDataType as P;
461
462 match (a, b) {
463 (P::I8, P::I16) | (P::I16, P::I8) => Some(P::I16),
464 (P::I8, P::I32) | (P::I32, P::I8) => Some(P::I32),
465 (P::I8, P::I64) | (P::I64, P::I8) => Some(P::I64),
466 (P::I16, P::I32) | (P::I32, P::I16) => Some(P::I32),
467 (P::I16, P::I64) | (P::I64, P::I16) => Some(P::I64),
468 (P::I32, P::I64) | (P::I64, P::I32) => Some(P::I64),
469
470 (P::U8, P::U16) | (P::U16, P::U8) => Some(P::U16),
471 (P::U8, P::U32) | (P::U32, P::U8) => Some(P::U32),
472 (P::U8, P::U64) | (P::U64, P::U8) => Some(P::U64),
473 (P::U16, P::U32) | (P::U32, P::U16) => Some(P::U32),
474 (P::U16, P::U64) | (P::U64, P::U16) => Some(P::U64),
475 (P::U32, P::U64) | (P::U64, P::U32) => Some(P::U64),
476
477 (P::F16, P::F32) | (P::F32, P::F16) => Some(P::F32),
478 (P::F32, P::F64) | (P::F64, P::F32) => Some(P::F64),
479 (P::F16, P::F64) | (P::F64, P::F16) => Some(P::F64),
480
481 (P::I8 | P::I16 | P::I32 | P::I64, P::F64) | (P::F64, P::I8 | P::I16 | P::I32 | P::I64) => {
483 Some(P::F64)
484 }
485 (P::U8 | P::U16 | P::U32 | P::U64, P::F64) | (P::F64, P::U8 | P::U16 | P::U32 | P::U64) => {
486 Some(P::F64)
487 }
488
489 _ => None,
490 }
491}