1use crate::autoscaler::{Autoscaler, MetricsSnapshot, ScalingDecision};
8use crate::builder::ForgeConfig;
9use crate::error::Result;
10use crate::job::Job;
11use crate::metrics::ForgeMetrics;
12use crate::moe::{BoxedMoERouter, RouteResult};
13use crate::networking::{HttpServer, HttpState};
14use crate::nomad::NomadClient;
15use crate::storage::{keys, store_get_json, store_set_json, BoxedStateStore};
16use crate::types::{Expert, NodeId, Shard, ShardId};
17use axum::{
18 extract::{Path, State},
19 routing::{delete, get, post},
20 Json, Router,
21};
22use dashmap::DashMap;
23use std::sync::Arc;
24use std::time::Instant;
25use tokio::sync::{broadcast, RwLock};
26use tracing::{error, info, warn};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum RuntimeState {
31 Stopped,
33 Starting,
35 Running,
37 ShuttingDown,
39}
40
41pub struct Forge {
43 config: ForgeConfig,
44 state: Arc<RwLock<RuntimeState>>,
45 node_id: NodeId,
46 start_time: Option<Instant>,
47
48 nomad: Option<NomadClient>,
50 store: BoxedStateStore,
51 router: BoxedMoERouter,
52 autoscaler: Autoscaler,
53 metrics: Option<Arc<ForgeMetrics>>,
54
55 jobs: DashMap<String, Job>,
57 experts: DashMap<usize, Expert>,
58 shards: DashMap<ShardId, Shard>,
59
60 shutdown_tx: broadcast::Sender<()>,
62}
63
64impl Forge {
65 pub(crate) fn new(
67 config: ForgeConfig,
68 nomad: Option<NomadClient>,
69 store: BoxedStateStore,
70 router: BoxedMoERouter,
71 autoscaler: Autoscaler,
72 metrics: Option<Arc<ForgeMetrics>>,
73 ) -> Self {
74 let (shutdown_tx, _) = broadcast::channel(1);
75
76 Self {
77 config,
78 state: Arc::new(RwLock::new(RuntimeState::Stopped)),
79 node_id: NodeId::new(),
80 start_time: None,
81 nomad,
82 store,
83 router,
84 autoscaler,
85 metrics,
86 jobs: DashMap::new(),
87 experts: DashMap::new(),
88 shards: DashMap::new(),
89 shutdown_tx,
90 }
91 }
92
93 pub fn node_id(&self) -> NodeId {
95 self.node_id
96 }
97
98 pub async fn state(&self) -> RuntimeState {
100 *self.state.read().await
101 }
102
103 pub fn uptime_secs(&self) -> u64 {
105 self.start_time
106 .map(|t| t.elapsed().as_secs())
107 .unwrap_or(0)
108 }
109
110 pub fn metrics(&self) -> Option<&Arc<ForgeMetrics>> {
112 self.metrics.as_ref()
113 }
114
115 pub fn store(&self) -> &BoxedStateStore {
117 &self.store
118 }
119
120 pub fn router(&self) -> &BoxedMoERouter {
122 &self.router
123 }
124
125 pub fn has_nomad(&self) -> bool {
127 self.nomad.is_some()
128 }
129
130 pub async fn run(mut self) -> Result<()> {
132 {
133 let mut state = self.state.write().await;
134 *state = RuntimeState::Starting;
135 }
136
137 info!(
138 node_id = %self.node_id,
139 node_name = %self.config.node_name,
140 "Starting Forge control plane"
141 );
142
143 self.start_time = Some(Instant::now());
144
145 if let Some(nomad) = &self.nomad {
147 match nomad.health().await {
148 Ok(true) => info!("Nomad connection verified"),
149 Ok(false) => warn!("Nomad returned unhealthy status"),
150 Err(e) => warn!(error = %e, "Failed to connect to Nomad"),
151 }
152 }
153
154 self.load_state().await?;
156
157 {
158 let mut state = self.state.write().await;
159 *state = RuntimeState::Running;
160 }
161
162 info!("Forge control plane running");
163
164 let forge_state = Arc::new(RwLock::new(ForgeHttpState {
166 jobs: self.jobs.clone(),
167 metrics: self.metrics.clone(),
168 }));
169
170 let http_router = self.build_http_router(forge_state);
172
173 let http_server = HttpServer::new(self.config.http_config.clone())
175 .with_router(http_router);
176
177 let mut shutdown_rx = self.shutdown_tx.subscribe();
179
180 tokio::select! {
181 result = http_server.serve() => {
182 if let Err(e) = result {
183 error!(error = %e, "HTTP server error");
184 }
185 }
186 _ = shutdown_rx.recv() => {
187 info!("Shutdown signal received");
188 }
189 }
190
191 self.shutdown().await?;
192
193 Ok(())
194 }
195
196 pub async fn shutdown(&self) -> Result<()> {
198 {
199 let mut state = self.state.write().await;
200 if *state == RuntimeState::Stopped {
201 return Ok(());
202 }
203 *state = RuntimeState::ShuttingDown;
204 }
205
206 info!("Shutting down Forge control plane");
207
208 self.save_state().await?;
210
211 let _ = self.shutdown_tx.send(());
213
214 {
215 let mut state = self.state.write().await;
216 *state = RuntimeState::Stopped;
217 }
218
219 info!("Forge control plane stopped");
220 Ok(())
221 }
222
223 pub fn signal_shutdown(&self) {
225 let _ = self.shutdown_tx.send(());
226 }
227
228 pub fn shutdown_receiver(&self) -> broadcast::Receiver<()> {
230 self.shutdown_tx.subscribe()
231 }
232
233 async fn load_state(&self) -> Result<()> {
236 let job_keys = self.store.list_prefix(keys::JOBS).await?;
238 for key in job_keys {
239 if let Some(job) = store_get_json::<Job>(self.store.as_ref(), &key).await? {
240 self.jobs.insert(job.id.clone(), job);
241 }
242 }
243
244 info!(jobs = self.jobs.len(), "Loaded state from store");
245 Ok(())
246 }
247
248 async fn save_state(&self) -> Result<()> {
249 for entry in self.jobs.iter() {
251 let key = keys::job(&entry.key());
252 store_set_json(self.store.as_ref(), &key, entry.value()).await?;
253 }
254
255 info!(jobs = self.jobs.len(), "Saved state to store");
256 Ok(())
257 }
258
259 pub async fn submit_job(&self, job: Job) -> Result<String> {
263 let job_id = job.id.clone();
264
265 if let Some(nomad) = &self.nomad {
267 nomad.submit_job(&job).await?;
268 }
269
270 let key = keys::job(&job_id);
272 store_set_json(self.store.as_ref(), &key, &job).await?;
273 self.jobs.insert(job_id.clone(), job);
274
275 if let Some(metrics) = &self.metrics {
276 metrics.record_job_submitted();
277 }
278
279 info!(job_id = %job_id, "Job submitted");
280 Ok(job_id)
281 }
282
283 pub fn get_job(&self, job_id: &str) -> Option<Job> {
285 self.jobs.get(job_id).map(|e| e.value().clone())
286 }
287
288 pub fn list_jobs(&self) -> Vec<Job> {
290 self.jobs.iter().map(|e| e.value().clone()).collect()
291 }
292
293 pub async fn stop_job(&self, job_id: &str, purge: bool) -> Result<()> {
295 if let Some(nomad) = &self.nomad {
297 nomad.stop_job(job_id, purge).await?;
298 }
299
300 if purge {
302 let key = keys::job(job_id);
303 self.store.delete(&key).await?;
304 self.jobs.remove(job_id);
305 }
306
307 if let Some(metrics) = &self.metrics {
308 metrics.record_job_completed(true);
309 }
310
311 info!(job_id = %job_id, purge = purge, "Job stopped");
312 Ok(())
313 }
314
315 pub async fn scale_job(&self, job_id: &str, group: &str, count: u32) -> Result<()> {
317 if let Some(nomad) = &self.nomad {
319 nomad
320 .scale_job(job_id, group, count, Some("Manual scale"))
321 .await?;
322 }
323
324 if let Some(mut job) = self.jobs.get_mut(job_id) {
326 for g in &mut job.groups {
327 if g.name == group {
328 g.scaling.desired = count;
329 break;
330 }
331 }
332 }
333
334 if let Some(metrics) = &self.metrics {
335 let direction = "manual";
336 metrics.record_scale_event(job_id, direction);
337 metrics.set_instances(job_id, group, count as f64);
338 }
339
340 info!(job_id = %job_id, group = %group, count = count, "Job scaled");
341 Ok(())
342 }
343
344 pub async fn route(&self, input: &str) -> RouteResult {
348 let experts: Vec<Expert> = self.experts.iter().map(|e| e.value().clone()).collect();
349
350 let result = if experts.is_empty() {
351 self.router.route(input, 8).await
352 } else {
353 self.router.route_with_experts(input, &experts).await
354 };
355
356 if let Some(metrics) = &self.metrics {
357 metrics.record_route(self.router.name(), result.expert_index, 0.001);
358 }
359
360 result
361 }
362
363 pub fn register_expert(&self, expert: Expert) {
365 info!(index = expert.index, node = %expert.node, "Expert registered");
366 self.experts.insert(expert.index, expert);
367 }
368
369 pub fn update_expert_load(&self, index: usize, load: f64) {
371 if let Some(mut expert) = self.experts.get_mut(&index) {
372 expert.update_load(load);
373 }
374 }
375
376 pub async fn evaluate_scaling(
380 &self,
381 job_id: &str,
382 cpu: f64,
383 memory: f64,
384 instances: u32,
385 ) -> ScalingDecision {
386 let metrics = MetricsSnapshot::new(cpu, memory, instances);
387 let decision = self.autoscaler.evaluate(job_id, metrics).await;
388
389 if decision.is_scaling() {
390 if let Some(m) = &self.metrics {
391 let direction = match &decision {
392 ScalingDecision::ScaleUp(_) => "up",
393 ScalingDecision::ScaleDown(_) => "down",
394 _ => "none",
395 };
396 m.record_scale_event(job_id, direction);
397 }
398 }
399
400 decision
401 }
402
403 fn build_http_router(&self, state: Arc<RwLock<ForgeHttpState>>) -> Router {
406 let state = HttpState { app: state };
407
408 Router::new()
409 .route("/health", get(health_handler))
410 .route("/ready", get(ready_handler))
411 .route("/api/v1/jobs", get(list_jobs_handler))
412 .route("/api/v1/jobs", post(submit_job_handler))
413 .route("/api/v1/jobs/:id", get(get_job_handler))
414 .route("/api/v1/jobs/:id", delete(stop_job_handler))
415 .route("/metrics", get(metrics_handler))
416 .with_state(state)
417 }
418}
419
420struct ForgeHttpState {
422 jobs: DashMap<String, Job>,
423 metrics: Option<Arc<ForgeMetrics>>,
424}
425
426async fn health_handler() -> Json<serde_json::Value> {
429 Json(serde_json::json!({
430 "status": "healthy",
431 "version": env!("CARGO_PKG_VERSION")
432 }))
433}
434
435async fn ready_handler() -> axum::http::StatusCode {
436 axum::http::StatusCode::OK
437}
438
439async fn list_jobs_handler(
440 State(state): State<HttpState<ForgeHttpState>>,
441) -> Json<Vec<Job>> {
442 let app = state.app.read().await;
443 let jobs: Vec<Job> = app.jobs.iter().map(|e| e.value().clone()).collect();
444 Json(jobs)
445}
446
447async fn get_job_handler(
448 State(state): State<HttpState<ForgeHttpState>>,
449 Path(id): Path<String>,
450) -> std::result::Result<Json<Job>, axum::http::StatusCode> {
451 let app = state.app.read().await;
452 app.jobs
453 .get(&id)
454 .map(|e| Json(e.value().clone()))
455 .ok_or(axum::http::StatusCode::NOT_FOUND)
456}
457
458async fn submit_job_handler(
459 State(state): State<HttpState<ForgeHttpState>>,
460 Json(job): Json<Job>,
461) -> std::result::Result<Json<serde_json::Value>, axum::http::StatusCode> {
462 let app = state.app.read().await;
463 let job_id = job.id.clone();
464 app.jobs.insert(job_id.clone(), job);
465 Ok(Json(serde_json::json!({ "job_id": job_id })))
466}
467
468async fn stop_job_handler(
469 State(state): State<HttpState<ForgeHttpState>>,
470 Path(id): Path<String>,
471) -> axum::http::StatusCode {
472 let app = state.app.read().await;
473 if app.jobs.remove(&id).is_some() {
474 axum::http::StatusCode::NO_CONTENT
475 } else {
476 axum::http::StatusCode::NOT_FOUND
477 }
478}
479
480async fn metrics_handler(
481 State(state): State<HttpState<ForgeHttpState>>,
482) -> std::result::Result<String, axum::http::StatusCode> {
483 let app = state.app.read().await;
484 match &app.metrics {
485 Some(m) => m
486 .gather_text()
487 .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR),
488 None => Err(axum::http::StatusCode::NOT_FOUND),
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use crate::builder::ForgeBuilder;
496 use crate::job::{Driver, Task};
497
498 #[tokio::test]
499 async fn test_forge_creation() {
500 let forge = ForgeBuilder::new().build().unwrap();
501 assert_eq!(forge.state().await, RuntimeState::Stopped);
502 }
503
504 #[tokio::test]
505 async fn test_job_management() {
506 let forge = ForgeBuilder::new().build().unwrap();
507
508 let job = Job::new("test-job").with_group(
509 "api",
510 Task::new("server")
511 .driver(Driver::Exec)
512 .command("/bin/server"),
513 );
514
515 let job_id = forge.submit_job(job).await.unwrap();
516 assert!(forge.get_job(&job_id).is_some());
517
518 let jobs = forge.list_jobs();
519 assert_eq!(jobs.len(), 1);
520
521 forge.stop_job(&job_id, true).await.unwrap();
522 assert!(forge.get_job(&job_id).is_none());
523 }
524
525 #[tokio::test]
526 async fn test_routing() {
527 let forge = ForgeBuilder::new().build().unwrap();
528
529 let result = forge.route("test-input").await;
530 assert!(result.expert_index < 8);
531 }
532
533 #[tokio::test]
534 async fn test_expert_registration() {
535 let forge = ForgeBuilder::new().build().unwrap();
536
537 let expert = Expert::new(0, NodeId::new());
538 forge.register_expert(expert);
539
540 forge.update_expert_load(0, 0.5);
541 }
542}