knafeh 1.1.0

QUIC-based RPC library with Python bindings
Documentation
use std::sync::Arc;

use pyo3::prelude::*;
use pyo3::types::PyBytes;

use crate::client::Client;
use crate::python::types::PyTlsConfig;
use crate::rpc::stream::RpcStreamReceiver;

/// Python-visible RPC client.
///
/// All async methods return native Python awaitables via `future_into_py`.
/// No background threads or channel marshalling — Python `await` drives
/// the tokio runtime directly.
#[pyclass(name = "RpcClient")]
pub struct PyRpcClient {
    client: Option<Arc<Client>>,
    endpoint: String,
    tls_config: PyTlsConfig,
    pool_size: usize,
}

#[pymethods]
impl PyRpcClient {
    #[new]
    #[pyo3(signature = (endpoint, tls, pool_size=4))]
    fn new(endpoint: String, tls: PyTlsConfig, pool_size: usize) -> Self {
        Self {
            client: None,
            endpoint,
            tls_config: tls,
            pool_size,
        }
    }

    /// Connect to the server. Returns a Python awaitable.
    fn connect<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
        if self.client.is_some() {
            return Ok(py.None().into_bound(py));
        }

        let endpoint = self.endpoint.clone();
        let tls = self.tls_config.inner.clone();
        let pool_size = self.pool_size;

        // We need to build the client synchronously here since we store it
        // in &mut self. Use the runtime's block_on.
        let rt = pyo3_async_runtimes::tokio::get_runtime();
        let client = rt
            .block_on(async move {
                Client::builder()
                    .endpoint(endpoint)
                    .tls(tls)
                    .pool_size(pool_size)
                    .build()
                    .await
            })
            .map_err(|e| pyo3::exceptions::PyConnectionError::new_err(e.to_string()))?;

        self.client = Some(Arc::new(client));
        Ok(py.None().into_bound(py))
    }

    /// Make a unary RPC call. Returns a Python awaitable that resolves to bytes.
    fn call<'py>(
        &mut self,
        py: Python<'py>,
        method: String,
        body: Vec<u8>,
    ) -> PyResult<Bound<'py, PyAny>> {
        self.ensure_connected()?;
        let client = Arc::clone(self.client.as_ref().unwrap());

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let response = client
                .call(&method, body)
                .await
                .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
            Ok(response.body)
        })
    }

    /// Initiate a server-streaming RPC call.
    /// Returns a Python awaitable that resolves to a `StreamReceiver`.
    fn server_stream<'py>(
        &mut self,
        py: Python<'py>,
        method: String,
        body: Vec<u8>,
    ) -> PyResult<Bound<'py, PyAny>> {
        self.ensure_connected()?;
        let client = Arc::clone(self.client.as_ref().unwrap());

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let receiver = client
                .server_stream(&method, body)
                .await
                .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

            Python::attach(|py| {
                let stream = PyStreamReceiver {
                    receiver: Arc::new(tokio::sync::Mutex::new(Some(receiver))),
                };
                Ok(stream.into_pyobject(py)?.into_any().unbind())
            })
        })
    }

    // -----------------------------------------------------------------
    // Synchronous API — for scripts, notebooks, and non-async contexts.
    // Uses the tokio runtime's block_on to drive futures synchronously.
    // -----------------------------------------------------------------

    /// Make a unary RPC call synchronously. Returns bytes.
    fn call_sync<'py>(
        &mut self,
        py: Python<'py>,
        method: String,
        body: Vec<u8>,
    ) -> PyResult<Bound<'py, PyBytes>> {
        self.ensure_connected()?;
        let client = Arc::clone(self.client.as_ref().unwrap());
        let rt = pyo3_async_runtimes::tokio::get_runtime();

        let result = py.detach(|| {
            rt.block_on(async move {
                client
                    .call(&method, body)
                    .await
                    .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
            })
        })?;

        Ok(PyBytes::new(py, &result.body))
    }

    /// Initiate a server-streaming RPC call synchronously.
    /// Returns a `SyncStreamReceiver`.
    fn server_stream_sync(
        &mut self,
        method: String,
        body: Vec<u8>,
    ) -> PyResult<PySyncStreamReceiver> {
        self.ensure_connected()?;
        let client = Arc::clone(self.client.as_ref().unwrap());
        let rt = pyo3_async_runtimes::tokio::get_runtime();

        let receiver = rt.block_on(async move {
            client
                .server_stream(&method, body)
                .await
                .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
        })?;

        Ok(PySyncStreamReceiver {
            receiver: Arc::new(tokio::sync::Mutex::new(Some(receiver))),
        })
    }

    /// Close the client and release all pooled connections.
    fn close(&mut self) {
        self.client = None;
    }
}

impl PyRpcClient {
    fn ensure_connected(&self) -> PyResult<()> {
        if self.client.is_none() {
            return Err(pyo3::exceptions::PyConnectionError::new_err(
                "client not connected — call connect() first",
            ));
        }
        Ok(())
    }
}

/// A streaming response receiver exposed to Python.
///
/// `next_chunk()` returns a native Python awaitable that resolves to
/// bytes or None.
#[pyclass(name = "StreamReceiver")]
pub struct PyStreamReceiver {
    receiver: Arc<tokio::sync::Mutex<Option<RpcStreamReceiver>>>,
}

#[pymethods]
impl PyStreamReceiver {
    /// Get the next chunk. Returns awaitable → bytes | None.
    fn next_chunk<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
        let receiver = Arc::clone(&self.receiver);

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let mut guard = receiver.lock().await;
            if let Some(ref mut recv) = *guard {
                match recv.next().await {
                    Some(Ok(bytes)) => {
                        Python::attach(|py| Ok(PyBytes::new(py, &bytes).into_any().unbind()))
                    }
                    Some(Err(e)) => Err(pyo3::exceptions::PyRuntimeError::new_err(e.to_string())),
                    None => Python::attach(|py| Ok(py.None().into())),
                }
            } else {
                Python::attach(|py| Ok(py.None().into()))
            }
        })
    }

    /// Close the stream early.
    fn close<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
        let receiver = Arc::clone(&self.receiver);

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let mut guard = receiver.lock().await;
            *guard = None;
            Ok(())
        })
    }
}

/// Synchronous streaming response receiver.
///
/// Each `next_chunk()` call blocks until the next chunk is available.
/// For use in scripts, notebooks, and non-async contexts.
#[pyclass(name = "SyncStreamReceiver")]
pub struct PySyncStreamReceiver {
    receiver: Arc<tokio::sync::Mutex<Option<RpcStreamReceiver>>>,
}

#[pymethods]
impl PySyncStreamReceiver {
    /// Get the next chunk (blocking). Returns bytes or None.
    fn next_chunk<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyBytes>>> {
        let receiver = Arc::clone(&self.receiver);
        let rt = pyo3_async_runtimes::tokio::get_runtime();

        let result = py.detach(|| {
            rt.block_on(async {
                let mut guard = receiver.lock().await;
                if let Some(ref mut recv) = *guard {
                    recv.next().await
                } else {
                    None
                }
            })
        });

        match result {
            Some(Ok(bytes)) => Ok(Some(PyBytes::new(py, &bytes))),
            Some(Err(e)) => Err(pyo3::exceptions::PyRuntimeError::new_err(e.to_string())),
            None => Ok(None),
        }
    }

    /// Close the stream.
    fn close(&self) {
        let receiver = Arc::clone(&self.receiver);
        let rt = pyo3_async_runtimes::tokio::get_runtime();
        rt.block_on(async {
            let mut guard = receiver.lock().await;
            *guard = None;
        });
    }

    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
        slf
    }

    fn __next__<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyBytes>>> {
        self.next_chunk(py)
    }
}