#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use candle_core::Tensor;
use serde::{Deserialize, Serialize};
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use crate::backend::CandleBackend;
use crate::InferenceError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum RoutingWorkload {
#[default]
Interactive,
Batch,
Background,
LocalPreferred,
Fastest,
}
impl RoutingWorkload {
pub fn is_latency_sensitive(self) -> bool {
matches!(
self,
RoutingWorkload::Interactive
| RoutingWorkload::LocalPreferred
| RoutingWorkload::Fastest,
)
}
pub fn weights(self) -> (f64, f64, f64) {
match self {
RoutingWorkload::Interactive => (0.45, 0.40, 0.15),
RoutingWorkload::Batch => (0.60, 0.15, 0.25),
RoutingWorkload::Background => (0.65, 0.05, 0.30),
RoutingWorkload::LocalPreferred => (0.55, 0.20, 0.25),
RoutingWorkload::Fastest => (0.10, 0.85, 0.05),
}
}
pub fn local_bonus(self) -> f64 {
match self {
RoutingWorkload::Interactive => 0.0,
RoutingWorkload::Batch => 0.08,
RoutingWorkload::Background => 0.15,
RoutingWorkload::LocalPreferred => 0.20,
RoutingWorkload::Fastest => 0.20,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ThinkingMode {
#[default]
Auto,
On,
Off,
}
impl ThinkingMode {
pub fn directive(self) -> Option<&'static str> {
match self {
ThinkingMode::Auto => None,
ThinkingMode::On => Some("/think"),
ThinkingMode::Off => Some("/no_think"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateParams {
#[serde(default = "default_temperature")]
pub temperature: f64,
#[serde(default = "default_top_p")]
pub top_p: f64,
#[serde(default)]
pub top_k: usize,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
#[serde(default)]
pub stop: Vec<String>,
#[serde(default)]
pub budget_tokens: usize,
#[serde(default)]
pub workload: RoutingWorkload,
#[serde(default)]
pub tool_choice: Option<String>,
#[serde(default)]
pub parallel_tool_calls: Option<bool>,
#[serde(default)]
pub thinking: ThinkingMode,
}
fn default_temperature() -> f64 {
0.7
}
fn default_top_p() -> f64 {
0.9
}
fn default_max_tokens() -> usize {
4096
}
impl Default for GenerateParams {
fn default() -> Self {
Self {
temperature: default_temperature(),
top_p: default_top_p(),
top_k: 0,
max_tokens: default_max_tokens(),
stop: Vec::new(),
budget_tokens: 0,
workload: RoutingWorkload::Interactive,
tool_choice: None,
parallel_tool_calls: None,
thinking: ThinkingMode::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub name: String,
pub arguments: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
ImageBase64 {
data: String,
media_type: String,
},
ImageUrl {
url: String,
#[serde(default = "default_detail")]
detail: String,
},
VideoPath {
path: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
fps: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
max_frames: Option<u32>,
},
VideoUrl {
url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
fps: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
max_frames: Option<u32>,
},
VideoBase64 {
data: String,
media_type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
fps: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
max_frames: Option<u32>,
},
AudioPath {
path: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
sample_rate: Option<u32>,
},
AudioUrl {
url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
sample_rate: Option<u32>,
},
AudioBase64 {
data: String,
media_type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
sample_rate: Option<u32>,
},
}
impl ContentBlock {
pub fn is_video(&self) -> bool {
matches!(
self,
ContentBlock::VideoPath { .. }
| ContentBlock::VideoUrl { .. }
| ContentBlock::VideoBase64 { .. }
)
}
pub fn is_audio(&self) -> bool {
matches!(
self,
ContentBlock::AudioPath { .. }
| ContentBlock::AudioUrl { .. }
| ContentBlock::AudioBase64 { .. }
)
}
}
fn default_detail() -> String {
"auto".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "role", rename_all = "snake_case")]
pub enum Message {
System { content: String },
User { content: String },
UserMultimodal { content: Vec<ContentBlock> },
Assistant {
#[serde(default)]
content: String,
#[serde(default)]
tool_calls: Vec<ToolCall>,
},
ToolResult {
tool_use_id: String,
content: String,
},
ProviderOutputItems {
protocol: String,
items: Vec<serde_json::Value>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseFormat {
JsonSchema {
schema: serde_json::Value,
#[serde(default)]
strict: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
JsonObject,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GenerateRequest {
pub prompt: String,
pub model: Option<String>,
#[serde(default)]
pub params: GenerateParams,
#[serde(default)]
pub context: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub images: Option<Vec<ContentBlock>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub messages: Option<Vec<Message>>,
#[serde(default)]
pub cache_control: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub intent: Option<crate::intent::IntentHint>,
}
pub fn apply_chat_template(prompt: &str, context: Option<&str>, thinking: ThinkingMode) -> String {
if prompt.contains("<|im_start|>") {
return prompt.to_string();
}
let directive_line = match thinking.directive() {
Some(d) => format!("\n{d}"),
None => String::new(),
};
let thinking_prefill = match thinking {
ThinkingMode::Off => "<think>\n\n</think>\n\n",
_ => "",
};
match context {
Some(ctx) => format!(
"<|im_start|>system\nYou are a helpful assistant. Use the following context to inform your response.\n\n{ctx}{directive_line}<|im_end|>\n\
<|im_start|>user\n{prompt}<|im_end|>\n\
<|im_start|>assistant\n{thinking_prefill}"
),
None => format!(
"<|im_start|>system\nYou are a helpful assistant.{directive_line}<|im_end|>\n\
<|im_start|>user\n{prompt}<|im_end|>\n\
<|im_start|>assistant\n{thinking_prefill}"
),
}
}
pub fn strip_thinking(text: &str, thinking: ThinkingMode) -> String {
if matches!(thinking, ThinkingMode::On) {
return text.to_string();
}
strip_thinking_block(text)
}
fn strip_thinking_block(text: &str) -> String {
if let Some(end) = text.find("</think>") {
text[end + 8..].trim_start().to_string()
} else if text.contains("<think>") {
tracing::warn!(
target: "car_inference::tasks::generate",
raw_len = text.len(),
"model output opened <think> but never closed it — \
likely truncated by max_tokens; returning empty text. \
Increase max_tokens, or set thinking=off to suppress \
the reasoning phase."
);
String::new()
} else {
text.to_string()
}
}
pub type RetrievalCallback = Box<dyn Fn(&str) -> Option<String> + Send>;
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
pub async fn generate(
backend: &mut CandleBackend,
req: GenerateRequest,
) -> Result<(String, Option<u64>), InferenceError> {
let start = std::time::Instant::now();
backend.clear_kv_cache();
let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
let tokens = backend.encode(&formatted)?;
let eos = backend.eos_token_id();
let eos_alt = backend.token_id("<|im_end|>");
let params = &req.params;
if tokens.is_empty() {
return Ok((String::new(), None));
}
let max_ctx = backend.context_length().unwrap_or(32768);
let headroom = params.max_tokens.min(max_ctx / 4);
let max_prompt = max_ctx.saturating_sub(headroom);
let tokens = if tokens.len() > max_prompt {
eprintln!(
"[car-inference] truncating prompt from {} to {} tokens (context_length={})",
tokens.len(),
max_prompt,
max_ctx
);
tokens[tokens.len() - max_prompt..].to_vec()
} else {
tokens
};
let mut generated = Vec::new();
let logits = backend.forward(&tokens, 0)?;
let mut next_token = sample_token(&logits, params)?;
let ttft_ms = Some(start.elapsed().as_millis() as u64);
for _i in 0..params.max_tokens {
if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
{
break;
}
generated.push(next_token);
if !params.stop.is_empty() {
let text_so_far = backend.decode(&generated)?;
if params.stop.iter().any(|s| text_so_far.contains(s)) {
break;
}
}
let pos = tokens.len() + generated.len() - 1;
let logits = backend.forward(&[next_token], pos)?;
next_token = sample_token(&logits, params)?;
}
let text = backend.decode(&generated)?;
Ok((strip_thinking(&text, params.thinking), ttft_ms))
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
pub async fn generate_with_retrieval(
backend: &mut CandleBackend,
mut req: GenerateRequest,
retrieval_cb: RetrievalCallback,
) -> Result<String, InferenceError> {
backend.clear_kv_cache();
let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
let tokens = backend.encode(&formatted)?;
let eos = backend.eos_token_id();
let eos_alt = backend.token_id("<|im_end|>");
let params = req.params.clone();
if tokens.is_empty() {
return Ok(String::new());
}
let mut generated = Vec::new();
let mut low_confidence_count = 0u32;
let mut retrieval_attempts = 0u32;
let max_retrievals = 2;
let confidence_threshold = 0.4f32;
let low_confidence_window = 3u32;
let logits = backend.forward(&tokens, 0)?;
let mut next_token = sample_token(&logits, ¶ms)?;
for _i in 0..params.max_tokens {
if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
{
break;
}
generated.push(next_token);
let pos = tokens.len() + generated.len() - 1;
let logits = backend.forward(&[next_token], pos)?;
let logits_f32: Vec<f32> = logits
.squeeze(0)
.unwrap_or(logits.clone())
.to_dtype(candle_core::DType::F32)
.map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
.to_vec1()
.unwrap_or_default();
if !logits_f32.is_empty() {
let max_logit = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits_f32.iter().map(|&v| (v - max_logit).exp()).sum();
let max_prob = 1.0 / exp_sum;
if max_prob < confidence_threshold {
low_confidence_count += 1;
} else {
low_confidence_count = 0;
}
if low_confidence_count >= low_confidence_window && retrieval_attempts < max_retrievals
{
retrieval_attempts += 1;
low_confidence_count = 0;
let partial = backend.decode(&generated)?;
if let Some(new_context) = retrieval_cb(&partial) {
let combined_context = match req.context.take() {
Some(old) => format!("{}\n\n{}", old, new_context),
None => new_context,
};
req.context = Some(combined_context);
backend.clear_kv_cache();
let new_formatted = apply_chat_template(
&req.prompt,
req.context.as_deref(),
req.params.thinking,
);
let new_tokens = backend.encode(&new_formatted)?;
generated.clear();
let logits = backend.forward(&new_tokens, 0)?;
next_token = sample_token(&logits, ¶ms)?;
continue;
}
}
}
next_token = sample_token(&logits, ¶ms)?;
}
let text = backend.decode(&generated)?;
Ok(strip_thinking(&text, params.thinking))
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
pub fn sample_token_suppress(
logits: &Tensor,
params: &GenerateParams,
suppress: &[u32],
) -> Result<u32, InferenceError> {
if suppress.is_empty() {
return sample_token(logits, params);
}
let mut logits_vec: Vec<f32> = logits
.squeeze(0)
.unwrap_or(logits.clone())
.to_dtype(candle_core::DType::F32)
.map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
.to_vec1()
.map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
let dims = logits.dims();
if dims.len() == 2 {
let vocab = dims[dims.len() - 1];
let start = logits_vec.len() - vocab;
logits_vec = logits_vec[start..].to_vec();
}
for &id in suppress {
if (id as usize) < logits_vec.len() {
logits_vec[id as usize] = f32::NEG_INFINITY;
}
}
let modified = Tensor::from_vec(
logits_vec,
logits.squeeze(0).unwrap_or(logits.clone()).shape(),
logits.device(),
)
.map_err(|e| InferenceError::InferenceFailed(format!("from_vec: {e}")))?;
sample_token(&modified, params)
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
pub fn sample_token(logits: &Tensor, params: &GenerateParams) -> Result<u32, InferenceError> {
let logits = logits
.squeeze(0)
.map_err(|e| InferenceError::InferenceFailed(format!("squeeze: {e}")))?;
let logits = logits
.to_dtype(candle_core::DType::F32)
.map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?;
let dim = logits.dims();
let logits = if dim.len() == 2 {
logits
.get(dim[0] - 1)
.map_err(|e| InferenceError::InferenceFailed(format!("get last: {e}")))?
} else {
logits
};
if params.temperature <= 0.0 {
let token = logits
.argmax(0)
.map_err(|e| InferenceError::InferenceFailed(format!("argmax: {e}")))?
.to_scalar::<u32>()
.map_err(|e| InferenceError::InferenceFailed(format!("scalar: {e}")))?;
return Ok(token);
}
let logits = (&logits / params.temperature)
.map_err(|e| InferenceError::InferenceFailed(format!("temp scale: {e}")))?;
let mut logits_vec: Vec<f32> = logits
.to_vec1()
.map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
if params.top_k > 0 && params.top_k < logits_vec.len() {
let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let threshold = indexed[params.top_k].1;
for v in &mut logits_vec {
if *v < threshold {
*v = f32::NEG_INFINITY;
}
}
}
let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp: Vec<f32> = logits_vec.iter().map(|&v| (v - max_logit).exp()).collect();
let sum: f32 = exp.iter().sum();
let mut probs: Vec<f32> = exp.iter().map(|&v| v / sum).collect();
if params.top_p < 1.0 {
let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
sorted_indices.sort_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut cumsum = 0.0f32;
let mut cutoff_idx = sorted_indices.len();
for (i, &idx) in sorted_indices.iter().enumerate() {
cumsum += probs[idx];
if cumsum > params.top_p as f32 {
cutoff_idx = i + 1;
break;
}
}
let keep: std::collections::HashSet<usize> =
sorted_indices[..cutoff_idx].iter().copied().collect();
for (i, p) in probs.iter_mut().enumerate() {
if !keep.contains(&i) {
*p = 0.0;
}
}
let sum: f32 = probs.iter().sum();
if sum > 0.0 {
for p in &mut probs {
*p /= sum;
}
}
}
let r: f32 = rand_f32();
let mut cumsum = 0.0f32;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if cumsum >= r {
return Ok(i as u32);
}
}
Ok(probs
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0))
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn rand_f32() -> f32 {
rand::random::<f32>()
}
#[cfg(test)]
mod thinking_tests {
use super::*;
#[test]
fn auto_injects_no_directive_and_no_prefill() {
let out = apply_chat_template("hi", None, ThinkingMode::Auto);
assert!(!out.contains("/no_think"));
assert!(!out.contains("/think"));
assert!(!out.contains("<think>"));
assert!(out.contains("<|im_start|>user\nhi<|im_end|>"));
}
#[test]
fn off_injects_no_think_on_own_line_and_prefills_empty_think() {
let out = apply_chat_template("hi", None, ThinkingMode::Off);
assert!(out.contains("\n/no_think<|im_end|>"));
assert!(!out.contains(" /no_think"));
assert!(out.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
}
#[test]
fn on_injects_think_and_no_prefill() {
let out = apply_chat_template("hi", None, ThinkingMode::On);
assert!(out.contains("\n/think<|im_end|>"));
assert!(!out.contains("/no_think"));
assert!(!out.contains("<think>"));
}
#[test]
fn pre_formatted_prompt_is_untouched() {
let pre = "<|im_start|>system\ncustom<|im_end|>\n<|im_start|>user\nhi<|im_end|>";
let out = apply_chat_template(pre, None, ThinkingMode::Off);
assert_eq!(out, pre);
}
#[test]
fn directive_appears_after_context_not_before() {
let out = apply_chat_template("q?", Some("some memory"), ThinkingMode::Off);
let ctx_idx = out.find("some memory").unwrap();
let directive_idx = out.find("/no_think").unwrap();
assert!(
directive_idx > ctx_idx,
"directive must appear after context so user memory cannot nudge the parse"
);
}
#[test]
fn default_params_is_auto() {
assert_eq!(GenerateParams::default().thinking, ThinkingMode::Auto);
}
#[test]
fn thinking_mode_serde_snake_case() {
let json = serde_json::to_string(&ThinkingMode::Off).unwrap();
assert_eq!(json, "\"off\"");
let parsed: ThinkingMode = serde_json::from_str("\"on\"").unwrap();
assert_eq!(parsed, ThinkingMode::On);
}
#[test]
fn strip_preserves_thinking_when_on() {
let text = "<think>reasoning here</think>the answer";
let out = strip_thinking(text, ThinkingMode::On);
assert_eq!(
out, text,
"On mode must return raw text with <think> visible"
);
}
#[test]
fn strip_removes_thinking_when_auto_or_off() {
let text = "<think>reasoning</think>the answer";
assert_eq!(strip_thinking(text, ThinkingMode::Auto), "the answer");
assert_eq!(strip_thinking(text, ThinkingMode::Off), "the answer");
}
#[test]
fn strip_returns_empty_on_unterminated_think() {
let text = "<think>mid-reasoning, never closed";
assert_eq!(strip_thinking(text, ThinkingMode::Auto), "");
assert_eq!(strip_thinking(text, ThinkingMode::Off), "");
assert_eq!(strip_thinking(text, ThinkingMode::On), text);
}
#[test]
fn strip_is_noop_when_no_think_tag() {
let text = "just a plain answer";
assert_eq!(strip_thinking(text, ThinkingMode::Auto), text);
assert_eq!(strip_thinking(text, ThinkingMode::Off), text);
assert_eq!(strip_thinking(text, ThinkingMode::On), text);
}
}
#[cfg(test)]
mod workload_tests {
use super::*;
#[test]
fn all_workload_weights_sum_to_one() {
for w in [
RoutingWorkload::Interactive,
RoutingWorkload::Batch,
RoutingWorkload::Background,
RoutingWorkload::LocalPreferred,
RoutingWorkload::Fastest,
] {
let (q, l, c) = w.weights();
let sum = q + l + c;
assert!(
(sum - 1.0).abs() < 1e-6,
"weights for {w:?} sum to {sum}, expected 1.0"
);
}
}
#[test]
fn fastest_weights_dominate_on_latency() {
let (q, l, c) = RoutingWorkload::Fastest.weights();
assert!(l > q && l > c);
assert!(l >= 0.7, "latency weight too small: {l}");
}
#[test]
fn fastest_is_latency_sensitive() {
assert!(RoutingWorkload::Fastest.is_latency_sensitive());
}
}