1use std::{collections::HashMap, sync::Arc};
5
6use anyhow::{Result, anyhow};
7use futures::{SinkExt, StreamExt};
8use indexmap::IndexMap;
9use kcmc::{
10 ModelingCmd,
11 websocket::{
12 BatchResponse, FailureWebSocketResponse, ModelingCmdReq, ModelingSessionData, OkWebSocketResponseData,
13 SuccessWebSocketResponse, WebSocketRequest, WebSocketResponse,
14 },
15};
16use kittycad_modeling_cmds::{self as kcmc};
17use tokio::sync::{RwLock, mpsc, oneshot};
18use tokio_tungstenite::tungstenite::Message as WsMsg;
19use uuid::Uuid;
20
21use crate::{
22 SourceRange,
23 engine::{AsyncTasks, EngineManager, EngineStats},
24 errors::{KclError, KclErrorDetails},
25 execution::{DefaultPlanes, IdGenerator},
26};
27
28#[derive(Debug, PartialEq)]
29enum SocketHealth {
30 Active,
31 Inactive,
32}
33
34type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
35#[derive(Debug)]
36pub struct EngineConnection {
37 engine_req_tx: mpsc::Sender<ToEngineReq>,
38 shutdown_tx: mpsc::Sender<()>,
39 responses: ResponseInformation,
40 pending_errors: Arc<RwLock<Vec<String>>>,
41 #[allow(dead_code)]
42 tcp_read_handle: Arc<TcpReadHandle>,
43 socket_health: Arc<RwLock<SocketHealth>>,
44 batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
45 batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
46 ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
47
48 default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
50 session_data: Arc<RwLock<Option<ModelingSessionData>>>,
52
53 stats: EngineStats,
54
55 async_tasks: AsyncTasks,
56
57 debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
58}
59
60pub struct TcpRead {
61 stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
62}
63
64#[allow(clippy::large_enum_variant)]
67pub enum WebSocketReadError {
68 Read(tokio_tungstenite::tungstenite::Error),
70 Deser(anyhow::Error),
72}
73
74impl From<anyhow::Error> for WebSocketReadError {
75 fn from(e: anyhow::Error) -> Self {
76 Self::Deser(e)
77 }
78}
79
80impl TcpRead {
81 pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
82 let Some(msg) = self.stream.next().await else {
83 return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
84 };
85 let msg = match msg {
86 Ok(msg) => msg,
87 Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
88 return Err(WebSocketReadError::Read(e));
89 }
90 Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
91 };
92 let msg: WebSocketResponse = match msg {
93 WsMsg::Text(text) => serde_json::from_str(&text)
94 .map_err(anyhow::Error::from)
95 .map_err(WebSocketReadError::from)?,
96 WsMsg::Binary(bin) => match rmp_serde::from_slice(&bin) {
97 Ok(resp) => resp,
98 Err(_) => bson::from_slice(&bin)
99 .map_err(anyhow::Error::from)
100 .map_err(WebSocketReadError::from)?,
101 },
102 other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
103 };
104 Ok(msg)
105 }
106}
107
108pub struct TcpReadHandle {
109 handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
110}
111
112impl std::fmt::Debug for TcpReadHandle {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(f, "TcpReadHandle")
115 }
116}
117
118impl Drop for TcpReadHandle {
119 fn drop(&mut self) {
120 self.handle.abort();
122 }
123}
124
125#[derive(Clone, Debug)]
127struct ResponseInformation {
128 responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
130}
131
132impl ResponseInformation {
133 pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
134 self.responses.write().await.insert(id, response);
135 }
136}
137
138struct ToEngineReq {
140 req: WebSocketRequest,
142 request_sent: oneshot::Sender<Result<()>>,
146}
147
148impl EngineConnection {
149 async fn start_write_actor(
151 mut tcp_write: WebSocketTcpWrite,
152 mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
153 mut shutdown_rx: mpsc::Receiver<()>,
154 ) {
155 loop {
156 tokio::select! {
157 maybe_req = engine_req_rx.recv() => {
158 match maybe_req {
159 Some(ToEngineReq { req, request_sent }) => {
160 let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
163 cmd: ModelingCmd::ImportFiles { .. },
164 cmd_id: _,
165 }) = &req
166 {
167 Self::inner_send_to_engine_binary(req, &mut tcp_write).await
168 } else {
169 Self::inner_send_to_engine(req, &mut tcp_write).await
170 };
171
172 let _ = request_sent.send(res);
174 }
175 None => {
176 break;
179 }
180 }
181 },
182
183 _ = shutdown_rx.recv() => {
185 let _ = Self::inner_close_engine(&mut tcp_write).await;
186 return;
187 }
188 }
189 }
190
191 let _ = Self::inner_close_engine(&mut tcp_write).await;
194 }
195
196 async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
198 tcp_write
199 .send(WsMsg::Close(None))
200 .await
201 .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
202 Ok(())
203 }
204
205 async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
207 let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
208 tcp_write
209 .send(WsMsg::Text(msg.into()))
210 .await
211 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
212 Ok(())
213 }
214
215 async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
217 let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
218 tcp_write
219 .send(WsMsg::Binary(msg.into()))
220 .await
221 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
222 Ok(())
223 }
224
225 pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
226 let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
227 .max_message_size(Some(usize::MAX))
229 .max_frame_size(Some(usize::MAX));
230
231 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
232 ws,
233 tokio_tungstenite::tungstenite::protocol::Role::Client,
234 Some(wsconfig),
235 )
236 .await;
237
238 let (tcp_write, tcp_read) = ws_stream.split();
239 let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
240 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
241 tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
242
243 let mut tcp_read = TcpRead { stream: tcp_read };
244
245 let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
246 let session_data2 = session_data.clone();
247 let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
248 let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
249 let pending_errors = Arc::new(RwLock::new(Vec::new()));
250 let pending_errors_clone = pending_errors.clone();
251 let response_information = ResponseInformation {
252 responses: Arc::new(RwLock::new(IndexMap::new())),
253 };
254 let response_information_cloned = response_information.clone();
255 let debug_info = Arc::new(RwLock::new(None));
256 let debug_info_cloned = debug_info.clone();
257
258 let socket_health_tcp_read = socket_health.clone();
259 let tcp_read_handle = tokio::spawn(async move {
260 loop {
262 match tcp_read.read().await {
263 Ok(ws_resp) => {
264 let id = ws_resp.request_id();
266 match &ws_resp {
267 WebSocketResponse::Success(SuccessWebSocketResponse {
268 resp: OkWebSocketResponseData::ModelingBatch { responses },
269 ..
270 }) => {
271 #[expect(
272 clippy::iter_over_hash_type,
273 reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
274 )]
275 for (resp_id, batch_response) in responses {
276 let id: uuid::Uuid = (*resp_id).into();
277 match batch_response {
278 BatchResponse::Success { response } => {
279 response_information_cloned
282 .add(
283 id,
284 WebSocketResponse::Success(SuccessWebSocketResponse {
285 success: true,
286 request_id: Some(id),
287 resp: OkWebSocketResponseData::Modeling {
288 modeling_response: response.clone(),
289 },
290 }),
291 )
292 .await;
293 }
294 BatchResponse::Failure { errors } => {
295 response_information_cloned
296 .add(
297 id,
298 WebSocketResponse::Failure(FailureWebSocketResponse {
299 success: false,
300 request_id: Some(id),
301 errors: errors.clone(),
302 }),
303 )
304 .await;
305 }
306 }
307 }
308 }
309 WebSocketResponse::Success(SuccessWebSocketResponse {
310 resp: OkWebSocketResponseData::ModelingSessionData { session },
311 ..
312 }) => {
313 let mut sd = session_data2.write().await;
314 sd.replace(session.clone());
315 }
316 WebSocketResponse::Failure(FailureWebSocketResponse {
317 success: _,
318 request_id,
319 errors,
320 }) => {
321 if let Some(id) = request_id {
322 response_information_cloned
323 .add(
324 *id,
325 WebSocketResponse::Failure(FailureWebSocketResponse {
326 success: false,
327 request_id: *request_id,
328 errors: errors.clone(),
329 }),
330 )
331 .await;
332 } else {
333 let mut pe = pending_errors_clone.write().await;
335 for error in errors {
336 if !pe.contains(&error.message) {
337 pe.push(error.message.clone());
338 }
339 }
340 drop(pe);
341 }
342 }
343 WebSocketResponse::Success(SuccessWebSocketResponse {
344 resp: debug @ OkWebSocketResponseData::Debug { .. },
345 ..
346 }) => {
347 let mut handle = debug_info_cloned.write().await;
348 *handle = Some(debug.clone());
349 }
350 _ => {}
351 }
352
353 if let Some(id) = id {
354 response_information_cloned.add(id, ws_resp.clone()).await;
355 }
356 }
357 Err(e) => {
358 match &e {
359 WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
360 WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
361 }
362 *socket_health_tcp_read.write().await = SocketHealth::Inactive;
363 return Err(e);
364 }
365 }
366 }
367 });
368
369 Ok(EngineConnection {
370 engine_req_tx,
371 shutdown_tx,
372 tcp_read_handle: Arc::new(TcpReadHandle {
373 handle: Arc::new(tcp_read_handle),
374 }),
375 responses: response_information,
376 pending_errors,
377 socket_health,
378 batch: Arc::new(RwLock::new(Vec::new())),
379 batch_end: Arc::new(RwLock::new(IndexMap::new())),
380 ids_of_async_commands,
381 default_planes: Default::default(),
382 session_data,
383 stats: Default::default(),
384 async_tasks: AsyncTasks::new(),
385 debug_info,
386 })
387 }
388}
389
390#[async_trait::async_trait]
391impl EngineManager for EngineConnection {
392 fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
393 self.batch.clone()
394 }
395
396 fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
397 self.batch_end.clone()
398 }
399
400 fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
401 self.responses.responses.clone()
402 }
403
404 fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
405 self.ids_of_async_commands.clone()
406 }
407
408 fn async_tasks(&self) -> AsyncTasks {
409 self.async_tasks.clone()
410 }
411
412 fn stats(&self) -> &EngineStats {
413 &self.stats
414 }
415
416 fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
417 self.default_planes.clone()
418 }
419
420 async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
421 self.debug_info.read().await.clone()
422 }
423
424 async fn fetch_debug(&self) -> Result<(), KclError> {
425 let (tx, rx) = oneshot::channel();
426
427 self.engine_req_tx
428 .send(ToEngineReq {
429 req: WebSocketRequest::Debug {},
430 request_sent: tx,
431 })
432 .await
433 .map_err(|e| KclError::new_engine(KclErrorDetails::new(format!("Failed to send debug: {e}"), vec![])))?;
434
435 let _ = rx.await;
436 Ok(())
437 }
438
439 async fn clear_scene_post_hook(
440 &self,
441 id_generator: &mut IdGenerator,
442 source_range: SourceRange,
443 ) -> Result<(), KclError> {
444 let new_planes = self.new_default_planes(id_generator, source_range).await?;
446 *self.default_planes.write().await = Some(new_planes);
447
448 Ok(())
449 }
450
451 async fn inner_fire_modeling_cmd(
452 &self,
453 _id: uuid::Uuid,
454 source_range: SourceRange,
455 cmd: WebSocketRequest,
456 _id_to_source_range: HashMap<Uuid, SourceRange>,
457 ) -> Result<(), KclError> {
458 let (tx, rx) = oneshot::channel();
459
460 self.engine_req_tx
462 .send(ToEngineReq {
463 req: cmd.clone(),
464 request_sent: tx,
465 })
466 .await
467 .map_err(|e| {
468 KclError::new_engine(KclErrorDetails::new(
469 format!("Failed to send modeling command: {e}"),
470 vec![source_range],
471 ))
472 })?;
473
474 rx.await
476 .map_err(|e| {
477 KclError::new_engine(KclErrorDetails::new(
478 format!("could not send request to the engine actor: {e}"),
479 vec![source_range],
480 ))
481 })?
482 .map_err(|e| {
483 KclError::new_engine(KclErrorDetails::new(
484 format!("could not send request to the engine: {e}"),
485 vec![source_range],
486 ))
487 })?;
488
489 Ok(())
490 }
491
492 async fn inner_send_modeling_cmd(
493 &self,
494 id: uuid::Uuid,
495 source_range: SourceRange,
496 cmd: WebSocketRequest,
497 id_to_source_range: HashMap<Uuid, SourceRange>,
498 ) -> Result<WebSocketResponse, KclError> {
499 self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
500 .await?;
501
502 let response_timeout = 300;
504 let current_time = std::time::Instant::now();
505 while current_time.elapsed().as_secs() < response_timeout {
506 let guard = self.socket_health.read().await;
507 if *guard == SocketHealth::Inactive {
508 let pe = self.pending_errors.read().await;
510 if !pe.is_empty() {
511 return Err(KclError::new_engine(KclErrorDetails::new(
512 pe.join(", "),
513 vec![source_range],
514 )));
515 } else {
516 return Err(KclError::new_engine(KclErrorDetails::new(
517 "Modeling command failed: websocket closed early".to_string(),
518 vec![source_range],
519 )));
520 }
521 }
522
523 #[cfg(feature = "artifact-graph")]
524 {
525 if let Some(resp) = self.responses.responses.read().await.get(&id) {
527 return Ok(resp.clone());
528 }
529 }
530 #[cfg(not(feature = "artifact-graph"))]
531 {
532 if let Some(resp) = self.responses.responses.write().await.shift_remove(&id) {
533 return Ok(resp);
534 }
535 }
536 }
537
538 Err(KclError::new_engine(KclErrorDetails::new(
539 format!("Modeling command timed out `{id}`"),
540 vec![source_range],
541 )))
542 }
543
544 async fn get_session_data(&self) -> Option<ModelingSessionData> {
545 self.session_data.read().await.clone()
546 }
547
548 async fn close(&self) {
549 let _ = self.shutdown_tx.send(()).await;
550 loop {
551 let guard = self.socket_health.read().await;
552 if *guard == SocketHealth::Inactive {
553 return;
554 }
555 }
556 }
557}