1mod moondream;
6mod response;
7mod responses;
8
9use std::collections::HashMap;
10use std::sync::{Arc, OnceLock};
11
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use thiserror::Error;
16use tokio::sync::mpsc;
17
18static SYNC_RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
21
22fn get_sync_runtime() -> &'static tokio::runtime::Runtime {
23 SYNC_RUNTIME.get_or_init(|| {
24 tokio::runtime::Builder::new_current_thread()
25 .enable_all()
26 .build()
27 .expect("Failed to create sync runtime")
28 })
29}
30
31use crate::formatter::multimodal::{
32 build_multimodal_layout, build_multimodal_messages, CapabilityInput, LayoutSegment,
33};
34use crate::ipc::client::{EventCallback, IPCClient, ResponseDelta};
35use crate::ipc::serialization::{CapabilityEntry, LayoutEntry, PromptPayload, RequestType};
36use crate::model::registry::ModelRegistry;
37
38pub use moondream::{
39 BoundingBox, CaptionResult, DetectResult, DetectedObject, GazeResult, GroundingSpan,
40 MoondreamClient, Point, PointResult, QueryResult, ReasoningOutput, SpatialRef,
41 MOONDREAM_MODEL_ID,
42};
43pub use response::{BatchChatResult, ClientDelta, ClientResponse, UsageStats};
44pub use responses::{
45 ContentPartAddedEvent, ContentPartDoneEvent, FunctionCallArgumentsDeltaEvent,
46 FunctionCallArgumentsDoneEvent, IncompleteDetails, InputTokensDetails, OutputFunctionCall,
47 OutputItemAddedEvent, OutputItemDoneEvent, OutputMessage, OutputReasoning, OutputStatus,
48 OutputTextContent, OutputTextDeltaEvent, OutputTextDoneEvent, OutputTokensDetails,
49 ReasoningContent, ReasoningDeltaEvent, ReasoningDoneEvent, ReasoningSummaryTextContent,
50 ReasoningSummaryTextDeltaEvent, ReasoningSummaryTextDoneEvent, ResponseCompletedEvent,
51 ResponseCreatedEvent, ResponseError, ResponseEvent, ResponseFailedEvent,
52 ResponseInProgressEvent, ResponseIncompleteEvent, ResponseInputItem, ResponseObject,
53 ResponseOutputItem, ResponseSnapshot, ResponseUsage, ResponsesInput, ResponsesRequest,
54 ResponsesResult, StreamErrorDetail, StreamErrorEvent,
55};
56
57#[derive(Error, Debug)]
59pub enum ClientError {
60 #[error("Model not found: {0}")]
61 ModelNotFound(String),
62
63 #[error("{0}")]
64 ModelNotReady(String),
65
66 #[error("{0}")]
67 Ipc(String),
68
69 #[error("{0}")]
70 Formatter(String),
71
72 #[error("{0}")]
73 Multimodal(String),
74
75 #[error("{0}")]
76 RequestFailed(String),
77}
78
79impl From<crate::error::Error> for ClientError {
80 fn from(err: crate::error::Error) -> Self {
81 use crate::error::Error;
82 match err {
83 Error::ModelNotFound(s) => ClientError::ModelNotFound(s),
84 Error::ModelNotReady(s) => ClientError::ModelNotReady(s),
85 Error::NotConnected
86 | Error::InvalidResponse
87 | Error::Nng(_)
88 | Error::Timeout
89 | Error::ChannelClosed => ClientError::Ipc(err.to_string()),
90 Error::Template(s) => ClientError::Formatter(s),
91 Error::InvalidImageUrl
92 | Error::InvalidBase64
93 | Error::MissingContentType(_, _)
94 | Error::InvalidContent
95 | Error::PlaceholderMismatch(_, _)
96 | Error::EmptyRequest => ClientError::Multimodal(err.to_string()),
97 _ => ClientError::RequestFailed(err.to_string()),
98 }
99 }
100}
101
102pub type Result<T> = std::result::Result<T, ClientError>;
103
104use crate::defaults;
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SamplingParams {
109 #[serde(default = "defaults::max_tokens")]
110 pub max_tokens: i32,
111 #[serde(default = "defaults::temperature")]
112 pub temperature: f64,
113 #[serde(default = "defaults::top_p")]
114 pub top_p: f64,
115 #[serde(default = "defaults::top_k")]
116 pub top_k: i32,
117 #[serde(default)]
118 pub min_p: f64,
119 #[serde(default)]
120 pub rng_seed: u64,
121 #[serde(default)]
122 pub stop: Vec<String>,
123 #[serde(default)]
124 pub frequency_penalty: f64,
125 #[serde(default)]
126 pub presence_penalty: f64,
127 #[serde(default = "defaults::repetition_penalty")]
128 pub repetition_penalty: f64,
129 #[serde(default = "defaults::repetition_context_size")]
130 pub repetition_context_size: i32,
131 #[serde(default = "defaults::num_candidates")]
132 pub n: i32,
133 #[serde(default)]
134 pub best_of: Option<i32>,
135 #[serde(default)]
136 pub final_candidates: Option<i32>,
137 #[serde(default)]
138 pub top_logprobs: i32,
139 #[serde(default)]
140 pub logit_bias: HashMap<i32, f64>,
141 #[serde(default)]
142 pub tools: Vec<serde_json::Value>,
143 #[serde(default)]
144 pub tool_choice: Option<serde_json::Value>,
145 #[serde(default)]
146 pub max_tool_calls: Option<i32>,
147 #[serde(default)]
148 pub response_format: Option<serde_json::Value>,
149 #[serde(default)]
150 pub reasoning: bool,
151 #[serde(default)]
152 pub reasoning_effort: Option<String>,
153 #[serde(default)]
154 pub instructions: Option<String>,
155 #[serde(default)]
156 pub task_name: Option<String>,
157}
158
159impl Default for SamplingParams {
160 fn default() -> Self {
161 Self {
162 max_tokens: defaults::MAX_TOKENS,
163 temperature: defaults::TEMPERATURE,
164 top_p: defaults::TOP_P,
165 top_k: defaults::TOP_K,
166 min_p: 0.0,
167 rng_seed: 0,
168 stop: Vec::new(),
169 frequency_penalty: 0.0,
170 presence_penalty: 0.0,
171 repetition_penalty: defaults::REPETITION_PENALTY,
172 repetition_context_size: defaults::REPETITION_CONTEXT_SIZE,
173 n: defaults::NUM_CANDIDATES,
174 best_of: None,
175 final_candidates: None,
176 top_logprobs: 0,
177 logit_bias: HashMap::new(),
178 tools: Vec::new(),
179 tool_choice: None,
180 max_tool_calls: None,
181 response_format: None,
182 reasoning: false,
183 reasoning_effort: None,
184 instructions: None,
185 task_name: None,
186 }
187 }
188}
189
190fn tool_choice_to_string(tool_choice: Option<&Value>) -> String {
191 match tool_choice {
192 None | Some(Value::Null) => "auto".to_string(),
193 Some(Value::String(value)) => value.clone(),
194 Some(Value::Object(value)) => serde_json::to_string(value).unwrap_or_default(),
195 Some(other) => other.to_string(),
196 }
197}
198
199pub struct Client {
203 ipc: Arc<IPCClient>,
204 registry: Arc<ModelRegistry>,
205}
206
207impl Client {
208 pub fn new(ipc: Arc<IPCClient>, registry: Arc<ModelRegistry>) -> Self {
210 Self { ipc, registry }
211 }
212
213 pub async fn connect(registry: Arc<ModelRegistry>) -> Result<Self> {
219 let registry_for_events = Arc::clone(®istry);
221 let runtime_handle = tokio::runtime::Handle::current();
222 let event_callback: EventCallback =
223 Arc::new(move |event_name: &str, payload: &Value| match event_name {
224 "model_loaded" => {
225 let registry = Arc::clone(®istry_for_events);
226 let payload = payload.clone();
227 let handle = runtime_handle.clone();
228 handle.spawn(async move {
229 registry.handle_model_loaded(&payload).await;
230 });
231 }
232 "model_load_failed" => {
233 let registry = Arc::clone(®istry_for_events);
234 let payload = payload.clone();
235 let handle = runtime_handle.clone();
236 handle.spawn(async move {
237 registry.handle_model_load_failed(&payload).await;
238 });
239 }
240 _ => {}
241 });
242
243 let mut ipc = IPCClient::with_event_callback(event_callback);
244 ipc.connect()?;
245 let ipc = Arc::new(ipc);
246
247 registry.set_ipc_client(Arc::clone(&ipc)).await;
249
250 Ok(Self { ipc, registry })
251 }
252
253 pub async fn resolve_capabilities(&self, model_id: &str) -> Result<HashMap<String, i32>> {
255 let info = self.registry.ensure_loaded(model_id).await?;
256
257 let capabilities = info.capabilities.as_ref().cloned().unwrap_or_default();
258 let mut resolved = HashMap::new();
259
260 for (name, token_ids) in capabilities {
261 if let Some(&first) = token_ids.first() {
262 resolved.insert(name, first);
263 }
264 }
265
266 Ok(resolved)
267 }
268
269 pub async fn achat(
277 &self,
278 model_id: &str,
279 messages: Vec<HashMap<String, serde_json::Value>>,
280 params: SamplingParams,
281 stream: bool,
282 ) -> Result<ChatResult> {
283 let info = self.registry.ensure_loaded(model_id).await?;
284 let formatter = info.require_formatter()?;
285
286 let request_id = self.ipc.next_request_id();
287 tracing::debug!(
288 request_id,
289 model_id = %model_id,
290 stream,
291 message_count = messages.len(),
292 "Building chat request"
293 );
294 tracing::trace!(
295 request_id,
296 model_id = %model_id,
297 messages = ?messages,
298 "Chat messages before template application"
299 );
300
301 let reasoning_flag = params.reasoning || params.reasoning_effort.is_some();
303
304 let (messages_for_template, image_buffers, capabilities, content_order) =
306 build_multimodal_messages(formatter, &messages, params.instructions.as_deref())
307 .map_err(|e| ClientError::Multimodal(e.to_string()))?;
308
309 if messages_for_template.is_empty() {
310 return Err(ClientError::RequestFailed(
311 "Chat request must include at least one message".into(),
312 ));
313 }
314 tracing::trace!(
315 request_id,
316 model_id = %model_id,
317 messages_for_template = ?messages_for_template,
318 "Chat messages after multimodal expansion"
319 );
320
321 let prompt_text = formatter
323 .apply_template(
324 &messages_for_template,
325 true,
326 reasoning_flag,
327 params.task_name.as_deref(),
328 )
329 .map_err(|e| ClientError::Formatter(e.to_string()))?;
330
331 let capability_placeholder = formatter.capability_placeholder_token();
332
333 let layout_segments = build_multimodal_layout(
335 &prompt_text,
336 &image_buffers,
337 &capabilities,
338 &content_order,
339 formatter.image_placeholder_token(),
340 formatter.should_clip_image_placeholder(),
341 capability_placeholder,
342 )
343 .map_err(|e| ClientError::Multimodal(e.to_string()))?;
344
345 let final_prompt = formatter.strip_template_placeholders(&prompt_text);
346 tracing::debug!(
347 request_id,
348 model_id = %model_id,
349 prompt_chars = final_prompt.chars().count(),
350 image_count = image_buffers.len(),
351 capability_count = capabilities.len(),
352 layout_segment_count = layout_segments.len(),
353 "Prepared chat prompt payload"
354 );
355 tracing::trace!(
356 request_id,
357 model_id = %model_id,
358 prompt = %final_prompt,
359 "Chat prompt sent to PIE"
360 );
361
362 let tool_schemas_json = if params.tools.is_empty() {
364 String::new()
365 } else {
366 serde_json::to_string(¶ms.tools).unwrap_or_default()
367 };
368 let response_format_json = params
369 .response_format
370 .as_ref()
371 .map(|rf| serde_json::to_string(rf).unwrap_or_default())
372 .unwrap_or_default();
373 let tool_calling_tokens = formatter.get_tool_calling_tokens().clone();
374 let tool_choice = tool_choice_to_string(params.tool_choice.as_ref());
375 let max_tool_calls = params.max_tool_calls.unwrap_or(0).max(0);
376
377 let rng_seed = if params.rng_seed == 0 {
380 rand::thread_rng().gen::<u64>()
381 } else {
382 params.rng_seed
383 };
384
385 let prompt_payload = PromptPayload {
386 prompt: final_prompt,
387 image_buffers,
388 capabilities: convert_capabilities(&capabilities),
389 layout: convert_layout(&layout_segments),
390 max_generated_tokens: params.max_tokens,
391 temperature: params.temperature,
392 top_p: params.top_p,
393 top_k: params.top_k,
394 min_p: params.min_p,
395 rng_seed,
396 stop_sequences: params.stop.clone(),
397 num_candidates: params.n,
398 best_of: params.best_of,
399 final_candidates: params.final_candidates,
400 frequency_penalty: params.frequency_penalty,
401 presence_penalty: params.presence_penalty,
402 repetition_penalty: params.repetition_penalty,
403 repetition_context_size: params.repetition_context_size,
404 top_logprobs: params.top_logprobs,
405 logit_bias: params.logit_bias.clone(),
406 tool_schemas_json,
407 tool_calling_tokens,
408 tool_choice,
409 max_tool_calls,
410 response_format_json,
411 task_name: params.task_name.clone(),
412 reasoning_effort: params.reasoning_effort.clone(),
413 };
414
415 tracing::debug!(
417 request_id,
418 model_id = %model_id,
419 stream,
420 "Dispatching chat request to PIE"
421 );
422 let (_batch_size, rx) = self.ipc.send_batch_request(
423 request_id,
424 model_id,
425 &info.model_path,
426 &[prompt_payload],
427 )?;
428
429 if stream {
430 Ok(ChatResult::Stream(rx))
431 } else {
432 let best_of = params.best_of.unwrap_or(params.n).max(1) as usize;
434 let final_candidates = params.final_candidates.unwrap_or(params.n).max(1) as usize;
435
436 let mut candidate_states: Vec<CandidateState> =
438 (0..best_of).map(|_| CandidateState::default()).collect();
439 let mut remaining_sequences = best_of;
440 let mut rx = rx;
441
442 while remaining_sequences > 0 {
443 match rx.recv().await {
444 Some(delta) => {
445 let candidate_index = delta.candidate_index.unwrap_or(0) as usize;
446 if candidate_index >= candidate_states.len() {
447 continue;
448 }
449
450 let state = &mut candidate_states[candidate_index];
451
452 if let Some(content) = &delta.content {
453 state.content.push_str(content);
454 }
455 state.completion_tokens += delta.tokens.len() as u32;
456 if let Some(count) = delta.prompt_token_count {
457 state.prompt_tokens = state.prompt_tokens.max(count);
458 }
459
460 let client_delta = ClientDelta::from(delta.clone());
461 state.deltas.push(client_delta);
462
463 if delta.is_final_delta && !state.completed {
464 state.completed = true;
465 state.finish_reason = delta.finish_reason.clone();
466 state.cumulative_logprob = delta.cumulative_logprob;
467 state.generation_len = delta.generation_len;
468 remaining_sequences -= 1;
469 }
470 }
471 None => break,
472 }
473 }
474
475 let total_completion_tokens: u32 =
476 candidate_states.iter().map(|c| c.completion_tokens).sum();
477
478 let selected = select_best_candidates(candidate_states, best_of, final_candidates);
480
481 Ok(ChatResult::Complete(build_response_from_candidates(
482 selected,
483 total_completion_tokens,
484 )))
485 }
486 }
487
488 pub fn chat(
492 &self,
493 model_id: &str,
494 messages: Vec<HashMap<String, serde_json::Value>>,
495 params: SamplingParams,
496 ) -> Result<ClientResponse> {
497 let future = async {
498 match self.achat(model_id, messages, params, false).await? {
499 ChatResult::Complete(response) => Ok(response),
500 ChatResult::Stream(_) => Err(ClientError::RequestFailed(
501 "Unexpected stream result".into(),
502 )),
503 }
504 };
505
506 match tokio::runtime::Handle::try_current() {
507 Ok(handle) => {
508 tokio::task::block_in_place(|| handle.block_on(future))
510 }
511 Err(_) => {
512 get_sync_runtime().block_on(future)
514 }
515 }
516 }
517
518 pub async fn achat_batch(
530 &self,
531 model_id: &str,
532 conversations: Vec<Vec<HashMap<String, serde_json::Value>>>,
533 params: SamplingParams,
534 stream: bool,
535 ) -> Result<BatchChatResult> {
536 if conversations.is_empty() {
537 return Ok(BatchChatResult::Complete(Vec::new()));
538 }
539
540 let info = self.registry.ensure_loaded(model_id).await?;
541 let formatter = info.require_formatter()?;
542
543 let request_id = self.ipc.next_request_id();
544 let num_prompts = conversations.len();
545 tracing::debug!(
546 request_id,
547 model_id = %model_id,
548 stream,
549 prompt_count = num_prompts,
550 "Building batched chat request"
551 );
552
553 let reasoning_flag = params.reasoning || params.reasoning_effort.is_some();
555
556 let tool_schemas_json = if params.tools.is_empty() {
558 String::new()
559 } else {
560 serde_json::to_string(¶ms.tools).unwrap_or_default()
561 };
562 let response_format_json = params
563 .response_format
564 .as_ref()
565 .map(|rf| serde_json::to_string(rf).unwrap_or_default())
566 .unwrap_or_default();
567 let tool_calling_tokens = formatter.get_tool_calling_tokens().clone();
568 let tool_choice = tool_choice_to_string(params.tool_choice.as_ref());
569 let max_tool_calls = params.max_tool_calls.unwrap_or(0).max(0);
570
571 let mut prompt_payloads = Vec::with_capacity(num_prompts);
573
574 for (prompt_index, messages) in conversations.iter().enumerate() {
575 let (messages_for_template, image_buffers, capabilities, content_order) =
577 build_multimodal_messages(formatter, messages, params.instructions.as_deref())
578 .map_err(|e| ClientError::Multimodal(e.to_string()))?;
579
580 if messages_for_template.is_empty() {
581 return Err(ClientError::RequestFailed(
582 "Chat request must include at least one message".into(),
583 ));
584 }
585 tracing::trace!(
586 request_id,
587 model_id = %model_id,
588 prompt_index,
589 messages = ?messages,
590 messages_for_template = ?messages_for_template,
591 "Prepared batch messages for prompt"
592 );
593
594 let prompt_text = formatter
596 .apply_template(
597 &messages_for_template,
598 true,
599 reasoning_flag,
600 params.task_name.as_deref(),
601 )
602 .map_err(|e| ClientError::Formatter(e.to_string()))?;
603
604 let capability_placeholder = formatter.capability_placeholder_token();
605
606 let layout_segments = build_multimodal_layout(
608 &prompt_text,
609 &image_buffers,
610 &capabilities,
611 &content_order,
612 formatter.image_placeholder_token(),
613 formatter.should_clip_image_placeholder(),
614 capability_placeholder,
615 )
616 .map_err(|e| ClientError::Multimodal(e.to_string()))?;
617
618 let final_prompt = formatter.strip_template_placeholders(&prompt_text);
619 tracing::debug!(
620 request_id,
621 model_id = %model_id,
622 prompt_index,
623 prompt_chars = final_prompt.chars().count(),
624 image_count = image_buffers.len(),
625 capability_count = capabilities.len(),
626 layout_segment_count = layout_segments.len(),
627 "Prepared batched prompt payload"
628 );
629 tracing::trace!(
630 request_id,
631 model_id = %model_id,
632 prompt_index,
633 prompt = %final_prompt,
634 "Batch prompt sent to PIE"
635 );
636
637 let rng_seed = if params.rng_seed == 0 {
639 rand::thread_rng().gen::<u64>()
640 } else {
641 params.rng_seed
642 };
643
644 prompt_payloads.push(PromptPayload {
645 prompt: final_prompt,
646 image_buffers,
647 capabilities: convert_capabilities(&capabilities),
648 layout: convert_layout(&layout_segments),
649 max_generated_tokens: params.max_tokens,
650 temperature: params.temperature,
651 top_p: params.top_p,
652 top_k: params.top_k,
653 min_p: params.min_p,
654 rng_seed,
655 stop_sequences: params.stop.clone(),
656 num_candidates: params.n,
657 best_of: params.best_of,
658 final_candidates: params.final_candidates,
659 frequency_penalty: params.frequency_penalty,
660 presence_penalty: params.presence_penalty,
661 repetition_penalty: params.repetition_penalty,
662 repetition_context_size: params.repetition_context_size,
663 top_logprobs: params.top_logprobs,
664 logit_bias: params.logit_bias.clone(),
665 tool_schemas_json: tool_schemas_json.clone(),
666 tool_calling_tokens: tool_calling_tokens.clone(),
667 tool_choice: tool_choice.clone(),
668 max_tool_calls,
669 response_format_json: response_format_json.clone(),
670 task_name: params.task_name.clone(),
671 reasoning_effort: params.reasoning_effort.clone(),
672 });
673 }
674
675 tracing::debug!(
677 request_id,
678 model_id = %model_id,
679 stream,
680 prompt_count = prompt_payloads.len(),
681 "Dispatching batched chat request to PIE"
682 );
683 let (_batch_size, rx) = self.ipc.send_batch_request(
684 request_id,
685 model_id,
686 &info.model_path,
687 &prompt_payloads,
688 )?;
689
690 if stream {
691 let (tx, client_rx) = mpsc::channel(256);
693 tokio::spawn(async move {
694 let mut rx = rx;
695 while let Some(delta) = rx.recv().await {
696 if tx.send(ClientDelta::from(delta)).await.is_err() {
697 break;
698 }
699 }
700 });
701 return Ok(BatchChatResult::Stream(client_rx));
702 }
703
704 let mut deltas_by_prompt: HashMap<u32, Vec<ClientDelta>> = HashMap::new();
706 let mut finals_received = 0usize;
707 let mut rx = rx;
708
709 while finals_received < num_prompts {
710 match rx.recv().await {
711 Some(delta) => {
712 let prompt_index = delta.prompt_index.unwrap_or(0);
713 let is_final = delta.is_final_delta;
714
715 deltas_by_prompt
716 .entry(prompt_index)
717 .or_default()
718 .push(ClientDelta::from(delta));
719
720 if is_final {
721 finals_received += 1;
722 }
723 }
724 None => break, }
726 }
727
728 let mut responses = Vec::with_capacity(num_prompts);
730 for idx in 0..num_prompts {
731 let deltas = deltas_by_prompt.remove(&(idx as u32)).unwrap_or_default();
732 responses.push(aggregate_response(deltas));
733 }
734
735 Ok(BatchChatResult::Complete(responses))
736 }
737
738 pub async fn aembed(&self, model_id: &str, text: &str) -> Result<Vec<f32>> {
740 let mut embeddings = self.aembed_batch(model_id, vec![text.to_string()]).await?;
741 embeddings.pop().ok_or_else(|| {
742 ClientError::RequestFailed("Embedding response missing result".to_string())
743 })
744 }
745
746 pub async fn aembed_batch(&self, model_id: &str, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
748 if texts.is_empty() {
749 return Ok(Vec::new());
750 }
751
752 let info = self.registry.ensure_loaded(model_id).await?;
753
754 let request_id = self.ipc.next_request_id();
755 tracing::debug!(
756 request_id,
757 model_id = %model_id,
758 prompt_count = texts.len(),
759 "Building batched embedding request"
760 );
761
762 let mut prompt_payloads = Vec::with_capacity(texts.len());
763 for (prompt_index, text) in texts.into_iter().enumerate() {
764 let prompt_chars = text.chars().count();
765 tracing::debug!(
766 request_id,
767 model_id = %model_id,
768 prompt_index,
769 prompt_chars,
770 "Prepared embedding prompt payload"
771 );
772 tracing::trace!(
773 request_id,
774 model_id = %model_id,
775 prompt_index,
776 prompt = %text,
777 "Embedding prompt sent to PIE"
778 );
779
780 prompt_payloads.push(build_embedding_prompt_payload(text));
781 }
782
783 tracing::debug!(
784 request_id,
785 model_id = %model_id,
786 prompt_count = prompt_payloads.len(),
787 "Dispatching batched embedding request to PIE"
788 );
789 let (_batch_size, rx) = self.ipc.send_batch_request_with_type(
790 request_id,
791 model_id,
792 &info.model_path,
793 RequestType::Embedding,
794 &prompt_payloads,
795 )?;
796
797 collect_embeddings(rx, prompt_payloads.len()).await
798 }
799
800 pub async fn atranscribe_audio(&self, model_id: &str, pcm: &[f32]) -> Result<String> {
802 if pcm.is_empty() {
803 return Ok(String::new());
804 }
805
806 let info = self.registry.ensure_loaded(model_id).await?;
807
808 let request_id = self.ipc.next_request_id();
809 tracing::debug!(
810 request_id,
811 model_id = %model_id,
812 sample_count = pcm.len(),
813 "Building speech-to-text request"
814 );
815
816 let prompt_payload = build_stt_prompt_payload(pcm);
817
818 tracing::debug!(
819 request_id,
820 model_id = %model_id,
821 payload_bytes = prompt_payload.capabilities[0].payload.len(),
822 "Dispatching speech-to-text request to PIE"
823 );
824 let (_batch_size, rx) = self.ipc.send_batch_request_with_type(
825 request_id,
826 model_id,
827 &info.model_path,
828 RequestType::Omni,
829 &[prompt_payload],
830 )?;
831
832 collect_transcription(rx).await
833 }
834
835 pub fn transcribe_audio(&self, model_id: &str, pcm: &[f32]) -> Result<String> {
837 let model_id = model_id.to_string();
838 let pcm = pcm.to_vec();
839 let future = async move { self.atranscribe_audio(&model_id, &pcm).await };
840
841 match tokio::runtime::Handle::try_current() {
842 Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)),
843 Err(_) => get_sync_runtime().block_on(future),
844 }
845 }
846}
847
848fn convert_capabilities(capabilities: &[CapabilityInput]) -> Vec<CapabilityEntry> {
851 capabilities
852 .iter()
853 .map(|cap| CapabilityEntry {
854 name: cap.name.clone(),
855 position: 0, payload: cap.payload.clone(),
857 })
858 .collect()
859}
860
861fn convert_layout(segments: &[LayoutSegment]) -> Vec<LayoutEntry> {
863 segments
864 .iter()
865 .map(|seg| LayoutEntry {
866 segment_type: seg.segment_type.clone(),
867 length: seg.length,
868 })
869 .collect()
870}
871
872fn build_embedding_prompt_payload(prompt: String) -> PromptPayload {
873 let prompt_len = prompt.len();
874
875 PromptPayload {
876 prompt,
877 image_buffers: Vec::new(),
878 capabilities: Vec::new(),
879 layout: vec![LayoutEntry {
880 segment_type: "text".to_string(),
881 length: prompt_len,
882 }],
883 max_generated_tokens: 0,
884 temperature: defaults::TEMPERATURE,
885 top_p: defaults::TOP_P,
886 top_k: defaults::TOP_K,
887 min_p: 0.0,
888 rng_seed: rand::thread_rng().gen::<u64>(),
889 stop_sequences: Vec::new(),
890 num_candidates: 1,
891 best_of: Some(1),
892 final_candidates: Some(1),
893 frequency_penalty: 0.0,
894 presence_penalty: 0.0,
895 repetition_penalty: defaults::REPETITION_PENALTY,
896 repetition_context_size: 0,
897 top_logprobs: 0,
898 logit_bias: HashMap::new(),
899 tool_schemas_json: String::new(),
900 tool_calling_tokens: Default::default(),
901 tool_choice: "auto".to_string(),
902 max_tool_calls: 0,
903 response_format_json: String::new(),
904 task_name: None,
905 reasoning_effort: None,
906 }
907}
908
909fn build_stt_prompt_payload(pcm: &[f32]) -> PromptPayload {
910 let audio_payload = encode_float32_pcm_bytes(pcm);
911 let audio_payload_size = audio_payload.len();
912
913 PromptPayload {
914 prompt: String::new(),
915 image_buffers: Vec::new(),
916 capabilities: vec![CapabilityEntry {
917 name: "audio".to_string(),
918 position: 0,
919 payload: audio_payload,
920 }],
921 layout: vec![
922 LayoutEntry {
923 segment_type: "text".to_string(),
924 length: 0,
925 },
926 LayoutEntry {
927 segment_type: "capability".to_string(),
928 length: audio_payload_size,
929 },
930 ],
931 max_generated_tokens: 0,
932 temperature: defaults::TEMPERATURE,
933 top_p: defaults::TOP_P,
934 top_k: defaults::TOP_K,
935 min_p: 0.0,
936 rng_seed: rand::thread_rng().gen::<u64>(),
937 stop_sequences: Vec::new(),
938 num_candidates: 1,
939 best_of: Some(1),
940 final_candidates: Some(1),
941 frequency_penalty: 0.0,
942 presence_penalty: 0.0,
943 repetition_penalty: defaults::REPETITION_PENALTY,
944 repetition_context_size: 0,
945 top_logprobs: 0,
946 logit_bias: HashMap::new(),
947 tool_schemas_json: String::new(),
948 tool_calling_tokens: Default::default(),
949 tool_choice: "auto".to_string(),
950 max_tool_calls: 0,
951 response_format_json: String::new(),
952 task_name: None,
953 reasoning_effort: None,
954 }
955}
956
957fn encode_float32_pcm_bytes(pcm: &[f32]) -> Vec<u8> {
958 let mut bytes = Vec::with_capacity(std::mem::size_of_val(pcm));
959 for sample in pcm {
960 bytes.extend_from_slice(&sample.to_le_bytes());
961 }
962 bytes
963}
964
965async fn collect_embeddings(
966 mut rx: mpsc::UnboundedReceiver<ResponseDelta>,
967 prompt_count: usize,
968) -> Result<Vec<Vec<f32>>> {
969 let mut embeddings_by_prompt: Vec<Option<Vec<f32>>> = vec![None; prompt_count];
970 let mut completed_prompts = vec![false; prompt_count];
971 let mut finals_received = 0usize;
972
973 while finals_received < prompt_count {
974 match rx.recv().await {
975 Some(delta) => {
976 if let Some(error) = delta.error {
977 return Err(ClientError::RequestFailed(error));
978 }
979
980 let prompt_index = delta.prompt_index.unwrap_or(0) as usize;
981 if prompt_index >= prompt_count {
982 continue;
983 }
984
985 if let Some(bytes) = delta.embedding_bytes.as_deref() {
986 embeddings_by_prompt[prompt_index] = Some(decode_embedding_bytes(bytes)?);
987 }
988
989 if delta.is_final_delta && !completed_prompts[prompt_index] {
990 completed_prompts[prompt_index] = true;
991 finals_received += 1;
992 }
993 }
994 None => {
995 return Err(ClientError::RequestFailed(
996 "Embedding response channel closed before completion".to_string(),
997 ));
998 }
999 }
1000 }
1001
1002 embeddings_by_prompt
1003 .into_iter()
1004 .enumerate()
1005 .map(|(prompt_index, embedding)| {
1006 embedding.ok_or_else(|| {
1007 ClientError::RequestFailed(format!(
1008 "Embedding response missing bytes for prompt_index={}",
1009 prompt_index
1010 ))
1011 })
1012 })
1013 .collect()
1014}
1015
1016fn decode_embedding_bytes(bytes: &[u8]) -> Result<Vec<f32>> {
1017 let mut chunks = bytes.chunks_exact(std::mem::size_of::<f32>());
1018 if !chunks.remainder().is_empty() {
1019 return Err(ClientError::RequestFailed(format!(
1020 "Embedding payload length {} is not divisible by {}",
1021 bytes.len(),
1022 std::mem::size_of::<f32>()
1023 )));
1024 }
1025
1026 Ok(chunks
1027 .by_ref()
1028 .map(|chunk| f32::from_le_bytes(chunk.try_into().expect("f32 chunk size")))
1029 .collect())
1030}
1031
1032async fn collect_transcription(mut rx: mpsc::UnboundedReceiver<ResponseDelta>) -> Result<String> {
1033 let mut transcription = String::new();
1034
1035 loop {
1036 match rx.recv().await {
1037 Some(delta) => {
1038 if let Some(error) = delta.error {
1039 return Err(ClientError::RequestFailed(error));
1040 }
1041
1042 if let Some(content) = delta.content {
1043 transcription.push_str(&content);
1044 }
1045
1046 if delta.is_final_delta {
1047 return Ok(transcription);
1048 }
1049 }
1050 None => {
1051 return Err(ClientError::RequestFailed(
1052 "Speech-to-text response channel closed before completion".to_string(),
1053 ));
1054 }
1055 }
1056 }
1057}
1058
1059pub enum ChatResult {
1061 Complete(ClientResponse),
1063 Stream(mpsc::UnboundedReceiver<ResponseDelta>),
1065}
1066
1067#[derive(Default)]
1070struct CandidateState {
1071 content: String,
1072 finish_reason: Option<String>,
1073 completion_tokens: u32,
1074 prompt_tokens: u32,
1075 cumulative_logprob: Option<f64>,
1076 generation_len: Option<u32>,
1077 completed: bool,
1078 deltas: Vec<ClientDelta>,
1079}
1080
1081impl CandidateState {
1082 #[inline]
1084 fn score(&self) -> f64 {
1085 match (self.cumulative_logprob, self.generation_len) {
1086 (Some(cumulative), Some(gen_len)) if gen_len > 0 => cumulative / gen_len as f64,
1087 _ => f64::NEG_INFINITY,
1088 }
1089 }
1090}
1091
1092fn select_best_candidates(
1095 mut candidates: Vec<CandidateState>,
1096 fanout: usize,
1097 final_target: usize,
1098) -> Vec<CandidateState> {
1099 let final_target = final_target.min(candidates.len()).max(1);
1100
1101 if final_target >= fanout {
1102 return candidates;
1103 }
1104
1105 candidates.sort_by(|a, b| {
1107 b.score()
1108 .partial_cmp(&a.score())
1109 .unwrap_or(std::cmp::Ordering::Equal)
1110 });
1111
1112 candidates.truncate(final_target);
1113 candidates
1114}
1115
1116fn build_response_from_candidates(
1118 candidates: Vec<CandidateState>,
1119 total_completion_tokens: u32,
1120) -> ClientResponse {
1121 let prompt_tokens = candidates
1122 .iter()
1123 .map(|c| c.prompt_tokens)
1124 .max()
1125 .unwrap_or(0);
1126
1127 let capacity: usize = candidates.iter().map(|c| c.deltas.len()).sum();
1128 let mut all_deltas = Vec::with_capacity(capacity);
1129 let mut text = String::new();
1130 let mut finish_reason = None;
1131
1132 for candidate in candidates {
1133 text.push_str(&candidate.content);
1134 if candidate.finish_reason.is_some() {
1135 finish_reason = candidate.finish_reason;
1136 }
1137 all_deltas.extend(candidate.deltas);
1138 }
1139
1140 ClientResponse {
1141 text,
1142 finish_reason,
1143 usage: UsageStats {
1144 prompt_tokens,
1145 completion_tokens: total_completion_tokens,
1146 total_tokens: prompt_tokens + total_completion_tokens,
1147 },
1148 deltas: all_deltas,
1149 }
1150}
1151
1152fn aggregate_response(deltas: Vec<ClientDelta>) -> ClientResponse {
1154 let text: String = deltas
1155 .iter()
1156 .filter_map(|d| d.content.as_ref())
1157 .cloned()
1158 .collect();
1159
1160 let finish_reason = deltas
1161 .iter()
1162 .rev()
1163 .find_map(|d| d.finish_reason.as_ref())
1164 .cloned();
1165
1166 let usage = extract_usage(&deltas);
1167
1168 ClientResponse {
1169 text,
1170 finish_reason,
1171 usage,
1172 deltas,
1173 }
1174}
1175
1176fn extract_usage(deltas: &[ClientDelta]) -> UsageStats {
1177 let mut usage = UsageStats::default();
1178
1179 for delta in deltas {
1180 if let Some(count) = delta.prompt_token_count {
1181 usage.prompt_tokens = usage.prompt_tokens.max(count);
1182 }
1183 if let Some(len) = delta.generation_len {
1184 usage.completion_tokens = usage.completion_tokens.max(len);
1185 }
1186 }
1187
1188 usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
1189 usage
1190}
1191
1192#[cfg(test)]
1193mod tests {
1194 use super::*;
1195
1196 #[test]
1197 fn test_sampling_params_default() {
1198 let params = SamplingParams::default();
1199 assert_eq!(params.max_tokens, 1024);
1200 assert_eq!(params.temperature, 1.0);
1201 assert_eq!(params.top_p, 1.0);
1202 assert_eq!(params.top_k, -1);
1203 assert_eq!(params.repetition_context_size, 60);
1204 assert_eq!(params.top_logprobs, 0);
1205 assert!(params.logit_bias.is_empty());
1206 assert!(params.tools.is_empty());
1207 assert!(params.response_format.is_none());
1208 assert!(!params.reasoning);
1209 assert!(params.reasoning_effort.is_none());
1210 assert!(params.instructions.is_none());
1211 }
1212
1213 #[test]
1214 fn test_aggregate_response() {
1215 let deltas = vec![
1216 ClientDelta {
1217 content: Some("Hello".to_string()),
1218 is_final: false,
1219 ..Default::default()
1220 },
1221 ClientDelta {
1222 content: Some(" World".to_string()),
1223 is_final: true,
1224 finish_reason: Some("stop".to_string()),
1225 ..Default::default()
1226 },
1227 ];
1228
1229 let response = aggregate_response(deltas);
1230 assert_eq!(response.text, "Hello World");
1231 assert_eq!(response.finish_reason, Some("stop".to_string()));
1232 }
1233
1234 #[test]
1235 fn test_build_embedding_prompt_payload() {
1236 let payload = build_embedding_prompt_payload("hello".to_string());
1237
1238 assert_eq!(payload.prompt, "hello");
1239 assert_eq!(payload.max_generated_tokens, 0);
1240 assert_eq!(payload.layout.len(), 1);
1241 assert_eq!(payload.layout[0].segment_type, "text");
1242 assert_eq!(payload.layout[0].length, 5);
1243 assert_eq!(payload.num_candidates, 1);
1244 assert_eq!(payload.best_of, Some(1));
1245 assert_eq!(payload.final_candidates, Some(1));
1246 }
1247
1248 #[test]
1249 fn test_decode_embedding_bytes() {
1250 let bytes = [
1251 0.0f32.to_le_bytes(),
1252 1.5f32.to_le_bytes(),
1253 (-2.25f32).to_le_bytes(),
1254 ]
1255 .concat();
1256
1257 let embedding = decode_embedding_bytes(&bytes).expect("embedding should decode");
1258 assert_eq!(embedding, vec![0.0, 1.5, -2.25]);
1259 }
1260
1261 #[test]
1262 fn test_decode_embedding_bytes_rejects_partial_float() {
1263 let error = decode_embedding_bytes(&[0, 0, 128]).expect_err("decode should fail");
1264 assert!(matches!(error, ClientError::RequestFailed(_)));
1265 }
1266
1267 #[test]
1268 fn test_build_stt_prompt_payload() {
1269 let payload = build_stt_prompt_payload(&[0.25, -0.5]);
1270
1271 assert!(payload.prompt.is_empty());
1272 assert_eq!(payload.capabilities.len(), 1);
1273 assert_eq!(payload.capabilities[0].name, "audio");
1274 assert_eq!(payload.capabilities[0].payload.len(), 8);
1275 assert_eq!(payload.layout.len(), 2);
1276 assert_eq!(payload.layout[0].segment_type, "text");
1277 assert_eq!(payload.layout[0].length, 0);
1278 assert_eq!(payload.layout[1].segment_type, "capability");
1279 assert_eq!(payload.layout[1].length, 8);
1280 }
1281
1282 #[test]
1283 fn test_encode_float32_pcm_bytes() {
1284 let bytes = encode_float32_pcm_bytes(&[0.0, 1.5, -2.25]);
1285 let decoded = decode_embedding_bytes(&bytes).expect("audio bytes should decode");
1286 assert_eq!(decoded, vec![0.0, 1.5, -2.25]);
1287 }
1288
1289 #[test]
1290 fn test_build_response_from_candidates_uses_total_completion_tokens() {
1291 let response = build_response_from_candidates(
1292 vec![CandidateState {
1293 content: "winner".to_string(),
1294 finish_reason: Some("stop".to_string()),
1295 completion_tokens: 2,
1296 prompt_tokens: 5,
1297 deltas: vec![ClientDelta {
1298 content: Some("winner".to_string()),
1299 ..Default::default()
1300 }],
1301 ..Default::default()
1302 }],
1303 7,
1304 );
1305
1306 assert_eq!(response.text, "winner");
1307 assert_eq!(response.finish_reason, Some("stop".to_string()));
1308 assert_eq!(response.usage.prompt_tokens, 5);
1309 assert_eq!(response.usage.completion_tokens, 7);
1310 assert_eq!(response.usage.total_tokens, 12);
1311 }
1312}