1use std::{collections::HashMap, sync::Arc};
5
6use anyhow::{anyhow, Result};
7use futures::{SinkExt, StreamExt};
8use indexmap::IndexMap;
9use kcmc::{
10 websocket::{
11 BatchResponse, FailureWebSocketResponse, ModelingCmdReq, ModelingSessionData, OkWebSocketResponseData,
12 SuccessWebSocketResponse, WebSocketRequest, WebSocketResponse,
13 },
14 ModelingCmd,
15};
16use kittycad_modeling_cmds::{self as kcmc};
17use tokio::sync::{mpsc, oneshot, RwLock};
18use tokio_tungstenite::tungstenite::Message as WsMsg;
19use uuid::Uuid;
20
21use super::{EngineStats, ExecutionKind};
22use crate::{
23 engine::EngineManager,
24 errors::{KclError, KclErrorDetails},
25 execution::{ArtifactCommand, DefaultPlanes, IdGenerator},
26 SourceRange,
27};
28
29#[derive(Debug, PartialEq)]
30enum SocketHealth {
31 Active,
32 Inactive,
33}
34
35type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
36#[derive(Debug)]
37pub struct EngineConnection {
38 engine_req_tx: mpsc::Sender<ToEngineReq>,
39 shutdown_tx: mpsc::Sender<()>,
40 responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
41 pending_errors: Arc<RwLock<Vec<String>>>,
42 #[allow(dead_code)]
43 tcp_read_handle: Arc<TcpReadHandle>,
44 socket_health: Arc<RwLock<SocketHealth>>,
45 batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
46 batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
47 artifact_commands: Arc<RwLock<Vec<ArtifactCommand>>>,
48
49 default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
51 session_data: Arc<RwLock<Option<ModelingSessionData>>>,
53
54 execution_kind: Arc<RwLock<ExecutionKind>>,
55 stats: EngineStats,
56}
57
58pub struct TcpRead {
59 stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
60}
61
62pub enum WebSocketReadError {
65 Read(tokio_tungstenite::tungstenite::Error),
67 Deser(anyhow::Error),
69}
70
71impl From<anyhow::Error> for WebSocketReadError {
72 fn from(e: anyhow::Error) -> Self {
73 Self::Deser(e)
74 }
75}
76
77impl TcpRead {
78 pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
79 let Some(msg) = self.stream.next().await else {
80 return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
81 };
82 let msg = match msg {
83 Ok(msg) => msg,
84 Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
85 return Err(WebSocketReadError::Read(e))
86 }
87 Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
88 };
89 let msg: WebSocketResponse = match msg {
90 WsMsg::Text(text) => serde_json::from_str(&text)
91 .map_err(anyhow::Error::from)
92 .map_err(WebSocketReadError::from)?,
93 WsMsg::Binary(bin) => bson::from_slice(&bin)
94 .map_err(anyhow::Error::from)
95 .map_err(WebSocketReadError::from)?,
96 other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
97 };
98 Ok(msg)
99 }
100}
101
102pub struct TcpReadHandle {
103 handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
104}
105
106impl std::fmt::Debug for TcpReadHandle {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(f, "TcpReadHandle")
109 }
110}
111
112impl Drop for TcpReadHandle {
113 fn drop(&mut self) {
114 self.handle.abort();
116 }
117}
118
119struct ToEngineReq {
121 req: WebSocketRequest,
123 request_sent: oneshot::Sender<Result<()>>,
127}
128
129impl EngineConnection {
130 async fn start_write_actor(
132 mut tcp_write: WebSocketTcpWrite,
133 mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
134 mut shutdown_rx: mpsc::Receiver<()>,
135 ) {
136 loop {
137 tokio::select! {
138 maybe_req = engine_req_rx.recv() => {
139 match maybe_req {
140 Some(ToEngineReq { req, request_sent }) => {
141 let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
144 cmd: ModelingCmd::ImportFiles { .. },
145 cmd_id: _,
146 }) = &req
147 {
148 Self::inner_send_to_engine_binary(req, &mut tcp_write).await
149 } else {
150 Self::inner_send_to_engine(req, &mut tcp_write).await
151 };
152
153 let _ = request_sent.send(res);
155 }
156 None => {
157 break;
160 }
161 }
162 },
163
164 _ = shutdown_rx.recv() => {
166 let _ = Self::inner_close_engine(&mut tcp_write).await;
167 return;
168 }
169 }
170 }
171
172 let _ = Self::inner_close_engine(&mut tcp_write).await;
175 }
176
177 async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
179 tcp_write
180 .send(WsMsg::Close(None))
181 .await
182 .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
183 Ok(())
184 }
185
186 async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
188 let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
189 tcp_write
190 .send(WsMsg::Text(msg))
191 .await
192 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
193 Ok(())
194 }
195
196 async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
198 let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
199 tcp_write
200 .send(WsMsg::Binary(msg))
201 .await
202 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
203 Ok(())
204 }
205
206 pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
207 let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
208 max_message_size: Some(usize::MAX),
210 max_frame_size: Some(usize::MAX),
211 ..Default::default()
212 };
213
214 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
215 ws,
216 tokio_tungstenite::tungstenite::protocol::Role::Client,
217 Some(wsconfig),
218 )
219 .await;
220
221 let (tcp_write, tcp_read) = ws_stream.split();
222 let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
223 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
224 tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
225
226 let mut tcp_read = TcpRead { stream: tcp_read };
227
228 let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
229 let session_data2 = session_data.clone();
230 let responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>> = Arc::new(RwLock::new(IndexMap::new()));
231 let responses_clone = responses.clone();
232 let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
233 let pending_errors = Arc::new(RwLock::new(Vec::new()));
234 let pending_errors_clone = pending_errors.clone();
235
236 let socket_health_tcp_read = socket_health.clone();
237 let tcp_read_handle = tokio::spawn(async move {
238 loop {
240 match tcp_read.read().await {
241 Ok(ws_resp) => {
242 let id = ws_resp.request_id();
244 match &ws_resp {
245 WebSocketResponse::Success(SuccessWebSocketResponse {
246 resp: OkWebSocketResponseData::ModelingBatch { responses },
247 ..
248 }) =>
249 {
250 #[expect(
251 clippy::iter_over_hash_type,
252 reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
253 )]
254 for (resp_id, batch_response) in responses {
255 let id: uuid::Uuid = (*resp_id).into();
256 match batch_response {
257 BatchResponse::Success { response } => {
258 responses_clone.write().await.insert(
259 id,
260 WebSocketResponse::Success(SuccessWebSocketResponse {
261 success: true,
262 request_id: Some(id),
263 resp: OkWebSocketResponseData::Modeling {
264 modeling_response: response.clone(),
265 },
266 }),
267 );
268 }
269 BatchResponse::Failure { errors } => {
270 responses_clone.write().await.insert(
271 id,
272 WebSocketResponse::Failure(FailureWebSocketResponse {
273 success: false,
274 request_id: Some(id),
275 errors: errors.clone(),
276 }),
277 );
278 }
279 }
280 }
281 }
282 WebSocketResponse::Success(SuccessWebSocketResponse {
283 resp: OkWebSocketResponseData::ModelingSessionData { session },
284 ..
285 }) => {
286 let mut sd = session_data2.write().await;
287 sd.replace(session.clone());
288 }
289 WebSocketResponse::Failure(FailureWebSocketResponse {
290 success: _,
291 request_id,
292 errors,
293 }) => {
294 if let Some(id) = request_id {
295 responses_clone.write().await.insert(
296 *id,
297 WebSocketResponse::Failure(FailureWebSocketResponse {
298 success: false,
299 request_id: *request_id,
300 errors: errors.clone(),
301 }),
302 );
303 } else {
304 let mut pe = pending_errors_clone.write().await;
306 for error in errors {
307 if !pe.contains(&error.message) {
308 pe.push(error.message.clone());
309 }
310 }
311 drop(pe);
312 }
313 }
314 _ => {}
315 }
316
317 if let Some(id) = id {
318 responses_clone.write().await.insert(id, ws_resp.clone());
319 }
320 }
321 Err(e) => {
322 match &e {
323 WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
324 WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
325 }
326 *socket_health_tcp_read.write().await = SocketHealth::Inactive;
327 return Err(e);
328 }
329 }
330 }
331 });
332
333 Ok(EngineConnection {
334 engine_req_tx,
335 shutdown_tx,
336 tcp_read_handle: Arc::new(TcpReadHandle {
337 handle: Arc::new(tcp_read_handle),
338 }),
339 responses,
340 pending_errors,
341 socket_health,
342 batch: Arc::new(RwLock::new(Vec::new())),
343 batch_end: Arc::new(RwLock::new(IndexMap::new())),
344 artifact_commands: Arc::new(RwLock::new(Vec::new())),
345 default_planes: Default::default(),
346 session_data,
347 execution_kind: Default::default(),
348 stats: Default::default(),
349 })
350 }
351}
352
353#[async_trait::async_trait]
354impl EngineManager for EngineConnection {
355 fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
356 self.batch.clone()
357 }
358
359 fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
360 self.batch_end.clone()
361 }
362
363 fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
364 self.responses.clone()
365 }
366
367 fn artifact_commands(&self) -> Arc<RwLock<Vec<ArtifactCommand>>> {
368 self.artifact_commands.clone()
369 }
370
371 async fn execution_kind(&self) -> ExecutionKind {
372 let guard = self.execution_kind.read().await;
373 *guard
374 }
375
376 async fn replace_execution_kind(&self, execution_kind: ExecutionKind) -> ExecutionKind {
377 let mut guard = self.execution_kind.write().await;
378 let original = *guard;
379 *guard = execution_kind;
380 original
381 }
382
383 fn stats(&self) -> &EngineStats {
384 &self.stats
385 }
386
387 fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
388 self.default_planes.clone()
389 }
390
391 async fn clear_scene_post_hook(
392 &self,
393 id_generator: &mut IdGenerator,
394 source_range: SourceRange,
395 ) -> Result<(), KclError> {
396 let new_planes = self.new_default_planes(id_generator, source_range).await?;
398 *self.default_planes.write().await = Some(new_planes);
399
400 Ok(())
401 }
402
403 async fn inner_send_modeling_cmd(
404 &self,
405 id: uuid::Uuid,
406 source_range: SourceRange,
407 cmd: WebSocketRequest,
408 _id_to_source_range: HashMap<Uuid, SourceRange>,
409 ) -> Result<WebSocketResponse, KclError> {
410 let (tx, rx) = oneshot::channel();
411
412 self.engine_req_tx
414 .send(ToEngineReq {
415 req: cmd.clone(),
416 request_sent: tx,
417 })
418 .await
419 .map_err(|e| {
420 KclError::Engine(KclErrorDetails {
421 message: format!("Failed to send modeling command: {}", e),
422 source_ranges: vec![source_range],
423 })
424 })?;
425
426 rx.await
428 .map_err(|e| {
429 KclError::Engine(KclErrorDetails {
430 message: format!("could not send request to the engine actor: {e}"),
431 source_ranges: vec![source_range],
432 })
433 })?
434 .map_err(|e| {
435 KclError::Engine(KclErrorDetails {
436 message: format!("could not send request to the engine: {e}"),
437 source_ranges: vec![source_range],
438 })
439 })?;
440
441 let current_time = std::time::Instant::now();
443 while current_time.elapsed().as_secs() < 60 {
444 let guard = self.socket_health.read().await;
445 if *guard == SocketHealth::Inactive {
446 let pe = self.pending_errors.read().await;
448 if !pe.is_empty() {
449 return Err(KclError::Engine(KclErrorDetails {
450 message: pe.join(", ").to_string(),
451 source_ranges: vec![source_range],
452 }));
453 } else {
454 return Err(KclError::Engine(KclErrorDetails {
455 message: "Modeling command failed: websocket closed early".to_string(),
456 source_ranges: vec![source_range],
457 }));
458 }
459 }
460 if let Some(resp) = self.responses.write().await.shift_remove(&id) {
462 return Ok(resp);
463 }
464 }
465
466 Err(KclError::Engine(KclErrorDetails {
467 message: format!("Modeling command timed out `{}`", id),
468 source_ranges: vec![source_range],
469 }))
470 }
471
472 async fn get_session_data(&self) -> Option<ModelingSessionData> {
473 self.session_data.read().await.clone()
474 }
475
476 async fn close(&self) {
477 let _ = self.shutdown_tx.send(()).await;
478 loop {
479 let guard = self.socket_health.read().await;
480 if *guard == SocketHealth::Inactive {
481 return;
482 }
483 }
484 }
485}