1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::time::Instant;
9
10use anyhow::{Context, Result};
11use indicatif::{ProgressBar, ProgressStyle};
12use serde::{Deserialize, Serialize};
13
14use oxigaf_flame::sequence::FlameSequence;
15
16#[derive(Debug)]
22pub struct GaussianModel {
23 #[allow(dead_code)]
24 num_gaussians: usize,
25}
26
27impl GaussianModel {
28 pub fn new(num_gaussians: usize) -> Self {
29 Self { num_gaussians }
30 }
31}
32
33pub trait PipelineStage: Send + Sync {
35 fn name(&self) -> &str;
37
38 fn run(&mut self, ctx: &mut PipelineContext) -> Result<()>;
44
45 fn progress(&self) -> f32 {
47 0.0
48 }
49
50 fn eta_seconds(&self) -> Option<f64> {
52 None
53 }
54}
55
56#[derive(Default)]
60pub struct PipelineContext {
61 pub video_path: Option<PathBuf>,
63
64 pub flame_sequence: Option<FlameSequence>,
66
67 pub generated_images: Vec<image::RgbImage>,
69
70 pub generated_masks: Vec<image::GrayImage>,
72
73 pub trained_model: Option<GaussianModel>,
75
76 pub metrics: HashMap<String, f32>,
78
79 pub checkpoint_dir: Option<PathBuf>,
81
82 pub current_stage: usize,
84
85 pub total_stages: usize,
87}
88
89impl PipelineContext {
90 pub fn new() -> Self {
92 Self::default()
93 }
94
95 pub fn with_checkpoint_dir(mut self, dir: PathBuf) -> Self {
97 self.checkpoint_dir = Some(dir);
98 self
99 }
100
101 pub fn save_checkpoint(&self, stage_name: &str) -> Result<()> {
107 if let Some(ref checkpoint_dir) = self.checkpoint_dir {
108 std::fs::create_dir_all(checkpoint_dir)
109 .context("Failed to create checkpoint directory")?;
110
111 let checkpoint_path = checkpoint_dir.join(format!("stage_{}.json", stage_name));
112
113 let checkpoint = CheckpointData {
114 stage_name: stage_name.to_string(),
115 current_stage: self.current_stage,
116 total_stages: self.total_stages,
117 metrics: self.metrics.clone(),
118 has_flame_sequence: self.flame_sequence.is_some(),
119 num_generated_images: self.generated_images.len(),
120 has_trained_model: self.trained_model.is_some(),
121 };
122
123 let json = serde_json::to_string_pretty(&checkpoint)
124 .context("Failed to serialize checkpoint")?;
125 std::fs::write(&checkpoint_path, json).context("Failed to write checkpoint file")?;
126
127 tracing::info!("Saved checkpoint: {}", checkpoint_path.display());
128 }
129
130 Ok(())
131 }
132
133 pub fn load_checkpoint(checkpoint_dir: &Path, stage_name: &str) -> Result<CheckpointData> {
139 let checkpoint_path = checkpoint_dir.join(format!("stage_{}.json", stage_name));
140
141 let json = std::fs::read_to_string(&checkpoint_path)
142 .with_context(|| format!("Failed to read checkpoint: {}", checkpoint_path.display()))?;
143
144 serde_json::from_str(&json).context("Failed to parse checkpoint JSON")
145 }
146}
147
148#[derive(Debug, Serialize, Deserialize)]
150pub struct CheckpointData {
151 pub stage_name: String,
152 pub current_stage: usize,
153 pub total_stages: usize,
154 pub metrics: HashMap<String, f32>,
155 pub has_flame_sequence: bool,
156 pub num_generated_images: usize,
157 pub has_trained_model: bool,
158}
159
160pub struct TrackingStage {
162 video_path: PathBuf,
163 #[allow(dead_code)]
164 output_path: PathBuf,
165 progress: f32,
166}
167
168impl TrackingStage {
169 pub fn new(video_path: PathBuf, output_path: PathBuf) -> Self {
171 Self {
172 video_path,
173 output_path,
174 progress: 0.0,
175 }
176 }
177}
178
179impl PipelineStage for TrackingStage {
180 fn name(&self) -> &str {
181 "Tracking"
182 }
183
184 fn run(&mut self, ctx: &mut PipelineContext) -> Result<()> {
185 tracing::info!(
186 "Starting FLAME tracking from video: {}",
187 self.video_path.display()
188 );
189
190 self.progress = 0.5;
193
194 let _sequence = FlameSequence::from_memory(vec![], Some(30.0));
196
197 self.progress = 1.0;
198
199 ctx.video_path = Some(self.video_path.clone());
201 ctx.metrics.insert("tracking_fps".to_string(), 30.0);
202
203 Ok(())
204 }
205
206 fn progress(&self) -> f32 {
207 self.progress
208 }
209}
210
211pub struct DiffusionStage {
213 num_views: usize,
214 resolution: (u32, u32),
215 progress: f32,
216}
217
218impl DiffusionStage {
219 pub fn new(num_views: usize, resolution: (u32, u32)) -> Self {
221 Self {
222 num_views,
223 resolution,
224 progress: 0.0,
225 }
226 }
227}
228
229impl PipelineStage for DiffusionStage {
230 fn name(&self) -> &str {
231 "Diffusion"
232 }
233
234 fn run(&mut self, ctx: &mut PipelineContext) -> Result<()> {
235 tracing::info!(
236 "Generating {} views at {:?}",
237 self.num_views,
238 self.resolution
239 );
240
241 if ctx.flame_sequence.is_none() {
242 anyhow::bail!("No FLAME sequence available for diffusion");
243 }
244
245 self.progress = 0.5;
248
249 let (width, height) = self.resolution;
250 for i in 0..self.num_views {
251 let img = image::RgbImage::new(width, height);
252 ctx.generated_images.push(img);
253
254 let mask = image::GrayImage::new(width, height);
255 ctx.generated_masks.push(mask);
256
257 self.progress = 0.5 + 0.5 * (i as f32 / self.num_views as f32);
258 }
259
260 ctx.metrics
261 .insert("num_views".to_string(), self.num_views as f32);
262
263 Ok(())
264 }
265
266 fn progress(&self) -> f32 {
267 self.progress
268 }
269}
270
271pub struct TrainingStage {
273 num_iterations: usize,
274 current_iteration: usize,
275 start_time: Option<Instant>,
276}
277
278impl TrainingStage {
279 pub fn new(num_iterations: usize) -> Self {
281 Self {
282 num_iterations,
283 current_iteration: 0,
284 start_time: None,
285 }
286 }
287}
288
289impl PipelineStage for TrainingStage {
290 fn name(&self) -> &str {
291 "Training"
292 }
293
294 fn run(&mut self, ctx: &mut PipelineContext) -> Result<()> {
295 tracing::info!(
296 "Training 3D Gaussians for {} iterations",
297 self.num_iterations
298 );
299
300 if ctx.generated_images.is_empty() {
301 anyhow::bail!("No generated images available for training");
302 }
303
304 self.start_time = Some(Instant::now());
305
306 for i in 0..self.num_iterations {
309 self.current_iteration = i + 1;
310
311 std::thread::sleep(std::time::Duration::from_millis(1));
313
314 if i % 100 == 0 {
315 let loss = 1.0 / (i as f32 + 1.0);
316 ctx.metrics.insert(format!("loss_iter_{}", i), loss);
317 }
318 }
319
320 let model = GaussianModel::new(1000);
322 ctx.trained_model = Some(model);
323
324 ctx.metrics.insert("final_loss".to_string(), 0.01);
325 ctx.metrics
326 .insert("iterations".to_string(), self.num_iterations as f32);
327
328 Ok(())
329 }
330
331 fn progress(&self) -> f32 {
332 if self.num_iterations == 0 {
333 return 0.0;
334 }
335 self.current_iteration as f32 / self.num_iterations as f32
336 }
337
338 fn eta_seconds(&self) -> Option<f64> {
339 if let Some(start) = self.start_time {
340 if self.current_iteration > 0 {
341 let elapsed = start.elapsed().as_secs_f64();
342 let per_iter = elapsed / self.current_iteration as f64;
343 let remaining = (self.num_iterations - self.current_iteration) as f64 * per_iter;
344 return Some(remaining);
345 }
346 }
347 None
348 }
349}
350
351pub struct ExportStage {
353 format: ExportFormat,
354 output_path: PathBuf,
355}
356
357#[derive(Debug, Clone, Copy)]
359pub enum ExportFormat {
360 Ply,
362 Gltf,
364 Binary,
366}
367
368impl ExportStage {
369 pub fn new(format: ExportFormat, output_path: PathBuf) -> Self {
371 Self {
372 format,
373 output_path,
374 }
375 }
376}
377
378impl PipelineStage for ExportStage {
379 fn name(&self) -> &str {
380 "Export"
381 }
382
383 fn run(&mut self, ctx: &mut PipelineContext) -> Result<()> {
384 tracing::info!(
385 "Exporting model to {:?} format: {}",
386 self.format,
387 self.output_path.display()
388 );
389
390 if ctx.trained_model.is_none() {
391 anyhow::bail!("No trained model available for export");
392 }
393
394 std::fs::write(&self.output_path, b"placeholder").context("Failed to write export file")?;
397
398 ctx.metrics.insert("exported".to_string(), 1.0);
399
400 Ok(())
401 }
402
403 fn progress(&self) -> f32 {
404 1.0
406 }
407}
408
409pub struct PipelineExecutor {
411 stages: Vec<Box<dyn PipelineStage>>,
412 show_progress: bool,
413}
414
415impl PipelineExecutor {
416 pub fn new() -> Self {
418 Self {
419 stages: Vec::new(),
420 show_progress: true,
421 }
422 }
423
424 pub fn add_stage(&mut self, stage: Box<dyn PipelineStage>) -> &mut Self {
426 self.stages.push(stage);
427 self
428 }
429
430 pub fn show_progress(&mut self, show: bool) -> &mut Self {
432 self.show_progress = show;
433 self
434 }
435
436 pub fn execute(&mut self, mut ctx: PipelineContext) -> Result<PipelineContext> {
442 ctx.total_stages = self.stages.len();
443
444 for (i, stage) in self.stages.iter_mut().enumerate() {
445 ctx.current_stage = i;
446
447 let stage_name = stage.name().to_string(); tracing::info!(
449 "Executing stage {}/{}: {}",
450 i + 1,
451 ctx.total_stages,
452 stage_name
453 );
454
455 let pb = if self.show_progress {
456 let pb = ProgressBar::new(100);
457 pb.set_style(
458 ProgressStyle::default_bar()
459 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}% {msg}")
460 .context("Failed to create progress bar template")?
461 .progress_chars("#>-"),
462 );
463 pb.set_message(stage_name.clone());
464 Some(pb)
465 } else {
466 None
467 };
468
469 stage
471 .run(&mut ctx)
472 .with_context(|| format!("Stage '{}' failed", stage_name))?;
473
474 if let Some(pb) = pb {
475 pb.set_position(100);
476 pb.finish_with_message(format!("{} complete", stage_name));
477 }
478
479 ctx.save_checkpoint(&stage_name)?;
481
482 tracing::info!("Stage {} complete", stage_name);
483 }
484
485 tracing::info!("Pipeline complete! Executed {} stages", ctx.total_stages);
486
487 Ok(ctx)
488 }
489}
490
491impl Default for PipelineExecutor {
492 fn default() -> Self {
493 Self::new()
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use tempfile::TempDir;
501
502 #[test]
503 fn test_pipeline_context_creation() {
504 let ctx = PipelineContext::new();
505 assert_eq!(ctx.current_stage, 0);
506 assert_eq!(ctx.total_stages, 0);
507 assert!(ctx.flame_sequence.is_none());
508 assert!(ctx.trained_model.is_none());
509 }
510
511 #[test]
512 fn test_checkpoint_save_load() {
513 let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
514 let mut ctx = PipelineContext::new().with_checkpoint_dir(temp_dir.path().to_path_buf());
515 ctx.current_stage = 1;
516 ctx.total_stages = 3;
517 ctx.metrics.insert("test".to_string(), 42.0);
518
519 ctx.save_checkpoint("test_stage")
521 .expect("test: checkpoint operation should succeed");
522
523 let loaded = PipelineContext::load_checkpoint(temp_dir.path(), "test_stage")
525 .expect("test: checkpoint operation should succeed");
526 assert_eq!(loaded.stage_name, "test_stage");
527 assert_eq!(loaded.current_stage, 1);
528 assert_eq!(loaded.total_stages, 3);
529 assert_eq!(loaded.metrics.get("test"), Some(&42.0));
530 }
531
532 #[test]
533 fn test_training_stage_progress() {
534 let mut stage = TrainingStage::new(100);
535 assert_eq!(stage.progress(), 0.0);
536
537 stage.current_iteration = 50;
538 assert_eq!(stage.progress(), 0.5);
539
540 stage.current_iteration = 100;
541 assert_eq!(stage.progress(), 1.0);
542 }
543
544 #[test]
545 fn test_pipeline_executor() {
546 let mut executor = PipelineExecutor::new();
547 executor.show_progress(false);
548
549 let ctx = PipelineContext::new();
550
551 let result = executor.execute(ctx);
553 assert!(result.is_ok());
554 }
555
556 #[test]
557 fn test_export_stage() {
558 let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
559 let output_path = temp_dir.path().join("model.ply");
560
561 let mut stage = ExportStage::new(ExportFormat::Ply, output_path.clone());
562 let mut ctx = PipelineContext::new();
563
564 assert!(stage.run(&mut ctx).is_err());
566
567 ctx.trained_model = Some(GaussianModel::new(10));
569 assert!(stage.run(&mut ctx).is_ok());
570
571 assert!(output_path.exists());
573 }
574}