mwapi_responses_derive 0.5.1

Automatically generate strict types for MediaWiki API responses (macro)
Documentation
/*
Copyright (C) 2020-2021 Kunal Mehta <legoktm@debian.org>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

use crate::builder::StructField;
use crate::Result;
use quote::quote;
use serde::Deserialize;
use std::collections::{BTreeMap, HashMap};
use std::fs;
use syn::{Error, LitStr};

#[derive(Debug, Deserialize)]
pub(crate) struct Metadata {
    // Unused in Rust currently
    // pub(crate) name: String,
    // pub(crate) mode: String,
    pub(crate) fieldname: String,
    pub(crate) prop: Option<String>,
    pub(crate) fields: Vec<Field>,
    #[serde(default)]
    pub(crate) wrap_in_vec: bool,
}

impl Metadata {
    pub fn new(name: &LitStr) -> Result<Self> {
        // XXX: Is loading JSON files the best way to store/access this metadata?
        let path = format!(
            "{}/data/query+{}.json",
            env!("CARGO_MANIFEST_DIR"),
            name.value()
        );
        let contents = match fs::read_to_string(path) {
            Ok(contents) => contents,
            Err(_) => {
                return Err(Error::new(
                    name.span(),
                    format!("unable to load metadata for \"{}\", please file a bug to add support for this API module", name.value()),
                ))
            }
        };
        serde_json::from_str(&contents).map_err(|err| {
            // This should not be reachable by users
            Error::new(
                name.span(),
                format!(
                    "internal error: parsing metadata for \"{}\" failed: {}",
                    name.value(),
                    err
                ),
            )
        })
    }

    /// Fields that match a specific prop
    pub fn get_fields(self, props: &[&str]) -> Vec<Field> {
        let mut found = vec![];
        for field in self.fields {
            if field.matches_prop(props) {
                found.push(field);
            }
        }
        found
    }
}

#[derive(Debug, Deserialize)]
pub(crate) struct Field {
    /// Name of field in API response
    pub(crate) name: String,
    /// Rust type to map to
    pub(crate) type_: String,
    /// prop= value that gives this field. The magic string
    /// "=default" can be used to indicate that the field
    /// is always available.
    pub(crate) prop: String,
    /// Value this field should be renamed to (e.g. "type" -> "type_")
    pub(crate) rename: Option<String>,
    /// Whether `#[serde(default)]` should be set on this field
    #[serde(default)]
    pub(crate) default: bool,
    /// If `#[serde(deserialize_with(...)]` should be used
    pub(crate) deserialize_with: Option<String>,
}

impl Field {
    fn matches_prop(&self, props: &[&str]) -> bool {
        self.prop == "=default"
            || self.prop.split("||").any(|x| props.contains(&x))
    }
}

fn normalize_type(input: &str) -> String {
    if input.starts_with("HashMap<") {
        input.replace("HashMap<", "::std::collections::HashMap<")
    } else if input == "enum" {
        "String".to_string()
    } else {
        input.to_string()
    }
}

impl TryFrom<Field> for StructField {
    type Error = Error;

    fn try_from(other: Field) -> Result<Self> {
        let parsed: syn::Type = syn::parse_str(&normalize_type(&other.type_))?;
        Ok(Self {
            name: other.name,
            type_: quote! { #parsed },
            default: other.default
                || (other.type_.starts_with("Option<")
                    && other.type_.ends_with('>')),
            rename: other.rename,
            deserialize_with: other.deserialize_with,
        })
    }
}

#[derive(Default)]
pub(crate) struct FieldContainer {
    pub(crate) top: BTreeMap<String, Field>,
    pub(crate) sub: HashMap<String, BTreeMap<String, Field>>,
}

impl FieldContainer {
    pub(crate) fn add_fields(
        &mut self,
        wrap_field: Option<String>,
        fields: Vec<Field>,
    ) {
        let map: BTreeMap<String, Field> = fields
            .into_iter()
            .map(|field| (field.name.to_string(), field))
            .collect();
        if let Some(fieldname) = wrap_field {
            self.sub.entry(fieldname).or_default().extend(map);
        } else {
            self.top.extend(map);
        }
    }
}

/// Fields that are present in every action=query response
/// when titles=/pageids= is used
pub(crate) fn default_query_fields() -> BTreeMap<String, Field> {
    let fields = [
        // #[serde(default)]
        // ns: i32
        Field {
            name: "ns".to_string(),
            type_: "i32".to_string(),
            prop: "=default".to_string(),
            rename: None,
            // Not present if title is invalid
            default: true,
            deserialize_with: None,
        },
        // title: String
        Field {
            name: "title".to_string(),
            type_: "String".to_string(),
            prop: "=default".to_string(),
            rename: None,
            default: false,
            deserialize_with: None,
        },
        // pageid: Option<u32>
        Field {
            name: "pageid".to_string(),
            type_: "Option<u32>".to_string(),
            prop: "=default".to_string(),
            rename: None,
            default: false,
            deserialize_with: None,
        },
        // #[serde(default)]
        // missing: bool
        Field {
            name: "missing".to_string(),
            type_: "bool".to_string(),
            prop: "=default".to_string(),
            rename: None,
            default: true,
            deserialize_with: None,
        },
        // #[serde(default)]
        // invalid: bool
        Field {
            name: "invalid".to_string(),
            type_: "bool".to_string(),
            prop: "=default".to_string(),
            rename: None,
            default: true,
            deserialize_with: None,
        },
        // invalidreason: Option<String>
        Field {
            name: "invalidreason".to_string(),
            type_: "Option<String>".to_string(),
            prop: "=default".to_string(),
            rename: None,
            default: false,
            deserialize_with: None,
        },
    ];
    fields
        .into_iter()
        .map(|field| (field.name.to_string(), field))
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_field() {
        let field = Field {
            name: "test".to_string(),
            type_: "String".to_string(),
            prop: "test1||test2".to_string(),
            rename: None,
            default: false,
            deserialize_with: None,
        };
        assert!(field.matches_prop(&["test1"]));
        assert!(field.matches_prop(&["test2"]));
        assert!(!field.matches_prop(&["test3"]));
    }

    #[test]
    fn test_default_query_fields() {
        let fields = default_query_fields();
        assert!(fields.contains_key("ns"));
    }
}