Skip to main content

_rdx/
lib.rs

1use pyo3::prelude::*;
2use pyo3::types::PyAny;
3use pythonize::pythonize;
4
5/// Parse an RDX document and return the AST as a Python dict.
6#[pyfunction]
7fn parse<'py>(py: Python<'py>, input: &str) -> PyResult<Bound<'py, PyAny>> {
8    let root = rdx_parser::parse(input);
9    let val = serde_json::to_value(&root)
10        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
11    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
12}
13
14/// Parse with default transforms (auto-slug + table of contents).
15#[pyfunction]
16fn parse_with_defaults<'py>(py: Python<'py>, input: &str) -> PyResult<Bound<'py, PyAny>> {
17    let root = rdx_transform::parse_with_defaults(input);
18    let val = serde_json::to_value(&root)
19        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
20    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
21}
22
23/// Parse with a specific set of transforms.
24#[pyfunction]
25fn parse_with_transforms<'py>(
26    py: Python<'py>,
27    input: &str,
28    transforms: Vec<String>,
29) -> PyResult<Bound<'py, PyAny>> {
30    let mut pipeline = rdx_transform::Pipeline::new();
31    for name in &transforms {
32        match name.as_str() {
33            "auto-slug" => {
34                pipeline = pipeline.add(rdx_transform::AutoSlug::new());
35            }
36            "toc" => {
37                pipeline = pipeline.add(rdx_transform::TableOfContents::default());
38            }
39            other => {
40                return Err(pyo3::exceptions::PyValueError::new_err(format!(
41                    "unknown transform: \"{other}\""
42                )));
43            }
44        }
45    }
46    let root = pipeline.run(input);
47    let val = serde_json::to_value(&root)
48        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
49    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
50}
51
52/// Validate an AST dict against a schema dict.
53/// Returns a list of diagnostic dicts.
54#[pyfunction]
55fn validate<'py>(
56    py: Python<'py>,
57    ast: &Bound<'_, PyAny>,
58    schema: &Bound<'_, PyAny>,
59) -> PyResult<Bound<'py, PyAny>> {
60    let ast_val: serde_json::Value = pythonize::depythonize(ast)
61        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
62    let schema_val: serde_json::Value = pythonize::depythonize(schema)
63        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
64
65    let root: rdx_ast::Root = serde_json::from_value(ast_val)
66        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
67    let schema: rdx_schema::Schema = serde_json::from_value(schema_val)
68        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
69
70    let diagnostics = rdx_schema::validate(&root, &schema);
71
72    let results: Vec<serde_json::Value> = diagnostics
73        .into_iter()
74        .map(|d| {
75            serde_json::json!({
76                "severity": match d.severity {
77                    rdx_schema::Severity::Error => "error",
78                    rdx_schema::Severity::Warning => "warning",
79                },
80                "message": d.message,
81                "component": d.component,
82                "line": d.line,
83                "column": d.column,
84            })
85        })
86        .collect();
87
88    let val = serde_json::to_value(&results)
89        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
90    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
91}
92
93/// Extract plain text from an AST dict.
94#[pyfunction]
95fn collect_text(ast: &Bound<'_, PyAny>) -> PyResult<String> {
96    let ast_val: serde_json::Value = pythonize::depythonize(ast)
97        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
98    let root: rdx_ast::Root = serde_json::from_value(ast_val)
99        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
100    Ok(rdx_transform::collect_text(&root.children))
101}
102
103/// Find all nodes of a given type. Returns a list of node dicts.
104#[pyfunction]
105fn query_all<'py>(
106    py: Python<'py>,
107    ast: &Bound<'_, PyAny>,
108    node_type: &str,
109) -> PyResult<Bound<'py, PyAny>> {
110    let ast_val: serde_json::Value = pythonize::depythonize(ast)
111        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
112    let root: rdx_ast::Root = serde_json::from_value(ast_val)
113        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
114
115    let mut results: Vec<&rdx_ast::Node> = Vec::new();
116    collect_by_type(&root.children, node_type, &mut results);
117
118    let val = serde_json::to_value(&results)
119        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
120    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
121}
122
123fn collect_by_type<'a>(
124    nodes: &'a [rdx_ast::Node],
125    node_type: &str,
126    results: &mut Vec<&'a rdx_ast::Node>,
127) {
128    for node in nodes {
129        if node_type_matches(node, node_type) {
130            results.push(node);
131        }
132        if let Some(children) = node.children() {
133            collect_by_type(children, node_type, results);
134        }
135    }
136}
137
138#[allow(clippy::match_like_matches_macro)]
139fn node_type_matches(node: &rdx_ast::Node, expected: &str) -> bool {
140    match (node, expected) {
141        (rdx_ast::Node::Text(_), "text") => true,
142        (rdx_ast::Node::CodeInline(_), "code_inline") => true,
143        (rdx_ast::Node::CodeBlock(_), "code_block") => true,
144        (rdx_ast::Node::Paragraph(_), "paragraph") => true,
145        (rdx_ast::Node::Heading(_), "heading") => true,
146        (rdx_ast::Node::List(_), "list") => true,
147        (rdx_ast::Node::ListItem(_), "list_item") => true,
148        (rdx_ast::Node::Blockquote(_), "blockquote") => true,
149        (rdx_ast::Node::ThematicBreak(_), "thematic_break") => true,
150        (rdx_ast::Node::Html(_), "html") => true,
151        (rdx_ast::Node::Table(_), "table") => true,
152        (rdx_ast::Node::TableRow(_), "table_row") => true,
153        (rdx_ast::Node::TableCell(_), "table_cell") => true,
154        (rdx_ast::Node::Link(_), "link") => true,
155        (rdx_ast::Node::Image(_), "image") => true,
156        (rdx_ast::Node::Emphasis(_), "emphasis") => true,
157        (rdx_ast::Node::Strong(_), "strong") => true,
158        (rdx_ast::Node::Strikethrough(_), "strikethrough") => true,
159        (rdx_ast::Node::FootnoteDefinition(_), "footnote_definition") => true,
160        (rdx_ast::Node::FootnoteReference(_), "footnote_reference") => true,
161        (rdx_ast::Node::MathInline(_), "math_inline") => true,
162        (rdx_ast::Node::MathDisplay(_), "math_display") => true,
163        (rdx_ast::Node::Component(_), "component") => true,
164        (rdx_ast::Node::Variable(_), "variable") => true,
165        (rdx_ast::Node::Error(_), "error") => true,
166        _ => false,
167    }
168}
169
170/// Return the RDX parser version.
171#[pyfunction]
172fn version() -> String {
173    env!("CARGO_PKG_VERSION").to_string()
174}
175
176/// RDX Python module.
177#[pymodule]
178fn _rdx(m: &Bound<'_, PyModule>) -> PyResult<()> {
179    m.add_function(wrap_pyfunction!(parse, m)?)?;
180    m.add_function(wrap_pyfunction!(parse_with_defaults, m)?)?;
181    m.add_function(wrap_pyfunction!(parse_with_transforms, m)?)?;
182    m.add_function(wrap_pyfunction!(validate, m)?)?;
183    m.add_function(wrap_pyfunction!(collect_text, m)?)?;
184    m.add_function(wrap_pyfunction!(query_all, m)?)?;
185    m.add_function(wrap_pyfunction!(version, m)?)?;
186    Ok(())
187}