Skip to main content

_rdx/
lib.rs

1//! Python bindings for the RDX parser via PyO3.
2//!
3//! Exposes parsing, schema validation, transforms, and utility functions
4//! to Python. All AST data is returned as native Python dicts and lists.
5
6use pyo3::prelude::*;
7use pyo3::types::PyAny;
8use pythonize::pythonize;
9
10/// Parse an RDX document and return the AST as a Python dict.
11#[pyfunction]
12fn parse<'py>(py: Python<'py>, input: &str) -> PyResult<Bound<'py, PyAny>> {
13    let root = rdx_parser::parse(input);
14    let val = serde_json::to_value(&root)
15        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
16    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
17}
18
19/// Parse with default transforms (auto-slug + table of contents).
20#[pyfunction]
21fn parse_with_defaults<'py>(py: Python<'py>, input: &str) -> PyResult<Bound<'py, PyAny>> {
22    let root = rdx_transform::parse_with_defaults(input);
23    let val = serde_json::to_value(&root)
24        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
25    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
26}
27
28/// Parse with a specific set of transforms.
29///
30/// Supported transform names: ``"auto-slug"``, ``"toc"``, ``"github"``.
31/// Pass ``repo="owner/repo"`` for the github transform.
32#[pyfunction]
33#[pyo3(signature = (input, transforms, repo=None))]
34fn parse_with_transforms<'py>(
35    py: Python<'py>,
36    input: &str,
37    transforms: Vec<String>,
38    repo: Option<String>,
39) -> PyResult<Bound<'py, PyAny>> {
40    let mut pipeline = rdx_transform::Pipeline::new();
41    for name in &transforms {
42        match name.as_str() {
43            "auto-slug" => {
44                pipeline = pipeline.add(rdx_transform::AutoSlug::new());
45            }
46            "toc" => {
47                pipeline = pipeline.add(rdx_transform::TableOfContents::default());
48            }
49            "github" => {
50                let gh = if let Some(ref r) = repo {
51                    rdx_github::GithubReferences::new(r)
52                } else {
53                    rdx_github::GithubReferences::default()
54                };
55                pipeline = pipeline.add(gh);
56            }
57            other => {
58                return Err(pyo3::exceptions::PyValueError::new_err(format!(
59                    "unknown transform: \"{other}\""
60                )));
61            }
62        }
63    }
64    let root = pipeline.run(input);
65    let val = serde_json::to_value(&root)
66        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
67    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
68}
69
70/// Validate an AST dict against a schema dict.
71/// Returns a list of diagnostic dicts.
72#[pyfunction]
73fn validate<'py>(
74    py: Python<'py>,
75    ast: &Bound<'_, PyAny>,
76    schema: &Bound<'_, PyAny>,
77) -> PyResult<Bound<'py, PyAny>> {
78    let ast_val: serde_json::Value = pythonize::depythonize(ast)
79        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
80    let schema_val: serde_json::Value = pythonize::depythonize(schema)
81        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
82
83    let root: rdx_ast::Root = serde_json::from_value(ast_val)
84        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
85    let schema: rdx_schema::Schema = serde_json::from_value(schema_val)
86        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
87
88    let diagnostics = rdx_schema::validate(&root, &schema);
89
90    let results: Vec<serde_json::Value> = diagnostics
91        .into_iter()
92        .map(|d| {
93            serde_json::json!({
94                "severity": match d.severity {
95                    rdx_schema::Severity::Error => "error",
96                    rdx_schema::Severity::Warning => "warning",
97                },
98                "message": d.message,
99                "component": d.component,
100                "line": d.line,
101                "column": d.column,
102            })
103        })
104        .collect();
105
106    let val = serde_json::to_value(&results)
107        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
108    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
109}
110
111/// Extract plain text from an AST dict.
112#[pyfunction]
113fn collect_text(ast: &Bound<'_, PyAny>) -> PyResult<String> {
114    let ast_val: serde_json::Value = pythonize::depythonize(ast)
115        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
116    let root: rdx_ast::Root = serde_json::from_value(ast_val)
117        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
118    Ok(rdx_transform::collect_text(&root.children))
119}
120
121/// Find all nodes of a given type. Returns a list of node dicts.
122#[pyfunction]
123fn query_all<'py>(
124    py: Python<'py>,
125    ast: &Bound<'_, PyAny>,
126    node_type: &str,
127) -> PyResult<Bound<'py, PyAny>> {
128    let ast_val: serde_json::Value = pythonize::depythonize(ast)
129        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
130    let root: rdx_ast::Root = serde_json::from_value(ast_val)
131        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
132
133    let mut results: Vec<&rdx_ast::Node> = Vec::new();
134    collect_by_type(&root.children, node_type, &mut results);
135
136    let val = serde_json::to_value(&results)
137        .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
138    pythonize(py, &val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
139}
140
141fn collect_by_type<'a>(
142    nodes: &'a [rdx_ast::Node],
143    node_type: &str,
144    results: &mut Vec<&'a rdx_ast::Node>,
145) {
146    for node in nodes {
147        if node_type_matches(node, node_type) {
148            results.push(node);
149        }
150        if let Some(children) = node.children() {
151            collect_by_type(children, node_type, results);
152        }
153    }
154}
155
156#[allow(clippy::match_like_matches_macro)]
157fn node_type_matches(node: &rdx_ast::Node, expected: &str) -> bool {
158    match (node, expected) {
159        (rdx_ast::Node::Text(_), "text") => true,
160        (rdx_ast::Node::CodeInline(_), "code_inline") => true,
161        (rdx_ast::Node::CodeBlock(_), "code_block") => true,
162        (rdx_ast::Node::Paragraph(_), "paragraph") => true,
163        (rdx_ast::Node::Heading(_), "heading") => true,
164        (rdx_ast::Node::List(_), "list") => true,
165        (rdx_ast::Node::ListItem(_), "list_item") => true,
166        (rdx_ast::Node::Blockquote(_), "blockquote") => true,
167        (rdx_ast::Node::ThematicBreak(_), "thematic_break") => true,
168        (rdx_ast::Node::Html(_), "html") => true,
169        (rdx_ast::Node::Table(_), "table") => true,
170        (rdx_ast::Node::TableRow(_), "table_row") => true,
171        (rdx_ast::Node::TableCell(_), "table_cell") => true,
172        (rdx_ast::Node::Link(_), "link") => true,
173        (rdx_ast::Node::Image(_), "image") => true,
174        (rdx_ast::Node::Emphasis(_), "emphasis") => true,
175        (rdx_ast::Node::Strong(_), "strong") => true,
176        (rdx_ast::Node::Strikethrough(_), "strikethrough") => true,
177        (rdx_ast::Node::FootnoteDefinition(_), "footnote_definition") => true,
178        (rdx_ast::Node::FootnoteReference(_), "footnote_reference") => true,
179        (rdx_ast::Node::MathInline(_), "math_inline") => true,
180        (rdx_ast::Node::MathDisplay(_), "math_display") => true,
181        (rdx_ast::Node::Component(_), "component") => true,
182        (rdx_ast::Node::Variable(_), "variable") => true,
183        (rdx_ast::Node::Error(_), "error") => true,
184        _ => false,
185    }
186}
187
188/// Return the RDX parser version.
189#[pyfunction]
190fn version() -> String {
191    env!("CARGO_PKG_VERSION").to_string()
192}
193
194/// RDX Python module.
195#[pymodule]
196fn _rdx(m: &Bound<'_, PyModule>) -> PyResult<()> {
197    m.add_function(wrap_pyfunction!(parse, m)?)?;
198    m.add_function(wrap_pyfunction!(parse_with_defaults, m)?)?;
199    m.add_function(wrap_pyfunction!(parse_with_transforms, m)?)?;
200    m.add_function(wrap_pyfunction!(validate, m)?)?;
201    m.add_function(wrap_pyfunction!(collect_text, m)?)?;
202    m.add_function(wrap_pyfunction!(query_all, m)?)?;
203    m.add_function(wrap_pyfunction!(version, m)?)?;
204    Ok(())
205}