1use std::{collections::HashMap, sync::Arc};
5
6use anyhow::{anyhow, Result};
7use futures::{SinkExt, StreamExt};
8use indexmap::IndexMap;
9use kcmc::{
10 websocket::{
11 BatchResponse, FailureWebSocketResponse, ModelingCmdReq, ModelingSessionData, OkWebSocketResponseData,
12 SuccessWebSocketResponse, WebSocketRequest, WebSocketResponse,
13 },
14 ModelingCmd,
15};
16use kittycad_modeling_cmds::{self as kcmc};
17use tokio::sync::{mpsc, oneshot, RwLock};
18use tokio_tungstenite::tungstenite::Message as WsMsg;
19use uuid::Uuid;
20
21use super::ExecutionKind;
22use crate::{
23 engine::EngineManager,
24 errors::{KclError, KclErrorDetails},
25 execution::{ArtifactCommand, DefaultPlanes, IdGenerator},
26 SourceRange,
27};
28
29#[derive(Debug, PartialEq)]
30enum SocketHealth {
31 Active,
32 Inactive,
33}
34
35type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
36#[derive(Debug)]
37pub struct EngineConnection {
38 engine_req_tx: mpsc::Sender<ToEngineReq>,
39 shutdown_tx: mpsc::Sender<()>,
40 responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
41 pending_errors: Arc<RwLock<Vec<String>>>,
42 #[allow(dead_code)]
43 tcp_read_handle: Arc<TcpReadHandle>,
44 socket_health: Arc<RwLock<SocketHealth>>,
45 batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
46 batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
47 artifact_commands: Arc<RwLock<Vec<ArtifactCommand>>>,
48
49 default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
51 session_data: Arc<RwLock<Option<ModelingSessionData>>>,
53
54 execution_kind: Arc<RwLock<ExecutionKind>>,
55}
56
57pub struct TcpRead {
58 stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
59}
60
61pub enum WebSocketReadError {
64 Read(tokio_tungstenite::tungstenite::Error),
66 Deser(anyhow::Error),
68}
69
70impl From<anyhow::Error> for WebSocketReadError {
71 fn from(e: anyhow::Error) -> Self {
72 Self::Deser(e)
73 }
74}
75
76impl TcpRead {
77 pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
78 let Some(msg) = self.stream.next().await else {
79 return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
80 };
81 let msg = match msg {
82 Ok(msg) => msg,
83 Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
84 return Err(WebSocketReadError::Read(e))
85 }
86 Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
87 };
88 let msg: WebSocketResponse = match msg {
89 WsMsg::Text(text) => serde_json::from_str(&text)
90 .map_err(anyhow::Error::from)
91 .map_err(WebSocketReadError::from)?,
92 WsMsg::Binary(bin) => bson::from_slice(&bin)
93 .map_err(anyhow::Error::from)
94 .map_err(WebSocketReadError::from)?,
95 other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
96 };
97 Ok(msg)
98 }
99}
100
101pub struct TcpReadHandle {
102 handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
103}
104
105impl std::fmt::Debug for TcpReadHandle {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 write!(f, "TcpReadHandle")
108 }
109}
110
111impl Drop for TcpReadHandle {
112 fn drop(&mut self) {
113 self.handle.abort();
115 }
116}
117
118struct ToEngineReq {
120 req: WebSocketRequest,
122 request_sent: oneshot::Sender<Result<()>>,
126}
127
128impl EngineConnection {
129 async fn start_write_actor(
131 mut tcp_write: WebSocketTcpWrite,
132 mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
133 mut shutdown_rx: mpsc::Receiver<()>,
134 ) {
135 loop {
136 tokio::select! {
137 maybe_req = engine_req_rx.recv() => {
138 match maybe_req {
139 Some(ToEngineReq { req, request_sent }) => {
140 let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
143 cmd: ModelingCmd::ImportFiles { .. },
144 cmd_id: _,
145 }) = &req
146 {
147 Self::inner_send_to_engine_binary(req, &mut tcp_write).await
148 } else {
149 Self::inner_send_to_engine(req, &mut tcp_write).await
150 };
151
152 let _ = request_sent.send(res);
154 }
155 None => {
156 break;
159 }
160 }
161 },
162
163 _ = shutdown_rx.recv() => {
165 let _ = Self::inner_close_engine(&mut tcp_write).await;
166 return;
167 }
168 }
169 }
170
171 let _ = Self::inner_close_engine(&mut tcp_write).await;
174 }
175
176 async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
178 tcp_write
179 .send(WsMsg::Close(None))
180 .await
181 .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
182 Ok(())
183 }
184
185 async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
187 let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
188 tcp_write
189 .send(WsMsg::Text(msg))
190 .await
191 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
192 Ok(())
193 }
194
195 async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
197 let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
198 tcp_write
199 .send(WsMsg::Binary(msg))
200 .await
201 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
202 Ok(())
203 }
204
205 pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
206 let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
207 max_message_size: Some(usize::MAX),
209 max_frame_size: Some(usize::MAX),
210 ..Default::default()
211 };
212
213 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
214 ws,
215 tokio_tungstenite::tungstenite::protocol::Role::Client,
216 Some(wsconfig),
217 )
218 .await;
219
220 let (tcp_write, tcp_read) = ws_stream.split();
221 let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
222 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
223 tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
224
225 let mut tcp_read = TcpRead { stream: tcp_read };
226
227 let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
228 let session_data2 = session_data.clone();
229 let responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>> = Arc::new(RwLock::new(IndexMap::new()));
230 let responses_clone = responses.clone();
231 let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
232 let pending_errors = Arc::new(RwLock::new(Vec::new()));
233 let pending_errors_clone = pending_errors.clone();
234
235 let socket_health_tcp_read = socket_health.clone();
236 let tcp_read_handle = tokio::spawn(async move {
237 loop {
239 match tcp_read.read().await {
240 Ok(ws_resp) => {
241 let id = ws_resp.request_id();
243 match &ws_resp {
244 WebSocketResponse::Success(SuccessWebSocketResponse {
245 resp: OkWebSocketResponseData::ModelingBatch { responses },
246 ..
247 }) =>
248 {
249 #[expect(
250 clippy::iter_over_hash_type,
251 reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
252 )]
253 for (resp_id, batch_response) in responses {
254 let id: uuid::Uuid = (*resp_id).into();
255 match batch_response {
256 BatchResponse::Success { response } => {
257 responses_clone.write().await.insert(
258 id,
259 WebSocketResponse::Success(SuccessWebSocketResponse {
260 success: true,
261 request_id: Some(id),
262 resp: OkWebSocketResponseData::Modeling {
263 modeling_response: response.clone(),
264 },
265 }),
266 );
267 }
268 BatchResponse::Failure { errors } => {
269 responses_clone.write().await.insert(
270 id,
271 WebSocketResponse::Failure(FailureWebSocketResponse {
272 success: false,
273 request_id: Some(id),
274 errors: errors.clone(),
275 }),
276 );
277 }
278 }
279 }
280 }
281 WebSocketResponse::Success(SuccessWebSocketResponse {
282 resp: OkWebSocketResponseData::ModelingSessionData { session },
283 ..
284 }) => {
285 let mut sd = session_data2.write().await;
286 sd.replace(session.clone());
287 }
288 WebSocketResponse::Failure(FailureWebSocketResponse {
289 success: _,
290 request_id,
291 errors,
292 }) => {
293 if let Some(id) = request_id {
294 responses_clone.write().await.insert(
295 *id,
296 WebSocketResponse::Failure(FailureWebSocketResponse {
297 success: false,
298 request_id: *request_id,
299 errors: errors.clone(),
300 }),
301 );
302 } else {
303 let mut pe = pending_errors_clone.write().await;
305 for error in errors {
306 if !pe.contains(&error.message) {
307 pe.push(error.message.clone());
308 }
309 }
310 drop(pe);
311 }
312 }
313 _ => {}
314 }
315
316 if let Some(id) = id {
317 responses_clone.write().await.insert(id, ws_resp.clone());
318 }
319 }
320 Err(e) => {
321 match &e {
322 WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
323 WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
324 }
325 *socket_health_tcp_read.write().await = SocketHealth::Inactive;
326 return Err(e);
327 }
328 }
329 }
330 });
331
332 Ok(EngineConnection {
333 engine_req_tx,
334 shutdown_tx,
335 tcp_read_handle: Arc::new(TcpReadHandle {
336 handle: Arc::new(tcp_read_handle),
337 }),
338 responses,
339 pending_errors,
340 socket_health,
341 batch: Arc::new(RwLock::new(Vec::new())),
342 batch_end: Arc::new(RwLock::new(IndexMap::new())),
343 artifact_commands: Arc::new(RwLock::new(Vec::new())),
344 default_planes: Default::default(),
345 session_data,
346 execution_kind: Default::default(),
347 })
348 }
349}
350
351#[async_trait::async_trait]
352impl EngineManager for EngineConnection {
353 fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
354 self.batch.clone()
355 }
356
357 fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
358 self.batch_end.clone()
359 }
360
361 fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
362 self.responses.clone()
363 }
364
365 fn artifact_commands(&self) -> Arc<RwLock<Vec<ArtifactCommand>>> {
366 self.artifact_commands.clone()
367 }
368
369 async fn execution_kind(&self) -> ExecutionKind {
370 let guard = self.execution_kind.read().await;
371 *guard
372 }
373
374 async fn replace_execution_kind(&self, execution_kind: ExecutionKind) -> ExecutionKind {
375 let mut guard = self.execution_kind.write().await;
376 let original = *guard;
377 *guard = execution_kind;
378 original
379 }
380
381 async fn default_planes(
382 &self,
383 id_generator: &mut IdGenerator,
384 source_range: SourceRange,
385 ) -> Result<DefaultPlanes, KclError> {
386 {
387 let opt = self.default_planes.read().await.as_ref().cloned();
388 if let Some(planes) = opt {
389 return Ok(planes);
390 }
391 } let new_planes = self.new_default_planes(id_generator, source_range).await?;
394 *self.default_planes.write().await = Some(new_planes.clone());
395
396 Ok(new_planes)
397 }
398
399 async fn clear_scene_post_hook(
400 &self,
401 id_generator: &mut IdGenerator,
402 source_range: SourceRange,
403 ) -> Result<(), KclError> {
404 let new_planes = self.new_default_planes(id_generator, source_range).await?;
406 *self.default_planes.write().await = Some(new_planes);
407
408 Ok(())
409 }
410
411 async fn inner_send_modeling_cmd(
412 &self,
413 id: uuid::Uuid,
414 source_range: SourceRange,
415 cmd: WebSocketRequest,
416 _id_to_source_range: HashMap<Uuid, SourceRange>,
417 ) -> Result<WebSocketResponse, KclError> {
418 let (tx, rx) = oneshot::channel();
419
420 self.engine_req_tx
422 .send(ToEngineReq {
423 req: cmd.clone(),
424 request_sent: tx,
425 })
426 .await
427 .map_err(|e| {
428 KclError::Engine(KclErrorDetails {
429 message: format!("Failed to send modeling command: {}", e),
430 source_ranges: vec![source_range],
431 })
432 })?;
433
434 rx.await
436 .map_err(|e| {
437 KclError::Engine(KclErrorDetails {
438 message: format!("could not send request to the engine actor: {e}"),
439 source_ranges: vec![source_range],
440 })
441 })?
442 .map_err(|e| {
443 KclError::Engine(KclErrorDetails {
444 message: format!("could not send request to the engine: {e}"),
445 source_ranges: vec![source_range],
446 })
447 })?;
448
449 let current_time = std::time::Instant::now();
451 while current_time.elapsed().as_secs() < 60 {
452 let guard = self.socket_health.read().await;
453 if *guard == SocketHealth::Inactive {
454 let pe = self.pending_errors.read().await;
456 if !pe.is_empty() {
457 return Err(KclError::Engine(KclErrorDetails {
458 message: pe.join(", ").to_string(),
459 source_ranges: vec![source_range],
460 }));
461 } else {
462 return Err(KclError::Engine(KclErrorDetails {
463 message: "Modeling command failed: websocket closed early".to_string(),
464 source_ranges: vec![source_range],
465 }));
466 }
467 }
468 if let Some(resp) = self.responses.write().await.shift_remove(&id) {
470 return Ok(resp);
471 }
472 }
473
474 Err(KclError::Engine(KclErrorDetails {
475 message: format!("Modeling command timed out `{}`", id),
476 source_ranges: vec![source_range],
477 }))
478 }
479
480 async fn get_session_data(&self) -> Option<ModelingSessionData> {
481 self.session_data.read().await.clone()
482 }
483
484 async fn close(&self) {
485 let _ = self.shutdown_tx.send(()).await;
486 loop {
487 let guard = self.socket_health.read().await;
488 if *guard == SocketHealth::Inactive {
489 return;
490 }
491 }
492 }
493}