Skip to main content

asic_rs_pydantic/
lib.rs

1use std::{fmt::Display, net::IpAddr, str::FromStr, time::Duration};
2
3use macaddr::MacAddr;
4use measurements::{AngularVelocity, Frequency, Power, Temperature, Voltage};
5use pyo3::{
6    PyTypeInfo,
7    exceptions::PyValueError,
8    prelude::*,
9    types::{PyAnyMethods, PyBool, PyDict, PyDictMethods, PyList, PyListMethods, PyType},
10};
11
12pub use asic_rs_pydantic_macros::{
13    PyPydanticData, PyPydanticEnum, PyPydanticModel, PyPydanticTaggedEnum, PyPydanticTaggedUnion,
14    py_pydantic_model,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum PydanticSchemaMode {
19    Validation,
20    Serialization,
21}
22
23pub trait PyPydanticType: Sized {
24    fn pydantic_schema<'py>(
25        core_schema: &Bound<'py, PyAny>,
26        mode: PydanticSchemaMode,
27    ) -> PyResult<Bound<'py, PyAny>>;
28
29    fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self>;
30
31    fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>>;
32
33    fn to_pydantic_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
34        self.to_pydantic_data(py)
35    }
36}
37
38pub trait PyPydanticStringEnum: Clone + Display + FromStr + PyTypeInfo + Sized {
39    const PYDANTIC_VALUES: &'static [&'static str];
40
41    fn to_pydantic_enum_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>>;
42}
43
44impl<T> PyPydanticType for T
45where
46    T: PyPydanticStringEnum + for<'py> FromPyObject<'py, 'py>,
47    <T as FromStr>::Err: Display,
48    for<'py> <T as FromPyObject<'py, 'py>>::Error: Into<PyErr>,
49{
50    fn pydantic_schema<'py>(
51        core_schema: &Bound<'py, PyAny>,
52        mode: PydanticSchemaMode,
53    ) -> PyResult<Bound<'py, PyAny>> {
54        let string_schema = literal_schema(core_schema, T::PYDANTIC_VALUES)?;
55        if mode == PydanticSchemaMode::Serialization {
56            return Ok(string_schema);
57        }
58
59        let py = core_schema.py();
60        let instance_schema =
61            core_schema.call_method1("is_instance_schema", (py.get_type::<T>(),))?;
62        union_schema(core_schema, [instance_schema, string_schema])
63    }
64
65    fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
66        if let Ok(value) = value.extract::<T>() {
67            return Ok(value);
68        }
69
70        let value = value.extract::<String>()?;
71        T::from_str(&value).map_err(|error| PyValueError::new_err(error.to_string()))
72    }
73
74    fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
75        Ok(self.to_string().into_pyobject(py)?.into_any().unbind())
76    }
77
78    fn to_pydantic_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
79        self.to_pydantic_enum_repr_value(py)
80    }
81}
82
83macro_rules! impl_pydantic_python_value {
84    ($schema:literal; $($ty:ty),* $(,)?) => {
85        $(
86            impl PyPydanticType for $ty {
87                fn pydantic_schema<'py>(
88                    core_schema: &Bound<'py, PyAny>,
89                    _mode: PydanticSchemaMode,
90                ) -> PyResult<Bound<'py, PyAny>> {
91                    core_schema.call_method0($schema)
92                }
93
94                fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
95                    value.extract()
96                }
97
98                fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
99                    Ok((*self).into_pyobject(py)?.clone().into_any().unbind())
100                }
101            }
102        )*
103    };
104}
105
106impl_pydantic_python_value!("int_schema"; i8, i16, i32, i64, isize, u8, u16, u32, u64, usize);
107impl_pydantic_python_value!("float_schema"; f32, f64);
108
109impl PyPydanticType for bool {
110    fn pydantic_schema<'py>(
111        core_schema: &Bound<'py, PyAny>,
112        _mode: PydanticSchemaMode,
113    ) -> PyResult<Bound<'py, PyAny>> {
114        core_schema.call_method0("bool_schema")
115    }
116
117    fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
118        value.extract()
119    }
120
121    fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
122        Ok(PyBool::new(py, *self).to_owned().into_any().unbind())
123    }
124}
125
126impl PyPydanticType for String {
127    fn pydantic_schema<'py>(
128        core_schema: &Bound<'py, PyAny>,
129        _mode: PydanticSchemaMode,
130    ) -> PyResult<Bound<'py, PyAny>> {
131        core_schema.call_method0("str_schema")
132    }
133
134    fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
135        value.extract()
136    }
137
138    fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
139        Ok(self.clone().into_pyobject(py)?.into_any().unbind())
140    }
141}
142
143impl PyPydanticType for IpAddr {
144    fn pydantic_schema<'py>(
145        core_schema: &Bound<'py, PyAny>,
146        _mode: PydanticSchemaMode,
147    ) -> PyResult<Bound<'py, PyAny>> {
148        core_schema.call_method0("str_schema")
149    }
150
151    fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
152        if let Ok(ip) = value.extract::<Self>() {
153            return Ok(ip);
154        }
155        value
156            .extract::<String>()?
157            .parse()
158            .map_err(|error| PyValueError::new_err(format!("Invalid IP address: {error}")))
159    }
160
161    fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
162        Ok(self.to_string().into_pyobject(py)?.into_any().unbind())
163    }
164}
165
166impl PyPydanticType for MacAddr {
167    fn pydantic_schema<'py>(
168        core_schema: &Bound<'py, PyAny>,
169        _mode: PydanticSchemaMode,
170    ) -> PyResult<Bound<'py, PyAny>> {
171        core_schema.call_method0("str_schema")
172    }
173
174    fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
175        value
176            .extract::<String>()?
177            .parse()
178            .map_err(|error| PyValueError::new_err(format!("Invalid MAC address: {error}")))
179    }
180
181    fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
182        Ok(self.to_string().into_pyobject(py)?.into_any().unbind())
183    }
184}
185
186fn duration_to_seconds(duration: Duration) -> f64 {
187    duration.as_secs() as f64
188}
189
190impl PyPydanticType for Duration {
191    fn pydantic_schema<'py>(
192        core_schema: &Bound<'py, PyAny>,
193        mode: PydanticSchemaMode,
194    ) -> PyResult<Bound<'py, PyAny>> {
195        match mode {
196            PydanticSchemaMode::Validation => core_schema.call_method0("any_schema"),
197            PydanticSchemaMode::Serialization => core_schema.call_method0("float_schema"),
198        }
199    }
200
201    fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
202        if let Ok(duration) = value.extract::<Self>() {
203            return Ok(duration);
204        }
205        if let Ok(seconds) = value.extract::<f64>()
206            && seconds.is_finite()
207            && seconds >= 0.0
208        {
209            return Ok(Self::from_secs_f64(seconds));
210        }
211        if let Ok(dict) = value.cast::<PyDict>() {
212            let secs = required_dict_item(dict, "secs")?.extract::<u64>()?;
213            return Ok(Self::from_secs(secs));
214        }
215        Err(PyValueError::new_err(
216            "Expected duration as timedelta, non-negative seconds, or {secs} dict",
217        ))
218    }
219
220    fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
221        Ok(duration_to_seconds(*self)
222            .into_pyobject(py)?
223            .into_any()
224            .unbind())
225    }
226}
227
228macro_rules! impl_pydantic_measurement {
229    ($ty:ty, $from_unit:ident, $as_unit:ident) => {
230        impl PyPydanticType for $ty {
231            fn pydantic_schema<'py>(
232                core_schema: &Bound<'py, PyAny>,
233                _mode: PydanticSchemaMode,
234            ) -> PyResult<Bound<'py, PyAny>> {
235                core_schema.call_method0("float_schema")
236            }
237
238            fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
239                Ok(Self::$from_unit(value.extract::<f64>()?))
240            }
241
242            fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
243                Ok(self.$as_unit().into_pyobject(py)?.into_any().unbind())
244            }
245        }
246    };
247}
248
249impl_pydantic_measurement!(AngularVelocity, from_rpm, as_rpm);
250impl_pydantic_measurement!(Frequency, from_megahertz, as_megahertz);
251impl_pydantic_measurement!(Power, from_watts, as_watts);
252impl_pydantic_measurement!(Temperature, from_celsius, as_celsius);
253impl_pydantic_measurement!(Voltage, from_volts, as_volts);
254
255impl<T> PyPydanticType for Option<T>
256where
257    T: PyPydanticType,
258{
259    fn pydantic_schema<'py>(
260        core_schema: &Bound<'py, PyAny>,
261        mode: PydanticSchemaMode,
262    ) -> PyResult<Bound<'py, PyAny>> {
263        let inner = T::pydantic_schema(core_schema, mode)?;
264        nullable_schema(core_schema, &inner)
265    }
266
267    fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
268        if value.is_none() {
269            Ok(None)
270        } else {
271            T::from_pydantic(value).map(Some)
272        }
273    }
274
275    fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
276        if let Some(value) = self {
277            value.to_pydantic_data(py)
278        } else {
279            Ok(py.None())
280        }
281    }
282
283    fn to_pydantic_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
284        if let Some(value) = self {
285            value.to_pydantic_repr_value(py)
286        } else {
287            Ok(py.None())
288        }
289    }
290}
291
292impl<T> PyPydanticType for Vec<T>
293where
294    T: PyPydanticType,
295{
296    fn pydantic_schema<'py>(
297        core_schema: &Bound<'py, PyAny>,
298        mode: PydanticSchemaMode,
299    ) -> PyResult<Bound<'py, PyAny>> {
300        let inner = T::pydantic_schema(core_schema, mode)?;
301        list_schema(core_schema, &inner)
302    }
303
304    fn from_pydantic(value: &Bound<'_, PyAny>) -> PyResult<Self> {
305        value
306            .try_iter()?
307            .map(|item| {
308                let item = item?;
309                T::from_pydantic(&item)
310            })
311            .collect()
312    }
313
314    fn to_pydantic_data(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
315        let list = PyList::empty(py);
316        for value in self {
317            list.append(value.to_pydantic_data(py)?)?;
318        }
319        Ok(list.into_any().unbind())
320    }
321
322    fn to_pydantic_repr_value(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
323        let list = PyList::empty(py);
324        for value in self {
325            list.append(value.to_pydantic_repr_value(py)?)?;
326        }
327        Ok(list.into_any().unbind())
328    }
329}
330
331pub fn typed_dict_field<'py>(
332    core_schema: &Bound<'py, PyAny>,
333    schema: &Bound<'py, PyAny>,
334    required: bool,
335) -> PyResult<Bound<'py, PyAny>> {
336    let kwargs = PyDict::new(core_schema.py());
337    kwargs.set_item("required", required)?;
338    core_schema.call_method("typed_dict_field", (schema,), Some(&kwargs))
339}
340
341pub fn typed_dict_schema<'py>(
342    core_schema: &Bound<'py, PyAny>,
343    fields: &Bound<'py, PyDict>,
344    ref_name: Option<&str>,
345) -> PyResult<Bound<'py, PyAny>> {
346    let kwargs = PyDict::new(core_schema.py());
347    if let Some(ref_name) = ref_name {
348        kwargs.set_item("ref", ref_name)?;
349    }
350    core_schema.call_method("typed_dict_schema", (fields,), Some(&kwargs))
351}
352
353#[macro_export]
354macro_rules! pydantic_typed_dict_schema {
355    ($core_schema:expr, $ref_name:expr, { $($fields:tt)* }) => {{
356        let fields = ::pyo3::types::PyDict::new($core_schema.py());
357        $crate::pydantic_typed_dict_schema!(@fields fields, $core_schema, $($fields)*,);
358        $crate::typed_dict_schema($core_schema, &fields, Some($ref_name))
359    }};
360
361    (@fields $fields:ident, $core_schema:expr, $(,)?) => {};
362
363    (@fields $fields:ident, $core_schema:expr, $field:expr => required($schema:expr), $($rest:tt)*) => {{
364        $crate::pydantic_typed_dict_schema!(@insert $fields, $core_schema, $field, required($schema));
365        $crate::pydantic_typed_dict_schema!(@fields $fields, $core_schema, $($rest)*);
366    }};
367
368    (@fields $fields:ident, $core_schema:expr, $field:expr => required_if($schema:expr, $required:expr), $($rest:tt)*) => {{
369        $crate::pydantic_typed_dict_schema!(@insert $fields, $core_schema, $field, required_if($schema, $required));
370        $crate::pydantic_typed_dict_schema!(@fields $fields, $core_schema, $($rest)*);
371    }};
372
373    (@fields $fields:ident, $core_schema:expr, $field:expr => nullable($schema:expr), $($rest:tt)*) => {{
374        $crate::pydantic_typed_dict_schema!(@insert $fields, $core_schema, $field, nullable($schema));
375        $crate::pydantic_typed_dict_schema!(@fields $fields, $core_schema, $($rest)*);
376    }};
377
378    (@fields $fields:ident, $core_schema:expr, $field:expr => nullable_if($schema:expr, $required:expr), $($rest:tt)*) => {{
379        $crate::pydantic_typed_dict_schema!(@insert $fields, $core_schema, $field, nullable_if($schema, $required));
380        $crate::pydantic_typed_dict_schema!(@fields $fields, $core_schema, $($rest)*);
381    }};
382
383    (@insert $fields:ident, $core_schema:expr, $field:expr, required($schema:expr)) => {{
384        $fields.set_item(
385            $field,
386            $crate::typed_dict_field($core_schema, &$schema, true)?,
387        )?;
388    }};
389
390    (@insert $fields:ident, $core_schema:expr, $field:expr, required_if($schema:expr, $required:expr)) => {{
391        $fields.set_item(
392            $field,
393            $crate::typed_dict_field($core_schema, &$schema, $required)?,
394        )?;
395    }};
396
397    (@insert $fields:ident, $core_schema:expr, $field:expr, nullable($schema:expr)) => {{
398        $fields.set_item(
399            $field,
400            $crate::nullable_field($core_schema, &$schema, true)?,
401        )?;
402    }};
403
404    (@insert $fields:ident, $core_schema:expr, $field:expr, nullable_if($schema:expr, $required:expr)) => {{
405        $fields.set_item(
406            $field,
407            $crate::nullable_field($core_schema, &$schema, $required)?,
408        )?;
409    }};
410
411}
412
413pub fn tagged_union_schema<'py, I>(
414    core_schema: &Bound<'py, PyAny>,
415    choices: I,
416    discriminator: &str,
417    ref_name: Option<&str>,
418) -> PyResult<Bound<'py, PyAny>>
419where
420    I: IntoIterator<Item = (&'static str, Bound<'py, PyAny>)>,
421{
422    let py = core_schema.py();
423    let choices_dict = PyDict::new(py);
424    for (tag, schema) in choices {
425        choices_dict.set_item(tag, schema)?;
426    }
427    let kwargs = PyDict::new(py);
428    if let Some(ref_name) = ref_name {
429        kwargs.set_item("ref", ref_name)?;
430    }
431    core_schema.call_method(
432        "tagged_union_schema",
433        (choices_dict, discriminator),
434        Some(&kwargs),
435    )
436}
437
438pub fn union_schema<'py, I>(
439    core_schema: &Bound<'py, PyAny>,
440    choices: I,
441) -> PyResult<Bound<'py, PyAny>>
442where
443    I: IntoIterator<Item = Bound<'py, PyAny>>,
444{
445    let choices_list = PyList::empty(core_schema.py());
446    for schema in choices {
447        choices_list.append(schema)?;
448    }
449    core_schema.call_method1("union_schema", (choices_list,))
450}
451
452pub fn literal_schema<'py>(
453    core_schema: &Bound<'py, PyAny>,
454    values: &[&str],
455) -> PyResult<Bound<'py, PyAny>> {
456    let values = PyList::new(core_schema.py(), values)?;
457    core_schema.call_method1("literal_schema", (values,))
458}
459
460pub fn nullable_schema<'py>(
461    core_schema: &Bound<'py, PyAny>,
462    schema: &Bound<'py, PyAny>,
463) -> PyResult<Bound<'py, PyAny>> {
464    core_schema.call_method1("nullable_schema", (schema,))
465}
466
467pub fn nullable_field<'py>(
468    core_schema: &Bound<'py, PyAny>,
469    schema: &Bound<'py, PyAny>,
470    required: bool,
471) -> PyResult<Bound<'py, PyAny>> {
472    let schema = nullable_schema(core_schema, schema)?;
473    typed_dict_field(core_schema, &schema, required)
474}
475
476pub fn list_schema<'py>(
477    core_schema: &Bound<'py, PyAny>,
478    item_schema: &Bound<'py, PyAny>,
479) -> PyResult<Bound<'py, PyAny>> {
480    core_schema.call_method1("list_schema", (item_schema,))
481}
482
483pub fn required_dict_item<'py>(
484    dict: &Bound<'py, PyDict>,
485    key: &str,
486) -> PyResult<Bound<'py, PyAny>> {
487    dict.get_item(key)?
488        .ok_or_else(|| PyValueError::new_err(format!("Missing required key: {key}")))
489}
490
491pub fn py_to_string(value: &Bound<'_, PyAny>) -> PyResult<String> {
492    Ok(value.str()?.to_str()?.to_string())
493}
494
495pub fn get_required_field<'py>(
496    value: &Bound<'py, PyAny>,
497    key: &str,
498) -> PyResult<Bound<'py, PyAny>> {
499    if let Ok(dict) = value.cast::<PyDict>() {
500        required_dict_item(dict, key)
501    } else {
502        value.getattr(key)
503    }
504}
505
506pub fn get_optional_field<'py>(
507    value: &Bound<'py, PyAny>,
508    key: &str,
509) -> PyResult<Option<Bound<'py, PyAny>>> {
510    if let Ok(dict) = value.cast::<PyDict>() {
511        dict.get_item(key)
512    } else if value.hasattr(key)? {
513        Ok(Some(value.getattr(key)?))
514    } else {
515        Ok(None)
516    }
517}
518
519pub fn parse_optional<T>(value: Option<Bound<'_, PyAny>>) -> PyResult<Option<T>>
520where
521    for<'a> T: FromPyObject<'a, 'a>,
522    for<'a> <T as FromPyObject<'a, 'a>>::Error: Into<PyErr>,
523{
524    match value {
525        Some(value) if value.is_none() => Ok(None),
526        Some(value) => value.extract().map(Some).map_err(Into::into),
527        None => Ok(None),
528    }
529}
530
531pub fn parse_required_list<T, F>(value: &Bound<'_, PyAny>, key: &str, parse: F) -> PyResult<Vec<T>>
532where
533    F: for<'py> Fn(&Bound<'py, PyAny>) -> PyResult<T>,
534{
535    get_required_field(value, key)?
536        .try_iter()?
537        .map(|item| {
538            let item = item?;
539            parse(&item)
540        })
541        .collect()
542}
543
544pub fn parse_required_option<T>(value: &Bound<'_, PyAny>, key: &str) -> PyResult<Option<T>>
545where
546    for<'a> T: FromPyObject<'a, 'a>,
547    for<'a> <T as FromPyObject<'a, 'a>>::Error: Into<PyErr>,
548{
549    get_required_field(value, key)?
550        .extract::<Option<T>>()
551        .map_err(Into::into)
552}
553
554pub fn model_core_schema(
555    cls: &Bound<'_, PyType>,
556    validation_schema: &Bound<'_, PyAny>,
557    serialization_schema: &Bound<'_, PyAny>,
558) -> PyResult<Py<PyAny>> {
559    let py = cls.py();
560    let core_schema = py.import("pydantic_core")?.getattr("core_schema")?;
561    let validator = cls.getattr("_pydantic_validate")?;
562    let serializer = cls.getattr("_pydantic_serialize")?;
563    let instance_schema = core_schema.call_method1("is_instance_schema", (cls,))?;
564    let python_schema = union_schema(&core_schema, [instance_schema, validation_schema.clone()])?;
565    let serializer_kwargs = PyDict::new(py);
566    serializer_kwargs.set_item("return_schema", serialization_schema)?;
567    let serializer_schema = core_schema.call_method(
568        "plain_serializer_function_ser_schema",
569        (serializer,),
570        Some(&serializer_kwargs),
571    )?;
572    let kwargs = PyDict::new(py);
573    kwargs.set_item("json_schema_input_schema", validation_schema)?;
574    kwargs.set_item("serialization", serializer_schema)?;
575    let schema = core_schema.call_method(
576        "no_info_after_validator_function",
577        (validator, python_schema),
578        Some(&kwargs),
579    )?;
580    Ok(schema.unbind())
581}
582
583pub fn model_json_schema(
584    cls: &Bound<'_, PyType>,
585    kwargs: Option<&Bound<'_, PyDict>>,
586) -> PyResult<Py<PyAny>> {
587    let adapter = cls
588        .py()
589        .import("pydantic")?
590        .getattr("TypeAdapter")?
591        .call1((cls,))?;
592    Ok(adapter.call_method("json_schema", (), kwargs)?.unbind())
593}
594
595pub fn reject_model_kwargs(kwargs: Option<&Bound<'_, PyDict>>, method: &str) -> PyResult<()> {
596    if let Some(kwargs) = kwargs
597        && !kwargs.is_empty()
598    {
599        return Err(PyValueError::new_err(format!(
600            "{method} keyword arguments are not supported by asic_rs models"
601        )));
602    }
603    Ok(())
604}