use crate::error::{Result, SklearsError};
use crate::types::FloatBounds;
use std::fmt::Debug;
#[allow(dead_code)]
pub(crate) trait InternalValidate {
fn internal_validate(&self) -> Result<()>;
fn validation_context(&self) -> ValidationDebugInfo;
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) struct ValidationDebugInfo {
pub type_name: &'static str,
pub rules_applied: Vec<String>,
pub validation_time_ns: u64,
pub memory_used_bytes: usize,
}
impl Default for ValidationDebugInfo {
fn default() -> Self {
Self {
type_name: "Unknown",
rules_applied: Vec::new(),
validation_time_ns: 0,
memory_used_bytes: 0,
}
}
}
pub(crate) mod memory {
use super::*;
#[allow(dead_code)]
pub(crate) struct InternalMemoryPool {
allocated_bytes: usize,
max_allocation: usize,
allocation_count: usize,
}
#[allow(dead_code)]
impl InternalMemoryPool {
pub(crate) fn new(max_allocation: usize) -> Self {
Self {
allocated_bytes: 0,
max_allocation,
allocation_count: 0,
}
}
pub(crate) fn try_allocate(&mut self, bytes: usize) -> Result<AllocationHandle> {
if self.allocated_bytes + bytes > self.max_allocation {
return Err(SklearsError::InvalidOperation(
"Memory pool exhausted".to_string(),
));
}
self.allocated_bytes += bytes;
self.allocation_count += 1;
Ok(AllocationHandle {
id: self.allocation_count,
size: bytes,
})
}
pub(crate) fn deallocate(&mut self, handle: AllocationHandle) {
self.allocated_bytes = self.allocated_bytes.saturating_sub(handle.size);
}
pub(crate) fn stats(&self) -> MemoryStats {
MemoryStats {
allocated_bytes: self.allocated_bytes,
max_allocation: self.max_allocation,
allocation_count: self.allocation_count,
}
}
}
#[derive(Debug, Clone, Copy)]
#[allow(dead_code)]
pub(crate) struct AllocationHandle {
pub(crate) id: usize,
pub(crate) size: usize,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) struct MemoryStats {
pub allocated_bytes: usize,
pub max_allocation: usize,
pub allocation_count: usize,
}
}
pub(crate) mod algorithm_internals {
use super::*;
#[allow(dead_code)]
pub(crate) trait AlgorithmState {
fn current_state(&self) -> InternalState;
fn transition_to(&mut self, new_state: InternalState) -> Result<()>;
fn can_transition_to(&self, new_state: InternalState) -> bool;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub(crate) enum InternalState {
Uninitialized,
Configured,
Fitting,
Fitted,
Error,
}
#[allow(dead_code)]
pub(crate) struct ConvergenceChecker<T: FloatBounds> {
tolerance: T,
max_iterations: usize,
current_iteration: usize,
last_value: Option<T>,
history: Vec<T>,
}
impl<T: FloatBounds> ConvergenceChecker<T> {
#[allow(dead_code)]
pub(crate) fn new(tolerance: T, max_iterations: usize) -> Self {
Self {
tolerance,
max_iterations,
current_iteration: 0,
last_value: None,
history: Vec::new(),
}
}
#[allow(dead_code)]
pub(crate) fn check_convergence(&mut self, current_value: T) -> ConvergenceStatus {
self.current_iteration += 1;
self.history.push(current_value);
if self.current_iteration >= self.max_iterations {
return ConvergenceStatus::MaxIterationsReached;
}
if let Some(last) = self.last_value {
let diff = if current_value > last {
current_value - last
} else {
last - current_value
};
if diff < self.tolerance {
return ConvergenceStatus::Converged;
}
}
self.last_value = Some(current_value);
ConvergenceStatus::Continuing
}
#[allow(dead_code)]
pub(crate) fn stats(&self) -> ConvergenceStats<T> {
ConvergenceStats {
iterations: self.current_iteration,
tolerance: self.tolerance,
history: self.history.clone(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub(crate) enum ConvergenceStatus {
Continuing,
Converged,
MaxIterationsReached,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) struct ConvergenceStats<T> {
pub iterations: usize,
pub tolerance: T,
pub history: Vec<T>,
}
}
#[allow(non_snake_case)]
#[cfg(test)]
pub(crate) mod testing {
use super::*;
pub(crate) fn generate_test_matrix(rows: usize, cols: usize) -> Vec<Vec<f64>> {
(0..rows)
.map(|i| (0..cols).map(|j| (i * cols + j) as f64).collect())
.collect()
}
pub(crate) fn assert_matrices_close(a: &[Vec<f64>], b: &[Vec<f64>], tolerance: f64) {
assert_eq!(a.len(), b.len(), "Matrix row count mismatch");
for (i, (row_a, row_b)) in a.iter().zip(b.iter()).enumerate() {
assert_eq!(
row_a.len(),
row_b.len(),
"Matrix column count mismatch at row {}",
i
);
for (j, (&val_a, &val_b)) in row_a.iter().zip(row_b.iter()).enumerate() {
let diff = (val_a - val_b).abs();
assert!(
diff < tolerance,
"Values differ at position ({}, {}): {} vs {} (diff: {})",
i,
j,
val_a,
val_b,
diff
);
}
}
}
pub(crate) struct MockAlgorithm {
state: algorithm_internals::InternalState,
#[allow(dead_code)]
fitted_data: Option<Vec<f64>>,
}
impl MockAlgorithm {
pub(crate) fn new() -> Self {
Self {
state: algorithm_internals::InternalState::Uninitialized,
fitted_data: None,
}
}
#[allow(dead_code)] pub(crate) fn fit(&mut self, data: Vec<f64>) -> Result<()> {
self.state = algorithm_internals::InternalState::Fitting;
self.fitted_data = Some(data);
self.state = algorithm_internals::InternalState::Fitted;
Ok(())
}
#[allow(dead_code)] pub(crate) fn predict(&self, input: &[f64]) -> Result<Vec<f64>> {
if self.state != algorithm_internals::InternalState::Fitted {
return Err(SklearsError::InvalidOperation(
"Algorithm must be fitted before prediction".to_string(),
));
}
Ok(input.iter().map(|&x| x * 2.0).collect())
}
}
impl algorithm_internals::AlgorithmState for MockAlgorithm {
fn current_state(&self) -> algorithm_internals::InternalState {
self.state
}
fn transition_to(&mut self, new_state: algorithm_internals::InternalState) -> Result<()> {
if !self.can_transition_to(new_state) {
return Err(SklearsError::InvalidOperation(format!(
"Invalid state transition from {:?} to {:?}",
self.state, new_state
)));
}
self.state = new_state;
Ok(())
}
fn can_transition_to(&self, new_state: algorithm_internals::InternalState) -> bool {
use algorithm_internals::InternalState::*;
matches!(
(self.state, new_state),
(Uninitialized, Configured)
| (Configured, Fitting)
| (Fitting, Fitted)
| (Fitting, Error)
| (_, Error)
)
}
}
}
pub(crate) mod debug {
use super::*;
#[allow(dead_code)]
pub(crate) struct DebugCollector {
enabled: bool,
entries: Vec<DebugEntry>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) struct DebugEntry {
pub timestamp: std::time::Instant,
pub category: DebugCategory,
pub message: String,
pub data: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub(crate) enum DebugCategory {
Memory,
Performance,
Algorithm,
Validation,
Other,
}
impl DebugCollector {
#[allow(dead_code)]
pub(crate) fn new(enabled: bool) -> Self {
Self {
enabled,
entries: Vec::new(),
}
}
#[allow(dead_code)]
pub(crate) fn add_entry(
&mut self,
category: DebugCategory,
message: String,
data: Option<String>,
) {
if self.enabled {
self.entries.push(DebugEntry {
timestamp: std::time::Instant::now(),
category,
message,
data,
});
}
}
#[allow(dead_code)]
pub(crate) fn entries(&self) -> &[DebugEntry] {
&self.entries
}
#[allow(dead_code)]
pub(crate) fn clear(&mut self) {
self.entries.clear();
}
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool() {
let mut pool = memory::InternalMemoryPool::new(1024);
let handle1 = pool.try_allocate(512).expect("try_allocate should succeed");
assert_eq!(handle1.size, 512);
let handle2 = pool.try_allocate(256).expect("try_allocate should succeed");
assert_eq!(handle2.size, 256);
assert!(pool.try_allocate(512).is_err());
pool.deallocate(handle1);
let handle3 = pool.try_allocate(512).expect("try_allocate should succeed");
assert_eq!(handle3.size, 512);
let stats = pool.stats();
assert_eq!(stats.allocated_bytes, 768); assert_eq!(stats.allocation_count, 3);
}
#[test]
fn test_convergence_checker() {
let mut checker = algorithm_internals::ConvergenceChecker::new(0.01, 10);
assert_eq!(
checker.check_convergence(1.0),
algorithm_internals::ConvergenceStatus::Continuing
);
assert_eq!(
checker.check_convergence(1.005),
algorithm_internals::ConvergenceStatus::Converged
);
let mut checker2 = algorithm_internals::ConvergenceChecker::new(0.01, 2);
assert_eq!(
checker2.check_convergence(1.0),
algorithm_internals::ConvergenceStatus::Continuing
);
assert_eq!(
checker2.check_convergence(2.0),
algorithm_internals::ConvergenceStatus::MaxIterationsReached
);
}
#[test]
fn test_mock_algorithm() {
use algorithm_internals::AlgorithmState;
let mut algo = testing::MockAlgorithm::new();
assert_eq!(
algo.current_state(),
algorithm_internals::InternalState::Uninitialized
);
assert!(algo
.transition_to(algorithm_internals::InternalState::Configured)
.is_ok());
assert!(algo
.transition_to(algorithm_internals::InternalState::Fitting)
.is_ok());
assert!(algo
.transition_to(algorithm_internals::InternalState::Fitted)
.is_ok());
assert!(algo
.transition_to(algorithm_internals::InternalState::Uninitialized)
.is_err());
}
#[test]
fn test_debug_collector() {
let mut collector = debug::DebugCollector::new(true);
collector.add_entry(
debug::DebugCategory::Performance,
"Test message".to_string(),
Some("Test data".to_string()),
);
let entries = collector.entries();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].category, debug::DebugCategory::Performance);
assert_eq!(entries[0].message, "Test message");
collector.clear();
assert_eq!(collector.entries().len(), 0);
}
#[test]
fn test_testing_utilities() {
let matrix = testing::generate_test_matrix(2, 3);
assert_eq!(matrix.len(), 2);
assert_eq!(matrix[0].len(), 3);
assert_eq!(matrix[0], vec![0.0, 1.0, 2.0]);
assert_eq!(matrix[1], vec![3.0, 4.0, 5.0]);
let matrix2 = vec![vec![0.0, 1.0, 2.0], vec![3.0, 4.0, 5.0]];
testing::assert_matrices_close(&matrix, &matrix2, 1e-10);
}
}