1use std::collections::HashMap;
5use std::sync::Arc;
6
7use anyhow::Result;
8use anyhow::anyhow;
9use futures::SinkExt;
10use futures::StreamExt;
11use indexmap::IndexMap;
12use kcmc::ModelingCmd;
13use kcmc::websocket::BatchResponse;
14use kcmc::websocket::FailureWebSocketResponse;
15use kcmc::websocket::ModelingCmdReq;
16use kcmc::websocket::ModelingSessionData;
17use kcmc::websocket::OkWebSocketResponseData;
18use kcmc::websocket::SuccessWebSocketResponse;
19use kcmc::websocket::WebSocketRequest;
20use kcmc::websocket::WebSocketResponse;
21use kittycad_modeling_cmds::{self as kcmc};
22use tokio::sync::RwLock;
23use tokio::sync::mpsc;
24use tokio::sync::oneshot;
25use tokio_tungstenite::tungstenite::Message as WsMsg;
26use uuid::Uuid;
27
28use crate::SourceRange;
29use crate::engine::AsyncTasks;
30use crate::engine::EngineManager;
31use crate::engine::EngineStats;
32use crate::errors::KclError;
33use crate::errors::KclErrorDetails;
34use crate::execution::DefaultPlanes;
35use crate::execution::IdGenerator;
36use crate::log::logln;
37
38#[derive(Debug, PartialEq)]
39enum SocketHealth {
40 Active,
41 Inactive,
42}
43
44type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
45#[derive(Debug)]
46pub struct EngineConnection {
47 engine_req_tx: mpsc::Sender<ToEngineReq>,
48 shutdown_tx: mpsc::Sender<()>,
49 responses: ResponseInformation,
50 pending_errors: Arc<RwLock<Vec<String>>>,
51 #[allow(dead_code)]
52 tcp_read_handle: Arc<TcpReadHandle>,
53 socket_health: Arc<RwLock<SocketHealth>>,
54 batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
55 batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
56 ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
57
58 default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
60 session_data: Arc<RwLock<Option<ModelingSessionData>>>,
62
63 stats: EngineStats,
64
65 async_tasks: AsyncTasks,
66
67 debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
68}
69
70pub struct TcpRead {
71 stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
72}
73
74#[allow(clippy::large_enum_variant)]
77pub enum WebSocketReadError {
78 Read(tokio_tungstenite::tungstenite::Error),
80 Deser(anyhow::Error),
82}
83
84impl From<anyhow::Error> for WebSocketReadError {
85 fn from(e: anyhow::Error) -> Self {
86 Self::Deser(e)
87 }
88}
89
90impl TcpRead {
91 pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
92 let Some(msg) = self.stream.next().await else {
93 return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
94 };
95 let msg = match msg {
96 Ok(msg) => msg,
97 Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
98 return Err(WebSocketReadError::Read(e));
99 }
100 Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
101 };
102 let msg: WebSocketResponse = match msg {
103 WsMsg::Text(text) => serde_json::from_str(&text)
104 .map_err(anyhow::Error::from)
105 .map_err(WebSocketReadError::from)?,
106 WsMsg::Binary(bin) => rmp_serde::from_slice(&bin)
107 .map_err(anyhow::Error::from)
108 .map_err(WebSocketReadError::from)?,
109 other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
110 };
111 Ok(msg)
112 }
113}
114
115pub struct TcpReadHandle {
116 handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
117}
118
119impl std::fmt::Debug for TcpReadHandle {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 write!(f, "TcpReadHandle")
122 }
123}
124
125impl Drop for TcpReadHandle {
126 fn drop(&mut self) {
127 self.handle.abort();
129 }
130}
131
132#[derive(Clone, Debug)]
134struct ResponseInformation {
135 responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
137}
138
139impl ResponseInformation {
140 pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
141 self.responses.write().await.insert(id, response);
142 }
143}
144
145struct ToEngineReq {
147 req: WebSocketRequest,
149 request_sent: oneshot::Sender<Result<()>>,
153}
154
155impl EngineConnection {
156 async fn start_write_actor(
158 mut tcp_write: WebSocketTcpWrite,
159 mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
160 mut shutdown_rx: mpsc::Receiver<()>,
161 ) {
162 loop {
163 tokio::select! {
164 maybe_req = engine_req_rx.recv() => {
165 match maybe_req {
166 Some(ToEngineReq { req, request_sent }) => {
167 let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
170 cmd: ModelingCmd::ImportFiles { .. },
171 cmd_id: _,
172 }) = &req
173 {
174 Self::inner_send_to_engine_binary(req, &mut tcp_write).await
175 } else {
176 Self::inner_send_to_engine(req, &mut tcp_write).await
177 };
178
179 let _ = request_sent.send(res);
181 }
182 None => {
183 break;
186 }
187 }
188 },
189
190 _ = shutdown_rx.recv() => {
192 let _ = Self::inner_close_engine(&mut tcp_write).await;
193 return;
194 }
195 }
196 }
197
198 let _ = Self::inner_close_engine(&mut tcp_write).await;
201 }
202
203 async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
205 tcp_write
206 .send(WsMsg::Close(None))
207 .await
208 .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
209 Ok(())
210 }
211
212 async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
214 let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
215 tcp_write
216 .send(WsMsg::Text(msg.into()))
217 .await
218 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
219 Ok(())
220 }
221
222 async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
224 let msg = rmp_serde::to_vec_named(&request).map_err(|e| anyhow!("could not serialize msgpack: {e}"))?;
225 tcp_write
226 .send(WsMsg::Binary(msg.into()))
227 .await
228 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
229 Ok(())
230 }
231
232 pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
233 let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
234 .max_message_size(Some(usize::MAX))
236 .max_frame_size(Some(usize::MAX));
237
238 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
239 ws,
240 tokio_tungstenite::tungstenite::protocol::Role::Client,
241 Some(wsconfig),
242 )
243 .await;
244
245 let (tcp_write, tcp_read) = ws_stream.split();
246 let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
247 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
248 tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
249
250 let mut tcp_read = TcpRead { stream: tcp_read };
251
252 let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
253 let session_data2 = session_data.clone();
254 let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
255 let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
256 let pending_errors = Arc::new(RwLock::new(Vec::new()));
257 let pending_errors_clone = pending_errors.clone();
258 let response_information = ResponseInformation {
259 responses: Arc::new(RwLock::new(IndexMap::new())),
260 };
261 let response_information_cloned = response_information.clone();
262 let debug_info = Arc::new(RwLock::new(None));
263 let debug_info_cloned = debug_info.clone();
264
265 let socket_health_tcp_read = socket_health.clone();
266 let tcp_read_handle = tokio::spawn(async move {
267 loop {
269 match tcp_read.read().await {
270 Ok(ws_resp) => {
271 let id = ws_resp.request_id();
273 match &ws_resp {
274 WebSocketResponse::Success(SuccessWebSocketResponse {
275 resp: OkWebSocketResponseData::ModelingBatch { responses },
276 ..
277 }) => {
278 #[expect(
279 clippy::iter_over_hash_type,
280 reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
281 )]
282 for (resp_id, batch_response) in responses {
283 let id: uuid::Uuid = (*resp_id).into();
284 match batch_response {
285 BatchResponse::Success { response } => {
286 response_information_cloned
289 .add(
290 id,
291 WebSocketResponse::Success(SuccessWebSocketResponse {
292 success: true,
293 request_id: Some(id),
294 resp: OkWebSocketResponseData::Modeling {
295 modeling_response: response.clone(),
296 },
297 }),
298 )
299 .await;
300 }
301 BatchResponse::Failure { errors } => {
302 response_information_cloned
303 .add(
304 id,
305 WebSocketResponse::Failure(FailureWebSocketResponse {
306 success: false,
307 request_id: Some(id),
308 errors: errors.clone(),
309 }),
310 )
311 .await;
312 }
313 }
314 }
315 }
316 WebSocketResponse::Success(SuccessWebSocketResponse {
317 resp: OkWebSocketResponseData::ModelingSessionData { session },
318 ..
319 }) => {
320 let mut sd = session_data2.write().await;
321 sd.replace(session.clone());
322 logln!("API Call ID: {}", session.api_call_id);
323 }
324 WebSocketResponse::Failure(FailureWebSocketResponse {
325 success: _,
326 request_id,
327 errors,
328 }) => {
329 if let Some(id) = request_id {
330 response_information_cloned
331 .add(
332 *id,
333 WebSocketResponse::Failure(FailureWebSocketResponse {
334 success: false,
335 request_id: *request_id,
336 errors: errors.clone(),
337 }),
338 )
339 .await;
340 } else {
341 let mut pe = pending_errors_clone.write().await;
343 for error in errors {
344 if !pe.contains(&error.message) {
345 pe.push(error.message.clone());
346 }
347 }
348 drop(pe);
349 }
350 }
351 WebSocketResponse::Success(SuccessWebSocketResponse {
352 resp: debug @ OkWebSocketResponseData::Debug { .. },
353 ..
354 }) => {
355 let mut handle = debug_info_cloned.write().await;
356 *handle = Some(debug.clone());
357 }
358 _ => {}
359 }
360
361 if let Some(id) = id {
362 response_information_cloned.add(id, ws_resp.clone()).await;
363 }
364 }
365 Err(e) => {
366 match &e {
367 WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
368 WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
369 }
370 *socket_health_tcp_read.write().await = SocketHealth::Inactive;
371 return Err(e);
372 }
373 }
374 }
375 });
376
377 Ok(EngineConnection {
378 engine_req_tx,
379 shutdown_tx,
380 tcp_read_handle: Arc::new(TcpReadHandle {
381 handle: Arc::new(tcp_read_handle),
382 }),
383 responses: response_information,
384 pending_errors,
385 socket_health,
386 batch: Arc::new(RwLock::new(Vec::new())),
387 batch_end: Arc::new(RwLock::new(IndexMap::new())),
388 ids_of_async_commands,
389 default_planes: Default::default(),
390 session_data,
391 stats: Default::default(),
392 async_tasks: AsyncTasks::new(),
393 debug_info,
394 })
395 }
396}
397
398#[async_trait::async_trait]
399impl EngineManager for EngineConnection {
400 fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
401 self.batch.clone()
402 }
403
404 fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
405 self.batch_end.clone()
406 }
407
408 fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
409 self.responses.responses.clone()
410 }
411
412 fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
413 self.ids_of_async_commands.clone()
414 }
415
416 fn async_tasks(&self) -> AsyncTasks {
417 self.async_tasks.clone()
418 }
419
420 fn stats(&self) -> &EngineStats {
421 &self.stats
422 }
423
424 fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
425 self.default_planes.clone()
426 }
427
428 async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
429 self.debug_info.read().await.clone()
430 }
431
432 async fn fetch_debug(&self) -> Result<(), KclError> {
433 let (tx, rx) = oneshot::channel();
434
435 self.engine_req_tx
436 .send(ToEngineReq {
437 req: WebSocketRequest::Debug {},
438 request_sent: tx,
439 })
440 .await
441 .map_err(|e| KclError::new_engine(KclErrorDetails::new(format!("Failed to send debug: {e}"), vec![])))?;
442
443 let _ = rx.await;
444 Ok(())
445 }
446
447 async fn clear_scene_post_hook(
448 &self,
449 id_generator: &mut IdGenerator,
450 source_range: SourceRange,
451 ) -> Result<(), KclError> {
452 let new_planes = self.new_default_planes(id_generator, source_range).await?;
454 *self.default_planes.write().await = Some(new_planes);
455
456 Ok(())
457 }
458
459 async fn inner_fire_modeling_cmd(
460 &self,
461 _id: uuid::Uuid,
462 source_range: SourceRange,
463 cmd: WebSocketRequest,
464 _id_to_source_range: HashMap<Uuid, SourceRange>,
465 ) -> Result<(), KclError> {
466 let (tx, rx) = oneshot::channel();
467
468 self.engine_req_tx
470 .send(ToEngineReq {
471 req: cmd.clone(),
472 request_sent: tx,
473 })
474 .await
475 .map_err(|e| {
476 KclError::new_engine(KclErrorDetails::new(
477 format!("Failed to send modeling command: {e}"),
478 vec![source_range],
479 ))
480 })?;
481
482 rx.await
484 .map_err(|e| {
485 KclError::new_engine(KclErrorDetails::new(
486 format!("could not send request to the engine actor: {e}"),
487 vec![source_range],
488 ))
489 })?
490 .map_err(|e| {
491 KclError::new_engine(KclErrorDetails::new(
492 format!("could not send request to the engine: {e}"),
493 vec![source_range],
494 ))
495 })?;
496
497 Ok(())
498 }
499
500 async fn inner_send_modeling_cmd(
501 &self,
502 id: uuid::Uuid,
503 source_range: SourceRange,
504 cmd: WebSocketRequest,
505 id_to_source_range: HashMap<Uuid, SourceRange>,
506 ) -> Result<WebSocketResponse, KclError> {
507 self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
508 .await?;
509
510 let response_timeout = 300;
512 let current_time = std::time::Instant::now();
513 while current_time.elapsed().as_secs() < response_timeout {
514 let guard = self.socket_health.read().await;
515 if *guard == SocketHealth::Inactive {
516 let pe = self.pending_errors.read().await;
518 if !pe.is_empty() {
519 return Err(KclError::new_engine(KclErrorDetails::new(
520 pe.join(", "),
521 vec![source_range],
522 )));
523 } else {
524 return Err(KclError::new_engine_hangup(KclErrorDetails::new(
525 "Modeling command failed: websocket closed early".to_string(),
526 vec![source_range],
527 )));
528 }
529 }
530
531 #[cfg(feature = "artifact-graph")]
532 {
533 if let Some(resp) = self.responses.responses.read().await.get(&id) {
535 return Ok(resp.clone());
536 }
537 }
538 #[cfg(not(feature = "artifact-graph"))]
539 {
540 if let Some(resp) = self.responses.responses.write().await.shift_remove(&id) {
541 return Ok(resp);
542 }
543 }
544 }
545
546 Err(KclError::new_engine(KclErrorDetails::new(
547 format!("Modeling command timed out `{id}`"),
548 vec![source_range],
549 )))
550 }
551
552 async fn get_session_data(&self) -> Option<ModelingSessionData> {
553 self.session_data.read().await.clone()
554 }
555
556 async fn close(&self) {
557 let _ = self.shutdown_tx.send(()).await;
558 loop {
559 let guard = self.socket_health.read().await;
560 if *guard == SocketHealth::Inactive {
561 return;
562 }
563 }
564 }
565}