#![allow(unused_variables)]
use std::collections::HashMap;
use std::ffi::c_void;
use std::fmt;
use std::ptr;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, SystemTime};
use crate::eval::{Value, Environment};
use crate::ast::Literal;
use crate::diagnostics::Error;
use crate::ffi::c_types::{CType, TypeMarshaller, ConversionError};
#[derive(Debug, Clone)]
pub enum CallbackError {
NotFound(String),
InvalidSignature {
callback: String,
reason: String,
},
ExecutionFailed {
callback: String,
error: String,
},
ConversionError(ConversionError),
AlreadyRegistered(String),
Expired(String),
StackOverflow,
}
impl fmt::Display for CallbackError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CallbackError::NotFound(name) => {
write!(f, "Callback not found: {name}")
}
CallbackError::InvalidSignature { callback, reason } => {
write!(f, "Invalid callback signature for '{callback}': {reason}")
}
CallbackError::ExecutionFailed { callback, error } => {
write!(f, "Callback '{callback}' execution failed: {error}")
}
CallbackError::ConversionError(e) => {
write!(f, "Callback type conversion error: {e}")
}
CallbackError::AlreadyRegistered(name) => {
write!(f, "Callback '{name}' is already registered")
}
CallbackError::Expired(name) => {
write!(f, "Callback '{name}' has expired")
}
CallbackError::StackOverflow => {
write!(f, "Callback stack overflow protection triggered")
}
}
}
}
impl std::error::Error for CallbackError {}
impl From<ConversionError> for CallbackError {
fn from(e: ConversionError) -> Self {
CallbackError::ConversionError(e)
}
}
impl From<CallbackError> for Error {
fn from(callback_error: CallbackError) -> Self {
Error::runtime_error(callback_error.to_string(), None)
}
}
#[derive(Debug, Clone)]
pub struct CallbackSignature {
pub name: String,
pub parameters: Vec<CType>,
pub return_type: CType,
pub variadic: bool,
pub calling_convention: CallingConvention,
}
#[derive(Debug, Clone, PartialEq)]
#[derive(Default)]
pub enum CallingConvention {
#[default]
C,
Stdcall,
Fastcall,
SystemV,
}
#[derive(Debug)]
pub struct CallbackFunction {
pub signature: CallbackSignature,
pub function: Value,
pub environment: Arc<Mutex<Environment>>,
pub registered_at: SystemTime,
pub expires_at: Option<SystemTime>,
pub call_count: Arc<Mutex<u64>>,
pub c_function_ptr: *const c_void,
pub active: Arc<Mutex<bool>>,
}
impl CallbackFunction {
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
SystemTime::now() > expires_at
} else {
false
}
}
pub fn is_active(&self) -> bool {
*self.active.lock().unwrap()
}
pub fn call_count(&self) -> u64 {
*self.call_count.lock().unwrap()
}
pub fn increment_call_count(&self) {
let mut count = self.call_count.lock().unwrap();
*count += 1;
}
}
unsafe impl Send for CallbackFunction {}
unsafe impl Sync for CallbackFunction {}
#[derive(Debug)]
pub struct CallbackRegistry {
callbacks: RwLock<HashMap<String, Arc<CallbackFunction>>>,
marshaller: Arc<Mutex<TypeMarshaller>>,
stack_depth: Arc<Mutex<usize>>,
max_stack_depth: usize,
stats: RwLock<CallbackStats>,
}
#[derive(Debug, Default, Clone)]
pub struct CallbackStats {
pub total_registered: usize,
pub currently_active: usize,
pub total_invocations: u64,
pub successful_invocations: u64,
pub failed_invocations: u64,
pub stack_overflows_prevented: u64,
}
impl Default for CallbackRegistry {
fn default() -> Self {
Self::new()
}
}
impl CallbackRegistry {
pub fn new() -> Self {
Self {
callbacks: RwLock::new(HashMap::new()),
marshaller: Arc::new(Mutex::new(TypeMarshaller::new())),
stack_depth: Arc::new(Mutex::new(0)),
max_stack_depth: 32, stats: RwLock::new(CallbackStats::default()),
}
}
pub fn set_max_stack_depth(&mut self, depth: usize) {
self.max_stack_depth = depth;
}
pub fn register_callback(
&self,
signature: CallbackSignature,
function: Value,
environment: Arc<Mutex<Environment>>,
expires_after: Option<Duration>,
) -> std::result::Result<*const c_void, CallbackError> {
{
let callbacks = self.callbacks.read().unwrap();
if callbacks.contains_key(&signature.name) {
return Err(CallbackError::AlreadyRegistered(signature.name.clone()));
}
}
let expires_at = expires_after.map(|duration| SystemTime::now() + duration);
let c_function_ptr = self.generate_c_function_ptr(&signature)?;
let callback = Arc::new(CallbackFunction {
signature: signature.clone(),
function,
environment,
registered_at: SystemTime::now(),
expires_at,
call_count: Arc::new(Mutex::new(0)),
c_function_ptr,
active: Arc::new(Mutex::new(true)),
});
{
let mut callbacks = self.callbacks.write().unwrap();
callbacks.insert(signature.name.clone(), callback);
}
{
let mut stats = self.stats.write().unwrap();
stats.total_registered += 1;
stats.currently_active = self.callbacks.read().unwrap().len();
}
Ok(c_function_ptr)
}
pub fn unregister_callback(&self, name: &str) -> std::result::Result<(), CallbackError> {
let callback = {
let mut callbacks = self.callbacks.write().unwrap();
callbacks.remove(name)
};
if let Some(callback) = callback {
{
let mut active = callback.active.lock().unwrap();
*active = false;
}
{
let mut stats = self.stats.write().unwrap();
stats.currently_active = self.callbacks.read().unwrap().len();
}
Ok(())
} else {
Err(CallbackError::NotFound(name.to_string()))
}
}
pub fn get_callback(&self, name: &str) -> Option<Arc<CallbackFunction>> {
let callbacks = self.callbacks.read().unwrap();
callbacks.get(name).cloned()
}
pub fn list_callbacks(&self) -> Vec<String> {
let callbacks = self.callbacks.read().unwrap();
callbacks.keys().cloned().collect()
}
pub fn cleanup_expired(&self) -> usize {
let expired_names: Vec<String> = {
let callbacks = self.callbacks.read().unwrap();
callbacks
.iter()
.filter(|(_, callback)| callback.is_expired())
.map(|(name, _)| name.clone())
.collect()
};
let count = expired_names.len();
for name in expired_names {
let _ = self.unregister_callback(&name);
}
count
}
pub fn stats(&self) -> CallbackStats {
self.stats.read().unwrap().clone()
}
fn generate_c_function_ptr(&self, _signature: &CallbackSignature) -> std::result::Result<*const c_void, CallbackError> {
let dummy_ptr = self as *const CallbackRegistry as *const c_void;
Ok(dummy_ptr)
}
pub unsafe fn execute_callback(
&self,
name: &str,
args: *const *const c_void,
arg_count: usize,
) -> std::result::Result<*const c_void, CallbackError> {
{
let mut depth = self.stack_depth.lock().unwrap();
if *depth >= self.max_stack_depth {
let mut stats = self.stats.write().unwrap();
stats.stack_overflows_prevented += 1;
return Err(CallbackError::StackOverflow);
}
*depth += 1;
}
let _stack_guard = StackGuard::new(Arc::clone(&self.stack_depth));
let callback = self.get_callback(name)
.ok_or_else(|| CallbackError::NotFound(name.to_string()))?;
if callback.is_expired() {
return Err(CallbackError::Expired(name.to_string()));
}
if !callback.is_active() {
return Err(CallbackError::NotFound(name.to_string()));
}
{
let mut stats = self.stats.write().unwrap();
stats.total_invocations += 1;
}
let mut scheme_args = Vec::new();
let marshaller = self.marshaller.lock().unwrap();
for i in 0..arg_count.min(callback.signature.parameters.len()) {
let arg_ptr = unsafe { *args.add(i) };
let param_type = &callback.signature.parameters[i];
let value = match param_type {
CType::CInt => {
let int_val = unsafe { *(arg_ptr as *const libc::c_int) };
Value::Literal(Literal::integer(int_val as i64))
}
CType::CString => {
let c_str_ptr = unsafe { *(arg_ptr as *const *const libc::c_char) };
if c_str_ptr.is_null() {
Value::Literal(Literal::String("".to_string()))
} else {
let c_str = unsafe { std::ffi::CStr::from_ptr(c_str_ptr) };
let rust_str = c_str.to_str()
.map_err(|e| CallbackError::ConversionError(
ConversionError::StringConversion(e.to_string())
))?;
Value::Literal(Literal::String(rust_str.to_string()))
}
}
_ => Value::Nil, };
scheme_args.push(value);
}
let result = {
let env = callback.environment.lock().unwrap();
match &callback.function {
Value::Procedure(_) => {
Ok(Value::Literal(Literal::integer(0))) }
_ => Err(CallbackError::ExecutionFailed {
callback: name.to_string(),
error: "Not a function".to_string(),
})
}
};
match result {
Ok(return_value) => {
callback.increment_call_count();
{
let mut stats = self.stats.write().unwrap();
stats.successful_invocations += 1;
}
Ok(ptr::null())
}
Err(e) => {
{
let mut stats = self.stats.write().unwrap();
stats.failed_invocations += 1;
}
Err(e)
}
}
}
}
struct StackGuard {
stack_depth: Arc<Mutex<usize>>,
}
impl StackGuard {
fn new(stack_depth: Arc<Mutex<usize>>) -> Self {
Self { stack_depth }
}
}
impl Drop for StackGuard {
fn drop(&mut self) {
let mut depth = self.stack_depth.lock().unwrap();
if *depth > 0 {
*depth -= 1;
}
}
}
#[cfg(feature = "async")]
pub mod async_callbacks {
use super::*;
use std::future::Future;
use std::pin::Pin;
use tokio::sync::oneshot;
pub type AsyncCallbackResult = Pin<Box<dyn Future<Output = std::result::Result<Value, CallbackError>> + Send>>;
#[derive(Debug)]
pub struct AsyncCallbackRegistry {
base: CallbackRegistry,
pending: Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
}
impl AsyncCallbackRegistry {
pub fn new() -> Self {
Self {
base: CallbackRegistry::new(),
pending: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn register_async_callback(
&self,
signature: CallbackSignature,
function: Value,
environment: Arc<Mutex<Environment>>,
expires_after: Option<Duration>,
) -> std::result::Result<*const c_void, CallbackError> {
self.base.register_callback(signature, function, environment, expires_after)
}
pub async fn execute_async_callback(
&self,
_name: &str,
_args: Vec<Value>,
) -> std::result::Result<Value, CallbackError> {
Ok(Value::Nil)
}
}
impl Default for AsyncCallbackRegistry {
fn default() -> Self {
Self::new()
}
}
}
lazy_static::lazy_static! {
pub static ref GLOBAL_CALLBACK_REGISTRY: CallbackRegistry = CallbackRegistry::new();
}
pub fn register_callback(
signature: CallbackSignature,
function: Value,
environment: Arc<Mutex<Environment>>,
expires_after: Option<Duration>,
) -> std::result::Result<*const c_void, CallbackError> {
GLOBAL_CALLBACK_REGISTRY.register_callback(signature, function, environment, expires_after)
}
pub fn unregister_callback(name: &str) -> std::result::Result<(), CallbackError> {
GLOBAL_CALLBACK_REGISTRY.unregister_callback(name)
}
pub fn cleanup_expired_callbacks() -> usize {
GLOBAL_CALLBACK_REGISTRY.cleanup_expired()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::Environment;
#[test]
fn test_callback_registry_creation() {
let registry = CallbackRegistry::new();
let stats = registry.stats();
assert_eq!(stats.currently_active, 0);
assert_eq!(stats.total_registered, 0);
}
#[test]
fn test_callback_signature() {
let sig = CallbackSignature {
name: "test_callback".to_string(),
parameters: vec![CType::CInt, CType::CString],
return_type: CType::CInt,
variadic: false,
calling_convention: CallingConvention::C,
};
assert_eq!(sig.name, "test_callback");
assert_eq!(sig.parameters.len(), 2);
assert!(!sig.variadic);
}
#[test]
fn test_callback_function_creation() {
let signature = CallbackSignature {
name: "test".to_string(),
parameters: vec![],
return_type: CType::Void,
variadic: false,
calling_convention: CallingConvention::C,
};
let function = Value::Literal(Literal::Number(42.0));
let environment = Arc::new(Mutex::new(Environment::new(None, 0)));
let callback = CallbackFunction {
signature,
function,
environment,
registered_at: SystemTime::now(),
expires_at: None,
call_count: Arc::new(Mutex::new(0)),
c_function_ptr: ptr::null(),
active: Arc::new(Mutex::new(true)),
};
assert!(callback.is_active());
assert!(!callback.is_expired());
assert_eq!(callback.call_count(), 0);
}
#[test]
fn test_cleanup_expired() {
let registry = CallbackRegistry::new();
let cleaned = registry.cleanup_expired();
assert_eq!(cleaned, 0); }
}