1use std::collections::HashMap;
5use std::sync::Arc;
6
7use anyhow::Result;
8use anyhow::anyhow;
9use futures::SinkExt;
10use futures::StreamExt;
11use indexmap::IndexMap;
12use kcmc::ModelingCmd;
13use kcmc::websocket::BatchResponse;
14use kcmc::websocket::FailureWebSocketResponse;
15use kcmc::websocket::ModelingCmdReq;
16use kcmc::websocket::ModelingSessionData;
17use kcmc::websocket::OkWebSocketResponseData;
18use kcmc::websocket::SuccessWebSocketResponse;
19use kcmc::websocket::WebSocketRequest;
20use kcmc::websocket::WebSocketResponse;
21use kittycad_modeling_cmds::{self as kcmc};
22use tokio::sync::RwLock;
23use tokio::sync::mpsc;
24use tokio::sync::oneshot;
25use tokio_tungstenite::tungstenite::Message as WsMsg;
26use uuid::Uuid;
27
28use crate::SourceRange;
29use crate::engine::AsyncTasks;
30use crate::engine::EngineBatchContext;
31use crate::engine::EngineManager;
32use crate::engine::EngineStats;
33use crate::errors::KclError;
34use crate::errors::KclErrorDetails;
35use crate::execution::DefaultPlanes;
36use crate::execution::IdGenerator;
37use crate::log::logln;
38
39#[derive(Debug, PartialEq)]
40enum SocketHealth {
41 Active,
42 Inactive,
43}
44
45type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
46#[derive(Debug)]
47pub struct EngineConnection {
48 engine_req_tx: mpsc::Sender<ToEngineReq>,
49 shutdown_tx: mpsc::Sender<()>,
50 responses: ResponseInformation,
51 pending_errors: Arc<RwLock<Vec<String>>>,
52 #[allow(dead_code)]
53 tcp_read_handle: Arc<TcpReadHandle>,
54 socket_health: Arc<RwLock<SocketHealth>>,
55 ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
56
57 default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
59 session_data: Arc<RwLock<Option<ModelingSessionData>>>,
61
62 stats: EngineStats,
63
64 async_tasks: AsyncTasks,
65
66 debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
67}
68
69pub struct TcpRead {
70 stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
71}
72
73#[allow(clippy::large_enum_variant)]
76pub enum WebSocketReadError {
77 Read(tokio_tungstenite::tungstenite::Error),
79 Deser(anyhow::Error),
81}
82
83impl From<anyhow::Error> for WebSocketReadError {
84 fn from(e: anyhow::Error) -> Self {
85 Self::Deser(e)
86 }
87}
88
89impl TcpRead {
90 pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
91 let Some(msg) = self.stream.next().await else {
92 return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
93 };
94 let msg = match msg {
95 Ok(msg) => msg,
96 Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
97 return Err(WebSocketReadError::Read(e));
98 }
99 Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
100 };
101 let msg: WebSocketResponse = match msg {
102 WsMsg::Text(text) => serde_json::from_str(&text)
103 .map_err(anyhow::Error::from)
104 .map_err(WebSocketReadError::from)?,
105 WsMsg::Binary(bin) => rmp_serde::from_slice(&bin)
106 .map_err(anyhow::Error::from)
107 .map_err(WebSocketReadError::from)?,
108 other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
109 };
110 Ok(msg)
111 }
112}
113
114pub struct TcpReadHandle {
115 handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
116}
117
118impl std::fmt::Debug for TcpReadHandle {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "TcpReadHandle")
121 }
122}
123
124impl Drop for TcpReadHandle {
125 fn drop(&mut self) {
126 self.handle.abort();
128 }
129}
130
131#[derive(Clone, Debug)]
133struct ResponseInformation {
134 responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
136}
137
138impl ResponseInformation {
139 pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
140 self.responses.write().await.insert(id, response);
141 }
142}
143
144struct ToEngineReq {
146 req: WebSocketRequest,
148 request_sent: oneshot::Sender<Result<()>>,
152}
153
154impl EngineConnection {
155 async fn start_write_actor(
157 mut tcp_write: WebSocketTcpWrite,
158 mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
159 mut shutdown_rx: mpsc::Receiver<()>,
160 ) {
161 loop {
162 tokio::select! {
163 maybe_req = engine_req_rx.recv() => {
164 match maybe_req {
165 Some(ToEngineReq { req, request_sent }) => {
166 let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
169 cmd: ModelingCmd::ImportFiles { .. },
170 cmd_id: _,
171 }) = &req
172 {
173 Self::inner_send_to_engine_binary(req, &mut tcp_write).await
174 } else {
175 Self::inner_send_to_engine(req, &mut tcp_write).await
176 };
177
178 let _ = request_sent.send(res);
180 }
181 None => {
182 break;
185 }
186 }
187 },
188
189 _ = shutdown_rx.recv() => {
191 let _ = Self::inner_close_engine(&mut tcp_write).await;
192 return;
193 }
194 }
195 }
196
197 let _ = Self::inner_close_engine(&mut tcp_write).await;
200 }
201
202 async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
204 tcp_write
205 .send(WsMsg::Close(None))
206 .await
207 .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
208 Ok(())
209 }
210
211 async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
213 let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
214 tcp_write
215 .send(WsMsg::Text(msg.into()))
216 .await
217 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
218 Ok(())
219 }
220
221 async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
223 let msg = rmp_serde::to_vec_named(&request).map_err(|e| anyhow!("could not serialize msgpack: {e}"))?;
224 tcp_write
225 .send(WsMsg::Binary(msg.into()))
226 .await
227 .map_err(|e| anyhow!("could not send MsgPack over websocket: {e}"))?;
228 Ok(())
229 }
230
231 pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
232 let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
233 .max_message_size(Some(usize::MAX))
235 .max_frame_size(Some(usize::MAX));
236
237 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
238 ws,
239 tokio_tungstenite::tungstenite::protocol::Role::Client,
240 Some(wsconfig),
241 )
242 .await;
243
244 let (tcp_write, tcp_read) = ws_stream.split();
245 let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
246 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
247 tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
248
249 let mut tcp_read = TcpRead { stream: tcp_read };
250
251 let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
252 let session_data2 = session_data.clone();
253 let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
254 let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
255 let pending_errors = Arc::new(RwLock::new(Vec::new()));
256 let pending_errors_clone = pending_errors.clone();
257 let response_information = ResponseInformation {
258 responses: Arc::new(RwLock::new(IndexMap::new())),
259 };
260 let response_information_cloned = response_information.clone();
261 let debug_info = Arc::new(RwLock::new(None));
262 let debug_info_cloned = debug_info.clone();
263
264 let socket_health_tcp_read = socket_health.clone();
265 let tcp_read_handle = tokio::spawn(async move {
266 loop {
268 match tcp_read.read().await {
269 Ok(ws_resp) => {
270 let id = ws_resp.request_id();
272 match &ws_resp {
273 WebSocketResponse::Success(SuccessWebSocketResponse {
274 resp: OkWebSocketResponseData::ModelingBatch { responses },
275 ..
276 }) => {
277 #[expect(
278 clippy::iter_over_hash_type,
279 reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
280 )]
281 for (resp_id, batch_response) in responses {
282 let id: uuid::Uuid = (*resp_id).into();
283 match batch_response {
284 BatchResponse::Success { response } => {
285 response_information_cloned
288 .add(
289 id,
290 WebSocketResponse::Success(SuccessWebSocketResponse {
291 success: true,
292 request_id: Some(id),
293 resp: OkWebSocketResponseData::Modeling {
294 modeling_response: response.clone(),
295 },
296 }),
297 )
298 .await;
299 }
300 BatchResponse::Failure { errors } => {
301 response_information_cloned
302 .add(
303 id,
304 WebSocketResponse::Failure(FailureWebSocketResponse {
305 success: false,
306 request_id: Some(id),
307 errors: errors.clone(),
308 }),
309 )
310 .await;
311 }
312 }
313 }
314 }
315 WebSocketResponse::Success(SuccessWebSocketResponse {
316 resp: OkWebSocketResponseData::ModelingSessionData { session },
317 ..
318 }) => {
319 let mut sd = session_data2.write().await;
320 sd.replace(session.clone());
321 logln!("API Call ID: {}", session.api_call_id);
322 }
323 WebSocketResponse::Failure(FailureWebSocketResponse {
324 success: _,
325 request_id,
326 errors,
327 }) => {
328 if let Some(id) = request_id {
329 response_information_cloned
330 .add(
331 *id,
332 WebSocketResponse::Failure(FailureWebSocketResponse {
333 success: false,
334 request_id: *request_id,
335 errors: errors.clone(),
336 }),
337 )
338 .await;
339 } else {
340 let mut pe = pending_errors_clone.write().await;
342 for error in errors {
343 if !pe.contains(&error.message) {
344 pe.push(error.message.clone());
345 }
346 }
347 drop(pe);
348 }
349 }
350 WebSocketResponse::Success(SuccessWebSocketResponse {
351 resp: debug @ OkWebSocketResponseData::Debug { .. },
352 ..
353 }) => {
354 let mut handle = debug_info_cloned.write().await;
355 *handle = Some(debug.clone());
356 }
357 _ => {}
358 }
359
360 if let Some(id) = id {
361 response_information_cloned.add(id, ws_resp.clone()).await;
362 }
363 }
364 Err(e) => {
365 match &e {
366 WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
367 WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
368 }
369 *socket_health_tcp_read.write().await = SocketHealth::Inactive;
370 return Err(e);
371 }
372 }
373 }
374 });
375
376 Ok(EngineConnection {
377 engine_req_tx,
378 shutdown_tx,
379 tcp_read_handle: Arc::new(TcpReadHandle {
380 handle: Arc::new(tcp_read_handle),
381 }),
382 responses: response_information,
383 pending_errors,
384 socket_health,
385 ids_of_async_commands,
386 default_planes: Default::default(),
387 session_data,
388 stats: Default::default(),
389 async_tasks: AsyncTasks::new(),
390 debug_info,
391 })
392 }
393}
394
395#[async_trait::async_trait]
396impl EngineManager for EngineConnection {
397 fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
398 self.responses.responses.clone()
399 }
400
401 fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
402 self.ids_of_async_commands.clone()
403 }
404
405 fn async_tasks(&self) -> AsyncTasks {
406 self.async_tasks.clone()
407 }
408
409 fn stats(&self) -> &EngineStats {
410 &self.stats
411 }
412
413 fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
414 self.default_planes.clone()
415 }
416
417 async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
418 self.debug_info.read().await.clone()
419 }
420
421 async fn fetch_debug(&self) -> Result<(), KclError> {
422 let (tx, rx) = oneshot::channel();
423
424 self.engine_req_tx
425 .send(ToEngineReq {
426 req: WebSocketRequest::Debug {},
427 request_sent: tx,
428 })
429 .await
430 .map_err(|e| KclError::new_engine(KclErrorDetails::new(format!("Failed to send debug: {e}"), vec![])))?;
431
432 let _ = rx.await;
433 Ok(())
434 }
435
436 async fn clear_scene_post_hook(
437 &self,
438 batch_context: &EngineBatchContext,
439 id_generator: &mut IdGenerator,
440 source_range: SourceRange,
441 ) -> Result<(), KclError> {
442 let new_planes = self
444 .new_default_planes(batch_context, id_generator, source_range)
445 .await?;
446 *self.default_planes.write().await = Some(new_planes);
447
448 Ok(())
449 }
450
451 async fn inner_fire_modeling_cmd(
452 &self,
453 _id: uuid::Uuid,
454 source_range: SourceRange,
455 cmd: WebSocketRequest,
456 _id_to_source_range: HashMap<Uuid, SourceRange>,
457 ) -> Result<(), KclError> {
458 let (tx, rx) = oneshot::channel();
459
460 self.engine_req_tx
462 .send(ToEngineReq {
463 req: cmd.clone(),
464 request_sent: tx,
465 })
466 .await
467 .map_err(|e| {
468 KclError::new_engine(KclErrorDetails::new(
469 format!("Failed to send modeling command: {e}"),
470 vec![source_range],
471 ))
472 })?;
473
474 rx.await
476 .map_err(|e| {
477 KclError::new_engine_hangup(
478 KclErrorDetails::new(
479 format!("could not send request to the engine actor: {e}"),
480 vec![source_range],
481 ),
482 None,
483 )
484 })?
485 .map_err(|e| {
486 KclError::new_engine_hangup(
487 KclErrorDetails::new(format!("could not send request to the engine: {e}"), vec![source_range]),
488 None,
489 )
490 })?;
491
492 Ok(())
493 }
494
495 async fn inner_send_modeling_cmd(
496 &self,
497 id: uuid::Uuid,
498 source_range: SourceRange,
499 cmd: WebSocketRequest,
500 id_to_source_range: HashMap<Uuid, SourceRange>,
501 ) -> Result<WebSocketResponse, KclError> {
502 self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
503 .await?;
504
505 let response_timeout = 300;
507 let current_time = std::time::Instant::now();
508 while current_time.elapsed().as_secs() < response_timeout {
509 let guard = self.socket_health.read().await;
510 if *guard == SocketHealth::Inactive {
511 let session_data = self.session_data.read().await;
513 let api_call_id = session_data.as_ref().map(|session| session.api_call_id.to_string());
514 let api_call_id_msg = if let Some(ref id) = api_call_id {
515 format!(" (API call ID: {})", id)
516 } else {
517 String::new()
518 };
519
520 let pe = self.pending_errors.read().await;
522 if !pe.is_empty() {
523 return Err(KclError::new_engine(KclErrorDetails::new(
524 format!("{}{}", pe.join(", "), api_call_id_msg),
525 vec![source_range],
526 )));
527 } else {
528 return Err(KclError::new_engine_hangup(
529 KclErrorDetails::new(
530 format!("Modeling command failed: websocket closed early{}", api_call_id_msg),
531 vec![source_range],
532 ),
533 api_call_id,
534 ));
535 }
536 }
537
538 #[cfg(feature = "artifact-graph")]
539 {
540 if let Some(resp) = self.responses.responses.read().await.get(&id) {
542 return Ok(resp.clone());
543 }
544 }
545 #[cfg(not(feature = "artifact-graph"))]
546 {
547 if let Some(resp) = self.responses.responses.write().await.shift_remove(&id) {
548 return Ok(resp);
549 }
550 }
551 }
552
553 let session_data = self.session_data.read().await;
555 let api_call_id_msg = if let Some(session) = session_data.as_ref() {
556 format!(" (API call ID: {})", session.api_call_id)
557 } else {
558 String::new()
559 };
560
561 Err(KclError::new_engine(KclErrorDetails::new(
562 format!("Modeling command timed out `{id}`{}", api_call_id_msg),
563 vec![source_range],
564 )))
565 }
566
567 async fn get_session_data(&self) -> Option<ModelingSessionData> {
568 self.session_data.read().await.clone()
569 }
570
571 async fn close(&self) {
572 let _ = self.shutdown_tx.send(()).await;
573 loop {
574 let guard = self.socket_health.read().await;
575 if *guard == SocketHealth::Inactive {
576 return;
577 }
578 }
579 }
580}