#![allow(unused_variables)]
use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, RwLock};
use crate::eval::Value;
use crate::ast::Literal;
use crate::diagnostics::Error;
use crate::ffi::c_types::CType;
pub type SafetyResult<T> = std::result::Result<T, Box<SafetyError>>;
#[derive(Debug, Clone)]
pub enum SafetyError {
SignatureMismatch {
function: String,
expected: FunctionSignature,
actual: Box<FunctionSignature>,
},
InvalidFunctionPointer {
function: String,
pointer: *const u8,
},
RuntimeTypeCheck {
parameter: usize,
expected: CType,
actual_value: String,
},
BoundaryViolation {
operation: String,
description: String,
},
NullPointerDereference {
parameter: usize,
context: String,
},
BufferBoundsCheck {
buffer_size: usize,
access_offset: usize,
access_size: usize,
},
UninitializedMemory {
pointer: *const u8,
size: usize,
},
StackOverflow {
current_depth: usize,
max_depth: usize,
},
ResourceLeak {
resource_type: String,
resource_id: String,
},
}
impl fmt::Display for SafetyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SafetyError::SignatureMismatch { function, expected, actual } => {
write!(f, "Function '{function}' signature mismatch: expected {expected:?}, got {actual:?}")
}
SafetyError::InvalidFunctionPointer { function, pointer } => {
write!(f, "Invalid function pointer for '{function}': {pointer:p}")
}
SafetyError::RuntimeTypeCheck { parameter, expected, actual_value } => {
write!(f, "Runtime type check failed for parameter {parameter}: expected {expected}, got {actual_value}")
}
SafetyError::BoundaryViolation { operation, description } => {
write!(f, "Boundary violation in {operation}: {description}")
}
SafetyError::NullPointerDereference { parameter, context } => {
write!(f, "Null pointer dereference in parameter {parameter} ({context})")
}
SafetyError::BufferBoundsCheck { buffer_size, access_offset, access_size } => {
write!(f, "Buffer bounds check failed: buffer size {buffer_size}, access offset {access_offset}, access size {access_size}")
}
SafetyError::UninitializedMemory { pointer, size } => {
write!(f, "Uninitialized memory access at {pointer:p} (size {size})")
}
SafetyError::StackOverflow { current_depth, max_depth } => {
write!(f, "Stack overflow: current depth {current_depth}, max depth {max_depth}")
}
SafetyError::ResourceLeak { resource_type, resource_id } => {
write!(f, "Resource leak detected: {resource_type} (ID: {resource_id})")
}
}
}
}
impl std::error::Error for SafetyError {}
impl From<SafetyError> for Error {
fn from(safety_error: SafetyError) -> Self {
Error::runtime_error(safety_error.to_string(), None)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FunctionSignature {
pub name: String,
pub parameters: Vec<CType>,
pub return_type: CType,
pub variadic: bool,
pub safe: bool,
pub constraints: Vec<TypeConstraint>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TypeConstraint {
NonNull(usize),
Bounds {
parameter: usize,
min: i64,
max: i64,
},
NullTerminated(usize),
BufferWithSize {
buffer_param: usize,
size_param: usize,
},
Aligned {
parameter: usize,
alignment: usize,
},
ResourceManagement {
parameter: usize,
resource_type: String,
lifetime: ResourceLifetime,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum ResourceLifetime {
Owned,
Borrowed,
Transferred,
Shared,
}
#[derive(Debug)]
pub struct TypeSafetyValidator {
signatures: RwLock<HashMap<String, FunctionSignature>>,
validation_rules: RwLock<HashMap<String, Vec<ValidationRule>>>,
config: RwLock<SafetyConfig>,
stats: RwLock<SafetyStats>,
stack_depth: RwLock<usize>,
}
#[derive(Debug, Clone)]
pub struct ValidationRule {
pub name: String,
pub trigger: ValidationTrigger,
pub validator: ValidationFunction,
pub enabled: bool,
}
#[derive(Debug, Clone)]
pub enum ValidationTrigger {
PreCall,
PostCall,
ParameterConversion(usize),
ReturnConversion,
Custom(String),
}
#[derive(Debug, Clone)]
pub enum ValidationFunction {
NullPointerCheck,
BoundsCheck { min: i64, max: i64 },
BufferSizeCheck,
StringValidation,
AlignmentCheck { alignment: usize },
Custom { name: String, description: String },
}
#[derive(Debug, Clone)]
pub struct SafetyConfig {
pub runtime_type_checking: bool,
pub null_pointer_checking: bool,
pub bounds_checking: bool,
pub buffer_overflow_protection: bool,
pub stack_overflow_protection: bool,
pub max_stack_depth: usize,
pub resource_leak_detection: bool,
pub function_pointer_validation: bool,
pub memory_alignment_checking: bool,
}
impl Default for SafetyConfig {
fn default() -> Self {
Self {
runtime_type_checking: true,
null_pointer_checking: true,
bounds_checking: true,
buffer_overflow_protection: true,
stack_overflow_protection: true,
max_stack_depth: 64,
resource_leak_detection: true,
function_pointer_validation: true,
memory_alignment_checking: true,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct SafetyStats {
pub total_validations: u64,
pub successful_validations: u64,
pub failed_validations: u64,
pub null_pointer_violations: u64,
pub bounds_violations: u64,
pub buffer_overflow_prevented: u64,
pub stack_overflow_prevented: u64,
pub resource_leaks_detected: u64,
}
impl Default for TypeSafetyValidator {
fn default() -> Self {
Self::new()
}
}
impl TypeSafetyValidator {
pub fn new() -> Self {
Self {
signatures: RwLock::new(HashMap::new()),
validation_rules: RwLock::new(HashMap::new()),
config: RwLock::new(SafetyConfig::default()),
stats: RwLock::new(SafetyStats::default()),
stack_depth: RwLock::new(0),
}
}
pub fn configure(&self, config: SafetyConfig) {
let mut current_config = self.config.write().unwrap();
*current_config = config;
}
pub fn register_function_signature(&self, signature: FunctionSignature) -> SafetyResult<()> {
let mut signatures = self.signatures.write().unwrap();
signatures.insert(signature.name.clone(), signature);
Ok(())
}
pub fn add_validation_rule(&self, function_name: String, rule: ValidationRule) {
let mut rules = self.validation_rules.write().unwrap();
rules.entry(function_name).or_default().push(rule);
}
pub fn validate_function_call(
&self,
function_name: &str,
args: &[Value],
function_ptr: *const u8,
) -> SafetyResult<()> {
let config = self.config.read().unwrap();
{
let mut stats = self.stats.write().unwrap();
stats.total_validations += 1;
}
if config.stack_overflow_protection {
let mut depth = self.stack_depth.write().unwrap();
if *depth >= config.max_stack_depth {
let mut stats = self.stats.write().unwrap();
stats.stack_overflow_prevented += 1;
return Err(Box::new(SafetyError::StackOverflow {
current_depth: *depth,
max_depth: config.max_stack_depth,
}));
}
*depth += 1;
}
if config.function_pointer_validation && function_ptr.is_null() {
return Err(Box::new(SafetyError::InvalidFunctionPointer {
function: function_name.to_string(),
pointer: function_ptr,
}));
}
let signature = {
let signatures = self.signatures.read().unwrap();
signatures.get(function_name).cloned()
};
if let Some(sig) = signature {
if config.runtime_type_checking {
self.validate_parameter_types(&sig, args, function_name)?;
}
self.validate_constraints(&sig, args, function_name)?;
self.apply_validation_rules(function_name, args, ValidationTrigger::PreCall)?;
}
{
let mut stats = self.stats.write().unwrap();
stats.successful_validations += 1;
}
Ok(())
}
pub fn validate_function_completion(
&self,
function_name: &str,
return_value: &Value,
) -> SafetyResult<()> {
let config = self.config.read().unwrap();
if config.stack_overflow_protection {
let mut depth = self.stack_depth.write().unwrap();
if *depth > 0 {
*depth -= 1;
}
}
if let Some(signature) = self.get_function_signature(function_name) {
if config.runtime_type_checking {
self.validate_return_type(&signature, return_value, function_name)?;
}
self.apply_validation_rules(function_name, &[], ValidationTrigger::PostCall)?;
}
Ok(())
}
fn validate_parameter_types(
&self,
signature: &FunctionSignature,
args: &[Value],
function_name: &str,
) -> SafetyResult<()> {
if !signature.variadic && args.len() != signature.parameters.len() {
return Err(Box::new(SafetyError::SignatureMismatch {
function: function_name.to_string(),
expected: signature.clone(),
actual: Box::new(FunctionSignature {
name: function_name.to_string(),
parameters: args.iter().map(|_| CType::Void).collect(), return_type: CType::Void,
variadic: false,
safe: false,
constraints: vec![],
}),
}));
}
for (i, (arg, expected_type)) in args.iter().zip(signature.parameters.iter()).enumerate() {
if !self.is_value_compatible_with_type(arg, expected_type) {
return Err(Box::new(SafetyError::RuntimeTypeCheck {
parameter: i,
expected: expected_type.clone(),
actual_value: format!("{arg:?}"),
}));
}
}
Ok(())
}
fn validate_return_type(
&self,
signature: &FunctionSignature,
return_value: &Value,
_function_name: &str,
) -> SafetyResult<()> {
if !self.is_value_compatible_with_type(return_value, &signature.return_type) {
return Err(Box::new(SafetyError::RuntimeTypeCheck {
parameter: 0, expected: signature.return_type.clone(),
actual_value: format!("{return_value:?}"),
}));
}
Ok(())
}
fn is_value_compatible_with_type(&self, value: &Value, c_type: &CType) -> bool {
match (value, c_type) {
(Value::Literal(literal), t) if literal.is_number() && t.is_numeric() => true,
(Value::Literal(literal), CType::Float | CType::Double) if literal.is_number() => true,
(Value::Literal(Literal::Boolean(_)), CType::Bool) => true,
(Value::Literal(Literal::String(_)), CType::CString) => true,
(Value::Literal(Literal::Character(_)), CType::Char) => true,
(Value::Nil, t) if t.is_pointer() => true,
_ => false,
}
}
fn validate_constraints(
&self,
signature: &FunctionSignature,
args: &[Value],
function_name: &str,
) -> SafetyResult<()> {
let config = self.config.read().unwrap();
for constraint in &signature.constraints {
match constraint {
TypeConstraint::NonNull(param_idx) => {
if config.null_pointer_checking && *param_idx < args.len()
&& matches!(args[*param_idx], Value::Nil) {
let mut stats = self.stats.write().unwrap();
stats.null_pointer_violations += 1;
return Err(Box::new(SafetyError::NullPointerDereference {
parameter: *param_idx,
context: function_name.to_string(),
}));
}
}
TypeConstraint::Bounds { parameter, min, max } => {
if config.bounds_checking && *parameter < args.len() {
if let Value::Literal(literal) = &args[*parameter] {
if let Some(val) = literal.to_f64() {
if val < (*min as f64) || val > (*max as f64) {
let mut stats = self.stats.write().unwrap();
stats.bounds_violations += 1;
return Err(Box::new(SafetyError::BoundaryViolation {
operation: format!("parameter {parameter} bounds check"),
description: format!("value {val} not in range [{min}..{max}]"),
}));
}
}
}
}
}
TypeConstraint::BufferWithSize { buffer_param, size_param } => {
if config.buffer_overflow_protection &&
*buffer_param < args.len() && *size_param < args.len() {
if matches!(args[*buffer_param], Value::Nil) {
let mut stats = self.stats.write().unwrap();
stats.buffer_overflow_prevented += 1;
return Err(Box::new(SafetyError::NullPointerDereference {
parameter: *buffer_param,
context: "buffer parameter".to_string(),
}));
}
}
}
_ => {
}
}
}
Ok(())
}
fn apply_validation_rules(
&self,
function_name: &str,
args: &[Value],
trigger: ValidationTrigger,
) -> SafetyResult<()> {
let rules = self.validation_rules.read().unwrap();
if let Some(function_rules) = rules.get(function_name) {
for rule in function_rules {
if rule.enabled && self.matches_trigger(&rule.trigger, &trigger) {
self.apply_single_validation_rule(rule, args, function_name)?;
}
}
}
Ok(())
}
fn matches_trigger(&self, rule_trigger: &ValidationTrigger, actual_trigger: &ValidationTrigger) -> bool {
match (rule_trigger, actual_trigger) {
(ValidationTrigger::PreCall, ValidationTrigger::PreCall) => true,
(ValidationTrigger::PostCall, ValidationTrigger::PostCall) => true,
(ValidationTrigger::ParameterConversion(a), ValidationTrigger::ParameterConversion(b)) => a == b,
(ValidationTrigger::ReturnConversion, ValidationTrigger::ReturnConversion) => true,
(ValidationTrigger::Custom(a), ValidationTrigger::Custom(b)) => a == b,
_ => false,
}
}
fn apply_single_validation_rule(
&self,
rule: &ValidationRule,
args: &[Value],
function_name: &str,
) -> SafetyResult<()> {
match &rule.validator {
ValidationFunction::NullPointerCheck => {
for (i, arg) in args.iter().enumerate() {
if matches!(arg, Value::Nil) {
return Err(Box::new(SafetyError::NullPointerDereference {
parameter: i,
context: rule.name.clone(),
}));
}
}
}
ValidationFunction::BoundsCheck { min, max } => {
for (i, arg) in args.iter().enumerate() {
let val = match arg {
Value::Literal(Literal::ExactInteger(val)) => *val as f64,
Value::Literal(Literal::InexactReal(val)) => *val,
_ => continue,
};
if val < (*min as f64) || val > (*max as f64) {
return Err(Box::new(SafetyError::BoundaryViolation {
operation: rule.name.clone(),
description: format!("value {val} not in range [{min}..{max}]"),
}));
}
}
}
ValidationFunction::StringValidation => {
for (i, arg) in args.iter().enumerate() {
if let Value::Literal(Literal::String(s)) = arg {
if s.contains('\0') && !s.ends_with('\0') {
return Err(Box::new(SafetyError::BoundaryViolation {
operation: "string validation".to_string(),
description: "string contains null character but is not null-terminated".to_string(),
}));
}
}
}
}
_ => {
}
}
Ok(())
}
pub fn get_function_signature(&self, function_name: &str) -> Option<FunctionSignature> {
let signatures = self.signatures.read().unwrap();
signatures.get(function_name).cloned()
}
pub fn list_registered_functions(&self) -> Vec<String> {
let signatures = self.signatures.read().unwrap();
signatures.keys().cloned().collect()
}
pub fn stats(&self) -> SafetyStats {
self.stats.read().unwrap().clone()
}
pub fn clear(&self) {
let mut signatures = self.signatures.write().unwrap();
signatures.clear();
let mut rules = self.validation_rules.write().unwrap();
rules.clear();
}
}
pub struct StackDepthGuard {
validator: Arc<TypeSafetyValidator>,
}
impl StackDepthGuard {
pub fn new(validator: Arc<TypeSafetyValidator>) -> Self {
Self { validator }
}
}
impl Drop for StackDepthGuard {
fn drop(&mut self) {
let mut depth = self.validator.stack_depth.write().unwrap();
if *depth > 0 {
*depth -= 1;
}
}
}
lazy_static::lazy_static! {
pub static ref GLOBAL_TYPE_SAFETY_VALIDATOR: TypeSafetyValidator = TypeSafetyValidator::new();
}
pub fn register_function_signature(signature: FunctionSignature) -> SafetyResult<()> {
GLOBAL_TYPE_SAFETY_VALIDATOR.register_function_signature(signature)
}
pub fn validate_function_call(
function_name: &str,
args: &[Value],
function_ptr: *const u8,
) -> SafetyResult<()> {
GLOBAL_TYPE_SAFETY_VALIDATOR.validate_function_call(function_name, args, function_ptr)
}
pub fn validate_function_completion(
function_name: &str,
return_value: &Value,
) -> SafetyResult<()> {
GLOBAL_TYPE_SAFETY_VALIDATOR.validate_function_completion(function_name, return_value)
}
#[cfg(test)]
mod tests {
use super::*;
use std::ptr;
#[test]
fn test_validator_creation() {
let validator = TypeSafetyValidator::new();
let stats = validator.stats();
assert_eq!(stats.total_validations, 0);
}
#[test]
fn test_function_signature_registration() {
let validator = TypeSafetyValidator::new();
let signature = FunctionSignature {
name: "test_function".to_string(),
parameters: vec![CType::CInt, CType::CString],
return_type: CType::CInt,
variadic: false,
safe: true,
constraints: vec![TypeConstraint::NonNull(1)],
};
validator.register_function_signature(signature.clone()).unwrap();
let retrieved = validator.get_function_signature("test_function").unwrap();
assert_eq!(retrieved.name, "test_function");
assert_eq!(retrieved.parameters.len(), 2);
}
#[test]
fn test_parameter_type_validation() {
let validator = TypeSafetyValidator::new();
let signature = FunctionSignature {
name: "test_function".to_string(),
parameters: vec![CType::CInt],
return_type: CType::CInt,
variadic: false,
safe: true,
constraints: vec![],
};
validator.register_function_signature(signature).unwrap();
let args = vec![Value::Literal(Literal::Number(42.0))];
let result = validator.validate_function_call("test_function", &args, ptr::null());
assert!(result.is_ok());
let args = vec![Value::Literal(Literal::String("hello".to_string()))];
let result = validator.validate_function_call("test_function", &args, ptr::null());
assert!(matches!(result, Err(ref err) if matches!(**err, SafetyError::RuntimeTypeCheck { .. })));
}
#[test]
fn test_null_pointer_constraint() {
let validator = TypeSafetyValidator::new();
let signature = FunctionSignature {
name: "test_function".to_string(),
parameters: vec![CType::CString],
return_type: CType::CInt,
variadic: false,
safe: true,
constraints: vec![TypeConstraint::NonNull(0)],
};
validator.register_function_signature(signature).unwrap();
let args = vec![Value::Literal(Literal::String("hello".to_string()))];
let result = validator.validate_function_call("test_function", &args, ptr::null());
assert!(result.is_ok());
let args = vec![Value::Nil];
let result = validator.validate_function_call("test_function", &args, ptr::null());
assert!(matches!(result, Err(ref err) if matches!(**err, SafetyError::NullPointerDereference { .. })));
}
#[test]
fn test_bounds_constraint() {
let validator = TypeSafetyValidator::new();
let signature = FunctionSignature {
name: "test_function".to_string(),
parameters: vec![CType::CInt],
return_type: CType::CInt,
variadic: false,
safe: true,
constraints: vec![TypeConstraint::Bounds {
parameter: 0,
min: 0,
max: 100,
}],
};
validator.register_function_signature(signature).unwrap();
let args = vec![Value::Literal(Literal::Number(50.0))];
let result = validator.validate_function_call("test_function", &args, ptr::null());
assert!(result.is_ok());
let args = vec![Value::Literal(Literal::Number(150.0))];
let result = validator.validate_function_call("test_function", &args, ptr::null());
assert!(matches!(result, Err(ref err) if matches!(**err, SafetyError::BoundaryViolation { .. })));
}
#[test]
fn test_validation_rules() {
let validator = TypeSafetyValidator::new();
let rule = ValidationRule {
name: "null_check".to_string(),
trigger: ValidationTrigger::PreCall,
validator: ValidationFunction::NullPointerCheck,
enabled: true,
};
validator.add_validation_rule("test_function".to_string(), rule);
let args = vec![Value::Nil];
let result = validator.apply_validation_rules("test_function", &args, ValidationTrigger::PreCall);
assert!(matches!(result, Err(ref err) if matches!(**err, SafetyError::NullPointerDereference { .. })));
}
}