use fidius_core::python_descriptor::PythonInterfaceDescriptor;
use fidius_core::PluginError;
use pyo3::prelude::*;
use pyo3::types::{PyAnyMethods, PyBytes, PyDict, PyTuple};
use crate::error::pyerr_to_plugin_error;
use crate::value_bridge::{pyobject_to_value, value_to_pyobject};
#[derive(Debug, thiserror::Error)]
pub enum PythonCallError {
#[error("invalid method index {index} (interface has {count} method(s))")]
InvalidMethodIndex { index: usize, count: usize },
#[error(
"wire-mode mismatch on method '{method}': declared wire_raw={declared}, dispatcher used wire_raw={attempted}"
)]
WireModeMismatch {
method: &'static str,
declared: bool,
attempted: bool,
},
#[error("failed to decode typed input: {0}")]
InputDecode(String),
#[error("failed to encode typed output: {0}")]
OutputEncode(String),
#[error("plugin raised: [{}] {}", .0.code, .0.message)]
Plugin(PluginError),
}
#[derive(Debug)]
pub struct PythonPluginHandle {
descriptor: &'static PythonInterfaceDescriptor,
_module: Py<PyAny>,
method_callables: Vec<Py<PyAny>>,
}
impl PythonPluginHandle {
pub(crate) fn new(
descriptor: &'static PythonInterfaceDescriptor,
module: Py<PyAny>,
method_callables: Vec<Py<PyAny>>,
) -> Self {
Self {
descriptor,
_module: module,
method_callables,
}
}
pub fn descriptor(&self) -> &'static PythonInterfaceDescriptor {
self.descriptor
}
pub fn method_count(&self) -> usize {
self.descriptor.methods.len()
}
pub fn call_typed(
&self,
method_index: usize,
input_bincode: &[u8],
) -> Result<Vec<u8>, PythonCallError> {
self.call_typed_json(method_index, input_bincode)
}
pub fn call_typed_json(
&self,
method_index: usize,
input_json: &[u8],
) -> Result<Vec<u8>, PythonCallError> {
let method = self.lookup_method(method_index, false)?;
let input_value: serde_json::Value = serde_json::from_slice(input_json)
.map_err(|e| PythonCallError::InputDecode(e.to_string()))?;
let result_value = Python::with_gil(|py| -> Result<serde_json::Value, PythonCallError> {
let callable = method.callable.bind(py);
let py_args = build_call_args(py, &input_value)
.map_err(|e| PythonCallError::InputDecode(e.to_string()))?;
let result = callable
.call(py_args, None::<&Bound<'_, PyDict>>)
.map_err(|e| PythonCallError::Plugin(pyerr_to_plugin_error(e)))?;
pyobject_to_value(&result).map_err(|e| PythonCallError::OutputEncode(e.to_string()))
})?;
serde_json::to_vec(&result_value).map_err(|e| PythonCallError::OutputEncode(e.to_string()))
}
pub fn call_raw(&self, method_index: usize, input: &[u8]) -> Result<Vec<u8>, PythonCallError> {
let method = self.lookup_method(method_index, true)?;
Python::with_gil(|py| {
let callable = method.callable.bind(py);
let arg = PyBytes::new(py, input);
let result = callable
.call1((arg,))
.map_err(|e| PythonCallError::Plugin(pyerr_to_plugin_error(e)))?;
let bytes: Vec<u8> = result.extract().map_err(|e| {
PythonCallError::OutputEncode(format!(
"raw method must return bytes/bytearray, got: {e}"
))
})?;
Ok(bytes)
})
}
fn lookup_method(
&self,
index: usize,
attempting_raw: bool,
) -> Result<MethodLookup<'_>, PythonCallError> {
if index >= self.method_callables.len() {
return Err(PythonCallError::InvalidMethodIndex {
index,
count: self.method_callables.len(),
});
}
let desc = &self.descriptor.methods[index];
if desc.wire_raw != attempting_raw {
return Err(PythonCallError::WireModeMismatch {
method: desc.name,
declared: desc.wire_raw,
attempted: attempting_raw,
});
}
Ok(MethodLookup {
callable: &self.method_callables[index],
})
}
}
struct MethodLookup<'a> {
callable: &'a Py<PyAny>,
}
fn build_call_args<'py>(
py: Python<'py>,
input: &serde_json::Value,
) -> PyResult<Bound<'py, PyTuple>> {
match input {
serde_json::Value::Array(items) => {
let py_items: Vec<Bound<'_, PyAny>> = items
.iter()
.map(|v| value_to_pyobject(py, v))
.collect::<PyResult<_>>()?;
PyTuple::new(py, py_items)
}
serde_json::Value::Null => PyTuple::new(py, Vec::<Bound<'_, PyAny>>::new()),
other => {
let pyobj = value_to_pyobject(py, other)?;
PyTuple::new(py, vec![pyobj])
}
}
}