Skip to main content

orchard/client/
mod.rs

1//! High-level client API for Orchard.
2//!
3//! Provides the main user-facing interface for LLM inference.
4
5mod moondream;
6mod response;
7mod responses;
8
9use std::collections::HashMap;
10use std::sync::{Arc, OnceLock};
11
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use thiserror::Error;
16use tokio::sync::mpsc;
17
18/// Global runtime for synchronous operations.
19/// Uses current_thread for efficiency - sync callers don't need multi-thread.
20static SYNC_RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
21
22fn get_sync_runtime() -> &'static tokio::runtime::Runtime {
23    SYNC_RUNTIME.get_or_init(|| {
24        tokio::runtime::Builder::new_current_thread()
25            .enable_all()
26            .build()
27            .expect("Failed to create sync runtime")
28    })
29}
30
31use crate::formatter::multimodal::{
32    build_multimodal_layout, build_multimodal_messages, CapabilityInput, LayoutSegment,
33};
34use crate::ipc::client::{EventCallback, IPCClient, ResponseDelta};
35use crate::ipc::serialization::{CapabilityEntry, LayoutEntry, PromptPayload, RequestType};
36use crate::model::registry::ModelRegistry;
37
38pub use moondream::{
39    BoundingBox, CaptionResult, DetectResult, DetectedObject, GazeResult, GroundingSpan,
40    MoondreamClient, Point, PointResult, QueryResult, ReasoningOutput, SpatialRef,
41    MOONDREAM_MODEL_ID,
42};
43pub use response::{BatchChatResult, ClientDelta, ClientResponse, UsageStats};
44pub use responses::{
45    ContentPartAddedEvent, ContentPartDoneEvent, FunctionCallArgumentsDeltaEvent,
46    FunctionCallArgumentsDoneEvent, IncompleteDetails, InputTokensDetails, OutputFunctionCall,
47    OutputItemAddedEvent, OutputItemDoneEvent, OutputMessage, OutputReasoning, OutputStatus,
48    OutputTextContent, OutputTextDeltaEvent, OutputTextDoneEvent, OutputTokensDetails,
49    ReasoningContent, ReasoningDeltaEvent, ReasoningDoneEvent, ReasoningSummaryTextContent,
50    ReasoningSummaryTextDeltaEvent, ReasoningSummaryTextDoneEvent, ResponseCompletedEvent,
51    ResponseCreatedEvent, ResponseError, ResponseEvent, ResponseFailedEvent,
52    ResponseInProgressEvent, ResponseIncompleteEvent, ResponseInputItem, ResponseObject,
53    ResponseOutputItem, ResponseSnapshot, ResponseUsage, ResponsesInput, ResponsesRequest,
54    ResponsesResult, StreamErrorDetail, StreamErrorEvent,
55};
56
57/// Errors that can occur during client operations.
58#[derive(Error, Debug)]
59pub enum ClientError {
60    #[error("Model not found: {0}")]
61    ModelNotFound(String),
62
63    #[error("{0}")]
64    ModelNotReady(String),
65
66    #[error("{0}")]
67    Ipc(String),
68
69    #[error("{0}")]
70    Formatter(String),
71
72    #[error("{0}")]
73    Multimodal(String),
74
75    #[error("{0}")]
76    RequestFailed(String),
77}
78
79impl From<crate::error::Error> for ClientError {
80    fn from(err: crate::error::Error) -> Self {
81        use crate::error::Error;
82        match err {
83            Error::ModelNotFound(s) => ClientError::ModelNotFound(s),
84            Error::ModelNotReady(s) => ClientError::ModelNotReady(s),
85            Error::NotConnected
86            | Error::InvalidResponse
87            | Error::Nng(_)
88            | Error::Timeout
89            | Error::ChannelClosed => ClientError::Ipc(err.to_string()),
90            Error::Template(s) => ClientError::Formatter(s),
91            Error::InvalidImageUrl
92            | Error::InvalidBase64
93            | Error::MissingContentType(_, _)
94            | Error::InvalidContent
95            | Error::PlaceholderMismatch(_, _)
96            | Error::EmptyRequest => ClientError::Multimodal(err.to_string()),
97            _ => ClientError::RequestFailed(err.to_string()),
98        }
99    }
100}
101
102pub type Result<T> = std::result::Result<T, ClientError>;
103
104use crate::defaults;
105
106/// Sampling parameters for generation.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SamplingParams {
109    #[serde(default = "defaults::max_tokens")]
110    pub max_tokens: i32,
111    #[serde(default = "defaults::temperature")]
112    pub temperature: f64,
113    #[serde(default = "defaults::top_p")]
114    pub top_p: f64,
115    #[serde(default = "defaults::top_k")]
116    pub top_k: i32,
117    #[serde(default)]
118    pub min_p: f64,
119    #[serde(default)]
120    pub rng_seed: u64,
121    #[serde(default)]
122    pub stop: Vec<String>,
123    #[serde(default)]
124    pub frequency_penalty: f64,
125    #[serde(default)]
126    pub presence_penalty: f64,
127    #[serde(default = "defaults::repetition_penalty")]
128    pub repetition_penalty: f64,
129    #[serde(default = "defaults::repetition_context_size")]
130    pub repetition_context_size: i32,
131    #[serde(default = "defaults::num_candidates")]
132    pub n: i32,
133    #[serde(default)]
134    pub best_of: Option<i32>,
135    #[serde(default)]
136    pub final_candidates: Option<i32>,
137    #[serde(default)]
138    pub top_logprobs: i32,
139    #[serde(default)]
140    pub logit_bias: HashMap<i32, f64>,
141    #[serde(default)]
142    pub tools: Vec<serde_json::Value>,
143    #[serde(default)]
144    pub tool_choice: Option<serde_json::Value>,
145    #[serde(default)]
146    pub max_tool_calls: Option<i32>,
147    #[serde(default)]
148    pub response_format: Option<serde_json::Value>,
149    #[serde(default)]
150    pub reasoning: bool,
151    #[serde(default)]
152    pub reasoning_effort: Option<String>,
153    #[serde(default)]
154    pub instructions: Option<String>,
155    #[serde(default)]
156    pub task_name: Option<String>,
157}
158
159impl Default for SamplingParams {
160    fn default() -> Self {
161        Self {
162            max_tokens: defaults::MAX_TOKENS,
163            temperature: defaults::TEMPERATURE,
164            top_p: defaults::TOP_P,
165            top_k: defaults::TOP_K,
166            min_p: 0.0,
167            rng_seed: 0,
168            stop: Vec::new(),
169            frequency_penalty: 0.0,
170            presence_penalty: 0.0,
171            repetition_penalty: defaults::REPETITION_PENALTY,
172            repetition_context_size: defaults::REPETITION_CONTEXT_SIZE,
173            n: defaults::NUM_CANDIDATES,
174            best_of: None,
175            final_candidates: None,
176            top_logprobs: 0,
177            logit_bias: HashMap::new(),
178            tools: Vec::new(),
179            tool_choice: None,
180            max_tool_calls: None,
181            response_format: None,
182            reasoning: false,
183            reasoning_effort: None,
184            instructions: None,
185            task_name: None,
186        }
187    }
188}
189
190fn tool_choice_to_string(tool_choice: Option<&Value>) -> String {
191    match tool_choice {
192        None | Some(Value::Null) => "auto".to_string(),
193        Some(Value::String(value)) => value.clone(),
194        Some(Value::Object(value)) => serde_json::to_string(value).unwrap_or_default(),
195        Some(other) => other.to_string(),
196    }
197}
198
199/// A high-level client for the Proxy Inference Engine.
200///
201/// Provides both synchronous and asynchronous interfaces for LLM inference.
202pub struct Client {
203    ipc: Arc<IPCClient>,
204    registry: Arc<ModelRegistry>,
205}
206
207impl Client {
208    /// Create a new client with the given IPC client and model registry.
209    pub fn new(ipc: Arc<IPCClient>, registry: Arc<ModelRegistry>) -> Self {
210        Self { ipc, registry }
211    }
212
213    /// Create a client and connect to the engine (async).
214    ///
215    /// This sets up:
216    /// - Event callback for handling model lifecycle events
217    /// - IPC client shared with registry for management commands
218    pub async fn connect(registry: Arc<ModelRegistry>) -> Result<Self> {
219        // Create event callback that routes model lifecycle events to registry
220        let registry_for_events = Arc::clone(&registry);
221        let runtime_handle = tokio::runtime::Handle::current();
222        let event_callback: EventCallback =
223            Arc::new(move |event_name: &str, payload: &Value| match event_name {
224                "model_loaded" => {
225                    let registry = Arc::clone(&registry_for_events);
226                    let payload = payload.clone();
227                    let handle = runtime_handle.clone();
228                    handle.spawn(async move {
229                        registry.handle_model_loaded(&payload).await;
230                    });
231                }
232                "model_load_failed" => {
233                    let registry = Arc::clone(&registry_for_events);
234                    let payload = payload.clone();
235                    let handle = runtime_handle.clone();
236                    handle.spawn(async move {
237                        registry.handle_model_load_failed(&payload).await;
238                    });
239                }
240                _ => {}
241            });
242
243        let mut ipc = IPCClient::with_event_callback(event_callback);
244        ipc.connect()?;
245        let ipc = Arc::new(ipc);
246
247        // Share the IPC client with the registry for management commands
248        registry.set_ipc_client(Arc::clone(&ipc)).await;
249
250        Ok(Self { ipc, registry })
251    }
252
253    /// Resolve control token capabilities for a model.
254    pub async fn resolve_capabilities(&self, model_id: &str) -> Result<HashMap<String, i32>> {
255        let info = self.registry.ensure_loaded(model_id).await?;
256
257        let capabilities = info.capabilities.as_ref().cloned().unwrap_or_default();
258        let mut resolved = HashMap::new();
259
260        for (name, token_ids) in capabilities {
261            if let Some(&first) = token_ids.first() {
262                resolved.insert(name, first);
263            }
264        }
265
266        Ok(resolved)
267    }
268
269    /// Perform asynchronous chat completion.
270    ///
271    /// # Arguments
272    /// * `model_id` - Model to use for generation
273    /// * `messages` - Conversation messages
274    /// * `params` - Sampling parameters
275    /// * `stream` - Whether to stream responses
276    pub async fn achat(
277        &self,
278        model_id: &str,
279        messages: Vec<HashMap<String, serde_json::Value>>,
280        params: SamplingParams,
281        stream: bool,
282    ) -> Result<ChatResult> {
283        let info = self.registry.ensure_loaded(model_id).await?;
284        let formatter = info.require_formatter()?;
285
286        let request_id = self.ipc.next_request_id();
287        tracing::debug!(
288            request_id,
289            model_id = %model_id,
290            stream,
291            message_count = messages.len(),
292            "Building chat request"
293        );
294        tracing::trace!(
295            request_id,
296            model_id = %model_id,
297            messages = ?messages,
298            "Chat messages before template application"
299        );
300
301        // Compute reasoning flag (same as Python: reasoning OR reasoning_effort present)
302        let reasoning_flag = params.reasoning || params.reasoning_effort.is_some();
303
304        // Build multimodal content (pass instructions if provided)
305        let (messages_for_template, image_buffers, capabilities, content_order) =
306            build_multimodal_messages(formatter, &messages, params.instructions.as_deref())
307                .map_err(|e| ClientError::Multimodal(e.to_string()))?;
308
309        if messages_for_template.is_empty() {
310            return Err(ClientError::RequestFailed(
311                "Chat request must include at least one message".into(),
312            ));
313        }
314        tracing::trace!(
315            request_id,
316            model_id = %model_id,
317            messages_for_template = ?messages_for_template,
318            "Chat messages after multimodal expansion"
319        );
320
321        // Apply template with reasoning flag
322        let prompt_text = formatter
323            .apply_template(
324                &messages_for_template,
325                true,
326                reasoning_flag,
327                params.task_name.as_deref(),
328            )
329            .map_err(|e| ClientError::Formatter(e.to_string()))?;
330
331        let capability_placeholder = formatter.capability_placeholder_token();
332
333        // Build layout for multimodal content
334        let layout_segments = build_multimodal_layout(
335            &prompt_text,
336            &image_buffers,
337            &capabilities,
338            &content_order,
339            formatter.image_placeholder_token(),
340            formatter.should_clip_image_placeholder(),
341            capability_placeholder,
342        )
343        .map_err(|e| ClientError::Multimodal(e.to_string()))?;
344
345        let final_prompt = formatter.strip_template_placeholders(&prompt_text);
346        tracing::debug!(
347            request_id,
348            model_id = %model_id,
349            prompt_chars = final_prompt.chars().count(),
350            image_count = image_buffers.len(),
351            capability_count = capabilities.len(),
352            layout_segment_count = layout_segments.len(),
353            "Prepared chat prompt payload"
354        );
355        tracing::trace!(
356            request_id,
357            model_id = %model_id,
358            prompt = %final_prompt,
359            "Chat prompt sent to PIE"
360        );
361
362        // Serialize tools and response_format to JSON strings (matching Python)
363        let tool_schemas_json = if params.tools.is_empty() {
364            String::new()
365        } else {
366            serde_json::to_string(&params.tools).unwrap_or_default()
367        };
368        let response_format_json = params
369            .response_format
370            .as_ref()
371            .map(|rf| serde_json::to_string(rf).unwrap_or_default())
372            .unwrap_or_default();
373        let tool_calling_tokens = formatter.get_tool_calling_tokens().clone();
374        let tool_choice = tool_choice_to_string(params.tool_choice.as_ref());
375        let max_tool_calls = params.max_tool_calls.unwrap_or(0).max(0);
376
377        // Build PromptPayload with full multimodal data
378        // Generate unique RNG seed if not explicitly provided
379        let rng_seed = if params.rng_seed == 0 {
380            rand::thread_rng().gen::<u64>()
381        } else {
382            params.rng_seed
383        };
384
385        let prompt_payload = PromptPayload {
386            prompt: final_prompt,
387            image_buffers,
388            capabilities: convert_capabilities(&capabilities),
389            layout: convert_layout(&layout_segments),
390            max_generated_tokens: params.max_tokens,
391            temperature: params.temperature,
392            top_p: params.top_p,
393            top_k: params.top_k,
394            min_p: params.min_p,
395            rng_seed,
396            stop_sequences: params.stop.clone(),
397            num_candidates: params.n,
398            best_of: params.best_of,
399            final_candidates: params.final_candidates,
400            frequency_penalty: params.frequency_penalty,
401            presence_penalty: params.presence_penalty,
402            repetition_penalty: params.repetition_penalty,
403            repetition_context_size: params.repetition_context_size,
404            top_logprobs: params.top_logprobs,
405            logit_bias: params.logit_bias.clone(),
406            tool_schemas_json,
407            tool_calling_tokens,
408            tool_choice,
409            max_tool_calls,
410            response_format_json,
411            task_name: params.task_name.clone(),
412            reasoning_effort: params.reasoning_effort.clone(),
413        };
414
415        // Use unified batch request path (even for single prompts)
416        tracing::debug!(
417            request_id,
418            model_id = %model_id,
419            stream,
420            "Dispatching chat request to PIE"
421        );
422        let (_batch_size, rx) = self.ipc.send_batch_request(
423            request_id,
424            model_id,
425            &info.model_path,
426            &[prompt_payload],
427        )?;
428
429        if stream {
430            Ok(ChatResult::Stream(rx))
431        } else {
432            // Determine how many candidates to expect
433            let best_of = params.best_of.unwrap_or(params.n).max(1) as usize;
434            let final_candidates = params.final_candidates.unwrap_or(params.n).max(1) as usize;
435
436            // Collect deltas grouped by candidate_index (matching Python's gather_non_streaming_batch_response)
437            let mut candidate_states: Vec<CandidateState> =
438                (0..best_of).map(|_| CandidateState::default()).collect();
439            let mut remaining_sequences = best_of;
440            let mut rx = rx;
441
442            while remaining_sequences > 0 {
443                match rx.recv().await {
444                    Some(delta) => {
445                        let candidate_index = delta.candidate_index.unwrap_or(0) as usize;
446                        if candidate_index >= candidate_states.len() {
447                            continue;
448                        }
449
450                        let state = &mut candidate_states[candidate_index];
451
452                        if let Some(content) = &delta.content {
453                            state.content.push_str(content);
454                        }
455                        state.completion_tokens += delta.tokens.len() as u32;
456                        if let Some(count) = delta.prompt_token_count {
457                            state.prompt_tokens = state.prompt_tokens.max(count);
458                        }
459
460                        let client_delta = ClientDelta::from(delta.clone());
461                        state.deltas.push(client_delta);
462
463                        if delta.is_final_delta && !state.completed {
464                            state.completed = true;
465                            state.finish_reason = delta.finish_reason.clone();
466                            state.cumulative_logprob = delta.cumulative_logprob;
467                            state.generation_len = delta.generation_len;
468                            remaining_sequences -= 1;
469                        }
470                    }
471                    None => break,
472                }
473            }
474
475            let total_completion_tokens: u32 =
476                candidate_states.iter().map(|c| c.completion_tokens).sum();
477
478            // Score and select best candidates (matching Python logic)
479            let selected = select_best_candidates(candidate_states, best_of, final_candidates);
480
481            Ok(ChatResult::Complete(build_response_from_candidates(
482                selected,
483                total_completion_tokens,
484            )))
485        }
486    }
487
488    /// Perform synchronous chat completion (blocking).
489    ///
490    /// Handles nested async contexts properly - safe to call from any context.
491    pub fn chat(
492        &self,
493        model_id: &str,
494        messages: Vec<HashMap<String, serde_json::Value>>,
495        params: SamplingParams,
496    ) -> Result<ClientResponse> {
497        let future = async {
498            match self.achat(model_id, messages, params, false).await? {
499                ChatResult::Complete(response) => Ok(response),
500                ChatResult::Stream(_) => Err(ClientError::RequestFailed(
501                    "Unexpected stream result".into(),
502                )),
503            }
504        };
505
506        match tokio::runtime::Handle::try_current() {
507            Ok(handle) => {
508                // Already in async context - use block_in_place to avoid panic
509                tokio::task::block_in_place(|| handle.block_on(future))
510            }
511            Err(_) => {
512                // Not in async context - use the global sync runtime
513                get_sync_runtime().block_on(future)
514            }
515        }
516    }
517
518    /// Perform batched chat completion.
519    ///
520    /// This sends ALL conversations in ONE IPC message, allowing the engine
521    /// to schedule them together efficiently. Responses are demultiplexed
522    /// by prompt_index and returned in order.
523    ///
524    /// # Arguments
525    /// * `model_id` - Model to use for generation
526    /// * `conversations` - List of conversation message lists
527    /// * `params` - Sampling parameters
528    /// * `stream` - Whether to stream responses (deltas contain prompt_index)
529    pub async fn achat_batch(
530        &self,
531        model_id: &str,
532        conversations: Vec<Vec<HashMap<String, serde_json::Value>>>,
533        params: SamplingParams,
534        stream: bool,
535    ) -> Result<BatchChatResult> {
536        if conversations.is_empty() {
537            return Ok(BatchChatResult::Complete(Vec::new()));
538        }
539
540        let info = self.registry.ensure_loaded(model_id).await?;
541        let formatter = info.require_formatter()?;
542
543        let request_id = self.ipc.next_request_id();
544        let num_prompts = conversations.len();
545        tracing::debug!(
546            request_id,
547            model_id = %model_id,
548            stream,
549            prompt_count = num_prompts,
550            "Building batched chat request"
551        );
552
553        // Compute reasoning flag (same as Python: reasoning OR reasoning_effort present)
554        let reasoning_flag = params.reasoning || params.reasoning_effort.is_some();
555
556        // Serialize tools and response_format to JSON strings (matching Python)
557        let tool_schemas_json = if params.tools.is_empty() {
558            String::new()
559        } else {
560            serde_json::to_string(&params.tools).unwrap_or_default()
561        };
562        let response_format_json = params
563            .response_format
564            .as_ref()
565            .map(|rf| serde_json::to_string(rf).unwrap_or_default())
566            .unwrap_or_default();
567        let tool_calling_tokens = formatter.get_tool_calling_tokens().clone();
568        let tool_choice = tool_choice_to_string(params.tool_choice.as_ref());
569        let max_tool_calls = params.max_tool_calls.unwrap_or(0).max(0);
570
571        // Build all prompt payloads
572        let mut prompt_payloads = Vec::with_capacity(num_prompts);
573
574        for (prompt_index, messages) in conversations.iter().enumerate() {
575            // Build multimodal content (pass instructions if provided)
576            let (messages_for_template, image_buffers, capabilities, content_order) =
577                build_multimodal_messages(formatter, messages, params.instructions.as_deref())
578                    .map_err(|e| ClientError::Multimodal(e.to_string()))?;
579
580            if messages_for_template.is_empty() {
581                return Err(ClientError::RequestFailed(
582                    "Chat request must include at least one message".into(),
583                ));
584            }
585            tracing::trace!(
586                request_id,
587                model_id = %model_id,
588                prompt_index,
589                messages = ?messages,
590                messages_for_template = ?messages_for_template,
591                "Prepared batch messages for prompt"
592            );
593
594            // Apply template with reasoning flag
595            let prompt_text = formatter
596                .apply_template(
597                    &messages_for_template,
598                    true,
599                    reasoning_flag,
600                    params.task_name.as_deref(),
601                )
602                .map_err(|e| ClientError::Formatter(e.to_string()))?;
603
604            let capability_placeholder = formatter.capability_placeholder_token();
605
606            // Build layout for multimodal content
607            let layout_segments = build_multimodal_layout(
608                &prompt_text,
609                &image_buffers,
610                &capabilities,
611                &content_order,
612                formatter.image_placeholder_token(),
613                formatter.should_clip_image_placeholder(),
614                capability_placeholder,
615            )
616            .map_err(|e| ClientError::Multimodal(e.to_string()))?;
617
618            let final_prompt = formatter.strip_template_placeholders(&prompt_text);
619            tracing::debug!(
620                request_id,
621                model_id = %model_id,
622                prompt_index,
623                prompt_chars = final_prompt.chars().count(),
624                image_count = image_buffers.len(),
625                capability_count = capabilities.len(),
626                layout_segment_count = layout_segments.len(),
627                "Prepared batched prompt payload"
628            );
629            tracing::trace!(
630                request_id,
631                model_id = %model_id,
632                prompt_index,
633                prompt = %final_prompt,
634                "Batch prompt sent to PIE"
635            );
636
637            // Generate unique RNG seed for EACH prompt in batch
638            let rng_seed = if params.rng_seed == 0 {
639                rand::thread_rng().gen::<u64>()
640            } else {
641                params.rng_seed
642            };
643
644            prompt_payloads.push(PromptPayload {
645                prompt: final_prompt,
646                image_buffers,
647                capabilities: convert_capabilities(&capabilities),
648                layout: convert_layout(&layout_segments),
649                max_generated_tokens: params.max_tokens,
650                temperature: params.temperature,
651                top_p: params.top_p,
652                top_k: params.top_k,
653                min_p: params.min_p,
654                rng_seed,
655                stop_sequences: params.stop.clone(),
656                num_candidates: params.n,
657                best_of: params.best_of,
658                final_candidates: params.final_candidates,
659                frequency_penalty: params.frequency_penalty,
660                presence_penalty: params.presence_penalty,
661                repetition_penalty: params.repetition_penalty,
662                repetition_context_size: params.repetition_context_size,
663                top_logprobs: params.top_logprobs,
664                logit_bias: params.logit_bias.clone(),
665                tool_schemas_json: tool_schemas_json.clone(),
666                tool_calling_tokens: tool_calling_tokens.clone(),
667                tool_choice: tool_choice.clone(),
668                max_tool_calls,
669                response_format_json: response_format_json.clone(),
670                task_name: params.task_name.clone(),
671                reasoning_effort: params.reasoning_effort.clone(),
672            });
673        }
674
675        // Send ONE batch request with all prompts
676        tracing::debug!(
677            request_id,
678            model_id = %model_id,
679            stream,
680            prompt_count = prompt_payloads.len(),
681            "Dispatching batched chat request to PIE"
682        );
683        let (_batch_size, rx) = self.ipc.send_batch_request(
684            request_id,
685            model_id,
686            &info.model_path,
687            &prompt_payloads,
688        )?;
689
690        if stream {
691            // Convert ResponseDelta receiver to ClientDelta receiver
692            let (tx, client_rx) = mpsc::channel(256);
693            tokio::spawn(async move {
694                let mut rx = rx;
695                while let Some(delta) = rx.recv().await {
696                    if tx.send(ClientDelta::from(delta)).await.is_err() {
697                        break;
698                    }
699                }
700            });
701            return Ok(BatchChatResult::Stream(client_rx));
702        }
703
704        // Collect responses grouped by prompt_index
705        let mut deltas_by_prompt: HashMap<u32, Vec<ClientDelta>> = HashMap::new();
706        let mut finals_received = 0usize;
707        let mut rx = rx;
708
709        while finals_received < num_prompts {
710            match rx.recv().await {
711                Some(delta) => {
712                    let prompt_index = delta.prompt_index.unwrap_or(0);
713                    let is_final = delta.is_final_delta;
714
715                    deltas_by_prompt
716                        .entry(prompt_index)
717                        .or_default()
718                        .push(ClientDelta::from(delta));
719
720                    if is_final {
721                        finals_received += 1;
722                    }
723                }
724                None => break, // Channel closed
725            }
726        }
727
728        // Build responses in order
729        let mut responses = Vec::with_capacity(num_prompts);
730        for idx in 0..num_prompts {
731            let deltas = deltas_by_prompt.remove(&(idx as u32)).unwrap_or_default();
732            responses.push(aggregate_response(deltas));
733        }
734
735        Ok(BatchChatResult::Complete(responses))
736    }
737
738    /// Generate an embedding for a single text input.
739    pub async fn aembed(&self, model_id: &str, text: &str) -> Result<Vec<f32>> {
740        let mut embeddings = self.aembed_batch(model_id, vec![text.to_string()]).await?;
741        embeddings.pop().ok_or_else(|| {
742            ClientError::RequestFailed("Embedding response missing result".to_string())
743        })
744    }
745
746    /// Generate embeddings for multiple text inputs in a single IPC request.
747    pub async fn aembed_batch(&self, model_id: &str, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
748        if texts.is_empty() {
749            return Ok(Vec::new());
750        }
751
752        let info = self.registry.ensure_loaded(model_id).await?;
753
754        let request_id = self.ipc.next_request_id();
755        tracing::debug!(
756            request_id,
757            model_id = %model_id,
758            prompt_count = texts.len(),
759            "Building batched embedding request"
760        );
761
762        let mut prompt_payloads = Vec::with_capacity(texts.len());
763        for (prompt_index, text) in texts.into_iter().enumerate() {
764            let prompt_chars = text.chars().count();
765            tracing::debug!(
766                request_id,
767                model_id = %model_id,
768                prompt_index,
769                prompt_chars,
770                "Prepared embedding prompt payload"
771            );
772            tracing::trace!(
773                request_id,
774                model_id = %model_id,
775                prompt_index,
776                prompt = %text,
777                "Embedding prompt sent to PIE"
778            );
779
780            prompt_payloads.push(build_embedding_prompt_payload(text));
781        }
782
783        tracing::debug!(
784            request_id,
785            model_id = %model_id,
786            prompt_count = prompt_payloads.len(),
787            "Dispatching batched embedding request to PIE"
788        );
789        let (_batch_size, rx) = self.ipc.send_batch_request_with_type(
790            request_id,
791            model_id,
792            &info.model_path,
793            RequestType::Embedding,
794            &prompt_payloads,
795        )?;
796
797        collect_embeddings(rx, prompt_payloads.len()).await
798    }
799
800    /// Transcribe float32 PCM audio with a local speech-to-text model.
801    pub async fn atranscribe_audio(&self, model_id: &str, pcm: &[f32]) -> Result<String> {
802        if pcm.is_empty() {
803            return Ok(String::new());
804        }
805
806        let info = self.registry.ensure_loaded(model_id).await?;
807
808        let request_id = self.ipc.next_request_id();
809        tracing::debug!(
810            request_id,
811            model_id = %model_id,
812            sample_count = pcm.len(),
813            "Building speech-to-text request"
814        );
815
816        let prompt_payload = build_stt_prompt_payload(pcm);
817
818        tracing::debug!(
819            request_id,
820            model_id = %model_id,
821            payload_bytes = prompt_payload.capabilities[0].payload.len(),
822            "Dispatching speech-to-text request to PIE"
823        );
824        let (_batch_size, rx) = self.ipc.send_batch_request_with_type(
825            request_id,
826            model_id,
827            &info.model_path,
828            RequestType::Omni,
829            &[prompt_payload],
830        )?;
831
832        collect_transcription(rx).await
833    }
834
835    /// Synchronous speech-to-text wrapper.
836    pub fn transcribe_audio(&self, model_id: &str, pcm: &[f32]) -> Result<String> {
837        let model_id = model_id.to_string();
838        let pcm = pcm.to_vec();
839        let future = async move { self.atranscribe_audio(&model_id, &pcm).await };
840
841        match tokio::runtime::Handle::try_current() {
842            Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)),
843            Err(_) => get_sync_runtime().block_on(future),
844        }
845    }
846}
847
848/// Convert CapabilityInput from multimodal to CapabilityEntry for serialization.
849/// Position is always 0 (matching Python behavior).
850fn convert_capabilities(capabilities: &[CapabilityInput]) -> Vec<CapabilityEntry> {
851    capabilities
852        .iter()
853        .map(|cap| CapabilityEntry {
854            name: cap.name.clone(),
855            position: 0, // Always 0, matching Python
856            payload: cap.payload.clone(),
857        })
858        .collect()
859}
860
861/// Convert LayoutSegment from multimodal to LayoutEntry for serialization.
862fn convert_layout(segments: &[LayoutSegment]) -> Vec<LayoutEntry> {
863    segments
864        .iter()
865        .map(|seg| LayoutEntry {
866            segment_type: seg.segment_type.clone(),
867            length: seg.length,
868        })
869        .collect()
870}
871
872fn build_embedding_prompt_payload(prompt: String) -> PromptPayload {
873    let prompt_len = prompt.len();
874
875    PromptPayload {
876        prompt,
877        image_buffers: Vec::new(),
878        capabilities: Vec::new(),
879        layout: vec![LayoutEntry {
880            segment_type: "text".to_string(),
881            length: prompt_len,
882        }],
883        max_generated_tokens: 0,
884        temperature: defaults::TEMPERATURE,
885        top_p: defaults::TOP_P,
886        top_k: defaults::TOP_K,
887        min_p: 0.0,
888        rng_seed: rand::thread_rng().gen::<u64>(),
889        stop_sequences: Vec::new(),
890        num_candidates: 1,
891        best_of: Some(1),
892        final_candidates: Some(1),
893        frequency_penalty: 0.0,
894        presence_penalty: 0.0,
895        repetition_penalty: defaults::REPETITION_PENALTY,
896        repetition_context_size: 0,
897        top_logprobs: 0,
898        logit_bias: HashMap::new(),
899        tool_schemas_json: String::new(),
900        tool_calling_tokens: Default::default(),
901        tool_choice: "auto".to_string(),
902        max_tool_calls: 0,
903        response_format_json: String::new(),
904        task_name: None,
905        reasoning_effort: None,
906    }
907}
908
909fn build_stt_prompt_payload(pcm: &[f32]) -> PromptPayload {
910    let audio_payload = encode_float32_pcm_bytes(pcm);
911    let audio_payload_size = audio_payload.len();
912
913    PromptPayload {
914        prompt: String::new(),
915        image_buffers: Vec::new(),
916        capabilities: vec![CapabilityEntry {
917            name: "audio".to_string(),
918            position: 0,
919            payload: audio_payload,
920        }],
921        layout: vec![
922            LayoutEntry {
923                segment_type: "text".to_string(),
924                length: 0,
925            },
926            LayoutEntry {
927                segment_type: "capability".to_string(),
928                length: audio_payload_size,
929            },
930        ],
931        max_generated_tokens: 0,
932        temperature: defaults::TEMPERATURE,
933        top_p: defaults::TOP_P,
934        top_k: defaults::TOP_K,
935        min_p: 0.0,
936        rng_seed: rand::thread_rng().gen::<u64>(),
937        stop_sequences: Vec::new(),
938        num_candidates: 1,
939        best_of: Some(1),
940        final_candidates: Some(1),
941        frequency_penalty: 0.0,
942        presence_penalty: 0.0,
943        repetition_penalty: defaults::REPETITION_PENALTY,
944        repetition_context_size: 0,
945        top_logprobs: 0,
946        logit_bias: HashMap::new(),
947        tool_schemas_json: String::new(),
948        tool_calling_tokens: Default::default(),
949        tool_choice: "auto".to_string(),
950        max_tool_calls: 0,
951        response_format_json: String::new(),
952        task_name: None,
953        reasoning_effort: None,
954    }
955}
956
957fn encode_float32_pcm_bytes(pcm: &[f32]) -> Vec<u8> {
958    let mut bytes = Vec::with_capacity(std::mem::size_of_val(pcm));
959    for sample in pcm {
960        bytes.extend_from_slice(&sample.to_le_bytes());
961    }
962    bytes
963}
964
965async fn collect_embeddings(
966    mut rx: mpsc::UnboundedReceiver<ResponseDelta>,
967    prompt_count: usize,
968) -> Result<Vec<Vec<f32>>> {
969    let mut embeddings_by_prompt: Vec<Option<Vec<f32>>> = vec![None; prompt_count];
970    let mut completed_prompts = vec![false; prompt_count];
971    let mut finals_received = 0usize;
972
973    while finals_received < prompt_count {
974        match rx.recv().await {
975            Some(delta) => {
976                if let Some(error) = delta.error {
977                    return Err(ClientError::RequestFailed(error));
978                }
979
980                let prompt_index = delta.prompt_index.unwrap_or(0) as usize;
981                if prompt_index >= prompt_count {
982                    continue;
983                }
984
985                if let Some(bytes) = delta.embedding_bytes.as_deref() {
986                    embeddings_by_prompt[prompt_index] = Some(decode_embedding_bytes(bytes)?);
987                }
988
989                if delta.is_final_delta && !completed_prompts[prompt_index] {
990                    completed_prompts[prompt_index] = true;
991                    finals_received += 1;
992                }
993            }
994            None => {
995                return Err(ClientError::RequestFailed(
996                    "Embedding response channel closed before completion".to_string(),
997                ));
998            }
999        }
1000    }
1001
1002    embeddings_by_prompt
1003        .into_iter()
1004        .enumerate()
1005        .map(|(prompt_index, embedding)| {
1006            embedding.ok_or_else(|| {
1007                ClientError::RequestFailed(format!(
1008                    "Embedding response missing bytes for prompt_index={}",
1009                    prompt_index
1010                ))
1011            })
1012        })
1013        .collect()
1014}
1015
1016fn decode_embedding_bytes(bytes: &[u8]) -> Result<Vec<f32>> {
1017    let mut chunks = bytes.chunks_exact(std::mem::size_of::<f32>());
1018    if !chunks.remainder().is_empty() {
1019        return Err(ClientError::RequestFailed(format!(
1020            "Embedding payload length {} is not divisible by {}",
1021            bytes.len(),
1022            std::mem::size_of::<f32>()
1023        )));
1024    }
1025
1026    Ok(chunks
1027        .by_ref()
1028        .map(|chunk| f32::from_le_bytes(chunk.try_into().expect("f32 chunk size")))
1029        .collect())
1030}
1031
1032async fn collect_transcription(mut rx: mpsc::UnboundedReceiver<ResponseDelta>) -> Result<String> {
1033    let mut transcription = String::new();
1034
1035    loop {
1036        match rx.recv().await {
1037            Some(delta) => {
1038                if let Some(error) = delta.error {
1039                    return Err(ClientError::RequestFailed(error));
1040                }
1041
1042                if let Some(content) = delta.content {
1043                    transcription.push_str(&content);
1044                }
1045
1046                if delta.is_final_delta {
1047                    return Ok(transcription);
1048                }
1049            }
1050            None => {
1051                return Err(ClientError::RequestFailed(
1052                    "Speech-to-text response channel closed before completion".to_string(),
1053                ));
1054            }
1055        }
1056    }
1057}
1058
1059/// Result of a chat operation.
1060pub enum ChatResult {
1061    /// Complete response (non-streaming)
1062    Complete(ClientResponse),
1063    /// Streaming response receiver
1064    Stream(mpsc::UnboundedReceiver<ResponseDelta>),
1065}
1066
1067/// State for a single candidate during best_of collection.
1068/// Mirrors Python's candidate state dict in gather_non_streaming_batch_response.
1069#[derive(Default)]
1070struct CandidateState {
1071    content: String,
1072    finish_reason: Option<String>,
1073    completion_tokens: u32,
1074    prompt_tokens: u32,
1075    cumulative_logprob: Option<f64>,
1076    generation_len: Option<u32>,
1077    completed: bool,
1078    deltas: Vec<ClientDelta>,
1079}
1080
1081impl CandidateState {
1082    /// Score for best_of selection: cumulative_logprob / generation_len
1083    #[inline]
1084    fn score(&self) -> f64 {
1085        match (self.cumulative_logprob, self.generation_len) {
1086            (Some(cumulative), Some(gen_len)) if gen_len > 0 => cumulative / gen_len as f64,
1087            _ => f64::NEG_INFINITY,
1088        }
1089    }
1090}
1091
1092/// Select the best candidates based on cumulative_logprob / generation_len.
1093/// Sorts in-place and truncates to avoid allocations.
1094fn select_best_candidates(
1095    mut candidates: Vec<CandidateState>,
1096    fanout: usize,
1097    final_target: usize,
1098) -> Vec<CandidateState> {
1099    let final_target = final_target.min(candidates.len()).max(1);
1100
1101    if final_target >= fanout {
1102        return candidates;
1103    }
1104
1105    // Sort in-place by score descending
1106    candidates.sort_by(|a, b| {
1107        b.score()
1108            .partial_cmp(&a.score())
1109            .unwrap_or(std::cmp::Ordering::Equal)
1110    });
1111
1112    candidates.truncate(final_target);
1113    candidates
1114}
1115
1116/// Build a ClientResponse from selected candidates, using completion tokens from the full best_of fan-out.
1117fn build_response_from_candidates(
1118    candidates: Vec<CandidateState>,
1119    total_completion_tokens: u32,
1120) -> ClientResponse {
1121    let prompt_tokens = candidates
1122        .iter()
1123        .map(|c| c.prompt_tokens)
1124        .max()
1125        .unwrap_or(0);
1126
1127    let capacity: usize = candidates.iter().map(|c| c.deltas.len()).sum();
1128    let mut all_deltas = Vec::with_capacity(capacity);
1129    let mut text = String::new();
1130    let mut finish_reason = None;
1131
1132    for candidate in candidates {
1133        text.push_str(&candidate.content);
1134        if candidate.finish_reason.is_some() {
1135            finish_reason = candidate.finish_reason;
1136        }
1137        all_deltas.extend(candidate.deltas);
1138    }
1139
1140    ClientResponse {
1141        text,
1142        finish_reason,
1143        usage: UsageStats {
1144            prompt_tokens,
1145            completion_tokens: total_completion_tokens,
1146            total_tokens: prompt_tokens + total_completion_tokens,
1147        },
1148        deltas: all_deltas,
1149    }
1150}
1151
1152/// Aggregate deltas into a complete response.
1153fn aggregate_response(deltas: Vec<ClientDelta>) -> ClientResponse {
1154    let text: String = deltas
1155        .iter()
1156        .filter_map(|d| d.content.as_ref())
1157        .cloned()
1158        .collect();
1159
1160    let finish_reason = deltas
1161        .iter()
1162        .rev()
1163        .find_map(|d| d.finish_reason.as_ref())
1164        .cloned();
1165
1166    let usage = extract_usage(&deltas);
1167
1168    ClientResponse {
1169        text,
1170        finish_reason,
1171        usage,
1172        deltas,
1173    }
1174}
1175
1176fn extract_usage(deltas: &[ClientDelta]) -> UsageStats {
1177    let mut usage = UsageStats::default();
1178
1179    for delta in deltas {
1180        if let Some(count) = delta.prompt_token_count {
1181            usage.prompt_tokens = usage.prompt_tokens.max(count);
1182        }
1183        if let Some(len) = delta.generation_len {
1184            usage.completion_tokens = usage.completion_tokens.max(len);
1185        }
1186    }
1187
1188    usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
1189    usage
1190}
1191
1192#[cfg(test)]
1193mod tests {
1194    use super::*;
1195
1196    #[test]
1197    fn test_sampling_params_default() {
1198        let params = SamplingParams::default();
1199        assert_eq!(params.max_tokens, 1024);
1200        assert_eq!(params.temperature, 1.0);
1201        assert_eq!(params.top_p, 1.0);
1202        assert_eq!(params.top_k, -1);
1203        assert_eq!(params.repetition_context_size, 60);
1204        assert_eq!(params.top_logprobs, 0);
1205        assert!(params.logit_bias.is_empty());
1206        assert!(params.tools.is_empty());
1207        assert!(params.response_format.is_none());
1208        assert!(!params.reasoning);
1209        assert!(params.reasoning_effort.is_none());
1210        assert!(params.instructions.is_none());
1211    }
1212
1213    #[test]
1214    fn test_aggregate_response() {
1215        let deltas = vec![
1216            ClientDelta {
1217                content: Some("Hello".to_string()),
1218                is_final: false,
1219                ..Default::default()
1220            },
1221            ClientDelta {
1222                content: Some(" World".to_string()),
1223                is_final: true,
1224                finish_reason: Some("stop".to_string()),
1225                ..Default::default()
1226            },
1227        ];
1228
1229        let response = aggregate_response(deltas);
1230        assert_eq!(response.text, "Hello World");
1231        assert_eq!(response.finish_reason, Some("stop".to_string()));
1232    }
1233
1234    #[test]
1235    fn test_build_embedding_prompt_payload() {
1236        let payload = build_embedding_prompt_payload("hello".to_string());
1237
1238        assert_eq!(payload.prompt, "hello");
1239        assert_eq!(payload.max_generated_tokens, 0);
1240        assert_eq!(payload.layout.len(), 1);
1241        assert_eq!(payload.layout[0].segment_type, "text");
1242        assert_eq!(payload.layout[0].length, 5);
1243        assert_eq!(payload.num_candidates, 1);
1244        assert_eq!(payload.best_of, Some(1));
1245        assert_eq!(payload.final_candidates, Some(1));
1246    }
1247
1248    #[test]
1249    fn test_decode_embedding_bytes() {
1250        let bytes = [
1251            0.0f32.to_le_bytes(),
1252            1.5f32.to_le_bytes(),
1253            (-2.25f32).to_le_bytes(),
1254        ]
1255        .concat();
1256
1257        let embedding = decode_embedding_bytes(&bytes).expect("embedding should decode");
1258        assert_eq!(embedding, vec![0.0, 1.5, -2.25]);
1259    }
1260
1261    #[test]
1262    fn test_decode_embedding_bytes_rejects_partial_float() {
1263        let error = decode_embedding_bytes(&[0, 0, 128]).expect_err("decode should fail");
1264        assert!(matches!(error, ClientError::RequestFailed(_)));
1265    }
1266
1267    #[test]
1268    fn test_build_stt_prompt_payload() {
1269        let payload = build_stt_prompt_payload(&[0.25, -0.5]);
1270
1271        assert!(payload.prompt.is_empty());
1272        assert_eq!(payload.capabilities.len(), 1);
1273        assert_eq!(payload.capabilities[0].name, "audio");
1274        assert_eq!(payload.capabilities[0].payload.len(), 8);
1275        assert_eq!(payload.layout.len(), 2);
1276        assert_eq!(payload.layout[0].segment_type, "text");
1277        assert_eq!(payload.layout[0].length, 0);
1278        assert_eq!(payload.layout[1].segment_type, "capability");
1279        assert_eq!(payload.layout[1].length, 8);
1280    }
1281
1282    #[test]
1283    fn test_encode_float32_pcm_bytes() {
1284        let bytes = encode_float32_pcm_bytes(&[0.0, 1.5, -2.25]);
1285        let decoded = decode_embedding_bytes(&bytes).expect("audio bytes should decode");
1286        assert_eq!(decoded, vec![0.0, 1.5, -2.25]);
1287    }
1288
1289    #[test]
1290    fn test_build_response_from_candidates_uses_total_completion_tokens() {
1291        let response = build_response_from_candidates(
1292            vec![CandidateState {
1293                content: "winner".to_string(),
1294                finish_reason: Some("stop".to_string()),
1295                completion_tokens: 2,
1296                prompt_tokens: 5,
1297                deltas: vec![ClientDelta {
1298                    content: Some("winner".to_string()),
1299                    ..Default::default()
1300                }],
1301                ..Default::default()
1302            }],
1303            /*total_completion_tokens=*/ 7,
1304        );
1305
1306        assert_eq!(response.text, "winner");
1307        assert_eq!(response.finish_reason, Some("stop".to_string()));
1308        assert_eq!(response.usage.prompt_tokens, 5);
1309        assert_eq!(response.usage.completion_tokens, 7);
1310        assert_eq!(response.usage.total_tokens, 12);
1311    }
1312}