Skip to main content

kapsl_ipc/
server.rs

1use crate::protocol::{
2    HybridRequest, HybridResponse, RequestHeader, ResponseHeader, OP_HYBRID_INFER, OP_INFER,
3    OP_INFER_STREAM, STATUS_ERR, STATUS_OK, STATUS_STREAM_CHUNK, STATUS_STREAM_END,
4};
5use async_trait::async_trait;
6use bincode;
7use kapsl_engine_api::{BinaryTensorPacket, InferenceRequest, NamedTensor, TensorDtype};
8use kapsl_scheduler::{Priority, ReplicaScheduler};
9use kapsl_transport::{ResponseMetadata, TransportError, TransportServer};
10use serde::Deserialize;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14
15#[cfg(unix)]
16use std::os::unix::fs::PermissionsExt;
17#[cfg(windows)]
18use tokio::net::windows::named_pipe::ServerOptions;
19#[cfg(unix)]
20use tokio::net::{UnixListener, UnixStream};
21
22use kapsl_shm::memory::{ShmManager, TensorHeader};
23
24pub type SchedulerLookup =
25    Arc<dyn Fn(u32) -> Option<Arc<dyn ReplicaScheduler + Send + Sync>> + Send + Sync>;
26
27#[derive(Debug, Deserialize)]
28struct LegacyInferenceRequestV1 {
29    input: BinaryTensorPacket,
30    #[serde(default)]
31    additional_inputs: Vec<NamedTensor>,
32    #[serde(default)]
33    session_id: Option<String>,
34}
35
36fn check_auth(request: &InferenceRequest, expected: Option<&str>) -> Option<String> {
37    let Some(expected_token) = expected else {
38        return None; // auth not configured — allow all
39    };
40    let presented = request
41        .metadata
42        .as_ref()
43        .and_then(|m| m.auth_token.as_deref());
44    if presented != Some(expected_token) {
45        Some("Unauthorized".to_string())
46    } else {
47        None
48    }
49}
50
51fn decode_inference_request(payload: &[u8]) -> Result<InferenceRequest, String> {
52    match bincode::deserialize::<InferenceRequest>(payload) {
53        Ok(request) => Ok(request),
54        Err(primary_err) => {
55            if let Ok(legacy) = bincode::deserialize::<LegacyInferenceRequestV1>(payload) {
56                return Ok(InferenceRequest {
57                    input: legacy.input,
58                    additional_inputs: legacy.additional_inputs,
59                    session_id: legacy.session_id,
60                    metadata: None,
61                    cancellation: None,
62                });
63            }
64            Err(format!("Deserialization error: {}", primary_err))
65        }
66    }
67}
68
69pub struct IpcServer {
70    socket_path: String,
71    scheduler_lookup: SchedulerLookup,
72    shm_manager: Option<Arc<ShmManager>>,
73    auth_token: Option<Arc<str>>,
74}
75
76impl IpcServer {
77    pub fn new(
78        socket_path: &str,
79        schedulers: HashMap<u32, Arc<dyn ReplicaScheduler + Send + Sync>>,
80        shm_manager: Option<Arc<ShmManager>>,
81    ) -> Self {
82        let schedulers = Arc::new(schedulers);
83        let scheduler_lookup: SchedulerLookup =
84            Arc::new(move |model_id| schedulers.get(&model_id).cloned());
85        Self::new_with_lookup(socket_path, scheduler_lookup, shm_manager)
86    }
87
88    pub fn new_with_lookup(
89        socket_path: &str,
90        scheduler_lookup: SchedulerLookup,
91        shm_manager: Option<Arc<ShmManager>>,
92    ) -> Self {
93        Self {
94            socket_path: socket_path.to_string(),
95            scheduler_lookup,
96            shm_manager,
97            auth_token: None,
98        }
99    }
100
101    /// Require every inference request to carry this token in
102    /// `request.metadata.auth_token`. Requests without the token
103    /// or with a wrong token receive `STATUS_ERR: Unauthorized`.
104    pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
105        self.auth_token = Some(Arc::from(token.into().as_str()));
106        self
107    }
108
109    async fn run_internal(&self) -> std::io::Result<()> {
110        let scheduler_lookup = self.scheduler_lookup.clone();
111        let auth_token = self.auth_token.clone();
112
113        #[cfg(unix)]
114        {
115            if std::path::Path::new(&self.socket_path).exists() {
116                // Avoid clobbering a live socket from another runtime: if we can connect,
117                // it is in-use and we should refuse to start.
118                if UnixStream::connect(&self.socket_path).await.is_ok() {
119                    return Err(std::io::Error::new(
120                        std::io::ErrorKind::AddrInUse,
121                        format!(
122                            "IPC socket path {} is already in use. Is another kapsl runtime running? Use --socket to choose a different path.",
123                            self.socket_path
124                        ),
125                    ));
126                }
127
128                // Stale socket (or leftover file) from a previous crash.
129                std::fs::remove_file(&self.socket_path)?;
130            }
131
132            let listener = UnixListener::bind(&self.socket_path)?;
133            std::fs::set_permissions(&self.socket_path, std::fs::Permissions::from_mode(0o600))?;
134            log::info!("IPC Server listening on {}", self.socket_path);
135
136            loop {
137                let (stream, _) = listener.accept().await?;
138                let scheduler_lookup = scheduler_lookup.clone();
139                let shm_manager = self.shm_manager.clone();
140                let auth_token = auth_token.clone();
141
142                tokio::spawn(async move {
143                    if let Err(e) =
144                        handle_connection(stream, scheduler_lookup, shm_manager, auth_token).await
145                    {
146                        log::error!("Connection error: {}", e);
147                    }
148                });
149            }
150        }
151
152        #[cfg(windows)]
153        {
154            loop {
155                let server = ServerOptions::new().create(&self.socket_path)?;
156
157                server.connect().await?;
158                let scheduler_lookup = scheduler_lookup.clone();
159                let shm_manager = self.shm_manager.clone();
160                let auth_token = auth_token.clone();
161
162                tokio::spawn(async move {
163                    if let Err(e) =
164                        handle_connection(server, scheduler_lookup, shm_manager, auth_token).await
165                    {
166                        log::error!("Connection error: {}", e);
167                    }
168                });
169            }
170        }
171    }
172}
173
174#[async_trait]
175impl TransportServer for IpcServer {
176    async fn run(&self) -> Result<(), TransportError> {
177        self.run_internal().await.map_err(TransportError::Io)
178    }
179
180    async fn shutdown(&self) -> Result<(), TransportError> {
181        // Clean up socket file on shutdown
182        #[cfg(unix)]
183        {
184            if std::path::Path::new(&self.socket_path).exists() {
185                std::fs::remove_file(&self.socket_path).map_err(TransportError::Io)?;
186            }
187        }
188        Ok(())
189    }
190
191    fn transport_type(&self) -> &'static str {
192        "socket"
193    }
194}
195
196pub(crate) async fn handle_connection<T>(
197    mut connection: T,
198    scheduler_lookup: SchedulerLookup,
199    shm_manager: Option<Arc<ShmManager>>,
200    auth_token: Option<Arc<str>>,
201) -> std::io::Result<()>
202where
203    T: AsyncRead + AsyncWrite + Unpin,
204{
205    loop {
206        // Read header as raw bytes (not bincode)
207        let mut model_id_buf = [0u8; 4];
208        if connection.read_exact(&mut model_id_buf).await.is_err() {
209            return Ok(()); // Connection closed
210        }
211        let mut op_code_buf = [0u8; 4];
212        connection.read_exact(&mut op_code_buf).await?;
213        let mut payload_size_buf = [0u8; 4];
214        connection.read_exact(&mut payload_size_buf).await?;
215
216        let header = RequestHeader {
217            model_id: u32::from_le_bytes(model_id_buf),
218            op_code: u32::from_le_bytes(op_code_buf),
219            payload_size: u32::from_le_bytes(payload_size_buf),
220        };
221
222        // Read payload
223        let mut payload = vec![0u8; header.payload_size as usize];
224        connection.read_exact(&mut payload).await?;
225
226        match header.op_code {
227            OP_INFER_STREAM => {
228                // Deserialize request
229                let request: InferenceRequest = match decode_inference_request(&payload) {
230                    Ok(req) => req,
231                    Err(error_msg) => {
232                        let resp_header = ResponseHeader {
233                            status: STATUS_ERR,
234                            payload_size: error_msg.len() as u32,
235                        };
236                        connection
237                            .write_all(&resp_header.status.to_le_bytes())
238                            .await?;
239                        connection
240                            .write_all(&resp_header.payload_size.to_le_bytes())
241                            .await?;
242                        connection.write_all(error_msg.as_bytes()).await?;
243                        continue;
244                    }
245                };
246
247                if let Some(error_msg) = check_auth(&request, auth_token.as_deref()) {
248                    let resp_header = ResponseHeader {
249                        status: STATUS_ERR,
250                        payload_size: error_msg.len() as u32,
251                    };
252                    connection
253                        .write_all(&resp_header.status.to_le_bytes())
254                        .await?;
255                    connection
256                        .write_all(&resp_header.payload_size.to_le_bytes())
257                        .await?;
258                    connection.write_all(error_msg.as_bytes()).await?;
259                    continue;
260                }
261
262                // Get scheduler for model
263                let scheduler = match scheduler_lookup(header.model_id) {
264                    Some(s) => s,
265                    None => {
266                        let error_msg = format!("Model {} not found", header.model_id);
267                        let resp_header = ResponseHeader {
268                            status: STATUS_ERR,
269                            payload_size: error_msg.len() as u32,
270                        };
271                        connection
272                            .write_all(&resp_header.status.to_le_bytes())
273                            .await?;
274                        connection
275                            .write_all(&resp_header.payload_size.to_le_bytes())
276                            .await?;
277                        connection.write_all(error_msg.as_bytes()).await?;
278                        continue;
279                    }
280                };
281
282                // Execute streaming inference
283                let stream_result = scheduler
284                    .infer_stream(request, Priority::LatencyCritical, false)
285                    .await;
286
287                use futures::StreamExt;
288                match stream_result {
289                    Ok(mut inference_stream) => {
290                        while let Some(result) = inference_stream.next().await {
291                            match result {
292                                Ok(packet) => {
293                                    // Serialize packet
294                                    let response_bytes = match bincode::serialize(&packet) {
295                                        Ok(b) => b,
296                                        Err(e) => {
297                                            log::error!("Serialization error: {}", e);
298                                            break;
299                                        }
300                                    };
301
302                                    // Send chunk header
303                                    let response_header = ResponseHeader {
304                                        status: STATUS_STREAM_CHUNK,
305                                        payload_size: response_bytes.len() as u32,
306                                    };
307
308                                    connection
309                                        .write_all(&response_header.status.to_le_bytes())
310                                        .await?;
311                                    connection
312                                        .write_all(&response_header.payload_size.to_le_bytes())
313                                        .await?;
314                                    connection.write_all(&response_bytes).await?;
315                                    connection.flush().await?;
316                                }
317                                Err(e) => {
318                                    // Send error frame and stop
319                                    let error_msg = e.to_string();
320                                    let response_bytes = error_msg.as_bytes();
321                                    let response_header = ResponseHeader {
322                                        status: STATUS_ERR,
323                                        payload_size: response_bytes.len() as u32,
324                                    };
325                                    connection
326                                        .write_all(&response_header.status.to_le_bytes())
327                                        .await?;
328                                    connection
329                                        .write_all(&response_header.payload_size.to_le_bytes())
330                                        .await?;
331                                    connection.write_all(response_bytes).await?;
332                                    connection.flush().await?;
333                                    break;
334                                }
335                            }
336                        }
337
338                        // Send End of Stream frame
339                        let response_header = ResponseHeader {
340                            status: STATUS_STREAM_END,
341                            payload_size: 0,
342                        };
343                        connection
344                            .write_all(&response_header.status.to_le_bytes())
345                            .await?;
346                        connection
347                            .write_all(&response_header.payload_size.to_le_bytes())
348                            .await?;
349                        connection.flush().await?;
350                    }
351                    Err(e) => {
352                        // Send error frame for initial failure
353                        let error_msg = e.to_string();
354                        let response_bytes = error_msg.as_bytes();
355                        let response_header = ResponseHeader {
356                            status: STATUS_ERR,
357                            payload_size: response_bytes.len() as u32,
358                        };
359                        connection
360                            .write_all(&response_header.status.to_le_bytes())
361                            .await?;
362                        connection
363                            .write_all(&response_header.payload_size.to_le_bytes())
364                            .await?;
365                        connection.write_all(response_bytes).await?;
366                        connection.flush().await?;
367                    }
368                }
369            }
370            OP_INFER => {
371                // Find scheduler for model_id
372                if let Some(scheduler) = scheduler_lookup(header.model_id) {
373                    // Deserialize payload to InferenceRequest
374                    let request: InferenceRequest = match decode_inference_request(&payload) {
375                        Ok(req) => req,
376                        Err(error_msg) => {
377                            let resp_header = ResponseHeader {
378                                status: STATUS_ERR,
379                                payload_size: error_msg.len() as u32,
380                            };
381                            connection
382                                .write_all(&resp_header.status.to_le_bytes())
383                                .await?;
384                            connection
385                                .write_all(&resp_header.payload_size.to_le_bytes())
386                                .await?;
387                            connection.write_all(error_msg.as_bytes()).await?;
388                            continue;
389                        }
390                    };
391
392                    if let Some(error_msg) = check_auth(&request, auth_token.as_deref()) {
393                        let resp_header = ResponseHeader {
394                            status: STATUS_ERR,
395                            payload_size: error_msg.len() as u32,
396                        };
397                        connection
398                            .write_all(&resp_header.status.to_le_bytes())
399                            .await?;
400                        connection
401                            .write_all(&resp_header.payload_size.to_le_bytes())
402                            .await?;
403                        connection.write_all(error_msg.as_bytes()).await?;
404                        continue;
405                    }
406
407                    // Process
408                    // Default to Throughput priority and allow GPU (force_cpu = false)
409                    let result = scheduler.infer(&request, Priority::Throughput, false).await;
410
411                    match result {
412                        Ok(output) => {
413                            let output_bytes =
414                                bincode::serialize(&output).map_err(std::io::Error::other)?;
415
416                            let resp_header = ResponseHeader {
417                                status: STATUS_OK,
418                                payload_size: output_bytes.len() as u32,
419                            };
420
421                            // Write header as raw bytes (not bincode)
422                            connection
423                                .write_all(&resp_header.status.to_le_bytes())
424                                .await?;
425                            connection
426                                .write_all(&resp_header.payload_size.to_le_bytes())
427                                .await?;
428                            connection.write_all(&output_bytes).await?;
429                        }
430                        Err(e) => {
431                            let error_msg = e.to_string();
432                            let resp_header = ResponseHeader {
433                                status: STATUS_ERR,
434                                payload_size: error_msg.len() as u32,
435                            };
436                            connection
437                                .write_all(&resp_header.status.to_le_bytes())
438                                .await?;
439                            connection
440                                .write_all(&resp_header.payload_size.to_le_bytes())
441                                .await?;
442                            connection.write_all(error_msg.as_bytes()).await?;
443                        }
444                    }
445                } else {
446                    // Model not found
447                    let error_msg = format!("Model {} not found", header.model_id);
448                    let resp_header = ResponseHeader {
449                        status: STATUS_ERR,
450                        payload_size: error_msg.len() as u32,
451                    };
452                    connection
453                        .write_all(&resp_header.status.to_le_bytes())
454                        .await?;
455                    connection
456                        .write_all(&resp_header.payload_size.to_le_bytes())
457                        .await?;
458                    connection.write_all(error_msg.as_bytes()).await?;
459                }
460            }
461            OP_HYBRID_INFER => {
462                // Payload already read at line 131-132, just deserialize it
463                // Deserialize HybridRequest
464                let hybrid_req: HybridRequest = bincode::deserialize(&payload)
465                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
466
467                if let Some(shm_manager) = &shm_manager {
468                    let base_ptr = shm_manager.as_ptr();
469
470                    // Read TensorHeader from SHM
471                    let header_ptr = unsafe {
472                        base_ptr.add(hybrid_req.shm_offset as usize) as *const TensorHeader
473                    };
474                    let tensor_header = unsafe { &*header_ptr };
475
476                    // Read tensor data
477                    let data_ptr = unsafe {
478                        base_ptr.add(
479                            hybrid_req.shm_offset as usize + std::mem::size_of::<TensorHeader>(),
480                        )
481                    };
482                    let data_slice = unsafe {
483                        std::slice::from_raw_parts(data_ptr, tensor_header.data_size as usize)
484                    };
485
486                    // Build InferenceRequest
487                    let shape = tensor_header.shape[0..tensor_header.ndim as usize].to_vec();
488                    let dtype = match tensor_header.dtype {
489                        0 => TensorDtype::Float32,
490                        1 => TensorDtype::Float64,
491                        2 => TensorDtype::Int32,
492                        3 => TensorDtype::Int64,
493                        _ => TensorDtype::Float32,
494                    };
495
496                    let packet = BinaryTensorPacket {
497                        shape,
498                        dtype,
499                        data: data_slice.to_vec(),
500                    };
501
502                    let request = InferenceRequest {
503                        input: packet,
504                        additional_inputs: Vec::new(),
505                        session_id: None,
506                        metadata: None,
507                        cancellation: None,
508                    };
509
510                    // Perform inference
511                    let result =
512                        if let Some(scheduler) = scheduler_lookup(hybrid_req.metadata.model_id) {
513                            scheduler
514                                .infer(
515                                    &request,
516                                    Priority::Throughput,
517                                    hybrid_req.metadata.force_cpu,
518                                )
519                                .await
520                        } else {
521                            Err(kapsl_engine_api::EngineError::ModelNotLoaded)
522                        };
523
524                    match result {
525                        Ok(output) => {
526                            // Serialize output to BinaryTensorPacket
527                            let packet = BinaryTensorPacket {
528                                shape: output.shape,
529                                dtype: output.dtype,
530                                data: output.data,
531                            };
532
533                            // Calculate required size
534                            let output_size =
535                                std::mem::size_of::<TensorHeader>() + packet.data.len();
536
537                            // Allocate output slot with bounds checking
538                            // Use smaller slots (1MB) and more of them (400 slots from 512MB to 912MB)
539                            static SERVER_SLOT_COUNTER: std::sync::atomic::AtomicUsize =
540                                std::sync::atomic::AtomicUsize::new(0);
541                            let slot = SERVER_SLOT_COUNTER
542                                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
543                            let output_offset = 512 * 1024 * 1024 + (slot % 400) * 1_000_000; // 1MB slots, 400 slots
544
545                            // Bounds check
546                            let shm_size = shm_manager.size();
547                            if output_offset + output_size > shm_size {
548                                let error_msg = format!("Output would exceed SHM bounds: offset={}, size={}, shm_size={}",
549                                    output_offset, output_size, shm_size);
550                                let resp_header = ResponseHeader {
551                                    status: STATUS_ERR,
552                                    payload_size: error_msg.len() as u32,
553                                };
554                                connection
555                                    .write_all(&resp_header.status.to_le_bytes())
556                                    .await?;
557                                connection
558                                    .write_all(&resp_header.payload_size.to_le_bytes())
559                                    .await?;
560                                connection.write_all(error_msg.as_bytes()).await?;
561                                continue;
562                            }
563
564                            // Write result to SHM
565                            // Re-acquire base_ptr to avoid holding !Send raw pointer across await
566                            let base_ptr = shm_manager.as_ptr();
567                            unsafe {
568                                // Write header
569                                let out_header = TensorHeader {
570                                    ndim: packet.shape.len() as u32,
571                                    dtype: match packet.dtype {
572                                        TensorDtype::Float32 => 0,
573                                        TensorDtype::Float64 => 1,
574                                        TensorDtype::Int32 => 2,
575                                        TensorDtype::Int64 => 3,
576                                        _ => 0,
577                                    },
578                                    _padding: [0; 3],
579                                    shape: {
580                                        let mut arr = [0i64; 8];
581                                        for (i, &v) in packet.shape.iter().enumerate() {
582                                            arr[i] = v;
583                                        }
584                                        arr
585                                    },
586                                    data_size: packet.data.len() as u64,
587                                };
588
589                                let hdr_ptr = base_ptr.add(output_offset) as *mut TensorHeader;
590                                std::ptr::write(hdr_ptr, out_header);
591
592                                let data_ptr = base_ptr
593                                    .add(output_offset + std::mem::size_of::<TensorHeader>());
594                                std::ptr::copy_nonoverlapping(
595                                    packet.data.as_ptr(),
596                                    data_ptr,
597                                    packet.data.len(),
598                                );
599                            }
600
601                            // Send HybridResponse
602                            let resp = HybridResponse {
603                                metadata: ResponseMetadata {
604                                    request_id: hybrid_req.metadata.request_id,
605                                    status: STATUS_OK as u8,
606                                    _padding: [0; 7],
607                                    latency_ns: 0,
608                                },
609                                shm_offset: output_offset as u64,
610                                shm_size: (std::mem::size_of::<TensorHeader>() + packet.data.len())
611                                    as u64,
612                            };
613
614                            let resp_bytes =
615                                bincode::serialize(&resp).map_err(std::io::Error::other)?;
616
617                            let resp_header = ResponseHeader {
618                                status: STATUS_OK,
619                                payload_size: resp_bytes.len() as u32,
620                            };
621
622                            connection
623                                .write_all(&resp_header.status.to_le_bytes())
624                                .await?;
625                            connection
626                                .write_all(&resp_header.payload_size.to_le_bytes())
627                                .await?;
628                            connection.write_all(&resp_bytes).await?;
629                        }
630                        Err(e) => {
631                            let error_msg = e.to_string();
632                            let resp_header = ResponseHeader {
633                                status: STATUS_ERR,
634                                payload_size: error_msg.len() as u32,
635                            };
636                            connection
637                                .write_all(&resp_header.status.to_le_bytes())
638                                .await?;
639                            connection
640                                .write_all(&resp_header.payload_size.to_le_bytes())
641                                .await?;
642                            connection.write_all(error_msg.as_bytes()).await?;
643                        }
644                    }
645                } else {
646                    let error_msg = "SHM Manager not configured".to_string();
647                    let resp_header = ResponseHeader {
648                        status: STATUS_ERR,
649                        payload_size: error_msg.len() as u32,
650                    };
651                    connection
652                        .write_all(&resp_header.status.to_le_bytes())
653                        .await?;
654                    connection
655                        .write_all(&resp_header.payload_size.to_le_bytes())
656                        .await?;
657                    connection.write_all(error_msg.as_bytes()).await?;
658                }
659            }
660            _ => {
661                // Unsupported op
662                let resp_header = ResponseHeader {
663                    status: STATUS_ERR,
664                    payload_size: 0,
665                };
666                // Write header as raw bytes (not bincode)
667                connection
668                    .write_all(&resp_header.status.to_le_bytes())
669                    .await?;
670                connection
671                    .write_all(&resp_header.payload_size.to_le_bytes())
672                    .await?;
673            }
674        }
675    }
676}