use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
pub type IngestName = String;
pub type IngestDescription = String;
pub type IngestRegex = String;
pub type IngestThreshold = f64;
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestDType {
String(String),
Int(i64),
Float(f64),
Bool(bool),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum IngestValidatorMode {
Partial,
Full,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestValidatorDict {
pub pattern: IngestRegex,
#[serde(default)]
pub mode: Option<IngestValidatorMode>,
#[serde(default)]
pub exclude: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestValidator {
Regex(IngestRegex),
Dict(IngestValidatorDict),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestEntityPropertyDict {
pub name: IngestName,
#[serde(default)]
pub dtype: Option<IngestDType>,
#[serde(default)]
pub validator: Option<IngestValidator>,
#[serde(default)]
pub threshold: Option<IngestThreshold>,
#[serde(default)]
pub description: Option<IngestDescription>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestEntityProperty {
Description(IngestDescription),
Dict(IngestEntityPropertyDict),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestEntity {
Stringish(String),
SingleEntityDict(BTreeMap<String, IngestEntityProperty>),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestEntityList {
List(Vec<IngestEntity>),
Dict(BTreeMap<String, IngestEntityProperty>),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestStructureProperty {
#[serde(default)]
pub choices: Option<IngestEntityList>,
#[serde(default)]
pub description: Option<IngestDescription>,
#[serde(default)]
pub value: Option<String>,
#[serde(default)]
pub dtype: Option<IngestDType>,
#[serde(default)]
pub validator: Option<IngestValidator>,
#[serde(default)]
pub threshold: Option<IngestThreshold>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestStructureProperties {
EntityDict(BTreeMap<String, IngestEntityProperty>),
EntityList(IngestEntityList),
StructurePropertiesDict(BTreeMap<String, IngestStructureProperty>),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestNamedStructure {
pub name: IngestName,
#[serde(flatten)]
pub props: BTreeMap<String, IngestStructureProperty>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestJsonNameKeyedStructure(pub BTreeMap<String, serde_json::Value>);
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestJsonStructure {
NamedStructure(IngestNamedStructure),
EntityList(IngestEntityList),
JsonNameKeyedStructure(IngestJsonNameKeyedStructure),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestJsonStructureList {
Single(IngestJsonStructure),
List(Vec<IngestJsonStructure>),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestClassification {
pub task: IngestEntity,
pub labels: IngestEntityList,
#[serde(default)]
pub threshold: Option<IngestThreshold>,
#[serde(default, alias = "cls_threshold")]
pub cls_threshold: Option<IngestThreshold>,
#[serde(default)]
pub multi_label: Option<bool>,
#[serde(default)]
pub label_descriptions: Option<BTreeMap<String, IngestEntityProperty>>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct IngestEntityAcquired {
pub head: IngestEntity,
pub tail: IngestEntity,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum IngestRelation {
Name(String),
NameDescription(BTreeMap<String, String>),
RelationEntityAcquired(BTreeMap<String, IngestEntityAcquired>),
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Default)]
#[serde(deny_unknown_fields)]
pub struct IngestSchema {
#[serde(default)]
pub entities: Option<IngestEntityList>,
#[serde(default)]
pub json_structures: Option<IngestJsonStructureList>,
#[serde(default)]
pub classifications: Option<Vec<IngestClassification>>,
#[serde(default)]
pub relations: Option<Vec<IngestRelation>>,
}
impl IngestSchema {
pub fn from_json_str(s: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(s)
}
pub fn from_json_slice(bytes: &[u8]) -> Result<Self, serde_json::Error> {
serde_json::from_slice(bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ingest_entity_string_preserved() {
let s = r#"{ "entities": ["gene::str::0.9::gene symbol"] }"#;
let schema: IngestSchema = serde_json::from_str(s).unwrap();
let IngestEntityList::List(items) = schema.entities.unwrap() else {
panic!("expected list");
};
match &items[0] {
IngestEntity::Stringish(v) => assert_eq!(v, "gene::str::0.9::gene symbol"),
_ => panic!("expected stringish"),
}
}
#[test]
fn ingest_entity_dict_form() {
let s = r#"
{
"entities": {
"gene": "gene symbol",
"score": {
"name": "score",
"dtype": "float",
"threshold": 0.8
}
}
}
"#;
let schema: IngestSchema = serde_json::from_str(s).unwrap();
let IngestEntityList::Dict(map) = schema.entities.unwrap() else {
panic!("expected dict");
};
assert!(map.contains_key("gene"));
assert!(map.contains_key("score"));
}
#[test]
fn ingest_classification_aliases_preserved() {
let s = r#"
{
"classifications": [
{
"task": "sentiment",
"labels": ["positive", "negative"],
"threshold": 0.4,
"cls_threshold": 0.7
}
]
}
"#;
let schema: IngestSchema = serde_json::from_str(s).unwrap();
let cls = &schema.classifications.unwrap()[0];
assert_eq!(cls.threshold, Some(0.4));
assert_eq!(cls.cls_threshold, Some(0.7));
}
#[test]
fn ingest_relation_variants() {
let s = r#"
{
"relations": [
"interacts_with",
{ "expressed_in": "entity is expressed in tissue" },
{ "binds_to": { "head": "", "tail": "" } },
{ "associated_with": { "head": "gene", "tail": "disease" } }
]
}
"#;
let schema: IngestSchema = serde_json::from_str(s).unwrap();
let relations = schema.relations.unwrap();
assert_eq!(relations.len(), 4);
assert!(matches!(relations[0], IngestRelation::Name(_)));
assert!(matches!(relations[1], IngestRelation::NameDescription(_)));
assert!(matches!(
relations[2],
IngestRelation::RelationEntityAcquired(_)
));
assert!(matches!(
relations[3],
IngestRelation::RelationEntityAcquired(_)
));
}
#[test]
fn ingest_rejects_unknown_top_level_key() {
let s = r#"{ "entities": ["gene"], "extra_root": 1 }"#;
let err = serde_json::from_str::<IngestSchema>(s).expect_err("unknown key");
assert!(
err.to_string().contains("unknown field"),
"unexpected serde error: {err}"
);
}
#[test]
fn ingest_rejects_unknown_top_level_key_python_fixture() {
let s = r#"{"entities": ["gene"], "not_an_ie_field": true}"#;
let err = serde_json::from_str::<IngestSchema>(s).expect_err("unknown key");
assert!(err.to_string().contains("unknown field"), "{err}");
}
#[test]
fn ingest_name_keyed_structure_preserved_as_dirty_map() {
let s = r#"
{
"json_structures": {
"patient_record": {
"id": { "description": "identifier", "dtype": "str" }
}
}
}
"#;
let schema: IngestSchema = serde_json::from_str(s).unwrap();
let IngestJsonStructureList::Single(IngestJsonStructure::JsonNameKeyedStructure(
IngestJsonNameKeyedStructure(map),
)) = schema.json_structures.unwrap()
else {
panic!("expected name-keyed structure");
};
assert!(map.contains_key("patient_record"));
}
}