ie-schema 0.1.0

A flexible schema specification and parser for information extraction tasks.
Documentation
pub mod expanded;
pub mod ingest;
pub mod lifted;
pub mod normalized;
pub mod prompt_plan;
pub mod task_plan;
pub mod token_plan;

#[cfg(feature = "python")]
use std::sync::Arc;
#[cfg(feature = "python")]
use std::sync::atomic::{AtomicUsize, Ordering};

#[cfg(feature = "python")]
use pyo3::exceptions::PyValueError;
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "python")]
use pyo3::types::PyType;



#[cfg(feature = "python")]
impl From<normalized::SchemaLoadError> for PyErr {
    fn from(e: normalized::SchemaLoadError) -> Self {
        PyValueError::new_err(e.to_string())
    }
}

#[cfg(feature = "python")]
impl From<expanded::SchemaExpandError> for PyErr {
    fn from(e: expanded::SchemaExpandError) -> Self {
        PyValueError::new_err(e.to_string())
    }
}

#[cfg(feature = "python")]
impl From<lifted::SchemaLiftError> for PyErr {
    fn from(e: lifted::SchemaLiftError) -> Self {
        PyValueError::new_err(e.to_string())
    }
}

#[cfg(feature = "python")]
impl From<task_plan::TaskPlanError> for PyErr {
    fn from(e: task_plan::TaskPlanError) -> Self {
        PyValueError::new_err(e.to_string())
    }
}

#[cfg(feature = "python")]
impl From<prompt_plan::PromptPlanError> for PyErr {
    fn from(e: prompt_plan::PromptPlanError) -> Self {
        PyValueError::new_err(e.to_string())
    }
}

#[cfg(feature = "python")]
#[pymodule]
#[pyo3(name = "ie_schema")]
fn ieschema_library(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_class::<IESchema>()?;
    m.add_class::<Task>()?;
    m.add_class::<ClassificationTask>()?;
    m.add_class::<EntityExtractionTask>()?;
    m.add_class::<RelationExtractionTask>()?;
    m.add_class::<JSONStructureTask>()?;
    m.add_class::<StructureChild>()?;
    Ok(())
}

#[cfg(feature = "python")]
#[pyclass]
pub struct IESchema {
    task_plan: Arc<task_plan::TaskPlan>,
    prompt_plan: prompt_plan::PromptPlan,
    iter_index: AtomicUsize,
}

#[cfg(feature = "python")]
impl IESchema {
    fn loads_inner(s: &str) -> PyResult<Self> {
        let normalized = normalized::NormalizedSchema::from_json_str(s)?;
        let expanded = expanded::ExpandedSchema::try_from(normalized)?;
        let lifted = lifted::LiftedSchema::try_from(expanded)?;
        let tp = task_plan::TaskPlan::try_from(lifted)?;
        let pp = prompt_plan::PromptPlan::try_from(tp.clone())?;
        Ok(Self {
            task_plan: Arc::new(tp),
            prompt_plan: pp,
            iter_index: AtomicUsize::new(0),
        })
    }
}

#[cfg(feature = "python")]
#[pymethods]
impl IESchema {
    #[classmethod]
    fn loads(_cls: &Bound<'_, PyType>, s: String) -> PyResult<Self> {
        Self::loads_inner(&s)
    }

    #[classmethod]
    fn load(_cls: &Bound<'_, PyType>, path: String) -> PyResult<Self> {
        let content = std::fs::read_to_string(&path)
            .map_err(|e| PyValueError::new_err(format!("failed to read {}: {}", path, e)))?;
        Self::loads_inner(&content)
    }

    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
        slf.iter_index.store(0, Ordering::Relaxed);
        slf
    }

    fn __next__(slf: PyRefMut<'_, Self>) -> Option<Py<PyAny>> {
        let idx = slf.iter_index.load(Ordering::Relaxed);
        if idx >= slf.task_plan.tasks.len() {
            return None;
        }
        slf.iter_index.store(idx + 1, Ordering::Relaxed);

        let arc = slf.task_plan.clone();
        let py = slf.py();

        match &slf.task_plan.tasks[idx] {
            PlannedTask::Classification(_) => {
                let obj = Bound::new(
                    py,
                    PyClassInitializer::from(Task {}).add_subclass(ClassificationTask {
                        task_plan: arc,
                        index: idx,
                    }),
                )
                .unwrap();
                Some(obj.into_any().unbind())
            }
            PlannedTask::Entity(_) => {
                let obj = Bound::new(
                    py,
                    PyClassInitializer::from(Task {}).add_subclass(EntityExtractionTask {
                        task_plan: arc,
                        index: idx,
                    }),
                )
                .unwrap();
                Some(obj.into_any().unbind())
            }
            PlannedTask::Relation(_) => {
                let obj = Bound::new(
                    py,
                    PyClassInitializer::from(Task {}).add_subclass(RelationExtractionTask {
                        task_plan: arc,
                        index: idx,
                    }),
                )
                .unwrap();
                Some(obj.into_any().unbind())
            }
            PlannedTask::Structure(_) => {
                let obj = Bound::new(
                    py,
                    PyClassInitializer::from(Task {}).add_subclass(JSONStructureTask {
                        task_plan: arc,
                        index: idx,
                    }),
                )
                .unwrap();
                Some(obj.into_any().unbind())
            }
        }
    }

    fn prompt(&self) -> String {
        self.prompt_plan.render_debug_string()
    }
}

#[cfg(feature = "python")]
#[pyclass(subclass)]
pub struct Task {}

#[cfg(feature = "python")]
#[pyclass(extends = Task)]
pub struct ClassificationTask {
    task_plan: Arc<task_plan::TaskPlan>,
    index: usize,
}

#[cfg(feature = "python")]
#[pymethods]
impl ClassificationTask {
    #[getter]
    fn task(&self) -> String {
        let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        ctp.task.to_string()
    }

    #[getter]
    fn labels(&self) -> Vec<String> {
        let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        ctp.labels.iter().map(|l| l.to_string()).collect()
    }

    #[getter]
    fn threshold(&self) -> Option<f64> {
        let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        ctp.threshold
    }

    #[getter]
    fn multi_label(&self) -> bool {
        let PlannedTask::Classification(ref ctp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        ctp.multi_label
    }
}

#[cfg(feature = "python")]
#[pyclass(extends = Task)]
pub struct EntityExtractionTask {
    task_plan: Arc<task_plan::TaskPlan>,
    index: usize,
}

#[cfg(feature = "python")]
#[pymethods]
impl EntityExtractionTask {
    #[getter]
    fn entities(&self) -> Vec<String> {
        let PlannedTask::Entity(ref etp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        etp.entities.iter().map(|e| e.to_string()).collect()
    }
}

#[cfg(feature = "python")]
#[pyclass(extends = Task)]
pub struct RelationExtractionTask {
    task_plan: Arc<task_plan::TaskPlan>,
    index: usize,
}

#[cfg(feature = "python")]
#[pymethods]
impl RelationExtractionTask {
    #[getter]
    fn name(&self) -> String {
        let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        rtp.relation.to_string()
    }

    #[getter]
    fn head(&self) -> String {
        let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        rtp.head.to_string()
    }

    #[getter]
    fn tail(&self) -> String {
        let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        rtp.tail.to_string()
    }

    #[getter]
    fn description(&self) -> Option<String> {
        let PlannedTask::Relation(ref rtp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        rtp.description.clone()
    }
}

#[cfg(feature = "python")]
#[pyclass(extends = Task)]
pub struct JSONStructureTask {
    task_plan: Arc<task_plan::TaskPlan>,
    index: usize,
}

#[cfg(feature = "python")]
#[pymethods]
impl JSONStructureTask {
    #[getter]
    fn name(&self) -> String {
        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        stp.structure.to_string()
    }

    #[getter]
    fn children(&self) -> Vec<StructureChild> {
        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.index] else {
            unreachable!()
        };
        stp.children
            .iter()
            .enumerate()
            .map(|(ci, _)| StructureChild {
                task_plan: self.task_plan.clone(),
                structure_index: self.index,
                child_index: ci,
            })
            .collect()
    }
}

#[cfg(feature = "python")]
#[pyclass]
pub struct StructureChild {
    task_plan: Arc<task_plan::TaskPlan>,
    structure_index: usize,
    child_index: usize,
}

#[cfg(feature = "python")]
#[pymethods]
impl StructureChild {
    #[getter]
    fn property(&self) -> String {
        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.structure_index] else {
            unreachable!()
        };
        stp.children[self.child_index].property.to_string()
    }

    #[getter]
    fn choices(&self) -> Vec<String> {
        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.structure_index] else {
            unreachable!()
        };
        stp.children[self.child_index]
            .choices
            .iter()
            .map(|c| c.to_string())
            .collect()
    }

    #[getter]
    fn description(&self) -> Option<String> {
        let PlannedTask::Structure(ref stp) = self.task_plan.tasks[self.structure_index] else {
            unreachable!()
        };
        stp.children[self.child_index].description.clone()
    }
}