use crate::api_data_structures::{AssociatedType, MethodInfo, TraitInfo};
use crate::error::{Result, SklearsError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraitExplorationResult {
pub trait_name: String,
pub trait_info: TraitInfo,
pub implementations: Vec<String>,
pub dependencies: DependencyAnalysis,
pub performance: PerformanceAnalysis,
pub complexity_score: f64,
pub graph: Option<TraitGraph>,
pub examples: Vec<UsageExample>,
pub related_traits: Vec<String>,
}
impl Default for TraitExplorationResult {
fn default() -> Self {
Self {
trait_name: String::new(),
trait_info: TraitInfo::default(),
implementations: Vec::new(),
dependencies: DependencyAnalysis::default(),
performance: PerformanceAnalysis::default(),
complexity_score: 0.0,
graph: None,
examples: Vec::new(),
related_traits: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct TraitRegistry {
traits: HashMap<String, TraitInfo>,
implementations: HashMap<String, Vec<String>>,
implementation_traits: HashMap<String, Vec<String>>,
}
impl TraitRegistry {
pub fn new() -> Self {
Self {
traits: HashMap::new(),
implementations: HashMap::new(),
implementation_traits: HashMap::new(),
}
}
pub fn load_sklears_traits(&mut self) -> Result<()> {
let estimator_trait = TraitInfo {
name: "Estimator".to_string(),
description: "Base trait for all machine learning estimators".to_string(),
path: "sklears_core::traits::Estimator".to_string(),
generics: Vec::new(),
associated_types: vec![AssociatedType {
name: "Config".to_string(),
description: "Configuration type for the estimator".to_string(),
bounds: Vec::new(),
}],
methods: vec![MethodInfo {
name: "name".to_string(),
signature: "fn name(&self) -> &'static str".to_string(),
description: "Get the name of the estimator".to_string(),
parameters: Vec::new(),
return_type: "&'static str".to_string(),
required: true,
}],
supertraits: Vec::new(),
implementations: vec![
"LinearRegression".to_string(),
"LogisticRegression".to_string(),
"RandomForest".to_string(),
],
};
let fit_trait = TraitInfo {
name: "Fit".to_string(),
description: "Trait for estimators that can be fitted to training data".to_string(),
path: "sklears_core::traits::Fit".to_string(),
generics: vec!["X".to_string(), "Y".to_string()],
associated_types: vec![AssociatedType {
name: "Fitted".to_string(),
description: "The type returned after fitting".to_string(),
bounds: vec!["Send".to_string(), "Sync".to_string()],
}],
methods: vec![MethodInfo {
name: "fit".to_string(),
signature: "fn fit(self, x: &X, y: &Y) -> Result<Self::Fitted>".to_string(),
description: "Fit the estimator to training data".to_string(),
parameters: Vec::new(),
return_type: "Result<Self::Fitted>".to_string(),
required: true,
}],
supertraits: vec!["Estimator".to_string()],
implementations: vec![
"LinearRegression".to_string(),
"LogisticRegression".to_string(),
],
};
let predict_trait = TraitInfo {
name: "Predict".to_string(),
description: "Trait for making predictions on new data".to_string(),
path: "sklears_core::traits::Predict".to_string(),
generics: vec!["X".to_string()],
associated_types: vec![AssociatedType {
name: "Output".to_string(),
description: "The type of predictions made".to_string(),
bounds: Vec::new(),
}],
methods: vec![MethodInfo {
name: "predict".to_string(),
signature: "fn predict(&self, x: &X) -> Result<Self::Output>".to_string(),
description: "Make predictions on input data".to_string(),
parameters: Vec::new(),
return_type: "Result<Self::Output>".to_string(),
required: true,
}],
supertraits: Vec::new(),
implementations: vec![
"LinearRegression".to_string(),
"LogisticRegression".to_string(),
"RandomForest".to_string(),
],
};
let transform_trait = TraitInfo {
name: "Transform".to_string(),
description: "Trait for data transformation operations".to_string(),
path: "sklears_core::traits::Transform".to_string(),
generics: vec!["X".to_string()],
associated_types: vec![AssociatedType {
name: "Output".to_string(),
description: "The type of transformed data".to_string(),
bounds: Vec::new(),
}],
methods: vec![MethodInfo {
name: "transform".to_string(),
signature: "fn transform(&self, x: &X) -> Result<Self::Output>".to_string(),
description: "Transform input data".to_string(),
parameters: Vec::new(),
return_type: "Result<Self::Output>".to_string(),
required: true,
}],
supertraits: Vec::new(),
implementations: vec![
"StandardScaler".to_string(),
"PCA".to_string(),
"MinMaxScaler".to_string(),
],
};
self.add_trait(estimator_trait);
self.add_trait(fit_trait);
self.add_trait(predict_trait);
self.add_trait(transform_trait);
Ok(())
}
pub fn add_trait(&mut self, trait_info: TraitInfo) {
for impl_name in &trait_info.implementations {
self.implementations
.entry(trait_info.name.clone())
.or_default()
.push(impl_name.clone());
self.implementation_traits
.entry(impl_name.clone())
.or_default()
.push(trait_info.name.clone());
}
self.traits.insert(trait_info.name.clone(), trait_info);
}
pub fn get_trait(&self, name: &str) -> Option<&TraitInfo> {
self.traits.get(name)
}
pub fn get_all_traits(&self) -> Vec<&TraitInfo> {
self.traits.values().collect()
}
pub fn get_all_trait_names(&self) -> Vec<String> {
self.traits.keys().cloned().collect()
}
pub fn get_implementations(&self, trait_name: &str) -> Vec<String> {
self.implementations
.get(trait_name)
.cloned()
.unwrap_or_default()
}
pub fn get_traits_for_implementation(&self, implementation: &str) -> Vec<String> {
self.implementation_traits
.get(implementation)
.cloned()
.unwrap_or_default()
}
pub fn has_trait(&self, trait_name: &str) -> bool {
self.traits.contains_key(trait_name)
}
pub fn has_implementation(&self, implementation: &str) -> bool {
self.implementation_traits.contains_key(implementation)
}
pub fn trait_count(&self) -> usize {
self.traits.len()
}
pub fn implementation_count(&self) -> usize {
self.implementation_traits.len()
}
}
impl Default for TraitRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DependencyAnalysis {
pub direct_dependencies: Vec<String>,
pub transitive_dependencies: Vec<String>,
pub dependency_depth: usize,
pub circular_dependencies: Vec<Vec<String>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PerformanceAnalysis {
pub compilation_impact: CompilationImpact,
pub runtime_overhead: RuntimeOverhead,
pub memory_footprint: MemoryFootprint,
pub optimization_hints: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CompilationImpact {
pub estimated_compile_time_ms: usize,
pub generic_instantiation_cost: usize,
pub type_checking_complexity: f64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RuntimeOverhead {
pub virtual_dispatch_cost: usize,
pub stack_frame_size: usize,
pub cache_pressure: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MemoryFootprint {
pub vtable_size_bytes: usize,
pub associated_data_size: usize,
pub total_overhead: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraitGraph {
pub nodes: Vec<TraitGraphNode>,
pub edges: Vec<TraitGraphEdge>,
pub metadata: TraitGraphMetadata,
}
impl TraitGraph {
pub fn to_dot(&self) -> String {
let mut dot = String::from("digraph TraitGraph {\n");
dot.push_str(" rankdir=TB;\n");
dot.push_str(" node [shape=box, style=rounded];\n");
for node in &self.nodes {
let color = match node.node_type {
TraitNodeType::Trait => "lightblue",
TraitNodeType::Implementation => "lightgreen",
TraitNodeType::AssociatedType => "lightyellow",
};
dot.push_str(&format!(
" \"{}\" [label=\"{}\" fillcolor={} style=\"filled,rounded\"];\n",
node.id, node.label, color
));
}
for edge in &self.edges {
let style = match edge.edge_type {
EdgeType::Inherits => "solid",
EdgeType::Implements => "dashed",
EdgeType::AssociatedWith => "dotted",
};
dot.push_str(&format!(
" \"{}\" -> \"{}\" [style={}];\n",
edge.from, edge.to, style
));
}
dot.push_str("}\n");
dot
}
pub fn to_json(&self) -> Result<String> {
serde_json::to_string_pretty(self)
.map_err(|e| SklearsError::Other(format!("Failed to serialize graph to JSON: {}", e)))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraitGraphNode {
pub id: String,
pub label: String,
pub node_type: TraitNodeType,
pub description: String,
pub complexity: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraitGraphEdge {
pub from: String,
pub to: String,
pub edge_type: EdgeType,
pub weight: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TraitNodeType {
Trait,
Implementation,
AssociatedType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EdgeType {
Inherits,
Implements,
AssociatedWith,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraitGraphMetadata {
pub center_node: String,
pub generation_time: chrono::DateTime<chrono::Utc>,
pub export_format: GraphExportFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GraphExportFormat {
Dot,
Json,
Svg,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageExample {
pub title: String,
pub description: String,
pub code: String,
pub category: ExampleCategory,
pub difficulty: ExampleDifficulty,
pub runnable: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExampleCategory {
Implementation,
Usage,
Generic,
Advanced,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExampleDifficulty {
Beginner,
Intermediate,
Advanced,
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trait_registry_creation() {
let registry = TraitRegistry::new();
assert_eq!(registry.trait_count(), 0);
assert_eq!(registry.implementation_count(), 0);
}
#[test]
fn test_load_sklears_traits() -> Result<()> {
let mut registry = TraitRegistry::new();
registry.load_sklears_traits()?;
assert!(registry.has_trait("Estimator"));
assert!(registry.has_trait("Fit"));
assert!(registry.has_trait("Predict"));
assert!(registry.has_trait("Transform"));
assert_eq!(registry.trait_count(), 4);
Ok(())
}
#[test]
fn test_trait_retrieval() -> Result<()> {
let mut registry = TraitRegistry::new();
registry.load_sklears_traits()?;
let estimator_trait = registry
.get_trait("Estimator")
.expect("get_trait should succeed");
assert_eq!(estimator_trait.name, "Estimator");
assert_eq!(estimator_trait.path, "sklears_core::traits::Estimator");
assert!(registry.get_trait("NonExistent").is_none());
Ok(())
}
#[test]
fn test_implementation_mapping() -> Result<()> {
let mut registry = TraitRegistry::new();
registry.load_sklears_traits()?;
let estimator_impls = registry.get_implementations("Estimator");
assert!(estimator_impls.contains(&"LinearRegression".to_string()));
assert!(estimator_impls.contains(&"LogisticRegression".to_string()));
assert!(estimator_impls.contains(&"RandomForest".to_string()));
let linear_reg_traits = registry.get_traits_for_implementation("LinearRegression");
assert!(linear_reg_traits.contains(&"Estimator".to_string()));
assert!(linear_reg_traits.contains(&"Fit".to_string()));
assert!(linear_reg_traits.contains(&"Predict".to_string()));
Ok(())
}
#[test]
fn test_custom_trait_addition() {
let mut registry = TraitRegistry::new();
let custom_trait = TraitInfo {
name: "CustomTrait".to_string(),
description: "A custom trait for testing".to_string(),
path: "test::CustomTrait".to_string(),
generics: vec!["T".to_string()],
associated_types: vec![],
methods: vec![],
supertraits: vec![],
implementations: vec!["CustomImpl".to_string()],
};
registry.add_trait(custom_trait);
assert!(registry.has_trait("CustomTrait"));
assert!(registry.has_implementation("CustomImpl"));
assert_eq!(
registry.get_implementations("CustomTrait"),
vec!["CustomImpl"]
);
assert_eq!(
registry.get_traits_for_implementation("CustomImpl"),
vec!["CustomTrait"]
);
}
#[test]
fn test_trait_exploration_result_default() {
let result = TraitExplorationResult::default();
assert!(result.trait_name.is_empty());
assert_eq!(result.complexity_score, 0.0);
assert!(result.implementations.is_empty());
assert!(result.examples.is_empty());
assert!(result.related_traits.is_empty());
}
#[test]
fn test_dependency_analysis_default() {
let analysis = DependencyAnalysis::default();
assert!(analysis.direct_dependencies.is_empty());
assert!(analysis.transitive_dependencies.is_empty());
assert_eq!(analysis.dependency_depth, 0);
assert!(analysis.circular_dependencies.is_empty());
}
#[test]
fn test_performance_analysis_default() {
let performance = PerformanceAnalysis::default();
assert_eq!(performance.compilation_impact.estimated_compile_time_ms, 0);
assert_eq!(performance.runtime_overhead.virtual_dispatch_cost, 0);
assert_eq!(performance.memory_footprint.vtable_size_bytes, 0);
assert!(performance.optimization_hints.is_empty());
}
#[test]
fn test_trait_graph_dot_export() {
use chrono::Utc;
let graph = TraitGraph {
nodes: vec![
TraitGraphNode {
id: "Estimator".to_string(),
label: "Estimator".to_string(),
node_type: TraitNodeType::Trait,
description: "Base trait".to_string(),
complexity: 1.0,
},
TraitGraphNode {
id: "LinearRegression".to_string(),
label: "LinearRegression".to_string(),
node_type: TraitNodeType::Implementation,
description: "Linear regression implementation".to_string(),
complexity: 2.0,
},
],
edges: vec![TraitGraphEdge {
from: "LinearRegression".to_string(),
to: "Estimator".to_string(),
edge_type: EdgeType::Implements,
weight: 1.0,
}],
metadata: TraitGraphMetadata {
center_node: "Estimator".to_string(),
generation_time: Utc::now(),
export_format: GraphExportFormat::Dot,
},
};
let dot = graph.to_dot();
assert!(dot.contains("digraph TraitGraph"));
assert!(dot.contains("Estimator"));
assert!(dot.contains("LinearRegression"));
assert!(dot.contains("lightblue")); assert!(dot.contains("lightgreen")); assert!(dot.contains("dashed")); }
#[test]
fn test_trait_graph_json_export() -> Result<()> {
use chrono::Utc;
let graph = TraitGraph {
nodes: vec![TraitGraphNode {
id: "Estimator".to_string(),
label: "Estimator".to_string(),
node_type: TraitNodeType::Trait,
description: "Base trait".to_string(),
complexity: 1.0,
}],
edges: vec![],
metadata: TraitGraphMetadata {
center_node: "Estimator".to_string(),
generation_time: Utc::now(),
export_format: GraphExportFormat::Json,
},
};
let json = graph.to_json()?;
assert!(json.contains("\"nodes\""));
assert!(json.contains("\"edges\""));
assert!(json.contains("\"metadata\""));
assert!(json.contains("Estimator"));
Ok(())
}
#[test]
fn test_usage_example_creation() {
let example = UsageExample {
title: "Basic Example".to_string(),
description: "A simple usage example".to_string(),
code: "println!(\"Hello, world!\");".to_string(),
category: ExampleCategory::Usage,
difficulty: ExampleDifficulty::Beginner,
runnable: true,
};
assert_eq!(example.category, ExampleCategory::Usage);
assert_eq!(example.difficulty, ExampleDifficulty::Beginner);
assert!(example.runnable);
}
#[test]
fn test_all_trait_names() -> Result<()> {
let mut registry = TraitRegistry::new();
registry.load_sklears_traits()?;
let trait_names = registry.get_all_trait_names();
assert_eq!(trait_names.len(), 4);
assert!(trait_names.contains(&"Estimator".to_string()));
assert!(trait_names.contains(&"Fit".to_string()));
assert!(trait_names.contains(&"Predict".to_string()));
assert!(trait_names.contains(&"Transform".to_string()));
Ok(())
}
#[test]
fn test_all_traits() -> Result<()> {
let mut registry = TraitRegistry::new();
registry.load_sklears_traits()?;
let all_traits = registry.get_all_traits();
assert_eq!(all_traits.len(), 4);
let trait_names: Vec<&str> = all_traits.iter().map(|t| t.name.as_str()).collect();
assert!(trait_names.contains(&"Estimator"));
assert!(trait_names.contains(&"Fit"));
assert!(trait_names.contains(&"Predict"));
assert!(trait_names.contains(&"Transform"));
Ok(())
}
#[test]
fn test_nonexistent_queries() {
let registry = TraitRegistry::new();
assert!(!registry.has_trait("NonExistent"));
assert!(!registry.has_implementation("NonExistent"));
assert!(registry.get_implementations("NonExistent").is_empty());
assert!(registry
.get_traits_for_implementation("NonExistent")
.is_empty());
}
#[test]
fn test_scirs2_compliance() {
use scirs2_core::random::Random;
use scirs2_core::random::RngExt;
let mut rng = Random::seed(42);
let _random_value: f64 = rng.random();
use scirs2_core::ndarray::{Array1, Array2};
let _arr1: Array1<f64> = Array1::zeros(10);
let _arr2: Array2<f64> = Array2::zeros((5, 5));
}
#[test]
fn test_registry_counts() -> Result<()> {
let mut registry = TraitRegistry::new();
assert_eq!(registry.trait_count(), 0);
assert_eq!(registry.implementation_count(), 0);
registry.load_sklears_traits()?;
assert_eq!(registry.trait_count(), 4);
assert!(registry.implementation_count() > 0);
Ok(())
}
}