Skip to main content

entrenar/staging/
mod.rs

1//! Model Staging Workflows (GH-70)
2//!
3//! Provides a lightweight model staging registry for managing model versions
4//! through lifecycle stages: Dev -> Staging -> Production.
5//!
6//! Transition rules enforce no skipping: a model must progress through each
7//! stage sequentially, and demotion follows the reverse path.
8//!
9//! # Example
10//!
11//! ```
12//! use entrenar::staging::{Stage, StagingRegistry};
13//!
14//! let mut registry = StagingRegistry::new();
15//! let mv = registry.register_model("llama-7b", "1.0.0", "/models/llama-7b-v1");
16//! assert_eq!(mv.stage, Stage::Dev);
17//!
18//! registry.promote("llama-7b", "1.0.0", Stage::Staging).expect("promote to staging");
19//! registry.promote("llama-7b", "1.0.0", Stage::Production).expect("promote to production");
20//! ```
21
22#[cfg(test)]
23mod tests;
24
25use std::collections::HashMap;
26
27use chrono::{DateTime, Utc};
28use serde::{Deserialize, Serialize};
29use thiserror::Error;
30
31/// Model lifecycle stage.
32///
33/// Models progress linearly: Dev -> Staging -> Production.
34/// No stage may be skipped during promotion or demotion.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub enum Stage {
37    /// In active development
38    Dev,
39    /// Under validation/testing
40    Staging,
41    /// Deployed to production
42    Production,
43}
44
45impl Stage {
46    /// Numeric ordering for stage progression.
47    fn ordinal(self) -> u8 {
48        match self {
49            Stage::Dev => 0,
50            Stage::Staging => 1,
51            Stage::Production => 2,
52        }
53    }
54
55    /// Display name for the stage.
56    pub fn as_str(self) -> &'static str {
57        match self {
58            Stage::Dev => "Dev",
59            Stage::Staging => "Staging",
60            Stage::Production => "Production",
61        }
62    }
63}
64
65impl std::fmt::Display for Stage {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        write!(f, "{}", self.as_str())
68    }
69}
70
71/// Metadata for a registered model version.
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ModelVersion {
74    /// Model name (e.g., "llama-7b-finetuned")
75    pub name: String,
76    /// Semantic version string (e.g., "1.0.0")
77    pub version: String,
78    /// Current lifecycle stage
79    pub stage: Stage,
80    /// Arbitrary key-value metadata
81    pub metadata: HashMap<String, String>,
82    /// When this version was registered
83    pub created_at: DateTime<Utc>,
84    /// When this version was last promoted or demoted (None if still at initial stage)
85    pub promoted_at: Option<DateTime<Utc>>,
86    /// Path to model artifacts
87    pub path: String,
88}
89
90/// Errors from staging operations.
91#[derive(Debug, Error)]
92pub enum StagingError {
93    /// The requested model/version was not found in the registry.
94    #[error("model not found: {name} v{version}")]
95    NotFound { name: String, version: String },
96
97    /// The requested stage transition is not allowed.
98    /// Only single-step transitions (adjacent stages) are permitted.
99    #[error("invalid transition from {from} to {to} for {name} v{version}")]
100    InvalidTransition { name: String, version: String, from: Stage, to: Stage },
101
102    /// A model with this name and version already exists.
103    #[error("model already exists: {name} v{version}")]
104    AlreadyExists { name: String, version: String },
105}
106
107/// Result type for staging operations.
108pub type Result<T> = std::result::Result<T, StagingError>;
109
110/// Registry managing model versions and their lifecycle stages.
111///
112/// Models are keyed by (name, version) pairs. The registry enforces
113/// sequential stage transitions: Dev -> Staging -> Production for
114/// promotion, and the reverse for demotion.
115#[derive(Debug, Default)]
116pub struct StagingRegistry {
117    /// All registered model versions, keyed by (name, version).
118    models: HashMap<(String, String), ModelVersion>,
119}
120
121impl StagingRegistry {
122    /// Create an empty staging registry.
123    pub fn new() -> Self {
124        Self { models: HashMap::new() }
125    }
126
127    /// Register a new model version at the Dev stage.
128    ///
129    /// # Panics
130    ///
131    /// Does not panic. Returns the existing version if already registered
132    /// (idempotent behavior for re-registration at the same key).
133    pub fn register_model(&mut self, name: &str, version: &str, path: &str) -> ModelVersion {
134        let key = (name.to_string(), version.to_string());
135        let mv = ModelVersion {
136            name: name.to_string(),
137            version: version.to_string(),
138            stage: Stage::Dev,
139            metadata: HashMap::new(),
140            created_at: Utc::now(),
141            promoted_at: None,
142            path: path.to_string(),
143        };
144        self.models.entry(key).or_insert(mv).clone()
145    }
146
147    /// Promote a model version to the given target stage.
148    ///
149    /// The target must be exactly one stage above the current stage:
150    /// - Dev -> Staging
151    /// - Staging -> Production
152    ///
153    /// Skipping stages (e.g., Dev -> Production) is rejected.
154    pub fn promote(&mut self, name: &str, version: &str, target: Stage) -> Result<ModelVersion> {
155        let key = (name.to_string(), version.to_string());
156        let mv = self.models.get_mut(&key).ok_or_else(|| StagingError::NotFound {
157            name: name.to_string(),
158            version: version.to_string(),
159        })?;
160
161        let current_ord = mv.stage.ordinal();
162        let target_ord = target.ordinal();
163
164        // Promote must go exactly one step up
165        if target_ord != current_ord + 1 {
166            return Err(StagingError::InvalidTransition {
167                name: name.to_string(),
168                version: version.to_string(),
169                from: mv.stage,
170                to: target,
171            });
172        }
173
174        mv.stage = target;
175        mv.promoted_at = Some(Utc::now());
176        Ok(mv.clone())
177    }
178
179    /// Demote a model version to the given target stage.
180    ///
181    /// The target must be exactly one stage below the current stage:
182    /// - Production -> Staging
183    /// - Staging -> Dev
184    ///
185    /// Skipping stages (e.g., Production -> Dev) is rejected.
186    pub fn demote(&mut self, name: &str, version: &str, target: Stage) -> Result<ModelVersion> {
187        let key = (name.to_string(), version.to_string());
188        let mv = self.models.get_mut(&key).ok_or_else(|| StagingError::NotFound {
189            name: name.to_string(),
190            version: version.to_string(),
191        })?;
192
193        let current_ord = mv.stage.ordinal();
194        let target_ord = target.ordinal();
195
196        // Demote must go exactly one step down
197        if current_ord == 0 || target_ord != current_ord - 1 {
198            return Err(StagingError::InvalidTransition {
199                name: name.to_string(),
200                version: version.to_string(),
201                from: mv.stage,
202                to: target,
203            });
204        }
205
206        mv.stage = target;
207        mv.promoted_at = Some(Utc::now());
208        Ok(mv.clone())
209    }
210
211    /// Get the latest version of a model at the given stage.
212    ///
213    /// "Latest" is determined by `created_at` timestamp. Returns `None`
214    /// if no version of the model exists at that stage.
215    pub fn get_latest(&self, name: &str, stage: Stage) -> Option<&ModelVersion> {
216        self.models
217            .values()
218            .filter(|mv| mv.name == name && mv.stage == stage)
219            .max_by_key(|mv| mv.created_at)
220    }
221
222    /// List all versions of a model, sorted by creation time (oldest first).
223    pub fn list_versions(&self, name: &str) -> Vec<&ModelVersion> {
224        let mut versions: Vec<&ModelVersion> =
225            self.models.values().filter(|mv| mv.name == name).collect();
226        versions.sort_by_key(|mv| mv.created_at);
227        versions
228    }
229}