Skip to main content

ie_schema/
lib.rs

1pub mod expanded;
2pub mod ingest;
3pub mod json_schema;
4pub mod lifted;
5pub mod normalized;
6pub mod prompt_plan;
7pub mod task_plan;
8pub mod token_plan;
9
10#[cfg(feature = "python")]
11use std::sync::Arc;
12#[cfg(feature = "python")]
13use std::sync::atomic::{AtomicUsize, Ordering};
14
15#[cfg(feature = "python")]
16use pyo3::exceptions::PyValueError;
17#[cfg(feature = "python")]
18use pyo3::prelude::*;
19#[cfg(feature = "python")]
20use pyo3::types::{PyAnyMethods, PyDict, PyModule, PyString, PyType};
21#[cfg(feature = "python")]
22use pyo3_stub_gen::define_stub_info_gatherer;
23#[cfg(feature = "python")]
24use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
25#[cfg(feature = "python")]
26use task_plan::PlannedTask;
27
28#[cfg(feature = "python")]
29impl From<normalized::SchemaLoadError> for PyErr {
30    fn from(e: normalized::SchemaLoadError) -> Self {
31        PyValueError::new_err(e.to_string())
32    }
33}
34
35#[cfg(feature = "python")]
36impl From<normalized::SchemaNormalizeError> for PyErr {
37    fn from(e: normalized::SchemaNormalizeError) -> Self {
38        PyValueError::new_err(e.to_string())
39    }
40}
41
42#[cfg(feature = "python")]
43impl From<expanded::SchemaExpandError> for PyErr {
44    fn from(e: expanded::SchemaExpandError) -> Self {
45        PyValueError::new_err(e.to_string())
46    }
47}
48
49#[cfg(feature = "python")]
50impl From<lifted::SchemaLiftError> for PyErr {
51    fn from(e: lifted::SchemaLiftError) -> Self {
52        PyValueError::new_err(e.to_string())
53    }
54}
55
56#[cfg(feature = "python")]
57impl From<task_plan::TaskPlanError> for PyErr {
58    fn from(e: task_plan::TaskPlanError) -> Self {
59        PyValueError::new_err(e.to_string())
60    }
61}
62
63#[cfg(feature = "python")]
64impl From<prompt_plan::PromptPlanError> for PyErr {
65    fn from(e: prompt_plan::PromptPlanError) -> Self {
66        PyValueError::new_err(e.to_string())
67    }
68}
69
70#[cfg(feature = "python")]
71#[pymodule]
72#[pyo3(name = "ie_schema")]
73fn ieschema_library(m: &Bound<'_, PyModule>) -> PyResult<()> {
74    m.add_class::<IESchema>()?;
75    m.add_class::<Task>()?;
76    m.add_class::<ClassificationTask>()?;
77    m.add_class::<EntityExtractionTask>()?;
78    m.add_class::<RelationExtractionTask>()?;
79    m.add_class::<JSONStructureTask>()?;
80    m.add_class::<StructureChild>()?;
81    Ok(())
82}
83
84#[cfg(feature = "python")]
85#[gen_stub_pyclass]
86#[pyclass(module = "ie_schema")]
87/// Information-extraction schema loaded from JSON (or from a dataclass / Pydantic model type).
88///
89/// Build an `IESchema` from a JSON string with `loads()` (IE ingest JSON or a root JSON Schema
90/// object), from a path with `load()`, or from a stdlib dataclass / Pydantic v2 `BaseModel` by
91/// passing the class (or an instance) to `loads()`. Iterating over the object yields task instances
92/// in schema order.
93///
94/// Example:
95/// >>> import ie_schema
96/// >>> _j = '{"json_structures":[{"name":"Business","business_name":{"dtype":"str"}}]}'
97/// >>> schema = ie_schema.IESchema.loads(_j)
98/// >>> isinstance(schema, ie_schema.IESchema)
99/// True
100/// >>> len(list(schema))
101/// 1
102pub struct IESchema {
103    task_plan: Arc<task_plan::TaskPlan>,
104    prompt_plan: prompt_plan::PromptPlan,
105    iter_index: AtomicUsize,
106}
107
108#[cfg(feature = "python")]
109impl IESchema {
110    fn from_normalized(normalized: normalized::NormalizedSchema) -> PyResult<Self> {
111        let expanded = expanded::ExpandedSchema::try_from(normalized)?;
112        let lifted = lifted::LiftedSchema::try_from(expanded)?;
113        let tp = task_plan::TaskPlan::try_from(lifted)?;
114        let pp = prompt_plan::PromptPlan::try_from(tp.clone())?;
115        Ok(Self {
116            task_plan: Arc::new(tp),
117            prompt_plan: pp,
118            iter_index: AtomicUsize::new(0),
119        })
120    }
121
122    fn loads_inner_bytes(bytes: &[u8]) -> PyResult<Self> {
123        let normalized = normalized::NormalizedSchema::from_json_bytes(bytes)?;
124        Self::from_normalized(normalized)
125    }
126
127    fn loads_inner(s: &str) -> PyResult<Self> {
128        Self::loads_inner_bytes(s.as_bytes())
129    }
130}
131
132/// JSON Schema as UTF-8 bytes for a stdlib dataclass type or Pydantic v2 `BaseModel` subclass.
133#[cfg(feature = "python")]
134fn json_schema_utf8_bytes_from_type<'py>(
135    py: Python<'py>,
136    type_obj: &Bound<'py, PyType>,
137) -> PyResult<Vec<u8>> {
138    let json_mod = PyModule::import(py, "json")?;
139    let builtins = PyModule::import(py, "builtins")?;
140
141    let dataclasses = PyModule::import(py, "dataclasses")?;
142    let is_dataclass = dataclasses.getattr("is_dataclass")?;
143    let is_dc: bool = is_dataclass.call1((type_obj,))?.extract()?;
144
145    let pydantic_mod = match PyModule::import(py, "pydantic") {
146        Ok(m) => Some(m),
147        Err(e) => {
148            if is_dc {
149                return Err(PyValueError::new_err(format!(
150                    "IESchema.loads: converting a dataclass to JSON schema requires Pydantic v2 \
151                     (install with `uv add pydantic` or `pip install pydantic`). \
152                     Original import error: {e}"
153                )));
154            }
155            None
156        }
157    };
158
159    let pyd = pydantic_mod
160        .as_ref()
161        .ok_or_else(loads_unsupported_input_error)?;
162    let base_model = pyd.getattr("BaseModel")?;
163    let issub = builtins.getattr("issubclass")?;
164    let is_model = match issub.call1((type_obj, &base_model)) {
165        Ok(v) => v.is_truthy()?,
166        Err(_) => false,
167    };
168    let schema_obj = if is_model {
169        type_obj.call_method0("model_json_schema")?
170    } else if is_dc {
171        let type_adapter = pyd.getattr("TypeAdapter")?.call1((type_obj,))?;
172        type_adapter.call_method0("json_schema")?
173    } else {
174        return Err(loads_unsupported_input_error());
175    };
176
177    let dumps = json_mod.getattr("dumps")?;
178    let kwargs = PyDict::new(py);
179    kwargs.set_item("ensure_ascii", false)?;
180    let dumped = dumps.call((&schema_obj,), Some(&kwargs))?;
181    let encoded = dumped.call_method1("encode", ("utf-8",))?;
182    encoded.extract()
183}
184
185#[cfg(feature = "python")]
186fn loads_unsupported_input_error() -> PyErr {
187    PyValueError::new_err(
188        "IESchema.loads: expected a JSON `str` (IE ingest or root JSON Schema), a `type` \
189         (stdlib dataclass or Pydantic v2 BaseModel), or an instance of such a type; got an \
190         unsupported value",
191    )
192}
193
194#[cfg(feature = "python")]
195#[gen_stub_pymethods]
196#[pymethods]
197impl IESchema {
198    #[classmethod]
199    /// Parse an `IESchema` from a JSON string or from a dataclass / Pydantic v2 `BaseModel` type.
200    ///
201    /// String input must be either IE ingest JSON (top-level keys such as `json_structures`,
202    /// `entities`, …) or a root JSON Schema object (`type`, `properties`, …). Unknown top-level
203    /// keys are rejected for the IE shape so JSON Schema is not misread as an empty ingest.
204    ///
205    /// For a dataclass or `BaseModel` type (or instance), Pydantic v2 builds JSON Schema
206    /// (`TypeAdapter` for dataclasses, `model_json_schema()` for `BaseModel` subclasses), which is
207    /// then parsed like JSON Schema string input.
208    ///
209    /// Example:
210    /// >>> import ie_schema
211    /// >>> schema = ie_schema.IESchema.loads('{"json_structures":[{"name":"Business","business_name":{"dtype":"str"}}]}')
212    /// >>> len(list(schema))
213    /// 1
214    fn loads(_cls: &Bound<'_, PyType>, input: &Bound<'_, PyAny>) -> PyResult<Self> {
215        if input.is_instance_of::<PyString>() {
216            let s: String = input.extract()?;
217            return Self::loads_inner(&s);
218        }
219
220        let type_obj: Bound<'_, PyType> = if let Ok(t) = input.cast::<PyType>() {
221            t.clone()
222        } else {
223            input.get_type()
224        };
225
226        let utf8 = json_schema_utf8_bytes_from_type(input.py(), &type_obj)?;
227        Self::loads_inner_bytes(&utf8)
228    }
229
230    #[classmethod]
231    /// Parse an `IESchema` from a JSON file path.
232    fn load(_cls: &Bound<'_, PyType>, path: String) -> PyResult<Self> {
233        let content = std::fs::read_to_string(&path)
234            .map_err(|e| PyValueError::new_err(format!("failed to read {}: {}", path, e)))?;
235        Self::loads_inner(&content)
236    }
237
238    /// Return an iterator over planned extraction tasks.
239    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
240        slf.iter_index.store(0, Ordering::Relaxed);
241        slf
242    }
243
244    /// Return the next planned task, or `None` at the end.
245    fn __next__(slf: PyRefMut<'_, Self>) -> Option<Py<PyAny>> {
246        let idx = slf.iter_index.load(Ordering::Relaxed);
247        if idx >= slf.task_plan.tasks.len() {
248            return None;
249        }
250        slf.iter_index.store(idx + 1, Ordering::Relaxed);
251
252        let arc = slf.task_plan.clone();
253        let py = slf.py();
254
255        match &slf.task_plan.tasks[idx] {
256            PlannedTask::Classification(_) => {
257                let obj = Bound::new(
258                    py,
259                    PyClassInitializer::from(Task {}).add_subclass(ClassificationTask {
260                        task_plan: arc,
261                        index: idx,
262                    }),
263                )
264                .unwrap();
265                Some(obj.into_any().unbind())
266            }
267            PlannedTask::Entity(_) => {
268                let obj = Bound::new(
269                    py,
270                    PyClassInitializer::from(Task {}).add_subclass(EntityExtractionTask {
271                        task_plan: arc,
272                        index: idx,
273                    }),
274                )
275                .unwrap();
276                Some(obj.into_any().unbind())
277            }
278            PlannedTask::Relation(_) => {
279                let obj = Bound::new(
280                    py,
281                    PyClassInitializer::from(Task {}).add_subclass(RelationExtractionTask {
282                        task_plan: arc,
283                        index: idx,
284                    }),
285                )
286                .unwrap();
287                Some(obj.into_any().unbind())
288            }
289            PlannedTask::Structure(_) => {
290                let obj = Bound::new(
291                    py,
292                    PyClassInitializer::from(Task {}).add_subclass(JSONStructureTask {
293                        task_plan: arc,
294                        index: idx,
295                    }),
296                )
297                .unwrap();
298                Some(obj.into_any().unbind())
299            }
300        }
301    }
302
303    /// Render the generated extraction prompt as a debug string.
304    ///
305    /// Example:
306    /// >>> import ie_schema
307    /// >>> schema = ie_schema.IESchema.loads('{"json_structures":[{"name":"Business","business_name":{"dtype":"str"}}]}')
308    /// >>> s = schema.prompt()
309    /// >>> ("[P]" in s) and ("business_name" in s)
310    /// True
311    fn prompt(&self) -> String {
312        self.prompt_plan.render_debug_string()
313    }
314}
315
316#[cfg(feature = "python")]
317#[gen_stub_pyclass]
318#[pyclass(subclass, module = "ie_schema")]
319/// Base class for all extraction tasks yielded by `IESchema`.
320pub struct Task {}
321
322#[cfg(feature = "python")]
323#[gen_stub_pyclass]
324#[pyclass(extends = Task, module = "ie_schema")]
325/// Classification task definition with labels and threshold metadata.
326pub struct ClassificationTask {
327    task_plan: Arc<task_plan::TaskPlan>,
328    index: usize,
329}
330
331#[cfg(feature = "python")]
332#[gen_stub_pymethods]
333#[pymethods]
334impl ClassificationTask {
335    #[getter]
336    /// Classification task name.
337    fn task(&self) -> String {
338        let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
339            unreachable!()
340        };
341        ctp.task.to_string()
342    }
343
344    #[getter]
345    /// Ordered list of class labels.
346    fn labels(&self) -> Vec<String> {
347        let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
348            unreachable!()
349        };
350        ctp.labels.iter().map(|l| l.to_string()).collect()
351    }
352
353    #[getter]
354    /// Optional confidence threshold for the classification.
355    fn threshold(&self) -> Option<f64> {
356        let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
357            unreachable!()
358        };
359        ctp.threshold
360    }
361
362    #[getter]
363    /// Whether multiple labels may be assigned.
364    fn multi_label(&self) -> bool {
365        let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
366            unreachable!()
367        };
368        ctp.multi_label
369    }
370}
371
372#[cfg(feature = "python")]
373#[gen_stub_pyclass]
374#[pyclass(extends = Task, module = "ie_schema")]
375/// Entity extraction task definition.
376pub struct EntityExtractionTask {
377    task_plan: Arc<task_plan::TaskPlan>,
378    index: usize,
379}
380
381#[cfg(feature = "python")]
382#[gen_stub_pymethods]
383#[pymethods]
384impl EntityExtractionTask {
385    #[getter]
386    /// Entity labels that should be extracted.
387    fn entities(&self) -> Vec<String> {
388        let PlannedTask::Entity(ref etp) = self.task_plan.tasks[self.index] else {
389            unreachable!()
390        };
391        etp.entities.iter().map(|e| e.to_string()).collect()
392    }
393}
394
395#[cfg(feature = "python")]
396#[gen_stub_pyclass]
397#[pyclass(extends = Task, module = "ie_schema")]
398/// Relation extraction task between head and tail entity types.
399pub struct RelationExtractionTask {
400    task_plan: Arc<task_plan::TaskPlan>,
401    index: usize,
402}
403
404#[cfg(feature = "python")]
405#[gen_stub_pymethods]
406#[pymethods]
407impl RelationExtractionTask {
408    #[getter]
409    /// Relation name.
410    fn name(&self) -> String {
411        let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
412            unreachable!()
413        };
414        rtp.relation.to_string()
415    }
416
417    #[getter]
418    /// Head entity type.
419    fn head(&self) -> String {
420        let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
421            unreachable!()
422        };
423        rtp.head.to_string()
424    }
425
426    #[getter]
427    /// Tail entity type.
428    fn tail(&self) -> String {
429        let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
430            unreachable!()
431        };
432        rtp.tail.to_string()
433    }
434
435    #[getter]
436    /// Optional human-readable relation description.
437    fn description(&self) -> Option<String> {
438        let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
439            unreachable!()
440        };
441        rtp.description.clone()
442    }
443}
444
445#[cfg(feature = "python")]
446#[gen_stub_pyclass]
447#[pyclass(extends = Task, module = "ie_schema")]
448/// Structured JSON extraction task with named children.
449pub struct JSONStructureTask {
450    task_plan: Arc<task_plan::TaskPlan>,
451    index: usize,
452}
453
454#[cfg(feature = "python")]
455#[gen_stub_pymethods]
456#[pymethods]
457impl JSONStructureTask {
458    #[getter]
459    /// Structure name.
460    fn name(&self) -> String {
461        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.index] else {
462            unreachable!()
463        };
464        stp.structure.to_string()
465    }
466
467    #[getter]
468    /// Child fields that belong to this structure.
469    fn children(&self) -> Vec<StructureChild> {
470        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.index] else {
471            unreachable!()
472        };
473        stp.children
474            .iter()
475            .enumerate()
476            .map(|(ci, _)| StructureChild {
477                task_plan: self.task_plan.clone(),
478                structure_index: self.index,
479                child_index: ci,
480            })
481            .collect()
482    }
483}
484
485#[cfg(feature = "python")]
486#[gen_stub_pyclass]
487#[pyclass(module = "ie_schema")]
488/// Child field in a `JSONStructureTask`.
489pub struct StructureChild {
490    task_plan: Arc<task_plan::TaskPlan>,
491    structure_index: usize,
492    child_index: usize,
493}
494
495#[cfg(feature = "python")]
496#[gen_stub_pymethods]
497#[pymethods]
498impl StructureChild {
499    #[getter]
500    /// Property name for this child field.
501    fn property(&self) -> String {
502        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.structure_index] else {
503            unreachable!()
504        };
505        stp.children[self.child_index].property.to_string()
506    }
507
508    #[getter]
509    /// Allowed string choices for this property.
510    fn choices(&self) -> Vec<String> {
511        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.structure_index] else {
512            unreachable!()
513        };
514        stp.children[self.child_index]
515            .choices
516            .iter()
517            .map(|c| c.to_string())
518            .collect()
519    }
520
521    #[getter]
522    /// Optional child-field description.
523    fn description(&self) -> Option<String> {
524        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.structure_index] else {
525            unreachable!()
526        };
527        stp.children[self.child_index].description.clone()
528    }
529}
530
531#[cfg(feature = "python")]
532define_stub_info_gatherer!(stub_info);