Skip to main content

car_inference/tasks/
generate.rs

1//! Text generation with sampling.
2
3#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
4use candle_core::Tensor;
5use serde::{Deserialize, Serialize};
6
7#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
8use crate::backend::CandleBackend;
9use crate::InferenceError;
10
11/// How latency-sensitive this generation request is.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
13#[serde(rename_all = "snake_case")]
14pub enum RoutingWorkload {
15    /// User-facing, interactive request where latency matters.
16    #[default]
17    Interactive,
18    /// Batch job where latency matters somewhat, but quality/cost matter more.
19    Batch,
20    /// Background or offline work where latency is a weak concern.
21    Background,
22    /// Caller explicitly prefers on-device models. Distinct from
23    /// `Background` (which is "this is a background job, latency
24    /// barely matters"). The caller may be doing latency-sensitive
25    /// interactive work but wants the privacy / cost / offline
26    /// properties of local inference. Same `local_bonus` as
27    /// `Background` plus a slightly more quality-aware weight profile
28    /// — the caller chose local for a reason, not because the work is
29    /// throwaway.
30    LocalPreferred,
31    /// Aggressive latency bias for time-to-first-token. Voice turns
32    /// (specifically the fast track in the two-track sidecar pattern)
33    /// pick this. Quality and cost are heavily downweighted; the
34    /// router prefers whichever model produces a first token soonest.
35    /// On macOS 26+ this typically resolves to `apple/foundation:default`
36    /// via the Foundation Models system-LLM bonus. Reached via the
37    /// `IntentHint::prefer_fast` flag (or `RoutingWorkload::Fastest`
38    /// directly when callers know they want it).
39    Fastest,
40}
41
42impl RoutingWorkload {
43    pub fn is_latency_sensitive(self) -> bool {
44        matches!(
45            self,
46            RoutingWorkload::Interactive
47                | RoutingWorkload::LocalPreferred
48                | RoutingWorkload::Fastest,
49        )
50    }
51
52    pub fn weights(self) -> (f64, f64, f64) {
53        // Tuple is `(quality, latency, cost)` — destructured at the
54        // single use site `adaptive_router.rs:892`.
55        match self {
56            RoutingWorkload::Interactive => (0.45, 0.40, 0.15),
57            RoutingWorkload::Batch => (0.60, 0.15, 0.25),
58            RoutingWorkload::Background => (0.65, 0.05, 0.30),
59            // Quality-aware (closer to Interactive) but tolerant of
60            // some latency hit since the caller chose local. Cost
61            // weight matches Batch.
62            RoutingWorkload::LocalPreferred => (0.55, 0.20, 0.25),
63            // Voice fast track: latency is everything. Quality and
64            // cost are deliberately near-floor — first audio in
65            // <500ms beats any quality gain that takes another
66            // round-trip. Sums to 1.0 like every other variant.
67            RoutingWorkload::Fastest => (0.10, 0.85, 0.05),
68        }
69    }
70
71    pub fn local_bonus(self) -> f64 {
72        match self {
73            RoutingWorkload::Interactive => 0.0,
74            RoutingWorkload::Batch => 0.08,
75            RoutingWorkload::Background => 0.15,
76            // Stronger push than Background — "prefer local" should
77            // win ties decisively, otherwise the hint is ineffective.
78            RoutingWorkload::LocalPreferred => 0.20,
79            // Local inference avoids network round-trips entirely —
80            // the single biggest latency win available. On macOS the
81            // Foundation Models `system_llm_bonus` stacks on top of
82            // this for `apple/foundation:default`. Match
83            // LocalPreferred's bonus so cloud-streamed fast paths
84            // (e.g. gpt-4o-mini) can still win when locally there's
85            // no model loaded; the weight profile already strongly
86            // favours latency.
87            RoutingWorkload::Fastest => 0.20,
88        }
89    }
90}
91
92/// Qwen3 hybrid thinking control. Qwen3 models were trained with both a
93/// "thinking" (chain-of-thought inside `<think>...</think>`) and a
94/// non-thinking mode. Upstream defaults thinking ON; `/no_think` and
95/// `/think` are the documented per-turn overrides in the chat template.
96///
97/// Scope: applies to the *single-turn* local Qwen3 path driven by
98/// [`apply_chat_template`]. The multi-turn `messages: Vec<Message>`
99/// field on [`GenerateRequest`] is consumed by remote protocol
100/// handlers (OpenAI/Anthropic/Google) which pass through user-supplied
101/// system messages verbatim; this flag is not injected there. If you
102/// need Qwen3 thinking control over a remote API, include `/think` or
103/// `/no_think` explicitly in your own system message.
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
105#[serde(rename_all = "snake_case")]
106pub enum ThinkingMode {
107    /// Let the model decide. No explicit `/think` or `/no_think` directive
108    /// is injected into the system prompt, so Qwen3's trained default
109    /// (thinking ON) applies. `<think>...</think>` output is stripped
110    /// from the returned text.
111    #[default]
112    Auto,
113    /// Inject `/think` into the system prompt to explicitly request the
114    /// thinking phase. Useful when callers want to force reasoning even
115    /// on short prompts the model would normally answer directly.
116    On,
117    /// Inject `/no_think` into the system prompt to suppress the
118    /// thinking phase for faster, more direct responses. This was the
119    /// prior hard-coded behavior; callers now opt into it explicitly.
120    Off,
121}
122
123impl ThinkingMode {
124    /// Return the directive marker to append to the system prompt, or
125    /// `None` when `Auto` (don't inject anything — trust model default).
126    pub fn directive(self) -> Option<&'static str> {
127        match self {
128            ThinkingMode::Auto => None,
129            ThinkingMode::On => Some("/think"),
130            ThinkingMode::Off => Some("/no_think"),
131        }
132    }
133}
134
135/// Parameters controlling generation behavior.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct GenerateParams {
138    /// Sampling temperature (0.0 = greedy, 1.0 = full distribution).
139    #[serde(default = "default_temperature")]
140    pub temperature: f64,
141    /// Top-p (nucleus) sampling threshold.
142    #[serde(default = "default_top_p")]
143    pub top_p: f64,
144    /// Top-k sampling (0 = disabled).
145    #[serde(default)]
146    pub top_k: usize,
147    /// Maximum tokens to generate.
148    #[serde(default = "default_max_tokens")]
149    pub max_tokens: usize,
150    /// Stop sequences — generation halts when any is produced.
151    #[serde(default)]
152    pub stop: Vec<String>,
153    /// Extended thinking budget (tokens). When > 0, enables the model's
154    /// internal reasoning/planning phase before responding. Only supported
155    /// by models with the ExtendedThinking capability (e.g., Claude).
156    #[serde(default)]
157    pub budget_tokens: usize,
158    /// Routing workload class. Interactive requests bias toward lower latency,
159    /// while batch/background work can tolerate slower high-quality local models.
160    #[serde(default)]
161    pub workload: RoutingWorkload,
162    /// Tool choice mode: "auto" (default when tools present), "required" (must use a tool),
163    /// "none" (disable tools). When "required", the model must respond with a tool call,
164    /// eliminating mixed text+JSON responses.
165    #[serde(default)]
166    pub tool_choice: Option<String>,
167    /// OpenAI-compatible parallel tool call control.
168    #[serde(default)]
169    pub parallel_tool_calls: Option<bool>,
170    /// Qwen3 hybrid thinking mode control. `Auto` (default) leaves the
171    /// model at its trained default (thinking on). `On`/`Off` inject
172    /// the documented `/think` or `/no_think` directive into the chat
173    /// template. Ignored by non-Qwen3 models.
174    #[serde(default)]
175    pub thinking: ThinkingMode,
176}
177
178fn default_temperature() -> f64 {
179    0.7
180}
181fn default_top_p() -> f64 {
182    0.9
183}
184fn default_max_tokens() -> usize {
185    4096
186}
187
188impl Default for GenerateParams {
189    fn default() -> Self {
190        Self {
191            temperature: default_temperature(),
192            top_p: default_top_p(),
193            top_k: 0,
194            max_tokens: default_max_tokens(),
195            stop: Vec::new(),
196            budget_tokens: 0,
197            workload: RoutingWorkload::Interactive,
198            tool_choice: None,
199            parallel_tool_calls: None,
200            thinking: ThinkingMode::default(),
201        }
202    }
203}
204
205/// A tool call returned by the model.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ToolCall {
208    /// Provider-assigned tool call ID (e.g. OpenAI `call_abc123`, Anthropic `toolu_abc123`).
209    /// When present, protocol handlers use this for round-trip correlation instead of
210    /// synthesizing positional IDs like `call_0`.
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub id: Option<String>,
213    /// Tool/function name.
214    pub name: String,
215    /// Arguments as key-value pairs.
216    pub arguments: std::collections::HashMap<String, serde_json::Value>,
217}
218
219/// A content block in a multimodal message.
220///
221/// The image variants (`ImageBase64`, `ImageUrl`) are fully wired on
222/// the native Qwen2.5-VL backend. The video variants
223/// (`VideoPath`, `VideoUrl`, `VideoBase64`) are defined on the public
224/// request surface so higher-level tooling can express Qwen2.5-VL
225/// video-understanding payloads, but the native backend returns
226/// [`crate::InferenceError::UnsupportedMode`] for them until the
227/// video-tokenization path lands. Remote multimodal providers
228/// (Anthropic, Google Vertex) accept them through the protocol
229/// handlers today.
230#[derive(Debug, Clone, Serialize, Deserialize)]
231#[serde(tag = "type", rename_all = "snake_case")]
232pub enum ContentBlock {
233    /// Plain text content.
234    Text { text: String },
235    /// Base64-encoded image.
236    ImageBase64 {
237        /// Base64-encoded image data.
238        data: String,
239        /// MIME type (e.g., "image/png", "image/jpeg").
240        media_type: String,
241    },
242    /// Image from URL.
243    ImageUrl {
244        /// URL of the image.
245        url: String,
246        /// Detail level for image processing ("auto", "low", "high").
247        #[serde(default = "default_detail")]
248        detail: String,
249    },
250    /// Video loaded from a local filesystem path. Qwen2.5-VL samples
251    /// the clip at `fps` frames/sec (default: backend-chosen) and
252    /// caps at `max_frames` to respect context budgets.
253    VideoPath {
254        path: String,
255        #[serde(default, skip_serializing_if = "Option::is_none")]
256        fps: Option<f32>,
257        #[serde(default, skip_serializing_if = "Option::is_none")]
258        max_frames: Option<u32>,
259    },
260    /// Video accessible over HTTP(S). Semantics as [`ContentBlock::VideoPath`].
261    VideoUrl {
262        url: String,
263        #[serde(default, skip_serializing_if = "Option::is_none")]
264        fps: Option<f32>,
265        #[serde(default, skip_serializing_if = "Option::is_none")]
266        max_frames: Option<u32>,
267    },
268    /// Base64-encoded video bytes. Prefer `VideoPath` when possible;
269    /// inline base64 is expensive to round-trip.
270    VideoBase64 {
271        data: String,
272        media_type: String,
273        #[serde(default, skip_serializing_if = "Option::is_none")]
274        fps: Option<f32>,
275        #[serde(default, skip_serializing_if = "Option::is_none")]
276        max_frames: Option<u32>,
277    },
278    /// Audio loaded from a local filesystem path. Used for
279    /// audio-understanding models (Gemma 4 small variants, Gemini).
280    AudioPath {
281        path: String,
282        /// Optional explicit sample-rate hint. Most backends will
283        /// resample internally; this is a best-effort declaration.
284        #[serde(default, skip_serializing_if = "Option::is_none")]
285        sample_rate: Option<u32>,
286    },
287    /// Audio accessible over HTTP(S).
288    AudioUrl {
289        url: String,
290        #[serde(default, skip_serializing_if = "Option::is_none")]
291        sample_rate: Option<u32>,
292    },
293    /// Base64-encoded audio bytes.
294    AudioBase64 {
295        data: String,
296        media_type: String,
297        #[serde(default, skip_serializing_if = "Option::is_none")]
298        sample_rate: Option<u32>,
299    },
300}
301
302impl ContentBlock {
303    /// Return true if this block carries video data (any encoding).
304    /// Used by backends that need to refuse video inputs until the
305    /// tokenization path is wired.
306    pub fn is_video(&self) -> bool {
307        matches!(
308            self,
309            ContentBlock::VideoPath { .. }
310                | ContentBlock::VideoUrl { .. }
311                | ContentBlock::VideoBase64 { .. }
312        )
313    }
314
315    /// Return true if this block carries audio data (any encoding).
316    /// Used by backends that need to refuse audio inputs until the
317    /// tokenization path is wired. Gemma 4 small variants and Gemini
318    /// accept audio; everything else in CAR rejects with
319    /// `UnsupportedMode`.
320    pub fn is_audio(&self) -> bool {
321        matches!(
322            self,
323            ContentBlock::AudioPath { .. }
324                | ContentBlock::AudioUrl { .. }
325                | ContentBlock::AudioBase64 { .. }
326        )
327    }
328}
329
330fn default_detail() -> String {
331    "auto".to_string()
332}
333
334/// A message in a multi-turn conversation.
335///
336/// The `System` variant exists so callers can express a first-class
337/// system prompt inside `messages: Vec<Message>` without threading it
338/// through the legacy `context: Option<String>` field on
339/// [`GenerateRequest`]. Protocol handlers and local chat templates
340/// that have a native system-role slot (OpenAI, Anthropic, Gemini,
341/// Gemma 4, Qwen) emit it in the right place; ones that don't can
342/// fold it into the first user turn.
343#[derive(Debug, Clone, Serialize, Deserialize)]
344#[serde(tag = "role", rename_all = "snake_case")]
345pub enum Message {
346    /// A system prompt. Appears once, at the start of the conversation.
347    System { content: String },
348    /// A user message (text only).
349    User { content: String },
350    /// A user message with multimodal content (text + images + video + audio).
351    UserMultimodal { content: Vec<ContentBlock> },
352    /// An assistant response, possibly with tool calls.
353    Assistant {
354        #[serde(default)]
355        content: String,
356        #[serde(default)]
357        tool_calls: Vec<ToolCall>,
358    },
359    /// The result of executing a tool call.
360    ToolResult {
361        tool_use_id: String,
362        content: String,
363    },
364    /// Provider-specific output items that need to round-trip
365    /// verbatim across turns. The OpenAI Responses API returns
366    /// reasoning blobs, encrypted_content, web-search results, etc.
367    /// as opaque structured items; the next request must include
368    /// them in the same form to preserve provider-side state.
369    ///
370    /// `protocol` identifies the provider format that produced the
371    /// items (currently `"openai-responses"`). Builder paths that
372    /// don't recognize the protocol drop the variant — there is no
373    /// portable rendering across providers.
374    ProviderOutputItems {
375        protocol: String,
376        items: Vec<serde_json::Value>,
377    },
378}
379
380/// Constraint on the model's response shape. Distinct from `tools` —
381/// tools are a side-channel for action invocation; `response_format`
382/// constrains the *primary* text output to be parseable JSON, optionally
383/// against a caller-supplied schema.
384///
385/// Provider mapping (handled in `protocol.rs`):
386/// * **OpenAI / Azure / OpenAI-compatible**: `response_format: {type: "json_schema", json_schema: {schema, strict, name}}`
387///   (or `{type: "json_object"}` for the looser variant). Strict mode rejects
388///   any deviation from the schema.
389/// * **Google (Gemini)**: `response_mime_type: "application/json"` plus
390///   optional `response_schema`.
391/// * **Anthropic**: no native field as of early 2026 — the schema is
392///   logged at `warn` level and dropped. Callers needing schema-validated
393///   output on Claude should fall back to the `tools` + `tool_choice="required"`
394///   coercion idiom (which is what worked before this field landed).
395///
396/// `JsonObject` is the looser variant — tells the provider "emit valid
397/// JSON, no schema check required". Use when the schema is too dynamic
398/// to spell out but the parse contract still matters.
399#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
400#[serde(tag = "type", rename_all = "snake_case")]
401pub enum ResponseFormat {
402    /// JSON output validated against the provided schema. `strict: true`
403    /// asks the provider to reject any deviation; `false` makes the
404    /// schema a best-effort hint. `name` is OpenAI-specific (the
405    /// `json_schema.name` field, max 64 chars, alphanumerics + `-_`);
406    /// other providers ignore it.
407    JsonSchema {
408        schema: serde_json::Value,
409        #[serde(default)]
410        strict: bool,
411        #[serde(default, skip_serializing_if = "Option::is_none")]
412        name: Option<String>,
413    },
414    /// Plain JSON-mode output — the provider emits valid JSON without
415    /// schema enforcement.
416    JsonObject,
417}
418
419/// A text generation request.
420///
421/// `Default` is derived so call sites can mutate just the fields
422/// they care about: `GenerateRequest { prompt: "...".into(), ..Default::default() }`.
423/// The default `prompt = ""` is useless on its own — callers always
424/// override it — but the `..Default::default()` shorthand stops the
425/// per-call-site mechanical churn every time a new optional field
426/// lands (closes #109).
427#[derive(Debug, Clone, Default, Serialize, Deserialize)]
428pub struct GenerateRequest {
429    /// The prompt to complete (first user message for single-turn).
430    pub prompt: String,
431    /// Optional model override.
432    pub model: Option<String>,
433    /// Generation parameters.
434    #[serde(default)]
435    pub params: GenerateParams,
436    /// Optional memory context to prepend to the prompt.
437    /// When provided, this is injected as a system-level context block
438    /// before the user prompt, grounding the model's response.
439    #[serde(default)]
440    pub context: Option<String>,
441    /// Optional tool definitions for structured tool_use.
442    /// When provided, the model may return tool_calls instead of text.
443    /// Each tool is a JSON object with: name, description, parameters (JSON Schema).
444    #[serde(default, skip_serializing_if = "Option::is_none")]
445    pub tools: Option<Vec<serde_json::Value>>,
446    /// Optional images for vision models.
447    /// When provided with a single-turn prompt, these are included as image content blocks
448    /// in the user message. For multi-turn with `messages`, use `UserMultimodal` variants instead.
449    #[serde(default, skip_serializing_if = "Option::is_none")]
450    pub images: Option<Vec<ContentBlock>>,
451    /// Optional multi-turn conversation history.
452    /// When provided, the backend builds a proper multi-turn message array
453    /// instead of a single user message. The `prompt` field is ignored when
454    /// messages are present.
455    #[serde(default, skip_serializing_if = "Option::is_none")]
456    pub messages: Option<Vec<Message>>,
457    /// Enable prompt caching for Anthropic API.
458    /// When true, system prompt and tools are marked with cache_control breakpoints,
459    /// enabling cache reuse across parent/child agent calls sharing the same prefix.
460    #[serde(default)]
461    pub cache_control: bool,
462    /// Constrain output to JSON (optionally schema-validated). See
463    /// [`ResponseFormat`] for the per-provider mapping. Defaults to
464    /// `None` — free-form text.
465    #[serde(default, skip_serializing_if = "Option::is_none")]
466    pub response_format: Option<ResponseFormat>,
467    /// Caller-supplied routing intent. None preserves the existing
468    /// adaptive vs. pinned-model behavior. When `Some`, the adaptive
469    /// router uses the hint to filter candidates (hard `require`),
470    /// override task selection, and bias the score profile
471    /// (`prefer_local`). See [`crate::intent::IntentHint`].
472    #[serde(default, skip_serializing_if = "Option::is_none")]
473    pub intent: Option<crate::intent::IntentHint>,
474}
475
476/// Wrap a raw prompt in Qwen3 chat format if it's not already formatted.
477///
478/// Thinking behavior follows the caller-supplied [`ThinkingMode`]:
479/// * `Auto` — no directive injected; Qwen3's trained default (thinking
480///   on) applies.
481/// * `On` — the documented `/think` directive is appended to the system
482///   message on its own line, and the model is allowed to emit a full
483///   `<think>...</think>` block before its answer.
484/// * `Off` — `/no_think` is appended to the system message *and* an
485///   empty `<think>\n\n</think>` block is pre-filled after the assistant
486///   marker, matching upstream Qwen3's `enable_thinking=False` jinja
487///   template. The pre-filled closed tags structurally prevent the
488///   model from emitting a thinking block even if the directive is
489///   contradicted later in the prompt.
490///
491/// When context is provided it is injected into the system message to
492/// ground the model's response with memory. The directive always
493/// appears *after* the context blob so user-supplied memory cannot
494/// nudge the directive's parse position.
495pub fn apply_chat_template(
496    prompt: &str,
497    context: Option<&str>,
498    thinking: ThinkingMode,
499) -> String {
500    if prompt.contains("<|im_start|>") {
501        return prompt.to_string();
502    }
503    // Directive goes on its own line at the end of the system message
504    // (never concatenated onto prose) so Qwen3's chat template parser
505    // sees `/think`/`/no_think` as a standalone token, not as part of
506    // "assistant. /no_think".
507    let directive_line = match thinking.directive() {
508        Some(d) => format!("\n{d}"),
509        None => String::new(),
510    };
511    // For Off, pre-fill a closed empty thinking block after the
512    // assistant marker. This mirrors upstream Qwen3's jinja behavior
513    // when `enable_thinking=False` and is the hard-switch (structural)
514    // form of the mode, whereas `/no_think` alone is a soft directive.
515    let thinking_prefill = match thinking {
516        ThinkingMode::Off => "<think>\n\n</think>\n\n",
517        _ => "",
518    };
519    match context {
520        Some(ctx) => format!(
521            "<|im_start|>system\nYou are a helpful assistant. Use the following context to inform your response.\n\n{ctx}{directive_line}<|im_end|>\n\
522             <|im_start|>user\n{prompt}<|im_end|>\n\
523             <|im_start|>assistant\n{thinking_prefill}"
524        ),
525        None => format!(
526            "<|im_start|>system\nYou are a helpful assistant.{directive_line}<|im_end|>\n\
527             <|im_start|>user\n{prompt}<|im_end|>\n\
528             <|im_start|>assistant\n{thinking_prefill}"
529        ),
530    }
531}
532
533/// Strip Qwen3 `<think>...</think>` blocks from model output, honoring
534/// the caller's requested [`ThinkingMode`]:
535///
536/// * `On` — the caller explicitly asked for reasoning; return the raw
537///   text verbatim so `<think>...</think>` is visible.
538/// * `Auto` / `Off` — strip the thinking block and return only the
539///   post-thinking answer. If the output contains an opening `<think>`
540///   without a closing tag (truncation or stop before the model
541///   finished thinking) return an empty string rather than leaking a
542///   dangling tag to the caller.
543pub fn strip_thinking(text: &str, thinking: ThinkingMode) -> String {
544    if matches!(thinking, ThinkingMode::On) {
545        return text.to_string();
546    }
547    strip_thinking_block(text)
548}
549
550/// Remove a leading `<think>...</think>` block unconditionally.
551/// Returns "" if `<think>` opens but never closes (incomplete output).
552///
553/// When that "opened but never closed" branch fires, log a warn line
554/// — the caller is about to receive an empty string for what was
555/// almost certainly a budget-truncation. Surfaces issue #168's root
556/// cause without changing the return contract: callers (e.g. car-cli)
557/// that look at stderr can tell users to either bump
558/// `--max-tokens` or pass `--thinking off`. The decision lives in
559/// the strip helper because every text-completion path funnels
560/// through it; logging at the call sites would be a lot of
561/// duplication.
562fn strip_thinking_block(text: &str) -> String {
563    if let Some(end) = text.find("</think>") {
564        text[end + 8..].trim_start().to_string()
565    } else if text.contains("<think>") {
566        tracing::warn!(
567            target: "car_inference::tasks::generate",
568            raw_len = text.len(),
569            "model output opened <think> but never closed it — \
570             likely truncated by max_tokens; returning empty text. \
571             Increase max_tokens, or set thinking=off to suppress \
572             the reasoning phase."
573        );
574        String::new()
575    } else {
576        text.to_string()
577    }
578}
579
580/// Callback for FLARE-style re-retrieval during generation.
581/// Called with partial generation text, returns additional context or None.
582pub type RetrievalCallback = Box<dyn Fn(&str) -> Option<String> + Send>;
583
584#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
585/// Generate text from a prompt using the loaded model.
586///
587/// Returns `(text, time_to_first_token_ms)`. TTFT is measured from
588/// function entry through prefill to the moment the first generated
589/// token has been sampled — the user-visible "did anything happen yet"
590/// gate. `None` only when the prompt encodes to zero tokens (degenerate
591/// input).
592pub async fn generate(
593    backend: &mut CandleBackend,
594    req: GenerateRequest,
595) -> Result<(String, Option<u64>), InferenceError> {
596    let start = std::time::Instant::now();
597
598    // Reset KV cache so each generation starts fresh (prevents cross-call state bleed)
599    backend.clear_kv_cache();
600
601    let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
602    let tokens = backend.encode(&formatted)?;
603    let eos = backend.eos_token_id();
604    let eos_alt = backend.token_id("<|im_end|>");
605    let params = &req.params;
606
607    if tokens.is_empty() {
608        return Ok((String::new(), None));
609    }
610
611    // Truncate to model's max context length minus generation headroom.
612    // This prevents KV cache overflow on long prompts.
613    let max_ctx = backend.context_length().unwrap_or(32768);
614    let headroom = params.max_tokens.min(max_ctx / 4);
615    let max_prompt = max_ctx.saturating_sub(headroom);
616    let tokens = if tokens.len() > max_prompt {
617        eprintln!(
618            "[car-inference] truncating prompt from {} to {} tokens (context_length={})",
619            tokens.len(),
620            max_prompt,
621            max_ctx
622        );
623        tokens[tokens.len() - max_prompt..].to_vec()
624    } else {
625        tokens
626    };
627
628    let mut generated = Vec::new();
629
630    // Prefill: process all prompt tokens, sample first generated token from prefill logits
631    let logits = backend.forward(&tokens, 0)?;
632    let mut next_token = sample_token(&logits, params)?;
633    let ttft_ms = Some(start.elapsed().as_millis() as u64);
634
635    for _i in 0..params.max_tokens {
636        // Check EOS
637        if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
638        {
639            break;
640        }
641
642        generated.push(next_token);
643
644        // Check stop sequences
645        if !params.stop.is_empty() {
646            let text_so_far = backend.decode(&generated)?;
647            if params.stop.iter().any(|s| text_so_far.contains(s)) {
648                break;
649            }
650        }
651
652        // Generate next token
653        let pos = tokens.len() + generated.len() - 1;
654        let logits = backend.forward(&[next_token], pos)?;
655        next_token = sample_token(&logits, params)?;
656    }
657
658    let text = backend.decode(&generated)?;
659    Ok((strip_thinking(&text, params.thinking), ttft_ms))
660}
661
662#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
663/// Generate with FLARE-style confidence-triggered re-retrieval.
664///
665/// Monitors token logit confidence during generation. When a window of
666/// low-confidence tokens is detected, pauses, re-queries memory with the
667/// partial generation, and resumes with augmented context.
668pub async fn generate_with_retrieval(
669    backend: &mut CandleBackend,
670    mut req: GenerateRequest,
671    retrieval_cb: RetrievalCallback,
672) -> Result<String, InferenceError> {
673    // First pass: generate normally
674    backend.clear_kv_cache();
675    let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
676    let tokens = backend.encode(&formatted)?;
677    let eos = backend.eos_token_id();
678    let eos_alt = backend.token_id("<|im_end|>");
679    let params = req.params.clone();
680
681    if tokens.is_empty() {
682        return Ok(String::new());
683    }
684
685    let mut generated = Vec::new();
686    let mut low_confidence_count = 0u32;
687    let mut retrieval_attempts = 0u32;
688    let max_retrievals = 2;
689    let confidence_threshold = 0.4f32;
690    let low_confidence_window = 3u32;
691
692    let logits = backend.forward(&tokens, 0)?;
693    let mut next_token = sample_token(&logits, &params)?;
694
695    for _i in 0..params.max_tokens {
696        if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
697        {
698            break;
699        }
700
701        generated.push(next_token);
702
703        // Generate next token and check confidence
704        let pos = tokens.len() + generated.len() - 1;
705        let logits = backend.forward(&[next_token], pos)?;
706
707        // Check max logit probability for confidence
708        let logits_f32: Vec<f32> = logits
709            .squeeze(0)
710            .unwrap_or(logits.clone())
711            .to_dtype(candle_core::DType::F32)
712            .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
713            .to_vec1()
714            .unwrap_or_default();
715
716        if !logits_f32.is_empty() {
717            // Compute softmax max probability
718            let max_logit = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
719            let exp_sum: f32 = logits_f32.iter().map(|&v| (v - max_logit).exp()).sum();
720            let max_prob = 1.0 / exp_sum; // probability of the top token
721
722            if max_prob < confidence_threshold {
723                low_confidence_count += 1;
724            } else {
725                low_confidence_count = 0;
726            }
727
728            // Trigger re-retrieval after sustained low confidence
729            if low_confidence_count >= low_confidence_window && retrieval_attempts < max_retrievals
730            {
731                retrieval_attempts += 1;
732                low_confidence_count = 0;
733
734                // Use partial generation as re-retrieval query
735                let partial = backend.decode(&generated)?;
736                if let Some(new_context) = retrieval_cb(&partial) {
737                    // Restart generation with augmented context
738                    let combined_context = match req.context.take() {
739                        Some(old) => format!("{}\n\n{}", old, new_context),
740                        None => new_context,
741                    };
742                    req.context = Some(combined_context);
743
744                    // Re-encode and restart
745                    backend.clear_kv_cache();
746                    let new_formatted =
747                        apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
748                    let new_tokens = backend.encode(&new_formatted)?;
749                    generated.clear();
750
751                    let logits = backend.forward(&new_tokens, 0)?;
752                    next_token = sample_token(&logits, &params)?;
753                    continue;
754                }
755            }
756        }
757
758        next_token = sample_token(&logits, &params)?;
759    }
760
761    let text = backend.decode(&generated)?;
762    Ok(strip_thinking(&text, params.thinking))
763}
764
765#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
766/// Sample a token, suppressing specific token IDs (set to -inf before sampling).
767pub fn sample_token_suppress(
768    logits: &Tensor,
769    params: &GenerateParams,
770    suppress: &[u32],
771) -> Result<u32, InferenceError> {
772    if suppress.is_empty() {
773        return sample_token(logits, params);
774    }
775    // Clone logits and set suppressed tokens to -inf
776    let mut logits_vec: Vec<f32> = logits
777        .squeeze(0)
778        .unwrap_or(logits.clone())
779        .to_dtype(candle_core::DType::F32)
780        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
781        .to_vec1()
782        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
783    // Handle 2D logits (take last row)
784    let dims = logits.dims();
785    if dims.len() == 2 {
786        let vocab = dims[dims.len() - 1];
787        let start = logits_vec.len() - vocab;
788        logits_vec = logits_vec[start..].to_vec();
789    }
790    for &id in suppress {
791        if (id as usize) < logits_vec.len() {
792            logits_vec[id as usize] = f32::NEG_INFINITY;
793        }
794    }
795    let modified = Tensor::from_vec(
796        logits_vec,
797        logits.squeeze(0).unwrap_or(logits.clone()).shape(),
798        logits.device(),
799    )
800    .map_err(|e| InferenceError::InferenceFailed(format!("from_vec: {e}")))?;
801    sample_token(&modified, params)
802}
803
804#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
805/// Sample a token from logits using temperature + top-p + top-k.
806pub fn sample_token(logits: &Tensor, params: &GenerateParams) -> Result<u32, InferenceError> {
807    let logits = logits
808        .squeeze(0)
809        .map_err(|e| InferenceError::InferenceFailed(format!("squeeze: {e}")))?;
810    let logits = logits
811        .to_dtype(candle_core::DType::F32)
812        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?;
813
814    // Get last position's logits
815    let dim = logits.dims();
816    let logits = if dim.len() == 2 {
817        logits
818            .get(dim[0] - 1)
819            .map_err(|e| InferenceError::InferenceFailed(format!("get last: {e}")))?
820    } else {
821        logits
822    };
823
824    // Greedy decoding
825    if params.temperature <= 0.0 {
826        let token = logits
827            .argmax(0)
828            .map_err(|e| InferenceError::InferenceFailed(format!("argmax: {e}")))?
829            .to_scalar::<u32>()
830            .map_err(|e| InferenceError::InferenceFailed(format!("scalar: {e}")))?;
831        return Ok(token);
832    }
833
834    // Temperature scaling
835    let logits = (&logits / params.temperature)
836        .map_err(|e| InferenceError::InferenceFailed(format!("temp scale: {e}")))?;
837
838    let mut logits_vec: Vec<f32> = logits
839        .to_vec1()
840        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
841
842    // Top-k filtering
843    if params.top_k > 0 && params.top_k < logits_vec.len() {
844        let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
845        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
846        let threshold = indexed[params.top_k].1;
847        for v in &mut logits_vec {
848            if *v < threshold {
849                *v = f32::NEG_INFINITY;
850            }
851        }
852    }
853
854    // Softmax
855    let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
856    let exp: Vec<f32> = logits_vec.iter().map(|&v| (v - max_logit).exp()).collect();
857    let sum: f32 = exp.iter().sum();
858    let mut probs: Vec<f32> = exp.iter().map(|&v| v / sum).collect();
859
860    // Top-p (nucleus) filtering
861    if params.top_p < 1.0 {
862        let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
863        sorted_indices.sort_by(|&a, &b| {
864            probs[b]
865                .partial_cmp(&probs[a])
866                .unwrap_or(std::cmp::Ordering::Equal)
867        });
868
869        let mut cumsum = 0.0f32;
870        let mut cutoff_idx = sorted_indices.len();
871        for (i, &idx) in sorted_indices.iter().enumerate() {
872            cumsum += probs[idx];
873            if cumsum > params.top_p as f32 {
874                cutoff_idx = i + 1;
875                break;
876            }
877        }
878
879        let keep: std::collections::HashSet<usize> =
880            sorted_indices[..cutoff_idx].iter().copied().collect();
881        for (i, p) in probs.iter_mut().enumerate() {
882            if !keep.contains(&i) {
883                *p = 0.0;
884            }
885        }
886
887        // Renormalize
888        let sum: f32 = probs.iter().sum();
889        if sum > 0.0 {
890            for p in &mut probs {
891                *p /= sum;
892            }
893        }
894    }
895
896    // Categorical sample
897    let r: f32 = rand_f32();
898    let mut cumsum = 0.0f32;
899    for (i, &p) in probs.iter().enumerate() {
900        cumsum += p;
901        if cumsum >= r {
902            return Ok(i as u32);
903        }
904    }
905
906    // Fallback: return highest prob token
907    Ok(probs
908        .iter()
909        .enumerate()
910        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
911        .map(|(i, _)| i as u32)
912        .unwrap_or(0))
913}
914
915#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
916/// Random float in [0, 1) using the rand crate.
917fn rand_f32() -> f32 {
918    rand::random::<f32>()
919}
920
921#[cfg(test)]
922mod thinking_tests {
923    use super::*;
924
925    #[test]
926    fn auto_injects_no_directive_and_no_prefill() {
927        let out = apply_chat_template("hi", None, ThinkingMode::Auto);
928        assert!(!out.contains("/no_think"));
929        assert!(!out.contains("/think"));
930        assert!(!out.contains("<think>"));
931        assert!(out.contains("<|im_start|>user\nhi<|im_end|>"));
932    }
933
934    #[test]
935    fn off_injects_no_think_on_own_line_and_prefills_empty_think() {
936        let out = apply_chat_template("hi", None, ThinkingMode::Off);
937        // Directive on its own line, not concatenated onto prose.
938        assert!(out.contains("\n/no_think<|im_end|>"));
939        assert!(!out.contains(" /no_think"));
940        // Closed empty thinking block pre-filled after assistant marker
941        // — the upstream jinja hard-switch for enable_thinking=False.
942        assert!(out.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
943    }
944
945    #[test]
946    fn on_injects_think_and_no_prefill() {
947        let out = apply_chat_template("hi", None, ThinkingMode::On);
948        assert!(out.contains("\n/think<|im_end|>"));
949        assert!(!out.contains("/no_think"));
950        assert!(!out.contains("<think>"));
951    }
952
953    #[test]
954    fn pre_formatted_prompt_is_untouched() {
955        let pre = "<|im_start|>system\ncustom<|im_end|>\n<|im_start|>user\nhi<|im_end|>";
956        let out = apply_chat_template(pre, None, ThinkingMode::Off);
957        assert_eq!(out, pre);
958    }
959
960    #[test]
961    fn directive_appears_after_context_not_before() {
962        let out = apply_chat_template("q?", Some("some memory"), ThinkingMode::Off);
963        let ctx_idx = out.find("some memory").unwrap();
964        let directive_idx = out.find("/no_think").unwrap();
965        assert!(
966            directive_idx > ctx_idx,
967            "directive must appear after context so user memory cannot nudge the parse"
968        );
969    }
970
971    #[test]
972    fn default_params_is_auto() {
973        assert_eq!(GenerateParams::default().thinking, ThinkingMode::Auto);
974    }
975
976    #[test]
977    fn thinking_mode_serde_snake_case() {
978        let json = serde_json::to_string(&ThinkingMode::Off).unwrap();
979        assert_eq!(json, "\"off\"");
980        let parsed: ThinkingMode = serde_json::from_str("\"on\"").unwrap();
981        assert_eq!(parsed, ThinkingMode::On);
982    }
983
984    #[test]
985    fn strip_preserves_thinking_when_on() {
986        let text = "<think>reasoning here</think>the answer";
987        let out = strip_thinking(text, ThinkingMode::On);
988        assert_eq!(out, text, "On mode must return raw text with <think> visible");
989    }
990
991    #[test]
992    fn strip_removes_thinking_when_auto_or_off() {
993        let text = "<think>reasoning</think>the answer";
994        assert_eq!(strip_thinking(text, ThinkingMode::Auto), "the answer");
995        assert_eq!(strip_thinking(text, ThinkingMode::Off), "the answer");
996    }
997
998    #[test]
999    fn strip_returns_empty_on_unterminated_think() {
1000        // Output was cut off mid-thinking — don't leak the dangling tag.
1001        let text = "<think>mid-reasoning, never closed";
1002        assert_eq!(strip_thinking(text, ThinkingMode::Auto), "");
1003        assert_eq!(strip_thinking(text, ThinkingMode::Off), "");
1004        // On mode still returns the raw text — caller asked for it.
1005        assert_eq!(strip_thinking(text, ThinkingMode::On), text);
1006    }
1007
1008    #[test]
1009    fn strip_is_noop_when_no_think_tag() {
1010        let text = "just a plain answer";
1011        assert_eq!(strip_thinking(text, ThinkingMode::Auto), text);
1012        assert_eq!(strip_thinking(text, ThinkingMode::Off), text);
1013        assert_eq!(strip_thinking(text, ThinkingMode::On), text);
1014    }
1015}
1016
1017#[cfg(test)]
1018mod workload_tests {
1019    use super::*;
1020
1021    #[test]
1022    fn all_workload_weights_sum_to_one() {
1023        for w in [
1024            RoutingWorkload::Interactive,
1025            RoutingWorkload::Batch,
1026            RoutingWorkload::Background,
1027            RoutingWorkload::LocalPreferred,
1028            RoutingWorkload::Fastest,
1029        ] {
1030            let (q, l, c) = w.weights();
1031            let sum = q + l + c;
1032            assert!(
1033                (sum - 1.0).abs() < 1e-6,
1034                "weights for {w:?} sum to {sum}, expected 1.0"
1035            );
1036        }
1037    }
1038
1039    #[test]
1040    fn fastest_weights_dominate_on_latency() {
1041        let (q, l, c) = RoutingWorkload::Fastest.weights();
1042        // Latency should be the largest by a wide margin — that's the
1043        // whole point of this workload class.
1044        assert!(l > q && l > c);
1045        assert!(l >= 0.7, "latency weight too small: {l}");
1046    }
1047
1048    #[test]
1049    fn fastest_is_latency_sensitive() {
1050        assert!(RoutingWorkload::Fastest.is_latency_sensitive());
1051    }
1052}