1use crate::protocol::{
2 HybridRequest, HybridResponse, RequestHeader, ResponseHeader, OP_HYBRID_INFER, OP_INFER,
3 OP_INFER_STREAM, STATUS_ERR, STATUS_OK, STATUS_STREAM_CHUNK, STATUS_STREAM_END,
4};
5use async_trait::async_trait;
6use bincode;
7use kapsl_engine_api::{BinaryTensorPacket, InferenceRequest, NamedTensor, TensorDtype};
8use kapsl_scheduler::{Priority, ReplicaScheduler};
9use kapsl_transport::{ResponseMetadata, TransportError, TransportServer};
10use serde::Deserialize;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14
15#[cfg(unix)]
16use std::os::unix::fs::PermissionsExt;
17#[cfg(windows)]
18use tokio::net::windows::named_pipe::ServerOptions;
19#[cfg(unix)]
20use tokio::net::{UnixListener, UnixStream};
21
22use kapsl_shm::memory::{ShmManager, TensorHeader};
23
24pub type SchedulerLookup =
25 Arc<dyn Fn(u32) -> Option<Arc<dyn ReplicaScheduler + Send + Sync>> + Send + Sync>;
26
27#[derive(Debug, Deserialize)]
28struct LegacyInferenceRequestV1 {
29 input: BinaryTensorPacket,
30 #[serde(default)]
31 additional_inputs: Vec<NamedTensor>,
32 #[serde(default)]
33 session_id: Option<String>,
34}
35
36fn check_auth(request: &InferenceRequest, expected: Option<&str>) -> Option<String> {
37 let Some(expected_token) = expected else {
38 return None; };
40 let presented = request
41 .metadata
42 .as_ref()
43 .and_then(|m| m.auth_token.as_deref());
44 if presented != Some(expected_token) {
45 Some("Unauthorized".to_string())
46 } else {
47 None
48 }
49}
50
51fn decode_inference_request(payload: &[u8]) -> Result<InferenceRequest, String> {
52 match bincode::deserialize::<InferenceRequest>(payload) {
53 Ok(request) => Ok(request),
54 Err(primary_err) => {
55 if let Ok(legacy) = bincode::deserialize::<LegacyInferenceRequestV1>(payload) {
56 return Ok(InferenceRequest {
57 input: legacy.input,
58 additional_inputs: legacy.additional_inputs,
59 session_id: legacy.session_id,
60 metadata: None,
61 cancellation: None,
62 });
63 }
64 Err(format!("Deserialization error: {}", primary_err))
65 }
66 }
67}
68
69pub struct IpcServer {
70 socket_path: String,
71 scheduler_lookup: SchedulerLookup,
72 shm_manager: Option<Arc<ShmManager>>,
73 auth_token: Option<Arc<str>>,
74}
75
76impl IpcServer {
77 pub fn new(
78 socket_path: &str,
79 schedulers: HashMap<u32, Arc<dyn ReplicaScheduler + Send + Sync>>,
80 shm_manager: Option<Arc<ShmManager>>,
81 ) -> Self {
82 let schedulers = Arc::new(schedulers);
83 let scheduler_lookup: SchedulerLookup =
84 Arc::new(move |model_id| schedulers.get(&model_id).cloned());
85 Self::new_with_lookup(socket_path, scheduler_lookup, shm_manager)
86 }
87
88 pub fn new_with_lookup(
89 socket_path: &str,
90 scheduler_lookup: SchedulerLookup,
91 shm_manager: Option<Arc<ShmManager>>,
92 ) -> Self {
93 Self {
94 socket_path: socket_path.to_string(),
95 scheduler_lookup,
96 shm_manager,
97 auth_token: None,
98 }
99 }
100
101 pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
105 self.auth_token = Some(Arc::from(token.into().as_str()));
106 self
107 }
108
109 async fn run_internal(&self) -> std::io::Result<()> {
110 let scheduler_lookup = self.scheduler_lookup.clone();
111 let auth_token = self.auth_token.clone();
112
113 #[cfg(unix)]
114 {
115 if std::path::Path::new(&self.socket_path).exists() {
116 if UnixStream::connect(&self.socket_path).await.is_ok() {
119 return Err(std::io::Error::new(
120 std::io::ErrorKind::AddrInUse,
121 format!(
122 "IPC socket path {} is already in use. Is another kapsl runtime running? Use --socket to choose a different path.",
123 self.socket_path
124 ),
125 ));
126 }
127
128 std::fs::remove_file(&self.socket_path)?;
130 }
131
132 let listener = UnixListener::bind(&self.socket_path)?;
133 std::fs::set_permissions(&self.socket_path, std::fs::Permissions::from_mode(0o600))?;
134 log::info!("IPC Server listening on {}", self.socket_path);
135
136 loop {
137 let (stream, _) = listener.accept().await?;
138 let scheduler_lookup = scheduler_lookup.clone();
139 let shm_manager = self.shm_manager.clone();
140 let auth_token = auth_token.clone();
141
142 tokio::spawn(async move {
143 if let Err(e) =
144 handle_connection(stream, scheduler_lookup, shm_manager, auth_token).await
145 {
146 log::error!("Connection error: {}", e);
147 }
148 });
149 }
150 }
151
152 #[cfg(windows)]
153 {
154 loop {
155 let server = ServerOptions::new().create(&self.socket_path)?;
156
157 server.connect().await?;
158 let scheduler_lookup = scheduler_lookup.clone();
159 let shm_manager = self.shm_manager.clone();
160 let auth_token = auth_token.clone();
161
162 tokio::spawn(async move {
163 if let Err(e) =
164 handle_connection(server, scheduler_lookup, shm_manager, auth_token).await
165 {
166 log::error!("Connection error: {}", e);
167 }
168 });
169 }
170 }
171 }
172}
173
174#[async_trait]
175impl TransportServer for IpcServer {
176 async fn run(&self) -> Result<(), TransportError> {
177 self.run_internal().await.map_err(TransportError::Io)
178 }
179
180 async fn shutdown(&self) -> Result<(), TransportError> {
181 #[cfg(unix)]
183 {
184 if std::path::Path::new(&self.socket_path).exists() {
185 std::fs::remove_file(&self.socket_path).map_err(TransportError::Io)?;
186 }
187 }
188 Ok(())
189 }
190
191 fn transport_type(&self) -> &'static str {
192 "socket"
193 }
194}
195
196pub(crate) async fn handle_connection<T>(
197 mut connection: T,
198 scheduler_lookup: SchedulerLookup,
199 shm_manager: Option<Arc<ShmManager>>,
200 auth_token: Option<Arc<str>>,
201) -> std::io::Result<()>
202where
203 T: AsyncRead + AsyncWrite + Unpin,
204{
205 loop {
206 let mut model_id_buf = [0u8; 4];
208 if connection.read_exact(&mut model_id_buf).await.is_err() {
209 return Ok(()); }
211 let mut op_code_buf = [0u8; 4];
212 connection.read_exact(&mut op_code_buf).await?;
213 let mut payload_size_buf = [0u8; 4];
214 connection.read_exact(&mut payload_size_buf).await?;
215
216 let header = RequestHeader {
217 model_id: u32::from_le_bytes(model_id_buf),
218 op_code: u32::from_le_bytes(op_code_buf),
219 payload_size: u32::from_le_bytes(payload_size_buf),
220 };
221
222 let mut payload = vec![0u8; header.payload_size as usize];
224 connection.read_exact(&mut payload).await?;
225
226 match header.op_code {
227 OP_INFER_STREAM => {
228 let request: InferenceRequest = match decode_inference_request(&payload) {
230 Ok(req) => req,
231 Err(error_msg) => {
232 let resp_header = ResponseHeader {
233 status: STATUS_ERR,
234 payload_size: error_msg.len() as u32,
235 };
236 connection
237 .write_all(&resp_header.status.to_le_bytes())
238 .await?;
239 connection
240 .write_all(&resp_header.payload_size.to_le_bytes())
241 .await?;
242 connection.write_all(error_msg.as_bytes()).await?;
243 continue;
244 }
245 };
246
247 if let Some(error_msg) = check_auth(&request, auth_token.as_deref()) {
248 let resp_header = ResponseHeader {
249 status: STATUS_ERR,
250 payload_size: error_msg.len() as u32,
251 };
252 connection
253 .write_all(&resp_header.status.to_le_bytes())
254 .await?;
255 connection
256 .write_all(&resp_header.payload_size.to_le_bytes())
257 .await?;
258 connection.write_all(error_msg.as_bytes()).await?;
259 continue;
260 }
261
262 let scheduler = match scheduler_lookup(header.model_id) {
264 Some(s) => s,
265 None => {
266 let error_msg = format!("Model {} not found", header.model_id);
267 let resp_header = ResponseHeader {
268 status: STATUS_ERR,
269 payload_size: error_msg.len() as u32,
270 };
271 connection
272 .write_all(&resp_header.status.to_le_bytes())
273 .await?;
274 connection
275 .write_all(&resp_header.payload_size.to_le_bytes())
276 .await?;
277 connection.write_all(error_msg.as_bytes()).await?;
278 continue;
279 }
280 };
281
282 let stream_result = scheduler
284 .infer_stream(request, Priority::LatencyCritical, false)
285 .await;
286
287 use futures::StreamExt;
288 match stream_result {
289 Ok(mut inference_stream) => {
290 while let Some(result) = inference_stream.next().await {
291 match result {
292 Ok(packet) => {
293 let response_bytes = match bincode::serialize(&packet) {
295 Ok(b) => b,
296 Err(e) => {
297 log::error!("Serialization error: {}", e);
298 break;
299 }
300 };
301
302 let response_header = ResponseHeader {
304 status: STATUS_STREAM_CHUNK,
305 payload_size: response_bytes.len() as u32,
306 };
307
308 connection
309 .write_all(&response_header.status.to_le_bytes())
310 .await?;
311 connection
312 .write_all(&response_header.payload_size.to_le_bytes())
313 .await?;
314 connection.write_all(&response_bytes).await?;
315 connection.flush().await?;
316 }
317 Err(e) => {
318 let error_msg = e.to_string();
320 let response_bytes = error_msg.as_bytes();
321 let response_header = ResponseHeader {
322 status: STATUS_ERR,
323 payload_size: response_bytes.len() as u32,
324 };
325 connection
326 .write_all(&response_header.status.to_le_bytes())
327 .await?;
328 connection
329 .write_all(&response_header.payload_size.to_le_bytes())
330 .await?;
331 connection.write_all(response_bytes).await?;
332 connection.flush().await?;
333 break;
334 }
335 }
336 }
337
338 let response_header = ResponseHeader {
340 status: STATUS_STREAM_END,
341 payload_size: 0,
342 };
343 connection
344 .write_all(&response_header.status.to_le_bytes())
345 .await?;
346 connection
347 .write_all(&response_header.payload_size.to_le_bytes())
348 .await?;
349 connection.flush().await?;
350 }
351 Err(e) => {
352 let error_msg = e.to_string();
354 let response_bytes = error_msg.as_bytes();
355 let response_header = ResponseHeader {
356 status: STATUS_ERR,
357 payload_size: response_bytes.len() as u32,
358 };
359 connection
360 .write_all(&response_header.status.to_le_bytes())
361 .await?;
362 connection
363 .write_all(&response_header.payload_size.to_le_bytes())
364 .await?;
365 connection.write_all(response_bytes).await?;
366 connection.flush().await?;
367 }
368 }
369 }
370 OP_INFER => {
371 if let Some(scheduler) = scheduler_lookup(header.model_id) {
373 let request: InferenceRequest = match decode_inference_request(&payload) {
375 Ok(req) => req,
376 Err(error_msg) => {
377 let resp_header = ResponseHeader {
378 status: STATUS_ERR,
379 payload_size: error_msg.len() as u32,
380 };
381 connection
382 .write_all(&resp_header.status.to_le_bytes())
383 .await?;
384 connection
385 .write_all(&resp_header.payload_size.to_le_bytes())
386 .await?;
387 connection.write_all(error_msg.as_bytes()).await?;
388 continue;
389 }
390 };
391
392 if let Some(error_msg) = check_auth(&request, auth_token.as_deref()) {
393 let resp_header = ResponseHeader {
394 status: STATUS_ERR,
395 payload_size: error_msg.len() as u32,
396 };
397 connection
398 .write_all(&resp_header.status.to_le_bytes())
399 .await?;
400 connection
401 .write_all(&resp_header.payload_size.to_le_bytes())
402 .await?;
403 connection.write_all(error_msg.as_bytes()).await?;
404 continue;
405 }
406
407 let result = scheduler.infer(&request, Priority::Throughput, false).await;
410
411 match result {
412 Ok(output) => {
413 let output_bytes =
414 bincode::serialize(&output).map_err(std::io::Error::other)?;
415
416 let resp_header = ResponseHeader {
417 status: STATUS_OK,
418 payload_size: output_bytes.len() as u32,
419 };
420
421 connection
423 .write_all(&resp_header.status.to_le_bytes())
424 .await?;
425 connection
426 .write_all(&resp_header.payload_size.to_le_bytes())
427 .await?;
428 connection.write_all(&output_bytes).await?;
429 }
430 Err(e) => {
431 let error_msg = e.to_string();
432 let resp_header = ResponseHeader {
433 status: STATUS_ERR,
434 payload_size: error_msg.len() as u32,
435 };
436 connection
437 .write_all(&resp_header.status.to_le_bytes())
438 .await?;
439 connection
440 .write_all(&resp_header.payload_size.to_le_bytes())
441 .await?;
442 connection.write_all(error_msg.as_bytes()).await?;
443 }
444 }
445 } else {
446 let error_msg = format!("Model {} not found", header.model_id);
448 let resp_header = ResponseHeader {
449 status: STATUS_ERR,
450 payload_size: error_msg.len() as u32,
451 };
452 connection
453 .write_all(&resp_header.status.to_le_bytes())
454 .await?;
455 connection
456 .write_all(&resp_header.payload_size.to_le_bytes())
457 .await?;
458 connection.write_all(error_msg.as_bytes()).await?;
459 }
460 }
461 OP_HYBRID_INFER => {
462 let hybrid_req: HybridRequest = bincode::deserialize(&payload)
465 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
466
467 if let Some(shm_manager) = &shm_manager {
468 let base_ptr = shm_manager.as_ptr();
469
470 let header_ptr = unsafe {
472 base_ptr.add(hybrid_req.shm_offset as usize) as *const TensorHeader
473 };
474 let tensor_header = unsafe { &*header_ptr };
475
476 let data_ptr = unsafe {
478 base_ptr.add(
479 hybrid_req.shm_offset as usize + std::mem::size_of::<TensorHeader>(),
480 )
481 };
482 let data_slice = unsafe {
483 std::slice::from_raw_parts(data_ptr, tensor_header.data_size as usize)
484 };
485
486 let shape = tensor_header.shape[0..tensor_header.ndim as usize].to_vec();
488 let dtype = match tensor_header.dtype {
489 0 => TensorDtype::Float32,
490 1 => TensorDtype::Float64,
491 2 => TensorDtype::Int32,
492 3 => TensorDtype::Int64,
493 _ => TensorDtype::Float32,
494 };
495
496 let packet = BinaryTensorPacket {
497 shape,
498 dtype,
499 data: data_slice.to_vec(),
500 };
501
502 let request = InferenceRequest {
503 input: packet,
504 additional_inputs: Vec::new(),
505 session_id: None,
506 metadata: None,
507 cancellation: None,
508 };
509
510 let result =
512 if let Some(scheduler) = scheduler_lookup(hybrid_req.metadata.model_id) {
513 scheduler
514 .infer(
515 &request,
516 Priority::Throughput,
517 hybrid_req.metadata.force_cpu,
518 )
519 .await
520 } else {
521 Err(kapsl_engine_api::EngineError::ModelNotLoaded)
522 };
523
524 match result {
525 Ok(output) => {
526 let packet = BinaryTensorPacket {
528 shape: output.shape,
529 dtype: output.dtype,
530 data: output.data,
531 };
532
533 let output_size =
535 std::mem::size_of::<TensorHeader>() + packet.data.len();
536
537 static SERVER_SLOT_COUNTER: std::sync::atomic::AtomicUsize =
540 std::sync::atomic::AtomicUsize::new(0);
541 let slot = SERVER_SLOT_COUNTER
542 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
543 let output_offset = 512 * 1024 * 1024 + (slot % 400) * 1_000_000; let shm_size = shm_manager.size();
547 if output_offset + output_size > shm_size {
548 let error_msg = format!("Output would exceed SHM bounds: offset={}, size={}, shm_size={}",
549 output_offset, output_size, shm_size);
550 let resp_header = ResponseHeader {
551 status: STATUS_ERR,
552 payload_size: error_msg.len() as u32,
553 };
554 connection
555 .write_all(&resp_header.status.to_le_bytes())
556 .await?;
557 connection
558 .write_all(&resp_header.payload_size.to_le_bytes())
559 .await?;
560 connection.write_all(error_msg.as_bytes()).await?;
561 continue;
562 }
563
564 let base_ptr = shm_manager.as_ptr();
567 unsafe {
568 let out_header = TensorHeader {
570 ndim: packet.shape.len() as u32,
571 dtype: match packet.dtype {
572 TensorDtype::Float32 => 0,
573 TensorDtype::Float64 => 1,
574 TensorDtype::Int32 => 2,
575 TensorDtype::Int64 => 3,
576 _ => 0,
577 },
578 _padding: [0; 3],
579 shape: {
580 let mut arr = [0i64; 8];
581 for (i, &v) in packet.shape.iter().enumerate() {
582 arr[i] = v;
583 }
584 arr
585 },
586 data_size: packet.data.len() as u64,
587 };
588
589 let hdr_ptr = base_ptr.add(output_offset) as *mut TensorHeader;
590 std::ptr::write(hdr_ptr, out_header);
591
592 let data_ptr = base_ptr
593 .add(output_offset + std::mem::size_of::<TensorHeader>());
594 std::ptr::copy_nonoverlapping(
595 packet.data.as_ptr(),
596 data_ptr,
597 packet.data.len(),
598 );
599 }
600
601 let resp = HybridResponse {
603 metadata: ResponseMetadata {
604 request_id: hybrid_req.metadata.request_id,
605 status: STATUS_OK as u8,
606 _padding: [0; 7],
607 latency_ns: 0,
608 },
609 shm_offset: output_offset as u64,
610 shm_size: (std::mem::size_of::<TensorHeader>() + packet.data.len())
611 as u64,
612 };
613
614 let resp_bytes =
615 bincode::serialize(&resp).map_err(std::io::Error::other)?;
616
617 let resp_header = ResponseHeader {
618 status: STATUS_OK,
619 payload_size: resp_bytes.len() as u32,
620 };
621
622 connection
623 .write_all(&resp_header.status.to_le_bytes())
624 .await?;
625 connection
626 .write_all(&resp_header.payload_size.to_le_bytes())
627 .await?;
628 connection.write_all(&resp_bytes).await?;
629 }
630 Err(e) => {
631 let error_msg = e.to_string();
632 let resp_header = ResponseHeader {
633 status: STATUS_ERR,
634 payload_size: error_msg.len() as u32,
635 };
636 connection
637 .write_all(&resp_header.status.to_le_bytes())
638 .await?;
639 connection
640 .write_all(&resp_header.payload_size.to_le_bytes())
641 .await?;
642 connection.write_all(error_msg.as_bytes()).await?;
643 }
644 }
645 } else {
646 let error_msg = "SHM Manager not configured".to_string();
647 let resp_header = ResponseHeader {
648 status: STATUS_ERR,
649 payload_size: error_msg.len() as u32,
650 };
651 connection
652 .write_all(&resp_header.status.to_le_bytes())
653 .await?;
654 connection
655 .write_all(&resp_header.payload_size.to_le_bytes())
656 .await?;
657 connection.write_all(error_msg.as_bytes()).await?;
658 }
659 }
660 _ => {
661 let resp_header = ResponseHeader {
663 status: STATUS_ERR,
664 payload_size: 0,
665 };
666 connection
668 .write_all(&resp_header.status.to_le_bytes())
669 .await?;
670 connection
671 .write_all(&resp_header.payload_size.to_le_bytes())
672 .await?;
673 }
674 }
675 }
676}