use std::ffi::{CStr, CString};
use std::os::raw::c_char;
use std::ptr;
use tensorlogic_compiler::CompilerContext;
use crate::executor::{Backend, ExecutionConfig};
use crate::optimize::OptimizationLevel;
use crate::parser::parse_expression;
#[repr(C)]
pub struct TLGraphResult {
pub graph_data: *mut c_char,
pub error_message: *mut c_char,
pub tensor_count: usize,
pub node_count: usize,
}
#[repr(C)]
pub struct TLExecutionResult {
pub output_data: *mut c_char,
pub error_message: *mut c_char,
pub execution_time_us: u64,
}
#[repr(C)]
pub struct TLOptimizationResult {
pub graph_data: *mut c_char,
pub error_message: *mut c_char,
pub tensors_removed: usize,
pub nodes_removed: usize,
}
#[repr(C)]
pub struct TLBenchmarkResult {
pub error_message: *mut c_char,
pub mean_us: f64,
pub std_dev_us: f64,
pub min_us: u64,
pub max_us: u64,
pub iterations: usize,
}
fn to_c_string(s: String) -> *mut c_char {
match CString::new(s) {
Ok(cstr) => cstr.into_raw(),
Err(_) => ptr::null_mut(),
}
}
unsafe fn from_c_string(s: *const c_char) -> Result<String, String> {
if s.is_null() {
return Err("NULL pointer passed".to_string());
}
CStr::from_ptr(s)
.to_str()
.map(|s| s.to_string())
.map_err(|e| format!("Invalid UTF-8 string: {}", e))
}
#[no_mangle]
pub unsafe extern "C" fn tl_compile_expr(expr: *const c_char) -> *mut TLGraphResult {
let result = Box::new(TLGraphResult {
graph_data: ptr::null_mut(),
error_message: ptr::null_mut(),
tensor_count: 0,
node_count: 0,
});
let expr_str = match from_c_string(expr) {
Ok(s) => s,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Invalid expression: {}", e));
return Box::into_raw(result);
}
};
let tlexpr = match parse_expression(&expr_str) {
Ok(e) => e,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Parse error: {}", e));
return Box::into_raw(result);
}
};
let mut context = CompilerContext::new();
let graph = match tensorlogic_compiler::compile_to_einsum_with_context(&tlexpr, &mut context) {
Ok(g) => g,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Compilation error: {:?}", e));
return Box::into_raw(result);
}
};
let json = match serde_json::to_string_pretty(&graph) {
Ok(j) => j,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Serialization error: {}", e));
return Box::into_raw(result);
}
};
let mut result = result;
result.graph_data = to_c_string(json);
result.tensor_count = graph.tensors.len();
result.node_count = graph.nodes.len();
Box::into_raw(result)
}
#[no_mangle]
pub unsafe extern "C" fn tl_execute_graph(
graph_json: *const c_char,
backend: *const c_char,
) -> *mut TLExecutionResult {
let result = Box::new(TLExecutionResult {
output_data: ptr::null_mut(),
error_message: ptr::null_mut(),
execution_time_us: 0,
});
let json_str = match from_c_string(graph_json) {
Ok(s) => s,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Invalid graph JSON: {}", e));
return Box::into_raw(result);
}
};
let backend_str = match from_c_string(backend) {
Ok(s) => s,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Invalid backend: {}", e));
return Box::into_raw(result);
}
};
let graph: tensorlogic_ir::EinsumGraph = match serde_json::from_str(&json_str) {
Ok(g) => g,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("JSON parse error: {}", e));
return Box::into_raw(result);
}
};
let backend_enum = match Backend::from_str(&backend_str) {
Ok(b) => b,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Unknown backend: {}", e));
return Box::into_raw(result);
}
};
let config = ExecutionConfig {
backend: backend_enum,
device: tensorlogic_scirs_backend::DeviceType::Cpu,
show_metrics: false,
show_intermediates: false,
validate_shapes: true,
trace: false,
};
use crate::executor::CliExecutor;
let executor = match CliExecutor::new(config) {
Ok(e) => e,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Executor creation error: {}", e));
return Box::into_raw(result);
}
};
let exec_result = match executor.execute(&graph) {
Ok(r) => r,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Execution error: {}", e));
return Box::into_raw(result);
}
};
let output_json = match serde_json::to_string_pretty(&exec_result.output) {
Ok(j) => j,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Serialization error: {}", e));
return Box::into_raw(result);
}
};
let mut result = result;
result.output_data = to_c_string(output_json);
result.execution_time_us = (exec_result.execution_time_ms * 1000.0) as u64;
Box::into_raw(result)
}
#[no_mangle]
pub unsafe extern "C" fn tl_optimize_graph(
graph_json: *const c_char,
level: i32,
) -> *mut TLOptimizationResult {
let result = Box::new(TLOptimizationResult {
graph_data: ptr::null_mut(),
error_message: ptr::null_mut(),
tensors_removed: 0,
nodes_removed: 0,
});
let json_str = match from_c_string(graph_json) {
Ok(s) => s,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Invalid graph JSON: {}", e));
return Box::into_raw(result);
}
};
let graph: tensorlogic_ir::EinsumGraph = match serde_json::from_str(&json_str) {
Ok(g) => g,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("JSON parse error: {}", e));
return Box::into_raw(result);
}
};
let opt_level = match level {
0 => OptimizationLevel::None,
1 => OptimizationLevel::Basic,
2 => OptimizationLevel::Standard,
3 => OptimizationLevel::Aggressive,
_ => {
let mut result = result;
result.error_message = to_c_string(format!("Invalid optimization level: {}", level));
return Box::into_raw(result);
}
};
use crate::optimize::OptimizationConfig;
let config = OptimizationConfig {
level: opt_level,
enable_dce: true,
enable_cse: true,
enable_identity: true,
show_stats: false,
verbose: false,
};
let initial_nodes = graph.nodes.len();
let initial_tensors = graph.tensors.len();
let (optimized, _stats) = match crate::optimize::optimize_einsum_graph(graph, &config) {
Ok(r) => r,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Optimization error: {}", e));
return Box::into_raw(result);
}
};
let output_json = match serde_json::to_string_pretty(&optimized) {
Ok(j) => j,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Serialization error: {}", e));
return Box::into_raw(result);
}
};
let mut result = result;
result.graph_data = to_c_string(output_json);
result.tensors_removed = initial_tensors.saturating_sub(optimized.tensors.len());
result.nodes_removed = initial_nodes.saturating_sub(optimized.nodes.len());
Box::into_raw(result)
}
#[no_mangle]
pub unsafe extern "C" fn tl_benchmark_compilation(
expr: *const c_char,
iterations: usize,
) -> *mut TLBenchmarkResult {
let result = Box::new(TLBenchmarkResult {
error_message: ptr::null_mut(),
mean_us: 0.0,
std_dev_us: 0.0,
min_us: 0,
max_us: 0,
iterations: 0,
});
let expr_str = match from_c_string(expr) {
Ok(s) => s,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Invalid expression: {}", e));
return Box::into_raw(result);
}
};
let tlexpr = match parse_expression(&expr_str) {
Ok(e) => e,
Err(e) => {
let mut result = result;
result.error_message = to_c_string(format!("Parse error: {}", e));
return Box::into_raw(result);
}
};
let mut timings = Vec::with_capacity(iterations);
for _ in 0..iterations {
let mut context = CompilerContext::new();
let start = std::time::Instant::now();
if tensorlogic_compiler::compile_to_einsum_with_context(&tlexpr, &mut context).is_ok() {
timings.push(start.elapsed());
} else {
let mut result = result;
result.error_message = to_c_string("Compilation failed during benchmark".to_string());
return Box::into_raw(result);
}
}
let mut sum_us = 0u64;
let mut min_us = u64::MAX;
let mut max_us = 0u64;
for timing in &timings {
let us = timing.as_micros() as u64;
sum_us += us;
min_us = min_us.min(us);
max_us = max_us.max(us);
}
let mean_us = sum_us as f64 / iterations as f64;
let mut variance_sum = 0.0;
for timing in &timings {
let us = timing.as_micros() as f64;
let diff = us - mean_us;
variance_sum += diff * diff;
}
let std_dev_us = (variance_sum / iterations as f64).sqrt();
let mut result = result;
result.mean_us = mean_us;
result.std_dev_us = std_dev_us;
result.min_us = min_us;
result.max_us = max_us;
result.iterations = iterations;
Box::into_raw(result)
}
#[no_mangle]
pub unsafe extern "C" fn tl_free_string(s: *mut c_char) {
if !s.is_null() {
drop(CString::from_raw(s));
}
}
#[no_mangle]
pub unsafe extern "C" fn tl_free_graph_result(result: *mut TLGraphResult) {
if !result.is_null() {
let result = Box::from_raw(result);
if !result.graph_data.is_null() {
tl_free_string(result.graph_data);
}
if !result.error_message.is_null() {
tl_free_string(result.error_message);
}
}
}
#[no_mangle]
pub unsafe extern "C" fn tl_free_execution_result(result: *mut TLExecutionResult) {
if !result.is_null() {
let result = Box::from_raw(result);
if !result.output_data.is_null() {
tl_free_string(result.output_data);
}
if !result.error_message.is_null() {
tl_free_string(result.error_message);
}
}
}
#[no_mangle]
pub unsafe extern "C" fn tl_free_optimization_result(result: *mut TLOptimizationResult) {
if !result.is_null() {
let result = Box::from_raw(result);
if !result.graph_data.is_null() {
tl_free_string(result.graph_data);
}
if !result.error_message.is_null() {
tl_free_string(result.error_message);
}
}
}
#[no_mangle]
pub unsafe extern "C" fn tl_free_benchmark_result(result: *mut TLBenchmarkResult) {
if !result.is_null() {
let result = Box::from_raw(result);
if !result.error_message.is_null() {
tl_free_string(result.error_message);
}
}
}
#[no_mangle]
pub extern "C" fn tl_version() -> *mut c_char {
to_c_string(env!("CARGO_PKG_VERSION").to_string())
}
#[no_mangle]
pub unsafe extern "C" fn tl_is_backend_available(backend: *const c_char) -> i32 {
let backend_str = match from_c_string(backend) {
Ok(s) => s,
Err(_) => return 0,
};
match Backend::from_str(&backend_str) {
Ok(b) => {
if b.is_available() {
1
} else {
0
}
}
Err(_) => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::CString;
#[test]
fn test_compile_expr_success() {
let expr = CString::new("AND(pred1(x), pred2(x, y))").unwrap();
unsafe {
let result = tl_compile_expr(expr.as_ptr());
assert!(!result.is_null());
let result = Box::from_raw(result);
if !result.error_message.is_null() {
let err = CStr::from_ptr(result.error_message).to_str().unwrap();
println!("Compilation error: {}", err);
}
if !result.graph_data.is_null() {
let graph = CStr::from_ptr(result.graph_data).to_str().unwrap();
println!("Graph: {}", &graph[..graph.len().min(200)]);
println!(
"Tensors: {}, Nodes: {}",
result.tensor_count, result.node_count
);
}
assert!(result.error_message.is_null(), "Compilation should succeed");
assert!(!result.graph_data.is_null());
assert!(result.tensor_count > 0, "Should have at least one tensor");
if !result.graph_data.is_null() {
tl_free_string(result.graph_data);
}
if !result.error_message.is_null() {
tl_free_string(result.error_message);
}
}
}
#[test]
fn test_compile_expr_invalid_syntax() {
let expr = CString::new("AND(pred1(x), )").unwrap();
unsafe {
let result = tl_compile_expr(expr.as_ptr());
assert!(!result.is_null());
let result = Box::from_raw(result);
if !result.error_message.is_null() {
tl_free_string(result.error_message);
}
if !result.graph_data.is_null() {
tl_free_string(result.graph_data);
}
}
}
#[test]
fn test_compile_expr_with_error() {
let expr = CString::new("\"unclosed_string").unwrap();
unsafe {
let result = tl_compile_expr(expr.as_ptr());
assert!(!result.is_null());
let result = Box::from_raw(result);
if !result.error_message.is_null() {
tl_free_string(result.error_message);
}
if !result.graph_data.is_null() {
tl_free_string(result.graph_data);
}
}
}
#[test]
fn test_version() {
unsafe {
let version = tl_version();
assert!(!version.is_null());
let version_str = CStr::from_ptr(version).to_str().unwrap();
assert!(!version_str.is_empty());
tl_free_string(version);
}
}
#[test]
fn test_backend_availability() {
let cpu = CString::new("cpu").unwrap();
unsafe {
assert_eq!(tl_is_backend_available(cpu.as_ptr()), 1);
}
let invalid = CString::new("invalid_backend").unwrap();
unsafe {
assert_eq!(tl_is_backend_available(invalid.as_ptr()), 0);
}
}
}