use std::collections::HashMap;
use async_trait::async_trait;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use crate::error::{KnafehError, RpcStatusCode};
use crate::rpc::message::{RpcRequest, RpcResponse};
use crate::rpc::service::{MethodDescriptor, MethodKind, Service};
use crate::rpc::stream::{RpcStreamRequest, RpcStreamResponse};
#[pyclass(name = "ServiceHandler")]
pub struct PyServiceHandler {
pub(crate) name: String,
pub(crate) handlers: HashMap<String, (Py<PyAny>, MethodKind)>,
}
#[pymethods]
impl PyServiceHandler {
#[new]
fn new(name: String) -> Self {
Self {
name,
handlers: HashMap::new(),
}
}
fn add_unary_handler(&mut self, method: String, handler: Py<PyAny>) {
self.handlers.insert(method, (handler, MethodKind::Unary));
}
fn add_server_stream_handler(&mut self, method: String, handler: Py<PyAny>) {
self.handlers
.insert(method, (handler, MethodKind::ServerStreaming));
}
fn add_client_stream_handler(&mut self, method: String, handler: Py<PyAny>) {
self.handlers
.insert(method, (handler, MethodKind::ClientStreaming));
}
fn add_bidi_stream_handler(&mut self, method: String, handler: Py<PyAny>) {
self.handlers
.insert(method, (handler, MethodKind::BidiStreaming));
}
}
pub(crate) struct PythonServiceBridge {
pub(crate) name: String,
pub(crate) handlers: HashMap<String, (Py<PyAny>, MethodKind)>,
}
impl PythonServiceBridge {
pub fn from_handler(handler: &PyServiceHandler) -> Self {
let handlers = Python::attach(|py| {
handler
.handlers
.iter()
.map(|(k, (obj, kind))| (k.clone(), (obj.clone_ref(py), *kind)))
.collect()
});
Self {
name: handler.name.clone(),
handlers,
}
}
}
#[async_trait]
impl Service for PythonServiceBridge {
fn name(&self) -> &str {
&self.name
}
fn methods(&self) -> Vec<MethodDescriptor> {
self.handlers
.iter()
.map(|(name, (_, kind))| MethodDescriptor {
name: name.clone(),
kind: *kind,
})
.collect()
}
async fn call_unary(
&self,
method: &str,
request: RpcRequest,
) -> Result<RpcResponse, KnafehError> {
let (handler, _) = self
.handlers
.get(method)
.ok_or_else(|| KnafehError::Service {
code: RpcStatusCode::NotFound,
message: format!("method not found: {method}"),
})?;
let handler = Python::attach(|py| handler.clone_ref(py));
let body = request.body;
let result: Vec<u8> = tokio::task::spawn_blocking(move || {
Python::attach(|py| {
let py_bytes = PyBytes::new(py, &body);
let result = handler
.call1(py, (py_bytes,))
.map_err(|e| KnafehError::Service {
code: RpcStatusCode::Internal,
message: format!("Python handler error: {e}"),
})?;
let bytes: &Bound<PyBytes> =
result.cast_bound(py).map_err(|e| KnafehError::Service {
code: RpcStatusCode::Internal,
message: format!("handler must return bytes: {e}"),
})?;
Ok::<Vec<u8>, KnafehError>(bytes.as_bytes().to_vec())
})
})
.await
.map_err(|e| KnafehError::Service {
code: RpcStatusCode::Internal,
message: format!("task join error: {e}"),
})??;
Ok(RpcResponse::ok(result))
}
async fn call_server_stream(
&self,
_method: &str,
_request: RpcRequest,
) -> Result<RpcStreamResponse, KnafehError> {
Err(KnafehError::Service {
code: RpcStatusCode::Unimplemented,
message: "server streaming not yet implemented in Python bridge".to_string(),
})
}
async fn call_client_stream(
&self,
_method: &str,
_stream: RpcStreamRequest,
) -> Result<RpcResponse, KnafehError> {
Err(KnafehError::Service {
code: RpcStatusCode::Unimplemented,
message: "client streaming not yet implemented in Python bridge".to_string(),
})
}
async fn call_bidi_stream(
&self,
_method: &str,
_stream: RpcStreamRequest,
) -> Result<RpcStreamResponse, KnafehError> {
Err(KnafehError::Service {
code: RpcStatusCode::Unimplemented,
message: "bidi streaming not yet implemented in Python bridge".to_string(),
})
}
}