use super::maml::MamlOptimizer;
use super::types::{linear_predict, MetaLearnerConfig, Task};
use crate::error::TimeSeriesError;
struct TaskBuilder {
xs: Vec<Vec<f64>>,
ys: Vec<f64>,
split_ratio: f64, }
impl TaskBuilder {
fn new(split_ratio: f64) -> Self {
Self {
xs: Vec::new(),
ys: Vec::new(),
split_ratio,
}
}
fn add(&mut self, x: Vec<f64>, y: f64) {
self.xs.push(x);
self.ys.push(y);
}
fn len(&self) -> usize {
self.ys.len()
}
fn build(self) -> Option<Task> {
let n = self.xs.len();
if n < 2 {
return None;
}
let n_support = (n as f64 * self.split_ratio).ceil() as usize;
let n_support = n_support.min(n - 1).max(1);
let n_query = n - n_support;
if n_query == 0 {
return None;
}
let support_x = self.xs[..n_support].to_vec();
let support_y = self.ys[..n_support].to_vec();
let query_x = self.xs[n_support..].to_vec();
let query_y = self.ys[n_support..].to_vec();
Some(Task {
support_x,
support_y,
query_x,
query_y,
})
}
}
pub struct OnlineMetaLearner {
optimizer: MamlOptimizer,
task_buffer: Vec<Task>,
current_task_builder: TaskBuilder,
task_window_size: usize,
adapted_weights: Vec<f64>,
adapted_bias: f64,
meta_update_count: usize,
}
impl OnlineMetaLearner {
pub fn new(config: MetaLearnerConfig, task_window_size: usize) -> Self {
let adapted_weights = vec![0.0; config.feature_dim];
let optimizer = MamlOptimizer::new(config);
Self {
optimizer,
task_buffer: Vec::new(),
current_task_builder: TaskBuilder::new(0.7),
task_window_size,
adapted_weights,
adapted_bias: 0.0,
meta_update_count: 0,
}
}
pub fn update(&mut self, x: &[f64], y: f64) {
self.current_task_builder.add(x.to_vec(), y);
}
pub fn current_task_len(&self) -> usize {
self.current_task_builder.len()
}
pub fn task_buffer_len(&self) -> usize {
self.task_buffer.len()
}
pub fn finalize_task(&mut self) -> Option<Task> {
let old_builder = std::mem::replace(&mut self.current_task_builder, TaskBuilder::new(0.7));
let task = old_builder.build()?;
let adapted = self.optimizer.model().inner_update(
&task,
self.optimizer.config.inner_lr,
self.optimizer.config.n_inner_steps,
);
self.adapted_weights = adapted.weights.clone();
self.adapted_bias = adapted.bias;
let max_buf = self.optimizer.config.task_buffer_size;
if self.task_buffer.len() >= max_buf {
self.task_buffer.remove(0);
}
self.task_buffer.push(task.clone());
Some(task)
}
pub fn maybe_finalize(&mut self) -> Option<Task> {
if self.current_task_builder.len() >= self.task_window_size {
self.finalize_task()
} else {
None
}
}
pub fn meta_update_if_ready(&mut self) -> Option<f64> {
let min_tasks = self.optimizer.config.min_tasks_for_update;
if self.task_buffer.len() < min_tasks {
return None;
}
let loss = self.optimizer.meta_train_step(&self.task_buffer);
self.meta_update_count += 1;
Some(loss)
}
pub fn predict(&self, x: &[f64]) -> f64 {
linear_predict(&self.adapted_weights, self.adapted_bias, x)
}
pub fn meta_update_count(&self) -> usize {
self.meta_update_count
}
pub fn flush(&mut self) -> Result<Option<f64>, TimeSeriesError> {
self.finalize_task();
Ok(self.meta_update_if_ready())
}
pub fn meta_weights(&self) -> (&[f64], f64) {
let m = self.optimizer.model();
(&m.weights, m.bias)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_online_maml_buffer_fills() {
let config = MetaLearnerConfig {
feature_dim: 2,
task_buffer_size: 4,
min_tasks_for_update: 4,
..Default::default()
};
let mut learner = OnlineMetaLearner::new(config, 10);
for task_idx in 0..4_usize {
for i in 0..10 {
let x = vec![i as f64 / 10.0, task_idx as f64];
let y = i as f64 / 10.0 + task_idx as f64;
learner.update(&x, y);
}
learner.finalize_task();
}
assert_eq!(
learner.task_buffer_len(),
4,
"buffer should contain 4 tasks"
);
}
#[test]
fn test_online_maml_predict_shape() {
let config = MetaLearnerConfig {
feature_dim: 3,
..Default::default()
};
let learner = OnlineMetaLearner::new(config, 20);
let pred = learner.predict(&[1.0, 2.0, 3.0]);
assert!(pred.is_finite(), "prediction should be finite");
}
#[test]
fn test_online_maml_meta_update_triggered() {
let config = MetaLearnerConfig {
feature_dim: 1,
task_buffer_size: 4,
min_tasks_for_update: 4,
inner_lr: 0.05,
outer_lr: 0.01,
n_inner_steps: 5,
};
let mut learner = OnlineMetaLearner::new(config, 8);
for slope in [1.0_f64, 2.0, 3.0, 4.0] {
for i in 0..8 {
let x = vec![i as f64 / 8.0];
let y = slope * i as f64 / 8.0;
learner.update(&x, y);
}
learner.finalize_task();
}
let result = learner.meta_update_if_ready();
assert!(result.is_some(), "meta-update should have been triggered");
let loss = result.expect("meta loss should exist");
assert!(loss.is_finite(), "meta loss should be finite: {loss}");
}
}