use std::path::PathBuf;
use std::time::Duration;
#[cfg(unix)]
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::time::timeout;
use crate::error::WorldModelError;
use crate::{OccupancyWorldModelRequest, OccupancyWorldModelResponse};
const TIMEOUT_S: u64 = 30;
#[cfg(unix)]
const MAX_RESPONSE_BYTES: usize = 64 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct OccWorldBridge {
pub socket_path: PathBuf,
}
impl OccWorldBridge {
pub fn new(socket_path: impl Into<PathBuf>) -> Self {
Self {
socket_path: socket_path.into(),
}
}
pub async fn predict(
&self,
request: OccupancyWorldModelRequest,
) -> Result<OccupancyWorldModelResponse, WorldModelError> {
timeout(
Duration::from_secs(TIMEOUT_S),
self.send_recv(request),
)
.await
.map_err(|_| WorldModelError::Timeout { timeout_s: TIMEOUT_S })?
}
#[cfg(not(unix))]
async fn send_recv(
&self,
_request: OccupancyWorldModelRequest,
) -> Result<OccupancyWorldModelResponse, WorldModelError> {
Err(WorldModelError::Protocol(
"OccWorld Unix-socket bridge is only supported on unix targets".into(),
))
}
#[cfg(unix)]
async fn send_recv(
&self,
request: OccupancyWorldModelRequest,
) -> Result<OccupancyWorldModelResponse, WorldModelError> {
let stream = self.connect().await?;
let (reader_half, mut writer_half) = stream.into_split();
let mut payload = serde_json::to_vec(&request)?;
payload.push(b'\n');
writer_half
.write_all(&payload)
.await
.map_err(|e| WorldModelError::Protocol(format!("write error: {e}")))?;
writer_half
.flush()
.await
.map_err(|e| WorldModelError::Protocol(format!("flush error: {e}")))?;
let mut line = String::new();
let mut buf_reader = BufReader::new(reader_half);
buf_reader
.read_line(&mut line)
.await
.map_err(|e| WorldModelError::Protocol(format!("read error: {e}")))?;
if line.is_empty() {
return Err(WorldModelError::Protocol(
"server closed connection before sending a response".into(),
));
}
if line.len() > MAX_RESPONSE_BYTES {
return Err(WorldModelError::Protocol(format!(
"response line too large ({} bytes > {} byte limit)",
line.len(),
MAX_RESPONSE_BYTES
)));
}
let response: OccupancyWorldModelResponse = serde_json::from_str(line.trim())?;
if response.model_id.starts_with("error:vram:") {
return Err(WorldModelError::VramUnavailable(
response.model_id["error:vram:".len()..].to_owned(),
));
}
Ok(response)
}
#[cfg(unix)]
async fn connect(&self) -> Result<UnixStream, WorldModelError> {
UnixStream::connect(&self.socket_path)
.await
.map_err(|e| WorldModelError::SocketConnect {
path: self.socket_path.display().to_string(),
source: e,
})
}
}
pub fn default_socket_path() -> PathBuf {
PathBuf::from("/tmp/occworld.sock")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bridge_new_stores_path() {
let b = OccWorldBridge::new("/tmp/test.sock");
assert_eq!(b.socket_path, PathBuf::from("/tmp/test.sock"));
}
#[test]
fn default_socket_path_is_deterministic() {
assert_eq!(default_socket_path(), PathBuf::from("/tmp/occworld.sock"));
}
#[cfg(unix)]
#[tokio::test]
async fn connect_to_missing_socket_returns_error() {
let bridge = OccWorldBridge::new("/tmp/__occworld_nonexistent_test__.sock");
use crate::{OccupancyGrid3D, OccupancyWorldModelRequest, SceneBoundsJson};
let req = OccupancyWorldModelRequest {
past_frames: vec![OccupancyGrid3D {
width: 200,
height: 200,
depth: 16,
voxels: vec![17u8; 200 * 200 * 16],
}],
voxel_resolution_m: 0.1,
scene_bounds: SceneBoundsJson {
min_e: -10.0,
min_n: -10.0,
max_e: 10.0,
max_n: 10.0,
},
prediction_steps: 1,
};
let err = bridge.predict(req).await.unwrap_err();
assert!(
matches!(err, WorldModelError::SocketConnect { .. }),
"expected SocketConnect, got {err:?}"
);
}
}