1use crate::engine::lifecycle::{current_engine_pid_file, EnginePaths};
6use crate::engine::multiprocess::{pid_is_alive, read_pid_file};
7use crate::error::{Error, Result};
8use crate::ipc::endpoints::{management_url, request_url, response_url, EVENT_TOPIC_PREFIX};
9use crate::ipc::serialization::{build_batch_request_payload, PromptPayload, RequestType};
10
11use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
12use nng::options::Options;
13use nng::{Protocol, Socket};
14use serde::de::Error as DeError;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::path::{Path, PathBuf};
19use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
20use std::sync::{Arc, Mutex};
21use std::thread::{self, JoinHandle};
22use std::time::{Duration, Instant};
23use tokio::sync::mpsc;
24
25pub type EventCallback = Arc<dyn Fn(&str, &Value) + Send + Sync>;
27
28const ENGINE_LIVENESS_POLL_INTERVAL: Duration = Duration::from_secs(10);
29const RESPONSE_RECV_TIMEOUT: Duration = Duration::from_millis(10);
30const RESPONSE_SOCKET_BUFFER_MESSAGES: i32 = 1024;
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize)]
34#[serde(default)]
35pub struct TokenLogProb {
36 pub token: String,
38 pub logprob: f64,
40 pub bytes: Option<Vec<u8>>,
42}
43
44#[derive(Debug, Clone, Default, Serialize, Deserialize)]
46#[serde(default)]
47pub struct ResponseStateEvent {
48 pub event_type: String,
50 pub item_type: String,
52 pub output_index: u32,
54 pub identifier: String,
56 pub delta: String,
58 pub value: Option<Value>,
60}
61
62fn deserialize_optional_bytes<'de, D>(
63 deserializer: D,
64) -> std::result::Result<Option<Vec<u8>>, D::Error>
65where
66 D: serde::Deserializer<'de>,
67{
68 #[derive(Deserialize)]
69 #[serde(untagged)]
70 enum BytePayload {
71 Bytes(Vec<u8>),
72 Base64(String),
73 }
74
75 match Option::<BytePayload>::deserialize(deserializer)? {
76 None => Ok(None),
77 Some(BytePayload::Bytes(bytes)) => Ok(Some(bytes)),
78 Some(BytePayload::Base64(encoded)) => {
79 BASE64.decode(encoded).map(Some).map_err(D::Error::custom)
80 }
81 }
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
88#[serde(default)]
89pub struct ResponseDelta {
90 pub request_id: u64,
92 pub sequence_id: Option<u64>,
94 pub prompt_index: Option<u32>,
96 pub candidate_index: Option<u32>,
98 pub content: Option<String>,
100 pub content_len: Option<u32>,
102 pub inline_content_bytes: Option<u32>,
104 pub is_final_delta: bool,
106 pub finish_reason: Option<String>,
108 #[serde(alias = "error_message")]
110 pub error: Option<String>,
111 pub prompt_token_count: Option<u32>,
113 pub num_tokens_in_delta: Option<u32>,
115 pub generation_len: Option<u32>,
117 pub tokens: Vec<i32>,
119 pub top_logprobs: Vec<TokenLogProb>,
121 pub cumulative_logprob: Option<f64>,
123 pub modal_decoder_id: Option<String>,
125 pub modal_bytes_b64: Option<String>,
127 #[serde(default, deserialize_with = "deserialize_optional_bytes")]
129 pub embedding_bytes: Option<Vec<u8>>,
130 pub state_events: Vec<ResponseStateEvent>,
132 pub cached_token_count: Option<u32>,
134 pub reasoning_tokens: Option<u32>,
136}
137
138pub struct IPCClient {
143 request_socket: Option<Socket>,
144 response_socket: Option<Socket>,
145 management_socket: Arc<Mutex<Option<Socket>>>,
147 response_channel_id: u64,
148 request_id_counter: AtomicU64,
149 active_requests: Arc<Mutex<HashMap<u64, ActiveRequest>>>,
150 listener_handle: Option<JoinHandle<()>>,
151 should_stop: Arc<AtomicBool>,
152 event_callback: Option<EventCallback>,
153}
154
155struct ActiveRequest {
156 sender: mpsc::UnboundedSender<ResponseDelta>,
157 remaining_finals: usize,
158}
159
160impl IPCClient {
161 pub fn new() -> Self {
163 Self {
164 request_socket: None,
165 response_socket: None,
166 management_socket: Arc::new(Mutex::new(None)),
167 response_channel_id: rand_u64(),
168 request_id_counter: AtomicU64::new(0),
169 active_requests: Arc::new(Mutex::new(HashMap::new())),
170 listener_handle: None,
171 should_stop: Arc::new(AtomicBool::new(false)),
172 event_callback: None,
173 }
174 }
175
176 pub fn with_event_callback(callback: EventCallback) -> Self {
178 Self {
179 request_socket: None,
180 response_socket: None,
181 management_socket: Arc::new(Mutex::new(None)),
182 response_channel_id: rand_u64(),
183 request_id_counter: AtomicU64::new(0),
184 active_requests: Arc::new(Mutex::new(HashMap::new())),
185 listener_handle: None,
186 should_stop: Arc::new(AtomicBool::new(false)),
187 event_callback: Some(callback),
188 }
189 }
190
191 pub fn management_socket(&self) -> Arc<Mutex<Option<Socket>>> {
193 Arc::clone(&self.management_socket)
194 }
195
196 pub fn set_event_callback(&mut self, callback: EventCallback) {
198 self.event_callback = Some(callback);
199 }
200
201 pub fn connect(&mut self) -> Result<()> {
203 let engine_pid_file = current_engine_pid_file()
204 .or_else(|| EnginePaths::new().ok().map(|paths| paths.pid_file))
205 .ok_or_else(|| Error::Internal("Cannot determine engine PID file path".into()))?;
206
207 let request_socket = Socket::new(Protocol::Push0)?;
209 request_socket.dial(&request_url())?;
210 self.request_socket = Some(request_socket);
211
212 let response_socket = Socket::new(Protocol::Sub0)?;
214 response_socket.set_opt::<nng::options::RecvBufferSize>(RESPONSE_SOCKET_BUFFER_MESSAGES)?;
215
216 let response_topic = format!("resp:{:x}:", self.response_channel_id);
218 response_socket.set_opt::<nng::options::protocol::pubsub::Subscribe>(
219 response_topic.as_bytes().to_vec(),
220 )?;
221
222 response_socket
224 .set_opt::<nng::options::protocol::pubsub::Subscribe>(EVENT_TOPIC_PREFIX.to_vec())?;
225
226 response_socket.dial(&response_url())?;
227 self.response_socket = Some(response_socket);
228
229 let management_socket = Socket::new(Protocol::Req0)?;
231 management_socket.dial(&management_url())?;
232 {
233 let mut mgmt = self
234 .management_socket
235 .lock()
236 .unwrap_or_else(|e| e.into_inner());
237 *mgmt = Some(management_socket);
238 }
239
240 self.should_stop.store(false, Ordering::SeqCst);
242 self.start_listener(engine_pid_file);
243
244 Ok(())
245 }
246
247 pub fn disconnect(&mut self) {
251 self.should_stop.store(true, Ordering::SeqCst);
252
253 {
255 let requests = self
256 .active_requests
257 .lock()
258 .unwrap_or_else(|e| e.into_inner());
259
260 for (request_id, entry) in requests.iter() {
261 let error_delta = ResponseDelta {
262 request_id: *request_id,
263 is_final_delta: true,
264 finish_reason: Some("error".to_string()),
265 content: Some("Engine process disconnected.".to_string()),
266 error: Some("Engine process disconnected.".to_string()),
267 ..Default::default()
268 };
269 let _ = entry.sender.send(error_delta);
270 }
271 }
272
273 if let Some(handle) = self.listener_handle.take() {
274 let _ = handle.join();
275 }
276
277 self.request_socket = None;
278 self.response_socket = None;
279 {
280 let mut mgmt = self
281 .management_socket
282 .lock()
283 .unwrap_or_else(|e| e.into_inner());
284 *mgmt = None;
285 }
286
287 if let Ok(mut requests) = self.active_requests.lock() {
288 requests.clear();
289 }
290 }
291
292 pub fn next_request_id(&self) -> u64 {
294 let id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
295 if id >= u64::MAX - 1 {
296 self.request_id_counter.store(1, Ordering::SeqCst);
297 }
298 id + 1
299 }
300
301 pub fn send_batch_request(
303 &self,
304 request_id: u64,
305 model_id: &str,
306 model_path: &str,
307 prompts: &[PromptPayload],
308 ) -> Result<(usize, mpsc::UnboundedReceiver<ResponseDelta>)> {
309 self.send_batch_request_with_type(
310 request_id,
311 model_id,
312 model_path,
313 RequestType::Generation,
314 prompts,
315 )
316 }
317
318 pub(crate) fn send_batch_request_with_type(
319 &self,
320 request_id: u64,
321 model_id: &str,
322 model_path: &str,
323 request_type: RequestType,
324 prompts: &[PromptPayload],
325 ) -> Result<(usize, mpsc::UnboundedReceiver<ResponseDelta>)> {
326 let socket = self.request_socket.as_ref().ok_or(Error::NotConnected)?;
327 tracing::debug!(
328 request_id,
329 model_id = %model_id,
330 ?request_type,
331 prompt_count = prompts.len(),
332 "Serializing and sending IPC batch request"
333 );
334
335 let payload = build_batch_request_payload(
336 request_id,
337 model_id,
338 model_path,
339 request_type,
340 self.response_channel_id,
341 prompts,
342 )?;
343 tracing::debug!(
344 request_id,
345 model_id = %model_id,
346 ?request_type,
347 payload_bytes = payload.len(),
348 "Built IPC batch payload"
349 );
350
351 let (tx, rx) = mpsc::unbounded_channel();
352 let remaining_finals = prompts
353 .iter()
354 .map(|prompt| {
355 let num_candidates = prompt.num_candidates.max(1);
356 let best_of = prompt.best_of.unwrap_or(num_candidates).max(1);
357 let final_candidates = prompt.final_candidates.unwrap_or(best_of).max(1);
358 final_candidates as usize
359 })
360 .sum::<usize>()
361 .max(1);
362
363 self.active_requests
364 .lock()
365 .unwrap_or_else(|e| e.into_inner())
366 .insert(
367 request_id,
368 ActiveRequest {
369 sender: tx,
370 remaining_finals,
371 },
372 );
373
374 let msg = nng::Message::from(payload.as_slice());
375 socket.send(msg).map_err(|(_, e)| Error::Nng(e))?;
376 tracing::debug!(
377 request_id,
378 model_id = %model_id,
379 ?request_type,
380 expected_final_count = remaining_finals,
381 "IPC batch request sent"
382 );
383
384 Ok((prompts.len(), rx))
385 }
386
387 pub async fn send_management_command_async(
391 &self,
392 command: Value,
393 timeout: Duration,
394 ) -> Result<Value> {
395 let socket_arc = Arc::clone(&self.management_socket);
396
397 tokio::task::spawn_blocking(move || {
398 let guard = socket_arc.lock().unwrap_or_else(|e| e.into_inner());
399 let socket = guard.as_ref().ok_or(Error::NotConnected)?;
400
401 socket.set_opt::<nng::options::RecvTimeout>(Some(timeout))?;
403
404 let data = serde_json::to_vec(&command)?;
406 let msg = nng::Message::from(data.as_slice());
407 socket.send(msg).map_err(|(_, e)| Error::Nng(e))?;
408
409 let response = socket.recv()?;
411 let json: Value = serde_json::from_slice(&response)?;
412
413 Ok(json)
414 })
415 .await
416 .map_err(|e| Error::Internal(format!("Task join error: {}", e)))?
417 }
418
419 pub fn send_management_command(&self, command: &Value, timeout: Duration) -> Result<Value> {
423 let guard = self
424 .management_socket
425 .lock()
426 .unwrap_or_else(|e| e.into_inner());
427 let socket = guard.as_ref().ok_or(Error::NotConnected)?;
428
429 socket.set_opt::<nng::options::RecvTimeout>(Some(timeout))?;
431
432 let data = serde_json::to_vec(command)?;
434 let msg = nng::Message::from(data.as_slice());
435 socket.send(msg).map_err(|(_, e)| Error::Nng(e))?;
436
437 let response = socket.recv()?;
439 let json: Value = serde_json::from_slice(&response)?;
440
441 Ok(json)
442 }
443
444 fn start_listener(&mut self, engine_pid_file: PathBuf) {
446 let response_socket = self.response_socket.take();
447 let active_requests = Arc::clone(&self.active_requests);
448 let should_stop = Arc::clone(&self.should_stop);
449 let response_channel_id = self.response_channel_id;
450 let event_callback = self.event_callback.clone();
451
452 let handle = thread::Builder::new()
453 .name("orchard-ipc-listener".to_string())
454 .spawn(move || {
455 if let Some(socket) = response_socket {
456 run_response_listener(
457 socket,
458 active_requests,
459 should_stop,
460 response_channel_id,
461 engine_pid_file,
462 event_callback,
463 );
464 }
465 });
466
467 match handle {
468 Ok(h) => self.listener_handle = Some(h),
469 Err(e) => tracing::error!("Failed to spawn IPC listener thread: {}", e),
470 }
471 }
472}
473
474impl Default for IPCClient {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480impl Drop for IPCClient {
481 fn drop(&mut self) {
482 self.disconnect();
483 }
484}
485
486fn engine_process_is_alive(engine_pid_file: &Path) -> bool {
487 read_pid_file(engine_pid_file)
488 .map(pid_is_alive)
489 .unwrap_or(false)
490}
491
492fn run_response_listener(
494 socket: Socket,
495 active_requests: Arc<Mutex<HashMap<u64, ActiveRequest>>>,
496 should_stop: Arc<AtomicBool>,
497 response_channel_id: u64,
498 engine_pid_file: PathBuf,
499 event_callback: Option<EventCallback>,
500) {
501 let response_topic = format!("resp:{:x}:", response_channel_id);
502 let response_topic_bytes = response_topic.as_bytes();
503
504 let _ = socket.set_opt::<nng::options::RecvTimeout>(Some(RESPONSE_RECV_TIMEOUT));
506 let mut last_engine_check = Instant::now();
507
508 while !should_stop.load(Ordering::SeqCst) {
509 match socket.recv() {
510 Ok(msg) => {
511 let data = msg.as_slice();
512
513 if data.starts_with(response_topic_bytes) {
515 let json_data = &data[response_topic_bytes.len()..];
516
517 if let Ok(delta) = serde_json::from_slice::<ResponseDelta>(json_data) {
518 let request_id = delta.request_id;
519 let is_final = delta.is_final_delta;
520
521 let sender = {
522 let mut requests =
523 active_requests.lock().unwrap_or_else(|e| e.into_inner());
524 if let Some(entry) = requests.get_mut(&request_id) {
525 if is_final {
526 entry.remaining_finals =
527 entry.remaining_finals.saturating_sub(1);
528 if entry.remaining_finals == 0 {
529 let sender = entry.sender.clone();
530 requests.remove(&request_id);
531 Some(sender)
532 } else {
533 Some(entry.sender.clone())
534 }
535 } else {
536 Some(entry.sender.clone())
537 }
538 } else {
539 None
540 }
541 };
542
543 if let Some(tx) = sender {
544 let _ = tx.send(delta);
545 }
546 } else {
547 tracing::warn!(
548 response_channel_id,
549 payload_bytes = json_data.len(),
550 "Failed to deserialize IPC response payload"
551 );
552 }
553 }
554 else if data.starts_with(EVENT_TOPIC_PREFIX) {
556 handle_engine_event(data, &event_callback);
557 }
558 }
559 Err(nng::Error::TimedOut) => {
560 if last_engine_check.elapsed() >= ENGINE_LIVENESS_POLL_INTERVAL {
561 last_engine_check = Instant::now();
562 if !engine_process_is_alive(&engine_pid_file) {
563 tracing::error!(
564 pid_file = %engine_pid_file.display(),
565 "PIE is no longer alive; shutting down IPC listener"
566 );
567 should_stop.store(true, Ordering::SeqCst);
568 break;
569 }
570 }
571 continue;
572 }
573 Err(error) => {
574 if should_stop.load(Ordering::SeqCst) {
575 break;
576 }
577 if !engine_process_is_alive(&engine_pid_file) {
578 tracing::error!(
579 pid_file = %engine_pid_file.display(),
580 error = %error,
581 "PIE is no longer alive; shutting down IPC listener"
582 );
583 should_stop.store(true, Ordering::SeqCst);
584 break;
585 }
586 }
587 }
588 }
589
590 tracing::info!("IPC listener shutting down");
592 let requests = active_requests.lock().unwrap_or_else(|e| e.into_inner());
593
594 if !requests.is_empty() {
595 tracing::warn!(
596 "IPC listener exiting with {} active requests; failing them.",
597 requests.len()
598 );
599
600 for (request_id, entry) in requests.iter() {
601 let error_delta = ResponseDelta {
602 request_id: *request_id,
603 is_final_delta: true,
604 finish_reason: Some("error".to_string()),
605 content: Some("Engine process disconnected.".to_string()),
606 error: Some("Engine process disconnected.".to_string()),
607 ..Default::default()
608 };
609 let _ = entry.sender.send(error_delta);
610 }
611 }
612}
613
614fn handle_engine_event(data: &[u8], event_callback: &Option<EventCallback>) {
616 let parts: Vec<&[u8]> = data.splitn(2, |&b| b == 0).collect();
618 if parts.len() != 2 {
619 tracing::warn!("Received malformed event message");
620 return;
621 }
622
623 let (topic_part, json_body) = (parts[0], parts[1]);
624
625 let event_name = if topic_part.len() > EVENT_TOPIC_PREFIX.len() {
627 String::from_utf8_lossy(&topic_part[EVENT_TOPIC_PREFIX.len()..]).to_string()
628 } else {
629 tracing::warn!("Event message has empty event name");
630 return;
631 };
632
633 let payload: Value = match serde_json::from_slice(json_body) {
635 Ok(v) => v,
636 Err(e) => {
637 tracing::error!("Failed to parse engine event payload: {}", e);
638 return;
639 }
640 };
641
642 if event_name != "telemetry" {
643 tracing::debug!("Received engine event: {}", event_name);
644 }
645
646 if let Some(callback) = event_callback {
648 callback(&event_name, &payload);
649 }
650}
651
652fn rand_u64() -> u64 {
655 use rand::Rng;
656
657 let pid = std::process::id() as u64 & 0xFFFFFFFF;
658 let random: u32 = rand::thread_rng().gen();
659
660 let channel_id = (pid << 32) | (random as u64);
661 if channel_id == 0 {
662 1
663 } else {
664 channel_id
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671 use tempfile::tempdir;
672
673 #[test]
674 fn test_client_creation() {
675 let client = IPCClient::new();
676 assert!(client.request_socket.is_none());
677 assert!(client.response_channel_id > 0);
678 }
679
680 #[test]
681 fn test_request_id_increment() {
682 let client = IPCClient::new();
683 let id1 = client.next_request_id();
684 let id2 = client.next_request_id();
685 assert_eq!(id2, id1 + 1);
686 }
687
688 #[test]
689 fn test_response_delta_default() {
690 let delta = ResponseDelta::default();
691 assert_eq!(delta.request_id, 0);
692 assert!(!delta.is_final_delta);
693 assert!(delta.tokens.is_empty());
694 assert!(delta.top_logprobs.is_empty());
695 assert!(delta.embedding_bytes.is_none());
696 assert!(delta.state_events.is_empty());
697 }
698
699 #[test]
700 fn test_response_delta_deserialize() {
701 let json = serde_json::json!({
702 "request_id": 123,
703 "sequence_id": 1,
704 "prompt_index": 0,
705 "candidate_index": 0,
706 "content": "Hello",
707 "content_len": 5,
708 "inline_content_bytes": 5,
709 "is_final_delta": false,
710 "num_tokens_in_delta": 3,
711 "tokens": [1, 2, 3],
712 "top_logprobs": [{"token": "hello", "logprob": -0.5}, {"token": "world", "logprob": -1.0}],
713 "cumulative_logprob": -1.5,
714 "modal_decoder_id": "moondream3.coord",
715 "modal_bytes_b64": "AAAA",
716 "embedding_bytes": [0, 0, 128, 63]
717 });
718 let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
719 assert_eq!(delta.request_id, 123);
720 assert_eq!(delta.sequence_id, Some(1));
721 assert_eq!(delta.candidate_index, Some(0));
722 assert_eq!(delta.content_len, Some(5));
723 assert_eq!(delta.num_tokens_in_delta, Some(3));
724 assert_eq!(delta.tokens, vec![1, 2, 3]);
725 assert_eq!(delta.top_logprobs.len(), 2);
726 assert_eq!(delta.cumulative_logprob, Some(-1.5));
727 assert_eq!(delta.modal_decoder_id, Some("moondream3.coord".to_string()));
728 assert_eq!(delta.embedding_bytes, Some(vec![0, 0, 128, 63]));
729 assert!(delta.state_events.is_empty());
730 }
731
732 #[test]
733 fn test_response_delta_deserialize_with_defaults() {
734 let json = serde_json::json!({
736 "request_id": 42,
737 "is_final_delta": true
738 });
739 let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
740 assert_eq!(delta.request_id, 42);
741 assert!(delta.is_final_delta);
742 assert!(delta.tokens.is_empty());
743 assert!(delta.content.is_none());
744 assert!(delta.state_events.is_empty());
745 }
746
747 #[test]
748 fn test_response_delta_deserialize_embedding_bytes_from_base64() {
749 let json = serde_json::json!({
750 "request_id": 42,
751 "is_final_delta": true,
752 "embedding_bytes": "AAAAAA==",
753 });
754
755 let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
756 assert_eq!(delta.embedding_bytes, Some(vec![0, 0, 0, 0]));
757 }
758
759 #[test]
760 fn test_response_delta_deserialize_error_message_alias() {
761 let json = serde_json::json!({
762 "request_id": 42,
763 "is_final_delta": true,
764 "error_message": "boom",
765 });
766
767 let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
768 assert_eq!(delta.error.as_deref(), Some("boom"));
769 }
770
771 #[test]
772 fn test_engine_process_is_alive_reads_pid_file() {
773 let dir = tempdir().expect("tempdir should be available");
774 let pid_file = dir.path().join("engine.pid");
775 std::fs::write(&pid_file, format!("{}\n", std::process::id()))
776 .expect("pid file should be written");
777
778 assert!(engine_process_is_alive(&pid_file));
779 }
780
781 #[test]
782 fn test_engine_process_is_alive_handles_missing_pid_file() {
783 let dir = tempdir().expect("tempdir should be available");
784 let pid_file = dir.path().join("missing.pid");
785
786 assert!(!engine_process_is_alive(&pid_file));
787 }
788
789 #[test]
790 fn test_response_delta_deserialize_with_state_events() {
791 let json = serde_json::json!({
792 "request_id": 7,
793 "is_final_delta": false,
794 "state_events": [
795 {
796 "event_type": "item_started",
797 "item_type": "message",
798 "output_index": 0,
799 "identifier": "",
800 "delta": ""
801 },
802 {
803 "event_type": "content_delta",
804 "item_type": "message",
805 "output_index": 0,
806 "identifier": "",
807 "delta": "hello"
808 }
809 ]
810 });
811
812 let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
813 assert_eq!(delta.request_id, 7);
814 assert_eq!(delta.state_events.len(), 2);
815 assert_eq!(delta.state_events[0].event_type, "item_started");
816 assert_eq!(delta.state_events[1].delta, "hello");
817 }
818}