1use std::{collections::HashMap, sync::Arc};
5
6use anyhow::{anyhow, Result};
7use futures::{SinkExt, StreamExt};
8use indexmap::IndexMap;
9use kcmc::{
10 websocket::{
11 BatchResponse, FailureWebSocketResponse, ModelingCmdReq, ModelingSessionData, OkWebSocketResponseData,
12 SuccessWebSocketResponse, WebSocketRequest, WebSocketResponse,
13 },
14 ModelingCmd,
15};
16use kittycad_modeling_cmds::{self as kcmc};
17use tokio::sync::{mpsc, oneshot, RwLock};
18use tokio_tungstenite::tungstenite::Message as WsMsg;
19use uuid::Uuid;
20
21#[cfg(feature = "artifact-graph")]
22use crate::execution::ArtifactCommand;
23use crate::{
24 engine::{AsyncTasks, EngineManager, EngineStats},
25 errors::{KclError, KclErrorDetails},
26 execution::{DefaultPlanes, IdGenerator},
27 SourceRange,
28};
29
30#[derive(Debug, PartialEq)]
31enum SocketHealth {
32 Active,
33 Inactive,
34}
35
36type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
37#[derive(Debug)]
38pub struct EngineConnection {
39 engine_req_tx: mpsc::Sender<ToEngineReq>,
40 shutdown_tx: mpsc::Sender<()>,
41 responses: ResponseInformation,
42 pending_errors: Arc<RwLock<Vec<String>>>,
43 #[allow(dead_code)]
44 tcp_read_handle: Arc<TcpReadHandle>,
45 socket_health: Arc<RwLock<SocketHealth>>,
46 batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
47 batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
48 #[cfg(feature = "artifact-graph")]
49 artifact_commands: Arc<RwLock<Vec<ArtifactCommand>>>,
50 ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
51
52 default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
54 session_data: Arc<RwLock<Option<ModelingSessionData>>>,
56
57 stats: EngineStats,
58
59 async_tasks: AsyncTasks,
60
61 debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
62}
63
64pub struct TcpRead {
65 stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
66}
67
68#[allow(clippy::large_enum_variant)]
71pub enum WebSocketReadError {
72 Read(tokio_tungstenite::tungstenite::Error),
74 Deser(anyhow::Error),
76}
77
78impl From<anyhow::Error> for WebSocketReadError {
79 fn from(e: anyhow::Error) -> Self {
80 Self::Deser(e)
81 }
82}
83
84impl TcpRead {
85 pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
86 let Some(msg) = self.stream.next().await else {
87 return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
88 };
89 let msg = match msg {
90 Ok(msg) => msg,
91 Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
92 return Err(WebSocketReadError::Read(e))
93 }
94 Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
95 };
96 let msg: WebSocketResponse = match msg {
97 WsMsg::Text(text) => serde_json::from_str(&text)
98 .map_err(anyhow::Error::from)
99 .map_err(WebSocketReadError::from)?,
100 WsMsg::Binary(bin) => bson::from_slice(&bin)
101 .map_err(anyhow::Error::from)
102 .map_err(WebSocketReadError::from)?,
103 other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
104 };
105 Ok(msg)
106 }
107}
108
109pub struct TcpReadHandle {
110 handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
111}
112
113impl std::fmt::Debug for TcpReadHandle {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 write!(f, "TcpReadHandle")
116 }
117}
118
119impl Drop for TcpReadHandle {
120 fn drop(&mut self) {
121 self.handle.abort();
123 }
124}
125
126#[derive(Clone, Debug)]
128struct ResponseInformation {
129 responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
131}
132
133impl ResponseInformation {
134 pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
135 self.responses.write().await.insert(id, response);
136 }
137}
138
139struct ToEngineReq {
141 req: WebSocketRequest,
143 request_sent: oneshot::Sender<Result<()>>,
147}
148
149impl EngineConnection {
150 async fn start_write_actor(
152 mut tcp_write: WebSocketTcpWrite,
153 mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
154 mut shutdown_rx: mpsc::Receiver<()>,
155 ) {
156 loop {
157 tokio::select! {
158 maybe_req = engine_req_rx.recv() => {
159 match maybe_req {
160 Some(ToEngineReq { req, request_sent }) => {
161 let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
164 cmd: ModelingCmd::ImportFiles { .. },
165 cmd_id: _,
166 }) = &req
167 {
168 Self::inner_send_to_engine_binary(req, &mut tcp_write).await
169 } else {
170 Self::inner_send_to_engine(req, &mut tcp_write).await
171 };
172
173 let _ = request_sent.send(res);
175 }
176 None => {
177 break;
180 }
181 }
182 },
183
184 _ = shutdown_rx.recv() => {
186 let _ = Self::inner_close_engine(&mut tcp_write).await;
187 return;
188 }
189 }
190 }
191
192 let _ = Self::inner_close_engine(&mut tcp_write).await;
195 }
196
197 async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
199 tcp_write
200 .send(WsMsg::Close(None))
201 .await
202 .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
203 Ok(())
204 }
205
206 async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
208 let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
209 tcp_write
210 .send(WsMsg::Text(msg.into()))
211 .await
212 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
213 Ok(())
214 }
215
216 async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
218 let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
219 tcp_write
220 .send(WsMsg::Binary(msg.into()))
221 .await
222 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
223 Ok(())
224 }
225
226 pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
227 let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
228 .max_message_size(Some(usize::MAX))
230 .max_frame_size(Some(usize::MAX));
231
232 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
233 ws,
234 tokio_tungstenite::tungstenite::protocol::Role::Client,
235 Some(wsconfig),
236 )
237 .await;
238
239 let (tcp_write, tcp_read) = ws_stream.split();
240 let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
241 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
242 tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
243
244 let mut tcp_read = TcpRead { stream: tcp_read };
245
246 let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
247 let session_data2 = session_data.clone();
248 let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
249 let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
250 let pending_errors = Arc::new(RwLock::new(Vec::new()));
251 let pending_errors_clone = pending_errors.clone();
252 let response_information = ResponseInformation {
253 responses: Arc::new(RwLock::new(IndexMap::new())),
254 };
255 let response_information_cloned = response_information.clone();
256 let debug_info = Arc::new(RwLock::new(None));
257 let debug_info_cloned = debug_info.clone();
258
259 let socket_health_tcp_read = socket_health.clone();
260 let tcp_read_handle = tokio::spawn(async move {
261 loop {
263 match tcp_read.read().await {
264 Ok(ws_resp) => {
265 let id = ws_resp.request_id();
267 match &ws_resp {
268 WebSocketResponse::Success(SuccessWebSocketResponse {
269 resp: OkWebSocketResponseData::ModelingBatch { responses },
270 ..
271 }) => {
272 #[expect(
273 clippy::iter_over_hash_type,
274 reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
275 )]
276 for (resp_id, batch_response) in responses {
277 let id: uuid::Uuid = (*resp_id).into();
278 match batch_response {
279 BatchResponse::Success { response } => {
280 response_information_cloned
283 .add(
284 id,
285 WebSocketResponse::Success(SuccessWebSocketResponse {
286 success: true,
287 request_id: Some(id),
288 resp: OkWebSocketResponseData::Modeling {
289 modeling_response: response.clone(),
290 },
291 }),
292 )
293 .await;
294 }
295 BatchResponse::Failure { errors } => {
296 response_information_cloned
297 .add(
298 id,
299 WebSocketResponse::Failure(FailureWebSocketResponse {
300 success: false,
301 request_id: Some(id),
302 errors: errors.clone(),
303 }),
304 )
305 .await;
306 }
307 }
308 }
309 }
310 WebSocketResponse::Success(SuccessWebSocketResponse {
311 resp: OkWebSocketResponseData::ModelingSessionData { session },
312 ..
313 }) => {
314 let mut sd = session_data2.write().await;
315 sd.replace(session.clone());
316 }
317 WebSocketResponse::Failure(FailureWebSocketResponse {
318 success: _,
319 request_id,
320 errors,
321 }) => {
322 if let Some(id) = request_id {
323 response_information_cloned
324 .add(
325 *id,
326 WebSocketResponse::Failure(FailureWebSocketResponse {
327 success: false,
328 request_id: *request_id,
329 errors: errors.clone(),
330 }),
331 )
332 .await;
333 } else {
334 let mut pe = pending_errors_clone.write().await;
336 for error in errors {
337 if !pe.contains(&error.message) {
338 pe.push(error.message.clone());
339 }
340 }
341 drop(pe);
342 }
343 }
344 WebSocketResponse::Success(SuccessWebSocketResponse {
345 resp: debug @ OkWebSocketResponseData::Debug { .. },
346 ..
347 }) => {
348 let mut handle = debug_info_cloned.write().await;
349 *handle = Some(debug.clone());
350 }
351 _ => {}
352 }
353
354 if let Some(id) = id {
355 response_information_cloned.add(id, ws_resp.clone()).await;
356 }
357 }
358 Err(e) => {
359 match &e {
360 WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
361 WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
362 }
363 *socket_health_tcp_read.write().await = SocketHealth::Inactive;
364 return Err(e);
365 }
366 }
367 }
368 });
369
370 Ok(EngineConnection {
371 engine_req_tx,
372 shutdown_tx,
373 tcp_read_handle: Arc::new(TcpReadHandle {
374 handle: Arc::new(tcp_read_handle),
375 }),
376 responses: response_information,
377 pending_errors,
378 socket_health,
379 batch: Arc::new(RwLock::new(Vec::new())),
380 batch_end: Arc::new(RwLock::new(IndexMap::new())),
381 #[cfg(feature = "artifact-graph")]
382 artifact_commands: Arc::new(RwLock::new(Vec::new())),
383 ids_of_async_commands,
384 default_planes: Default::default(),
385 session_data,
386 stats: Default::default(),
387 async_tasks: AsyncTasks::new(),
388 debug_info,
389 })
390 }
391}
392
393#[async_trait::async_trait]
394impl EngineManager for EngineConnection {
395 fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
396 self.batch.clone()
397 }
398
399 fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
400 self.batch_end.clone()
401 }
402
403 fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
404 self.responses.responses.clone()
405 }
406
407 #[cfg(feature = "artifact-graph")]
408 fn artifact_commands(&self) -> Arc<RwLock<Vec<ArtifactCommand>>> {
409 self.artifact_commands.clone()
410 }
411
412 fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
413 self.ids_of_async_commands.clone()
414 }
415
416 fn async_tasks(&self) -> AsyncTasks {
417 self.async_tasks.clone()
418 }
419
420 fn stats(&self) -> &EngineStats {
421 &self.stats
422 }
423
424 fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
425 self.default_planes.clone()
426 }
427
428 async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
429 self.debug_info.read().await.clone()
430 }
431
432 async fn fetch_debug(&self) -> Result<(), KclError> {
433 let (tx, rx) = oneshot::channel();
434
435 self.engine_req_tx
436 .send(ToEngineReq {
437 req: WebSocketRequest::Debug {},
438 request_sent: tx,
439 })
440 .await
441 .map_err(|e| KclError::new_engine(KclErrorDetails::new(format!("Failed to send debug: {}", e), vec![])))?;
442
443 let _ = rx.await;
444 Ok(())
445 }
446
447 async fn clear_scene_post_hook(
448 &self,
449 id_generator: &mut IdGenerator,
450 source_range: SourceRange,
451 ) -> Result<(), KclError> {
452 let new_planes = self.new_default_planes(id_generator, source_range).await?;
454 *self.default_planes.write().await = Some(new_planes);
455
456 Ok(())
457 }
458
459 async fn inner_fire_modeling_cmd(
460 &self,
461 _id: uuid::Uuid,
462 source_range: SourceRange,
463 cmd: WebSocketRequest,
464 _id_to_source_range: HashMap<Uuid, SourceRange>,
465 ) -> Result<(), KclError> {
466 let (tx, rx) = oneshot::channel();
467
468 self.engine_req_tx
470 .send(ToEngineReq {
471 req: cmd.clone(),
472 request_sent: tx,
473 })
474 .await
475 .map_err(|e| {
476 KclError::new_engine(KclErrorDetails::new(
477 format!("Failed to send modeling command: {}", e),
478 vec![source_range],
479 ))
480 })?;
481
482 rx.await
484 .map_err(|e| {
485 KclError::new_engine(KclErrorDetails::new(
486 format!("could not send request to the engine actor: {e}"),
487 vec![source_range],
488 ))
489 })?
490 .map_err(|e| {
491 KclError::new_engine(KclErrorDetails::new(
492 format!("could not send request to the engine: {e}"),
493 vec![source_range],
494 ))
495 })?;
496
497 Ok(())
498 }
499
500 async fn inner_send_modeling_cmd(
501 &self,
502 id: uuid::Uuid,
503 source_range: SourceRange,
504 cmd: WebSocketRequest,
505 id_to_source_range: HashMap<Uuid, SourceRange>,
506 ) -> Result<WebSocketResponse, KclError> {
507 self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
508 .await?;
509
510 let current_time = std::time::Instant::now();
512 while current_time.elapsed().as_secs() < 60 {
513 let guard = self.socket_health.read().await;
514 if *guard == SocketHealth::Inactive {
515 let pe = self.pending_errors.read().await;
517 if !pe.is_empty() {
518 return Err(KclError::new_engine(KclErrorDetails::new(
519 pe.join(", ").to_string(),
520 vec![source_range],
521 )));
522 } else {
523 return Err(KclError::new_engine(KclErrorDetails::new(
524 "Modeling command failed: websocket closed early".to_string(),
525 vec![source_range],
526 )));
527 }
528 }
529
530 #[cfg(feature = "artifact-graph")]
531 {
532 if let Some(resp) = self.responses.responses.read().await.get(&id) {
534 return Ok(resp.clone());
535 }
536 }
537 #[cfg(not(feature = "artifact-graph"))]
538 {
539 if let Some(resp) = self.responses.responses.write().await.shift_remove(&id) {
540 return Ok(resp);
541 }
542 }
543 }
544
545 Err(KclError::new_engine(KclErrorDetails::new(
546 format!("Modeling command timed out `{}`", id),
547 vec![source_range],
548 )))
549 }
550
551 async fn get_session_data(&self) -> Option<ModelingSessionData> {
552 self.session_data.read().await.clone()
553 }
554
555 async fn close(&self) {
556 let _ = self.shutdown_tx.send(()).await;
557 loop {
558 let guard = self.socket_health.read().await;
559 if *guard == SocketHealth::Inactive {
560 return;
561 }
562 }
563 }
564}