use std::sync::Arc;
use crate::learn::lora::{ApplicatorError, ModelApplicator, TrainedModel};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ApplyMode {
#[default]
Manual,
Auto,
}
#[derive(Debug, Clone)]
pub struct ApplierConfig {
pub mode: ApplyMode,
pub max_history: usize,
}
impl Default for ApplierConfig {
fn default() -> Self {
Self {
mode: ApplyMode::Manual,
max_history: 5,
}
}
}
impl ApplierConfig {
pub fn auto_apply(mut self) -> Self {
self.mode = ApplyMode::Auto;
self
}
pub fn max_history(mut self, n: usize) -> Self {
self.max_history = n;
self
}
}
#[derive(Debug)]
pub enum ApplierError {
Applicator(ApplicatorError),
Skipped(String),
Other(String),
}
impl std::fmt::Display for ApplierError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Applicator(e) => write!(f, "Applicator error: {}", e),
Self::Skipped(msg) => write!(f, "Apply skipped: {}", msg),
Self::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for ApplierError {}
impl From<ApplicatorError> for ApplierError {
fn from(e: ApplicatorError) -> Self {
Self::Applicator(e)
}
}
#[derive(Debug)]
pub enum ApplyResult {
Applied {
model_id: String,
previous_model_id: Option<String>,
},
Skipped { model_id: String, reason: String },
}
impl ApplyResult {
pub fn is_applied(&self) -> bool {
matches!(self, Self::Applied { .. })
}
}
pub struct Applier {
config: ApplierConfig,
applicator: Arc<dyn ModelApplicator>,
history: Vec<String>,
}
impl Applier {
pub fn new(config: ApplierConfig, applicator: Arc<dyn ModelApplicator>) -> Self {
Self {
config,
applicator,
history: Vec::new(),
}
}
pub fn config(&self) -> &ApplierConfig {
&self.config
}
pub fn history(&self) -> &[String] {
&self.history
}
pub async fn apply(&mut self, model: &TrainedModel) -> Result<ApplyResult, ApplierError> {
match self.config.mode {
ApplyMode::Manual => {
tracing::info!(
model_id = %model.id,
"Model ready for manual apply (auto-apply disabled)"
);
Ok(ApplyResult::Skipped {
model_id: model.id.to_string(),
reason: "Auto-apply disabled".into(),
})
}
ApplyMode::Auto => self.apply_now(model).await,
}
}
pub async fn apply_now(&mut self, model: &TrainedModel) -> Result<ApplyResult, ApplierError> {
let previous_model_id = self.applicator.previous_model_id().map(|id| id.to_string());
tracing::info!(
model_id = %model.id,
previous = ?previous_model_id,
"Applying trained model"
);
self.applicator.apply(model).await?;
self.history.push(model.id.to_string());
if self.history.len() > self.config.max_history {
self.history.remove(0);
}
tracing::info!(
model_id = %model.id,
"Model applied successfully"
);
Ok(ApplyResult::Applied {
model_id: model.id.to_string(),
previous_model_id,
})
}
pub async fn rollback(&self) -> Result<(), ApplierError> {
let previous_id = self
.applicator
.previous_model_id()
.ok_or_else(|| ApplierError::Other("No previous model to rollback to".into()))?;
tracing::info!(target_id = %previous_id, "Rolling back to previous model");
self.applicator.rollback(&previous_id).await?;
Ok(())
}
pub fn current_model(&self) -> Option<TrainedModel> {
self.applicator.current()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learn::lora::{LoraModelId, NoOpApplicator};
use std::path::PathBuf;
fn create_test_model(id: &str) -> TrainedModel {
TrainedModel {
id: LoraModelId::parse(id),
base_model: "test-base".to_string(),
adapter_path: PathBuf::from("/tmp/test"),
learn_model_name: "test".to_string(),
episode_ids: vec![],
sample_count: 10,
created_at: 0,
metrics: None,
}
}
#[tokio::test]
async fn test_applier_manual_mode() {
let config = ApplierConfig::default(); let applicator = Arc::new(NoOpApplicator::new());
let mut applier = Applier::new(config, applicator);
let model = create_test_model("test-model-1");
let result = applier.apply(&model).await.unwrap();
assert!(!result.is_applied());
match result {
ApplyResult::Skipped { model_id, .. } => {
assert_eq!(model_id, "test-model-1");
}
_ => panic!("Expected Skipped"),
}
}
#[tokio::test]
async fn test_applier_auto_mode() {
let config = ApplierConfig::default().auto_apply();
let applicator = Arc::new(NoOpApplicator::new());
let mut applier = Applier::new(config, applicator);
let model = create_test_model("test-model-1");
let result = applier.apply(&model).await.unwrap();
assert!(result.is_applied());
assert_eq!(applier.history().len(), 1);
}
#[tokio::test]
async fn test_applier_history_limit() {
let config = ApplierConfig::default().auto_apply().max_history(2);
let applicator = Arc::new(NoOpApplicator::new());
let mut applier = Applier::new(config, applicator);
for i in 0..3 {
let model = create_test_model(&format!("model-{}", i));
applier.apply(&model).await.unwrap();
}
assert_eq!(applier.history().len(), 2);
assert_eq!(applier.history()[0], "model-1");
assert_eq!(applier.history()[1], "model-2");
}
}