1pub mod expanded;
2pub mod ingest;
3pub mod json_schema;
4pub mod lifted;
5pub mod normalized;
6pub mod prompt_plan;
7pub mod task_plan;
8pub mod token_plan;
9
10#[cfg(feature = "python")]
11use std::sync::Arc;
12#[cfg(feature = "python")]
13use std::sync::atomic::{AtomicUsize, Ordering};
14
15#[cfg(feature = "python")]
16use pyo3::exceptions::PyValueError;
17#[cfg(feature = "python")]
18use pyo3::prelude::*;
19#[cfg(feature = "python")]
20use pyo3::types::{PyAnyMethods, PyDict, PyModule, PyString, PyType};
21#[cfg(feature = "python")]
22use pyo3_stub_gen::define_stub_info_gatherer;
23#[cfg(feature = "python")]
24use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
25#[cfg(feature = "python")]
26use task_plan::PlannedTask;
27
28#[cfg(feature = "python")]
29impl From<normalized::SchemaLoadError> for PyErr {
30 fn from(e: normalized::SchemaLoadError) -> Self {
31 PyValueError::new_err(e.to_string())
32 }
33}
34
35#[cfg(feature = "python")]
36impl From<normalized::SchemaNormalizeError> for PyErr {
37 fn from(e: normalized::SchemaNormalizeError) -> Self {
38 PyValueError::new_err(e.to_string())
39 }
40}
41
42#[cfg(feature = "python")]
43impl From<expanded::SchemaExpandError> for PyErr {
44 fn from(e: expanded::SchemaExpandError) -> Self {
45 PyValueError::new_err(e.to_string())
46 }
47}
48
49#[cfg(feature = "python")]
50impl From<lifted::SchemaLiftError> for PyErr {
51 fn from(e: lifted::SchemaLiftError) -> Self {
52 PyValueError::new_err(e.to_string())
53 }
54}
55
56#[cfg(feature = "python")]
57impl From<task_plan::TaskPlanError> for PyErr {
58 fn from(e: task_plan::TaskPlanError) -> Self {
59 PyValueError::new_err(e.to_string())
60 }
61}
62
63#[cfg(feature = "python")]
64impl From<prompt_plan::PromptPlanError> for PyErr {
65 fn from(e: prompt_plan::PromptPlanError) -> Self {
66 PyValueError::new_err(e.to_string())
67 }
68}
69
70#[cfg(feature = "python")]
71#[pymodule]
72#[pyo3(name = "ie_schema")]
73fn ieschema_library(m: &Bound<'_, PyModule>) -> PyResult<()> {
74 m.add_class::<IESchema>()?;
75 m.add_class::<Task>()?;
76 m.add_class::<ClassificationTask>()?;
77 m.add_class::<EntityExtractionTask>()?;
78 m.add_class::<RelationExtractionTask>()?;
79 m.add_class::<JSONStructureTask>()?;
80 m.add_class::<StructureChild>()?;
81 Ok(())
82}
83
84#[cfg(feature = "python")]
85#[gen_stub_pyclass]
86#[pyclass(module = "ie_schema")]
87pub struct IESchema {
103 task_plan: Arc<task_plan::TaskPlan>,
104 prompt_plan: prompt_plan::PromptPlan,
105 iter_index: AtomicUsize,
106}
107
108#[cfg(feature = "python")]
109impl IESchema {
110 fn from_normalized(normalized: normalized::NormalizedSchema) -> PyResult<Self> {
111 let expanded = expanded::ExpandedSchema::try_from(normalized)?;
112 let lifted = lifted::LiftedSchema::try_from(expanded)?;
113 let tp = task_plan::TaskPlan::try_from(lifted)?;
114 let pp = prompt_plan::PromptPlan::try_from(tp.clone())?;
115 Ok(Self {
116 task_plan: Arc::new(tp),
117 prompt_plan: pp,
118 iter_index: AtomicUsize::new(0),
119 })
120 }
121
122 fn loads_inner_bytes(bytes: &[u8]) -> PyResult<Self> {
123 let normalized = normalized::NormalizedSchema::from_json_bytes(bytes)?;
124 Self::from_normalized(normalized)
125 }
126
127 fn loads_inner(s: &str) -> PyResult<Self> {
128 Self::loads_inner_bytes(s.as_bytes())
129 }
130}
131
132#[cfg(feature = "python")]
134fn json_schema_utf8_bytes_from_type<'py>(
135 py: Python<'py>,
136 type_obj: &Bound<'py, PyType>,
137) -> PyResult<Vec<u8>> {
138 let json_mod = PyModule::import(py, "json")?;
139 let builtins = PyModule::import(py, "builtins")?;
140
141 let dataclasses = PyModule::import(py, "dataclasses")?;
142 let is_dataclass = dataclasses.getattr("is_dataclass")?;
143 let is_dc: bool = is_dataclass.call1((type_obj,))?.extract()?;
144
145 let pydantic_mod = match PyModule::import(py, "pydantic") {
146 Ok(m) => Some(m),
147 Err(e) => {
148 if is_dc {
149 return Err(PyValueError::new_err(format!(
150 "IESchema.loads: converting a dataclass to JSON schema requires Pydantic v2 \
151 (install with `uv add pydantic` or `pip install pydantic`). \
152 Original import error: {e}"
153 )));
154 }
155 None
156 }
157 };
158
159 let pyd = pydantic_mod
160 .as_ref()
161 .ok_or_else(loads_unsupported_input_error)?;
162 let base_model = pyd.getattr("BaseModel")?;
163 let issub = builtins.getattr("issubclass")?;
164 let is_model = match issub.call1((type_obj, &base_model)) {
165 Ok(v) => v.is_truthy()?,
166 Err(_) => false,
167 };
168 let schema_obj = if is_model {
169 type_obj.call_method0("model_json_schema")?
170 } else if is_dc {
171 let type_adapter = pyd.getattr("TypeAdapter")?.call1((type_obj,))?;
172 type_adapter.call_method0("json_schema")?
173 } else {
174 return Err(loads_unsupported_input_error());
175 };
176
177 let dumps = json_mod.getattr("dumps")?;
178 let kwargs = PyDict::new(py);
179 kwargs.set_item("ensure_ascii", false)?;
180 let dumped = dumps.call((&schema_obj,), Some(&kwargs))?;
181 let encoded = dumped.call_method1("encode", ("utf-8",))?;
182 encoded.extract()
183}
184
185#[cfg(feature = "python")]
186fn loads_unsupported_input_error() -> PyErr {
187 PyValueError::new_err(
188 "IESchema.loads: expected a JSON `str` (IE ingest or root JSON Schema), a `type` \
189 (stdlib dataclass or Pydantic v2 BaseModel), or an instance of such a type; got an \
190 unsupported value",
191 )
192}
193
194#[cfg(feature = "python")]
195#[gen_stub_pymethods]
196#[pymethods]
197impl IESchema {
198 #[classmethod]
199 fn loads(_cls: &Bound<'_, PyType>, input: &Bound<'_, PyAny>) -> PyResult<Self> {
215 if input.is_instance_of::<PyString>() {
216 let s: String = input.extract()?;
217 return Self::loads_inner(&s);
218 }
219
220 let type_obj: Bound<'_, PyType> = if let Ok(t) = input.cast::<PyType>() {
221 t.clone()
222 } else {
223 input.get_type()
224 };
225
226 let utf8 = json_schema_utf8_bytes_from_type(input.py(), &type_obj)?;
227 Self::loads_inner_bytes(&utf8)
228 }
229
230 #[classmethod]
231 fn load(_cls: &Bound<'_, PyType>, path: String) -> PyResult<Self> {
233 let content = std::fs::read_to_string(&path)
234 .map_err(|e| PyValueError::new_err(format!("failed to read {}: {}", path, e)))?;
235 Self::loads_inner(&content)
236 }
237
238 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
240 slf.iter_index.store(0, Ordering::Relaxed);
241 slf
242 }
243
244 fn __next__(slf: PyRefMut<'_, Self>) -> Option<Py<PyAny>> {
246 let idx = slf.iter_index.load(Ordering::Relaxed);
247 if idx >= slf.task_plan.tasks.len() {
248 return None;
249 }
250 slf.iter_index.store(idx + 1, Ordering::Relaxed);
251
252 let arc = slf.task_plan.clone();
253 let py = slf.py();
254
255 match &slf.task_plan.tasks[idx] {
256 PlannedTask::Classification(_) => {
257 let obj = Bound::new(
258 py,
259 PyClassInitializer::from(Task {}).add_subclass(ClassificationTask {
260 task_plan: arc,
261 index: idx,
262 }),
263 )
264 .unwrap();
265 Some(obj.into_any().unbind())
266 }
267 PlannedTask::Entity(_) => {
268 let obj = Bound::new(
269 py,
270 PyClassInitializer::from(Task {}).add_subclass(EntityExtractionTask {
271 task_plan: arc,
272 index: idx,
273 }),
274 )
275 .unwrap();
276 Some(obj.into_any().unbind())
277 }
278 PlannedTask::Relation(_) => {
279 let obj = Bound::new(
280 py,
281 PyClassInitializer::from(Task {}).add_subclass(RelationExtractionTask {
282 task_plan: arc,
283 index: idx,
284 }),
285 )
286 .unwrap();
287 Some(obj.into_any().unbind())
288 }
289 PlannedTask::Structure(_) => {
290 let obj = Bound::new(
291 py,
292 PyClassInitializer::from(Task {}).add_subclass(JSONStructureTask {
293 task_plan: arc,
294 index: idx,
295 }),
296 )
297 .unwrap();
298 Some(obj.into_any().unbind())
299 }
300 }
301 }
302
303 fn prompt(&self) -> String {
312 self.prompt_plan.render_debug_string()
313 }
314}
315
316#[cfg(feature = "python")]
317#[gen_stub_pyclass]
318#[pyclass(subclass, module = "ie_schema")]
319pub struct Task {}
321
322#[cfg(feature = "python")]
323#[gen_stub_pyclass]
324#[pyclass(extends = Task, module = "ie_schema")]
325pub struct ClassificationTask {
327 task_plan: Arc<task_plan::TaskPlan>,
328 index: usize,
329}
330
331#[cfg(feature = "python")]
332#[gen_stub_pymethods]
333#[pymethods]
334impl ClassificationTask {
335 #[getter]
336 fn task(&self) -> String {
338 let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
339 unreachable!()
340 };
341 ctp.task.to_string()
342 }
343
344 #[getter]
345 fn labels(&self) -> Vec<String> {
347 let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
348 unreachable!()
349 };
350 ctp.labels.iter().map(|l| l.to_string()).collect()
351 }
352
353 #[getter]
354 fn threshold(&self) -> Option<f64> {
356 let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
357 unreachable!()
358 };
359 ctp.threshold
360 }
361
362 #[getter]
363 fn multi_label(&self) -> bool {
365 let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
366 unreachable!()
367 };
368 ctp.multi_label
369 }
370}
371
372#[cfg(feature = "python")]
373#[gen_stub_pyclass]
374#[pyclass(extends = Task, module = "ie_schema")]
375pub struct EntityExtractionTask {
377 task_plan: Arc<task_plan::TaskPlan>,
378 index: usize,
379}
380
381#[cfg(feature = "python")]
382#[gen_stub_pymethods]
383#[pymethods]
384impl EntityExtractionTask {
385 #[getter]
386 fn entities(&self) -> Vec<String> {
388 let PlannedTask::Entity(ref etp) = self.task_plan.tasks[self.index] else {
389 unreachable!()
390 };
391 etp.entities.iter().map(|e| e.to_string()).collect()
392 }
393}
394
395#[cfg(feature = "python")]
396#[gen_stub_pyclass]
397#[pyclass(extends = Task, module = "ie_schema")]
398pub struct RelationExtractionTask {
400 task_plan: Arc<task_plan::TaskPlan>,
401 index: usize,
402}
403
404#[cfg(feature = "python")]
405#[gen_stub_pymethods]
406#[pymethods]
407impl RelationExtractionTask {
408 #[getter]
409 fn name(&self) -> String {
411 let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
412 unreachable!()
413 };
414 rtp.relation.to_string()
415 }
416
417 #[getter]
418 fn head(&self) -> String {
420 let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
421 unreachable!()
422 };
423 rtp.head.to_string()
424 }
425
426 #[getter]
427 fn tail(&self) -> String {
429 let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
430 unreachable!()
431 };
432 rtp.tail.to_string()
433 }
434
435 #[getter]
436 fn description(&self) -> Option<String> {
438 let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
439 unreachable!()
440 };
441 rtp.description.clone()
442 }
443}
444
445#[cfg(feature = "python")]
446#[gen_stub_pyclass]
447#[pyclass(extends = Task, module = "ie_schema")]
448pub struct JSONStructureTask {
450 task_plan: Arc<task_plan::TaskPlan>,
451 index: usize,
452}
453
454#[cfg(feature = "python")]
455#[gen_stub_pymethods]
456#[pymethods]
457impl JSONStructureTask {
458 #[getter]
459 fn name(&self) -> String {
461 let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.index] else {
462 unreachable!()
463 };
464 stp.structure.to_string()
465 }
466
467 #[getter]
468 fn children(&self) -> Vec<StructureChild> {
470 let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.index] else {
471 unreachable!()
472 };
473 stp.children
474 .iter()
475 .enumerate()
476 .map(|(ci, _)| StructureChild {
477 task_plan: self.task_plan.clone(),
478 structure_index: self.index,
479 child_index: ci,
480 })
481 .collect()
482 }
483}
484
485#[cfg(feature = "python")]
486#[gen_stub_pyclass]
487#[pyclass(module = "ie_schema")]
488pub struct StructureChild {
490 task_plan: Arc<task_plan::TaskPlan>,
491 structure_index: usize,
492 child_index: usize,
493}
494
495#[cfg(feature = "python")]
496#[gen_stub_pymethods]
497#[pymethods]
498impl StructureChild {
499 #[getter]
500 fn property(&self) -> String {
502 let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.structure_index] else {
503 unreachable!()
504 };
505 stp.children[self.child_index].property.to_string()
506 }
507
508 #[getter]
509 fn choices(&self) -> Vec<String> {
511 let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.structure_index] else {
512 unreachable!()
513 };
514 stp.children[self.child_index]
515 .choices
516 .iter()
517 .map(|c| c.to_string())
518 .collect()
519 }
520
521 #[getter]
522 fn description(&self) -> Option<String> {
524 let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.structure_index] else {
525 unreachable!()
526 };
527 stp.children[self.child_index].description.clone()
528 }
529}
530
531#[cfg(feature = "python")]
532define_stub_info_gatherer!(stub_info);