use crate::{EmbeddingModel, Vector};
use anyhow::Result;
use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
use uuid::Uuid;
use scirs2_core::random::{thread_rng, RngExt};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaLearningConfig {
pub maml_config: MAMLConfig,
pub reptile_config: ReptileConfig,
pub prototypical_config: PrototypicalConfig,
pub matching_config: MatchingConfig,
pub relation_config: RelationConfig,
pub mann_config: MANNConfig,
pub task_config: TaskSamplingConfig,
pub global_settings: GlobalMetaSettings,
}
impl Default for MetaLearningConfig {
fn default() -> Self {
Self {
maml_config: MAMLConfig::default(),
reptile_config: ReptileConfig::default(),
prototypical_config: PrototypicalConfig::default(),
matching_config: MatchingConfig::default(),
relation_config: RelationConfig::default(),
mann_config: MANNConfig::default(),
task_config: TaskSamplingConfig::default(),
global_settings: GlobalMetaSettings::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalMetaSettings {
pub meta_episodes: usize,
pub meta_learning_rate: f32,
pub task_learning_rate: f32,
pub adaptation_steps: usize,
pub eval_frequency: usize,
pub enable_gradient_clipping: bool,
pub gradient_clip_value: f32,
pub enable_early_stopping: bool,
pub early_stopping_patience: usize,
}
impl Default for GlobalMetaSettings {
fn default() -> Self {
Self {
meta_episodes: 1000,
meta_learning_rate: 0.001,
task_learning_rate: 0.01,
adaptation_steps: 5,
eval_frequency: 100,
enable_gradient_clipping: true,
gradient_clip_value: 0.5,
enable_early_stopping: true,
early_stopping_patience: 50,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MAMLConfig {
pub model_architecture: ModelArchitecture,
pub inner_lr: f32,
pub outer_lr: f32,
pub inner_steps: usize,
pub first_order: bool,
pub allow_unused: bool,
}
impl Default for MAMLConfig {
fn default() -> Self {
Self {
model_architecture: ModelArchitecture::default(),
inner_lr: 0.01,
outer_lr: 0.001,
inner_steps: 5,
first_order: false,
allow_unused: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelArchitecture {
pub input_dim: usize,
pub hidden_dims: Vec<usize>,
pub output_dim: usize,
pub activation: String,
pub use_batch_norm: bool,
pub dropout_rate: f32,
}
impl Default for ModelArchitecture {
fn default() -> Self {
Self {
input_dim: 128,
hidden_dims: vec![256, 256],
output_dim: 128,
activation: "relu".to_string(),
use_batch_norm: true,
dropout_rate: 0.1,
}
}
}
#[derive(Debug, Clone)]
pub struct ModelParameters {
pub weights: Vec<Array2<f32>>,
pub biases: Vec<Array1<f32>>,
pub batch_norm_params: Option<BatchNormParameters>,
}
#[derive(Debug, Clone)]
pub struct BatchNormParameters {
pub scale: Vec<Array1<f32>>,
pub shift: Vec<Array1<f32>>,
pub running_mean: Vec<Array1<f32>>,
pub running_var: Vec<Array1<f32>>,
}
#[derive(Debug)]
pub struct MetaOptimizer {
pub learning_rate: f32,
pub momentum: f32,
pub momentum_buffers: Vec<Array2<f32>>,
pub bias_momentum_buffers: Vec<Array1<f32>>,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub m_weights: Vec<Array2<f32>>,
pub v_weights: Vec<Array2<f32>>,
pub m_biases: Vec<Array1<f32>>,
pub v_biases: Vec<Array1<f32>>,
pub time_step: usize,
}
#[derive(Debug, Clone)]
pub struct AdaptationResult {
pub task_id: Uuid,
pub initial_loss: f32,
pub final_loss: f32,
pub adaptation_steps: usize,
pub duration: Duration,
pub task_metadata: TaskMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskMetadata {
pub domain: String,
pub difficulty: f32,
pub support_size: usize,
pub query_size: usize,
pub created_at: Instant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReptileConfig {
pub model_architecture: ModelArchitecture,
pub inner_lr: f32,
pub outer_lr: f32,
pub inner_steps: usize,
pub tasks_per_batch: usize,
}
impl Default for ReptileConfig {
fn default() -> Self {
Self {
model_architecture: ModelArchitecture::default(),
inner_lr: 0.01,
outer_lr: 0.001,
inner_steps: 10,
tasks_per_batch: 5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrototypicalConfig {
pub embedding_dim: usize,
pub support_size: usize,
pub query_size: usize,
pub num_classes: usize,
pub distance_metric: String,
pub temperature: f32,
}
impl Default for PrototypicalConfig {
fn default() -> Self {
Self {
embedding_dim: 128,
support_size: 5,
query_size: 15,
num_classes: 5,
distance_metric: "euclidean".to_string(),
temperature: 1.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MatchingConfig {
pub embedding_dim: usize,
pub lstm_hidden_dim: usize,
pub processing_steps: usize,
pub use_full_context: bool,
}
impl Default for MatchingConfig {
fn default() -> Self {
Self {
embedding_dim: 128,
lstm_hidden_dim: 256,
processing_steps: 5,
use_full_context: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelationConfig {
pub feature_dim: usize,
pub relation_hidden_dims: Vec<usize>,
pub embedding_hidden_dims: Vec<usize>,
pub dropout_rate: f32,
}
impl Default for RelationConfig {
fn default() -> Self {
Self {
feature_dim: 128,
relation_hidden_dims: vec![256, 128, 64, 1],
embedding_hidden_dims: vec![128, 128, 128, 128],
dropout_rate: 0.1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MANNConfig {
pub memory_size: usize,
pub memory_dim: usize,
pub controller_hidden_dim: usize,
pub num_read_heads: usize,
pub num_write_heads: usize,
pub memory_init: String,
}
impl Default for MANNConfig {
fn default() -> Self {
Self {
memory_size: 128,
memory_dim: 64,
controller_hidden_dim: 256,
num_read_heads: 4,
num_write_heads: 1,
memory_init: "random".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskSamplingConfig {
pub min_support: usize,
pub max_support: usize,
pub min_query: usize,
pub max_query: usize,
pub difficulty_sampling: String,
pub domain_weights: HashMap<String, f32>,
}
impl Default for TaskSamplingConfig {
fn default() -> Self {
Self {
min_support: 1,
max_support: 10,
min_query: 5,
max_query: 20,
difficulty_sampling: "uniform".to_string(),
domain_weights: HashMap::new(),
}
}
}
#[derive(Debug)]
pub struct MetaLearningHistory {
pub episodes: Vec<EpisodeResult>,
pub performance_history: Vec<PerformanceSnapshot>,
pub task_statistics: TaskStatistics,
}
#[derive(Debug, Clone)]
pub struct EpisodeResult {
pub episode: usize,
pub avg_loss: f32,
pub avg_accuracy: f32,
pub task_results: Vec<TaskResult>,
pub duration: Duration,
}
#[derive(Debug, Clone)]
pub struct TaskResult {
pub task_id: Uuid,
pub loss: f32,
pub accuracy: f32,
pub adaptation_steps: usize,
pub metadata: TaskMetadata,
}
#[derive(Debug, Clone)]
pub struct PerformanceSnapshot {
pub timestamp: Instant,
pub avg_loss: f32,
pub avg_accuracy: f32,
pub learning_rate: f32,
pub memory_usage: usize,
}
#[derive(Debug)]
pub struct TaskStatistics {
pub domain_distribution: HashMap<String, usize>,
pub difficulty_distribution: HashMap<String, usize>,
pub success_rate_by_domain: HashMap<String, f32>,
pub avg_adaptation_time: HashMap<String, Duration>,
}
#[derive(Debug)]
pub struct MetaPerformanceMetrics {
pub current_meta_lr: f32,
pub best_validation_accuracy: f32,
pub best_validation_loss: f32,
pub episodes_without_improvement: usize,
pub total_training_time: Duration,
pub memory_stats: MemoryStats,
}
#[derive(Debug)]
pub struct MemoryStats {
pub current_usage: usize,
pub peak_usage: usize,
pub avg_usage: f32,
pub allocation_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OptimizerType {
SGD,
Adam,
AdamW,
RMSprop,
Momentum,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DifficultyDistribution {
Uniform,
Exponential,
Gaussian,
Beta,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DataGenerator {
Sinusoidal,
Linear,
Polynomial,
Gaussian,
Categorical,
}
#[derive(Debug, Clone)]
pub struct DataPoint {
pub input: Array1<f32>,
pub target: Array1<f32>,
pub metadata: Option<DataPointMetadata>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataPointMetadata {
pub id: Uuid,
pub source: String,
pub quality: f32,
pub properties: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct Task {
pub id: Uuid,
pub task_type: String,
pub support_set: Vec<DataPoint>,
pub query_set: Vec<DataPoint>,
pub metadata: TaskMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaLearningResult {
pub average_loss: f32,
pub adaptation_results: Vec<AdaptationResult>,
pub convergence_metric: f32,
}
#[derive(Debug, Clone)]
pub struct ModelGradients {
pub weight_gradients: Vec<Array2<f32>>,
pub bias_gradients: Vec<Array1<f32>>,
pub batch_norm_gradients: Option<BatchNormGradients>,
}
#[derive(Debug, Clone)]
pub struct BatchNormGradients {
pub gamma_gradients: Vec<Array1<f32>>,
pub beta_gradients: Vec<Array1<f32>>,
}
impl ModelParameters {
pub fn new(architecture: &ModelArchitecture) -> Self {
let mut weights = Vec::new();
let mut biases = Vec::new();
let mut rng = rand::rng();
let mut dims = vec![architecture.input_dim];
dims.extend(&architecture.hidden_dims);
dims.push(architecture.output_dim);
for i in 0..dims.len() - 1 {
let input_dim = dims[i];
let output_dim = dims[i + 1];
let std = (2.0 / (input_dim + output_dim) as f32).sqrt();
let weight = Array2::from_shape_fn((output_dim, input_dim), |_| {
rng.uniform(-std, std)
});
let bias = Array1::zeros(output_dim);
weights.push(weight);
biases.push(bias);
}
let batch_norm_params = if architecture.use_batch_norm {
let mut scale = Vec::new();
let mut shift = Vec::new();
let mut running_mean = Vec::new();
let mut running_var = Vec::new();
for &dim in &architecture.hidden_dims {
scale.push(Array1::ones(dim));
shift.push(Array1::zeros(dim));
running_mean.push(Array1::zeros(dim));
running_var.push(Array1::ones(dim));
}
Some(BatchNormParameters {
scale,
shift,
running_mean,
running_var,
})
} else {
None
};
Self {
weights,
biases,
batch_norm_params,
}
}
pub fn clone_for_adaptation(&self) -> Self {
Self {
weights: self.weights.clone(),
biases: self.biases.clone(),
batch_norm_params: self.batch_norm_params.clone(),
}
}
pub fn update_with_gradients(&mut self, gradients: &ModelGradients, learning_rate: f32) {
for (weight, grad_weight) in self.weights.iter_mut().zip(&gradients.weight_gradients) {
*weight = &*weight - learning_rate * grad_weight;
}
for (bias, grad_bias) in self.biases.iter_mut().zip(&gradients.bias_gradients) {
*bias = &*bias - learning_rate * grad_bias;
}
if let (Some(bn_params), Some(bn_grads)) = (&mut self.batch_norm_params, &gradients.batch_norm_gradients) {
for (scale, grad_scale) in bn_params.scale.iter_mut().zip(&bn_grads.gamma_gradients) {
*scale = &*scale - learning_rate * grad_scale;
}
for (shift, grad_shift) in bn_params.shift.iter_mut().zip(&bn_grads.beta_gradients) {
*shift = &*shift - learning_rate * grad_shift;
}
}
}
}