use pyo3::prelude::*;
use pyo3::types::PyBytes;
use std::io::{Read, Write};
use super::json_utils::{json_value_to_py, py_to_json_value};
use crate::error::IpcError;
use crate::local_socket::{
LocalSocketListener as RustLocalSocketListener, LocalSocketStream as RustLocalSocketStream,
};
#[pyclass(name = "LocalSocketListener")]
pub struct PyLocalSocketListener {
inner: parking_lot::Mutex<RustLocalSocketListener>,
}
#[pymethods]
impl PyLocalSocketListener {
#[new]
fn new(name: &str) -> PyResult<Self> {
let inner = RustLocalSocketListener::bind(name)?;
Ok(Self {
inner: parking_lot::Mutex::new(inner),
})
}
#[staticmethod]
fn bind(name: &str) -> PyResult<Self> {
Self::new(name)
}
fn accept(&self, _py: Python<'_>) -> PyResult<PyLocalSocketStream> {
let guard = self.inner.lock();
let stream = guard.accept()?;
Ok(PyLocalSocketStream {
inner: parking_lot::Mutex::new(stream),
})
}
#[getter]
fn name(&self) -> String {
self.inner.lock().name().to_string()
}
}
#[pyclass(name = "LocalSocketStream")]
pub struct PyLocalSocketStream {
inner: parking_lot::Mutex<RustLocalSocketStream>,
}
#[pymethods]
impl PyLocalSocketStream {
#[staticmethod]
fn connect(name: &str) -> PyResult<Self> {
let inner = RustLocalSocketStream::connect(name)?;
Ok(Self {
inner: parking_lot::Mutex::new(inner),
})
}
#[getter]
fn name(&self) -> String {
self.inner.lock().name().to_string()
}
fn read(&self, py: Python<'_>, size: usize) -> PyResult<Py<PyBytes>> {
let mut buf = vec![0u8; size];
let n = {
let mut guard = self.inner.lock();
guard.read(&mut buf)?
};
buf.truncate(n);
Ok(PyBytes::new(py, &buf).into())
}
fn write(&self, _py: Python<'_>, data: Vec<u8>) -> PyResult<usize> {
let mut guard = self.inner.lock();
let n = guard.write(&data)?;
Ok(n)
}
fn read_exact(&self, py: Python<'_>, size: usize) -> PyResult<Py<PyBytes>> {
let mut buf = vec![0u8; size];
{
let mut guard = self.inner.lock();
guard.read_exact(&mut buf)?;
}
Ok(PyBytes::new(py, &buf).into())
}
fn write_all(&self, _py: Python<'_>, data: Vec<u8>) -> PyResult<()> {
let mut guard = self.inner.lock();
guard.write_all(&data)?;
Ok(())
}
fn flush(&self, _py: Python<'_>) -> PyResult<()> {
let mut guard = self.inner.lock();
guard.flush()?;
Ok(())
}
fn send_json(&self, _py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<()> {
let value = py_to_json_value(obj)?;
let json_bytes = serde_json::to_vec(&value)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
let len_bytes = (json_bytes.len() as u32).to_be_bytes();
let mut guard = self.inner.lock();
guard.write_all(&len_bytes)?;
guard.write_all(&json_bytes)?;
guard.flush()?;
Ok(())
}
fn recv_json(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let mut guard = self.inner.lock();
let mut len_bytes = [0u8; 4];
guard.read_exact(&mut len_bytes)?;
let len = u32::from_be_bytes(len_bytes) as usize;
let mut json_bytes = vec![0u8; len];
guard.read_exact(&mut json_bytes)?;
drop(guard);
let value: serde_json::Value = serde_json::from_slice(&json_bytes)
.map_err(|e| IpcError::deserialization(e.to_string()))?;
json_value_to_py(py, &value)
}
}