1#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4use candle_core::Tensor;
5use serde::{Deserialize, Serialize};
6
7#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
8use crate::backend::CandleBackend;
9use crate::InferenceError;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
13#[serde(rename_all = "snake_case")]
14pub enum RoutingWorkload {
15 #[default]
17 Interactive,
18 Batch,
20 Background,
22 LocalPreferred,
31}
32
33impl RoutingWorkload {
34 pub fn is_latency_sensitive(self) -> bool {
35 matches!(
36 self,
37 RoutingWorkload::Interactive | RoutingWorkload::LocalPreferred,
38 )
39 }
40
41 pub fn weights(self) -> (f64, f64, f64) {
42 match self {
43 RoutingWorkload::Interactive => (0.45, 0.40, 0.15),
44 RoutingWorkload::Batch => (0.60, 0.15, 0.25),
45 RoutingWorkload::Background => (0.65, 0.05, 0.30),
46 RoutingWorkload::LocalPreferred => (0.55, 0.20, 0.25),
50 }
51 }
52
53 pub fn local_bonus(self) -> f64 {
54 match self {
55 RoutingWorkload::Interactive => 0.0,
56 RoutingWorkload::Batch => 0.08,
57 RoutingWorkload::Background => 0.15,
58 RoutingWorkload::LocalPreferred => 0.20,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
78#[serde(rename_all = "snake_case")]
79pub enum ThinkingMode {
80 #[default]
85 Auto,
86 On,
90 Off,
94}
95
96impl ThinkingMode {
97 pub fn directive(self) -> Option<&'static str> {
100 match self {
101 ThinkingMode::Auto => None,
102 ThinkingMode::On => Some("/think"),
103 ThinkingMode::Off => Some("/no_think"),
104 }
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct GenerateParams {
111 #[serde(default = "default_temperature")]
113 pub temperature: f64,
114 #[serde(default = "default_top_p")]
116 pub top_p: f64,
117 #[serde(default)]
119 pub top_k: usize,
120 #[serde(default = "default_max_tokens")]
122 pub max_tokens: usize,
123 #[serde(default)]
125 pub stop: Vec<String>,
126 #[serde(default)]
130 pub budget_tokens: usize,
131 #[serde(default)]
134 pub workload: RoutingWorkload,
135 #[serde(default)]
139 pub tool_choice: Option<String>,
140 #[serde(default)]
142 pub parallel_tool_calls: Option<bool>,
143 #[serde(default)]
148 pub thinking: ThinkingMode,
149}
150
151fn default_temperature() -> f64 {
152 0.7
153}
154fn default_top_p() -> f64 {
155 0.9
156}
157fn default_max_tokens() -> usize {
158 4096
159}
160
161impl Default for GenerateParams {
162 fn default() -> Self {
163 Self {
164 temperature: default_temperature(),
165 top_p: default_top_p(),
166 top_k: 0,
167 max_tokens: default_max_tokens(),
168 stop: Vec::new(),
169 budget_tokens: 0,
170 workload: RoutingWorkload::Interactive,
171 tool_choice: None,
172 parallel_tool_calls: None,
173 thinking: ThinkingMode::default(),
174 }
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct ToolCall {
181 #[serde(default, skip_serializing_if = "Option::is_none")]
185 pub id: Option<String>,
186 pub name: String,
188 pub arguments: std::collections::HashMap<String, serde_json::Value>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
204#[serde(tag = "type", rename_all = "snake_case")]
205pub enum ContentBlock {
206 Text { text: String },
208 ImageBase64 {
210 data: String,
212 media_type: String,
214 },
215 ImageUrl {
217 url: String,
219 #[serde(default = "default_detail")]
221 detail: String,
222 },
223 VideoPath {
227 path: String,
228 #[serde(default, skip_serializing_if = "Option::is_none")]
229 fps: Option<f32>,
230 #[serde(default, skip_serializing_if = "Option::is_none")]
231 max_frames: Option<u32>,
232 },
233 VideoUrl {
235 url: String,
236 #[serde(default, skip_serializing_if = "Option::is_none")]
237 fps: Option<f32>,
238 #[serde(default, skip_serializing_if = "Option::is_none")]
239 max_frames: Option<u32>,
240 },
241 VideoBase64 {
244 data: String,
245 media_type: String,
246 #[serde(default, skip_serializing_if = "Option::is_none")]
247 fps: Option<f32>,
248 #[serde(default, skip_serializing_if = "Option::is_none")]
249 max_frames: Option<u32>,
250 },
251 AudioPath {
254 path: String,
255 #[serde(default, skip_serializing_if = "Option::is_none")]
258 sample_rate: Option<u32>,
259 },
260 AudioUrl {
262 url: String,
263 #[serde(default, skip_serializing_if = "Option::is_none")]
264 sample_rate: Option<u32>,
265 },
266 AudioBase64 {
268 data: String,
269 media_type: String,
270 #[serde(default, skip_serializing_if = "Option::is_none")]
271 sample_rate: Option<u32>,
272 },
273}
274
275impl ContentBlock {
276 pub fn is_video(&self) -> bool {
280 matches!(
281 self,
282 ContentBlock::VideoPath { .. }
283 | ContentBlock::VideoUrl { .. }
284 | ContentBlock::VideoBase64 { .. }
285 )
286 }
287
288 pub fn is_audio(&self) -> bool {
294 matches!(
295 self,
296 ContentBlock::AudioPath { .. }
297 | ContentBlock::AudioUrl { .. }
298 | ContentBlock::AudioBase64 { .. }
299 )
300 }
301}
302
303fn default_detail() -> String {
304 "auto".to_string()
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
317#[serde(tag = "role", rename_all = "snake_case")]
318pub enum Message {
319 System { content: String },
321 User { content: String },
323 UserMultimodal { content: Vec<ContentBlock> },
325 Assistant {
327 #[serde(default)]
328 content: String,
329 #[serde(default)]
330 tool_calls: Vec<ToolCall>,
331 },
332 ToolResult {
334 tool_use_id: String,
335 content: String,
336 },
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
359#[serde(tag = "type", rename_all = "snake_case")]
360pub enum ResponseFormat {
361 JsonSchema {
367 schema: serde_json::Value,
368 #[serde(default)]
369 strict: bool,
370 #[serde(default, skip_serializing_if = "Option::is_none")]
371 name: Option<String>,
372 },
373 JsonObject,
376}
377
378#[derive(Debug, Clone, Default, Serialize, Deserialize)]
387pub struct GenerateRequest {
388 pub prompt: String,
390 pub model: Option<String>,
392 #[serde(default)]
394 pub params: GenerateParams,
395 #[serde(default)]
399 pub context: Option<String>,
400 #[serde(default, skip_serializing_if = "Option::is_none")]
404 pub tools: Option<Vec<serde_json::Value>>,
405 #[serde(default, skip_serializing_if = "Option::is_none")]
409 pub images: Option<Vec<ContentBlock>>,
410 #[serde(default, skip_serializing_if = "Option::is_none")]
415 pub messages: Option<Vec<Message>>,
416 #[serde(default)]
420 pub cache_control: bool,
421 #[serde(default, skip_serializing_if = "Option::is_none")]
425 pub response_format: Option<ResponseFormat>,
426 #[serde(default, skip_serializing_if = "Option::is_none")]
432 pub intent: Option<crate::intent::IntentHint>,
433}
434
435pub fn apply_chat_template(
455 prompt: &str,
456 context: Option<&str>,
457 thinking: ThinkingMode,
458) -> String {
459 if prompt.contains("<|im_start|>") {
460 return prompt.to_string();
461 }
462 let directive_line = match thinking.directive() {
467 Some(d) => format!("\n{d}"),
468 None => String::new(),
469 };
470 let thinking_prefill = match thinking {
475 ThinkingMode::Off => "<think>\n\n</think>\n\n",
476 _ => "",
477 };
478 match context {
479 Some(ctx) => format!(
480 "<|im_start|>system\nYou are a helpful assistant. Use the following context to inform your response.\n\n{ctx}{directive_line}<|im_end|>\n\
481 <|im_start|>user\n{prompt}<|im_end|>\n\
482 <|im_start|>assistant\n{thinking_prefill}"
483 ),
484 None => format!(
485 "<|im_start|>system\nYou are a helpful assistant.{directive_line}<|im_end|>\n\
486 <|im_start|>user\n{prompt}<|im_end|>\n\
487 <|im_start|>assistant\n{thinking_prefill}"
488 ),
489 }
490}
491
492pub fn strip_thinking(text: &str, thinking: ThinkingMode) -> String {
503 if matches!(thinking, ThinkingMode::On) {
504 return text.to_string();
505 }
506 strip_thinking_block(text)
507}
508
509fn strip_thinking_block(text: &str) -> String {
512 if let Some(end) = text.find("</think>") {
513 text[end + 8..].trim_start().to_string()
514 } else if text.contains("<think>") {
515 String::new()
516 } else {
517 text.to_string()
518 }
519}
520
521pub type RetrievalCallback = Box<dyn Fn(&str) -> Option<String> + Send>;
524
525#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
526pub async fn generate(
534 backend: &mut CandleBackend,
535 req: GenerateRequest,
536) -> Result<(String, Option<u64>), InferenceError> {
537 let start = std::time::Instant::now();
538
539 backend.clear_kv_cache();
541
542 let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
543 let tokens = backend.encode(&formatted)?;
544 let eos = backend.eos_token_id();
545 let eos_alt = backend.token_id("<|im_end|>");
546 let params = &req.params;
547
548 if tokens.is_empty() {
549 return Ok((String::new(), None));
550 }
551
552 let max_ctx = backend.context_length().unwrap_or(32768);
555 let headroom = params.max_tokens.min(max_ctx / 4);
556 let max_prompt = max_ctx.saturating_sub(headroom);
557 let tokens = if tokens.len() > max_prompt {
558 eprintln!(
559 "[car-inference] truncating prompt from {} to {} tokens (context_length={})",
560 tokens.len(),
561 max_prompt,
562 max_ctx
563 );
564 tokens[tokens.len() - max_prompt..].to_vec()
565 } else {
566 tokens
567 };
568
569 let mut generated = Vec::new();
570
571 let logits = backend.forward(&tokens, 0)?;
573 let mut next_token = sample_token(&logits, params)?;
574 let ttft_ms = Some(start.elapsed().as_millis() as u64);
575
576 for _i in 0..params.max_tokens {
577 if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
579 {
580 break;
581 }
582
583 generated.push(next_token);
584
585 if !params.stop.is_empty() {
587 let text_so_far = backend.decode(&generated)?;
588 if params.stop.iter().any(|s| text_so_far.contains(s)) {
589 break;
590 }
591 }
592
593 let pos = tokens.len() + generated.len() - 1;
595 let logits = backend.forward(&[next_token], pos)?;
596 next_token = sample_token(&logits, params)?;
597 }
598
599 let text = backend.decode(&generated)?;
600 Ok((strip_thinking(&text, params.thinking), ttft_ms))
601}
602
603#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
604pub async fn generate_with_retrieval(
610 backend: &mut CandleBackend,
611 mut req: GenerateRequest,
612 retrieval_cb: RetrievalCallback,
613) -> Result<String, InferenceError> {
614 backend.clear_kv_cache();
616 let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
617 let tokens = backend.encode(&formatted)?;
618 let eos = backend.eos_token_id();
619 let eos_alt = backend.token_id("<|im_end|>");
620 let params = req.params.clone();
621
622 if tokens.is_empty() {
623 return Ok(String::new());
624 }
625
626 let mut generated = Vec::new();
627 let mut low_confidence_count = 0u32;
628 let mut retrieval_attempts = 0u32;
629 let max_retrievals = 2;
630 let confidence_threshold = 0.4f32;
631 let low_confidence_window = 3u32;
632
633 let logits = backend.forward(&tokens, 0)?;
634 let mut next_token = sample_token(&logits, ¶ms)?;
635
636 for _i in 0..params.max_tokens {
637 if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
638 {
639 break;
640 }
641
642 generated.push(next_token);
643
644 let pos = tokens.len() + generated.len() - 1;
646 let logits = backend.forward(&[next_token], pos)?;
647
648 let logits_f32: Vec<f32> = logits
650 .squeeze(0)
651 .unwrap_or(logits.clone())
652 .to_dtype(candle_core::DType::F32)
653 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
654 .to_vec1()
655 .unwrap_or_default();
656
657 if !logits_f32.is_empty() {
658 let max_logit = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
660 let exp_sum: f32 = logits_f32.iter().map(|&v| (v - max_logit).exp()).sum();
661 let max_prob = 1.0 / exp_sum; if max_prob < confidence_threshold {
664 low_confidence_count += 1;
665 } else {
666 low_confidence_count = 0;
667 }
668
669 if low_confidence_count >= low_confidence_window && retrieval_attempts < max_retrievals
671 {
672 retrieval_attempts += 1;
673 low_confidence_count = 0;
674
675 let partial = backend.decode(&generated)?;
677 if let Some(new_context) = retrieval_cb(&partial) {
678 let combined_context = match req.context.take() {
680 Some(old) => format!("{}\n\n{}", old, new_context),
681 None => new_context,
682 };
683 req.context = Some(combined_context);
684
685 backend.clear_kv_cache();
687 let new_formatted =
688 apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
689 let new_tokens = backend.encode(&new_formatted)?;
690 generated.clear();
691
692 let logits = backend.forward(&new_tokens, 0)?;
693 next_token = sample_token(&logits, ¶ms)?;
694 continue;
695 }
696 }
697 }
698
699 next_token = sample_token(&logits, ¶ms)?;
700 }
701
702 let text = backend.decode(&generated)?;
703 Ok(strip_thinking(&text, params.thinking))
704}
705
706#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
707pub fn sample_token_suppress(
709 logits: &Tensor,
710 params: &GenerateParams,
711 suppress: &[u32],
712) -> Result<u32, InferenceError> {
713 if suppress.is_empty() {
714 return sample_token(logits, params);
715 }
716 let mut logits_vec: Vec<f32> = logits
718 .squeeze(0)
719 .unwrap_or(logits.clone())
720 .to_dtype(candle_core::DType::F32)
721 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
722 .to_vec1()
723 .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
724 let dims = logits.dims();
726 if dims.len() == 2 {
727 let vocab = dims[dims.len() - 1];
728 let start = logits_vec.len() - vocab;
729 logits_vec = logits_vec[start..].to_vec();
730 }
731 for &id in suppress {
732 if (id as usize) < logits_vec.len() {
733 logits_vec[id as usize] = f32::NEG_INFINITY;
734 }
735 }
736 let modified = Tensor::from_vec(
737 logits_vec,
738 logits.squeeze(0).unwrap_or(logits.clone()).shape(),
739 logits.device(),
740 )
741 .map_err(|e| InferenceError::InferenceFailed(format!("from_vec: {e}")))?;
742 sample_token(&modified, params)
743}
744
745#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
746pub fn sample_token(logits: &Tensor, params: &GenerateParams) -> Result<u32, InferenceError> {
748 let logits = logits
749 .squeeze(0)
750 .map_err(|e| InferenceError::InferenceFailed(format!("squeeze: {e}")))?;
751 let logits = logits
752 .to_dtype(candle_core::DType::F32)
753 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?;
754
755 let dim = logits.dims();
757 let logits = if dim.len() == 2 {
758 logits
759 .get(dim[0] - 1)
760 .map_err(|e| InferenceError::InferenceFailed(format!("get last: {e}")))?
761 } else {
762 logits
763 };
764
765 if params.temperature <= 0.0 {
767 let token = logits
768 .argmax(0)
769 .map_err(|e| InferenceError::InferenceFailed(format!("argmax: {e}")))?
770 .to_scalar::<u32>()
771 .map_err(|e| InferenceError::InferenceFailed(format!("scalar: {e}")))?;
772 return Ok(token);
773 }
774
775 let logits = (&logits / params.temperature)
777 .map_err(|e| InferenceError::InferenceFailed(format!("temp scale: {e}")))?;
778
779 let mut logits_vec: Vec<f32> = logits
780 .to_vec1()
781 .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
782
783 if params.top_k > 0 && params.top_k < logits_vec.len() {
785 let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
786 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
787 let threshold = indexed[params.top_k].1;
788 for v in &mut logits_vec {
789 if *v < threshold {
790 *v = f32::NEG_INFINITY;
791 }
792 }
793 }
794
795 let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
797 let exp: Vec<f32> = logits_vec.iter().map(|&v| (v - max_logit).exp()).collect();
798 let sum: f32 = exp.iter().sum();
799 let mut probs: Vec<f32> = exp.iter().map(|&v| v / sum).collect();
800
801 if params.top_p < 1.0 {
803 let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
804 sorted_indices.sort_by(|&a, &b| {
805 probs[b]
806 .partial_cmp(&probs[a])
807 .unwrap_or(std::cmp::Ordering::Equal)
808 });
809
810 let mut cumsum = 0.0f32;
811 let mut cutoff_idx = sorted_indices.len();
812 for (i, &idx) in sorted_indices.iter().enumerate() {
813 cumsum += probs[idx];
814 if cumsum > params.top_p as f32 {
815 cutoff_idx = i + 1;
816 break;
817 }
818 }
819
820 let keep: std::collections::HashSet<usize> =
821 sorted_indices[..cutoff_idx].iter().copied().collect();
822 for (i, p) in probs.iter_mut().enumerate() {
823 if !keep.contains(&i) {
824 *p = 0.0;
825 }
826 }
827
828 let sum: f32 = probs.iter().sum();
830 if sum > 0.0 {
831 for p in &mut probs {
832 *p /= sum;
833 }
834 }
835 }
836
837 let r: f32 = rand_f32();
839 let mut cumsum = 0.0f32;
840 for (i, &p) in probs.iter().enumerate() {
841 cumsum += p;
842 if cumsum >= r {
843 return Ok(i as u32);
844 }
845 }
846
847 Ok(probs
849 .iter()
850 .enumerate()
851 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
852 .map(|(i, _)| i as u32)
853 .unwrap_or(0))
854}
855
856#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
857fn rand_f32() -> f32 {
859 rand::random::<f32>()
860}
861
862#[cfg(test)]
863mod thinking_tests {
864 use super::*;
865
866 #[test]
867 fn auto_injects_no_directive_and_no_prefill() {
868 let out = apply_chat_template("hi", None, ThinkingMode::Auto);
869 assert!(!out.contains("/no_think"));
870 assert!(!out.contains("/think"));
871 assert!(!out.contains("<think>"));
872 assert!(out.contains("<|im_start|>user\nhi<|im_end|>"));
873 }
874
875 #[test]
876 fn off_injects_no_think_on_own_line_and_prefills_empty_think() {
877 let out = apply_chat_template("hi", None, ThinkingMode::Off);
878 assert!(out.contains("\n/no_think<|im_end|>"));
880 assert!(!out.contains(" /no_think"));
881 assert!(out.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
884 }
885
886 #[test]
887 fn on_injects_think_and_no_prefill() {
888 let out = apply_chat_template("hi", None, ThinkingMode::On);
889 assert!(out.contains("\n/think<|im_end|>"));
890 assert!(!out.contains("/no_think"));
891 assert!(!out.contains("<think>"));
892 }
893
894 #[test]
895 fn pre_formatted_prompt_is_untouched() {
896 let pre = "<|im_start|>system\ncustom<|im_end|>\n<|im_start|>user\nhi<|im_end|>";
897 let out = apply_chat_template(pre, None, ThinkingMode::Off);
898 assert_eq!(out, pre);
899 }
900
901 #[test]
902 fn directive_appears_after_context_not_before() {
903 let out = apply_chat_template("q?", Some("some memory"), ThinkingMode::Off);
904 let ctx_idx = out.find("some memory").unwrap();
905 let directive_idx = out.find("/no_think").unwrap();
906 assert!(
907 directive_idx > ctx_idx,
908 "directive must appear after context so user memory cannot nudge the parse"
909 );
910 }
911
912 #[test]
913 fn default_params_is_auto() {
914 assert_eq!(GenerateParams::default().thinking, ThinkingMode::Auto);
915 }
916
917 #[test]
918 fn thinking_mode_serde_snake_case() {
919 let json = serde_json::to_string(&ThinkingMode::Off).unwrap();
920 assert_eq!(json, "\"off\"");
921 let parsed: ThinkingMode = serde_json::from_str("\"on\"").unwrap();
922 assert_eq!(parsed, ThinkingMode::On);
923 }
924
925 #[test]
926 fn strip_preserves_thinking_when_on() {
927 let text = "<think>reasoning here</think>the answer";
928 let out = strip_thinking(text, ThinkingMode::On);
929 assert_eq!(out, text, "On mode must return raw text with <think> visible");
930 }
931
932 #[test]
933 fn strip_removes_thinking_when_auto_or_off() {
934 let text = "<think>reasoning</think>the answer";
935 assert_eq!(strip_thinking(text, ThinkingMode::Auto), "the answer");
936 assert_eq!(strip_thinking(text, ThinkingMode::Off), "the answer");
937 }
938
939 #[test]
940 fn strip_returns_empty_on_unterminated_think() {
941 let text = "<think>mid-reasoning, never closed";
943 assert_eq!(strip_thinking(text, ThinkingMode::Auto), "");
944 assert_eq!(strip_thinking(text, ThinkingMode::Off), "");
945 assert_eq!(strip_thinking(text, ThinkingMode::On), text);
947 }
948
949 #[test]
950 fn strip_is_noop_when_no_think_tag() {
951 let text = "just a plain answer";
952 assert_eq!(strip_thinking(text, ThinkingMode::Auto), text);
953 assert_eq!(strip_thinking(text, ThinkingMode::Off), text);
954 assert_eq!(strip_thinking(text, ThinkingMode::On), text);
955 }
956}