1use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::Result;
9use anyhow::anyhow;
10use futures::SinkExt;
11use futures::StreamExt;
12use indexmap::IndexMap;
13use kcmc::ModelingCmd;
14use kcmc::websocket::BatchResponse;
15use kcmc::websocket::FailureWebSocketResponse;
16use kcmc::websocket::ModelingCmdReq;
17use kcmc::websocket::ModelingSessionData;
18use kcmc::websocket::OkWebSocketResponseData;
19use kcmc::websocket::SuccessWebSocketResponse;
20use kcmc::websocket::WebSocketRequest;
21use kcmc::websocket::WebSocketResponse;
22use kittycad_modeling_cmds::{self as kcmc};
23use tokio::sync::RwLock;
24use tokio::sync::mpsc;
25use tokio::sync::oneshot;
26use tokio_tungstenite::tungstenite::Message as WsMsg;
27use uuid::Uuid;
28
29use crate::SourceRange;
30use crate::engine::AsyncTasks;
31use crate::engine::EngineBatchContext;
32use crate::engine::EngineManager;
33use crate::engine::EngineStats;
34use crate::errors::KclError;
35use crate::errors::KclErrorDetails;
36use crate::execution::DefaultPlanes;
37use crate::execution::IdGenerator;
38use crate::log::logln;
39
40#[derive(Debug, PartialEq)]
41enum SocketHealth {
42 Active,
43 Inactive,
44}
45
46type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
47#[derive(Debug)]
48pub struct EngineConnection {
49 engine_req_tx: mpsc::Sender<ToEngineReq>,
50 shutdown_tx: mpsc::Sender<()>,
51 responses: ResponseInformation,
52 pending_errors: Arc<RwLock<Vec<String>>>,
53 #[allow(dead_code)]
54 tcp_read_handle: Arc<TcpReadHandle>,
55 socket_health: Arc<RwLock<SocketHealth>>,
56 ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
57
58 default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
60 session_data: Arc<RwLock<Option<ModelingSessionData>>>,
62
63 stats: EngineStats,
64
65 async_tasks: AsyncTasks,
66
67 debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
68}
69
70pub struct TcpRead {
71 stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
72}
73
74#[allow(clippy::large_enum_variant)]
77pub enum WebSocketReadError {
78 Read(tokio_tungstenite::tungstenite::Error),
80 Deser(anyhow::Error),
82}
83
84impl From<anyhow::Error> for WebSocketReadError {
85 fn from(e: anyhow::Error) -> Self {
86 Self::Deser(e)
87 }
88}
89
90impl TcpRead {
91 pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
92 let Some(msg) = self.stream.next().await else {
93 return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
94 };
95 let msg = match msg {
96 Ok(msg) => msg,
97 Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
98 return Err(WebSocketReadError::Read(e));
99 }
100 Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
101 };
102 let msg: WebSocketResponse = match msg {
103 WsMsg::Text(text) => serde_json::from_str(&text)
104 .map_err(anyhow::Error::from)
105 .map_err(WebSocketReadError::from)?,
106 WsMsg::Binary(bin) => rmp_serde::from_slice(&bin)
107 .map_err(anyhow::Error::from)
108 .map_err(WebSocketReadError::from)?,
109 other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
110 };
111 Ok(msg)
112 }
113}
114
115pub struct TcpReadHandle {
116 handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
117}
118
119impl std::fmt::Debug for TcpReadHandle {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 write!(f, "TcpReadHandle")
122 }
123}
124
125impl Drop for TcpReadHandle {
126 fn drop(&mut self) {
127 self.handle.abort();
129 }
130}
131
132#[derive(Clone, Debug)]
134struct ResponseInformation {
135 responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
137}
138
139impl ResponseInformation {
140 pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
141 self.responses.write().await.insert(id, response);
142 }
143}
144
145struct ToEngineReq {
147 req: WebSocketRequest,
149 request_sent: oneshot::Sender<Result<()>>,
153}
154
155impl EngineConnection {
156 async fn start_write_actor(
160 mut tcp_write: WebSocketTcpWrite,
161 mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
162 mut shutdown_rx: mpsc::Receiver<()>,
163 heartbeats: Option<u64>,
164 ) {
165 let heartbeats = heartbeats.unwrap_or_default();
166 let send_heartbeats = heartbeats != 0;
167 let period_seconds = if heartbeats == 0 { 5 * 60 } else { heartbeats };
168 let period = Duration::from_secs(period_seconds);
169 let mut heartbeats_stream = tokio::time::interval(period);
170
171 loop {
172 tokio::select! {
173 maybe_req = engine_req_rx.recv() => {
174 match maybe_req {
175 Some(ToEngineReq { req, request_sent }) => {
176 let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
179 cmd: ModelingCmd::ImportFiles { .. },
180 cmd_id: _,
181 }) = &req
182 {
183 Self::inner_send_to_engine_binary(req, &mut tcp_write).await
184 } else {
185 Self::inner_send_to_engine(req, &mut tcp_write).await
186 };
187
188 let _ = request_sent.send(res);
190 }
191 None => {
192 break;
195 }
196 }
197 },
198
199 _ = shutdown_rx.recv() => {
201 let _ = Self::inner_close_engine(&mut tcp_write).await;
202 return;
203 }
204
205 _ = heartbeats_stream.tick(), if send_heartbeats => {
207 let res = Self::inner_send_to_engine(WebSocketRequest::Ping {}, &mut tcp_write).await;
209 let _ = res;
211 }
212 }
213 }
214
215 let _ = Self::inner_close_engine(&mut tcp_write).await;
218 }
219
220 async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
222 tcp_write
223 .send(WsMsg::Close(None))
224 .await
225 .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
226 Ok(())
227 }
228
229 async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
231 let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
232 tcp_write
233 .send(WsMsg::Text(msg.into()))
234 .await
235 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
236 Ok(())
237 }
238
239 async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
241 let msg = rmp_serde::to_vec_named(&request).map_err(|e| anyhow!("could not serialize msgpack: {e}"))?;
242 tcp_write
243 .send(WsMsg::Binary(msg.into()))
244 .await
245 .map_err(|e| anyhow!("could not send MsgPack over websocket: {e}"))?;
246 Ok(())
247 }
248
249 pub async fn new(ws: reqwest::Upgraded, heartbeats: Option<u64>) -> Result<EngineConnection> {
250 let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
251 .max_message_size(Some(usize::MAX))
253 .max_frame_size(Some(usize::MAX));
254
255 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
256 ws,
257 tokio_tungstenite::tungstenite::protocol::Role::Client,
258 Some(wsconfig),
259 )
260 .await;
261
262 let (tcp_write, tcp_read) = ws_stream.split();
263 let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
264 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
265 tokio::task::spawn(Self::start_write_actor(
266 tcp_write,
267 engine_req_rx,
268 shutdown_rx,
269 heartbeats,
270 ));
271
272 let mut tcp_read = TcpRead { stream: tcp_read };
273
274 let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
275 let session_data2 = session_data.clone();
276 let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
277 let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
278 let pending_errors = Arc::new(RwLock::new(Vec::new()));
279 let pending_errors_clone = pending_errors.clone();
280 let response_information = ResponseInformation {
281 responses: Arc::new(RwLock::new(IndexMap::new())),
282 };
283 let response_information_cloned = response_information.clone();
284 let debug_info = Arc::new(RwLock::new(None));
285 let debug_info_cloned = debug_info.clone();
286
287 let socket_health_tcp_read = socket_health.clone();
288 let tcp_read_handle = tokio::spawn(async move {
289 loop {
291 match tcp_read.read().await {
292 Ok(ws_resp) => {
293 let id = ws_resp.request_id();
295 match &ws_resp {
296 WebSocketResponse::Success(SuccessWebSocketResponse {
297 resp: OkWebSocketResponseData::ModelingBatch { responses },
298 ..
299 }) => {
300 #[expect(
301 clippy::iter_over_hash_type,
302 reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
303 )]
304 for (resp_id, batch_response) in responses {
305 let id: uuid::Uuid = (*resp_id).into();
306 match batch_response {
307 BatchResponse::Success { response } => {
308 response_information_cloned
311 .add(
312 id,
313 WebSocketResponse::Success(SuccessWebSocketResponse {
314 success: true,
315 request_id: Some(id),
316 resp: OkWebSocketResponseData::Modeling {
317 modeling_response: response.clone(),
318 },
319 }),
320 )
321 .await;
322 }
323 BatchResponse::Failure { errors } => {
324 response_information_cloned
325 .add(
326 id,
327 WebSocketResponse::Failure(FailureWebSocketResponse {
328 success: false,
329 request_id: Some(id),
330 errors: errors.clone(),
331 }),
332 )
333 .await;
334 }
335 }
336 }
337 }
338 WebSocketResponse::Success(SuccessWebSocketResponse {
339 resp: OkWebSocketResponseData::ModelingSessionData { session },
340 ..
341 }) => {
342 let mut sd = session_data2.write().await;
343 sd.replace(session.clone());
344 logln!("API Call ID: {}", session.api_call_id);
345 }
346 WebSocketResponse::Failure(FailureWebSocketResponse {
347 success: _,
348 request_id,
349 errors,
350 }) => {
351 if let Some(id) = request_id {
352 response_information_cloned
353 .add(
354 *id,
355 WebSocketResponse::Failure(FailureWebSocketResponse {
356 success: false,
357 request_id: *request_id,
358 errors: errors.clone(),
359 }),
360 )
361 .await;
362 } else {
363 let mut pe = pending_errors_clone.write().await;
365 for error in errors {
366 if !pe.contains(&error.message) {
367 pe.push(error.message.clone());
368 }
369 }
370 drop(pe);
371 }
372 }
373 WebSocketResponse::Success(SuccessWebSocketResponse {
374 resp: debug @ OkWebSocketResponseData::Debug { .. },
375 ..
376 }) => {
377 let mut handle = debug_info_cloned.write().await;
378 *handle = Some(debug.clone());
379 }
380 _ => {}
381 }
382
383 if let Some(id) = id {
384 response_information_cloned.add(id, ws_resp.clone()).await;
385 }
386 }
387 Err(e) => {
388 match &e {
389 WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
390 WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
391 }
392 *socket_health_tcp_read.write().await = SocketHealth::Inactive;
393 return Err(e);
394 }
395 }
396 }
397 });
398
399 Ok(EngineConnection {
400 engine_req_tx,
401 shutdown_tx,
402 tcp_read_handle: Arc::new(TcpReadHandle {
403 handle: Arc::new(tcp_read_handle),
404 }),
405 responses: response_information,
406 pending_errors,
407 socket_health,
408 ids_of_async_commands,
409 default_planes: Default::default(),
410 session_data,
411 stats: Default::default(),
412 async_tasks: AsyncTasks::new(),
413 debug_info,
414 })
415 }
416}
417
418#[async_trait::async_trait]
419impl EngineManager for EngineConnection {
420 fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
421 self.responses.responses.clone()
422 }
423
424 fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
425 self.ids_of_async_commands.clone()
426 }
427
428 fn async_tasks(&self) -> AsyncTasks {
429 self.async_tasks.clone()
430 }
431
432 fn stats(&self) -> &EngineStats {
433 &self.stats
434 }
435
436 fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
437 self.default_planes.clone()
438 }
439
440 async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
441 self.debug_info.read().await.clone()
442 }
443
444 async fn fetch_debug(&self) -> Result<(), KclError> {
445 let (tx, rx) = oneshot::channel();
446
447 self.engine_req_tx
448 .send(ToEngineReq {
449 req: WebSocketRequest::Debug {},
450 request_sent: tx,
451 })
452 .await
453 .map_err(|e| KclError::new_engine(KclErrorDetails::new(format!("Failed to send debug: {e}"), vec![])))?;
454
455 let _ = rx.await;
456 Ok(())
457 }
458
459 async fn clear_scene_post_hook(
460 &self,
461 batch_context: &EngineBatchContext,
462 id_generator: &mut IdGenerator,
463 source_range: SourceRange,
464 ) -> Result<(), KclError> {
465 let new_planes = self
467 .new_default_planes(batch_context, id_generator, source_range)
468 .await?;
469 *self.default_planes.write().await = Some(new_planes);
470
471 Ok(())
472 }
473
474 async fn inner_fire_modeling_cmd(
475 &self,
476 _id: uuid::Uuid,
477 source_range: SourceRange,
478 cmd: WebSocketRequest,
479 _id_to_source_range: HashMap<Uuid, SourceRange>,
480 ) -> Result<(), KclError> {
481 let (tx, rx) = oneshot::channel();
482
483 self.engine_req_tx
485 .send(ToEngineReq {
486 req: cmd.clone(),
487 request_sent: tx,
488 })
489 .await
490 .map_err(|e| {
491 KclError::new_engine(KclErrorDetails::new(
492 format!("Failed to send modeling command: {e}"),
493 vec![source_range],
494 ))
495 })?;
496
497 rx.await
499 .map_err(|e| {
500 KclError::new_engine_hangup(
501 KclErrorDetails::new(
502 format!("could not send request to the engine actor: {e}"),
503 vec![source_range],
504 ),
505 None,
506 )
507 })?
508 .map_err(|e| {
509 KclError::new_engine_hangup(
510 KclErrorDetails::new(format!("could not send request to the engine: {e}"), vec![source_range]),
511 None,
512 )
513 })?;
514
515 Ok(())
516 }
517
518 async fn inner_send_modeling_cmd(
519 &self,
520 id: uuid::Uuid,
521 source_range: SourceRange,
522 cmd: WebSocketRequest,
523 id_to_source_range: HashMap<Uuid, SourceRange>,
524 ) -> Result<WebSocketResponse, KclError> {
525 self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
526 .await?;
527
528 let response_timeout = 600;
530 let current_time = std::time::Instant::now();
531 while current_time.elapsed().as_secs() < response_timeout {
532 let guard = self.socket_health.read().await;
533 if *guard == SocketHealth::Inactive {
534 let session_data = self.session_data.read().await;
536 let api_call_id = session_data.as_ref().map(|session| session.api_call_id.to_string());
537 let api_call_id_msg = if let Some(ref id) = api_call_id {
538 format!(" (API call ID: {})", id)
539 } else {
540 String::new()
541 };
542
543 let pe = self.pending_errors.read().await;
545 if !pe.is_empty() {
546 return Err(KclError::new_engine(KclErrorDetails::new(
547 format!("{}{}", pe.join(", "), api_call_id_msg),
548 vec![source_range],
549 )));
550 } else {
551 return Err(KclError::new_engine_hangup(
552 KclErrorDetails::new(
553 format!("Modeling command failed: websocket closed early{}", api_call_id_msg),
554 vec![source_range],
555 ),
556 api_call_id,
557 ));
558 }
559 }
560
561 if let Some(resp) = self.responses.responses.read().await.get(&id) {
563 return Ok(resp.clone());
564 }
565 }
566
567 let session_data = self.session_data.read().await;
569 let api_call_id_msg = if let Some(session) = session_data.as_ref() {
570 format!(" (API call ID: {})", session.api_call_id)
571 } else {
572 String::new()
573 };
574
575 Err(KclError::new_engine(KclErrorDetails::new(
576 format!("Modeling command timed out `{id}`{}", api_call_id_msg),
577 vec![source_range],
578 )))
579 }
580
581 async fn get_session_data(&self) -> Option<ModelingSessionData> {
582 self.session_data.read().await.clone()
583 }
584
585 async fn close(&self) {
586 let _ = self.shutdown_tx.send(()).await;
587 loop {
588 let guard = self.socket_health.read().await;
589 if *guard == SocketHealth::Inactive {
590 return;
591 }
592 }
593 }
594}