car_inference/tasks/generate.rs
1//! Text generation with sampling.
2
3#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4use candle_core::Tensor;
5use serde::{Deserialize, Serialize};
6
7#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
8use crate::backend::CandleBackend;
9use crate::InferenceError;
10
11/// How latency-sensitive this generation request is.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
13#[serde(rename_all = "snake_case")]
14pub enum RoutingWorkload {
15 /// User-facing, interactive request where latency matters.
16 #[default]
17 Interactive,
18 /// Batch job where latency matters somewhat, but quality/cost matter more.
19 Batch,
20 /// Background or offline work where latency is a weak concern.
21 Background,
22 /// Caller explicitly prefers on-device models. Distinct from
23 /// `Background` (which is "this is a background job, latency
24 /// barely matters"). The caller may be doing latency-sensitive
25 /// interactive work but wants the privacy / cost / offline
26 /// properties of local inference. Same `local_bonus` as
27 /// `Background` plus a slightly more quality-aware weight profile
28 /// — the caller chose local for a reason, not because the work is
29 /// throwaway.
30 LocalPreferred,
31 /// Aggressive latency bias for time-to-first-token. Voice turns
32 /// (specifically the fast track in the two-track sidecar pattern)
33 /// pick this. Quality and cost are heavily downweighted; the
34 /// router prefers whichever model produces a first token soonest.
35 /// On macOS 26+ this typically resolves to `apple/foundation:default`
36 /// via the Foundation Models system-LLM bonus. Reached via the
37 /// `IntentHint::prefer_fast` flag (or `RoutingWorkload::Fastest`
38 /// directly when callers know they want it).
39 Fastest,
40}
41
42impl RoutingWorkload {
43 pub fn is_latency_sensitive(self) -> bool {
44 matches!(
45 self,
46 RoutingWorkload::Interactive
47 | RoutingWorkload::LocalPreferred
48 | RoutingWorkload::Fastest,
49 )
50 }
51
52 pub fn weights(self) -> (f64, f64, f64) {
53 // Tuple is `(quality, latency, cost)` — destructured at the
54 // single use site `adaptive_router.rs:892`.
55 match self {
56 RoutingWorkload::Interactive => (0.45, 0.40, 0.15),
57 RoutingWorkload::Batch => (0.60, 0.15, 0.25),
58 RoutingWorkload::Background => (0.65, 0.05, 0.30),
59 // Quality-aware (closer to Interactive) but tolerant of
60 // some latency hit since the caller chose local. Cost
61 // weight matches Batch.
62 RoutingWorkload::LocalPreferred => (0.55, 0.20, 0.25),
63 // Voice fast track: latency is everything. Quality and
64 // cost are deliberately near-floor — first audio in
65 // <500ms beats any quality gain that takes another
66 // round-trip. Sums to 1.0 like every other variant.
67 RoutingWorkload::Fastest => (0.10, 0.85, 0.05),
68 }
69 }
70
71 pub fn local_bonus(self) -> f64 {
72 match self {
73 RoutingWorkload::Interactive => 0.0,
74 RoutingWorkload::Batch => 0.08,
75 RoutingWorkload::Background => 0.15,
76 // Stronger push than Background — "prefer local" should
77 // win ties decisively, otherwise the hint is ineffective.
78 RoutingWorkload::LocalPreferred => 0.20,
79 // Local inference avoids network round-trips entirely —
80 // the single biggest latency win available. On macOS the
81 // Foundation Models `system_llm_bonus` stacks on top of
82 // this for `apple/foundation:default`. Match
83 // LocalPreferred's bonus so cloud-streamed fast paths
84 // (e.g. gpt-4o-mini) can still win when locally there's
85 // no model loaded; the weight profile already strongly
86 // favours latency.
87 RoutingWorkload::Fastest => 0.20,
88 }
89 }
90}
91
92/// Qwen3 hybrid thinking control. Qwen3 models were trained with both a
93/// "thinking" (chain-of-thought inside `<think>...</think>`) and a
94/// non-thinking mode. Upstream defaults thinking ON; `/no_think` and
95/// `/think` are the documented per-turn overrides in the chat template.
96///
97/// Scope: applies to the *single-turn* local Qwen3 path driven by
98/// [`apply_chat_template`]. The multi-turn `messages: Vec<Message>`
99/// field on [`GenerateRequest`] is consumed by remote protocol
100/// handlers (OpenAI/Anthropic/Google) which pass through user-supplied
101/// system messages verbatim; this flag is not injected there. If you
102/// need Qwen3 thinking control over a remote API, include `/think` or
103/// `/no_think` explicitly in your own system message.
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
105#[serde(rename_all = "snake_case")]
106pub enum ThinkingMode {
107 /// Let the model decide. No explicit `/think` or `/no_think` directive
108 /// is injected into the system prompt, so Qwen3's trained default
109 /// (thinking ON) applies. `<think>...</think>` output is stripped
110 /// from the returned text.
111 #[default]
112 Auto,
113 /// Inject `/think` into the system prompt to explicitly request the
114 /// thinking phase. Useful when callers want to force reasoning even
115 /// on short prompts the model would normally answer directly.
116 On,
117 /// Inject `/no_think` into the system prompt to suppress the
118 /// thinking phase for faster, more direct responses. This was the
119 /// prior hard-coded behavior; callers now opt into it explicitly.
120 Off,
121}
122
123impl ThinkingMode {
124 /// Return the directive marker to append to the system prompt, or
125 /// `None` when `Auto` (don't inject anything — trust model default).
126 pub fn directive(self) -> Option<&'static str> {
127 match self {
128 ThinkingMode::Auto => None,
129 ThinkingMode::On => Some("/think"),
130 ThinkingMode::Off => Some("/no_think"),
131 }
132 }
133}
134
135/// Parameters controlling generation behavior.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct GenerateParams {
138 /// Sampling temperature (0.0 = greedy, 1.0 = full distribution).
139 #[serde(default = "default_temperature")]
140 pub temperature: f64,
141 /// Top-p (nucleus) sampling threshold.
142 #[serde(default = "default_top_p")]
143 pub top_p: f64,
144 /// Top-k sampling (0 = disabled).
145 #[serde(default)]
146 pub top_k: usize,
147 /// Maximum tokens to generate.
148 #[serde(default = "default_max_tokens")]
149 pub max_tokens: usize,
150 /// Stop sequences — generation halts when any is produced.
151 #[serde(default)]
152 pub stop: Vec<String>,
153 /// Extended thinking budget (tokens). When > 0, enables the model's
154 /// internal reasoning/planning phase before responding. Only supported
155 /// by models with the ExtendedThinking capability (e.g., Claude).
156 #[serde(default)]
157 pub budget_tokens: usize,
158 /// Routing workload class. Interactive requests bias toward lower latency,
159 /// while batch/background work can tolerate slower high-quality local models.
160 #[serde(default)]
161 pub workload: RoutingWorkload,
162 /// Tool choice mode: "auto" (default when tools present), "required" (must use a tool),
163 /// "none" (disable tools). When "required", the model must respond with a tool call,
164 /// eliminating mixed text+JSON responses.
165 #[serde(default)]
166 pub tool_choice: Option<String>,
167 /// OpenAI-compatible parallel tool call control.
168 #[serde(default)]
169 pub parallel_tool_calls: Option<bool>,
170 /// Qwen3 hybrid thinking mode control. `Auto` (default) leaves the
171 /// model at its trained default (thinking on). `On`/`Off` inject
172 /// the documented `/think` or `/no_think` directive into the chat
173 /// template. Ignored by non-Qwen3 models.
174 #[serde(default)]
175 pub thinking: ThinkingMode,
176}
177
178fn default_temperature() -> f64 {
179 0.7
180}
181fn default_top_p() -> f64 {
182 0.9
183}
184fn default_max_tokens() -> usize {
185 4096
186}
187
188impl Default for GenerateParams {
189 fn default() -> Self {
190 Self {
191 temperature: default_temperature(),
192 top_p: default_top_p(),
193 top_k: 0,
194 max_tokens: default_max_tokens(),
195 stop: Vec::new(),
196 budget_tokens: 0,
197 workload: RoutingWorkload::Interactive,
198 tool_choice: None,
199 parallel_tool_calls: None,
200 thinking: ThinkingMode::default(),
201 }
202 }
203}
204
205/// A tool call returned by the model.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ToolCall {
208 /// Provider-assigned tool call ID (e.g. OpenAI `call_abc123`, Anthropic `toolu_abc123`).
209 /// When present, protocol handlers use this for round-trip correlation instead of
210 /// synthesizing positional IDs like `call_0`.
211 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub id: Option<String>,
213 /// Tool/function name.
214 pub name: String,
215 /// Arguments as key-value pairs.
216 pub arguments: std::collections::HashMap<String, serde_json::Value>,
217}
218
219/// A content block in a multimodal message.
220///
221/// The image variants (`ImageBase64`, `ImageUrl`) are fully wired on
222/// the native Qwen2.5-VL backend. The video variants
223/// (`VideoPath`, `VideoUrl`, `VideoBase64`) are defined on the public
224/// request surface so higher-level tooling can express Qwen2.5-VL
225/// video-understanding payloads, but the native backend returns
226/// [`crate::InferenceError::UnsupportedMode`] for them until the
227/// video-tokenization path lands. Remote multimodal providers
228/// (Anthropic, Google Vertex) accept them through the protocol
229/// handlers today.
230#[derive(Debug, Clone, Serialize, Deserialize)]
231#[serde(tag = "type", rename_all = "snake_case")]
232pub enum ContentBlock {
233 /// Plain text content.
234 Text { text: String },
235 /// Base64-encoded image.
236 ImageBase64 {
237 /// Base64-encoded image data.
238 data: String,
239 /// MIME type (e.g., "image/png", "image/jpeg").
240 media_type: String,
241 },
242 /// Image from URL.
243 ImageUrl {
244 /// URL of the image.
245 url: String,
246 /// Detail level for image processing ("auto", "low", "high").
247 #[serde(default = "default_detail")]
248 detail: String,
249 },
250 /// Video loaded from a local filesystem path. Qwen2.5-VL samples
251 /// the clip at `fps` frames/sec (default: backend-chosen) and
252 /// caps at `max_frames` to respect context budgets.
253 VideoPath {
254 path: String,
255 #[serde(default, skip_serializing_if = "Option::is_none")]
256 fps: Option<f32>,
257 #[serde(default, skip_serializing_if = "Option::is_none")]
258 max_frames: Option<u32>,
259 },
260 /// Video accessible over HTTP(S). Semantics as [`ContentBlock::VideoPath`].
261 VideoUrl {
262 url: String,
263 #[serde(default, skip_serializing_if = "Option::is_none")]
264 fps: Option<f32>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 max_frames: Option<u32>,
267 },
268 /// Base64-encoded video bytes. Prefer `VideoPath` when possible;
269 /// inline base64 is expensive to round-trip.
270 VideoBase64 {
271 data: String,
272 media_type: String,
273 #[serde(default, skip_serializing_if = "Option::is_none")]
274 fps: Option<f32>,
275 #[serde(default, skip_serializing_if = "Option::is_none")]
276 max_frames: Option<u32>,
277 },
278 /// Audio loaded from a local filesystem path. Used for
279 /// audio-understanding models (Gemma 4 small variants, Gemini).
280 AudioPath {
281 path: String,
282 /// Optional explicit sample-rate hint. Most backends will
283 /// resample internally; this is a best-effort declaration.
284 #[serde(default, skip_serializing_if = "Option::is_none")]
285 sample_rate: Option<u32>,
286 },
287 /// Audio accessible over HTTP(S).
288 AudioUrl {
289 url: String,
290 #[serde(default, skip_serializing_if = "Option::is_none")]
291 sample_rate: Option<u32>,
292 },
293 /// Base64-encoded audio bytes.
294 AudioBase64 {
295 data: String,
296 media_type: String,
297 #[serde(default, skip_serializing_if = "Option::is_none")]
298 sample_rate: Option<u32>,
299 },
300}
301
302impl ContentBlock {
303 /// Return true if this block carries video data (any encoding).
304 /// Used by backends that need to refuse video inputs until the
305 /// tokenization path is wired.
306 pub fn is_video(&self) -> bool {
307 matches!(
308 self,
309 ContentBlock::VideoPath { .. }
310 | ContentBlock::VideoUrl { .. }
311 | ContentBlock::VideoBase64 { .. }
312 )
313 }
314
315 /// Return true if this block carries audio data (any encoding).
316 /// Used by backends that need to refuse audio inputs until the
317 /// tokenization path is wired. Gemma 4 small variants and Gemini
318 /// accept audio; everything else in CAR rejects with
319 /// `UnsupportedMode`.
320 pub fn is_audio(&self) -> bool {
321 matches!(
322 self,
323 ContentBlock::AudioPath { .. }
324 | ContentBlock::AudioUrl { .. }
325 | ContentBlock::AudioBase64 { .. }
326 )
327 }
328}
329
330fn default_detail() -> String {
331 "auto".to_string()
332}
333
334/// A message in a multi-turn conversation.
335///
336/// The `System` variant exists so callers can express a first-class
337/// system prompt inside `messages: Vec<Message>` without threading it
338/// through the legacy `context: Option<String>` field on
339/// [`GenerateRequest`]. Protocol handlers and local chat templates
340/// that have a native system-role slot (OpenAI, Anthropic, Gemini,
341/// Gemma 4, Qwen) emit it in the right place; ones that don't can
342/// fold it into the first user turn.
343#[derive(Debug, Clone, Serialize, Deserialize)]
344#[serde(tag = "role", rename_all = "snake_case")]
345pub enum Message {
346 /// A system prompt. Appears once, at the start of the conversation.
347 System { content: String },
348 /// A user message (text only).
349 User { content: String },
350 /// A user message with multimodal content (text + images + video + audio).
351 UserMultimodal { content: Vec<ContentBlock> },
352 /// An assistant response, possibly with tool calls.
353 Assistant {
354 #[serde(default)]
355 content: String,
356 #[serde(default)]
357 tool_calls: Vec<ToolCall>,
358 },
359 /// The result of executing a tool call.
360 ToolResult {
361 tool_use_id: String,
362 content: String,
363 },
364 /// Provider-specific output items that need to round-trip
365 /// verbatim across turns. The OpenAI Responses API returns
366 /// reasoning blobs, encrypted_content, web-search results, etc.
367 /// as opaque structured items; the next request must include
368 /// them in the same form to preserve provider-side state.
369 ///
370 /// `protocol` identifies the provider format that produced the
371 /// items (currently `"openai-responses"`). Builder paths that
372 /// don't recognize the protocol drop the variant — there is no
373 /// portable rendering across providers.
374 ProviderOutputItems {
375 protocol: String,
376 items: Vec<serde_json::Value>,
377 },
378}
379
380/// Constraint on the model's response shape. Distinct from `tools` —
381/// tools are a side-channel for action invocation; `response_format`
382/// constrains the *primary* text output to be parseable JSON, optionally
383/// against a caller-supplied schema.
384///
385/// Provider mapping (handled in `protocol.rs`):
386/// * **OpenAI / Azure / OpenAI-compatible**: `response_format: {type: "json_schema", json_schema: {schema, strict, name}}`
387/// (or `{type: "json_object"}` for the looser variant). Strict mode rejects
388/// any deviation from the schema.
389/// * **Google (Gemini)**: `response_mime_type: "application/json"` plus
390/// optional `response_schema`.
391/// * **Anthropic**: no native field as of early 2026 — the schema is
392/// logged at `warn` level and dropped. Callers needing schema-validated
393/// output on Claude should fall back to the `tools` + `tool_choice="required"`
394/// coercion idiom (which is what worked before this field landed).
395///
396/// `JsonObject` is the looser variant — tells the provider "emit valid
397/// JSON, no schema check required". Use when the schema is too dynamic
398/// to spell out but the parse contract still matters.
399#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
400#[serde(tag = "type", rename_all = "snake_case")]
401pub enum ResponseFormat {
402 /// JSON output validated against the provided schema. `strict: true`
403 /// asks the provider to reject any deviation; `false` makes the
404 /// schema a best-effort hint. `name` is OpenAI-specific (the
405 /// `json_schema.name` field, max 64 chars, alphanumerics + `-_`);
406 /// other providers ignore it.
407 JsonSchema {
408 schema: serde_json::Value,
409 #[serde(default)]
410 strict: bool,
411 #[serde(default, skip_serializing_if = "Option::is_none")]
412 name: Option<String>,
413 },
414 /// Plain JSON-mode output — the provider emits valid JSON without
415 /// schema enforcement.
416 JsonObject,
417}
418
419/// A text generation request.
420///
421/// `Default` is derived so call sites can mutate just the fields
422/// they care about: `GenerateRequest { prompt: "...".into(), ..Default::default() }`.
423/// The default `prompt = ""` is useless on its own — callers always
424/// override it — but the `..Default::default()` shorthand stops the
425/// per-call-site mechanical churn every time a new optional field
426/// lands (closes #109).
427#[derive(Debug, Clone, Default, Serialize, Deserialize)]
428pub struct GenerateRequest {
429 /// The prompt to complete (first user message for single-turn).
430 pub prompt: String,
431 /// Optional model override.
432 pub model: Option<String>,
433 /// Generation parameters.
434 #[serde(default)]
435 pub params: GenerateParams,
436 /// Optional memory context to prepend to the prompt.
437 /// When provided, this is injected as a system-level context block
438 /// before the user prompt, grounding the model's response.
439 #[serde(default)]
440 pub context: Option<String>,
441 /// Optional tool definitions for structured tool_use.
442 /// When provided, the model may return tool_calls instead of text.
443 /// Each tool is a JSON object with: name, description, parameters (JSON Schema).
444 #[serde(default, skip_serializing_if = "Option::is_none")]
445 pub tools: Option<Vec<serde_json::Value>>,
446 /// Optional images for vision models.
447 /// When provided with a single-turn prompt, these are included as image content blocks
448 /// in the user message. For multi-turn with `messages`, use `UserMultimodal` variants instead.
449 #[serde(default, skip_serializing_if = "Option::is_none")]
450 pub images: Option<Vec<ContentBlock>>,
451 /// Optional multi-turn conversation history.
452 /// When provided, the backend builds a proper multi-turn message array
453 /// instead of a single user message. The `prompt` field is ignored when
454 /// messages are present.
455 #[serde(default, skip_serializing_if = "Option::is_none")]
456 pub messages: Option<Vec<Message>>,
457 /// Enable prompt caching for Anthropic API.
458 /// When true, system prompt and tools are marked with cache_control breakpoints,
459 /// enabling cache reuse across parent/child agent calls sharing the same prefix.
460 #[serde(default)]
461 pub cache_control: bool,
462 /// Constrain output to JSON (optionally schema-validated). See
463 /// [`ResponseFormat`] for the per-provider mapping. Defaults to
464 /// `None` — free-form text.
465 #[serde(default, skip_serializing_if = "Option::is_none")]
466 pub response_format: Option<ResponseFormat>,
467 /// Caller-supplied routing intent. None preserves the existing
468 /// adaptive vs. pinned-model behavior. When `Some`, the adaptive
469 /// router uses the hint to filter candidates (hard `require`),
470 /// override task selection, and bias the score profile
471 /// (`prefer_local`). See [`crate::intent::IntentHint`].
472 #[serde(default, skip_serializing_if = "Option::is_none")]
473 pub intent: Option<crate::intent::IntentHint>,
474}
475
476/// Wrap a raw prompt in Qwen3 chat format if it's not already formatted.
477///
478/// Thinking behavior follows the caller-supplied [`ThinkingMode`]:
479/// * `Auto` — no directive injected; Qwen3's trained default (thinking
480/// on) applies.
481/// * `On` — the documented `/think` directive is appended to the system
482/// message on its own line, and the model is allowed to emit a full
483/// `<think>...</think>` block before its answer.
484/// * `Off` — `/no_think` is appended to the system message *and* an
485/// empty `<think>\n\n</think>` block is pre-filled after the assistant
486/// marker, matching upstream Qwen3's `enable_thinking=False` jinja
487/// template. The pre-filled closed tags structurally prevent the
488/// model from emitting a thinking block even if the directive is
489/// contradicted later in the prompt.
490///
491/// When context is provided it is injected into the system message to
492/// ground the model's response with memory. The directive always
493/// appears *after* the context blob so user-supplied memory cannot
494/// nudge the directive's parse position.
495pub fn apply_chat_template(prompt: &str, context: Option<&str>, thinking: ThinkingMode) -> String {
496 if prompt.contains("<|im_start|>") {
497 return prompt.to_string();
498 }
499 // Directive goes on its own line at the end of the system message
500 // (never concatenated onto prose) so Qwen3's chat template parser
501 // sees `/think`/`/no_think` as a standalone token, not as part of
502 // "assistant. /no_think".
503 let directive_line = match thinking.directive() {
504 Some(d) => format!("\n{d}"),
505 None => String::new(),
506 };
507 // For Off, pre-fill a closed empty thinking block after the
508 // assistant marker. This mirrors upstream Qwen3's jinja behavior
509 // when `enable_thinking=False` and is the hard-switch (structural)
510 // form of the mode, whereas `/no_think` alone is a soft directive.
511 let thinking_prefill = match thinking {
512 ThinkingMode::Off => "<think>\n\n</think>\n\n",
513 _ => "",
514 };
515 match context {
516 Some(ctx) => format!(
517 "<|im_start|>system\nYou are a helpful assistant. Use the following context to inform your response.\n\n{ctx}{directive_line}<|im_end|>\n\
518 <|im_start|>user\n{prompt}<|im_end|>\n\
519 <|im_start|>assistant\n{thinking_prefill}"
520 ),
521 None => format!(
522 "<|im_start|>system\nYou are a helpful assistant.{directive_line}<|im_end|>\n\
523 <|im_start|>user\n{prompt}<|im_end|>\n\
524 <|im_start|>assistant\n{thinking_prefill}"
525 ),
526 }
527}
528
529/// Strip Qwen3 `<think>...</think>` blocks from model output, honoring
530/// the caller's requested [`ThinkingMode`]:
531///
532/// * `On` — the caller explicitly asked for reasoning; return the raw
533/// text verbatim so `<think>...</think>` is visible.
534/// * `Auto` / `Off` — strip the thinking block and return only the
535/// post-thinking answer. If the output contains an opening `<think>`
536/// without a closing tag (truncation or stop before the model
537/// finished thinking) return an empty string rather than leaking a
538/// dangling tag to the caller.
539pub fn strip_thinking(text: &str, thinking: ThinkingMode) -> String {
540 if matches!(thinking, ThinkingMode::On) {
541 return text.to_string();
542 }
543 strip_thinking_block(text)
544}
545
546/// Remove a leading `<think>...</think>` block unconditionally.
547/// Returns "" if `<think>` opens but never closes (incomplete output).
548///
549/// When that "opened but never closed" branch fires, log a warn line
550/// — the caller is about to receive an empty string for what was
551/// almost certainly a budget-truncation. Surfaces issue #168's root
552/// cause without changing the return contract: callers (e.g. car-cli)
553/// that look at stderr can tell users to either bump
554/// `--max-tokens` or pass `--thinking off`. The decision lives in
555/// the strip helper because every text-completion path funnels
556/// through it; logging at the call sites would be a lot of
557/// duplication.
558fn strip_thinking_block(text: &str) -> String {
559 if let Some(end) = text.find("</think>") {
560 text[end + 8..].trim_start().to_string()
561 } else if text.contains("<think>") {
562 tracing::warn!(
563 target: "car_inference::tasks::generate",
564 raw_len = text.len(),
565 "model output opened <think> but never closed it — \
566 likely truncated by max_tokens; returning empty text. \
567 Increase max_tokens, or set thinking=off to suppress \
568 the reasoning phase."
569 );
570 String::new()
571 } else {
572 text.to_string()
573 }
574}
575
576/// Callback for FLARE-style re-retrieval during generation.
577/// Called with partial generation text, returns additional context or None.
578pub type RetrievalCallback = Box<dyn Fn(&str) -> Option<String> + Send>;
579
580#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
581/// Generate text from a prompt using the loaded model.
582///
583/// Returns `(text, time_to_first_token_ms)`. TTFT is measured from
584/// function entry through prefill to the moment the first generated
585/// token has been sampled — the user-visible "did anything happen yet"
586/// gate. `None` only when the prompt encodes to zero tokens (degenerate
587/// input).
588pub async fn generate(
589 backend: &mut CandleBackend,
590 req: GenerateRequest,
591) -> Result<(String, Option<u64>), InferenceError> {
592 let start = std::time::Instant::now();
593
594 // Reset KV cache so each generation starts fresh (prevents cross-call state bleed)
595 backend.clear_kv_cache();
596
597 let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
598 let tokens = backend.encode(&formatted)?;
599 let eos = backend.eos_token_id();
600 let eos_alt = backend.token_id("<|im_end|>");
601 let params = &req.params;
602
603 if tokens.is_empty() {
604 return Ok((String::new(), None));
605 }
606
607 // Truncate to model's max context length minus generation headroom.
608 // This prevents KV cache overflow on long prompts.
609 let max_ctx = backend.context_length().unwrap_or(32768);
610 let headroom = params.max_tokens.min(max_ctx / 4);
611 let max_prompt = max_ctx.saturating_sub(headroom);
612 let tokens = if tokens.len() > max_prompt {
613 eprintln!(
614 "[car-inference] truncating prompt from {} to {} tokens (context_length={})",
615 tokens.len(),
616 max_prompt,
617 max_ctx
618 );
619 tokens[tokens.len() - max_prompt..].to_vec()
620 } else {
621 tokens
622 };
623
624 let mut generated = Vec::new();
625
626 // Prefill: process all prompt tokens, sample first generated token from prefill logits
627 let logits = backend.forward(&tokens, 0)?;
628 let mut next_token = sample_token(&logits, params)?;
629 let ttft_ms = Some(start.elapsed().as_millis() as u64);
630
631 for _i in 0..params.max_tokens {
632 // Check EOS
633 if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
634 {
635 break;
636 }
637
638 generated.push(next_token);
639
640 // Check stop sequences
641 if !params.stop.is_empty() {
642 let text_so_far = backend.decode(&generated)?;
643 if params.stop.iter().any(|s| text_so_far.contains(s)) {
644 break;
645 }
646 }
647
648 // Generate next token
649 let pos = tokens.len() + generated.len() - 1;
650 let logits = backend.forward(&[next_token], pos)?;
651 next_token = sample_token(&logits, params)?;
652 }
653
654 let text = backend.decode(&generated)?;
655 Ok((strip_thinking(&text, params.thinking), ttft_ms))
656}
657
658#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
659/// Generate with FLARE-style confidence-triggered re-retrieval.
660///
661/// Monitors token logit confidence during generation. When a window of
662/// low-confidence tokens is detected, pauses, re-queries memory with the
663/// partial generation, and resumes with augmented context.
664pub async fn generate_with_retrieval(
665 backend: &mut CandleBackend,
666 mut req: GenerateRequest,
667 retrieval_cb: RetrievalCallback,
668) -> Result<String, InferenceError> {
669 // First pass: generate normally
670 backend.clear_kv_cache();
671 let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
672 let tokens = backend.encode(&formatted)?;
673 let eos = backend.eos_token_id();
674 let eos_alt = backend.token_id("<|im_end|>");
675 let params = req.params.clone();
676
677 if tokens.is_empty() {
678 return Ok(String::new());
679 }
680
681 let mut generated = Vec::new();
682 let mut low_confidence_count = 0u32;
683 let mut retrieval_attempts = 0u32;
684 let max_retrievals = 2;
685 let confidence_threshold = 0.4f32;
686 let low_confidence_window = 3u32;
687
688 let logits = backend.forward(&tokens, 0)?;
689 let mut next_token = sample_token(&logits, ¶ms)?;
690
691 for _i in 0..params.max_tokens {
692 if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
693 {
694 break;
695 }
696
697 generated.push(next_token);
698
699 // Generate next token and check confidence
700 let pos = tokens.len() + generated.len() - 1;
701 let logits = backend.forward(&[next_token], pos)?;
702
703 // Check max logit probability for confidence
704 let logits_f32: Vec<f32> = logits
705 .squeeze(0)
706 .unwrap_or(logits.clone())
707 .to_dtype(candle_core::DType::F32)
708 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
709 .to_vec1()
710 .unwrap_or_default();
711
712 if !logits_f32.is_empty() {
713 // Compute softmax max probability
714 let max_logit = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
715 let exp_sum: f32 = logits_f32.iter().map(|&v| (v - max_logit).exp()).sum();
716 let max_prob = 1.0 / exp_sum; // probability of the top token
717
718 if max_prob < confidence_threshold {
719 low_confidence_count += 1;
720 } else {
721 low_confidence_count = 0;
722 }
723
724 // Trigger re-retrieval after sustained low confidence
725 if low_confidence_count >= low_confidence_window && retrieval_attempts < max_retrievals
726 {
727 retrieval_attempts += 1;
728 low_confidence_count = 0;
729
730 // Use partial generation as re-retrieval query
731 let partial = backend.decode(&generated)?;
732 if let Some(new_context) = retrieval_cb(&partial) {
733 // Restart generation with augmented context
734 let combined_context = match req.context.take() {
735 Some(old) => format!("{}\n\n{}", old, new_context),
736 None => new_context,
737 };
738 req.context = Some(combined_context);
739
740 // Re-encode and restart
741 backend.clear_kv_cache();
742 let new_formatted = apply_chat_template(
743 &req.prompt,
744 req.context.as_deref(),
745 req.params.thinking,
746 );
747 let new_tokens = backend.encode(&new_formatted)?;
748 generated.clear();
749
750 let logits = backend.forward(&new_tokens, 0)?;
751 next_token = sample_token(&logits, ¶ms)?;
752 continue;
753 }
754 }
755 }
756
757 next_token = sample_token(&logits, ¶ms)?;
758 }
759
760 let text = backend.decode(&generated)?;
761 Ok(strip_thinking(&text, params.thinking))
762}
763
764#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
765/// Sample a token, suppressing specific token IDs (set to -inf before sampling).
766pub fn sample_token_suppress(
767 logits: &Tensor,
768 params: &GenerateParams,
769 suppress: &[u32],
770) -> Result<u32, InferenceError> {
771 if suppress.is_empty() {
772 return sample_token(logits, params);
773 }
774 // Clone logits and set suppressed tokens to -inf
775 let mut logits_vec: Vec<f32> = logits
776 .squeeze(0)
777 .unwrap_or(logits.clone())
778 .to_dtype(candle_core::DType::F32)
779 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
780 .to_vec1()
781 .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
782 // Handle 2D logits (take last row)
783 let dims = logits.dims();
784 if dims.len() == 2 {
785 let vocab = dims[dims.len() - 1];
786 let start = logits_vec.len() - vocab;
787 logits_vec = logits_vec[start..].to_vec();
788 }
789 for &id in suppress {
790 if (id as usize) < logits_vec.len() {
791 logits_vec[id as usize] = f32::NEG_INFINITY;
792 }
793 }
794 let modified = Tensor::from_vec(
795 logits_vec,
796 logits.squeeze(0).unwrap_or(logits.clone()).shape(),
797 logits.device(),
798 )
799 .map_err(|e| InferenceError::InferenceFailed(format!("from_vec: {e}")))?;
800 sample_token(&modified, params)
801}
802
803#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
804/// Sample a token from logits using temperature + top-p + top-k.
805pub fn sample_token(logits: &Tensor, params: &GenerateParams) -> Result<u32, InferenceError> {
806 let logits = logits
807 .squeeze(0)
808 .map_err(|e| InferenceError::InferenceFailed(format!("squeeze: {e}")))?;
809 let logits = logits
810 .to_dtype(candle_core::DType::F32)
811 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?;
812
813 // Get last position's logits
814 let dim = logits.dims();
815 let logits = if dim.len() == 2 {
816 logits
817 .get(dim[0] - 1)
818 .map_err(|e| InferenceError::InferenceFailed(format!("get last: {e}")))?
819 } else {
820 logits
821 };
822
823 // Greedy decoding
824 if params.temperature <= 0.0 {
825 let token = logits
826 .argmax(0)
827 .map_err(|e| InferenceError::InferenceFailed(format!("argmax: {e}")))?
828 .to_scalar::<u32>()
829 .map_err(|e| InferenceError::InferenceFailed(format!("scalar: {e}")))?;
830 return Ok(token);
831 }
832
833 // Temperature scaling
834 let logits = (&logits / params.temperature)
835 .map_err(|e| InferenceError::InferenceFailed(format!("temp scale: {e}")))?;
836
837 let mut logits_vec: Vec<f32> = logits
838 .to_vec1()
839 .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
840
841 // Top-k filtering
842 if params.top_k > 0 && params.top_k < logits_vec.len() {
843 let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
844 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
845 let threshold = indexed[params.top_k].1;
846 for v in &mut logits_vec {
847 if *v < threshold {
848 *v = f32::NEG_INFINITY;
849 }
850 }
851 }
852
853 // Softmax
854 let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
855 let exp: Vec<f32> = logits_vec.iter().map(|&v| (v - max_logit).exp()).collect();
856 let sum: f32 = exp.iter().sum();
857 let mut probs: Vec<f32> = exp.iter().map(|&v| v / sum).collect();
858
859 // Top-p (nucleus) filtering
860 if params.top_p < 1.0 {
861 let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
862 sorted_indices.sort_by(|&a, &b| {
863 probs[b]
864 .partial_cmp(&probs[a])
865 .unwrap_or(std::cmp::Ordering::Equal)
866 });
867
868 let mut cumsum = 0.0f32;
869 let mut cutoff_idx = sorted_indices.len();
870 for (i, &idx) in sorted_indices.iter().enumerate() {
871 cumsum += probs[idx];
872 if cumsum > params.top_p as f32 {
873 cutoff_idx = i + 1;
874 break;
875 }
876 }
877
878 let keep: std::collections::HashSet<usize> =
879 sorted_indices[..cutoff_idx].iter().copied().collect();
880 for (i, p) in probs.iter_mut().enumerate() {
881 if !keep.contains(&i) {
882 *p = 0.0;
883 }
884 }
885
886 // Renormalize
887 let sum: f32 = probs.iter().sum();
888 if sum > 0.0 {
889 for p in &mut probs {
890 *p /= sum;
891 }
892 }
893 }
894
895 // Categorical sample
896 let r: f32 = rand_f32();
897 let mut cumsum = 0.0f32;
898 for (i, &p) in probs.iter().enumerate() {
899 cumsum += p;
900 if cumsum >= r {
901 return Ok(i as u32);
902 }
903 }
904
905 // Fallback: return highest prob token
906 Ok(probs
907 .iter()
908 .enumerate()
909 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
910 .map(|(i, _)| i as u32)
911 .unwrap_or(0))
912}
913
914#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
915/// Random float in [0, 1) using the rand crate.
916fn rand_f32() -> f32 {
917 rand::random::<f32>()
918}
919
920#[cfg(test)]
921mod thinking_tests {
922 use super::*;
923
924 #[test]
925 fn auto_injects_no_directive_and_no_prefill() {
926 let out = apply_chat_template("hi", None, ThinkingMode::Auto);
927 assert!(!out.contains("/no_think"));
928 assert!(!out.contains("/think"));
929 assert!(!out.contains("<think>"));
930 assert!(out.contains("<|im_start|>user\nhi<|im_end|>"));
931 }
932
933 #[test]
934 fn off_injects_no_think_on_own_line_and_prefills_empty_think() {
935 let out = apply_chat_template("hi", None, ThinkingMode::Off);
936 // Directive on its own line, not concatenated onto prose.
937 assert!(out.contains("\n/no_think<|im_end|>"));
938 assert!(!out.contains(" /no_think"));
939 // Closed empty thinking block pre-filled after assistant marker
940 // — the upstream jinja hard-switch for enable_thinking=False.
941 assert!(out.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
942 }
943
944 #[test]
945 fn on_injects_think_and_no_prefill() {
946 let out = apply_chat_template("hi", None, ThinkingMode::On);
947 assert!(out.contains("\n/think<|im_end|>"));
948 assert!(!out.contains("/no_think"));
949 assert!(!out.contains("<think>"));
950 }
951
952 #[test]
953 fn pre_formatted_prompt_is_untouched() {
954 let pre = "<|im_start|>system\ncustom<|im_end|>\n<|im_start|>user\nhi<|im_end|>";
955 let out = apply_chat_template(pre, None, ThinkingMode::Off);
956 assert_eq!(out, pre);
957 }
958
959 #[test]
960 fn directive_appears_after_context_not_before() {
961 let out = apply_chat_template("q?", Some("some memory"), ThinkingMode::Off);
962 let ctx_idx = out.find("some memory").unwrap();
963 let directive_idx = out.find("/no_think").unwrap();
964 assert!(
965 directive_idx > ctx_idx,
966 "directive must appear after context so user memory cannot nudge the parse"
967 );
968 }
969
970 #[test]
971 fn default_params_is_auto() {
972 assert_eq!(GenerateParams::default().thinking, ThinkingMode::Auto);
973 }
974
975 #[test]
976 fn thinking_mode_serde_snake_case() {
977 let json = serde_json::to_string(&ThinkingMode::Off).unwrap();
978 assert_eq!(json, "\"off\"");
979 let parsed: ThinkingMode = serde_json::from_str("\"on\"").unwrap();
980 assert_eq!(parsed, ThinkingMode::On);
981 }
982
983 #[test]
984 fn strip_preserves_thinking_when_on() {
985 let text = "<think>reasoning here</think>the answer";
986 let out = strip_thinking(text, ThinkingMode::On);
987 assert_eq!(
988 out, text,
989 "On mode must return raw text with <think> visible"
990 );
991 }
992
993 #[test]
994 fn strip_removes_thinking_when_auto_or_off() {
995 let text = "<think>reasoning</think>the answer";
996 assert_eq!(strip_thinking(text, ThinkingMode::Auto), "the answer");
997 assert_eq!(strip_thinking(text, ThinkingMode::Off), "the answer");
998 }
999
1000 #[test]
1001 fn strip_returns_empty_on_unterminated_think() {
1002 // Output was cut off mid-thinking — don't leak the dangling tag.
1003 let text = "<think>mid-reasoning, never closed";
1004 assert_eq!(strip_thinking(text, ThinkingMode::Auto), "");
1005 assert_eq!(strip_thinking(text, ThinkingMode::Off), "");
1006 // On mode still returns the raw text — caller asked for it.
1007 assert_eq!(strip_thinking(text, ThinkingMode::On), text);
1008 }
1009
1010 #[test]
1011 fn strip_is_noop_when_no_think_tag() {
1012 let text = "just a plain answer";
1013 assert_eq!(strip_thinking(text, ThinkingMode::Auto), text);
1014 assert_eq!(strip_thinking(text, ThinkingMode::Off), text);
1015 assert_eq!(strip_thinking(text, ThinkingMode::On), text);
1016 }
1017}
1018
1019#[cfg(test)]
1020mod workload_tests {
1021 use super::*;
1022
1023 #[test]
1024 fn all_workload_weights_sum_to_one() {
1025 for w in [
1026 RoutingWorkload::Interactive,
1027 RoutingWorkload::Batch,
1028 RoutingWorkload::Background,
1029 RoutingWorkload::LocalPreferred,
1030 RoutingWorkload::Fastest,
1031 ] {
1032 let (q, l, c) = w.weights();
1033 let sum = q + l + c;
1034 assert!(
1035 (sum - 1.0).abs() < 1e-6,
1036 "weights for {w:?} sum to {sum}, expected 1.0"
1037 );
1038 }
1039 }
1040
1041 #[test]
1042 fn fastest_weights_dominate_on_latency() {
1043 let (q, l, c) = RoutingWorkload::Fastest.weights();
1044 // Latency should be the largest by a wide margin — that's the
1045 // whole point of this workload class.
1046 assert!(l > q && l > c);
1047 assert!(l >= 0.7, "latency weight too small: {l}");
1048 }
1049
1050 #[test]
1051 fn fastest_is_latency_sensitive() {
1052 assert!(RoutingWorkload::Fastest.is_latency_sensitive());
1053 }
1054}