#![allow(missing_docs)]
use super::{Type, TypeVar, TypeScheme, TypeChecker, TypeLevel};
use super::algebraic::{Pattern, PatternMatcher};
use super::advanced_type_classes::AdvancedTypeClassEnv;
use super::r7rs_integration::R7RSIntegration;
use crate::eval::value::{Value, PrimitiveProcedure, PrimitiveImpl, ThreadSafeEnvironment};
use crate::diagnostics::{Error, Result, Span};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use std::fmt;
pub struct TypeSystemBridge {
type_level: TypeLevel,
type_checker: TypeChecker,
advanced_classes: AdvancedTypeClassEnv,
r7rs_integration: R7RSIntegration,
pattern_matcher: PatternMatcher,
primitive_cache: Arc<RwLock<HashMap<String, OptimizedPrimitive>>>,
migration_state: MigrationState,
}
#[derive(Debug, Clone)]
pub struct MigrationState {
static_functions: HashSet<String>,
annotated_functions: HashMap<String, TypeScheme>,
inferred_functions: HashMap<String, TypeScheme>,
migration_warnings: Vec<MigrationWarning>,
}
#[derive(Debug, Clone)]
pub struct MigrationWarning {
pub message: String,
pub span: Option<Span>,
pub suggestion: Option<String>,
pub severity: WarningSeverity,
}
#[derive(Debug, Clone, PartialEq)]
pub enum WarningSeverity {
Info,
Warning,
Error,
}
#[derive(Debug, Clone)]
pub struct OptimizedPrimitive {
pub base: PrimitiveProcedure,
pub specializations: HashMap<TypeSignature, PrimitiveImpl>,
pub stats: PrimitiveStats,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TypeSignature {
pub params: Vec<Type>,
pub return_type: Type,
}
#[derive(Debug, Clone)]
pub struct PrimitiveStats {
pub call_count: u64,
pub type_calls: HashMap<TypeSignature, u64>,
pub avg_execution_time: u64,
pub avg_memory_usage: u64,
}
#[derive(Debug, Clone)]
pub struct IntegrationConfig {
pub infer_primitive_types: bool,
pub specialize_primitives: bool,
pub gradual_typing: bool,
pub compile_patterns: bool,
pub max_recursion_depth: usize,
pub optimization_level: OptimizationLevel,
}
#[derive(Debug, Clone, PartialEq)]
pub enum OptimizationLevel {
None,
Basic,
Aggressive,
}
impl TypeSystemBridge {
pub fn new(_config: IntegrationConfig) -> Self {
Self {
type_level: TypeLevel::Dynamic,
type_checker: TypeChecker::new(TypeLevel::Dynamic),
advanced_classes: AdvancedTypeClassEnv::default(),
r7rs_integration: R7RSIntegration::new(),
pattern_matcher: PatternMatcher::new(),
primitive_cache: Arc::new(RwLock::new(HashMap::new())),
migration_state: MigrationState::new(),
}
}
pub fn set_type_level(&mut self, level: TypeLevel) {
self.type_level = level;
self.type_checker = TypeChecker::new(level);
}
pub fn type_level(&self) -> TypeLevel {
self.type_level
}
pub fn integrate_primitive(&mut self, name: String, primitive: PrimitiveProcedure) -> Result<()> {
let type_scheme = self.infer_primitive_type(&name, &primitive)?;
self.type_checker.env_mut().bind(name.clone(), type_scheme.clone());
let optimized = OptimizedPrimitive {
base: primitive,
specializations: HashMap::new(),
stats: PrimitiveStats::new(),
};
self.primitive_cache.write().unwrap().insert(name, optimized);
Ok(())
}
fn infer_primitive_type(&self, name: &str, primitive: &PrimitiveProcedure) -> Result<TypeScheme> {
match name {
"+" | "-" | "*" | "/" => Ok(TypeScheme::polymorphic(
vec![TypeVar::with_name("a")],
vec![super::Constraint { class: "Num".to_string(), type_: Type::named_var("a") }],
Type::function(
vec![Type::named_var("a"), Type::named_var("a")],
Type::named_var("a"),
),
)),
"=" | "<" | ">" | "<=" | ">=" => Ok(TypeScheme::polymorphic(
vec![TypeVar::with_name("a")],
vec![super::Constraint { class: "Ord".to_string(), type_: Type::named_var("a") }],
Type::function(
vec![Type::named_var("a"), Type::named_var("a")],
Type::Boolean,
),
)),
"cons" => Ok(TypeScheme::polymorphic(
vec![TypeVar::with_name("a")],
vec![],
Type::function(
vec![Type::named_var("a"), Type::list(Type::named_var("a"))],
Type::list(Type::named_var("a")),
),
)),
"car" => Ok(TypeScheme::polymorphic(
vec![TypeVar::with_name("a")],
vec![],
Type::function(
vec![Type::pair(Type::named_var("a"), Type::Dynamic)],
Type::named_var("a"),
),
)),
"cdr" => Ok(TypeScheme::polymorphic(
vec![TypeVar::with_name("a"), TypeVar::with_name("b")],
vec![],
Type::function(
vec![Type::pair(Type::named_var("a"), Type::named_var("b"))],
Type::named_var("b"),
),
)),
"display" | "write" => Ok(TypeScheme::polymorphic(
vec![TypeVar::with_name("a")],
vec![super::Constraint { class: "Show".to_string(), type_: Type::named_var("a") }],
Type::Effectful {
input: Box::new(Type::named_var("a")),
effects: vec![super::Effect::IO],
output: Box::new(Type::Unit),
},
)),
"string-append" => Ok(TypeScheme::monomorphic(
Type::function(
vec![Type::String, Type::String],
Type::String,
),
)),
"string-length" => Ok(TypeScheme::monomorphic(
Type::function(vec![Type::String], Type::Number),
)),
"vector-ref" => Ok(TypeScheme::polymorphic(
vec![TypeVar::with_name("a")],
vec![],
Type::function(
vec![Type::vector(Type::named_var("a")), Type::Number],
Type::named_var("a"),
),
)),
"vector-set!" => Ok(TypeScheme::polymorphic(
vec![TypeVar::with_name("a")],
vec![],
Type::Effectful {
input: Box::new(Type::vector(Type::named_var("a"))),
effects: vec![super::Effect::State(Type::Unit)],
output: Box::new(Type::Unit),
},
)),
_ => {
let param_types = (0..primitive.arity_min)
.map(|i| Type::named_var(format!("a{i}")))
.collect();
let return_type = Type::named_var("result");
Ok(TypeScheme::polymorphic(
(0..=primitive.arity_min).map(|i| TypeVar::with_name(format!("a{i}"))).collect(),
vec![],
Type::function(param_types, return_type),
))
}
}
}
pub fn specialize_primitive(
&mut self,
name: &str,
type_args: &[Type]
) -> Result<Option<PrimitiveImpl>> {
let cache = self.primitive_cache.read().unwrap();
if let Some(optimized) = cache.get(name) {
let sig = TypeSignature {
params: type_args.to_vec(),
return_type: Type::Dynamic, };
if let Some(specialized) = optimized.specializations.get(&sig) {
return Ok(Some(specialized.clone()));
}
drop(cache); return self.generate_specialization(name, &sig);
}
Ok(None)
}
fn generate_specialization(&mut self, name: &str, sig: &TypeSignature) -> Result<Option<PrimitiveImpl>> {
match name {
"+" if sig.params.len() == 2 && sig.params.iter().all(|t| *t == Type::Number) => {
Ok(Some(PrimitiveImpl::RustFn(|args| {
if let (Some(n1), Some(n2)) = (args[0].as_number(), args[1].as_number()) {
Ok(Value::number(n1 + n2))
} else {
Err(Box::new(Error::runtime_error("Type error in specialized +".to_string(), None)))
}
})))
}
"*" if sig.params.len() == 2 && sig.params.iter().all(|t| *t == Type::Number) => {
Ok(Some(PrimitiveImpl::RustFn(|args| {
if let (Some(n1), Some(n2)) = (args[0].as_number(), args[1].as_number()) {
Ok(Value::number(n1 * n2))
} else {
Err(Box::new(Error::runtime_error("Type error in specialized *".to_string(), None)))
}
})))
}
"string-append" if sig.params.iter().all(|t| *t == Type::String) => {
Ok(Some(PrimitiveImpl::RustFn(|args| {
let mut result = String::new();
for arg in args {
if let Some(s) = arg.as_string() {
result.push_str(s);
} else {
return Err(Box::new(Error::runtime_error("Type error in specialized string-append".to_string(), None)));
}
}
Ok(Value::string(result))
})))
}
_ => Ok(None), }
}
pub fn type_safe_pattern_match(&mut self, pattern: &Pattern, value: &Value, expected_type: &Type) -> Result<bool> {
if !self.value_matches_type(value, expected_type)? {
return Ok(false);
}
self.pattern_matcher.compile_match(&super::algebraic::MatchExpression {
scrutinee: "value".to_string(), clauses: vec![super::algebraic::MatchClause {
pattern: pattern.clone(),
guard: None,
body: "true".to_string(),
span: None,
}],
span: None,
})?;
Ok(true)
}
pub fn value_matches_type(&self, value: &Value, ty: &Type) -> Result<bool> {
match ty {
Type::Dynamic => Ok(true),
Type::Number => Ok(value.is_number()),
Type::String => Ok(value.is_string()),
Type::Boolean => Ok(matches!(value, Value::Literal(crate::ast::Literal::Boolean(_)))),
Type::Symbol => Ok(value.is_symbol()),
Type::List(_) => Ok(value.is_list()),
Type::Vector(_) => Ok(value.is_vector()),
Type::Pair(_, _) => Ok(value.is_pair()),
_ => {
self.r7rs_integration.validate_gradual_typing(ty, value)
}
}
}
pub fn migrate_function(&mut self, name: String, type_scheme: TypeScheme) -> Result<()> {
self.migration_state.static_functions.insert(name.clone());
self.type_checker.env_mut().bind(name.clone(), type_scheme);
self.check_migration_issues(&name)?;
Ok(())
}
fn check_migration_issues(&mut self, name: &str) -> Result<()> {
if name.starts_with("string-") && self.migration_state.inferred_functions.contains_key(name) {
self.migration_state.migration_warnings.push(MigrationWarning {
message: format!("Function {name} migrated to static typing - verify all call sites use strings"),
span: None,
suggestion: Some("Add type annotations to caller functions".to_string()),
severity: WarningSeverity::Warning,
});
}
Ok(())
}
pub fn migration_warnings(&self) -> &[MigrationWarning] {
&self.migration_state.migration_warnings
}
pub fn optimize_for_types(&mut self, env: &Arc<ThreadSafeEnvironment>) -> Result<()> {
for name in env.all_variable_names() {
if let Some(optimized) = self.primitive_cache.read().unwrap().get(&name) {
self.analyze_primitive_usage(&name, optimized)?;
}
}
Ok(())
}
fn analyze_primitive_usage(&self, name: &str, optimized: &OptimizedPrimitive) -> Result<()> {
let mut most_common: Option<(TypeSignature, u64)> = None;
for (sig, count) in &optimized.stats.type_calls {
if let Some((_, current_max)) = &most_common {
if count > current_max {
most_common = Some((sig.clone(), *count));
}
} else {
most_common = Some((sig.clone(), *count));
}
}
if let Some((sig, count)) = most_common {
if count > 100 && !optimized.specializations.contains_key(&sig) {
println!("Suggestion: Specialize {name} for signature {sig:?} (used {count} times)");
}
}
Ok(())
}
pub fn performance_report(&self) -> PerformanceReport {
let cache = self.primitive_cache.read().unwrap();
let mut report = PerformanceReport::new();
for (name, optimized) in cache.iter() {
let primitive_report = PrimitivePerformanceReport {
name: name.clone(),
total_calls: optimized.stats.call_count,
specializations: optimized.specializations.len(),
avg_execution_time: optimized.stats.avg_execution_time,
avg_memory_usage: optimized.stats.avg_memory_usage,
};
report.add_primitive(primitive_report);
}
report
}
}
impl MigrationState {
pub fn new() -> Self {
Self {
static_functions: HashSet::new(),
annotated_functions: HashMap::new(),
inferred_functions: HashMap::new(),
migration_warnings: Vec::new(),
}
}
pub fn is_static(&self, name: &str) -> bool {
self.static_functions.contains(name)
}
pub fn progress(&self) -> f64 {
let total = self.static_functions.len() + self.annotated_functions.len() + self.inferred_functions.len();
if total == 0 {
0.0
} else {
(self.static_functions.len() as f64 / total as f64) * 100.0
}
}
}
impl PrimitiveStats {
pub fn new() -> Self {
Self {
call_count: 0,
type_calls: HashMap::new(),
avg_execution_time: 0,
avg_memory_usage: 0,
}
}
pub fn record_call(&mut self, sig: TypeSignature, execution_time: u64, memory_usage: u64) {
self.call_count += 1;
*self.type_calls.entry(sig).or_insert(0) += 1;
self.avg_execution_time = ((self.avg_execution_time * (self.call_count - 1)) + execution_time) / self.call_count;
self.avg_memory_usage = ((self.avg_memory_usage * (self.call_count - 1)) + memory_usage) / self.call_count;
}
}
#[derive(Debug, Clone)]
pub struct PerformanceReport {
pub primitives: Vec<PrimitivePerformanceReport>,
pub total_calls: u64,
pub total_specializations: usize,
}
#[derive(Debug, Clone)]
pub struct PrimitivePerformanceReport {
pub name: String,
pub total_calls: u64,
pub specializations: usize,
pub avg_execution_time: u64,
pub avg_memory_usage: u64,
}
impl PerformanceReport {
pub fn new() -> Self {
Self {
primitives: Vec::new(),
total_calls: 0,
total_specializations: 0,
}
}
pub fn add_primitive(&mut self, report: PrimitivePerformanceReport) {
self.total_calls += report.total_calls;
self.total_specializations += report.specializations;
self.primitives.push(report);
}
}
impl Default for IntegrationConfig {
fn default() -> Self {
Self {
infer_primitive_types: true,
specialize_primitives: true,
gradual_typing: true,
compile_patterns: true,
max_recursion_depth: 100,
optimization_level: OptimizationLevel::Basic,
}
}
}
impl Default for MigrationState {
fn default() -> Self {
Self::new()
}
}
impl Default for PrimitiveStats {
fn default() -> Self {
Self::new()
}
}
impl Default for PerformanceReport {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for WarningSeverity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WarningSeverity::Info => write!(f, "INFO"),
WarningSeverity::Warning => write!(f, "WARNING"),
WarningSeverity::Error => write!(f, "ERROR"),
}
}
}
impl fmt::Display for MigrationWarning {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[{}] {}", self.severity, self.message)?;
if let Some(suggestion) = &self.suggestion {
write!(f, " (Suggestion: {suggestion})")?;
}
Ok(())
}
}
impl fmt::Display for PerformanceReport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Type System Performance Report")?;
writeln!(f, "==============================")?;
writeln!(f, "Total calls: {}", self.total_calls)?;
writeln!(f, "Total specializations: {}", self.total_specializations)?;
writeln!(f)?;
for primitive in &self.primitives {
writeln!(f, "Primitive: {}", primitive.name)?;
writeln!(f, " Calls: {}", primitive.total_calls)?;
writeln!(f, " Specializations: {}", primitive.specializations)?;
writeln!(f, " Avg execution time: {}ns", primitive.avg_execution_time)?;
writeln!(f, " Avg memory usage: {} bytes", primitive.avg_memory_usage)?;
writeln!(f)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bridge_creation() {
let config = IntegrationConfig::default();
let bridge = TypeSystemBridge::new(config);
assert_eq!(bridge.type_level(), TypeLevel::Dynamic);
}
#[test]
fn test_primitive_integration() {
let config = IntegrationConfig::default();
let mut bridge = TypeSystemBridge::new(config);
let add_primitive = PrimitiveProcedure {
name: "+".to_string(),
arity_min: 2,
arity_max: None,
implementation: PrimitiveImpl::RustFn(|_| Ok(Value::Unspecified)),
effects: vec![crate::effects::Effect::Pure],
};
let result = bridge.integrate_primitive("+".to_string(), add_primitive);
assert!(result.is_ok());
}
#[test]
fn test_value_type_matching() {
let config = IntegrationConfig::default();
let bridge = TypeSystemBridge::new(config);
let number_val = Value::integer(42);
let string_val = Value::string("hello");
assert!(bridge.value_matches_type(&number_val, &Type::Number).unwrap());
assert!(!bridge.value_matches_type(&number_val, &Type::String).unwrap());
assert!(bridge.value_matches_type(&string_val, &Type::String).unwrap());
assert!(bridge.value_matches_type(&number_val, &Type::Dynamic).unwrap());
}
#[test]
fn test_migration_state() {
let mut state = MigrationState::new();
assert_eq!(state.progress(), 0.0);
state.static_functions.insert("test-func".to_string());
state.annotated_functions.insert("other-func".to_string(), TypeScheme::monomorphic(Type::Number));
assert!(state.is_static("test-func"));
assert!(!state.is_static("other-func"));
assert_eq!(state.progress(), 50.0);
}
#[test]
fn test_primitive_stats() {
let mut stats = PrimitiveStats::new();
let sig = TypeSignature {
params: vec![Type::Number, Type::Number],
return_type: Type::Number,
};
stats.record_call(sig.clone(), 100, 64);
stats.record_call(sig.clone(), 200, 128);
assert_eq!(stats.call_count, 2);
assert_eq!(stats.avg_execution_time, 150);
assert_eq!(stats.avg_memory_usage, 96);
assert_eq!(*stats.type_calls.get(&sig).unwrap(), 2);
}
#[test]
fn test_performance_report() {
let mut report = PerformanceReport::new();
let primitive_report = PrimitivePerformanceReport {
name: "+".to_string(),
total_calls: 100,
specializations: 2,
avg_execution_time: 50,
avg_memory_usage: 32,
};
report.add_primitive(primitive_report);
assert_eq!(report.total_calls, 100);
assert_eq!(report.total_specializations, 2);
assert_eq!(report.primitives.len(), 1);
}
}