numcodecs_python/
schema.rs

1use std::{
2    borrow::Cow,
3    collections::{hash_map::Entry, HashMap},
4};
5
6use pyo3::{intern, prelude::*, sync::GILOnceCell};
7use pythonize::{depythonize, PythonizeError};
8use schemars::Schema;
9use serde_json::{Map, Value};
10use thiserror::Error;
11
12use crate::{export::RustCodec, PyCodecClass};
13
14macro_rules! once {
15    ($py:ident, $module:literal $(, $path:literal)*) => {{
16        fn once(py: Python) -> Result<&Bound<PyAny>, PyErr> {
17            static ONCE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
18            Ok(ONCE.get_or_try_init(py, || -> Result<Py<PyAny>, PyErr> {
19                Ok(py
20                    .import(intern!(py, $module))?
21                    $(.getattr(intern!(py, $path))?)*
22                    .unbind())
23            })?.bind(py))
24        }
25
26        once($py)
27    }};
28}
29
30pub fn schema_from_codec_class(
31    py: Python,
32    class: &Bound<PyCodecClass>,
33) -> Result<Schema, SchemaError> {
34    if let Ok(schema) = class.getattr(intern!(py, RustCodec::SCHEMA_ATTRIBUTE)) {
35        return depythonize(&schema)
36            .map_err(|err| SchemaError::InvalidCachedJsonSchema { source: err });
37    }
38
39    let mut schema = Schema::default();
40
41    {
42        let schema = schema.ensure_object();
43
44        schema.insert(String::from("type"), Value::String(String::from("object")));
45
46        if let Ok(init) = class.getattr(intern!(py, "__init__")) {
47            let mut properties = Map::new();
48            let mut additional_properties = false;
49            let mut required = Vec::new();
50
51            let object_init = once!(py, "builtins", "object", "__init__")?;
52            let signature = once!(py, "inspect", "signature")?;
53            let empty_parameter = once!(py, "inspect", "Parameter", "empty")?;
54            let args_parameter = once!(py, "inspect", "Parameter", "VAR_POSITIONAL")?;
55            let kwargs_parameter = once!(py, "inspect", "Parameter", "VAR_KEYWORD")?;
56
57            for (i, param) in signature
58                .call1((&init,))?
59                .getattr(intern!(py, "parameters"))?
60                .call_method0(intern!(py, "items"))?
61                .try_iter()?
62                .enumerate()
63            {
64                let (name, param): (String, Bound<PyAny>) = param?.extract()?;
65
66                if i == 0 && name == "self" {
67                    continue;
68                }
69
70                let kind = param.getattr(intern!(py, "kind"))?;
71
72                if kind.eq(args_parameter)? && !init.eq(object_init)? {
73                    return Err(SchemaError::ArgsParameterInSignature);
74                }
75
76                if kind.eq(kwargs_parameter)? {
77                    additional_properties = true;
78                } else {
79                    let default = param.getattr(intern!(py, "default"))?;
80
81                    let mut parameter = Map::new();
82
83                    if default.eq(empty_parameter)? {
84                        required.push(Value::String(name.clone()));
85                    } else {
86                        let default = depythonize(&default).map_err(|err| {
87                            SchemaError::InvalidParameterDefault {
88                                name: name.clone(),
89                                source: err,
90                            }
91                        })?;
92                        parameter.insert(String::from("default"), default);
93                    }
94
95                    properties.insert(name, Value::Object(parameter));
96                }
97            }
98
99            schema.insert(
100                String::from("additionalProperties"),
101                Value::Bool(additional_properties),
102            );
103            schema.insert(String::from("properties"), Value::Object(properties));
104            schema.insert(String::from("required"), Value::Array(required));
105        } else {
106            schema.insert(String::from("additionalProperties"), Value::Bool(true));
107        }
108
109        if let Ok(doc) = class.getattr(intern!(py, "__doc__")) {
110            if !doc.is_none() {
111                let doc: String = doc
112                    .extract()
113                    .map_err(|err| SchemaError::InvalidClassDocs { source: err })?;
114                schema.insert(String::from("description"), Value::String(doc));
115            }
116        }
117
118        let name = class
119            .getattr(intern!(py, "__name__"))
120            .and_then(|name| name.extract())
121            .map_err(|err| SchemaError::InvalidClassName { source: err })?;
122        schema.insert(String::from("title"), Value::String(name));
123
124        schema.insert(
125            String::from("$schema"),
126            Value::String(String::from("https://json-schema.org/draft/2020-12/schema")),
127        );
128    }
129
130    Ok(schema)
131}
132
133pub fn docs_from_schema(schema: &Schema, codec_id: &str) -> Option<String> {
134    let parameters = parameters_from_schema(schema);
135    let schema = schema.as_object()?;
136
137    let mut docs = String::new();
138
139    docs.push_str("# ");
140    docs.push_str(codec_id);
141
142    if let Some(Value::String(title)) = schema.get("title") {
143        docs.push_str(" (");
144        docs.push_str(title);
145        docs.push(')');
146    }
147
148    docs.push_str("\n\n");
149
150    if let Some(Value::String(description)) = schema.get("description") {
151        docs.push_str(description);
152        docs.push_str("\n\n");
153    }
154
155    docs.push_str("## Parameters\n\n");
156
157    for parameter in &parameters.named {
158        docs.push_str(" - ");
159        docs.push_str(parameter.name);
160
161        docs.push_str(" (");
162
163        if parameter.required {
164            docs.push_str("required");
165        } else {
166            docs.push_str("optional");
167        }
168
169        if let Some(default) = parameter.default {
170            docs.push_str(", default = `");
171            docs.push_str(&format!("{default}"));
172            docs.push('`');
173        }
174
175        docs.push(')');
176
177        if let Some(info) = &parameter.docs {
178            docs.push_str(": ");
179            docs.push_str(&info.replace('\n', "\n   "));
180        }
181
182        docs.push('\n');
183    }
184
185    if parameters.named.is_empty() {
186        if parameters.additional {
187            docs.push_str("This codec takes *any* parameters.");
188        } else {
189            docs.push_str("This codec does *not* take any parameters.");
190        }
191    } else if parameters.additional {
192        docs.push_str("\nThis codec takes *any* additional parameters.");
193    }
194
195    docs.truncate(docs.trim_end().len());
196
197    Some(docs)
198}
199
200pub fn signature_from_schema(schema: &Schema) -> String {
201    let parameters = parameters_from_schema(schema);
202
203    let mut signature = String::new();
204    signature.push_str("self");
205
206    for parameter in parameters.named {
207        signature.push_str(", ");
208        signature.push_str(parameter.name);
209
210        if let Some(default) = parameter.default {
211            signature.push('=');
212            signature.push_str(&format!("{default}"));
213        } else if !parameter.required {
214            signature.push_str("=None");
215        }
216    }
217
218    if parameters.additional {
219        signature.push_str(", **kwargs");
220    }
221
222    signature
223}
224
225#[allow(clippy::too_many_lines)] // FIXME
226fn parameters_from_schema(schema: &Schema) -> Parameters {
227    // schema = true means that any parameters are allowed
228    if schema.as_bool() == Some(true) {
229        return Parameters {
230            named: Vec::new(),
231            additional: true,
232        };
233    }
234
235    // schema = false means that no config is valid
236    // we approximate that by saying that no parameters are allowed
237    let Some(schema) = schema.as_object() else {
238        return Parameters {
239            named: Vec::new(),
240            additional: false,
241        };
242    };
243
244    let mut parameters = Vec::new();
245
246    let required = match schema.get("required") {
247        Some(Value::Array(required)) => &**required,
248        _ => &[],
249    };
250
251    // extract the top-level parameters
252    if let Some(Value::Object(properties)) = schema.get("properties") {
253        for (name, parameter) in properties {
254            parameters.push(Parameter::new(name, parameter, required));
255        }
256    }
257
258    let mut additional = false;
259
260    extend_parameters_from_one_of_schema(schema, &mut parameters, &mut additional);
261
262    // iterate over allOf to handle flattened enums
263    if let Some(Value::Array(all)) = schema.get("allOf") {
264        for variant in all {
265            if let Some(variant) = variant.as_object() {
266                extend_parameters_from_one_of_schema(variant, &mut parameters, &mut additional);
267            }
268        }
269    }
270
271    // sort parameters by name and so that required parameters come first
272    parameters.sort_by_key(|p| (!p.required, p.name));
273
274    additional = match (
275        schema.get("additionalProperties"),
276        schema.get("unevaluatedProperties"),
277    ) {
278        (Some(Value::Bool(false)), None) => additional,
279        (None | Some(Value::Bool(false)), Some(Value::Bool(false))) => false,
280        _ => true,
281    };
282
283    Parameters {
284        named: parameters,
285        additional,
286    }
287}
288
289fn extend_parameters_from_one_of_schema<'a>(
290    schema: &'a Map<String, Value>,
291    parameters: &mut Vec<Parameter<'a>>,
292    additional: &mut bool,
293) {
294    // iterate over oneOf to handle top-level or flattened enums
295    if let Some(Value::Array(variants)) = schema.get("oneOf") {
296        let mut variant_parameters = HashMap::new();
297
298        for (generation, schema) in variants.iter().enumerate() {
299            // if any variant allows additional parameters, the top-level also
300            //  allows additional parameters
301            #[allow(clippy::unnested_or_patterns)]
302            if let Some(schema) = schema.as_object() {
303                *additional |= !matches!(
304                    (
305                        schema.get("additionalProperties"),
306                        schema.get("unevaluatedProperties")
307                    ),
308                    (Some(Value::Bool(false)), None)
309                        | (None, Some(Value::Bool(false)))
310                        | (Some(Value::Bool(false)), Some(Value::Bool(false)))
311                );
312            }
313
314            let required = match schema.get("required") {
315                Some(Value::Array(required)) => &**required,
316                _ => &[],
317            };
318            let variant_docs = match schema.get("description") {
319                Some(Value::String(docs)) => Some(docs.as_str()),
320                _ => None,
321            };
322
323            // extract the per-variant parameters and check for tag parameters
324            if let Some(Value::Object(properties)) = schema.get("properties") {
325                for (name, parameter) in properties {
326                    match variant_parameters.entry(name) {
327                        Entry::Vacant(entry) => {
328                            entry.insert(VariantParameter::new(
329                                generation,
330                                name,
331                                parameter,
332                                required,
333                                variant_docs,
334                            ));
335                        }
336                        Entry::Occupied(mut entry) => {
337                            entry.get_mut().merge(
338                                generation,
339                                name,
340                                parameter,
341                                required,
342                                variant_docs,
343                            );
344                        }
345                    }
346                }
347            }
348
349            // ensure that only parameters in all variants are required or tags
350            for parameter in variant_parameters.values_mut() {
351                parameter.update_generation(generation);
352            }
353        }
354
355        // merge the variant parameters into the top-level parameters
356        parameters.extend(
357            variant_parameters
358                .into_values()
359                .map(VariantParameter::into_parameter),
360        );
361    }
362}
363
364#[derive(Debug, Error)]
365pub enum SchemaError {
366    #[error("codec class' cached config schema is invalid")]
367    InvalidCachedJsonSchema { source: PythonizeError },
368    #[error("extracting the codec signature failed")]
369    SignatureExtraction {
370        #[from]
371        source: PyErr,
372    },
373    #[error("codec's signature must not contain an `*args` parameter")]
374    ArgsParameterInSignature,
375    #[error("{name} parameter's default value is invalid")]
376    InvalidParameterDefault {
377        name: String,
378        source: PythonizeError,
379    },
380    #[error("codec class's `__doc__` must be a string")]
381    InvalidClassDocs { source: PyErr },
382    #[error("codec class must have a string `__name__`")]
383    InvalidClassName { source: PyErr },
384}
385
386struct Parameters<'a> {
387    named: Vec<Parameter<'a>>,
388    additional: bool,
389}
390
391struct Parameter<'a> {
392    name: &'a str,
393    required: bool,
394    default: Option<&'a Value>,
395    docs: Option<Cow<'a, str>>,
396}
397
398impl<'a> Parameter<'a> {
399    #[must_use]
400    pub fn new(name: &'a str, parameter: &'a Value, required: &[Value]) -> Self {
401        Self {
402            name,
403            required: required
404                .iter()
405                .any(|r| matches!(r, Value::String(n) if n == name)),
406            default: parameter.get("default"),
407            docs: match parameter.get("description") {
408                Some(Value::String(docs)) => Some(Cow::Borrowed(docs)),
409                _ => None,
410            },
411        }
412    }
413}
414
415struct VariantParameter<'a> {
416    generation: usize,
417    parameter: Parameter<'a>,
418    #[allow(clippy::type_complexity)]
419    tag_docs: Option<Vec<(&'a Value, Option<Cow<'a, str>>)>>,
420}
421
422impl<'a> VariantParameter<'a> {
423    #[must_use]
424    pub fn new(
425        generation: usize,
426        name: &'a str,
427        parameter: &'a Value,
428        required: &[Value],
429        variant_docs: Option<&'a str>,
430    ) -> Self {
431        let r#const = parameter.get("const");
432
433        let mut parameter = Parameter::new(name, parameter, required);
434        parameter.required &= generation == 0;
435
436        let tag_docs = match r#const {
437            // a tag parameter must be introduced in the first generation
438            Some(r#const) if generation == 0 => {
439                #[allow(clippy::or_fun_call)]
440                let docs = parameter.docs.take().or(variant_docs.map(Cow::Borrowed));
441                Some(vec![(r#const, docs)])
442            }
443            _ => None,
444        };
445
446        Self {
447            generation,
448            parameter,
449            tag_docs,
450        }
451    }
452
453    pub fn merge(
454        &mut self,
455        generation: usize,
456        name: &'a str,
457        parameter: &'a Value,
458        required: &[Value],
459        variant_docs: Option<&'a str>,
460    ) {
461        self.generation = generation;
462
463        let r#const = parameter.get("const");
464
465        let parameter = Parameter::new(name, parameter, required);
466
467        self.parameter.required &= parameter.required;
468        if self.parameter.default != parameter.default {
469            self.parameter.default = None;
470        }
471
472        if let Some(tag_docs) = &mut self.tag_docs {
473            // we're building docs for a tag-like parameter
474            if let Some(r#const) = r#const {
475                #[allow(clippy::or_fun_call)]
476                tag_docs.push((r#const, parameter.docs.or(variant_docs.map(Cow::Borrowed))));
477            } else {
478                // mixing tag and non-tag parameter => no docs
479                self.tag_docs = None;
480                self.parameter.docs = None;
481            }
482        } else {
483            // we're building docs for a normal parameter
484            if r#const.is_none() {
485                // we only accept always matching docs for normal parameters
486                if self.parameter.docs != parameter.docs {
487                    self.parameter.docs = None;
488                }
489            } else {
490                // mixing tag and non-tag parameter => no docs
491                self.tag_docs = None;
492            }
493        }
494    }
495
496    pub fn update_generation(&mut self, generation: usize) {
497        if self.generation < generation {
498            // required and tag parameters must appear in all generations
499            self.parameter.required = false;
500            self.tag_docs = None;
501        }
502    }
503
504    #[must_use]
505    pub fn into_parameter(mut self) -> Parameter<'a> {
506        if let Some(tag_docs) = self.tag_docs {
507            let mut docs = String::from("\n");
508
509            for (tag, tag_docs) in tag_docs {
510                docs.push_str(" - ");
511                docs.push_str(&format!("{tag}"));
512                if let Some(tag_docs) = tag_docs {
513                    docs.push_str(": ");
514                    docs.push_str(&tag_docs);
515                }
516                docs.push('\n');
517            }
518
519            docs.truncate(docs.trim_end().len());
520
521            self.parameter.docs = Some(Cow::Owned(docs));
522        }
523
524        self.parameter
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use schemars::{schema_for, JsonSchema};
531
532    use super::*;
533
534    #[test]
535    fn schema() {
536        assert_eq!(
537            format!("{}", schema_for!(MyCodec).to_value()),
538            r#"{"type":"object","properties":{"param":{"type":["integer","null"],"format":"int32","description":"An optional integer value."}},"unevaluatedProperties":false,"oneOf":[{"type":"object","description":"Mode a.","properties":{"value":{"type":"boolean","description":"A boolean value."},"common":{"type":"string","description":"A common string value."},"mode":{"type":"string","const":"A"}},"required":["mode","value","common"]},{"type":"object","description":"Mode b.","properties":{"common":{"type":"string","description":"A common string value."},"mode":{"type":"string","const":"B"}},"required":["mode","common"]}],"description":"A codec that does something on encoding and decoding.","title":"MyCodec","$schema":"https://json-schema.org/draft/2020-12/schema"}"#
539        );
540    }
541
542    #[test]
543    fn docs() {
544        assert_eq!(
545            docs_from_schema(&schema_for!(MyCodec), "my-codec").as_deref(),
546            Some(
547                r#"# my-codec (MyCodec)
548
549A codec that does something on encoding and decoding.
550
551## Parameters
552
553 - common (required): A common string value.
554 - mode (required): 
555    - "A": Mode a.
556    - "B": Mode b.
557 - param (optional): An optional integer value.
558 - value (optional): A boolean value."#
559            )
560        );
561    }
562
563    #[test]
564    fn signature() {
565        assert_eq!(
566            signature_from_schema(&schema_for!(MyCodec)),
567            "self, common, mode, param=None, value=None",
568        );
569    }
570
571    #[allow(dead_code)]
572    #[derive(JsonSchema)]
573    #[schemars(deny_unknown_fields)]
574    /// A codec that does something on encoding and decoding.
575    struct MyCodec {
576        /// An optional integer value.
577        #[schemars(default, skip_serializing_if = "Option::is_none")]
578        param: Option<i32>,
579        /// The flattened configuration.
580        #[schemars(flatten)]
581        config: Config,
582    }
583
584    #[allow(dead_code)]
585    #[derive(JsonSchema)]
586    #[schemars(tag = "mode")]
587    #[schemars(deny_unknown_fields)]
588    enum Config {
589        /// Mode a.
590        A {
591            /// A boolean value.
592            value: bool,
593            /// A common string value.
594            common: String,
595        },
596        /// Mode b.
597        B {
598            /// A common string value.
599            common: String,
600        },
601    }
602}