use crate::models::PyDecisionSnapshot;
use crate::runtime::{PythonAsyncExt, PythonAsyncVecExt};
use crate::storage::{PyLakeFSBackend, PySqliteBackend, PyStorageBackend};
use briefcase_core::{
ReplayEngine, ReplayMode, ReplayPolicy, ReplayResult, ReplayStats, ReplayStatus,
};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
#[pyclass(name = "ReplayEngine")]
pub struct PyReplayEngine {
pub inner: ReplayEngine<PyStorageBackend>,
}
#[pymethods]
impl PyReplayEngine {
#[new]
fn new(storage: PyObject) -> PyResult<Self> {
Python::with_gil(|py| {
if let Ok(sqlite_backend) = storage.extract::<PyRef<PySqliteBackend>>(py) {
let unified_backend = PyStorageBackend::Sqlite(sqlite_backend.inner.clone());
Ok(Self {
inner: ReplayEngine::new(unified_backend),
})
} else if let Ok(lakefs_backend) = storage.extract::<PyRef<PyLakeFSBackend>>(py) {
let unified_backend = PyStorageBackend::LakeFS(lakefs_backend.inner.clone());
Ok(Self {
inner: ReplayEngine::new(unified_backend),
})
} else {
Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
"Storage must be SqliteBackend or LakeFSBackend",
))
}
})
}
#[getter]
fn default_mode(&self) -> String {
match self.inner.default_mode() {
ReplayMode::Strict => "strict".to_string(),
ReplayMode::Tolerant => "tolerant".to_string(),
ReplayMode::ValidationOnly => "validation_only".to_string(),
}
}
fn replay(&self, snapshot_id: String, mode: Option<String>) -> PyResult<PyReplayResult> {
let replay_mode = if let Some(mode_str) = mode {
Some(match mode_str.as_str() {
"strict" => ReplayMode::Strict,
"tolerant" => ReplayMode::Tolerant,
"validation_only" => ReplayMode::ValidationOnly,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid replay mode: {}",
mode_str
)))
}
})
} else {
None
};
let engine = self.inner.clone();
engine
.replay(&snapshot_id, replay_mode, None)
.block_on_python()
.map(|result| PyReplayResult { inner: result })
}
fn replay_with_policy(
&self,
snapshot_id: String,
policy: PyRef<PyReplayPolicy>,
mode: Option<String>,
) -> PyResult<PyReplayResult> {
let replay_mode = if let Some(mode_str) = mode {
Some(match mode_str.as_str() {
"strict" => ReplayMode::Strict,
"tolerant" => ReplayMode::Tolerant,
"validation_only" => ReplayMode::ValidationOnly,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid replay mode: {}",
mode_str
)))
}
})
} else {
None
};
let engine = self.inner.clone();
let policy_inner = policy.inner.clone();
engine
.replay_with_policy(&snapshot_id, &policy_inner, replay_mode)
.block_on_python()
.map(|result| PyReplayResult { inner: result })
}
fn replay_batch(
&self,
snapshot_ids: Vec<String>,
mode: Option<String>,
max_concurrent: Option<usize>,
) -> PyResult<PyObject> {
let replay_mode = if let Some(mode_str) = mode {
Some(match mode_str.as_str() {
"strict" => ReplayMode::Strict,
"tolerant" => ReplayMode::Tolerant,
"validation_only" => ReplayMode::ValidationOnly,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid replay mode: {}",
mode_str
)))
}
})
} else {
None
};
let engine = self.inner.clone();
let results = PythonAsyncVecExt::block_on_python(engine.replay_batch(
&snapshot_ids,
replay_mode,
max_concurrent.unwrap_or(4),
))?;
Python::with_gil(|py| {
let list = PyList::empty(py);
for replay_result in results {
let py_result = PyReplayResult {
inner: replay_result,
};
list.append(Py::new(py, py_result)?)?;
}
Ok(list.into())
})
}
fn validate(&self, snapshot_id: String, policy: PyRef<PyReplayPolicy>) -> PyResult<PyObject> {
let engine = self.inner.clone();
let policy_inner = policy.inner.clone();
let violations = engine
.validate(&snapshot_id, &policy_inner)
.block_on_python()?;
Python::with_gil(|py| {
let list = PyList::empty(py);
for violation in violations {
let violation_dict = PyDict::new(py);
violation_dict.set_item("rule_name", &violation.rule_name)?;
violation_dict.set_item("field", &violation.field)?;
violation_dict.set_item("expected", &violation.expected)?;
violation_dict.set_item("actual", &violation.actual)?;
violation_dict.set_item("message", &violation.message)?;
list.append(violation_dict)?;
}
Ok(list.into())
})
}
fn get_replay_stats(&self, snapshot_ids: Vec<String>) -> PyResult<PyReplayStats> {
let engine = self.inner.clone();
engine
.get_replay_stats(&snapshot_ids)
.block_on_python()
.map(|stats| PyReplayStats { inner: stats })
}
fn __repr__(&self) -> String {
"ReplayEngine()".to_string()
}
}
#[pyclass(name = "ReplayPolicy")]
pub struct PyReplayPolicy {
pub inner: ReplayPolicy,
}
#[pymethods]
impl PyReplayPolicy {
#[new]
fn new(name: String) -> Self {
Self {
inner: ReplayPolicy::new(name),
}
}
fn with_exact_match(mut slf: PyRefMut<Self>, field: String) -> PyRefMut<Self> {
slf.inner = slf.inner.clone().with_exact_match(field);
slf
}
fn with_similarity_threshold(
mut slf: PyRefMut<Self>,
field: String,
threshold: f64,
) -> PyRefMut<Self> {
slf.inner = slf
.inner
.clone()
.with_similarity_threshold(field, threshold);
slf
}
#[getter]
fn name(&self) -> String {
self.inner.name.clone()
}
#[getter]
fn rule_count(&self) -> usize {
self.inner.rules.len()
}
fn __repr__(&self) -> String {
format!(
"ReplayPolicy(name='{}', rules={})",
self.inner.name,
self.inner.rules.len()
)
}
}
#[pyclass(name = "ReplayResult")]
pub struct PyReplayResult {
pub inner: ReplayResult,
}
#[pymethods]
impl PyReplayResult {
#[getter]
fn status(&self) -> String {
match self.inner.status {
ReplayStatus::Success => "success".to_string(),
ReplayStatus::Failed => "failed".to_string(),
ReplayStatus::Partial => "partial".to_string(),
ReplayStatus::Pending => "pending".to_string(),
ReplayStatus::Running => "running".to_string(),
}
}
#[getter]
fn original_snapshot(&self) -> PyDecisionSnapshot {
PyDecisionSnapshot {
inner: self.inner.original_snapshot.clone(),
}
}
#[getter]
fn outputs_match(&self) -> bool {
self.inner.outputs_match
}
#[getter]
fn execution_time_ms(&self) -> f64 {
self.inner.execution_time_ms
}
#[getter]
fn policy_violations(&self) -> PyResult<PyObject> {
Python::with_gil(|py| {
let list = PyList::empty(py);
for violation in &self.inner.policy_violations {
let violation_dict = PyDict::new(py);
violation_dict.set_item("rule_name", &violation.rule_name)?;
violation_dict.set_item("field", &violation.field)?;
violation_dict.set_item("expected", &violation.expected)?;
violation_dict.set_item("actual", &violation.actual)?;
violation_dict.set_item("message", &violation.message)?;
list.append(violation_dict)?;
}
Ok(list.into())
})
}
#[getter]
fn replay_output(&self) -> PyResult<PyObject> {
Python::with_gil(|py| {
if let Some(ref output) = self.inner.replay_output {
crate::models::json_value_to_python(output, py)
} else {
Ok(py.None())
}
})
}
fn to_dict(&self) -> PyResult<PyObject> {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("status", self.status())?;
dict.set_item("outputs_match", self.inner.outputs_match)?;
dict.set_item("execution_time_ms", self.inner.execution_time_ms)?;
if let Some(ref output) = self.inner.replay_output {
dict.set_item(
"replay_output",
crate::models::json_value_to_python(output, py)?,
)?;
}
dict.set_item("policy_violations", self.policy_violations()?)?;
Ok(dict.into())
})
}
fn __repr__(&self) -> String {
format!(
"ReplayResult(status='{}', outputs_match={})",
self.status(),
self.inner.outputs_match
)
}
}
#[pyclass(name = "ReplayStats")]
pub struct PyReplayStats {
pub inner: ReplayStats,
}
#[pymethods]
impl PyReplayStats {
#[getter]
fn total_replays(&self) -> usize {
self.inner.total_replays
}
#[getter]
fn successful_replays(&self) -> usize {
self.inner.successful_replays
}
#[getter]
fn failed_replays(&self) -> usize {
self.inner.failed_replays
}
#[getter]
fn exact_matches(&self) -> usize {
self.inner.exact_matches
}
#[getter]
fn mismatches(&self) -> usize {
self.inner.mismatches
}
#[getter]
fn average_execution_time_ms(&self) -> f64 {
self.inner.average_execution_time_ms
}
#[getter]
fn total_execution_time_ms(&self) -> f64 {
self.inner.total_execution_time_ms
}
#[getter]
fn success_rate(&self) -> f64 {
if self.inner.total_replays == 0 {
0.0
} else {
self.inner.successful_replays as f64 / self.inner.total_replays as f64
}
}
fn to_dict(&self) -> PyResult<PyObject> {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("total_replays", self.inner.total_replays)?;
dict.set_item("successful_replays", self.inner.successful_replays)?;
dict.set_item("failed_replays", self.inner.failed_replays)?;
dict.set_item("exact_matches", self.inner.exact_matches)?;
dict.set_item("mismatches", self.inner.mismatches)?;
dict.set_item(
"average_execution_time_ms",
self.inner.average_execution_time_ms,
)?;
dict.set_item(
"total_execution_time_ms",
self.inner.total_execution_time_ms,
)?;
dict.set_item("success_rate", self.success_rate())?;
Ok(dict.into())
})
}
fn __repr__(&self) -> String {
format!(
"ReplayStats(total={}, success_rate={:.1}%)",
self.inner.total_replays,
self.success_rate() * 100.0
)
}
}