#![allow(clippy::useless_conversion)]
use std::io::Cursor;
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList};
use pythonize::{depythonize, pythonize};
use crate::agentlog::{hash, parser, writer, Record};
use crate::diff::{
compute_report,
cost::{ModelPricing, Pricing},
embedder::{BoxedEmbedder, Embedder},
semantic::compute_with_embedder,
};
#[pyfunction]
fn parse_agentlog<'py>(
py: Python<'py>,
data: &Bound<'py, PyBytes>,
) -> PyResult<Bound<'py, PyList>> {
let bytes = data.as_bytes();
let records =
parser::parse_all(Cursor::new(bytes)).map_err(|e| PyValueError::new_err(e.to_string()))?;
let out = PyList::empty_bound(py);
for r in records {
let v = serde_json::to_value(&r).map_err(|e| PyValueError::new_err(e.to_string()))?;
let obj = pythonize(py, &v).map_err(|e| PyValueError::new_err(e.to_string()))?;
out.append(obj)?;
}
Ok(out)
}
#[pyfunction]
fn write_agentlog<'py>(
py: Python<'py>,
records: &Bound<'py, PyList>,
) -> PyResult<Bound<'py, PyBytes>> {
let mut parsed: Vec<Record> = Vec::with_capacity(records.len());
for item in records.iter() {
let v: serde_json::Value =
depythonize(&item).map_err(|e| PyValueError::new_err(e.to_string()))?;
let r: Record =
serde_json::from_value(v).map_err(|e| PyValueError::new_err(e.to_string()))?;
parsed.push(r);
}
let mut buf = Vec::new();
writer::write_all(&mut buf, &parsed).map_err(|e| PyIOError::new_err(e.to_string()))?;
Ok(PyBytes::new_bound(py, &buf))
}
#[pyfunction]
fn canonical_bytes<'py>(
py: Python<'py>,
payload: &Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyBytes>> {
let v: serde_json::Value =
depythonize(payload).map_err(|e| PyValueError::new_err(e.to_string()))?;
let bytes = crate::agentlog::canonical::to_bytes(&v);
Ok(PyBytes::new_bound(py, &bytes))
}
#[pyfunction]
fn content_id(payload: &Bound<'_, PyAny>) -> PyResult<String> {
let v: serde_json::Value =
depythonize(payload).map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(hash::content_id(&v))
}
#[pyfunction]
#[pyo3(signature = (baseline, candidate, pricing=None, seed=None))]
fn compute_diff_report<'py>(
py: Python<'py>,
baseline: &Bound<'py, PyList>,
candidate: &Bound<'py, PyList>,
pricing: Option<&Bound<'py, PyDict>>,
seed: Option<u64>,
) -> PyResult<Bound<'py, PyAny>> {
let baseline_records = pylist_to_records(baseline)?;
let candidate_records = pylist_to_records(candidate)?;
let mut price_map = Pricing::new();
if let Some(dict) = pricing {
for (k, v) in dict.iter() {
let key: String = k
.extract()
.map_err(|e| PyValueError::new_err(format!("pricing key: {e}")))?;
let mp = if let Ok(pair) = v.extract::<(f64, f64)>() {
ModelPricing::simple(pair.0, pair.1)
} else {
let v_json: serde_json::Value = depythonize(&v)
.map_err(|e| PyValueError::new_err(format!("pricing value: {e}")))?;
serde_json::from_value(v_json)
.map_err(|e| PyValueError::new_err(format!("pricing value: {e}")))?
};
price_map.insert(key, mp);
}
}
let report = compute_report(&baseline_records, &candidate_records, &price_map, seed);
let v = serde_json::to_value(&report).map_err(|e| PyValueError::new_err(e.to_string()))?;
pythonize(py, &v).map_err(|e| PyValueError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (baseline, candidate, embedder, seed=None))]
fn compute_semantic_axis_with_embedder<'py>(
py: Python<'py>,
baseline: &Bound<'py, PyList>,
candidate: &Bound<'py, PyList>,
embedder: &Bound<'py, PyAny>,
seed: Option<u64>,
) -> PyResult<Bound<'py, PyAny>> {
let baseline_records = pylist_to_records(baseline)?;
let candidate_records = pylist_to_records(candidate)?;
if !embedder.is_callable() {
return Err(PyValueError::new_err(
"embedder must be callable: fn(list[str]) -> list[list[float]]",
));
}
let pairs = pair_responses(&baseline_records, &candidate_records);
let embedder_obj: Py<PyAny> = embedder.clone().unbind();
let py_embedder = BoxedEmbedder::named(
move |texts: &[&str]| -> Vec<Vec<f32>> {
Python::with_gil(|py| {
let owned: Vec<String> = texts.iter().map(|s| (*s).to_string()).collect();
let py_list = PyList::new_bound(py, &owned);
let result = embedder_obj.call1(py, (py_list,));
let any = match result {
Ok(v) => v,
Err(_) => return Vec::new(),
};
let bound = any.bind(py);
bound.extract::<Vec<Vec<f32>>>().unwrap_or_default()
})
},
"py-callback",
);
let pair_refs: Vec<(&Record, &Record)> = pairs.iter().map(|(a, b)| (*a, *b)).collect();
let stat = compute_with_embedder(&pair_refs, &py_embedder as &dyn Embedder, seed);
let v = serde_json::to_value(&stat).map_err(|e| PyValueError::new_err(e.to_string()))?;
pythonize(py, &v).map_err(|e| PyValueError::new_err(e.to_string()))
}
fn pair_responses<'a>(
baseline: &'a [Record],
candidate: &'a [Record],
) -> Vec<(&'a Record, &'a Record)> {
use crate::agentlog::Kind;
let b_resps: Vec<&Record> = baseline
.iter()
.filter(|r| r.kind == Kind::ChatResponse)
.collect();
let c_resps: Vec<&Record> = candidate
.iter()
.filter(|r| r.kind == Kind::ChatResponse)
.collect();
b_resps.into_iter().zip(c_resps).collect()
}
fn pylist_to_records(list: &Bound<'_, PyList>) -> PyResult<Vec<Record>> {
let mut out = Vec::with_capacity(list.len());
for item in list.iter() {
let v: serde_json::Value =
depythonize(&item).map_err(|e| PyValueError::new_err(e.to_string()))?;
let r: Record =
serde_json::from_value(v).map_err(|e| PyValueError::new_err(e.to_string()))?;
out.push(r);
}
Ok(out)
}
#[pymodule]
fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("__version__", crate::VERSION)?;
m.add("SPEC_VERSION", crate::agentlog::CURRENT_VERSION)?;
m.add_function(wrap_pyfunction!(parse_agentlog, m)?)?;
m.add_function(wrap_pyfunction!(write_agentlog, m)?)?;
m.add_function(wrap_pyfunction!(canonical_bytes, m)?)?;
m.add_function(wrap_pyfunction!(content_id, m)?)?;
m.add_function(wrap_pyfunction!(compute_diff_report, m)?)?;
m.add_function(wrap_pyfunction!(compute_semantic_axis_with_embedder, m)?)?;
Ok(())
}