ie-schema 0.1.5

A flexible schema specification and parser for information extraction tasks.
Documentation
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;

/// Serde-facing ingest structures for the IE JSON grammar.
///
/// The root [`IngestSchema`] uses `deny_unknown_fields` so documents are not misclassified
/// (for example JSON Schema with `type` / `properties` at the root). Nested enums and maps
/// remain shape-tolerant where needed.
pub type IngestName = String;
pub type IngestDescription = String;
pub type IngestRegex = String;
pub type IngestThreshold = f64;

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestDType {
    String(String),
    Int(i64),
    Float(f64),
    Bool(bool),
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum IngestValidatorMode {
    Partial,
    Full,
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestValidatorDict {
    pub pattern: IngestRegex,
    #[serde(default)]
    pub mode: Option<IngestValidatorMode>,
    #[serde(default)]
    pub exclude: Option<bool>,
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestValidator {
    Regex(IngestRegex),
    Dict(IngestValidatorDict),
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestEntityPropertyDict {
    pub name: IngestName,
    #[serde(default)]
    pub dtype: Option<IngestDType>,
    #[serde(default)]
    pub validator: Option<IngestValidator>,
    #[serde(default)]
    pub threshold: Option<IngestThreshold>,
    #[serde(default)]
    pub description: Option<IngestDescription>,
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestEntityProperty {
    Description(IngestDescription),
    Dict(IngestEntityPropertyDict),
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestEntity {
    /// Plain string in phase 1:
    /// could be a Name or a ColonDelimitedEntity mini-language.
    Stringish(String),

    /// Semantically intended to be a single-entry dict, but not enforced here.
    SingleEntityDict(BTreeMap<String, IngestEntityProperty>),
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestEntityList {
    List(Vec<IngestEntity>),
    Dict(BTreeMap<String, IngestEntityProperty>),
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestStructureProperty {
    #[serde(default)]
    pub choices: Option<IngestEntityList>,
    #[serde(default)]
    pub description: Option<IngestDescription>,
    #[serde(default)]
    pub value: Option<String>,
    #[serde(default)]
    pub dtype: Option<IngestDType>,
    #[serde(default)]
    pub validator: Option<IngestValidator>,
    #[serde(default)]
    pub threshold: Option<IngestThreshold>,
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestStructureProperties {
    EntityDict(BTreeMap<String, IngestEntityProperty>),
    EntityList(IngestEntityList),
    StructurePropertiesDict(BTreeMap<String, IngestStructureProperty>),
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestNamedStructure {
    pub name: IngestName,
    #[serde(flatten)]
    pub props: BTreeMap<String, IngestStructureProperty>,
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestJsonNameKeyedStructure(pub BTreeMap<String, serde_json::Value>);

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestJsonStructure {
    NamedStructure(IngestNamedStructure),
    EntityList(IngestEntityList),
    JsonNameKeyedStructure(IngestJsonNameKeyedStructure),
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestJsonStructureList {
    Single(IngestJsonStructure),
    List(Vec<IngestJsonStructure>),
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestClassification {
    pub task: IngestEntity,
    pub labels: IngestEntityList,

    #[serde(default)]
    pub threshold: Option<IngestThreshold>,

    #[serde(default, alias = "cls_threshold")]
    pub cls_threshold: Option<IngestThreshold>,

    #[serde(default)]
    pub multi_label: Option<bool>,

    #[serde(default)]
    pub label_descriptions: Option<BTreeMap<String, IngestEntityProperty>>,
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestEntityAcquired {
    pub head: IngestEntity,
    pub tail: IngestEntity,
}

#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestRelation {
    Name(String),
    NameDescription(BTreeMap<String, String>),
    RelationEntityAcquired(BTreeMap<String, IngestEntityAcquired>),
}

/// Root ingest document: unknown top-level keys are rejected so JSON Schema documents
/// (e.g. `type` / `properties`) do not deserialize as an empty ingest.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Default)]
#[serde(deny_unknown_fields)]
pub struct IngestSchema {
    #[serde(default)]
    pub entities: Option<IngestEntityList>,
    #[serde(default)]
    pub json_structures: Option<IngestJsonStructureList>,
    #[serde(default)]
    pub classifications: Option<Vec<IngestClassification>>,
    #[serde(default)]
    pub relations: Option<Vec<IngestRelation>>,
}

impl IngestSchema {
    pub fn from_json_str(s: &str) -> Result<Self, serde_json::Error> {
        serde_json::from_str(s)
    }

    pub fn from_json_slice(bytes: &[u8]) -> Result<Self, serde_json::Error> {
        serde_json::from_slice(bytes)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn ingest_entity_string_preserved() {
        let s = r#"{ "entities": ["gene::str::0.9::gene symbol"] }"#;
        let schema: IngestSchema = serde_json::from_str(s).unwrap();

        let IngestEntityList::List(items) = schema.entities.unwrap() else {
            panic!("expected list");
        };

        match &items[0] {
            IngestEntity::Stringish(v) => assert_eq!(v, "gene::str::0.9::gene symbol"),
            _ => panic!("expected stringish"),
        }
    }

    #[test]
    fn ingest_entity_dict_form() {
        let s = r#"
        {
            "entities": {
                "gene": "gene symbol",
                "score": {
                    "name": "score",
                    "dtype": "float",
                    "threshold": 0.8
                }
            }
        }
        "#;

        let schema: IngestSchema = serde_json::from_str(s).unwrap();
        let IngestEntityList::Dict(map) = schema.entities.unwrap() else {
            panic!("expected dict");
        };

        assert!(map.contains_key("gene"));
        assert!(map.contains_key("score"));
    }

    #[test]
    fn ingest_classification_aliases_preserved() {
        let s = r#"
        {
            "classifications": [
                {
                    "task": "sentiment",
                    "labels": ["positive", "negative"],
                    "threshold": 0.4,
                    "cls_threshold": 0.7
                }
            ]
        }
        "#;

        let schema: IngestSchema = serde_json::from_str(s).unwrap();
        let cls = &schema.classifications.unwrap()[0];
        assert_eq!(cls.threshold, Some(0.4));
        assert_eq!(cls.cls_threshold, Some(0.7));
    }

    #[test]
    fn ingest_relation_variants() {
        let s = r#"
        {
            "relations": [
                "interacts_with",
                { "expressed_in": "entity is expressed in tissue" },
                { "binds_to": { "head": "", "tail": "" } },
                { "associated_with": { "head": "gene", "tail": "disease" } }
            ]
        }
        "#;

        let schema: IngestSchema = serde_json::from_str(s).unwrap();
        let relations = schema.relations.unwrap();
        assert_eq!(relations.len(), 4);

        assert!(matches!(relations[0], IngestRelation::Name(_)));
        assert!(matches!(relations[1], IngestRelation::NameDescription(_)));
        assert!(matches!(
            relations[2],
            IngestRelation::RelationEntityAcquired(_)
        ));
        assert!(matches!(
            relations[3],
            IngestRelation::RelationEntityAcquired(_)
        ));
    }

    #[test]
    fn ingest_rejects_unknown_top_level_key() {
        let s = r#"{ "entities": ["gene"], "extra_root": 1 }"#;
        let err = serde_json::from_str::<IngestSchema>(s).expect_err("unknown key");
        assert!(
            err.to_string().contains("unknown field"),
            "unexpected serde error: {err}"
        );
    }

    #[test]
    fn ingest_rejects_unknown_top_level_key_python_fixture() {
        let s = r#"{"entities": ["gene"], "not_an_ie_field": true}"#;
        let err = serde_json::from_str::<IngestSchema>(s).expect_err("unknown key");
        assert!(err.to_string().contains("unknown field"), "{err}");
    }

    #[test]
    fn ingest_name_keyed_structure_preserved_as_dirty_map() {
        let s = r#"
        {
            "json_structures": {
                "patient_record": {
                    "id": { "description": "identifier", "dtype": "str" }
                }
            }
        }
        "#;

        let schema: IngestSchema = serde_json::from_str(s).unwrap();
        let IngestJsonStructureList::Single(IngestJsonStructure::JsonNameKeyedStructure(
            IngestJsonNameKeyedStructure(map),
        )) = schema.json_structures.unwrap()
        else {
            panic!("expected name-keyed structure");
        };

        assert!(map.contains_key("patient_record"));
    }
}