Skip to main content

car_inference/tasks/
generate.rs

1//! Text generation with sampling.
2
3#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4use candle_core::Tensor;
5use serde::{Deserialize, Serialize};
6
7#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
8use crate::backend::CandleBackend;
9use crate::InferenceError;
10
11/// How latency-sensitive this generation request is.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
13#[serde(rename_all = "snake_case")]
14pub enum RoutingWorkload {
15    /// User-facing, interactive request where latency matters.
16    #[default]
17    Interactive,
18    /// Batch job where latency matters somewhat, but quality/cost matter more.
19    Batch,
20    /// Background or offline work where latency is a weak concern.
21    Background,
22    /// Caller explicitly prefers on-device models. Distinct from
23    /// `Background` (which is "this is a background job, latency
24    /// barely matters"). The caller may be doing latency-sensitive
25    /// interactive work but wants the privacy / cost / offline
26    /// properties of local inference. Same `local_bonus` as
27    /// `Background` plus a slightly more quality-aware weight profile
28    /// — the caller chose local for a reason, not because the work is
29    /// throwaway.
30    LocalPreferred,
31    /// Aggressive latency bias for time-to-first-token. Voice turns
32    /// (specifically the fast track in the two-track sidecar pattern)
33    /// pick this. Quality and cost are heavily downweighted; the
34    /// router prefers whichever model produces a first token soonest.
35    /// On macOS 26+ this typically resolves to `apple/foundation:default`
36    /// via the Foundation Models system-LLM bonus. Reached via the
37    /// `IntentHint::prefer_fast` flag (or `RoutingWorkload::Fastest`
38    /// directly when callers know they want it).
39    Fastest,
40}
41
42impl RoutingWorkload {
43    pub fn is_latency_sensitive(self) -> bool {
44        matches!(
45            self,
46            RoutingWorkload::Interactive
47                | RoutingWorkload::LocalPreferred
48                | RoutingWorkload::Fastest,
49        )
50    }
51
52    pub fn weights(self) -> (f64, f64, f64) {
53        // Tuple is `(quality, latency, cost)` — destructured at the
54        // single use site `adaptive_router.rs:892`.
55        match self {
56            RoutingWorkload::Interactive => (0.45, 0.40, 0.15),
57            RoutingWorkload::Batch => (0.60, 0.15, 0.25),
58            RoutingWorkload::Background => (0.65, 0.05, 0.30),
59            // Quality-aware (closer to Interactive) but tolerant of
60            // some latency hit since the caller chose local. Cost
61            // weight matches Batch.
62            RoutingWorkload::LocalPreferred => (0.55, 0.20, 0.25),
63            // Voice fast track: latency is everything. Quality and
64            // cost are deliberately near-floor — first audio in
65            // <500ms beats any quality gain that takes another
66            // round-trip. Sums to 1.0 like every other variant.
67            RoutingWorkload::Fastest => (0.10, 0.85, 0.05),
68        }
69    }
70
71    pub fn local_bonus(self) -> f64 {
72        match self {
73            RoutingWorkload::Interactive => 0.0,
74            RoutingWorkload::Batch => 0.08,
75            RoutingWorkload::Background => 0.15,
76            // Stronger push than Background — "prefer local" should
77            // win ties decisively, otherwise the hint is ineffective.
78            RoutingWorkload::LocalPreferred => 0.20,
79            // Local inference avoids network round-trips entirely —
80            // the single biggest latency win available. On macOS the
81            // Foundation Models `system_llm_bonus` stacks on top of
82            // this for `apple/foundation:default`. Match
83            // LocalPreferred's bonus so cloud-streamed fast paths
84            // (e.g. gpt-4o-mini) can still win when locally there's
85            // no model loaded; the weight profile already strongly
86            // favours latency.
87            RoutingWorkload::Fastest => 0.20,
88        }
89    }
90}
91
92/// Qwen3 hybrid thinking control. Qwen3 models were trained with both a
93/// "thinking" (chain-of-thought inside `<think>...</think>`) and a
94/// non-thinking mode. Upstream defaults thinking ON; `/no_think` and
95/// `/think` are the documented per-turn overrides in the chat template.
96///
97/// Scope: applies to the *single-turn* local Qwen3 path driven by
98/// [`apply_chat_template`]. The multi-turn `messages: Vec<Message>`
99/// field on [`GenerateRequest`] is consumed by remote protocol
100/// handlers (OpenAI/Anthropic/Google) which pass through user-supplied
101/// system messages verbatim; this flag is not injected there. If you
102/// need Qwen3 thinking control over a remote API, include `/think` or
103/// `/no_think` explicitly in your own system message.
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
105#[serde(rename_all = "snake_case")]
106pub enum ThinkingMode {
107    /// Let the model decide. No explicit `/think` or `/no_think` directive
108    /// is injected into the system prompt, so Qwen3's trained default
109    /// (thinking ON) applies. `<think>...</think>` output is stripped
110    /// from the returned text.
111    #[default]
112    Auto,
113    /// Inject `/think` into the system prompt to explicitly request the
114    /// thinking phase. Useful when callers want to force reasoning even
115    /// on short prompts the model would normally answer directly.
116    On,
117    /// Inject `/no_think` into the system prompt to suppress the
118    /// thinking phase for faster, more direct responses. This was the
119    /// prior hard-coded behavior; callers now opt into it explicitly.
120    Off,
121}
122
123impl ThinkingMode {
124    /// Return the directive marker to append to the system prompt, or
125    /// `None` when `Auto` (don't inject anything — trust model default).
126    pub fn directive(self) -> Option<&'static str> {
127        match self {
128            ThinkingMode::Auto => None,
129            ThinkingMode::On => Some("/think"),
130            ThinkingMode::Off => Some("/no_think"),
131        }
132    }
133}
134
135/// Parameters controlling generation behavior.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct GenerateParams {
138    /// Sampling temperature (0.0 = greedy, 1.0 = full distribution).
139    #[serde(default = "default_temperature")]
140    pub temperature: f64,
141    /// Top-p (nucleus) sampling threshold.
142    #[serde(default = "default_top_p")]
143    pub top_p: f64,
144    /// Top-k sampling (0 = disabled).
145    #[serde(default)]
146    pub top_k: usize,
147    /// Maximum tokens to generate.
148    #[serde(default = "default_max_tokens")]
149    pub max_tokens: usize,
150    /// Stop sequences — generation halts when any is produced.
151    #[serde(default)]
152    pub stop: Vec<String>,
153    /// Extended thinking budget (tokens). When > 0, enables the model's
154    /// internal reasoning/planning phase before responding. Only supported
155    /// by models with the ExtendedThinking capability (e.g., Claude).
156    #[serde(default)]
157    pub budget_tokens: usize,
158    /// Routing workload class. Interactive requests bias toward lower latency,
159    /// while batch/background work can tolerate slower high-quality local models.
160    #[serde(default)]
161    pub workload: RoutingWorkload,
162    /// Tool choice mode: "auto" (default when tools present), "required" (must use a tool),
163    /// "none" (disable tools). When "required", the model must respond with a tool call,
164    /// eliminating mixed text+JSON responses.
165    #[serde(default)]
166    pub tool_choice: Option<String>,
167    /// OpenAI-compatible parallel tool call control.
168    #[serde(default)]
169    pub parallel_tool_calls: Option<bool>,
170    /// Qwen3 hybrid thinking mode control. `Auto` (default) leaves the
171    /// model at its trained default (thinking on). `On`/`Off` inject
172    /// the documented `/think` or `/no_think` directive into the chat
173    /// template. Ignored by non-Qwen3 models.
174    #[serde(default)]
175    pub thinking: ThinkingMode,
176}
177
178fn default_temperature() -> f64 {
179    0.7
180}
181fn default_top_p() -> f64 {
182    0.9
183}
184fn default_max_tokens() -> usize {
185    4096
186}
187
188impl Default for GenerateParams {
189    fn default() -> Self {
190        Self {
191            temperature: default_temperature(),
192            top_p: default_top_p(),
193            top_k: 0,
194            max_tokens: default_max_tokens(),
195            stop: Vec::new(),
196            budget_tokens: 0,
197            workload: RoutingWorkload::Interactive,
198            tool_choice: None,
199            parallel_tool_calls: None,
200            thinking: ThinkingMode::default(),
201        }
202    }
203}
204
205/// A tool call returned by the model.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ToolCall {
208    /// Provider-assigned tool call ID (e.g. OpenAI `call_abc123`, Anthropic `toolu_abc123`).
209    /// When present, protocol handlers use this for round-trip correlation instead of
210    /// synthesizing positional IDs like `call_0`.
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub id: Option<String>,
213    /// Tool/function name.
214    pub name: String,
215    /// Arguments as key-value pairs.
216    pub arguments: std::collections::HashMap<String, serde_json::Value>,
217}
218
219/// A content block in a multimodal message.
220///
221/// The image variants (`ImageBase64`, `ImageUrl`) are fully wired on
222/// the native Qwen2.5-VL backend. The video variants
223/// (`VideoPath`, `VideoUrl`, `VideoBase64`) are defined on the public
224/// request surface so higher-level tooling can express Qwen2.5-VL
225/// video-understanding payloads, but the native backend returns
226/// [`crate::InferenceError::UnsupportedMode`] for them until the
227/// video-tokenization path lands. Remote multimodal providers
228/// (Anthropic, Google Vertex) accept them through the protocol
229/// handlers today.
230#[derive(Debug, Clone, Serialize, Deserialize)]
231#[serde(tag = "type", rename_all = "snake_case")]
232pub enum ContentBlock {
233    /// Plain text content.
234    Text { text: String },
235    /// Base64-encoded image.
236    ImageBase64 {
237        /// Base64-encoded image data.
238        data: String,
239        /// MIME type (e.g., "image/png", "image/jpeg").
240        media_type: String,
241    },
242    /// Image from URL.
243    ImageUrl {
244        /// URL of the image.
245        url: String,
246        /// Detail level for image processing ("auto", "low", "high").
247        #[serde(default = "default_detail")]
248        detail: String,
249    },
250    /// Video loaded from a local filesystem path. Qwen2.5-VL samples
251    /// the clip at `fps` frames/sec (default: backend-chosen) and
252    /// caps at `max_frames` to respect context budgets.
253    VideoPath {
254        path: String,
255        #[serde(default, skip_serializing_if = "Option::is_none")]
256        fps: Option<f32>,
257        #[serde(default, skip_serializing_if = "Option::is_none")]
258        max_frames: Option<u32>,
259    },
260    /// Video accessible over HTTP(S). Semantics as [`ContentBlock::VideoPath`].
261    VideoUrl {
262        url: String,
263        #[serde(default, skip_serializing_if = "Option::is_none")]
264        fps: Option<f32>,
265        #[serde(default, skip_serializing_if = "Option::is_none")]
266        max_frames: Option<u32>,
267    },
268    /// Base64-encoded video bytes. Prefer `VideoPath` when possible;
269    /// inline base64 is expensive to round-trip.
270    VideoBase64 {
271        data: String,
272        media_type: String,
273        #[serde(default, skip_serializing_if = "Option::is_none")]
274        fps: Option<f32>,
275        #[serde(default, skip_serializing_if = "Option::is_none")]
276        max_frames: Option<u32>,
277    },
278    /// Audio loaded from a local filesystem path. Used for
279    /// audio-understanding models (Gemma 4 small variants, Gemini).
280    AudioPath {
281        path: String,
282        /// Optional explicit sample-rate hint. Most backends will
283        /// resample internally; this is a best-effort declaration.
284        #[serde(default, skip_serializing_if = "Option::is_none")]
285        sample_rate: Option<u32>,
286    },
287    /// Audio accessible over HTTP(S).
288    AudioUrl {
289        url: String,
290        #[serde(default, skip_serializing_if = "Option::is_none")]
291        sample_rate: Option<u32>,
292    },
293    /// Base64-encoded audio bytes.
294    AudioBase64 {
295        data: String,
296        media_type: String,
297        #[serde(default, skip_serializing_if = "Option::is_none")]
298        sample_rate: Option<u32>,
299    },
300}
301
302impl ContentBlock {
303    /// Return true if this block carries video data (any encoding).
304    /// Used by backends that need to refuse video inputs until the
305    /// tokenization path is wired.
306    pub fn is_video(&self) -> bool {
307        matches!(
308            self,
309            ContentBlock::VideoPath { .. }
310                | ContentBlock::VideoUrl { .. }
311                | ContentBlock::VideoBase64 { .. }
312        )
313    }
314
315    /// Return true if this block carries audio data (any encoding).
316    /// Used by backends that need to refuse audio inputs until the
317    /// tokenization path is wired. Gemma 4 small variants and Gemini
318    /// accept audio; everything else in CAR rejects with
319    /// `UnsupportedMode`.
320    pub fn is_audio(&self) -> bool {
321        matches!(
322            self,
323            ContentBlock::AudioPath { .. }
324                | ContentBlock::AudioUrl { .. }
325                | ContentBlock::AudioBase64 { .. }
326        )
327    }
328}
329
330fn default_detail() -> String {
331    "auto".to_string()
332}
333
334/// A message in a multi-turn conversation.
335///
336/// The `System` variant exists so callers can express a first-class
337/// system prompt inside `messages: Vec<Message>` without threading it
338/// through the legacy `context: Option<String>` field on
339/// [`GenerateRequest`]. Protocol handlers and local chat templates
340/// that have a native system-role slot (OpenAI, Anthropic, Gemini,
341/// Gemma 4, Qwen) emit it in the right place; ones that don't can
342/// fold it into the first user turn.
343#[derive(Debug, Clone, Serialize, Deserialize)]
344#[serde(tag = "role", rename_all = "snake_case")]
345pub enum Message {
346    /// A system prompt. Appears once, at the start of the conversation.
347    System { content: String },
348    /// A user message (text only).
349    User { content: String },
350    /// A user message with multimodal content (text + images + video + audio).
351    UserMultimodal { content: Vec<ContentBlock> },
352    /// An assistant response, possibly with tool calls.
353    Assistant {
354        #[serde(default)]
355        content: String,
356        #[serde(default)]
357        tool_calls: Vec<ToolCall>,
358    },
359    /// The result of executing a tool call.
360    ToolResult {
361        tool_use_id: String,
362        content: String,
363    },
364    /// Provider-specific output items that need to round-trip
365    /// verbatim across turns. The OpenAI Responses API returns
366    /// reasoning blobs, encrypted_content, web-search results, etc.
367    /// as opaque structured items; the next request must include
368    /// them in the same form to preserve provider-side state.
369    ///
370    /// `protocol` identifies the provider format that produced the
371    /// items (currently `"openai-responses"`). Builder paths that
372    /// don't recognize the protocol drop the variant — there is no
373    /// portable rendering across providers.
374    ProviderOutputItems {
375        protocol: String,
376        items: Vec<serde_json::Value>,
377    },
378}
379
380/// Constraint on the model's response shape. Distinct from `tools` —
381/// tools are a side-channel for action invocation; `response_format`
382/// constrains the *primary* text output to be parseable JSON, optionally
383/// against a caller-supplied schema.
384///
385/// Provider mapping (handled in `protocol.rs`):
386/// * **OpenAI / Azure / OpenAI-compatible**: `response_format: {type: "json_schema", json_schema: {schema, strict, name}}`
387///   (or `{type: "json_object"}` for the looser variant). Strict mode rejects
388///   any deviation from the schema.
389/// * **Google (Gemini)**: `response_mime_type: "application/json"` plus
390///   optional `response_schema`.
391/// * **Anthropic**: no native field as of early 2026 — the schema is
392///   logged at `warn` level and dropped. Callers needing schema-validated
393///   output on Claude should fall back to the `tools` + `tool_choice="required"`
394///   coercion idiom (which is what worked before this field landed).
395///
396/// `JsonObject` is the looser variant — tells the provider "emit valid
397/// JSON, no schema check required". Use when the schema is too dynamic
398/// to spell out but the parse contract still matters.
399#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
400#[serde(tag = "type", rename_all = "snake_case")]
401pub enum ResponseFormat {
402    /// JSON output validated against the provided schema. `strict: true`
403    /// asks the provider to reject any deviation; `false` makes the
404    /// schema a best-effort hint. `name` is OpenAI-specific (the
405    /// `json_schema.name` field, max 64 chars, alphanumerics + `-_`);
406    /// other providers ignore it.
407    JsonSchema {
408        schema: serde_json::Value,
409        #[serde(default)]
410        strict: bool,
411        #[serde(default, skip_serializing_if = "Option::is_none")]
412        name: Option<String>,
413    },
414    /// Plain JSON-mode output — the provider emits valid JSON without
415    /// schema enforcement.
416    JsonObject,
417}
418
419/// A text generation request.
420///
421/// `Default` is derived so call sites can mutate just the fields
422/// they care about: `GenerateRequest { prompt: "...".into(), ..Default::default() }`.
423/// The default `prompt = ""` is useless on its own — callers always
424/// override it — but the `..Default::default()` shorthand stops the
425/// per-call-site mechanical churn every time a new optional field
426/// lands (closes #109).
427#[derive(Debug, Clone, Default, Serialize, Deserialize)]
428pub struct GenerateRequest {
429    /// The prompt to complete (first user message for single-turn).
430    pub prompt: String,
431    /// Optional model override.
432    pub model: Option<String>,
433    /// Generation parameters.
434    #[serde(default)]
435    pub params: GenerateParams,
436    /// Optional memory context to prepend to the prompt.
437    /// When provided, this is injected as a system-level context block
438    /// before the user prompt, grounding the model's response.
439    #[serde(default)]
440    pub context: Option<String>,
441    /// Optional tool definitions for structured tool_use.
442    /// When provided, the model may return tool_calls instead of text.
443    /// Each tool is a JSON object with: name, description, parameters (JSON Schema).
444    #[serde(default, skip_serializing_if = "Option::is_none")]
445    pub tools: Option<Vec<serde_json::Value>>,
446    /// Optional images for vision models.
447    /// When provided with a single-turn prompt, these are included as image content blocks
448    /// in the user message. For multi-turn with `messages`, use `UserMultimodal` variants instead.
449    #[serde(default, skip_serializing_if = "Option::is_none")]
450    pub images: Option<Vec<ContentBlock>>,
451    /// Optional multi-turn conversation history.
452    /// When provided, the backend builds a proper multi-turn message array
453    /// instead of a single user message. The `prompt` field is ignored when
454    /// messages are present.
455    #[serde(default, skip_serializing_if = "Option::is_none")]
456    pub messages: Option<Vec<Message>>,
457    /// Enable prompt caching for Anthropic API.
458    /// When true, system prompt and tools are marked with cache_control breakpoints,
459    /// enabling cache reuse across parent/child agent calls sharing the same prefix.
460    #[serde(default)]
461    pub cache_control: bool,
462    /// Constrain output to JSON (optionally schema-validated). See
463    /// [`ResponseFormat`] for the per-provider mapping. Defaults to
464    /// `None` — free-form text.
465    #[serde(default, skip_serializing_if = "Option::is_none")]
466    pub response_format: Option<ResponseFormat>,
467    /// Caller-supplied routing intent. None preserves the existing
468    /// adaptive vs. pinned-model behavior. When `Some`, the adaptive
469    /// router uses the hint to filter candidates (hard `require`),
470    /// override task selection, and bias the score profile
471    /// (`prefer_local`). See [`crate::intent::IntentHint`].
472    #[serde(default, skip_serializing_if = "Option::is_none")]
473    pub intent: Option<crate::intent::IntentHint>,
474}
475
476/// Wrap a raw prompt in Qwen3 chat format if it's not already formatted.
477///
478/// Thinking behavior follows the caller-supplied [`ThinkingMode`]:
479/// * `Auto` — no directive injected; Qwen3's trained default (thinking
480///   on) applies.
481/// * `On` — the documented `/think` directive is appended to the system
482///   message on its own line, and the model is allowed to emit a full
483///   `<think>...</think>` block before its answer.
484/// * `Off` — `/no_think` is appended to the system message *and* an
485///   empty `<think>\n\n</think>` block is pre-filled after the assistant
486///   marker, matching upstream Qwen3's `enable_thinking=False` jinja
487///   template. The pre-filled closed tags structurally prevent the
488///   model from emitting a thinking block even if the directive is
489///   contradicted later in the prompt.
490///
491/// When context is provided it is injected into the system message to
492/// ground the model's response with memory. The directive always
493/// appears *after* the context blob so user-supplied memory cannot
494/// nudge the directive's parse position.
495pub fn apply_chat_template(prompt: &str, context: Option<&str>, thinking: ThinkingMode) -> String {
496    if prompt.contains("<|im_start|>") {
497        return prompt.to_string();
498    }
499    // Directive goes on its own line at the end of the system message
500    // (never concatenated onto prose) so Qwen3's chat template parser
501    // sees `/think`/`/no_think` as a standalone token, not as part of
502    // "assistant. /no_think".
503    let directive_line = match thinking.directive() {
504        Some(d) => format!("\n{d}"),
505        None => String::new(),
506    };
507    // For Off, pre-fill a closed empty thinking block after the
508    // assistant marker. This mirrors upstream Qwen3's jinja behavior
509    // when `enable_thinking=False` and is the hard-switch (structural)
510    // form of the mode, whereas `/no_think` alone is a soft directive.
511    let thinking_prefill = match thinking {
512        ThinkingMode::Off => "<think>\n\n</think>\n\n",
513        _ => "",
514    };
515    match context {
516        Some(ctx) => format!(
517            "<|im_start|>system\nYou are a helpful assistant. Use the following context to inform your response.\n\n{ctx}{directive_line}<|im_end|>\n\
518             <|im_start|>user\n{prompt}<|im_end|>\n\
519             <|im_start|>assistant\n{thinking_prefill}"
520        ),
521        None => format!(
522            "<|im_start|>system\nYou are a helpful assistant.{directive_line}<|im_end|>\n\
523             <|im_start|>user\n{prompt}<|im_end|>\n\
524             <|im_start|>assistant\n{thinking_prefill}"
525        ),
526    }
527}
528
529/// Strip Qwen3 `<think>...</think>` blocks from model output, honoring
530/// the caller's requested [`ThinkingMode`]:
531///
532/// * `On` — the caller explicitly asked for reasoning; return the raw
533///   text verbatim so `<think>...</think>` is visible.
534/// * `Auto` / `Off` — strip the thinking block and return only the
535///   post-thinking answer. If the output contains an opening `<think>`
536///   without a closing tag (truncation or stop before the model
537///   finished thinking) return an empty string rather than leaking a
538///   dangling tag to the caller.
539pub fn strip_thinking(text: &str, thinking: ThinkingMode) -> String {
540    if matches!(thinking, ThinkingMode::On) {
541        return text.to_string();
542    }
543    strip_thinking_block(text)
544}
545
546/// Remove a leading `<think>...</think>` block unconditionally.
547/// Returns "" if `<think>` opens but never closes (incomplete output).
548///
549/// When that "opened but never closed" branch fires, log a warn line
550/// — the caller is about to receive an empty string for what was
551/// almost certainly a budget-truncation. Surfaces issue #168's root
552/// cause without changing the return contract: callers (e.g. car-cli)
553/// that look at stderr can tell users to either bump
554/// `--max-tokens` or pass `--thinking off`. The decision lives in
555/// the strip helper because every text-completion path funnels
556/// through it; logging at the call sites would be a lot of
557/// duplication.
558fn strip_thinking_block(text: &str) -> String {
559    if let Some(end) = text.find("</think>") {
560        text[end + 8..].trim_start().to_string()
561    } else if text.contains("<think>") {
562        tracing::warn!(
563            target: "car_inference::tasks::generate",
564            raw_len = text.len(),
565            "model output opened <think> but never closed it — \
566             likely truncated by max_tokens; returning empty text. \
567             Increase max_tokens, or set thinking=off to suppress \
568             the reasoning phase."
569        );
570        String::new()
571    } else {
572        text.to_string()
573    }
574}
575
576/// Callback for FLARE-style re-retrieval during generation.
577/// Called with partial generation text, returns additional context or None.
578pub type RetrievalCallback = Box<dyn Fn(&str) -> Option<String> + Send>;
579
580#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
581/// Generate text from a prompt using the loaded model.
582///
583/// Returns `(text, time_to_first_token_ms)`. TTFT is measured from
584/// function entry through prefill to the moment the first generated
585/// token has been sampled — the user-visible "did anything happen yet"
586/// gate. `None` only when the prompt encodes to zero tokens (degenerate
587/// input).
588pub async fn generate(
589    backend: &mut CandleBackend,
590    req: GenerateRequest,
591) -> Result<(String, Option<u64>), InferenceError> {
592    let start = std::time::Instant::now();
593
594    // Reset KV cache so each generation starts fresh (prevents cross-call state bleed)
595    backend.clear_kv_cache();
596
597    let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
598    let tokens = backend.encode(&formatted)?;
599    let eos = backend.eos_token_id();
600    let eos_alt = backend.token_id("<|im_end|>");
601    let params = &req.params;
602
603    if tokens.is_empty() {
604        return Ok((String::new(), None));
605    }
606
607    // Truncate to model's max context length minus generation headroom.
608    // This prevents KV cache overflow on long prompts.
609    let max_ctx = backend.context_length().unwrap_or(32768);
610    let headroom = params.max_tokens.min(max_ctx / 4);
611    let max_prompt = max_ctx.saturating_sub(headroom);
612    let tokens = if tokens.len() > max_prompt {
613        eprintln!(
614            "[car-inference] truncating prompt from {} to {} tokens (context_length={})",
615            tokens.len(),
616            max_prompt,
617            max_ctx
618        );
619        tokens[tokens.len() - max_prompt..].to_vec()
620    } else {
621        tokens
622    };
623
624    let mut generated = Vec::new();
625
626    // Prefill: process all prompt tokens, sample first generated token from prefill logits
627    let logits = backend.forward(&tokens, 0)?;
628    let mut next_token = sample_token(&logits, params)?;
629    let ttft_ms = Some(start.elapsed().as_millis() as u64);
630
631    for _i in 0..params.max_tokens {
632        // Check EOS
633        if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
634        {
635            break;
636        }
637
638        generated.push(next_token);
639
640        // Check stop sequences
641        if !params.stop.is_empty() {
642            let text_so_far = backend.decode(&generated)?;
643            if params.stop.iter().any(|s| text_so_far.contains(s)) {
644                break;
645            }
646        }
647
648        // Generate next token
649        let pos = tokens.len() + generated.len() - 1;
650        let logits = backend.forward(&[next_token], pos)?;
651        next_token = sample_token(&logits, params)?;
652    }
653
654    let text = backend.decode(&generated)?;
655    Ok((strip_thinking(&text, params.thinking), ttft_ms))
656}
657
658#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
659/// Generate with FLARE-style confidence-triggered re-retrieval.
660///
661/// Monitors token logit confidence during generation. When a window of
662/// low-confidence tokens is detected, pauses, re-queries memory with the
663/// partial generation, and resumes with augmented context.
664pub async fn generate_with_retrieval(
665    backend: &mut CandleBackend,
666    mut req: GenerateRequest,
667    retrieval_cb: RetrievalCallback,
668) -> Result<String, InferenceError> {
669    // First pass: generate normally
670    backend.clear_kv_cache();
671    let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
672    let tokens = backend.encode(&formatted)?;
673    let eos = backend.eos_token_id();
674    let eos_alt = backend.token_id("<|im_end|>");
675    let params = req.params.clone();
676
677    if tokens.is_empty() {
678        return Ok(String::new());
679    }
680
681    let mut generated = Vec::new();
682    let mut low_confidence_count = 0u32;
683    let mut retrieval_attempts = 0u32;
684    let max_retrievals = 2;
685    let confidence_threshold = 0.4f32;
686    let low_confidence_window = 3u32;
687
688    let logits = backend.forward(&tokens, 0)?;
689    let mut next_token = sample_token(&logits, &params)?;
690
691    for _i in 0..params.max_tokens {
692        if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
693        {
694            break;
695        }
696
697        generated.push(next_token);
698
699        // Generate next token and check confidence
700        let pos = tokens.len() + generated.len() - 1;
701        let logits = backend.forward(&[next_token], pos)?;
702
703        // Check max logit probability for confidence
704        let logits_f32: Vec<f32> = logits
705            .squeeze(0)
706            .unwrap_or(logits.clone())
707            .to_dtype(candle_core::DType::F32)
708            .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
709            .to_vec1()
710            .unwrap_or_default();
711
712        if !logits_f32.is_empty() {
713            // Compute softmax max probability
714            let max_logit = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
715            let exp_sum: f32 = logits_f32.iter().map(|&v| (v - max_logit).exp()).sum();
716            let max_prob = 1.0 / exp_sum; // probability of the top token
717
718            if max_prob < confidence_threshold {
719                low_confidence_count += 1;
720            } else {
721                low_confidence_count = 0;
722            }
723
724            // Trigger re-retrieval after sustained low confidence
725            if low_confidence_count >= low_confidence_window && retrieval_attempts < max_retrievals
726            {
727                retrieval_attempts += 1;
728                low_confidence_count = 0;
729
730                // Use partial generation as re-retrieval query
731                let partial = backend.decode(&generated)?;
732                if let Some(new_context) = retrieval_cb(&partial) {
733                    // Restart generation with augmented context
734                    let combined_context = match req.context.take() {
735                        Some(old) => format!("{}\n\n{}", old, new_context),
736                        None => new_context,
737                    };
738                    req.context = Some(combined_context);
739
740                    // Re-encode and restart
741                    backend.clear_kv_cache();
742                    let new_formatted = apply_chat_template(
743                        &req.prompt,
744                        req.context.as_deref(),
745                        req.params.thinking,
746                    );
747                    let new_tokens = backend.encode(&new_formatted)?;
748                    generated.clear();
749
750                    let logits = backend.forward(&new_tokens, 0)?;
751                    next_token = sample_token(&logits, &params)?;
752                    continue;
753                }
754            }
755        }
756
757        next_token = sample_token(&logits, &params)?;
758    }
759
760    let text = backend.decode(&generated)?;
761    Ok(strip_thinking(&text, params.thinking))
762}
763
764#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
765/// Sample a token, suppressing specific token IDs (set to -inf before sampling).
766pub fn sample_token_suppress(
767    logits: &Tensor,
768    params: &GenerateParams,
769    suppress: &[u32],
770) -> Result<u32, InferenceError> {
771    if suppress.is_empty() {
772        return sample_token(logits, params);
773    }
774    // Clone logits and set suppressed tokens to -inf
775    let mut logits_vec: Vec<f32> = logits
776        .squeeze(0)
777        .unwrap_or(logits.clone())
778        .to_dtype(candle_core::DType::F32)
779        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
780        .to_vec1()
781        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
782    // Handle 2D logits (take last row)
783    let dims = logits.dims();
784    if dims.len() == 2 {
785        let vocab = dims[dims.len() - 1];
786        let start = logits_vec.len() - vocab;
787        logits_vec = logits_vec[start..].to_vec();
788    }
789    for &id in suppress {
790        if (id as usize) < logits_vec.len() {
791            logits_vec[id as usize] = f32::NEG_INFINITY;
792        }
793    }
794    let modified = Tensor::from_vec(
795        logits_vec,
796        logits.squeeze(0).unwrap_or(logits.clone()).shape(),
797        logits.device(),
798    )
799    .map_err(|e| InferenceError::InferenceFailed(format!("from_vec: {e}")))?;
800    sample_token(&modified, params)
801}
802
803#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
804/// Sample a token from logits using temperature + top-p + top-k.
805pub fn sample_token(logits: &Tensor, params: &GenerateParams) -> Result<u32, InferenceError> {
806    let logits = logits
807        .squeeze(0)
808        .map_err(|e| InferenceError::InferenceFailed(format!("squeeze: {e}")))?;
809    let logits = logits
810        .to_dtype(candle_core::DType::F32)
811        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?;
812
813    // Get last position's logits
814    let dim = logits.dims();
815    let logits = if dim.len() == 2 {
816        logits
817            .get(dim[0] - 1)
818            .map_err(|e| InferenceError::InferenceFailed(format!("get last: {e}")))?
819    } else {
820        logits
821    };
822
823    // Greedy decoding
824    if params.temperature <= 0.0 {
825        let token = logits
826            .argmax(0)
827            .map_err(|e| InferenceError::InferenceFailed(format!("argmax: {e}")))?
828            .to_scalar::<u32>()
829            .map_err(|e| InferenceError::InferenceFailed(format!("scalar: {e}")))?;
830        return Ok(token);
831    }
832
833    // Temperature scaling
834    let logits = (&logits / params.temperature)
835        .map_err(|e| InferenceError::InferenceFailed(format!("temp scale: {e}")))?;
836
837    let mut logits_vec: Vec<f32> = logits
838        .to_vec1()
839        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
840
841    // Top-k filtering
842    if params.top_k > 0 && params.top_k < logits_vec.len() {
843        let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
844        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
845        let threshold = indexed[params.top_k].1;
846        for v in &mut logits_vec {
847            if *v < threshold {
848                *v = f32::NEG_INFINITY;
849            }
850        }
851    }
852
853    // Softmax
854    let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
855    let exp: Vec<f32> = logits_vec.iter().map(|&v| (v - max_logit).exp()).collect();
856    let sum: f32 = exp.iter().sum();
857    let mut probs: Vec<f32> = exp.iter().map(|&v| v / sum).collect();
858
859    // Top-p (nucleus) filtering
860    if params.top_p < 1.0 {
861        let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
862        sorted_indices.sort_by(|&a, &b| {
863            probs[b]
864                .partial_cmp(&probs[a])
865                .unwrap_or(std::cmp::Ordering::Equal)
866        });
867
868        let mut cumsum = 0.0f32;
869        let mut cutoff_idx = sorted_indices.len();
870        for (i, &idx) in sorted_indices.iter().enumerate() {
871            cumsum += probs[idx];
872            if cumsum > params.top_p as f32 {
873                cutoff_idx = i + 1;
874                break;
875            }
876        }
877
878        let keep: std::collections::HashSet<usize> =
879            sorted_indices[..cutoff_idx].iter().copied().collect();
880        for (i, p) in probs.iter_mut().enumerate() {
881            if !keep.contains(&i) {
882                *p = 0.0;
883            }
884        }
885
886        // Renormalize
887        let sum: f32 = probs.iter().sum();
888        if sum > 0.0 {
889            for p in &mut probs {
890                *p /= sum;
891            }
892        }
893    }
894
895    // Categorical sample
896    let r: f32 = rand_f32();
897    let mut cumsum = 0.0f32;
898    for (i, &p) in probs.iter().enumerate() {
899        cumsum += p;
900        if cumsum >= r {
901            return Ok(i as u32);
902        }
903    }
904
905    // Fallback: return highest prob token
906    Ok(probs
907        .iter()
908        .enumerate()
909        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
910        .map(|(i, _)| i as u32)
911        .unwrap_or(0))
912}
913
914#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
915/// Random float in [0, 1) using the rand crate.
916fn rand_f32() -> f32 {
917    rand::random::<f32>()
918}
919
920#[cfg(test)]
921mod thinking_tests {
922    use super::*;
923
924    #[test]
925    fn auto_injects_no_directive_and_no_prefill() {
926        let out = apply_chat_template("hi", None, ThinkingMode::Auto);
927        assert!(!out.contains("/no_think"));
928        assert!(!out.contains("/think"));
929        assert!(!out.contains("<think>"));
930        assert!(out.contains("<|im_start|>user\nhi<|im_end|>"));
931    }
932
933    #[test]
934    fn off_injects_no_think_on_own_line_and_prefills_empty_think() {
935        let out = apply_chat_template("hi", None, ThinkingMode::Off);
936        // Directive on its own line, not concatenated onto prose.
937        assert!(out.contains("\n/no_think<|im_end|>"));
938        assert!(!out.contains(" /no_think"));
939        // Closed empty thinking block pre-filled after assistant marker
940        // — the upstream jinja hard-switch for enable_thinking=False.
941        assert!(out.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
942    }
943
944    #[test]
945    fn on_injects_think_and_no_prefill() {
946        let out = apply_chat_template("hi", None, ThinkingMode::On);
947        assert!(out.contains("\n/think<|im_end|>"));
948        assert!(!out.contains("/no_think"));
949        assert!(!out.contains("<think>"));
950    }
951
952    #[test]
953    fn pre_formatted_prompt_is_untouched() {
954        let pre = "<|im_start|>system\ncustom<|im_end|>\n<|im_start|>user\nhi<|im_end|>";
955        let out = apply_chat_template(pre, None, ThinkingMode::Off);
956        assert_eq!(out, pre);
957    }
958
959    #[test]
960    fn directive_appears_after_context_not_before() {
961        let out = apply_chat_template("q?", Some("some memory"), ThinkingMode::Off);
962        let ctx_idx = out.find("some memory").unwrap();
963        let directive_idx = out.find("/no_think").unwrap();
964        assert!(
965            directive_idx > ctx_idx,
966            "directive must appear after context so user memory cannot nudge the parse"
967        );
968    }
969
970    #[test]
971    fn default_params_is_auto() {
972        assert_eq!(GenerateParams::default().thinking, ThinkingMode::Auto);
973    }
974
975    #[test]
976    fn thinking_mode_serde_snake_case() {
977        let json = serde_json::to_string(&ThinkingMode::Off).unwrap();
978        assert_eq!(json, "\"off\"");
979        let parsed: ThinkingMode = serde_json::from_str("\"on\"").unwrap();
980        assert_eq!(parsed, ThinkingMode::On);
981    }
982
983    #[test]
984    fn strip_preserves_thinking_when_on() {
985        let text = "<think>reasoning here</think>the answer";
986        let out = strip_thinking(text, ThinkingMode::On);
987        assert_eq!(
988            out, text,
989            "On mode must return raw text with <think> visible"
990        );
991    }
992
993    #[test]
994    fn strip_removes_thinking_when_auto_or_off() {
995        let text = "<think>reasoning</think>the answer";
996        assert_eq!(strip_thinking(text, ThinkingMode::Auto), "the answer");
997        assert_eq!(strip_thinking(text, ThinkingMode::Off), "the answer");
998    }
999
1000    #[test]
1001    fn strip_returns_empty_on_unterminated_think() {
1002        // Output was cut off mid-thinking — don't leak the dangling tag.
1003        let text = "<think>mid-reasoning, never closed";
1004        assert_eq!(strip_thinking(text, ThinkingMode::Auto), "");
1005        assert_eq!(strip_thinking(text, ThinkingMode::Off), "");
1006        // On mode still returns the raw text — caller asked for it.
1007        assert_eq!(strip_thinking(text, ThinkingMode::On), text);
1008    }
1009
1010    #[test]
1011    fn strip_is_noop_when_no_think_tag() {
1012        let text = "just a plain answer";
1013        assert_eq!(strip_thinking(text, ThinkingMode::Auto), text);
1014        assert_eq!(strip_thinking(text, ThinkingMode::Off), text);
1015        assert_eq!(strip_thinking(text, ThinkingMode::On), text);
1016    }
1017}
1018
1019#[cfg(test)]
1020mod workload_tests {
1021    use super::*;
1022
1023    #[test]
1024    fn all_workload_weights_sum_to_one() {
1025        for w in [
1026            RoutingWorkload::Interactive,
1027            RoutingWorkload::Batch,
1028            RoutingWorkload::Background,
1029            RoutingWorkload::LocalPreferred,
1030            RoutingWorkload::Fastest,
1031        ] {
1032            let (q, l, c) = w.weights();
1033            let sum = q + l + c;
1034            assert!(
1035                (sum - 1.0).abs() < 1e-6,
1036                "weights for {w:?} sum to {sum}, expected 1.0"
1037            );
1038        }
1039    }
1040
1041    #[test]
1042    fn fastest_weights_dominate_on_latency() {
1043        let (q, l, c) = RoutingWorkload::Fastest.weights();
1044        // Latency should be the largest by a wide margin — that's the
1045        // whole point of this workload class.
1046        assert!(l > q && l > c);
1047        assert!(l >= 0.7, "latency weight too small: {l}");
1048    }
1049
1050    #[test]
1051    fn fastest_is_latency_sensitive() {
1052        assert!(RoutingWorkload::Fastest.is_latency_sensitive());
1053    }
1054}