1use std::{fmt::Display, net::IpAddr, str::FromStr, time::Duration};
2
3use macaddr::MacAddr;
4use measurements::{AngularVelocity, Frequency, Power, Temperature, Voltage};
5use pyo3::{
6 PyTypeInfo,
7 exceptions::PyValueError,
8 prelude::*,
9 types::{PyAnyMethods, PyBool, PyDict, PyDictMethods, PyList, PyListMethods, PyType},
10};
11
12pub use asic_rs_pydantic_macros::{
13 PyPydanticData, PyPydanticEnum, PyPydanticModel, PyPydanticTaggedEnum, PyPydanticTaggedUnion,
14 py_pydantic_model,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum PydanticSchemaMode {
19 Validation,
20 Serialization,
21}
22
23pub trait PyPydanticType: Sized {
24 fn pydantic_schema<'py>(
25 core_schema: &Bound<'py, PyAny>,
26 mode: PydanticSchemaMode,
27 ) -> PyResult<Bound<'py, PyAny>>;
28
29 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self>;
30
31 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>>;
32
33 fn to_pydantic_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
34 self.to_pydantic_data(py)
35 }
36}
37
38pub trait PyPydanticStringEnum: Clone + Display + FromStr + PyTypeInfo + Sized {
39 const PYDANTIC_VALUES: &'static [&'static str];
40
41 fn to_pydantic_enum_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>>;
42}
43
44impl<T> PyPydanticType for T
45where
46 T: PyPydanticStringEnum + for<'py> FromPyObject<'py, 'py>,
47 <T as FromStr>::Err: Display,
48 for<'py> <T as FromPyObject<'py, 'py>>::Error: Into<PyErr>,
49{
50 fn pydantic_schema<'py>(
51 core_schema: &Bound<'py, PyAny>,
52 mode: PydanticSchemaMode,
53 ) -> PyResult<Bound<'py, PyAny>> {
54 let string_schema = literal_schema(core_schema, T::PYDANTIC_VALUES)?;
55 if mode == PydanticSchemaMode::Serialization {
56 return Ok(string_schema);
57 }
58
59 let py = core_schema.py();
60 let instance_schema =
61 core_schema.call_method1("is_instance_schema", (py.get_type::<T>(),))?;
62 union_schema(core_schema, [instance_schema, string_schema])
63 }
64
65 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
66 if let Ok(value) = value.extract::<T>() {
67 return Ok(value);
68 }
69
70 let value = value.extract::<String>()?;
71 T::from_str(&value).map_err(|error| PyValueError::new_err(error.to_string()))
72 }
73
74 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
75 Ok(self.to_string().into_pyobject(py)?.into_any().unbind())
76 }
77
78 fn to_pydantic_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
79 self.to_pydantic_enum_repr_value(py)
80 }
81}
82
83macro_rules! impl_pydantic_python_value {
84 ($schema:literal; $($ty:ty),* $(,)?) => {
85 $(
86 impl PyPydanticType for $ty {
87 fn pydantic_schema<'py>(
88 core_schema: &Bound<'py, PyAny>,
89 _mode: PydanticSchemaMode,
90 ) -> PyResult<Bound<'py, PyAny>> {
91 core_schema.call_method0($schema)
92 }
93
94 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
95 value.extract()
96 }
97
98 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
99 Ok((*self).into_pyobject(py)?.clone().into_any().unbind())
100 }
101 }
102 )*
103 };
104}
105
106impl_pydantic_python_value!("int_schema"; i8, i16, i32, i64, isize, u8, u16, u32, u64, usize);
107impl_pydantic_python_value!("float_schema"; f32, f64);
108
109impl PyPydanticType for bool {
110 fn pydantic_schema<'py>(
111 core_schema: &Bound<'py, PyAny>,
112 _mode: PydanticSchemaMode,
113 ) -> PyResult<Bound<'py, PyAny>> {
114 core_schema.call_method0("bool_schema")
115 }
116
117 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
118 value.extract()
119 }
120
121 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
122 Ok(PyBool::new(py, *self).to_owned().into_any().unbind())
123 }
124}
125
126impl PyPydanticType for String {
127 fn pydantic_schema<'py>(
128 core_schema: &Bound<'py, PyAny>,
129 _mode: PydanticSchemaMode,
130 ) -> PyResult<Bound<'py, PyAny>> {
131 core_schema.call_method0("str_schema")
132 }
133
134 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
135 value.extract()
136 }
137
138 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
139 Ok(self.clone().into_pyobject(py)?.into_any().unbind())
140 }
141}
142
143impl PyPydanticType for IpAddr {
144 fn pydantic_schema<'py>(
145 core_schema: &Bound<'py, PyAny>,
146 _mode: PydanticSchemaMode,
147 ) -> PyResult<Bound<'py, PyAny>> {
148 core_schema.call_method0("str_schema")
149 }
150
151 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
152 if let Ok(ip) = value.extract::<Self>() {
153 return Ok(ip);
154 }
155 value
156 .extract::<String>()?
157 .parse()
158 .map_err(|error| PyValueError::new_err(format!("Invalid IP address: {error}")))
159 }
160
161 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
162 Ok(self.to_string().into_pyobject(py)?.into_any().unbind())
163 }
164}
165
166impl PyPydanticType for MacAddr {
167 fn pydantic_schema<'py>(
168 core_schema: &Bound<'py, PyAny>,
169 _mode: PydanticSchemaMode,
170 ) -> PyResult<Bound<'py, PyAny>> {
171 core_schema.call_method0("str_schema")
172 }
173
174 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
175 value
176 .extract::<String>()?
177 .parse()
178 .map_err(|error| PyValueError::new_err(format!("Invalid MAC address: {error}")))
179 }
180
181 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
182 Ok(self.to_string().into_pyobject(py)?.into_any().unbind())
183 }
184}
185
186fn duration_to_seconds(duration: Duration) -> f64 {
187 duration.as_secs() as f64
188}
189
190impl PyPydanticType for Duration {
191 fn pydantic_schema<'py>(
192 core_schema: &Bound<'py, PyAny>,
193 mode: PydanticSchemaMode,
194 ) -> PyResult<Bound<'py, PyAny>> {
195 match mode {
196 PydanticSchemaMode::Validation => core_schema.call_method0("any_schema"),
197 PydanticSchemaMode::Serialization => core_schema.call_method0("float_schema"),
198 }
199 }
200
201 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
202 if let Ok(duration) = value.extract::<Self>() {
203 return Ok(duration);
204 }
205 if let Ok(seconds) = value.extract::<f64>()
206 && seconds.is_finite()
207 && seconds >= 0.0
208 {
209 return Ok(Self::from_secs_f64(seconds));
210 }
211 if let Ok(dict) = value.cast::<PyDict>() {
212 let secs = required_dict_item(dict, "secs")?.extract::<u64>()?;
213 return Ok(Self::from_secs(secs));
214 }
215 Err(PyValueError::new_err(
216 "Expected duration as timedelta, non-negative seconds, or {secs} dict",
217 ))
218 }
219
220 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
221 Ok(duration_to_seconds(*self)
222 .into_pyobject(py)?
223 .into_any()
224 .unbind())
225 }
226}
227
228macro_rules! impl_pydantic_measurement {
229 ($ty:ty, $from_unit:ident, $as_unit:ident) => {
230 impl PyPydanticType for $ty {
231 fn pydantic_schema<'py>(
232 core_schema: &Bound<'py, PyAny>,
233 _mode: PydanticSchemaMode,
234 ) -> PyResult<Bound<'py, PyAny>> {
235 core_schema.call_method0("float_schema")
236 }
237
238 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
239 Ok(Self::$from_unit(value.extract::<f64>()?))
240 }
241
242 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
243 Ok(self.$as_unit().into_pyobject(py)?.into_any().unbind())
244 }
245 }
246 };
247}
248
249impl_pydantic_measurement!(AngularVelocity, from_rpm, as_rpm);
250impl_pydantic_measurement!(Frequency, from_megahertz, as_megahertz);
251impl_pydantic_measurement!(Power, from_watts, as_watts);
252impl_pydantic_measurement!(Temperature, from_celsius, as_celsius);
253impl_pydantic_measurement!(Voltage, from_volts, as_volts);
254
255impl<T> PyPydanticType for Option<T>
256where
257 T: PyPydanticType,
258{
259 fn pydantic_schema<'py>(
260 core_schema: &Bound<'py, PyAny>,
261 mode: PydanticSchemaMode,
262 ) -> PyResult<Bound<'py, PyAny>> {
263 let inner = T::pydantic_schema(core_schema, mode)?;
264 nullable_schema(core_schema, &inner)
265 }
266
267 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
268 if value.is_none() {
269 Ok(None)
270 } else {
271 T::from_pydantic(value).map(Some)
272 }
273 }
274
275 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
276 if let Some(value) = self {
277 value.to_pydantic_data(py)
278 } else {
279 Ok(py.None())
280 }
281 }
282
283 fn to_pydantic_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
284 if let Some(value) = self {
285 value.to_pydantic_repr_value(py)
286 } else {
287 Ok(py.None())
288 }
289 }
290}
291
292impl<T> PyPydanticType for Vec<T>
293where
294 T: PyPydanticType,
295{
296 fn pydantic_schema<'py>(
297 core_schema: &Bound<'py, PyAny>,
298 mode: PydanticSchemaMode,
299 ) -> PyResult<Bound<'py, PyAny>> {
300 let inner = T::pydantic_schema(core_schema, mode)?;
301 list_schema(core_schema, &inner)
302 }
303
304 fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
305 value
306 .try_iter()?
307 .map(|item| {
308 let item = item?;
309 T::from_pydantic(&item)
310 })
311 .collect()
312 }
313
314 fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
315 let list = PyList::empty(py);
316 for value in self {
317 list.append(value.to_pydantic_data(py)?)?;
318 }
319 Ok(list.into_any().unbind())
320 }
321
322 fn to_pydantic_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
323 let list = PyList::empty(py);
324 for value in self {
325 list.append(value.to_pydantic_repr_value(py)?)?;
326 }
327 Ok(list.into_any().unbind())
328 }
329}
330
331pub fn typed_dict_field<'py>(
332 core_schema: &Bound<'py, PyAny>,
333 schema: &Bound<'py, PyAny>,
334 required: bool,
335) -> PyResult<Bound<'py, PyAny>> {
336 let kwargs = PyDict::new(core_schema.py());
337 kwargs.set_item("required", required)?;
338 core_schema.call_method("typed_dict_field", (schema,), Some(&kwargs))
339}
340
341pub fn typed_dict_schema<'py>(
342 core_schema: &Bound<'py, PyAny>,
343 fields: &Bound<'py, PyDict>,
344 ref_name: Option<&str>,
345) -> PyResult<Bound<'py, PyAny>> {
346 let kwargs = PyDict::new(core_schema.py());
347 if let Some(ref_name) = ref_name {
348 kwargs.set_item("ref", ref_name)?;
349 }
350 core_schema.call_method("typed_dict_schema", (fields,), Some(&kwargs))
351}
352
353#[macro_export]
354macro_rules! pydantic_typed_dict_schema {
355 ($core_schema:expr, $ref_name:expr, { $($fields:tt)* }) => {{
356 let fields = ::pyo3::types::PyDict::new($core_schema.py());
357 $crate::pydantic_typed_dict_schema!(@fields fields, $core_schema, $($fields)*,);
358 $crate::typed_dict_schema($core_schema, &fields, Some($ref_name))
359 }};
360
361 (@fields $fields:ident, $core_schema:expr, $(,)?) => {};
362
363 (@fields $fields:ident, $core_schema:expr, $field:expr => required($schema:expr), $($rest:tt)*) => {{
364 $crate::pydantic_typed_dict_schema!(@insert $fields, $core_schema, $field, required($schema));
365 $crate::pydantic_typed_dict_schema!(@fields $fields, $core_schema, $($rest)*);
366 }};
367
368 (@fields $fields:ident, $core_schema:expr, $field:expr => required_if($schema:expr, $required:expr), $($rest:tt)*) => {{
369 $crate::pydantic_typed_dict_schema!(@insert $fields, $core_schema, $field, required_if($schema, $required));
370 $crate::pydantic_typed_dict_schema!(@fields $fields, $core_schema, $($rest)*);
371 }};
372
373 (@fields $fields:ident, $core_schema:expr, $field:expr => nullable($schema:expr), $($rest:tt)*) => {{
374 $crate::pydantic_typed_dict_schema!(@insert $fields, $core_schema, $field, nullable($schema));
375 $crate::pydantic_typed_dict_schema!(@fields $fields, $core_schema, $($rest)*);
376 }};
377
378 (@fields $fields:ident, $core_schema:expr, $field:expr => nullable_if($schema:expr, $required:expr), $($rest:tt)*) => {{
379 $crate::pydantic_typed_dict_schema!(@insert $fields, $core_schema, $field, nullable_if($schema, $required));
380 $crate::pydantic_typed_dict_schema!(@fields $fields, $core_schema, $($rest)*);
381 }};
382
383 (@insert $fields:ident, $core_schema:expr, $field:expr, required($schema:expr)) => {{
384 $fields.set_item(
385 $field,
386 $crate::typed_dict_field($core_schema, &$schema, true)?,
387 )?;
388 }};
389
390 (@insert $fields:ident, $core_schema:expr, $field:expr, required_if($schema:expr, $required:expr)) => {{
391 $fields.set_item(
392 $field,
393 $crate::typed_dict_field($core_schema, &$schema, $required)?,
394 )?;
395 }};
396
397 (@insert $fields:ident, $core_schema:expr, $field:expr, nullable($schema:expr)) => {{
398 $fields.set_item(
399 $field,
400 $crate::nullable_field($core_schema, &$schema, true)?,
401 )?;
402 }};
403
404 (@insert $fields:ident, $core_schema:expr, $field:expr, nullable_if($schema:expr, $required:expr)) => {{
405 $fields.set_item(
406 $field,
407 $crate::nullable_field($core_schema, &$schema, $required)?,
408 )?;
409 }};
410
411}
412
413pub fn tagged_union_schema<'py, I>(
414 core_schema: &Bound<'py, PyAny>,
415 choices: I,
416 discriminator: &str,
417 ref_name: Option<&str>,
418) -> PyResult<Bound<'py, PyAny>>
419where
420 I: IntoIterator<Item = (&'static str, Bound<'py, PyAny>)>,
421{
422 let py = core_schema.py();
423 let choices_dict = PyDict::new(py);
424 for (tag, schema) in choices {
425 choices_dict.set_item(tag, schema)?;
426 }
427 let kwargs = PyDict::new(py);
428 if let Some(ref_name) = ref_name {
429 kwargs.set_item("ref", ref_name)?;
430 }
431 core_schema.call_method(
432 "tagged_union_schema",
433 (choices_dict, discriminator),
434 Some(&kwargs),
435 )
436}
437
438pub fn union_schema<'py, I>(
439 core_schema: &Bound<'py, PyAny>,
440 choices: I,
441) -> PyResult<Bound<'py, PyAny>>
442where
443 I: IntoIterator<Item = Bound<'py, PyAny>>,
444{
445 let choices_list = PyList::empty(core_schema.py());
446 for schema in choices {
447 choices_list.append(schema)?;
448 }
449 core_schema.call_method1("union_schema", (choices_list,))
450}
451
452pub fn literal_schema<'py>(
453 core_schema: &Bound<'py, PyAny>,
454 values: &[&str],
455) -> PyResult<Bound<'py, PyAny>> {
456 let values = PyList::new(core_schema.py(), values)?;
457 core_schema.call_method1("literal_schema", (values,))
458}
459
460pub fn nullable_schema<'py>(
461 core_schema: &Bound<'py, PyAny>,
462 schema: &Bound<'py, PyAny>,
463) -> PyResult<Bound<'py, PyAny>> {
464 core_schema.call_method1("nullable_schema", (schema,))
465}
466
467pub fn nullable_field<'py>(
468 core_schema: &Bound<'py, PyAny>,
469 schema: &Bound<'py, PyAny>,
470 required: bool,
471) -> PyResult<Bound<'py, PyAny>> {
472 let schema = nullable_schema(core_schema, schema)?;
473 typed_dict_field(core_schema, &schema, required)
474}
475
476pub fn list_schema<'py>(
477 core_schema: &Bound<'py, PyAny>,
478 item_schema: &Bound<'py, PyAny>,
479) -> PyResult<Bound<'py, PyAny>> {
480 core_schema.call_method1("list_schema", (item_schema,))
481}
482
483pub fn required_dict_item<'py>(
484 dict: &Bound<'py, PyDict>,
485 key: &str,
486) -> PyResult<Bound<'py, PyAny>> {
487 dict.get_item(key)?
488 .ok_or_else(|| PyValueError::new_err(format!("Missing required key: {key}")))
489}
490
491pub fn py_to_string(value: &Bound<'_, PyAny>) -> PyResult<String> {
492 Ok(value.str()?.to_str()?.to_string())
493}
494
495pub fn get_required_field<'py>(
496 value: &Bound<'py, PyAny>,
497 key: &str,
498) -> PyResult<Bound<'py, PyAny>> {
499 if let Ok(dict) = value.cast::<PyDict>() {
500 required_dict_item(dict, key)
501 } else {
502 value.getattr(key)
503 }
504}
505
506pub fn get_optional_field<'py>(
507 value: &Bound<'py, PyAny>,
508 key: &str,
509) -> PyResult<Option<Bound<'py, PyAny>>> {
510 if let Ok(dict) = value.cast::<PyDict>() {
511 dict.get_item(key)
512 } else if value.hasattr(key)? {
513 Ok(Some(value.getattr(key)?))
514 } else {
515 Ok(None)
516 }
517}
518
519pub fn parse_optional<T>(value: Option<Bound<'_, PyAny>>) -> PyResult<Option<T>>
520where
521 for<'a> T: FromPyObject<'a, 'a>,
522 for<'a> <T as FromPyObject<'a, 'a>>::Error: Into<PyErr>,
523{
524 match value {
525 Some(value) if value.is_none() => Ok(None),
526 Some(value) => value.extract().map(Some).map_err(Into::into),
527 None => Ok(None),
528 }
529}
530
531pub fn parse_required_list<T, F>(value: &Bound<'_, PyAny>, key: &str, parse: F) -> PyResult<Vec<T>>
532where
533 F: for<'py> Fn(&Bound<'py, PyAny>) -> PyResult<T>,
534{
535 get_required_field(value, key)?
536 .try_iter()?
537 .map(|item| {
538 let item = item?;
539 parse(&item)
540 })
541 .collect()
542}
543
544pub fn parse_required_option<T>(value: &Bound<'_, PyAny>, key: &str) -> PyResult<Option<T>>
545where
546 for<'a> T: FromPyObject<'a, 'a>,
547 for<'a> <T as FromPyObject<'a, 'a>>::Error: Into<PyErr>,
548{
549 get_required_field(value, key)?
550 .extract::<Option<T>>()
551 .map_err(Into::into)
552}
553
554pub fn model_core_schema(
555 cls: &Bound<'_, PyType>,
556 validation_schema: &Bound<'_, PyAny>,
557 serialization_schema: &Bound<'_, PyAny>,
558) -> PyResult<Py<PyAny>> {
559 let py = cls.py();
560 let core_schema = py.import("pydantic_core")?.getattr("core_schema")?;
561 let validator = cls.getattr("_pydantic_validate")?;
562 let serializer = cls.getattr("_pydantic_serialize")?;
563 let instance_schema = core_schema.call_method1("is_instance_schema", (cls,))?;
564 let python_schema = union_schema(&core_schema, [instance_schema, validation_schema.clone()])?;
565 let serializer_kwargs = PyDict::new(py);
566 serializer_kwargs.set_item("return_schema", serialization_schema)?;
567 let serializer_schema = core_schema.call_method(
568 "plain_serializer_function_ser_schema",
569 (serializer,),
570 Some(&serializer_kwargs),
571 )?;
572 let kwargs = PyDict::new(py);
573 kwargs.set_item("json_schema_input_schema", validation_schema)?;
574 kwargs.set_item("serialization", serializer_schema)?;
575 let schema = core_schema.call_method(
576 "no_info_after_validator_function",
577 (validator, python_schema),
578 Some(&kwargs),
579 )?;
580 Ok(schema.unbind())
581}
582
583pub fn model_json_schema(
584 cls: &Bound<'_, PyType>,
585 kwargs: Option<&Bound<'_, PyDict>>,
586) -> PyResult<Py<PyAny>> {
587 let adapter = cls
588 .py()
589 .import("pydantic")?
590 .getattr("TypeAdapter")?
591 .call1((cls,))?;
592 Ok(adapter.call_method("json_schema", (), kwargs)?.unbind())
593}
594
595pub fn reject_model_kwargs(kwargs: Option<&Bound<'_, PyDict>>, method: &str) -> PyResult<()> {
596 if let Some(kwargs) = kwargs
597 && !kwargs.is_empty()
598 {
599 return Err(PyValueError::new_err(format!(
600 "{method} keyword arguments are not supported by asic_rs models"
601 )));
602 }
603 Ok(())
604}