forge_orchestration/
runtime.rs

1//! Forge runtime and control plane
2//!
3//! ## Table of Contents
4//! - **Forge**: Main runtime struct
5//! - **ForgeHandle**: Handle for interacting with running Forge
6
7use crate::autoscaler::{Autoscaler, MetricsSnapshot, ScalingDecision};
8use crate::builder::ForgeConfig;
9use crate::error::Result;
10use crate::job::Job;
11use crate::metrics::ForgeMetrics;
12use crate::moe::{BoxedMoERouter, RouteResult};
13use crate::networking::{HttpServer, HttpState};
14use crate::nomad::NomadClient;
15use crate::storage::{keys, store_get_json, store_set_json, BoxedStateStore};
16use crate::types::{Expert, NodeId, Shard, ShardId};
17use axum::{
18    extract::{Path, State},
19    routing::{delete, get, post},
20    Json, Router,
21};
22use dashmap::DashMap;
23use std::sync::Arc;
24use std::time::Instant;
25use tokio::sync::{broadcast, RwLock};
26use tracing::{error, info, warn};
27
28/// Runtime state
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum RuntimeState {
31    /// Not started
32    Stopped,
33    /// Starting up
34    Starting,
35    /// Running normally
36    Running,
37    /// Shutting down
38    ShuttingDown,
39}
40
41/// Main Forge runtime
42pub struct Forge {
43    config: ForgeConfig,
44    state: Arc<RwLock<RuntimeState>>,
45    node_id: NodeId,
46    start_time: Option<Instant>,
47
48    // Core components
49    nomad: Option<NomadClient>,
50    store: BoxedStateStore,
51    router: BoxedMoERouter,
52    autoscaler: Autoscaler,
53    metrics: Option<Arc<ForgeMetrics>>,
54
55    // Runtime state
56    jobs: DashMap<String, Job>,
57    experts: DashMap<usize, Expert>,
58    shards: DashMap<ShardId, Shard>,
59
60    // Shutdown signal
61    shutdown_tx: broadcast::Sender<()>,
62}
63
64impl Forge {
65    /// Create a new Forge instance (use ForgeBuilder instead)
66    pub(crate) fn new(
67        config: ForgeConfig,
68        nomad: Option<NomadClient>,
69        store: BoxedStateStore,
70        router: BoxedMoERouter,
71        autoscaler: Autoscaler,
72        metrics: Option<Arc<ForgeMetrics>>,
73    ) -> Self {
74        let (shutdown_tx, _) = broadcast::channel(1);
75
76        Self {
77            config,
78            state: Arc::new(RwLock::new(RuntimeState::Stopped)),
79            node_id: NodeId::new(),
80            start_time: None,
81            nomad,
82            store,
83            router,
84            autoscaler,
85            metrics,
86            jobs: DashMap::new(),
87            experts: DashMap::new(),
88            shards: DashMap::new(),
89            shutdown_tx,
90        }
91    }
92
93    /// Get the node ID
94    pub fn node_id(&self) -> NodeId {
95        self.node_id
96    }
97
98    /// Get current runtime state
99    pub async fn state(&self) -> RuntimeState {
100        *self.state.read().await
101    }
102
103    /// Get uptime in seconds
104    pub fn uptime_secs(&self) -> u64 {
105        self.start_time
106            .map(|t| t.elapsed().as_secs())
107            .unwrap_or(0)
108    }
109
110    /// Get metrics instance
111    pub fn metrics(&self) -> Option<&Arc<ForgeMetrics>> {
112        self.metrics.as_ref()
113    }
114
115    /// Get the state store
116    pub fn store(&self) -> &BoxedStateStore {
117        &self.store
118    }
119
120    /// Get the MoE router
121    pub fn router(&self) -> &BoxedMoERouter {
122        &self.router
123    }
124
125    /// Check if Nomad is configured
126    pub fn has_nomad(&self) -> bool {
127        self.nomad.is_some()
128    }
129
130    /// Run the Forge control plane
131    pub async fn run(mut self) -> Result<()> {
132        {
133            let mut state = self.state.write().await;
134            *state = RuntimeState::Starting;
135        }
136
137        info!(
138            node_id = %self.node_id,
139            node_name = %self.config.node_name,
140            "Starting Forge control plane"
141        );
142
143        self.start_time = Some(Instant::now());
144
145        // Verify Nomad connectivity if configured
146        if let Some(nomad) = &self.nomad {
147            match nomad.health().await {
148                Ok(true) => info!("Nomad connection verified"),
149                Ok(false) => warn!("Nomad returned unhealthy status"),
150                Err(e) => warn!(error = %e, "Failed to connect to Nomad"),
151            }
152        }
153
154        // Load existing state from store
155        self.load_state().await?;
156
157        {
158            let mut state = self.state.write().await;
159            *state = RuntimeState::Running;
160        }
161
162        info!("Forge control plane running");
163
164        // Create shared state for HTTP handlers
165        let forge_state = Arc::new(RwLock::new(ForgeHttpState {
166            jobs: self.jobs.clone(),
167            metrics: self.metrics.clone(),
168        }));
169
170        // Build HTTP router
171        let http_router = self.build_http_router(forge_state);
172
173        // Start HTTP server
174        let http_server = HttpServer::new(self.config.http_config.clone())
175            .with_router(http_router);
176
177        // Run until shutdown
178        let mut shutdown_rx = self.shutdown_tx.subscribe();
179
180        tokio::select! {
181            result = http_server.serve() => {
182                if let Err(e) = result {
183                    error!(error = %e, "HTTP server error");
184                }
185            }
186            _ = shutdown_rx.recv() => {
187                info!("Shutdown signal received");
188            }
189        }
190
191        self.shutdown().await?;
192
193        Ok(())
194    }
195
196    /// Shutdown the control plane
197    pub async fn shutdown(&self) -> Result<()> {
198        {
199            let mut state = self.state.write().await;
200            if *state == RuntimeState::Stopped {
201                return Ok(());
202            }
203            *state = RuntimeState::ShuttingDown;
204        }
205
206        info!("Shutting down Forge control plane");
207
208        // Save state
209        self.save_state().await?;
210
211        // Send shutdown signal
212        let _ = self.shutdown_tx.send(());
213
214        {
215            let mut state = self.state.write().await;
216            *state = RuntimeState::Stopped;
217        }
218
219        info!("Forge control plane stopped");
220        Ok(())
221    }
222
223    /// Signal shutdown
224    pub fn signal_shutdown(&self) {
225        let _ = self.shutdown_tx.send(());
226    }
227
228    /// Subscribe to shutdown signal
229    pub fn shutdown_receiver(&self) -> broadcast::Receiver<()> {
230        self.shutdown_tx.subscribe()
231    }
232
233    // State management
234
235    async fn load_state(&self) -> Result<()> {
236        // Load jobs from store
237        let job_keys = self.store.list_prefix(keys::JOBS).await?;
238        for key in job_keys {
239            if let Some(job) = store_get_json::<Job>(self.store.as_ref(), &key).await? {
240                self.jobs.insert(job.id.clone(), job);
241            }
242        }
243
244        info!(jobs = self.jobs.len(), "Loaded state from store");
245        Ok(())
246    }
247
248    async fn save_state(&self) -> Result<()> {
249        // Save all jobs
250        for entry in self.jobs.iter() {
251            let key = keys::job(&entry.key());
252            store_set_json(self.store.as_ref(), &key, entry.value()).await?;
253        }
254
255        info!(jobs = self.jobs.len(), "Saved state to store");
256        Ok(())
257    }
258
259    // Job management
260
261    /// Submit a job
262    pub async fn submit_job(&self, job: Job) -> Result<String> {
263        let job_id = job.id.clone();
264
265        // Submit to Nomad if configured
266        if let Some(nomad) = &self.nomad {
267            nomad.submit_job(&job).await?;
268        }
269
270        // Store locally
271        let key = keys::job(&job_id);
272        store_set_json(self.store.as_ref(), &key, &job).await?;
273        self.jobs.insert(job_id.clone(), job);
274
275        if let Some(metrics) = &self.metrics {
276            metrics.record_job_submitted();
277        }
278
279        info!(job_id = %job_id, "Job submitted");
280        Ok(job_id)
281    }
282
283    /// Get a job by ID
284    pub fn get_job(&self, job_id: &str) -> Option<Job> {
285        self.jobs.get(job_id).map(|e| e.value().clone())
286    }
287
288    /// List all jobs
289    pub fn list_jobs(&self) -> Vec<Job> {
290        self.jobs.iter().map(|e| e.value().clone()).collect()
291    }
292
293    /// Stop a job
294    pub async fn stop_job(&self, job_id: &str, purge: bool) -> Result<()> {
295        // Stop in Nomad if configured
296        if let Some(nomad) = &self.nomad {
297            nomad.stop_job(job_id, purge).await?;
298        }
299
300        // Remove locally
301        if purge {
302            let key = keys::job(job_id);
303            self.store.delete(&key).await?;
304            self.jobs.remove(job_id);
305        }
306
307        if let Some(metrics) = &self.metrics {
308            metrics.record_job_completed(true);
309        }
310
311        info!(job_id = %job_id, purge = purge, "Job stopped");
312        Ok(())
313    }
314
315    /// Scale a job's task group
316    pub async fn scale_job(&self, job_id: &str, group: &str, count: u32) -> Result<()> {
317        // Scale in Nomad if configured
318        if let Some(nomad) = &self.nomad {
319            nomad
320                .scale_job(job_id, group, count, Some("Manual scale"))
321                .await?;
322        }
323
324        // Update local state
325        if let Some(mut job) = self.jobs.get_mut(job_id) {
326            for g in &mut job.groups {
327                if g.name == group {
328                    g.scaling.desired = count;
329                    break;
330                }
331            }
332        }
333
334        if let Some(metrics) = &self.metrics {
335            let direction = "manual";
336            metrics.record_scale_event(job_id, direction);
337            metrics.set_instances(job_id, group, count as f64);
338        }
339
340        info!(job_id = %job_id, group = %group, count = count, "Job scaled");
341        Ok(())
342    }
343
344    // MoE routing
345
346    /// Route an input to an expert
347    pub async fn route(&self, input: &str) -> RouteResult {
348        let experts: Vec<Expert> = self.experts.iter().map(|e| e.value().clone()).collect();
349
350        let result = if experts.is_empty() {
351            self.router.route(input, 8).await
352        } else {
353            self.router.route_with_experts(input, &experts).await
354        };
355
356        if let Some(metrics) = &self.metrics {
357            metrics.record_route(self.router.name(), result.expert_index, 0.001);
358        }
359
360        result
361    }
362
363    /// Register an expert
364    pub fn register_expert(&self, expert: Expert) {
365        info!(index = expert.index, node = %expert.node, "Expert registered");
366        self.experts.insert(expert.index, expert);
367    }
368
369    /// Update expert load
370    pub fn update_expert_load(&self, index: usize, load: f64) {
371        if let Some(mut expert) = self.experts.get_mut(&index) {
372            expert.update_load(load);
373        }
374    }
375
376    // Autoscaling
377
378    /// Evaluate autoscaling for a job
379    pub async fn evaluate_scaling(
380        &self,
381        job_id: &str,
382        cpu: f64,
383        memory: f64,
384        instances: u32,
385    ) -> ScalingDecision {
386        let metrics = MetricsSnapshot::new(cpu, memory, instances);
387        let decision = self.autoscaler.evaluate(job_id, metrics).await;
388
389        if decision.is_scaling() {
390            if let Some(m) = &self.metrics {
391                let direction = match &decision {
392                    ScalingDecision::ScaleUp(_) => "up",
393                    ScalingDecision::ScaleDown(_) => "down",
394                    _ => "none",
395                };
396                m.record_scale_event(job_id, direction);
397            }
398        }
399
400        decision
401    }
402
403    // HTTP router
404
405    fn build_http_router(&self, state: Arc<RwLock<ForgeHttpState>>) -> Router {
406        let state = HttpState { app: state };
407
408        Router::new()
409            .route("/health", get(health_handler))
410            .route("/ready", get(ready_handler))
411            .route("/api/v1/jobs", get(list_jobs_handler))
412            .route("/api/v1/jobs", post(submit_job_handler))
413            .route("/api/v1/jobs/:id", get(get_job_handler))
414            .route("/api/v1/jobs/:id", delete(stop_job_handler))
415            .route("/metrics", get(metrics_handler))
416            .with_state(state)
417    }
418}
419
420// HTTP state for handlers
421struct ForgeHttpState {
422    jobs: DashMap<String, Job>,
423    metrics: Option<Arc<ForgeMetrics>>,
424}
425
426// HTTP handlers
427
428async fn health_handler() -> Json<serde_json::Value> {
429    Json(serde_json::json!({
430        "status": "healthy",
431        "version": env!("CARGO_PKG_VERSION")
432    }))
433}
434
435async fn ready_handler() -> axum::http::StatusCode {
436    axum::http::StatusCode::OK
437}
438
439async fn list_jobs_handler(
440    State(state): State<HttpState<ForgeHttpState>>,
441) -> Json<Vec<Job>> {
442    let app = state.app.read().await;
443    let jobs: Vec<Job> = app.jobs.iter().map(|e| e.value().clone()).collect();
444    Json(jobs)
445}
446
447async fn get_job_handler(
448    State(state): State<HttpState<ForgeHttpState>>,
449    Path(id): Path<String>,
450) -> std::result::Result<Json<Job>, axum::http::StatusCode> {
451    let app = state.app.read().await;
452    app.jobs
453        .get(&id)
454        .map(|e| Json(e.value().clone()))
455        .ok_or(axum::http::StatusCode::NOT_FOUND)
456}
457
458async fn submit_job_handler(
459    State(state): State<HttpState<ForgeHttpState>>,
460    Json(job): Json<Job>,
461) -> std::result::Result<Json<serde_json::Value>, axum::http::StatusCode> {
462    let app = state.app.read().await;
463    let job_id = job.id.clone();
464    app.jobs.insert(job_id.clone(), job);
465    Ok(Json(serde_json::json!({ "job_id": job_id })))
466}
467
468async fn stop_job_handler(
469    State(state): State<HttpState<ForgeHttpState>>,
470    Path(id): Path<String>,
471) -> axum::http::StatusCode {
472    let app = state.app.read().await;
473    if app.jobs.remove(&id).is_some() {
474        axum::http::StatusCode::NO_CONTENT
475    } else {
476        axum::http::StatusCode::NOT_FOUND
477    }
478}
479
480async fn metrics_handler(
481    State(state): State<HttpState<ForgeHttpState>>,
482) -> std::result::Result<String, axum::http::StatusCode> {
483    let app = state.app.read().await;
484    match &app.metrics {
485        Some(m) => m
486            .gather_text()
487            .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR),
488        None => Err(axum::http::StatusCode::NOT_FOUND),
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use crate::builder::ForgeBuilder;
496    use crate::job::{Driver, Task};
497
498    #[tokio::test]
499    async fn test_forge_creation() {
500        let forge = ForgeBuilder::new().build().unwrap();
501        assert_eq!(forge.state().await, RuntimeState::Stopped);
502    }
503
504    #[tokio::test]
505    async fn test_job_management() {
506        let forge = ForgeBuilder::new().build().unwrap();
507
508        let job = Job::new("test-job").with_group(
509            "api",
510            Task::new("server")
511                .driver(Driver::Exec)
512                .command("/bin/server"),
513        );
514
515        let job_id = forge.submit_job(job).await.unwrap();
516        assert!(forge.get_job(&job_id).is_some());
517
518        let jobs = forge.list_jobs();
519        assert_eq!(jobs.len(), 1);
520
521        forge.stop_job(&job_id, true).await.unwrap();
522        assert!(forge.get_job(&job_id).is_none());
523    }
524
525    #[tokio::test]
526    async fn test_routing() {
527        let forge = ForgeBuilder::new().build().unwrap();
528
529        let result = forge.route("test-input").await;
530        assert!(result.expert_index < 8);
531    }
532
533    #[tokio::test]
534    async fn test_expert_registration() {
535        let forge = ForgeBuilder::new().build().unwrap();
536
537        let expert = Expert::new(0, NodeId::new());
538        forge.register_expert(expert);
539
540        forge.update_expert_load(0, 0.5);
541    }
542}