1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4pub type IngestName = String;
10pub type IngestDescription = String;
11pub type IngestRegex = String;
12pub type IngestThreshold = f64;
13
14#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
15#[serde(untagged)]
16pub enum IngestDType {
17 String(String),
18 Int(i64),
19 Float(f64),
20 Bool(bool),
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
24#[serde(rename_all = "lowercase")]
25pub enum IngestValidatorMode {
26 Partial,
27 Full,
28}
29
30#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
31pub struct IngestValidatorDict {
32 pub pattern: IngestRegex,
33 #[serde(default)]
34 pub mode: Option<IngestValidatorMode>,
35 #[serde(default)]
36 pub exclude: Option<bool>,
37}
38
39#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
40#[serde(untagged)]
41pub enum IngestValidator {
42 Regex(IngestRegex),
43 Dict(IngestValidatorDict),
44}
45
46#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
47pub struct IngestEntityPropertyDict {
48 pub name: IngestName,
49 #[serde(default)]
50 pub dtype: Option<IngestDType>,
51 #[serde(default)]
52 pub validator: Option<IngestValidator>,
53 #[serde(default)]
54 pub threshold: Option<IngestThreshold>,
55 #[serde(default)]
56 pub description: Option<IngestDescription>,
57}
58
59#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
60#[serde(untagged)]
61pub enum IngestEntityProperty {
62 Description(IngestDescription),
63 Dict(IngestEntityPropertyDict),
64}
65
66#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
67#[serde(untagged)]
68pub enum IngestEntity {
69 Stringish(String),
72
73 SingleEntityDict(BTreeMap<String, IngestEntityProperty>),
75}
76
77#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
78#[serde(untagged)]
79pub enum IngestEntityList {
80 List(Vec<IngestEntity>),
81 Dict(BTreeMap<String, IngestEntityProperty>),
82}
83
84#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
85pub struct IngestStructureProperty {
86 #[serde(default)]
87 pub choices: Option<IngestEntityList>,
88 #[serde(default)]
89 pub description: Option<IngestDescription>,
90 #[serde(default)]
91 pub value: Option<String>,
92 #[serde(default)]
93 pub dtype: Option<IngestDType>,
94 #[serde(default)]
95 pub validator: Option<IngestValidator>,
96 #[serde(default)]
97 pub threshold: Option<IngestThreshold>,
98}
99
100#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
101#[serde(untagged)]
102pub enum IngestStructureProperties {
103 EntityDict(BTreeMap<String, IngestEntityProperty>),
104 EntityList(IngestEntityList),
105 StructurePropertiesDict(BTreeMap<String, IngestStructureProperty>),
106}
107
108#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
109pub struct IngestNamedStructure {
110 pub name: IngestName,
111 #[serde(flatten)]
112 pub props: BTreeMap<String, IngestStructureProperty>,
113}
114
115#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
116pub struct IngestJsonNameKeyedStructure(pub BTreeMap<String, serde_json::Value>);
117
118#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
119#[serde(untagged)]
120pub enum IngestJsonStructure {
121 NamedStructure(IngestNamedStructure),
122 EntityList(IngestEntityList),
123 JsonNameKeyedStructure(IngestJsonNameKeyedStructure),
124}
125
126#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
127#[serde(untagged)]
128pub enum IngestJsonStructureList {
129 Single(IngestJsonStructure),
130 List(Vec<IngestJsonStructure>),
131}
132
133#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
134pub struct IngestClassification {
135 pub task: IngestEntity,
136 pub labels: IngestEntityList,
137
138 #[serde(default)]
139 pub threshold: Option<IngestThreshold>,
140
141 #[serde(default, alias = "cls_threshold")]
142 pub cls_threshold: Option<IngestThreshold>,
143
144 #[serde(default)]
145 pub multi_label: Option<bool>,
146
147 #[serde(default)]
148 pub label_descriptions: Option<BTreeMap<String, IngestEntityProperty>>,
149}
150
151#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
152pub struct IngestEntityAcquired {
153 pub head: IngestEntity,
154 pub tail: IngestEntity,
155}
156
157#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
158#[serde(untagged)]
159pub enum IngestRelation {
160 Name(String),
161 NameDescription(BTreeMap<String, String>),
162 RelationEntityAcquired(BTreeMap<String, IngestEntityAcquired>),
163}
164
165#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Default)]
168#[serde(deny_unknown_fields)]
169pub struct IngestSchema {
170 #[serde(default)]
171 pub entities: Option<IngestEntityList>,
172 #[serde(default)]
173 pub json_structures: Option<IngestJsonStructureList>,
174 #[serde(default)]
175 pub classifications: Option<Vec<IngestClassification>>,
176 #[serde(default)]
177 pub relations: Option<Vec<IngestRelation>>,
178}
179
180impl IngestSchema {
181 pub fn from_json_str(s: &str) -> Result<Self, serde_json::Error> {
182 serde_json::from_str(s)
183 }
184
185 pub fn from_json_slice(bytes: &[u8]) -> Result<Self, serde_json::Error> {
186 serde_json::from_slice(bytes)
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn ingest_entity_string_preserved() {
196 let s = r#"{ "entities": ["gene::str::0.9::gene symbol"] }"#;
197 let schema: IngestSchema = serde_json::from_str(s).unwrap();
198
199 let IngestEntityList::List(items) = schema.entities.unwrap() else {
200 panic!("expected list");
201 };
202
203 match &items[0] {
204 IngestEntity::Stringish(v) => assert_eq!(v, "gene::str::0.9::gene symbol"),
205 _ => panic!("expected stringish"),
206 }
207 }
208
209 #[test]
210 fn ingest_entity_dict_form() {
211 let s = r#"
212 {
213 "entities": {
214 "gene": "gene symbol",
215 "score": {
216 "name": "score",
217 "dtype": "float",
218 "threshold": 0.8
219 }
220 }
221 }
222 "#;
223
224 let schema: IngestSchema = serde_json::from_str(s).unwrap();
225 let IngestEntityList::Dict(map) = schema.entities.unwrap() else {
226 panic!("expected dict");
227 };
228
229 assert!(map.contains_key("gene"));
230 assert!(map.contains_key("score"));
231 }
232
233 #[test]
234 fn ingest_classification_aliases_preserved() {
235 let s = r#"
236 {
237 "classifications": [
238 {
239 "task": "sentiment",
240 "labels": ["positive", "negative"],
241 "threshold": 0.4,
242 "cls_threshold": 0.7
243 }
244 ]
245 }
246 "#;
247
248 let schema: IngestSchema = serde_json::from_str(s).unwrap();
249 let cls = &schema.classifications.unwrap()[0];
250 assert_eq!(cls.threshold, Some(0.4));
251 assert_eq!(cls.cls_threshold, Some(0.7));
252 }
253
254 #[test]
255 fn ingest_relation_variants() {
256 let s = r#"
257 {
258 "relations": [
259 "interacts_with",
260 { "expressed_in": "entity is expressed in tissue" },
261 { "binds_to": { "head": "", "tail": "" } },
262 { "associated_with": { "head": "gene", "tail": "disease" } }
263 ]
264 }
265 "#;
266
267 let schema: IngestSchema = serde_json::from_str(s).unwrap();
268 let relations = schema.relations.unwrap();
269 assert_eq!(relations.len(), 4);
270
271 assert!(matches!(relations[0], IngestRelation::Name(_)));
272 assert!(matches!(relations[1], IngestRelation::NameDescription(_)));
273 assert!(matches!(
274 relations[2],
275 IngestRelation::RelationEntityAcquired(_)
276 ));
277 assert!(matches!(
278 relations[3],
279 IngestRelation::RelationEntityAcquired(_)
280 ));
281 }
282
283 #[test]
284 fn ingest_rejects_unknown_top_level_key() {
285 let s = r#"{ "entities": ["gene"], "extra_root": 1 }"#;
286 let err = serde_json::from_str::<IngestSchema>(s).expect_err("unknown key");
287 assert!(
288 err.to_string().contains("unknown field"),
289 "unexpected serde error: {err}"
290 );
291 }
292
293 #[test]
294 fn ingest_rejects_unknown_top_level_key_python_fixture() {
295 let s = r#"{"entities": ["gene"], "not_an_ie_field": true}"#;
296 let err = serde_json::from_str::<IngestSchema>(s).expect_err("unknown key");
297 assert!(err.to_string().contains("unknown field"), "{err}");
298 }
299
300 #[test]
301 fn ingest_name_keyed_structure_preserved_as_dirty_map() {
302 let s = r#"
303 {
304 "json_structures": {
305 "patient_record": {
306 "id": { "description": "identifier", "dtype": "str" }
307 }
308 }
309 }
310 "#;
311
312 let schema: IngestSchema = serde_json::from_str(s).unwrap();
313 let IngestJsonStructureList::Single(IngestJsonStructure::JsonNameKeyedStructure(
314 IngestJsonNameKeyedStructure(map),
315 )) = schema.json_structures.unwrap()
316 else {
317 panic!("expected name-keyed structure");
318 };
319
320 assert!(map.contains_key("patient_record"));
321 }
322}