use std::sync::Arc;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use crate::client::Client;
use crate::python::types::PyTlsConfig;
use crate::rpc::stream::RpcStreamReceiver;
#[pyclass(name = "RpcClient")]
pub struct PyRpcClient {
client: Option<Arc<Client>>,
endpoint: String,
tls_config: PyTlsConfig,
pool_size: usize,
}
#[pymethods]
impl PyRpcClient {
#[new]
#[pyo3(signature = (endpoint, tls, pool_size=4))]
fn new(endpoint: String, tls: PyTlsConfig, pool_size: usize) -> Self {
Self {
client: None,
endpoint,
tls_config: tls,
pool_size,
}
}
fn connect<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
if self.client.is_some() {
return Ok(py.None().into_bound(py));
}
let endpoint = self.endpoint.clone();
let tls = self.tls_config.inner.clone();
let pool_size = self.pool_size;
let rt = pyo3_async_runtimes::tokio::get_runtime();
let client = rt
.block_on(async move {
Client::builder()
.endpoint(endpoint)
.tls(tls)
.pool_size(pool_size)
.build()
.await
})
.map_err(|e| pyo3::exceptions::PyConnectionError::new_err(e.to_string()))?;
self.client = Some(Arc::new(client));
Ok(py.None().into_bound(py))
}
fn call<'py>(
&mut self,
py: Python<'py>,
method: String,
body: Vec<u8>,
) -> PyResult<Bound<'py, PyAny>> {
self.ensure_connected()?;
let client = Arc::clone(self.client.as_ref().unwrap());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let response = client
.call(&method, body)
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(response.body)
})
}
fn server_stream<'py>(
&mut self,
py: Python<'py>,
method: String,
body: Vec<u8>,
) -> PyResult<Bound<'py, PyAny>> {
self.ensure_connected()?;
let client = Arc::clone(self.client.as_ref().unwrap());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let receiver = client
.server_stream(&method, body)
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
Python::attach(|py| {
let stream = PyStreamReceiver {
receiver: Arc::new(tokio::sync::Mutex::new(Some(receiver))),
};
Ok(stream.into_pyobject(py)?.into_any().unbind())
})
})
}
fn call_sync<'py>(
&mut self,
py: Python<'py>,
method: String,
body: Vec<u8>,
) -> PyResult<Bound<'py, PyBytes>> {
self.ensure_connected()?;
let client = Arc::clone(self.client.as_ref().unwrap());
let rt = pyo3_async_runtimes::tokio::get_runtime();
let result = py.detach(|| {
rt.block_on(async move {
client
.call(&method, body)
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
})
})?;
Ok(PyBytes::new(py, &result.body))
}
fn server_stream_sync(
&mut self,
method: String,
body: Vec<u8>,
) -> PyResult<PySyncStreamReceiver> {
self.ensure_connected()?;
let client = Arc::clone(self.client.as_ref().unwrap());
let rt = pyo3_async_runtimes::tokio::get_runtime();
let receiver = rt.block_on(async move {
client
.server_stream(&method, body)
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
})?;
Ok(PySyncStreamReceiver {
receiver: Arc::new(tokio::sync::Mutex::new(Some(receiver))),
})
}
fn close(&mut self) {
self.client = None;
}
}
impl PyRpcClient {
fn ensure_connected(&self) -> PyResult<()> {
if self.client.is_none() {
return Err(pyo3::exceptions::PyConnectionError::new_err(
"client not connected — call connect() first",
));
}
Ok(())
}
}
#[pyclass(name = "StreamReceiver")]
pub struct PyStreamReceiver {
receiver: Arc<tokio::sync::Mutex<Option<RpcStreamReceiver>>>,
}
#[pymethods]
impl PyStreamReceiver {
fn next_chunk<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let receiver = Arc::clone(&self.receiver);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut guard = receiver.lock().await;
if let Some(ref mut recv) = *guard {
match recv.next().await {
Some(Ok(bytes)) => {
Python::attach(|py| Ok(PyBytes::new(py, &bytes).into_any().unbind()))
}
Some(Err(e)) => Err(pyo3::exceptions::PyRuntimeError::new_err(e.to_string())),
None => Python::attach(|py| Ok(py.None().into())),
}
} else {
Python::attach(|py| Ok(py.None().into()))
}
})
}
fn close<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let receiver = Arc::clone(&self.receiver);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut guard = receiver.lock().await;
*guard = None;
Ok(())
})
}
}
#[pyclass(name = "SyncStreamReceiver")]
pub struct PySyncStreamReceiver {
receiver: Arc<tokio::sync::Mutex<Option<RpcStreamReceiver>>>,
}
#[pymethods]
impl PySyncStreamReceiver {
fn next_chunk<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyBytes>>> {
let receiver = Arc::clone(&self.receiver);
let rt = pyo3_async_runtimes::tokio::get_runtime();
let result = py.detach(|| {
rt.block_on(async {
let mut guard = receiver.lock().await;
if let Some(ref mut recv) = *guard {
recv.next().await
} else {
None
}
})
});
match result {
Some(Ok(bytes)) => Ok(Some(PyBytes::new(py, &bytes))),
Some(Err(e)) => Err(pyo3::exceptions::PyRuntimeError::new_err(e.to_string())),
None => Ok(None),
}
}
fn close(&self) {
let receiver = Arc::clone(&self.receiver);
let rt = pyo3_async_runtimes::tokio::get_runtime();
rt.block_on(async {
let mut guard = receiver.lock().await;
*guard = None;
});
}
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __next__<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyBytes>>> {
self.next_chunk(py)
}
}