use crate::distributed::{RunInfo, RunStatus};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde_json;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, thiserror::Error)]
pub enum RunInfoError {
#[error("I/O error: {0}")]
Io(String),
#[error("Not found")]
NotFound,
#[error("Other error: {0}")]
Other(String),
}
#[async_trait]
pub trait RunInfoStore {
async fn insert_run(&self, info: RunInfo) -> Result<(), RunInfoError>;
async fn update_status(&self, run_id: &str, status: RunStatus) -> Result<(), RunInfoError>;
async fn update_finished_at(
&self,
run_id: &str,
finished_at: DateTime<Utc>,
) -> Result<(), RunInfoError>;
async fn get_run(&self, run_id: &str) -> Result<Option<RunInfo>, RunInfoError>;
async fn list_runs(&self, filter: Option<RunStatus>) -> Result<Vec<RunInfo>, RunInfoError>;
async fn update_output(
&self,
run_id: &str,
output: serde_json::Value,
) -> Result<(), RunInfoError>;
}
#[derive(Clone, Default)]
pub struct InMemoryRunInfoStore {
inner: Arc<Mutex<HashMap<String, RunInfo>>>,
}
#[async_trait]
impl RunInfoStore for InMemoryRunInfoStore {
async fn insert_run(&self, info: RunInfo) -> Result<(), RunInfoError> {
let mut map = self.inner.lock().await;
map.insert(info.run_id.clone(), info);
Ok(())
}
async fn update_status(&self, run_id: &str, status: RunStatus) -> Result<(), RunInfoError> {
let mut map = self.inner.lock().await;
if let Some(info) = map.get_mut(run_id) {
info.status = status;
Ok(())
} else {
Err(RunInfoError::NotFound)
}
}
async fn update_finished_at(
&self,
run_id: &str,
finished_at: DateTime<Utc>,
) -> Result<(), RunInfoError> {
let mut map = self.inner.lock().await;
if let Some(info) = map.get_mut(run_id) {
info.finished_at = Some(finished_at);
Ok(())
} else {
Err(RunInfoError::NotFound)
}
}
async fn get_run(&self, run_id: &str) -> Result<Option<RunInfo>, RunInfoError> {
let map = self.inner.lock().await;
Ok(map.get(run_id).cloned())
}
async fn list_runs(&self, filter: Option<RunStatus>) -> Result<Vec<RunInfo>, RunInfoError> {
let map = self.inner.lock().await;
let runs = map
.values()
.filter(|info| filter.as_ref().is_none_or(|f| *f == info.status))
.cloned()
.collect();
Ok(runs)
}
async fn update_output(
&self,
run_id: &str,
output: serde_json::Value,
) -> Result<(), RunInfoError> {
let mut map = self.inner.lock().await;
if let Some(info) = map.get_mut(run_id) {
info.output = Some(output);
Ok(())
} else {
Err(RunInfoError::NotFound)
}
}
}