use std::collections::HashMap;
use std::fs;
use std::path::Path;
fn main() {
let src_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("src");
let functions = collect_pyfunctions(&src_dir);
println!("# Auto-generated stubs for scirs2-python PyO3 bindings");
println!("# Generated by: cargo run --example generate_stubs -p scirs2-python");
println!("# Do NOT edit manually — regenerate with the command above.");
println!("#");
println!("# Note: parameter names use positional placeholders (a0, a1, ...).");
println!("# A proc-macro approach can yield exact parameter names in future.");
println!();
println!("from __future__ import annotations");
println!("from typing import Any");
println!();
for func in &functions {
if let Some(doc) = &func.doc {
println!(
"def {}({}) -> {}:",
func.name, func.params, func.return_type
);
println!(" \"\"\"{}\"\"\"", doc);
println!(" ...");
} else {
println!(
"def {}({}) -> {}: ...",
func.name, func.params, func.return_type
);
}
println!();
}
let rs_count = count_rs_files(&src_dir);
eprintln!(
"Generated {} function stubs from {} Rust source files",
functions.len(),
rs_count
);
}
struct PyFunctionStub {
name: String,
params: String,
return_type: String,
doc: Option<String>,
}
fn collect_pyfunctions(src_dir: &Path) -> Vec<PyFunctionStub> {
let mut map: HashMap<String, PyFunctionStub> = HashMap::new();
collect_from_dir(src_dir, &mut map);
let mut stubs: Vec<PyFunctionStub> = map.into_values().collect();
stubs.sort_by(|a, b| a.name.cmp(&b.name));
stubs
}
fn collect_from_dir(dir: &Path, map: &mut HashMap<String, PyFunctionStub>) {
let entries = match fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return,
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
collect_from_dir(&path, map);
} else if path.extension().map(|e| e == "rs").unwrap_or(false) {
collect_from_file(&path, map);
}
}
}
fn collect_from_file(path: &Path, map: &mut HashMap<String, PyFunctionStub>) {
let content = match fs::read_to_string(path) {
Ok(c) => c,
Err(_) => return,
};
let lines: Vec<&str> = content.lines().collect();
let mut i = 0usize;
while i < lines.len() {
let trimmed = lines[i].trim();
if trimmed == "#[pyfunction]" || trimmed.starts_with("#[pyfunction(") {
let doc = preceding_doc_comment(&lines, i);
let mut j = i + 1;
while j < lines.len() {
let next = lines[j].trim();
let is_fn = next.starts_with("pub fn ")
|| next.starts_with("fn ")
|| next.starts_with("pub async fn ")
|| next.starts_with("async fn ");
if is_fn {
break;
}
let is_attr_or_doc = next.starts_with('#') || next.starts_with("///");
if !is_attr_or_doc {
break;
}
j += 1;
}
if j < lines.len() {
if let Some(stub) = parse_fn_signature(lines[j], doc) {
map.insert(stub.name.clone(), stub);
}
}
}
i += 1;
}
}
fn preceding_doc_comment(lines: &[&str], attr_line: usize) -> Option<String> {
if attr_line == 0 {
return None;
}
let mut docs: Vec<&str> = Vec::new();
let mut k = attr_line as isize - 1;
while k >= 0 {
let t = lines[k as usize].trim();
if t.starts_with("///") {
docs.push(t.trim_start_matches("///").trim());
k -= 1;
} else if t.starts_with('#') || t.is_empty() {
k -= 1;
} else {
break;
}
}
if docs.is_empty() {
None
} else {
docs.reverse();
Some(docs.join(" "))
}
}
fn parse_fn_signature(line: &str, doc: Option<String>) -> Option<PyFunctionStub> {
let line = line.trim();
let line = line
.trim_start_matches("pub ")
.trim_start_matches("async ")
.trim_start_matches("pub async ");
if !line.starts_with("fn ") {
return None;
}
let after_fn = &line[3..];
let name_end = after_fn.find(['(', '<'])?;
let name = after_fn[..name_end].trim().to_string();
if name.is_empty() {
return None;
}
let return_type = infer_return_type(line);
let params = infer_params(line);
Some(PyFunctionStub {
name,
params,
return_type,
doc,
})
}
fn infer_return_type(line: &str) -> String {
if let Some(arrow_pos) = line.rfind("->") {
let rhs = line[arrow_pos + 2..].trim();
let rhs = rhs.trim_end_matches('{').trim();
return map_rust_to_python_type(rhs);
}
"None".to_string()
}
fn map_rust_to_python_type(rust_type: &str) -> String {
if let Some(inner) = strip_wrapper(rust_type, "PyResult<") {
return map_rust_to_python_type(inner);
}
if let Some(inner) = strip_wrapper(rust_type, "Py<") {
return map_rust_to_python_type(inner);
}
match rust_type {
"f64" | "f32" => "float".to_string(),
"i32" | "i64" | "u32" | "u64" | "usize" | "isize" => "int".to_string(),
"bool" => "bool".to_string(),
"String" | "&str" | "str" => "str".to_string(),
"()" => "None".to_string(),
"PyObject" | "PyAny" | "JsValue" => "Any".to_string(),
_ if rust_type.starts_with("Vec<f64>") || rust_type.starts_with("Vec<f32>") => {
"list[float]".to_string()
}
_ if rust_type.starts_with("Vec<i") || rust_type.starts_with("Vec<u") => {
"list[int]".to_string()
}
_ if rust_type.starts_with("Vec<Vec<") => "list[list[Any]]".to_string(),
_ if rust_type.starts_with("Vec<") => "list[Any]".to_string(),
_ if rust_type.starts_with("Option<") => "Any | None".to_string(),
_ if rust_type.starts_with("PyArray1<") => "numpy.ndarray".to_string(),
_ if rust_type.starts_with("PyArray2<") => "numpy.ndarray".to_string(),
_ if rust_type.starts_with("PyArray<") => "numpy.ndarray".to_string(),
_ => "Any".to_string(),
}
}
fn strip_wrapper<'a>(s: &'a str, prefix: &str) -> Option<&'a str> {
if s.starts_with(prefix) && s.ends_with('>') {
Some(&s[prefix.len()..s.len() - 1])
} else {
None
}
}
fn infer_params(line: &str) -> String {
let open = match line.find('(') {
Some(i) => i + 1,
None => return "*args: Any".to_string(),
};
let close = match line.rfind(')') {
Some(i) => i,
None => return "*args: Any".to_string(),
};
if open >= close {
return String::new();
}
let args_str = &line[open..close];
let mut depth = 0usize;
let mut commas = 0usize;
for ch in args_str.chars() {
match ch {
'<' => depth += 1,
'>' => depth = depth.saturating_sub(1),
',' if depth == 0 => commas += 1,
_ => {}
}
}
let has_py_param = args_str.contains("Python<");
let param_count = if args_str.trim().is_empty() {
0
} else if has_py_param {
commas } else {
commas + 1
};
let real_params = if has_py_param && param_count > 0 {
param_count - 1
} else {
param_count
};
if real_params == 0 {
String::new()
} else {
(0..real_params)
.map(|i| format!("a{i}: Any"))
.collect::<Vec<_>>()
.join(", ")
}
}
fn count_rs_files(dir: &Path) -> usize {
let entries = match fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return 0,
};
entries
.flatten()
.map(|e| {
let p = e.path();
if p.is_dir() {
count_rs_files(&p)
} else if p.extension().map(|x| x == "rs").unwrap_or(false) {
1
} else {
0
}
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn examples_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR")).join("examples")
}
fn src_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR")).join("src")
}
#[test]
fn generate_stubs_example_source_exists() {
let path = examples_dir().join("generate_stubs.rs");
assert!(
path.exists(),
"generate_stubs.rs should exist at {}",
path.display()
);
}
#[test]
fn src_dir_has_rust_files() {
let count = count_rs_files(&src_dir());
assert!(count > 0, "src/ should contain at least one .rs file");
}
#[test]
fn parse_fn_signature_extracts_name() {
let line = "pub fn gpu_matmul(a: Vec<f64>, b: Vec<f64>) -> PyResult<Vec<f64>> {";
let stub = parse_fn_signature(line, None).expect("should parse successfully");
assert_eq!(stub.name, "gpu_matmul");
}
#[test]
fn parse_fn_signature_extracts_float_return() {
let line = "pub fn det_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {";
let stub = parse_fn_signature(line, None).expect("should parse successfully");
assert_eq!(stub.return_type, "float");
}
#[test]
fn parse_fn_signature_extracts_vec_return() {
let line = "pub fn gpu_elementwise(data: Vec<f64>, op: &str) -> PyResult<Vec<f64>> {";
let stub = parse_fn_signature(line, None).expect("should parse");
assert_eq!(stub.return_type, "list[float]");
}
#[test]
fn parse_fn_signature_string_return() {
let line = "pub fn gpu_device_info() -> String {";
let stub = parse_fn_signature(line, None).expect("should parse");
assert_eq!(stub.return_type, "str");
}
#[test]
fn parse_fn_signature_void_return() {
let line = "pub fn init() {";
let stub = parse_fn_signature(line, None).expect("should parse");
assert_eq!(stub.return_type, "None");
}
#[test]
fn collect_pyfunctions_finds_gpu_ops() {
let src = src_dir();
let stubs = collect_pyfunctions(&src);
let names: Vec<&str> = stubs.iter().map(|s| s.name.as_str()).collect();
assert!(
names.contains(&"gpu_matmul") || names.contains(&"gpu_device_info"),
"Expected to find at least gpu_matmul or gpu_device_info in {:?}",
names
);
}
#[test]
fn infer_return_type_option() {
assert_eq!(map_rust_to_python_type("Option<f64>"), "Any | None");
}
#[test]
fn infer_return_type_ndarray() {
assert_eq!(map_rust_to_python_type("PyArray2<f64>"), "numpy.ndarray");
}
}