Skip to main content

ctxgraph_extract/
schema.rs

1use std::collections::BTreeMap;
2use std::path::Path;
3
4use serde::{Deserialize, Serialize};
5
6/// Extraction schema defining which entity types and relation types to extract.
7///
8/// Loaded from a `ctxgraph.toml` file or constructed via `ExtractionSchema::default()`.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ExtractionSchema {
11    pub name: String,
12    pub entity_types: BTreeMap<String, String>,
13    pub relation_types: BTreeMap<String, RelationSpec>,
14}
15
16/// Specification for a relation type — which entity types can be head/tail.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct RelationSpec {
19    pub head: Vec<String>,
20    pub tail: Vec<String>,
21    pub description: String,
22}
23
24/// Raw TOML structure for deserialization.
25#[derive(Debug, Deserialize)]
26struct SchemaToml {
27    schema: SchemaSection,
28}
29
30#[derive(Debug, Deserialize)]
31struct SchemaSection {
32    name: String,
33    entities: BTreeMap<String, String>,
34    #[serde(default)]
35    relations: BTreeMap<String, RelationSpecToml>,
36}
37
38#[derive(Debug, Deserialize)]
39struct RelationSpecToml {
40    head: Vec<String>,
41    tail: Vec<String>,
42    #[serde(default)]
43    description: String,
44}
45
46impl ExtractionSchema {
47    /// Load schema from a TOML file.
48    pub fn load(path: &Path) -> Result<Self, SchemaError> {
49        let content = std::fs::read_to_string(path).map_err(|e| SchemaError::Io {
50            path: path.display().to_string(),
51            source: e,
52        })?;
53        Self::from_toml(&content)
54    }
55
56    /// Parse schema from a TOML string.
57    pub fn from_toml(content: &str) -> Result<Self, SchemaError> {
58        let parsed: SchemaToml =
59            toml::from_str(content).map_err(|e| SchemaError::Parse(e.to_string()))?;
60
61        let relation_types = parsed
62            .schema
63            .relations
64            .into_iter()
65            .map(|(k, v)| {
66                (
67                    k,
68                    RelationSpec {
69                        head: v.head,
70                        tail: v.tail,
71                        description: v.description,
72                    },
73                )
74            })
75            .collect();
76
77        Ok(Self {
78            name: parsed.schema.name,
79            entity_types: parsed.schema.entities,
80            relation_types,
81        })
82    }
83
84    /// Entity label strings for GLiNER input.
85    ///
86    /// Returns the type key names (e.g. "Person", "Database"). Suitable for
87    /// models trained on those label conventions.
88    pub fn entity_labels(&self) -> Vec<&str> {
89        self.entity_types.keys().map(|s| s.as_str()).collect()
90    }
91
92    /// Entity descriptions for zero-shot GLiNER inference.
93    ///
94    /// Returns `(description, key)` pairs. Passing the description as the label
95    /// to GLiNER improves zero-shot recall because the model uses the label text
96    /// as a natural-language prompt. The key is the canonical type name used in
97    /// `ExtractionSchema` and benchmark fixtures.
98    pub fn entity_label_descriptions(&self) -> Vec<(&str, &str)> {
99        self.entity_types
100            .iter()
101            .map(|(k, v)| (v.as_str(), k.as_str()))
102            .collect()
103    }
104
105    /// Map a GLiNER class string back to the canonical entity type key.
106    ///
107    /// When descriptions are used as labels, GLiNER returns the description as
108    /// the span class. This method reverses that lookup.
109    pub fn entity_type_from_label<'a>(&'a self, label: &str) -> Option<&'a str> {
110        // Check descriptions first (zero-shot mode)
111        if let Some(key) = self
112            .entity_types
113            .iter()
114            .find(|(_, v)| v.as_str() == label)
115            .map(|(k, _)| k.as_str())
116        {
117            return Some(key);
118        }
119        // Fall back to direct key match (standard mode)
120        if self.entity_types.contains_key(label) {
121            return Some(self.entity_types.get_key_value(label).unwrap().0.as_str());
122        }
123        None
124    }
125
126    /// Relation label strings for GLiREL/relation extraction input.
127    pub fn relation_labels(&self) -> Vec<&str> {
128        self.relation_types.keys().map(|s| s.as_str()).collect()
129    }
130}
131
132impl Default for ExtractionSchema {
133    fn default() -> Self {
134        let mut entity_types = BTreeMap::new();
135        // Descriptions are short (2-4 words) so they fit inside GLiNER's token
136        // budget alongside the input text. They are used as the actual label
137        // strings passed to the model for zero-shot extraction, and are more
138        // semantically precise than the bare key names.
139        entity_types.insert("Person".into(), "software engineer".into());
140        entity_types.insert("Component".into(), "software library or framework".into());
141        entity_types.insert("Service".into(), "software service or API".into());
142        entity_types.insert("Language".into(), "programming language".into());
143        entity_types.insert("Database".into(), "database system".into());
144        entity_types.insert("Infrastructure".into(), "infrastructure tool".into());
145        entity_types.insert("Decision".into(), "technology decision".into());
146        entity_types.insert("Constraint".into(), "technical constraint or requirement".into());
147        entity_types.insert("Metric".into(), "performance metric".into());
148        entity_types.insert("Pattern".into(), "design pattern".into());
149
150        let mut relation_types = BTreeMap::new();
151        relation_types.insert(
152            "chose".into(),
153            RelationSpec {
154                head: vec!["Person".into()],
155                tail: vec!["Component".into(), "Database".into(), "Language".into()],
156                description: "person chose a technology".into(),
157            },
158        );
159        relation_types.insert(
160            "rejected".into(),
161            RelationSpec {
162                head: vec!["Person".into()],
163                tail: vec!["Component".into(), "Database".into()],
164                description: "person rejected an alternative".into(),
165            },
166        );
167        relation_types.insert(
168            "replaced".into(),
169            RelationSpec {
170                head: vec!["Component".into(), "Database".into()],
171                tail: vec!["Component".into(), "Database".into()],
172                description: "one thing replaced another".into(),
173            },
174        );
175        relation_types.insert(
176            "depends_on".into(),
177            RelationSpec {
178                head: vec!["Service".into(), "Component".into()],
179                tail: vec![
180                    "Service".into(),
181                    "Component".into(),
182                    "Database".into(),
183                ],
184                description: "dependency relationship".into(),
185            },
186        );
187        relation_types.insert(
188            "fixed".into(),
189            RelationSpec {
190                head: vec!["Person".into(), "Component".into()],
191                tail: vec!["Component".into(), "Service".into()],
192                description: "something fixed an issue".into(),
193            },
194        );
195        relation_types.insert(
196            "introduced".into(),
197            RelationSpec {
198                head: vec!["Person".into()],
199                tail: vec!["Component".into(), "Pattern".into()],
200                description: "person introduced a component".into(),
201            },
202        );
203        relation_types.insert(
204            "deprecated".into(),
205            RelationSpec {
206                head: vec!["Person".into(), "Decision".into()],
207                tail: vec!["Component".into(), "Pattern".into()],
208                description: "deprecation action".into(),
209            },
210        );
211        relation_types.insert(
212            "caused".into(),
213            RelationSpec {
214                head: vec!["Component".into(), "Decision".into()],
215                tail: vec!["Metric".into(), "Constraint".into()],
216                description: "causal relationship".into(),
217            },
218        );
219        relation_types.insert(
220            "constrained_by".into(),
221            RelationSpec {
222                head: vec!["Decision".into(), "Component".into()],
223                tail: vec!["Constraint".into()],
224                description: "decision constrained by".into(),
225            },
226        );
227
228        Self {
229            name: "default".into(),
230            entity_types,
231            relation_types,
232        }
233    }
234}
235
236#[derive(Debug, thiserror::Error)]
237pub enum SchemaError {
238    #[error("failed to read schema at {path}: {source}")]
239    Io {
240        path: String,
241        source: std::io::Error,
242    },
243
244    #[error("failed to parse schema: {0}")]
245    Parse(String),
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn default_schema_has_all_entity_types() {
254        let schema = ExtractionSchema::default();
255        let labels = schema.entity_labels();
256        assert!(labels.contains(&"Person"));
257        assert!(labels.contains(&"Component"));
258        assert!(labels.contains(&"Service"));
259        assert!(labels.contains(&"Language"));
260        assert!(labels.contains(&"Database"));
261        assert!(labels.contains(&"Infrastructure"));
262        assert!(labels.contains(&"Decision"));
263        assert!(labels.contains(&"Constraint"));
264        assert!(labels.contains(&"Metric"));
265        assert!(labels.contains(&"Pattern"));
266        assert_eq!(labels.len(), 10);
267    }
268
269    #[test]
270    fn default_schema_has_all_relation_types() {
271        let schema = ExtractionSchema::default();
272        let labels = schema.relation_labels();
273        assert!(labels.contains(&"chose"));
274        assert!(labels.contains(&"rejected"));
275        assert!(labels.contains(&"replaced"));
276        assert!(labels.contains(&"depends_on"));
277        assert!(labels.contains(&"fixed"));
278        assert!(labels.contains(&"introduced"));
279        assert!(labels.contains(&"deprecated"));
280        assert!(labels.contains(&"caused"));
281        assert!(labels.contains(&"constrained_by"));
282        assert_eq!(labels.len(), 9);
283    }
284
285    #[test]
286    fn parse_toml_schema() {
287        let toml = r#"
288[schema]
289name = "test"
290
291[schema.entities]
292Person = "A person"
293Component = "A software component"
294
295[schema.relations]
296chose = { head = ["Person"], tail = ["Component"], description = "person chose" }
297"#;
298        let schema = ExtractionSchema::from_toml(toml).unwrap();
299        assert_eq!(schema.name, "test");
300        assert_eq!(schema.entity_types.len(), 2);
301        assert_eq!(schema.relation_types.len(), 1);
302        assert_eq!(schema.relation_types["chose"].head, vec!["Person"]);
303    }
304
305    #[test]
306    fn parse_toml_schema_no_relations() {
307        let toml = r#"
308[schema]
309name = "entities-only"
310
311[schema.entities]
312Person = "A person"
313"#;
314        let schema = ExtractionSchema::from_toml(toml).unwrap();
315        assert_eq!(schema.relation_types.len(), 0);
316    }
317}