Skip to main content

expman_core/
engine.rs

1//! Async logging engine: the heart of expman-rs.
2//!
3//! `LoggingEngine::new()` spawns a background tokio task that owns all file handles.
4//! `log_metrics()` is a channel send — O(1), never blocks the experiment process.
5//! The background task batches rows and flushes to Parquet periodically.
6
7use std::collections::HashMap;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12
13use chrono::Utc;
14use tokio::runtime::Runtime;
15use tokio::sync::{mpsc, oneshot};
16use tokio::time::interval;
17use tracing::{error, info};
18
19use crate::error::{ExpmanError, Result};
20use crate::models::{ExperimentConfig, MetricRow, MetricValue, RunMetadata, RunStatus};
21use crate::storage;
22
23/// Commands sent to the background logging task.
24enum LogCommand {
25    /// Log a row of metrics.
26    Metric(MetricRow),
27    /// Update the config/params YAML.
28    Params(HashMap<String, serde_yaml::Value>),
29    /// Copy an artifact file into the run's artifacts directory.
30    Artifact(PathBuf),
31    /// Log a message to the run log file.
32    Log { level: LogLevel, message: String },
33    /// Force flush the current buffer to disk.
34    Flush(oneshot::Sender<Result<()>>),
35    /// Gracefully shut down: flush everything, write final metadata.
36    Shutdown {
37        status: RunStatus,
38        reply: oneshot::Sender<()>,
39    },
40}
41
42#[derive(Debug, Clone, Copy)]
43pub enum LogLevel {
44    Info,
45    Warn,
46    Error,
47}
48
49/// The non-blocking logging engine.
50///
51/// Internally holds a sender to a tokio mpsc channel. All heavy I/O
52/// happens in a background task on a dedicated tokio runtime thread.
53pub struct LoggingEngine {
54    sender: mpsc::UnboundedSender<LogCommand>,
55    /// Keep the runtime alive as long as the engine exists.
56    _runtime: Arc<Runtime>,
57    config: ExperimentConfig,
58}
59
60impl LoggingEngine {
61    /// Create a new `LoggingEngine` for the given config.
62    ///
63    /// This initializes the run directory, writes initial metadata,
64    /// and spawns the background I/O task.
65    pub fn new(config: ExperimentConfig) -> Result<Self> {
66        // Set up directories
67        let run_dir = config.run_dir();
68        storage::ensure_dir(&run_dir)?;
69        storage::ensure_dir(&run_dir.join("artifacts"))?;
70
71        // Write initial run metadata
72        let meta = RunMetadata {
73            name: config.run_name.clone(),
74            experiment: config.name.clone(),
75            status: RunStatus::Running,
76            started_at: Utc::now(),
77            language: Some(config.language.clone()),
78            env_path: config.env_path.clone(),
79            ..Default::default()
80        };
81        storage::save_run_metadata(&run_dir, &meta)?;
82
83        // Ensure experiment metadata exists
84        let exp_dir = config.experiment_dir();
85        storage::ensure_dir(&exp_dir)?;
86        let exp_meta_path = exp_dir.join("experiment.yaml");
87        if !exp_meta_path.exists() {
88            storage::save_experiment_metadata(
89                &exp_dir,
90                &crate::models::ExperimentMetadata::default(),
91            )?;
92        }
93
94        // Set up log file
95        let log_path = run_dir.join("run.log");
96
97        // Build dedicated tokio runtime for background I/O
98        let runtime = Arc::new(
99            tokio::runtime::Builder::new_multi_thread()
100                .worker_threads(1)
101                .thread_name("expman-io")
102                .enable_all()
103                .build()
104                .map_err(|e| ExpmanError::Other(e.to_string()))?,
105        );
106
107        let (sender, receiver) = mpsc::unbounded_channel::<LogCommand>();
108
109        // Spawn background task
110        let flush_rows = config.flush_interval_rows;
111        let flush_ms = config.flush_interval_ms;
112        let run_dir_clone = run_dir.clone();
113        runtime.spawn(background_task(
114            receiver,
115            run_dir_clone,
116            log_path,
117            flush_rows,
118            flush_ms,
119        ));
120
121        info!(
122            experiment = %config.name,
123            run = %config.run_name,
124            "LoggingEngine initialized"
125        );
126
127        Ok(Self {
128            sender,
129            _runtime: runtime,
130            config,
131        })
132    }
133
134    /// Log a row of metrics. Non-blocking — channel send only.
135    pub fn log_metrics(&self, values: HashMap<String, MetricValue>, step: Option<u64>) {
136        let row = MetricRow::new(values, step);
137        // If channel is closed (engine shut down), silently drop.
138        let _ = self.sender.send(LogCommand::Metric(row));
139    }
140
141    /// Log/update experiment parameters (config). Non-blocking.
142    pub fn log_params(&self, params: HashMap<String, serde_yaml::Value>) {
143        let _ = self.sender.send(LogCommand::Params(params));
144    }
145
146    /// Save an artifact file asynchronously. Non-blocking.
147    /// The path is relative to the current working directory for the source,
148    /// and will be preserved as a relative path within the run's artifacts directory.
149    pub fn save_artifact(&self, path: PathBuf) {
150        let _ = self.sender.send(LogCommand::Artifact(path));
151    }
152
153    /// Log a message to the run log. Non-blocking.
154    pub fn log_message(&self, level: LogLevel, message: String) {
155        let _ = self.sender.send(LogCommand::Log { level, message });
156    }
157
158    /// Force flush the metric buffer to disk. Async — awaits completion.
159    pub async fn flush(&self) -> Result<()> {
160        let (tx, rx) = oneshot::channel();
161        self.sender
162            .send(LogCommand::Flush(tx))
163            .map_err(|_| ExpmanError::ChannelClosed)?;
164        rx.await.map_err(|_| ExpmanError::ChannelClosed)?
165    }
166
167    /// Gracefully shut down: flush all pending metrics, write final metadata.
168    /// Blocks until complete. Should be called at experiment end.
169    pub fn close(&self, status: RunStatus) {
170        let (tx, rx) = oneshot::channel();
171        if self
172            .sender
173            .send(LogCommand::Shutdown { status, reply: tx })
174            .is_ok()
175        {
176            // Block current thread until background task confirms shutdown.
177            // We use the runtime's block_on for this.
178            let _ = self._runtime.block_on(rx);
179        }
180    }
181
182    pub fn config(&self) -> &ExperimentConfig {
183        &self.config
184    }
185}
186
187impl Drop for LoggingEngine {
188    fn drop(&mut self) {
189        // Best-effort graceful shutdown on drop
190        let (tx, rx) = oneshot::channel();
191        if self
192            .sender
193            .send(LogCommand::Shutdown {
194                status: RunStatus::Finished,
195                reply: tx,
196            })
197            .is_ok()
198        {
199            let _ = self
200                ._runtime
201                .block_on(async { tokio::time::timeout(Duration::from_secs(5), rx).await });
202        }
203    }
204}
205
206// ─── Background I/O task ─────────────────────────────────────────────────────
207
208async fn background_task(
209    mut receiver: mpsc::UnboundedReceiver<LogCommand>,
210    run_dir: PathBuf,
211    log_path: PathBuf,
212    flush_interval_rows: usize,
213    flush_interval_ms: u64,
214) {
215    let metrics_path = run_dir.join("metrics.parquet");
216    let config_path = run_dir.join("config.yaml");
217    let _meta_path = run_dir.join("run.yaml");
218    let artifacts_dir = run_dir.join("artifacts");
219
220    let mut metric_buffer: Vec<MetricRow> = Vec::with_capacity(flush_interval_rows * 2);
221    let mut log_lines: Vec<String> = Vec::new();
222    let mut flush_ticker = interval(Duration::from_millis(flush_interval_ms));
223    flush_ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
224
225    let started_at = Utc::now();
226
227    loop {
228        tokio::select! {
229            // Prioritize incoming commands
230            biased;
231
232            cmd = receiver.recv() => {
233                match cmd {
234                    None => {
235                        // Channel closed — flush and exit
236                        flush_metrics(&metrics_path, &mut metric_buffer);
237                        flush_logs(&log_path, &mut log_lines);
238                        break;
239                    }
240                    Some(LogCommand::Metric(row)) => {
241                        metric_buffer.push(row);
242                        if metric_buffer.len() >= flush_interval_rows {
243                            flush_metrics(&metrics_path, &mut metric_buffer);
244                        }
245                    }
246                    Some(LogCommand::Params(params)) => {
247                        handle_params(&config_path, params);
248                    }
249                    Some(LogCommand::Artifact(path)) => {
250                        handle_artifact(&artifacts_dir, path);
251                    }
252                    Some(LogCommand::Log { level, message }) => {
253                        let ts = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ");
254                        let level_str = match level {
255                            LogLevel::Info => "INFO",
256                            LogLevel::Warn => "WARN",
257                            LogLevel::Error => "ERROR",
258                        };
259                        log_lines.push(format!("[{ts}] [{level_str}] {message}"));
260                        if log_lines.len() >= 20 {
261                            flush_logs(&log_path, &mut log_lines);
262                        }
263                    }
264                    Some(LogCommand::Flush(reply)) => {
265                        flush_metrics(&metrics_path, &mut metric_buffer);
266                        flush_logs(&log_path, &mut log_lines);
267                        let _ = reply.send(Ok(()));
268                    }
269                    Some(LogCommand::Shutdown { status, reply }) => {
270                        // Final flush
271                        flush_metrics(&metrics_path, &mut metric_buffer);
272                        flush_logs(&log_path, &mut log_lines);
273
274                        // Update run metadata with final status
275                        let finished_at = Utc::now();
276                        let duration = (finished_at - started_at).num_milliseconds() as f64 / 1000.0;
277
278                        if let Ok(mut meta) = storage::load_run_metadata(&run_dir) {
279                            meta.status = status;
280                            meta.finished_at = Some(finished_at);
281                            meta.duration_secs = Some(duration);
282                            let _ = storage::save_run_metadata(&run_dir, &meta);
283                        }
284
285                        let _ = reply.send(());
286                        break;
287                    }
288                }
289            }
290
291            // Periodic flush
292            _ = flush_ticker.tick() => {
293                if !metric_buffer.is_empty() {
294                    flush_metrics(&metrics_path, &mut metric_buffer);
295                }
296                if !log_lines.is_empty() {
297                    flush_logs(&log_path, &mut log_lines);
298                }
299            }
300        }
301    }
302}
303
304fn flush_metrics(path: &std::path::Path, buffer: &mut Vec<MetricRow>) {
305    if buffer.is_empty() {
306        return;
307    }
308    if let Err(e) = storage::append_metrics(path, buffer) {
309        error!("Failed to flush metrics: {}", e);
310    }
311    buffer.clear();
312}
313
314fn flush_logs(path: &std::path::Path, lines: &mut Vec<String>) {
315    if lines.is_empty() {
316        return;
317    }
318    use std::io::Write;
319    match fs::OpenOptions::new().create(true).append(true).open(path) {
320        Ok(mut f) => {
321            for line in lines.iter() {
322                let _ = writeln!(f, "{}", line);
323            }
324        }
325        Err(e) => error!("Failed to write log: {}", e),
326    }
327    lines.clear();
328}
329
330fn handle_params(config_path: &std::path::Path, new_params: HashMap<String, serde_yaml::Value>) {
331    // Load existing, merge, save
332    let mut existing: HashMap<String, serde_yaml::Value> =
333        storage::load_yaml(config_path).unwrap_or_default();
334    existing.extend(new_params);
335    if let Err(e) = storage::save_yaml(config_path, &existing) {
336        error!("Failed to save params: {}", e);
337    }
338}
339
340fn handle_artifact(artifacts_dir: &std::path::Path, path: PathBuf) {
341    // If path is absolute, join() will replace artifacts_dir.
342    // We want to save the file into artifacts_dir, preserving its filename.
343    let dest = if path.is_absolute() {
344        if let Some(filename) = path.file_name() {
345            artifacts_dir.join(filename)
346        } else {
347            error!("Invalid artifact path: {}", path.display());
348            return;
349        }
350    } else {
351        artifacts_dir.join(&path)
352    };
353
354    if let Some(parent) = dest.parent() {
355        if let Err(e) = fs::create_dir_all(parent) {
356            error!("Failed to create artifact dir: {}", e);
357            return;
358        }
359    }
360    if let Err(e) = fs::copy(&path, &dest) {
361        error!(
362            "Failed to copy artifact {} -> {}: {}",
363            path.display(),
364            dest.display(),
365            e
366        );
367    }
368}