Skip to main content

orchard/ipc/
client.rs

1//! High-performance IPC client for communicating with PIE.
2//!
3//! Uses NNG sockets with a dedicated listener thread for response handling.
4
5use crate::engine::lifecycle::{current_engine_pid_file, EnginePaths};
6use crate::engine::multiprocess::{pid_is_alive, read_pid_file};
7use crate::error::{Error, Result};
8use crate::ipc::endpoints::{management_url, request_url, response_url, EVENT_TOPIC_PREFIX};
9use crate::ipc::serialization::{build_batch_request_payload, PromptPayload, RequestType};
10
11use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
12use nng::options::Options;
13use nng::{Protocol, Socket};
14use serde::de::Error as DeError;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::path::{Path, PathBuf};
19use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
20use std::sync::{Arc, Mutex};
21use std::thread::{self, JoinHandle};
22use std::time::{Duration, Instant};
23use tokio::sync::mpsc;
24
25/// Callback type for engine events (telemetry, model_loaded, etc.)
26pub type EventCallback = Arc<dyn Fn(&str, &Value) + Send + Sync>;
27
28const ENGINE_LIVENESS_POLL_INTERVAL: Duration = Duration::from_secs(10);
29const RESPONSE_RECV_TIMEOUT: Duration = Duration::from_millis(10);
30const RESPONSE_SOCKET_BUFFER_MESSAGES: i32 = 1024;
31
32/// A single token's log probability info from PIE.
33#[derive(Debug, Clone, Default, Serialize, Deserialize)]
34#[serde(default)]
35pub struct TokenLogProb {
36    /// The token string (or token ID as string)
37    pub token: String,
38    /// The log probability value
39    pub logprob: f64,
40    /// Optional bytes representation
41    pub bytes: Option<Vec<u8>>,
42}
43
44/// A single state transition event emitted by PIE/PSE for structured outputs.
45#[derive(Debug, Clone, Default, Serialize, Deserialize)]
46#[serde(default)]
47pub struct ResponseStateEvent {
48    /// Event type (e.g., item_started, content_delta, item_completed)
49    pub event_type: String,
50    /// Item type (e.g., message, tool_call, reasoning)
51    pub item_type: String,
52    /// Output index for this item in the response output array
53    pub output_index: u32,
54    /// Identifier for sub-items or tool names
55    pub identifier: String,
56    /// Delta text for streaming content updates
57    pub delta: String,
58    /// Optional final value for completion events
59    pub value: Option<Value>,
60}
61
62fn deserialize_optional_bytes<'de, D>(
63    deserializer: D,
64) -> std::result::Result<Option<Vec<u8>>, D::Error>
65where
66    D: serde::Deserializer<'de>,
67{
68    #[derive(Deserialize)]
69    #[serde(untagged)]
70    enum BytePayload {
71        Bytes(Vec<u8>),
72        Base64(String),
73    }
74
75    match Option::<BytePayload>::deserialize(deserializer)? {
76        None => Ok(None),
77        Some(BytePayload::Bytes(bytes)) => Ok(Some(bytes)),
78        Some(BytePayload::Base64(encoded)) => {
79            BASE64.decode(encoded).map(Some).map_err(D::Error::custom)
80        }
81    }
82}
83
84/// Response delta from PIE.
85///
86/// Uses serde for deserialization with sensible defaults for missing fields.
87#[derive(Debug, Clone, Default, Serialize, Deserialize)]
88#[serde(default)]
89pub struct ResponseDelta {
90    /// Request ID this delta belongs to
91    pub request_id: u64,
92    /// Sequence ID for ordering
93    pub sequence_id: Option<u64>,
94    /// Prompt index for batched requests (identifies which prompt in the batch)
95    pub prompt_index: Option<u32>,
96    /// Candidate index (for multi-candidate generation)
97    pub candidate_index: Option<u32>,
98    /// Generated content (token text)
99    pub content: Option<String>,
100    /// Content length in characters
101    pub content_len: Option<u32>,
102    /// Inline content bytes
103    pub inline_content_bytes: Option<u32>,
104    /// Whether this is the final delta
105    pub is_final_delta: bool,
106    /// Finish reason (e.g., "stop", "length")
107    pub finish_reason: Option<String>,
108    /// Error message if request failed
109    #[serde(alias = "error_message")]
110    pub error: Option<String>,
111    /// Prompt token count
112    pub prompt_token_count: Option<u32>,
113    /// Number of tokens in this delta
114    pub num_tokens_in_delta: Option<u32>,
115    /// Generation length so far
116    pub generation_len: Option<u32>,
117    /// Token IDs in this delta
118    pub tokens: Vec<i32>,
119    /// Top log probabilities for each token position
120    pub top_logprobs: Vec<TokenLogProb>,
121    /// Cumulative log probability
122    pub cumulative_logprob: Option<f64>,
123    /// Modal decoder identifier (e.g., "moondream3.coord")
124    pub modal_decoder_id: Option<String>,
125    /// Base64-encoded modal decoder output bytes
126    pub modal_bytes_b64: Option<String>,
127    /// Raw embedding bytes from PIE, when request_type is embedding.
128    #[serde(default, deserialize_with = "deserialize_optional_bytes")]
129    pub embedding_bytes: Option<Vec<u8>>,
130    /// Structured state transition events used by Responses API.
131    pub state_events: Vec<ResponseStateEvent>,
132    /// Cached token count (input token cache hits).
133    pub cached_token_count: Option<u32>,
134    /// Reasoning token count, when available.
135    pub reasoning_tokens: Option<u32>,
136}
137
138/// High-performance IPC client for communicating with PIE.
139///
140/// Uses a lock-based design instead of actors to minimize overhead in the hot path.
141/// All socket operations are thread-safe via internal locks.
142pub struct IPCClient {
143    request_socket: Option<Socket>,
144    response_socket: Option<Socket>,
145    /// Management socket wrapped in Arc<Mutex> for async access via spawn_blocking
146    management_socket: Arc<Mutex<Option<Socket>>>,
147    response_channel_id: u64,
148    request_id_counter: AtomicU64,
149    active_requests: Arc<Mutex<HashMap<u64, ActiveRequest>>>,
150    listener_handle: Option<JoinHandle<()>>,
151    should_stop: Arc<AtomicBool>,
152    event_callback: Option<EventCallback>,
153}
154
155struct ActiveRequest {
156    sender: mpsc::UnboundedSender<ResponseDelta>,
157    remaining_finals: usize,
158}
159
160impl IPCClient {
161    /// Create a new IPC client (not connected).
162    pub fn new() -> Self {
163        Self {
164            request_socket: None,
165            response_socket: None,
166            management_socket: Arc::new(Mutex::new(None)),
167            response_channel_id: rand_u64(),
168            request_id_counter: AtomicU64::new(0),
169            active_requests: Arc::new(Mutex::new(HashMap::new())),
170            listener_handle: None,
171            should_stop: Arc::new(AtomicBool::new(false)),
172            event_callback: None,
173        }
174    }
175
176    /// Create a new IPC client with an event callback.
177    pub fn with_event_callback(callback: EventCallback) -> Self {
178        Self {
179            request_socket: None,
180            response_socket: None,
181            management_socket: Arc::new(Mutex::new(None)),
182            response_channel_id: rand_u64(),
183            request_id_counter: AtomicU64::new(0),
184            active_requests: Arc::new(Mutex::new(HashMap::new())),
185            listener_handle: None,
186            should_stop: Arc::new(AtomicBool::new(false)),
187            event_callback: Some(callback),
188        }
189    }
190
191    /// Get a clone of the management socket Arc for async operations.
192    pub fn management_socket(&self) -> Arc<Mutex<Option<Socket>>> {
193        Arc::clone(&self.management_socket)
194    }
195
196    /// Set the event callback for handling engine events.
197    pub fn set_event_callback(&mut self, callback: EventCallback) {
198        self.event_callback = Some(callback);
199    }
200
201    /// Connect to PIE IPC endpoints.
202    pub fn connect(&mut self) -> Result<()> {
203        let engine_pid_file = current_engine_pid_file()
204            .or_else(|| EnginePaths::new().ok().map(|paths| paths.pid_file))
205            .ok_or_else(|| Error::Internal("Cannot determine engine PID file path".into()))?;
206
207        // Create and connect request socket (PUSH)
208        let request_socket = Socket::new(Protocol::Push0)?;
209        request_socket.dial(&request_url())?;
210        self.request_socket = Some(request_socket);
211
212        // Create response socket (SUB) - subscribe BEFORE dial
213        let response_socket = Socket::new(Protocol::Sub0)?;
214        response_socket.set_opt::<nng::options::RecvBufferSize>(RESPONSE_SOCKET_BUFFER_MESSAGES)?;
215
216        // Subscribe to our response topic
217        let response_topic = format!("resp:{:x}:", self.response_channel_id);
218        response_socket.set_opt::<nng::options::protocol::pubsub::Subscribe>(
219            response_topic.as_bytes().to_vec(),
220        )?;
221
222        // Subscribe to global events
223        response_socket
224            .set_opt::<nng::options::protocol::pubsub::Subscribe>(EVENT_TOPIC_PREFIX.to_vec())?;
225
226        response_socket.dial(&response_url())?;
227        self.response_socket = Some(response_socket);
228
229        // Create management socket (REQ)
230        let management_socket = Socket::new(Protocol::Req0)?;
231        management_socket.dial(&management_url())?;
232        {
233            let mut mgmt = self
234                .management_socket
235                .lock()
236                .unwrap_or_else(|e| e.into_inner());
237            *mgmt = Some(management_socket);
238        }
239
240        // Start listener thread
241        self.should_stop.store(false, Ordering::SeqCst);
242        self.start_listener(engine_pid_file);
243
244        Ok(())
245    }
246
247    /// Disconnect from PIE with graceful shutdown.
248    ///
249    /// Sends error deltas to all pending requests before closing.
250    pub fn disconnect(&mut self) {
251        self.should_stop.store(true, Ordering::SeqCst);
252
253        // Send error deltas to all pending requests (graceful shutdown)
254        {
255            let requests = self
256                .active_requests
257                .lock()
258                .unwrap_or_else(|e| e.into_inner());
259
260            for (request_id, entry) in requests.iter() {
261                let error_delta = ResponseDelta {
262                    request_id: *request_id,
263                    is_final_delta: true,
264                    finish_reason: Some("error".to_string()),
265                    content: Some("Engine process disconnected.".to_string()),
266                    error: Some("Engine process disconnected.".to_string()),
267                    ..Default::default()
268                };
269                let _ = entry.sender.send(error_delta);
270            }
271        }
272
273        if let Some(handle) = self.listener_handle.take() {
274            let _ = handle.join();
275        }
276
277        self.request_socket = None;
278        self.response_socket = None;
279        {
280            let mut mgmt = self
281                .management_socket
282                .lock()
283                .unwrap_or_else(|e| e.into_inner());
284            *mgmt = None;
285        }
286
287        if let Ok(mut requests) = self.active_requests.lock() {
288            requests.clear();
289        }
290    }
291
292    /// Get the next request ID.
293    pub fn next_request_id(&self) -> u64 {
294        let id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
295        if id >= u64::MAX - 1 {
296            self.request_id_counter.store(1, Ordering::SeqCst);
297        }
298        id + 1
299    }
300
301    /// Send a batched request with multiple prompts in ONE IPC message.
302    pub fn send_batch_request(
303        &self,
304        request_id: u64,
305        model_id: &str,
306        model_path: &str,
307        prompts: &[PromptPayload],
308    ) -> Result<(usize, mpsc::UnboundedReceiver<ResponseDelta>)> {
309        self.send_batch_request_with_type(
310            request_id,
311            model_id,
312            model_path,
313            RequestType::Generation,
314            prompts,
315        )
316    }
317
318    pub(crate) fn send_batch_request_with_type(
319        &self,
320        request_id: u64,
321        model_id: &str,
322        model_path: &str,
323        request_type: RequestType,
324        prompts: &[PromptPayload],
325    ) -> Result<(usize, mpsc::UnboundedReceiver<ResponseDelta>)> {
326        let socket = self.request_socket.as_ref().ok_or(Error::NotConnected)?;
327        tracing::debug!(
328            request_id,
329            model_id = %model_id,
330            ?request_type,
331            prompt_count = prompts.len(),
332            "Serializing and sending IPC batch request"
333        );
334
335        let payload = build_batch_request_payload(
336            request_id,
337            model_id,
338            model_path,
339            request_type,
340            self.response_channel_id,
341            prompts,
342        )?;
343        tracing::debug!(
344            request_id,
345            model_id = %model_id,
346            ?request_type,
347            payload_bytes = payload.len(),
348            "Built IPC batch payload"
349        );
350
351        let (tx, rx) = mpsc::unbounded_channel();
352        let remaining_finals = prompts
353            .iter()
354            .map(|prompt| {
355                let num_candidates = prompt.num_candidates.max(1);
356                let best_of = prompt.best_of.unwrap_or(num_candidates).max(1);
357                let final_candidates = prompt.final_candidates.unwrap_or(best_of).max(1);
358                final_candidates as usize
359            })
360            .sum::<usize>()
361            .max(1);
362
363        self.active_requests
364            .lock()
365            .unwrap_or_else(|e| e.into_inner())
366            .insert(
367                request_id,
368                ActiveRequest {
369                    sender: tx,
370                    remaining_finals,
371                },
372            );
373
374        let msg = nng::Message::from(payload.as_slice());
375        socket.send(msg).map_err(|(_, e)| Error::Nng(e))?;
376        tracing::debug!(
377            request_id,
378            model_id = %model_id,
379            ?request_type,
380            expected_final_count = remaining_finals,
381            "IPC batch request sent"
382        );
383
384        Ok((prompts.len(), rx))
385    }
386
387    /// Send a management command asynchronously.
388    ///
389    /// Uses spawn_blocking internally since NNG sockets are sync.
390    pub async fn send_management_command_async(
391        &self,
392        command: Value,
393        timeout: Duration,
394    ) -> Result<Value> {
395        let socket_arc = Arc::clone(&self.management_socket);
396
397        tokio::task::spawn_blocking(move || {
398            let guard = socket_arc.lock().unwrap_or_else(|e| e.into_inner());
399            let socket = guard.as_ref().ok_or(Error::NotConnected)?;
400
401            // Set timeout
402            socket.set_opt::<nng::options::RecvTimeout>(Some(timeout))?;
403
404            // Send command
405            let data = serde_json::to_vec(&command)?;
406            let msg = nng::Message::from(data.as_slice());
407            socket.send(msg).map_err(|(_, e)| Error::Nng(e))?;
408
409            // Receive response
410            let response = socket.recv()?;
411            let json: Value = serde_json::from_slice(&response)?;
412
413            Ok(json)
414        })
415        .await
416        .map_err(|e| Error::Internal(format!("Task join error: {}", e)))?
417    }
418
419    /// Send a management command synchronously (blocking).
420    ///
421    /// Prefer `send_management_command_async` in async contexts.
422    pub fn send_management_command(&self, command: &Value, timeout: Duration) -> Result<Value> {
423        let guard = self
424            .management_socket
425            .lock()
426            .unwrap_or_else(|e| e.into_inner());
427        let socket = guard.as_ref().ok_or(Error::NotConnected)?;
428
429        // Set timeout
430        socket.set_opt::<nng::options::RecvTimeout>(Some(timeout))?;
431
432        // Send command
433        let data = serde_json::to_vec(command)?;
434        let msg = nng::Message::from(data.as_slice());
435        socket.send(msg).map_err(|(_, e)| Error::Nng(e))?;
436
437        // Receive response
438        let response = socket.recv()?;
439        let json: Value = serde_json::from_slice(&response)?;
440
441        Ok(json)
442    }
443
444    /// Start the response listener thread.
445    fn start_listener(&mut self, engine_pid_file: PathBuf) {
446        let response_socket = self.response_socket.take();
447        let active_requests = Arc::clone(&self.active_requests);
448        let should_stop = Arc::clone(&self.should_stop);
449        let response_channel_id = self.response_channel_id;
450        let event_callback = self.event_callback.clone();
451
452        let handle = thread::Builder::new()
453            .name("orchard-ipc-listener".to_string())
454            .spawn(move || {
455                if let Some(socket) = response_socket {
456                    run_response_listener(
457                        socket,
458                        active_requests,
459                        should_stop,
460                        response_channel_id,
461                        engine_pid_file,
462                        event_callback,
463                    );
464                }
465            });
466
467        match handle {
468            Ok(h) => self.listener_handle = Some(h),
469            Err(e) => tracing::error!("Failed to spawn IPC listener thread: {}", e),
470        }
471    }
472}
473
474impl Default for IPCClient {
475    fn default() -> Self {
476        Self::new()
477    }
478}
479
480impl Drop for IPCClient {
481    fn drop(&mut self) {
482        self.disconnect();
483    }
484}
485
486fn engine_process_is_alive(engine_pid_file: &Path) -> bool {
487    read_pid_file(engine_pid_file)
488        .map(pid_is_alive)
489        .unwrap_or(false)
490}
491
492/// Response listener - runs on dedicated thread for minimal latency.
493fn run_response_listener(
494    socket: Socket,
495    active_requests: Arc<Mutex<HashMap<u64, ActiveRequest>>>,
496    should_stop: Arc<AtomicBool>,
497    response_channel_id: u64,
498    engine_pid_file: PathBuf,
499    event_callback: Option<EventCallback>,
500) {
501    let response_topic = format!("resp:{:x}:", response_channel_id);
502    let response_topic_bytes = response_topic.as_bytes();
503
504    // Set receive timeout for responsive polling (10ms for better latency)
505    let _ = socket.set_opt::<nng::options::RecvTimeout>(Some(RESPONSE_RECV_TIMEOUT));
506    let mut last_engine_check = Instant::now();
507
508    while !should_stop.load(Ordering::SeqCst) {
509        match socket.recv() {
510            Ok(msg) => {
511                let data = msg.as_slice();
512
513                // Check if it's a response for us
514                if data.starts_with(response_topic_bytes) {
515                    let json_data = &data[response_topic_bytes.len()..];
516
517                    if let Ok(delta) = serde_json::from_slice::<ResponseDelta>(json_data) {
518                        let request_id = delta.request_id;
519                        let is_final = delta.is_final_delta;
520
521                        let sender = {
522                            let mut requests =
523                                active_requests.lock().unwrap_or_else(|e| e.into_inner());
524                            if let Some(entry) = requests.get_mut(&request_id) {
525                                if is_final {
526                                    entry.remaining_finals =
527                                        entry.remaining_finals.saturating_sub(1);
528                                    if entry.remaining_finals == 0 {
529                                        let sender = entry.sender.clone();
530                                        requests.remove(&request_id);
531                                        Some(sender)
532                                    } else {
533                                        Some(entry.sender.clone())
534                                    }
535                                } else {
536                                    Some(entry.sender.clone())
537                                }
538                            } else {
539                                None
540                            }
541                        };
542
543                        if let Some(tx) = sender {
544                            let _ = tx.send(delta);
545                        }
546                    } else {
547                        tracing::warn!(
548                            response_channel_id,
549                            payload_bytes = json_data.len(),
550                            "Failed to deserialize IPC response payload"
551                        );
552                    }
553                }
554                // Check if it's an engine event
555                else if data.starts_with(EVENT_TOPIC_PREFIX) {
556                    handle_engine_event(data, &event_callback);
557                }
558            }
559            Err(nng::Error::TimedOut) => {
560                if last_engine_check.elapsed() >= ENGINE_LIVENESS_POLL_INTERVAL {
561                    last_engine_check = Instant::now();
562                    if !engine_process_is_alive(&engine_pid_file) {
563                        tracing::error!(
564                            pid_file = %engine_pid_file.display(),
565                            "PIE is no longer alive; shutting down IPC listener"
566                        );
567                        should_stop.store(true, Ordering::SeqCst);
568                        break;
569                    }
570                }
571                continue;
572            }
573            Err(error) => {
574                if should_stop.load(Ordering::SeqCst) {
575                    break;
576                }
577                if !engine_process_is_alive(&engine_pid_file) {
578                    tracing::error!(
579                        pid_file = %engine_pid_file.display(),
580                        error = %error,
581                        "PIE is no longer alive; shutting down IPC listener"
582                    );
583                    should_stop.store(true, Ordering::SeqCst);
584                    break;
585                }
586            }
587        }
588    }
589
590    // Graceful shutdown: notify any remaining pending requests
591    tracing::info!("IPC listener shutting down");
592    let requests = active_requests.lock().unwrap_or_else(|e| e.into_inner());
593
594    if !requests.is_empty() {
595        tracing::warn!(
596            "IPC listener exiting with {} active requests; failing them.",
597            requests.len()
598        );
599
600        for (request_id, entry) in requests.iter() {
601            let error_delta = ResponseDelta {
602                request_id: *request_id,
603                is_final_delta: true,
604                finish_reason: Some("error".to_string()),
605                content: Some("Engine process disconnected.".to_string()),
606                error: Some("Engine process disconnected.".to_string()),
607                ..Default::default()
608            };
609            let _ = entry.sender.send(error_delta);
610        }
611    }
612}
613
614/// Handle an engine event (telemetry, model_loaded, etc.)
615fn handle_engine_event(data: &[u8], event_callback: &Option<EventCallback>) {
616    // Event format: __PIE_EVENT__:<event_name>\x00<json_body>
617    let parts: Vec<&[u8]> = data.splitn(2, |&b| b == 0).collect();
618    if parts.len() != 2 {
619        tracing::warn!("Received malformed event message");
620        return;
621    }
622
623    let (topic_part, json_body) = (parts[0], parts[1]);
624
625    // Extract event name from topic: "__PIE_EVENT__:<event_name>"
626    let event_name = if topic_part.len() > EVENT_TOPIC_PREFIX.len() {
627        String::from_utf8_lossy(&topic_part[EVENT_TOPIC_PREFIX.len()..]).to_string()
628    } else {
629        tracing::warn!("Event message has empty event name");
630        return;
631    };
632
633    // Parse JSON payload
634    let payload: Value = match serde_json::from_slice(json_body) {
635        Ok(v) => v,
636        Err(e) => {
637            tracing::error!("Failed to parse engine event payload: {}", e);
638            return;
639        }
640    };
641
642    if event_name != "telemetry" {
643        tracing::debug!("Received engine event: {}", event_name);
644    }
645
646    // Dispatch to callback if registered
647    if let Some(callback) = event_callback {
648        callback(&event_name, &payload);
649    }
650}
651
652/// Generate a unique response channel ID.
653/// Format: (PID << 32) | random_32_bits
654fn rand_u64() -> u64 {
655    use rand::Rng;
656
657    let pid = std::process::id() as u64 & 0xFFFFFFFF;
658    let random: u32 = rand::thread_rng().gen();
659
660    let channel_id = (pid << 32) | (random as u64);
661    if channel_id == 0 {
662        1
663    } else {
664        channel_id
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671    use tempfile::tempdir;
672
673    #[test]
674    fn test_client_creation() {
675        let client = IPCClient::new();
676        assert!(client.request_socket.is_none());
677        assert!(client.response_channel_id > 0);
678    }
679
680    #[test]
681    fn test_request_id_increment() {
682        let client = IPCClient::new();
683        let id1 = client.next_request_id();
684        let id2 = client.next_request_id();
685        assert_eq!(id2, id1 + 1);
686    }
687
688    #[test]
689    fn test_response_delta_default() {
690        let delta = ResponseDelta::default();
691        assert_eq!(delta.request_id, 0);
692        assert!(!delta.is_final_delta);
693        assert!(delta.tokens.is_empty());
694        assert!(delta.top_logprobs.is_empty());
695        assert!(delta.embedding_bytes.is_none());
696        assert!(delta.state_events.is_empty());
697    }
698
699    #[test]
700    fn test_response_delta_deserialize() {
701        let json = serde_json::json!({
702            "request_id": 123,
703            "sequence_id": 1,
704            "prompt_index": 0,
705            "candidate_index": 0,
706            "content": "Hello",
707            "content_len": 5,
708            "inline_content_bytes": 5,
709            "is_final_delta": false,
710            "num_tokens_in_delta": 3,
711            "tokens": [1, 2, 3],
712            "top_logprobs": [{"token": "hello", "logprob": -0.5}, {"token": "world", "logprob": -1.0}],
713            "cumulative_logprob": -1.5,
714            "modal_decoder_id": "moondream3.coord",
715            "modal_bytes_b64": "AAAA",
716            "embedding_bytes": [0, 0, 128, 63]
717        });
718        let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
719        assert_eq!(delta.request_id, 123);
720        assert_eq!(delta.sequence_id, Some(1));
721        assert_eq!(delta.candidate_index, Some(0));
722        assert_eq!(delta.content_len, Some(5));
723        assert_eq!(delta.num_tokens_in_delta, Some(3));
724        assert_eq!(delta.tokens, vec![1, 2, 3]);
725        assert_eq!(delta.top_logprobs.len(), 2);
726        assert_eq!(delta.cumulative_logprob, Some(-1.5));
727        assert_eq!(delta.modal_decoder_id, Some("moondream3.coord".to_string()));
728        assert_eq!(delta.embedding_bytes, Some(vec![0, 0, 128, 63]));
729        assert!(delta.state_events.is_empty());
730    }
731
732    #[test]
733    fn test_response_delta_deserialize_with_defaults() {
734        // Test that missing fields get sensible defaults
735        let json = serde_json::json!({
736            "request_id": 42,
737            "is_final_delta": true
738        });
739        let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
740        assert_eq!(delta.request_id, 42);
741        assert!(delta.is_final_delta);
742        assert!(delta.tokens.is_empty());
743        assert!(delta.content.is_none());
744        assert!(delta.state_events.is_empty());
745    }
746
747    #[test]
748    fn test_response_delta_deserialize_embedding_bytes_from_base64() {
749        let json = serde_json::json!({
750            "request_id": 42,
751            "is_final_delta": true,
752            "embedding_bytes": "AAAAAA==",
753        });
754
755        let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
756        assert_eq!(delta.embedding_bytes, Some(vec![0, 0, 0, 0]));
757    }
758
759    #[test]
760    fn test_response_delta_deserialize_error_message_alias() {
761        let json = serde_json::json!({
762            "request_id": 42,
763            "is_final_delta": true,
764            "error_message": "boom",
765        });
766
767        let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
768        assert_eq!(delta.error.as_deref(), Some("boom"));
769    }
770
771    #[test]
772    fn test_engine_process_is_alive_reads_pid_file() {
773        let dir = tempdir().expect("tempdir should be available");
774        let pid_file = dir.path().join("engine.pid");
775        std::fs::write(&pid_file, format!("{}\n", std::process::id()))
776            .expect("pid file should be written");
777
778        assert!(engine_process_is_alive(&pid_file));
779    }
780
781    #[test]
782    fn test_engine_process_is_alive_handles_missing_pid_file() {
783        let dir = tempdir().expect("tempdir should be available");
784        let pid_file = dir.path().join("missing.pid");
785
786        assert!(!engine_process_is_alive(&pid_file));
787    }
788
789    #[test]
790    fn test_response_delta_deserialize_with_state_events() {
791        let json = serde_json::json!({
792            "request_id": 7,
793            "is_final_delta": false,
794            "state_events": [
795                {
796                    "event_type": "item_started",
797                    "item_type": "message",
798                    "output_index": 0,
799                    "identifier": "",
800                    "delta": ""
801                },
802                {
803                    "event_type": "content_delta",
804                    "item_type": "message",
805                    "output_index": 0,
806                    "identifier": "",
807                    "delta": "hello"
808                }
809            ]
810        });
811
812        let delta: ResponseDelta = serde_json::from_value(json).expect("deserialize failed");
813        assert_eq!(delta.request_id, 7);
814        assert_eq!(delta.state_events.len(), 2);
815        assert_eq!(delta.state_events[0].event_type, "item_started");
816        assert_eq!(delta.state_events[1].delta, "hello");
817    }
818}