use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use anyhow::{anyhow, Result};
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use kittycad::types::{ModelingSessionData, WebSocketRequest, WebSocketResponse};
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio_tungstenite::tungstenite::Message as WsMsg;
use crate::{
engine::EngineManager,
errors::{KclError, KclErrorDetails},
executor::DefaultPlanes,
};
#[derive(Debug, PartialEq)]
enum SocketHealth {
Active,
Inactive,
}
type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
#[derive(Debug, Clone)]
pub struct EngineConnection {
engine_req_tx: mpsc::Sender<ToEngineReq>,
responses: Arc<DashMap<uuid::Uuid, WebSocketResponse>>,
#[allow(dead_code)]
tcp_read_handle: Arc<TcpReadHandle>,
socket_health: Arc<Mutex<SocketHealth>>,
batch: Arc<Mutex<Vec<(WebSocketRequest, crate::executor::SourceRange)>>>,
batch_end: Arc<Mutex<HashMap<uuid::Uuid, (WebSocketRequest, crate::executor::SourceRange)>>>,
default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
session_data: Arc<Mutex<Option<ModelingSessionData>>>,
}
pub struct TcpRead {
stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
}
pub enum WebSocketReadError {
Read(tokio_tungstenite::tungstenite::Error),
Deser(anyhow::Error),
}
impl From<anyhow::Error> for WebSocketReadError {
fn from(e: anyhow::Error) -> Self {
Self::Deser(e)
}
}
impl TcpRead {
pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
let Some(msg) = self.stream.next().await else {
return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
};
let msg = match msg {
Ok(msg) => msg,
Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
return Err(WebSocketReadError::Read(e))
}
Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
};
let msg: WebSocketResponse = match msg {
WsMsg::Text(text) => serde_json::from_str(&text)
.map_err(anyhow::Error::from)
.map_err(WebSocketReadError::from)?,
WsMsg::Binary(bin) => bson::from_slice(&bin)
.map_err(anyhow::Error::from)
.map_err(WebSocketReadError::from)?,
other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
};
Ok(msg)
}
}
pub struct TcpReadHandle {
handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
}
impl std::fmt::Debug for TcpReadHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TcpReadHandle")
}
}
impl Drop for TcpReadHandle {
fn drop(&mut self) {
self.handle.abort();
}
}
struct ToEngineReq {
req: WebSocketRequest,
request_sent: oneshot::Sender<Result<()>>,
}
impl EngineConnection {
async fn start_write_actor(mut tcp_write: WebSocketTcpWrite, mut engine_req_rx: mpsc::Receiver<ToEngineReq>) {
while let Some(req) = engine_req_rx.recv().await {
let ToEngineReq { req, request_sent } = req;
let res = if let kittycad::types::WebSocketRequest::ModelingCmdReq {
cmd: kittycad::types::ModelingCmd::ImportFiles { .. },
cmd_id: _,
} = &req
{
Self::inner_send_to_engine_binary(req, &mut tcp_write).await
} else {
Self::inner_send_to_engine(req, &mut tcp_write).await
};
let _ = request_sent.send(res);
}
let _ = Self::inner_close_engine(&mut tcp_write).await;
}
async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
tcp_write
.send(WsMsg::Close(None))
.await
.map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
Ok(())
}
async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
tcp_write
.send(WsMsg::Text(msg))
.await
.map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
Ok(())
}
async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
tcp_write
.send(WsMsg::Binary(msg))
.await
.map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
Ok(())
}
pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: Some(0x100000000),
max_frame_size: Some(0x100000000),
..Default::default()
};
let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
ws,
tokio_tungstenite::tungstenite::protocol::Role::Client,
Some(wsconfig),
)
.await;
let (tcp_write, tcp_read) = ws_stream.split();
let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx));
let mut tcp_read = TcpRead { stream: tcp_read };
let session_data: Arc<Mutex<Option<ModelingSessionData>>> = Arc::new(Mutex::new(None));
let session_data2 = session_data.clone();
let responses: Arc<DashMap<uuid::Uuid, WebSocketResponse>> = Arc::new(DashMap::new());
let responses_clone = responses.clone();
let socket_health = Arc::new(Mutex::new(SocketHealth::Active));
let socket_health_tcp_read = socket_health.clone();
let tcp_read_handle = tokio::spawn(async move {
loop {
match tcp_read.read().await {
Ok(ws_resp) => {
match &ws_resp.resp {
Some(kittycad::types::OkWebSocketResponseData::ModelingBatch { responses }) => {
for (resp_id, batch_response) in responses {
let id: uuid::Uuid = resp_id.parse().unwrap();
if let Some(response) = &batch_response.response {
responses_clone.insert(
id,
kittycad::types::WebSocketResponse {
request_id: Some(id),
resp: Some(kittycad::types::OkWebSocketResponseData::Modeling {
modeling_response: response.clone(),
}),
errors: None,
success: Some(true),
},
);
} else {
responses_clone.insert(
id,
kittycad::types::WebSocketResponse {
request_id: Some(id),
resp: None,
errors: batch_response.errors.clone(),
success: Some(false),
},
);
}
}
}
Some(kittycad::types::OkWebSocketResponseData::ModelingSessionData { session }) => {
let mut sd = session_data2.lock().unwrap();
sd.replace(session.clone());
}
_ => {}
}
if let Some(id) = ws_resp.request_id {
responses_clone.insert(id, ws_resp.clone());
}
}
Err(e) => {
match &e {
WebSocketReadError::Read(e) => eprintln!("could not read from WS: {:?}", e),
WebSocketReadError::Deser(e) => eprintln!("could not deserialize msg from WS: {:?}", e),
}
*socket_health_tcp_read.lock().unwrap() = SocketHealth::Inactive;
return Err(e);
}
}
}
});
Ok(EngineConnection {
engine_req_tx,
tcp_read_handle: Arc::new(TcpReadHandle {
handle: Arc::new(tcp_read_handle),
}),
responses,
socket_health,
batch: Arc::new(Mutex::new(Vec::new())),
batch_end: Arc::new(Mutex::new(HashMap::new())),
default_planes: Default::default(),
session_data,
})
}
}
#[async_trait::async_trait]
impl EngineManager for EngineConnection {
fn batch(&self) -> Arc<Mutex<Vec<(WebSocketRequest, crate::executor::SourceRange)>>> {
self.batch.clone()
}
fn batch_end(&self) -> Arc<Mutex<HashMap<uuid::Uuid, (WebSocketRequest, crate::executor::SourceRange)>>> {
self.batch_end.clone()
}
async fn default_planes(&self, source_range: crate::executor::SourceRange) -> Result<DefaultPlanes, KclError> {
{
let opt = self.default_planes.read().await.as_ref().cloned();
if let Some(planes) = opt {
return Ok(planes);
}
} let new_planes = self.new_default_planes(source_range).await?;
*self.default_planes.write().await = Some(new_planes.clone());
Ok(new_planes)
}
async fn clear_scene_post_hook(&self, source_range: crate::executor::SourceRange) -> Result<(), KclError> {
let new_planes = self.new_default_planes(source_range).await?;
*self.default_planes.write().await = Some(new_planes);
Ok(())
}
async fn inner_send_modeling_cmd(
&self,
id: uuid::Uuid,
source_range: crate::executor::SourceRange,
cmd: kittycad::types::WebSocketRequest,
_id_to_source_range: std::collections::HashMap<uuid::Uuid, crate::executor::SourceRange>,
) -> Result<WebSocketResponse, KclError> {
let (tx, rx) = oneshot::channel();
self.engine_req_tx
.send(ToEngineReq {
req: cmd.clone(),
request_sent: tx,
})
.await
.map_err(|e| {
KclError::Engine(KclErrorDetails {
message: format!("Failed to send modeling command: {}", e),
source_ranges: vec![source_range],
})
})?;
rx.await
.map_err(|e| {
KclError::Engine(KclErrorDetails {
message: format!("could not send request to the engine actor: {e}"),
source_ranges: vec![source_range],
})
})?
.map_err(|e| {
KclError::Engine(KclErrorDetails {
message: format!("could not send request to the engine: {e}"),
source_ranges: vec![source_range],
})
})?;
let current_time = std::time::Instant::now();
while current_time.elapsed().as_secs() < 60 {
if let Ok(guard) = self.socket_health.lock() {
if *guard == SocketHealth::Inactive {
return Err(KclError::Engine(KclErrorDetails {
message: "Modeling command failed: websocket closed early".to_string(),
source_ranges: vec![source_range],
}));
}
}
if let Some((_, resp)) = self.responses.remove(&id) {
return Ok(resp);
}
}
Err(KclError::Engine(KclErrorDetails {
message: format!("Modeling command timed out `{}`", id),
source_ranges: vec![source_range],
}))
}
fn get_session_data(&self) -> Option<ModelingSessionData> {
self.session_data.lock().unwrap().clone()
}
}