1use pyo3::prelude::*;
6use pyo3::exceptions::{PyValueError, PyRuntimeError};
7use std::collections::HashMap;
8
9use mpl_core::{
10 hash::{canonicalize as rust_canonicalize, semantic_hash as rust_semantic_hash},
11 qom::{QomMetrics as RustQomMetrics, QomProfile as RustQomProfile},
12 stype::SType as RustSType,
13 validation::SchemaValidator as RustSchemaValidator,
14};
15
16#[pyclass(name = "SType")]
18#[derive(Clone)]
19pub struct PySType {
20 inner: RustSType,
21}
22
23#[pymethods]
24impl PySType {
25 #[new]
27 fn new(stype_str: &str) -> PyResult<Self> {
28 RustSType::parse(stype_str)
29 .map(|inner| Self { inner })
30 .map_err(|e| PyValueError::new_err(e.to_string()))
31 }
32
33 #[staticmethod]
35 fn create(namespace: &str, domain: &str, name: &str, major_version: u32) -> Self {
36 Self {
37 inner: RustSType::new(namespace, domain, name, major_version),
38 }
39 }
40
41 #[getter]
43 fn namespace(&self) -> &str {
44 &self.inner.namespace
45 }
46
47 #[getter]
49 fn domain(&self) -> &str {
50 &self.inner.domain
51 }
52
53 #[getter]
55 fn name(&self) -> &str {
56 &self.inner.name
57 }
58
59 #[getter]
61 fn major_version(&self) -> u32 {
62 self.inner.major_version
63 }
64
65 fn id(&self) -> String {
67 self.inner.id()
68 }
69
70 fn urn(&self) -> String {
72 self.inner.urn()
73 }
74
75 fn registry_path(&self) -> String {
77 self.inner.registry_path()
78 }
79
80 fn __str__(&self) -> String {
81 self.inner.id()
82 }
83
84 fn __repr__(&self) -> String {
85 format!("SType('{}')", self.inner.id())
86 }
87}
88
89#[pyclass(name = "SchemaValidator")]
91pub struct PySchemaValidator {
92 inner: RustSchemaValidator,
93}
94
95#[pymethods]
96impl PySchemaValidator {
97 #[new]
98 fn new() -> Self {
99 Self {
100 inner: RustSchemaValidator::new(),
101 }
102 }
103
104 fn register(&mut self, stype: &str, schema_json: &str) -> PyResult<()> {
106 self.inner
107 .register_json(stype, schema_json)
108 .map_err(|e| PyValueError::new_err(e.to_string()))
109 }
110
111 fn has_schema(&self, stype: &str) -> bool {
113 self.inner.has_schema(stype)
114 }
115
116 fn validate(&self, stype: &str, payload_json: &str) -> PyResult<PyValidationResult> {
118 let payload: serde_json::Value = serde_json::from_str(payload_json)
119 .map_err(|e| PyValueError::new_err(format!("Invalid JSON: {}", e)))?;
120
121 let result = self.inner.validate(stype, &payload)
122 .map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
123
124 Ok(PyValidationResult {
125 valid: result.valid,
126 errors: result.errors.iter().map(|e| PySchemaError {
127 path: e.path.clone(),
128 message: e.message.clone(),
129 }).collect(),
130 })
131 }
132
133 fn validate_or_raise(&self, stype: &str, payload_json: &str) -> PyResult<()> {
135 let payload: serde_json::Value = serde_json::from_str(payload_json)
136 .map_err(|e| PyValueError::new_err(format!("Invalid JSON: {}", e)))?;
137
138 self.inner.validate_or_error(stype, &payload)
139 .map_err(|e| PyValueError::new_err(e.to_string()))
140 }
141
142 fn registered_stypes(&self) -> Vec<String> {
144 self.inner.registered_stypes().iter().map(|s| s.to_string()).collect()
145 }
146}
147
148#[pyclass(name = "ValidationResult")]
150#[derive(Clone)]
151pub struct PyValidationResult {
152 #[pyo3(get)]
153 valid: bool,
154 #[pyo3(get)]
155 errors: Vec<PySchemaError>,
156}
157
158#[pymethods]
159impl PyValidationResult {
160 fn __bool__(&self) -> bool {
161 self.valid
162 }
163
164 fn __repr__(&self) -> String {
165 if self.valid {
166 "ValidationResult(valid=True)".to_string()
167 } else {
168 format!("ValidationResult(valid=False, errors={})", self.errors.len())
169 }
170 }
171}
172
173#[pyclass(name = "SchemaError")]
175#[derive(Clone)]
176pub struct PySchemaError {
177 #[pyo3(get)]
178 path: String,
179 #[pyo3(get)]
180 message: String,
181}
182
183#[pymethods]
184impl PySchemaError {
185 fn __repr__(&self) -> String {
186 format!("SchemaError(path='{}', message='{}')", self.path, self.message)
187 }
188}
189
190#[pyclass(name = "QomMetrics")]
192#[derive(Clone)]
193pub struct PyQomMetrics {
194 #[pyo3(get, set)]
195 schema_fidelity: f64,
196 #[pyo3(get, set)]
197 instruction_compliance: Option<f64>,
198 #[pyo3(get, set)]
199 groundedness: Option<f64>,
200 #[pyo3(get, set)]
201 determinism_jitter: Option<f64>,
202 #[pyo3(get, set)]
203 ontology_adherence: Option<f64>,
204 #[pyo3(get, set)]
205 tool_outcome_correctness: Option<f64>,
206}
207
208#[pymethods]
209impl PyQomMetrics {
210 #[new]
211 #[pyo3(signature = (schema_fidelity=1.0, instruction_compliance=None, groundedness=None, determinism_jitter=None, ontology_adherence=None, tool_outcome_correctness=None))]
212 fn new(
213 schema_fidelity: f64,
214 instruction_compliance: Option<f64>,
215 groundedness: Option<f64>,
216 determinism_jitter: Option<f64>,
217 ontology_adherence: Option<f64>,
218 tool_outcome_correctness: Option<f64>,
219 ) -> Self {
220 Self {
221 schema_fidelity,
222 instruction_compliance,
223 groundedness,
224 determinism_jitter,
225 ontology_adherence,
226 tool_outcome_correctness,
227 }
228 }
229
230 #[staticmethod]
232 fn schema_valid() -> Self {
233 Self {
234 schema_fidelity: 1.0,
235 instruction_compliance: None,
236 groundedness: None,
237 determinism_jitter: None,
238 ontology_adherence: None,
239 tool_outcome_correctness: None,
240 }
241 }
242
243 #[staticmethod]
245 fn schema_invalid() -> Self {
246 Self {
247 schema_fidelity: 0.0,
248 instruction_compliance: None,
249 groundedness: None,
250 determinism_jitter: None,
251 ontology_adherence: None,
252 tool_outcome_correctness: None,
253 }
254 }
255
256 fn to_dict(&self) -> HashMap<String, f64> {
258 let mut map = HashMap::new();
259 map.insert("schema_fidelity".to_string(), self.schema_fidelity);
260 if let Some(ic) = self.instruction_compliance {
261 map.insert("instruction_compliance".to_string(), ic);
262 }
263 if let Some(g) = self.groundedness {
264 map.insert("groundedness".to_string(), g);
265 }
266 if let Some(dj) = self.determinism_jitter {
267 map.insert("determinism_jitter".to_string(), dj);
268 }
269 if let Some(oa) = self.ontology_adherence {
270 map.insert("ontology_adherence".to_string(), oa);
271 }
272 if let Some(toc) = self.tool_outcome_correctness {
273 map.insert("tool_outcome_correctness".to_string(), toc);
274 }
275 map
276 }
277
278 fn __repr__(&self) -> String {
279 format!("QomMetrics(schema_fidelity={:.2})", self.schema_fidelity)
280 }
281}
282
283impl From<PyQomMetrics> for RustQomMetrics {
284 fn from(py: PyQomMetrics) -> Self {
285 RustQomMetrics {
286 schema_fidelity: py.schema_fidelity,
287 instruction_compliance: py.instruction_compliance,
288 groundedness: py.groundedness,
289 determinism_jitter: py.determinism_jitter,
290 ontology_adherence: py.ontology_adherence,
291 tool_outcome_correctness: py.tool_outcome_correctness,
292 }
293 }
294}
295
296#[pyclass(name = "QomProfile")]
298#[derive(Clone)]
299pub struct PyQomProfile {
300 inner: RustQomProfile,
301}
302
303#[pymethods]
304impl PyQomProfile {
305 #[staticmethod]
307 fn basic() -> Self {
308 Self {
309 inner: RustQomProfile::basic(),
310 }
311 }
312
313 #[staticmethod]
315 fn strict_argcheck() -> Self {
316 Self {
317 inner: RustQomProfile::strict_argcheck(),
318 }
319 }
320
321 #[getter]
323 fn name(&self) -> &str {
324 &self.inner.name
325 }
326
327 #[getter]
329 fn description(&self) -> Option<&str> {
330 self.inner.description.as_deref()
331 }
332
333 fn evaluate(&self, metrics: &PyQomMetrics) -> PyQomEvaluation {
335 let rust_metrics: RustQomMetrics = metrics.clone().into();
336 let eval = self.inner.evaluate(&rust_metrics);
337 PyQomEvaluation {
338 meets_profile: eval.meets_profile,
339 profile: eval.profile,
340 failures: eval.failures.iter().map(|f| PyMetricFailure {
341 metric: f.metric.clone(),
342 actual: f.actual,
343 threshold: f.threshold,
344 }).collect(),
345 }
346 }
347
348 fn __repr__(&self) -> String {
349 format!("QomProfile(name='{}')", self.inner.name)
350 }
351}
352
353#[pyclass(name = "QomEvaluation")]
355#[derive(Clone)]
356pub struct PyQomEvaluation {
357 #[pyo3(get)]
358 meets_profile: bool,
359 #[pyo3(get)]
360 profile: String,
361 #[pyo3(get)]
362 failures: Vec<PyMetricFailure>,
363}
364
365#[pymethods]
366impl PyQomEvaluation {
367 fn __bool__(&self) -> bool {
368 self.meets_profile
369 }
370
371 fn __repr__(&self) -> String {
372 if self.meets_profile {
373 format!("QomEvaluation(meets_profile=True, profile='{}')", self.profile)
374 } else {
375 format!("QomEvaluation(meets_profile=False, failures={})", self.failures.len())
376 }
377 }
378}
379
380#[pyclass(name = "MetricFailure")]
382#[derive(Clone)]
383pub struct PyMetricFailure {
384 #[pyo3(get)]
385 metric: String,
386 #[pyo3(get)]
387 actual: f64,
388 #[pyo3(get)]
389 threshold: f64,
390}
391
392#[pymethods]
393impl PyMetricFailure {
394 fn __repr__(&self) -> String {
395 format!(
396 "MetricFailure(metric='{}', actual={:.2}, threshold={:.2})",
397 self.metric, self.actual, self.threshold
398 )
399 }
400}
401
402#[pyfunction]
404fn canonicalize(json_str: &str) -> PyResult<String> {
405 let value: serde_json::Value = serde_json::from_str(json_str)
406 .map_err(|e| PyValueError::new_err(format!("Invalid JSON: {}", e)))?;
407
408 rust_canonicalize(&value)
409 .map_err(|e| PyRuntimeError::new_err(e.to_string()))
410}
411
412#[pyfunction]
414fn semantic_hash(json_str: &str) -> PyResult<String> {
415 let value: serde_json::Value = serde_json::from_str(json_str)
416 .map_err(|e| PyValueError::new_err(format!("Invalid JSON: {}", e)))?;
417
418 rust_semantic_hash(&value)
419 .map_err(|e| PyRuntimeError::new_err(e.to_string()))
420}
421
422#[pyfunction]
424fn verify_hash(json_str: &str, expected_hash: &str) -> PyResult<bool> {
425 let value: serde_json::Value = serde_json::from_str(json_str)
426 .map_err(|e| PyValueError::new_err(format!("Invalid JSON: {}", e)))?;
427
428 mpl_core::hash::verify_hash(&value, expected_hash)
429 .map_err(|e| PyRuntimeError::new_err(e.to_string()))
430}
431
432#[pyclass(name = "MplEnvelope")]
434#[derive(Clone)]
435pub struct PyMplEnvelope {
436 #[pyo3(get)]
437 id: String,
438 #[pyo3(get, set)]
439 stype: String,
440 #[pyo3(get, set)]
441 payload: String, #[pyo3(get, set)]
443 args_stype: Option<String>,
444 #[pyo3(get, set)]
445 profile: Option<String>,
446 #[pyo3(get, set)]
447 sem_hash: Option<String>,
448 #[pyo3(get, set)]
449 features: Vec<String>,
450}
451
452#[pymethods]
453impl PyMplEnvelope {
454 #[new]
455 #[pyo3(signature = (stype, payload, args_stype=None, profile=None))]
456 fn new(
457 stype: String,
458 payload: String,
459 args_stype: Option<String>,
460 profile: Option<String>,
461 ) -> PyResult<Self> {
462 let _: serde_json::Value = serde_json::from_str(&payload)
464 .map_err(|e| PyValueError::new_err(format!("Invalid JSON payload: {}", e)))?;
465
466 Ok(Self {
467 id: uuid::Uuid::new_v4().to_string(),
468 stype,
469 payload,
470 args_stype,
471 profile,
472 sem_hash: None,
473 features: Vec::new(),
474 })
475 }
476
477 fn compute_hash(&mut self) -> PyResult<String> {
479 let hash = semantic_hash(&self.payload)?;
480 self.sem_hash = Some(hash.clone());
481 Ok(hash)
482 }
483
484 fn verify_hash(&self) -> PyResult<bool> {
486 match &self.sem_hash {
487 Some(expected) => verify_hash(&self.payload, expected),
488 None => Ok(true),
489 }
490 }
491
492 fn get_payload(&self) -> PyResult<PyObject> {
494 Python::with_gil(|py| {
495 let value: serde_json::Value = serde_json::from_str(&self.payload)
496 .map_err(|e| PyValueError::new_err(e.to_string()))?;
497 json_to_py(py, &value)
498 })
499 }
500
501 fn to_json(&self) -> PyResult<String> {
503 let payload_value: serde_json::Value = serde_json::from_str(&self.payload)
504 .map_err(|e| PyValueError::new_err(format!("Invalid payload JSON: {}", e)))?;
505 let envelope = serde_json::json!({
506 "id": self.id,
507 "stype": self.stype,
508 "payload": payload_value,
509 "args_stype": self.args_stype,
510 "profile": self.profile,
511 "sem_hash": self.sem_hash,
512 "features": self.features,
513 });
514 serde_json::to_string_pretty(&envelope)
515 .map_err(|e| PyRuntimeError::new_err(e.to_string()))
516 }
517
518 fn __repr__(&self) -> String {
519 format!("MplEnvelope(id='{}', stype='{}')", self.id, self.stype)
520 }
521}
522
523fn json_to_py(py: Python<'_>, value: &serde_json::Value) -> PyResult<PyObject> {
525 match value {
526 serde_json::Value::Null => Ok(py.None()),
527 serde_json::Value::Bool(b) => Ok(b.into_py(py)),
528 serde_json::Value::Number(n) => {
529 if let Some(i) = n.as_i64() {
530 Ok(i.into_py(py))
531 } else if let Some(f) = n.as_f64() {
532 Ok(f.into_py(py))
533 } else {
534 Ok(py.None())
535 }
536 }
537 serde_json::Value::String(s) => Ok(s.into_py(py)),
538 serde_json::Value::Array(arr) => {
539 let list: Vec<PyObject> = arr.iter()
540 .map(|v| json_to_py(py, v))
541 .collect::<PyResult<_>>()?;
542 Ok(list.into_py(py))
543 }
544 serde_json::Value::Object(map) => {
545 let dict = pyo3::types::PyDict::new_bound(py);
546 for (k, v) in map {
547 dict.set_item(k, json_to_py(py, v)?)?;
548 }
549 Ok(dict.into())
550 }
551 }
552}
553
554#[pymodule]
556fn _mpl_core(m: &Bound<'_, PyModule>) -> PyResult<()> {
557 m.add_class::<PySType>()?;
558 m.add_class::<PySchemaValidator>()?;
559 m.add_class::<PyValidationResult>()?;
560 m.add_class::<PySchemaError>()?;
561 m.add_class::<PyQomMetrics>()?;
562 m.add_class::<PyQomProfile>()?;
563 m.add_class::<PyQomEvaluation>()?;
564 m.add_class::<PyMetricFailure>()?;
565 m.add_class::<PyMplEnvelope>()?;
566 m.add_function(wrap_pyfunction!(canonicalize, m)?)?;
567 m.add_function(wrap_pyfunction!(semantic_hash, m)?)?;
568 m.add_function(wrap_pyfunction!(verify_hash, m)?)?;
569
570 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
572
573 Ok(())
574}