Skip to main content

ctxgraph_extract/
schema.rs

1use std::collections::BTreeMap;
2use std::path::Path;
3
4use serde::{Deserialize, Serialize};
5
6/// Extraction schema defining which entity types and relation types to extract.
7///
8/// Loaded from a `ctxgraph.toml` file or constructed via `ExtractionSchema::default()`.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ExtractionSchema {
11    pub name: String,
12    pub entity_types: BTreeMap<String, String>,
13    pub relation_types: BTreeMap<String, RelationSpec>,
14}
15
16/// Specification for a relation type — which entity types can be head/tail.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct RelationSpec {
19    pub head: Vec<String>,
20    pub tail: Vec<String>,
21    pub description: String,
22}
23
24/// Raw TOML structure for deserialization.
25#[derive(Debug, Deserialize)]
26struct SchemaToml {
27    schema: SchemaSection,
28}
29
30#[derive(Debug, Deserialize)]
31struct SchemaSection {
32    name: String,
33    entities: BTreeMap<String, String>,
34    #[serde(default)]
35    relations: BTreeMap<String, RelationSpecToml>,
36}
37
38#[derive(Debug, Deserialize)]
39struct RelationSpecToml {
40    head: Vec<String>,
41    tail: Vec<String>,
42    #[serde(default)]
43    description: String,
44}
45
46impl ExtractionSchema {
47    /// Load schema from a TOML file.
48    pub fn load(path: &Path) -> Result<Self, SchemaError> {
49        let content = std::fs::read_to_string(path).map_err(|e| SchemaError::Io {
50            path: path.display().to_string(),
51            source: e,
52        })?;
53        Self::from_toml(&content)
54    }
55
56    /// Parse schema from a TOML string.
57    pub fn from_toml(content: &str) -> Result<Self, SchemaError> {
58        let parsed: SchemaToml =
59            toml::from_str(content).map_err(|e| SchemaError::Parse(e.to_string()))?;
60
61        let relation_types = parsed
62            .schema
63            .relations
64            .into_iter()
65            .map(|(k, v)| {
66                (
67                    k,
68                    RelationSpec {
69                        head: v.head,
70                        tail: v.tail,
71                        description: v.description,
72                    },
73                )
74            })
75            .collect();
76
77        Ok(Self {
78            name: parsed.schema.name,
79            entity_types: parsed.schema.entities,
80            relation_types,
81        })
82    }
83
84    /// Entity label strings for GLiNER input.
85    ///
86    /// Returns the type key names (e.g. "Person", "Database"). Suitable for
87    /// models trained on those label conventions.
88    pub fn entity_labels(&self) -> Vec<&str> {
89        self.entity_types.keys().map(|s| s.as_str()).collect()
90    }
91
92    /// Entity descriptions for zero-shot GLiNER inference.
93    ///
94    /// Returns `(description, key)` pairs. Passing the description as the label
95    /// to GLiNER improves zero-shot recall because the model uses the label text
96    /// as a natural-language prompt. The key is the canonical type name used in
97    /// `ExtractionSchema` and benchmark fixtures.
98    pub fn entity_label_descriptions(&self) -> Vec<(&str, &str)> {
99        self.entity_types
100            .iter()
101            .map(|(k, v)| (v.as_str(), k.as_str()))
102            .collect()
103    }
104
105    /// Map a GLiNER class string back to the canonical entity type key.
106    ///
107    /// When descriptions are used as labels, GLiNER returns the description as
108    /// the span class. This method reverses that lookup.
109    pub fn entity_type_from_label<'a>(&'a self, label: &str) -> Option<&'a str> {
110        // Check descriptions first (zero-shot mode)
111        if let Some(key) = self
112            .entity_types
113            .iter()
114            .find(|(_, v)| v.as_str() == label)
115            .map(|(k, _)| k.as_str())
116        {
117            return Some(key);
118        }
119        // Fall back to direct key match (standard mode)
120        if self.entity_types.contains_key(label) {
121            return Some(self.entity_types.get_key_value(label).unwrap().0.as_str());
122        }
123        None
124    }
125
126    /// Relation label strings for GLiREL/relation extraction input.
127    pub fn relation_labels(&self) -> Vec<&str> {
128        self.relation_types.keys().map(|s| s.as_str()).collect()
129    }
130}
131
132impl Default for ExtractionSchema {
133    fn default() -> Self {
134        let mut entity_types = BTreeMap::new();
135        // Descriptions are short (2-4 words) so they fit inside GLiNER's token
136        // budget alongside the input text. They are used as the actual label
137        // strings passed to the model for zero-shot extraction, and are more
138        // semantically precise than the bare key names.
139        entity_types.insert("Person".into(),         "person or engineer".into());
140        entity_types.insert("Component".into(),      "software library or framework".into());
141        entity_types.insert("Service".into(),        "cloud service or API".into());
142        entity_types.insert("Language".into(),       "programming language".into());
143        entity_types.insert("Database".into(),       "database or data store".into());
144        entity_types.insert("Infrastructure".into(), "server or cloud platform".into());
145        entity_types.insert("Decision".into(),       "architectural decision".into());
146        entity_types.insert("Constraint".into(),     "technical constraint".into());
147        entity_types.insert("Metric".into(),         "performance metric".into());
148        entity_types.insert("Pattern".into(),        "design pattern".into());
149
150        let mut relation_types = BTreeMap::new();
151        relation_types.insert(
152            "chose".into(),
153            RelationSpec {
154                head: vec!["Person".into(), "Service".into(), "Component".into()],
155                tail: vec![
156                    "Component".into(),
157                    "Database".into(),
158                    "Language".into(),
159                    "Infrastructure".into(),
160                    "Pattern".into(),
161                ],
162                description: "chose or adopted a technology".into(),
163            },
164        );
165        relation_types.insert(
166            "rejected".into(),
167            RelationSpec {
168                head: vec!["Person".into(), "Service".into(), "Component".into()],
169                tail: vec![
170                    "Component".into(),
171                    "Database".into(),
172                    "Language".into(),
173                    "Infrastructure".into(),
174                ],
175                description: "rejected an alternative".into(),
176            },
177        );
178        relation_types.insert(
179            "replaced".into(),
180            RelationSpec {
181                head: vec![
182                    "Component".into(),
183                    "Database".into(),
184                    "Infrastructure".into(),
185                    "Service".into(),
186                    "Pattern".into(),
187                    "Language".into(),
188                ],
189                tail: vec![
190                    "Component".into(),
191                    "Database".into(),
192                    "Infrastructure".into(),
193                    "Pattern".into(),
194                    "Language".into(),
195                ],
196                description: "one thing replaced another".into(),
197            },
198        );
199        relation_types.insert(
200            "depends_on".into(),
201            RelationSpec {
202                head: vec![
203                    "Service".into(), "Component".into(), "Infrastructure".into(),
204                    "Language".into(), "Pattern".into(), "Decision".into(),
205                ],
206                tail: vec![
207                    "Service".into(),
208                    "Component".into(),
209                    "Database".into(),
210                    "Infrastructure".into(),
211                    "Pattern".into(),
212                    "Language".into(),
213                ],
214                description: "dependency relationship".into(),
215            },
216        );
217        relation_types.insert(
218            "fixed".into(),
219            RelationSpec {
220                head: vec!["Person".into(), "Component".into(), "Service".into()],
221                tail: vec![
222                    "Component".into(),
223                    "Service".into(),
224                    "Database".into(),
225                    "Pattern".into(),
226                ],
227                description: "something fixed an issue".into(),
228            },
229        );
230        relation_types.insert(
231            "introduced".into(),
232            RelationSpec {
233                head: vec!["Person".into(), "Service".into(), "Infrastructure".into(), "Component".into(), "Language".into()],
234                tail: vec![
235                    "Component".into(),
236                    "Pattern".into(),
237                    "Infrastructure".into(),
238                    "Database".into(),
239                    "Language".into(),
240                    "Metric".into(),
241                ],
242                description: "introduced or added a component".into(),
243            },
244        );
245        relation_types.insert(
246            "deprecated".into(),
247            RelationSpec {
248                head: vec![
249                    "Person".into(), "Decision".into(), "Service".into(),
250                    "Component".into(), "Infrastructure".into(), "Pattern".into(),
251                ],
252                tail: vec![
253                    "Component".into(),
254                    "Pattern".into(),
255                    "Infrastructure".into(),
256                    "Database".into(),
257                    "Language".into(),
258                ],
259                description: "deprecation action".into(),
260            },
261        );
262        relation_types.insert(
263            "caused".into(),
264            RelationSpec {
265                head: vec![
266                    "Component".into(), "Decision".into(), "Service".into(),
267                    "Infrastructure".into(), "Language".into(), "Pattern".into(),
268                    "Database".into(),
269                ],
270                tail: vec!["Metric".into(), "Constraint".into(), "Pattern".into()],
271                description: "causal relationship".into(),
272            },
273        );
274        relation_types.insert(
275            "constrained_by".into(),
276            RelationSpec {
277                head: vec![
278                    "Decision".into(), "Component".into(), "Service".into(),
279                    "Infrastructure".into(), "Database".into(), "Pattern".into(),
280                ],
281                tail: vec!["Constraint".into(), "Pattern".into(), "Infrastructure".into(), "Metric".into()],
282                description: "decision constrained by".into(),
283            },
284        );
285
286        Self {
287            name: "default".into(),
288            entity_types,
289            relation_types,
290        }
291    }
292}
293
294#[derive(Debug, thiserror::Error)]
295pub enum SchemaError {
296    #[error("failed to read schema at {path}: {source}")]
297    Io {
298        path: String,
299        source: std::io::Error,
300    },
301
302    #[error("failed to parse schema: {0}")]
303    Parse(String),
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn default_schema_has_all_entity_types() {
312        let schema = ExtractionSchema::default();
313        let labels = schema.entity_labels();
314        assert!(labels.contains(&"Person"));
315        assert!(labels.contains(&"Component"));
316        assert!(labels.contains(&"Service"));
317        assert!(labels.contains(&"Language"));
318        assert!(labels.contains(&"Database"));
319        assert!(labels.contains(&"Infrastructure"));
320        assert!(labels.contains(&"Decision"));
321        assert!(labels.contains(&"Constraint"));
322        assert!(labels.contains(&"Metric"));
323        assert!(labels.contains(&"Pattern"));
324        assert_eq!(labels.len(), 10);
325    }
326
327    #[test]
328    fn default_schema_has_all_relation_types() {
329        let schema = ExtractionSchema::default();
330        let labels = schema.relation_labels();
331        assert!(labels.contains(&"chose"));
332        assert!(labels.contains(&"rejected"));
333        assert!(labels.contains(&"replaced"));
334        assert!(labels.contains(&"depends_on"));
335        assert!(labels.contains(&"fixed"));
336        assert!(labels.contains(&"introduced"));
337        assert!(labels.contains(&"deprecated"));
338        assert!(labels.contains(&"caused"));
339        assert!(labels.contains(&"constrained_by"));
340        assert_eq!(labels.len(), 9);
341    }
342
343    #[test]
344    fn parse_toml_schema() {
345        let toml = r#"
346[schema]
347name = "test"
348
349[schema.entities]
350Person = "A person"
351Component = "A software component"
352
353[schema.relations]
354chose = { head = ["Person"], tail = ["Component"], description = "person chose" }
355"#;
356        let schema = ExtractionSchema::from_toml(toml).unwrap();
357        assert_eq!(schema.name, "test");
358        assert_eq!(schema.entity_types.len(), 2);
359        assert_eq!(schema.relation_types.len(), 1);
360        assert_eq!(schema.relation_types["chose"].head, vec!["Person"]);
361    }
362
363    #[test]
364    fn parse_toml_schema_no_relations() {
365        let toml = r#"
366[schema]
367name = "entities-only"
368
369[schema.entities]
370Person = "A person"
371"#;
372        let schema = ExtractionSchema::from_toml(toml).unwrap();
373        assert_eq!(schema.relation_types.len(), 0);
374    }
375}