use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaLearningConfig {
pub algorithm: MetaAlgorithm,
pub inner_lr: f64,
pub meta_lr: f64,
pub inner_steps: usize,
pub support_size: usize,
pub query_size: usize,
pub num_ways: usize,
pub num_shots: usize,
pub first_order: bool,
pub temperature: f64,
pub embedding_dim: usize,
pub normalize_embeddings: bool,
pub memory_size: usize,
pub memory_key_dim: usize,
pub memory_value_dim: usize,
pub meta_batch_size: usize,
pub task_specific_params: bool,
pub inner_l2_reg: f64,
pub grad_clip_norm: f64,
}
impl Default for MetaLearningConfig {
fn default() -> Self {
Self {
algorithm: MetaAlgorithm::MAML,
inner_lr: 0.01,
meta_lr: 0.001,
inner_steps: 5,
support_size: 5,
query_size: 15,
num_ways: 5,
num_shots: 1,
first_order: false,
temperature: 1.0,
embedding_dim: 512,
normalize_embeddings: true,
memory_size: 128,
memory_key_dim: 64,
memory_value_dim: 256,
meta_batch_size: 32,
task_specific_params: false,
inner_l2_reg: 0.0001,
grad_clip_norm: 10.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MetaAlgorithm {
MAML,
Reptile,
ProtoNet,
MatchingNet,
RelationNet,
MANN,
GBML,
MetaSGD,
L2L,
}