1#[cfg(feature = "arrow")]
29mod arrow;
30
31use std::borrow::Cow;
32use std::collections::BTreeMap;
33use std::fmt;
34
35use serde::{Deserialize, Serialize};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
39#[serde(rename_all = "snake_case")]
40pub enum ModuleKind {
41 #[default]
42 Normal,
43 Session,
44 SessionDiff,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub struct ModuleFunc<'a> {
50 pub name: Cow<'a, str>,
51 pub description: Option<Cow<'a, str>>,
52 pub input: PyroSchema<'a>,
53 pub output: PyroSchema<'a>,
54 #[serde(default)]
55 pub kind: ModuleKind,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
60pub struct InterfaceSpec<'a> {
61 pub capability: Cow<'a, str>,
62
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub description: Option<Cow<'a, str>>,
65
66 pub classes: Vec<ClassSpec<'a>>,
67
68 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
69 pub structs: BTreeMap<Cow<'a, str>, PyroSchema<'a>>,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
73pub struct ClassSpec<'a> {
74 pub name: Cow<'a, str>,
75 pub description: Option<Cow<'a, str>>,
76 pub methods: Vec<CapabilityFunc<'a>>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 pub client: Option<PyroSchema<'a>>,
79 pub config: Option<PyroSchema<'a>>,
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
84pub struct CapabilityFunc<'a> {
85 pub name: Cow<'a, str>,
86 pub description: Option<Cow<'a, str>>,
87 pub input: PyroSchema<'a>,
88 pub output: PyroType<'a>,
89}
90
91#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
100pub enum PyroType<'a> {
101 Null,
103 PrimitiveScalar(PrimitiveDataType),
105 Str,
107 Timestamp,
109 PrimitiveList(PrimitiveDataType),
111 PrimitiveFixedList(PrimitiveDataType, usize),
113 List(Box<PyroType<'a>>, bool),
117 Group(Cow<'a, [PyroField<'a>]>),
119 Map {
121 key: Box<PyroType<'a>>,
122 value: Box<PyroType<'a>>,
123 },
124}
125
126impl<'a> fmt::Display for PyroType<'a> {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 match self {
129 PyroType::Null => write!(f, "Null"),
131 PyroType::PrimitiveScalar(t) => write!(f, "{}", t),
132 PyroType::Str => write!(f, "Str"),
133 PyroType::Timestamp => write!(f, "Timestamp"),
134
135 PyroType::PrimitiveList(inner_type) => {
137 write!(f, "[{}]", inner_type)
138 }
139 PyroType::PrimitiveFixedList(inner_type, len) => {
140 write!(f, "[{}; {}]", inner_type, len)
141 }
142 PyroType::List(inner_type, _nullable) => {
143 write!(f, "[{}]", inner_type)
144 }
145 PyroType::Group(fields) => {
146 write!(f, "{{ ")?;
147 for (i, field) in fields.iter().enumerate() {
148 if i > 0 {
149 write!(f, ", ")?;
150 }
151 write!(f, "{}: {}", field.name, field.data_type)?;
152 }
153 write!(f, " }}")
154 }
155 PyroType::Map { key, value } => {
156 write!(f, "Map<{}, {}>", key, value)
157 }
158 }
159 }
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
164pub enum PrimitiveDataType {
165 Bool,
166 U8,
167 U16,
168 U32,
169 U64,
170 I8,
171 I16,
172 I32,
173 I64,
174 F16,
175 F32,
176 F64,
177}
178
179impl fmt::Display for PrimitiveDataType {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 match self {
182 PrimitiveDataType::Bool => write!(f, "Bool"),
183 PrimitiveDataType::U8 => write!(f, "U8"),
184 PrimitiveDataType::U16 => write!(f, "U16"),
185 PrimitiveDataType::U32 => write!(f, "U32"),
186 PrimitiveDataType::U64 => write!(f, "U64"),
187 PrimitiveDataType::I8 => write!(f, "I8"),
188 PrimitiveDataType::I16 => write!(f, "I16"),
189 PrimitiveDataType::I32 => write!(f, "I32"),
190 PrimitiveDataType::I64 => write!(f, "I64"),
191 PrimitiveDataType::F16 => write!(f, "F16"),
192 PrimitiveDataType::F32 => write!(f, "F32"),
193 PrimitiveDataType::F64 => write!(f, "F64"),
194 }
195 }
196}
197
198impl<'a> PyroType<'a> {
199 pub fn into_owned(self) -> PyroType<'static> {
200 match self {
201 PyroType::Null => PyroType::Null,
202 PyroType::PrimitiveScalar(p) => PyroType::PrimitiveScalar(p),
203 PyroType::Str => PyroType::Str,
204 PyroType::Timestamp => PyroType::Timestamp,
205 PyroType::PrimitiveList(p) => PyroType::PrimitiveList(p),
206 PyroType::PrimitiveFixedList(p, l) => PyroType::PrimitiveFixedList(p, l),
207 PyroType::List(inner, n) => PyroType::List(Box::new(inner.into_owned()), n),
208 PyroType::Group(fields) => {
209 let owned_fields: Vec<PyroField<'static>> =
210 fields.iter().map(|f| f.clone().into_owned()).collect();
211 PyroType::Group(Cow::Owned(owned_fields))
212 }
213 PyroType::Map { key, value } => PyroType::Map {
214 key: Box::new(key.into_owned()),
215 value: Box::new(value.into_owned()),
216 },
217 }
218 }
219}
220
221#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
227pub struct PyroField<'a> {
228 pub name: Cow<'a, str>,
229 pub documentation: Option<Cow<'a, str>>,
230 pub data_type: PyroType<'a>,
231 pub nullable: bool,
232}
233
234impl<'a> PyroField<'a> {
235 pub fn new(name: impl Into<Cow<'a, str>>, data_type: PyroType<'a>, nullable: bool) -> Self {
238 Self {
239 name: name.into(),
240 documentation: None,
241 data_type,
242 nullable,
243 }
244 }
245
246 #[inline]
247 pub fn name(&self) -> &str {
248 &self.name
249 }
250
251 #[inline]
252 pub fn data_type(&self) -> &PyroType<'a> {
253 &self.data_type
254 }
255
256 #[inline]
257 pub fn is_nullable(&self) -> bool {
258 self.nullable
259 }
260
261 pub fn with_nullable(mut self, nullable: bool) -> Self {
262 self.nullable = nullable;
263 self
264 }
265
266 pub fn into_owned(self) -> PyroField<'static> {
268 PyroField {
269 name: Cow::Owned(self.name.into_owned()),
270 documentation: self.documentation.map(|d| Cow::Owned(d.into_owned())),
271 data_type: self.data_type.into_owned(),
272 nullable: self.nullable,
273 }
274 }
275
276 pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
277 self.documentation = Some(doc.into());
278 self
279 }
280}
281
282impl<'a> fmt::Display for PyroField<'a> {
283 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284 write!(
285 f,
286 "{}: {:?}{}",
287 self.name,
288 self.data_type,
289 if self.nullable { " (nullable)" } else { "" }
290 )
291 }
292}
293
294#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
300pub struct PyroSchema<'a> {
301 pub documentation: Option<Cow<'a, str>>,
302 pub fields: Cow<'a, [PyroField<'a>]>,
303}
304
305impl<'a> PyroSchema<'a> {
306 pub fn new(fields: Vec<PyroField<'a>>) -> Self {
307 Self {
308 documentation: None,
309 fields: Cow::Owned(fields),
310 }
311 }
312
313 pub fn empty() -> Self {
314 Self {
315 documentation: None,
316 fields: Cow::Owned(Vec::new()),
317 }
318 }
319
320 #[inline]
321 pub fn fields(&self) -> &[PyroField<'a>] {
322 &self.fields
323 }
324
325 #[inline]
326 pub fn num_fields(&self) -> usize {
327 self.fields.len()
328 }
329
330 pub fn field_with_name(&self, name: &str) -> Option<&PyroField<'a>> {
332 self.fields.iter().find(|f| f.name == name)
333 }
334
335 pub fn field(&self, index: usize) -> &PyroField<'a> {
337 &self.fields[index]
338 }
339
340 pub fn index_of(&self, name: &str) -> Option<usize> {
342 self.fields.iter().position(|f| f.name == name)
343 }
344
345 pub fn into_owned(self) -> PyroSchema<'static> {
347 PyroSchema {
348 documentation: self.documentation.map(|d| Cow::Owned(d.into_owned())),
349 fields: self.fields.iter().map(|f| f.clone().into_owned()).collect(),
350 }
351 }
352
353 pub fn add_docstring(mut self, doc: impl Into<Cow<'a, str>>) -> Self {
354 self.documentation = Some(doc.into());
355 self
356 }
357}
358
359impl<'a> fmt::Display for PyroSchema<'a> {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 writeln!(f, "PyroSchema {{")?;
362 for field in self.fields.iter() {
363 writeln!(f, " {field},")?;
364 }
365 write!(f, "}}")
366 }
367}
368
369impl<'a> From<Vec<PyroField<'a>>> for PyroSchema<'a> {
370 fn from(fields: Vec<PyroField<'a>>) -> Self {
371 Self::new(fields)
372 }
373}
374
375pub fn coerce_pyro_types<'a>(a: &PyroType<'a>, b: &PyroType<'a>) -> Option<PyroType<'a>> {
377 if a == b {
378 return Some(a.clone());
379 }
380
381 use PyroType::*;
382
383 match (a, b) {
384 (Null, other) | (other, Null) => Some(other.clone()),
386
387 (PrimitiveScalar(pa), PrimitiveScalar(pb)) => {
389 coerce_primitive_types(*pa, *pb).map(PrimitiveScalar)
390 }
391
392 (List(inner_a, null_a), List(inner_b, null_b)) => {
394 let merged_null = *null_a || *null_b;
395 coerce_pyro_types(inner_a, inner_b).map(|c| List(Box::new(c), merged_null))
396 }
397
398 (PrimitiveList(pa), PrimitiveList(pb)) => {
400 coerce_primitive_types(*pa, *pb).map(PrimitiveList)
401 }
402
403 (PrimitiveFixedList(pa, sa), PrimitiveFixedList(pb, sb)) => {
407 let coerced_elem = coerce_primitive_types(*pa, *pb)?;
408 if sa == sb {
409 Some(PrimitiveFixedList(coerced_elem, *sa))
410 } else {
411 Some(PrimitiveList(coerced_elem))
412 }
413 }
414
415 (PrimitiveFixedList(pa, _), PrimitiveList(pb))
417 | (PrimitiveList(pa), PrimitiveFixedList(pb, _)) => {
418 coerce_primitive_types(*pa, *pb).map(PrimitiveList)
419 }
420
421 (Group(fields_a), Group(fields_b)) => {
423 let mut merged_map: BTreeMap<String, PyroField> = BTreeMap::new();
424
425 for f in fields_a.iter().chain(fields_b.iter()) {
426 match merged_map.get(f.name()) {
427 None => {
428 merged_map.insert(
430 f.name().to_string(),
431 PyroField::new(
432 Cow::Owned(f.name().to_string()),
433 f.data_type().clone(),
434 true,
435 ),
436 );
437 }
438 Some(existing) => {
439 let coerced = coerce_pyro_types(existing.data_type(), f.data_type())?;
440 let nullable = existing.is_nullable() || f.is_nullable();
441 merged_map.insert(
442 f.name().to_string(),
443 PyroField::new(Cow::Owned(f.name().to_string()), coerced, nullable),
444 );
445 }
446 }
447 }
448
449 Some(Group(Cow::Owned(merged_map.into_values().collect())))
450 }
451
452 (Map { key: ka, value: va }, Map { key: kb, value: vb }) => {
454 let coerced_key = coerce_pyro_types(ka, kb)?;
455 let coerced_val = coerce_pyro_types(va, vb)?;
456 Some(Map {
457 key: Box::new(coerced_key),
458 value: Box::new(coerced_val),
459 })
460 }
461
462 _ => None,
463 }
464}
465
466fn coerce_primitive_types(a: PrimitiveDataType, b: PrimitiveDataType) -> Option<PrimitiveDataType> {
467 if a == b {
468 return Some(a);
469 }
470
471 use PrimitiveDataType as P;
472
473 match (a, b) {
474 (P::I8, P::I16) | (P::I16, P::I8) => Some(P::I16),
475 (P::I8, P::I32) | (P::I32, P::I8) => Some(P::I32),
476 (P::I8, P::I64) | (P::I64, P::I8) => Some(P::I64),
477 (P::I16, P::I32) | (P::I32, P::I16) => Some(P::I32),
478 (P::I16, P::I64) | (P::I64, P::I16) => Some(P::I64),
479 (P::I32, P::I64) | (P::I64, P::I32) => Some(P::I64),
480
481 (P::U8, P::U16) | (P::U16, P::U8) => Some(P::U16),
482 (P::U8, P::U32) | (P::U32, P::U8) => Some(P::U32),
483 (P::U8, P::U64) | (P::U64, P::U8) => Some(P::U64),
484 (P::U16, P::U32) | (P::U32, P::U16) => Some(P::U32),
485 (P::U16, P::U64) | (P::U64, P::U16) => Some(P::U64),
486 (P::U32, P::U64) | (P::U64, P::U32) => Some(P::U64),
487
488 (P::F16, P::F32) | (P::F32, P::F16) => Some(P::F32),
489 (P::F32, P::F64) | (P::F64, P::F32) => Some(P::F64),
490 (P::F16, P::F64) | (P::F64, P::F16) => Some(P::F64),
491
492 (P::I8 | P::I16 | P::I32 | P::I64, P::F64) | (P::F64, P::I8 | P::I16 | P::I32 | P::I64) => {
494 Some(P::F64)
495 }
496 (P::U8 | P::U16 | P::U32 | P::U64, P::F64) | (P::F64, P::U8 | P::U16 | P::U32 | P::U64) => {
497 Some(P::F64)
498 }
499
500 _ => None,
501 }
502}