Skip to main content

car_inference/tasks/
generate.rs

1//! Text generation with sampling.
2
3#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
4use candle_core::Tensor;
5use serde::{Deserialize, Serialize};
6
7#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
8use crate::backend::CandleBackend;
9use crate::InferenceError;
10
11/// How latency-sensitive this generation request is.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
13#[serde(rename_all = "snake_case")]
14pub enum RoutingWorkload {
15    /// User-facing, interactive request where latency matters.
16    #[default]
17    Interactive,
18    /// Batch job where latency matters somewhat, but quality/cost matter more.
19    Batch,
20    /// Background or offline work where latency is a weak concern.
21    Background,
22    /// Caller explicitly prefers on-device models. Distinct from
23    /// `Background` (which is "this is a background job, latency
24    /// barely matters"). The caller may be doing latency-sensitive
25    /// interactive work but wants the privacy / cost / offline
26    /// properties of local inference. Same `local_bonus` as
27    /// `Background` plus a slightly more quality-aware weight profile
28    /// — the caller chose local for a reason, not because the work is
29    /// throwaway.
30    LocalPreferred,
31}
32
33impl RoutingWorkload {
34    pub fn is_latency_sensitive(self) -> bool {
35        matches!(
36            self,
37            RoutingWorkload::Interactive | RoutingWorkload::LocalPreferred,
38        )
39    }
40
41    pub fn weights(self) -> (f64, f64, f64) {
42        match self {
43            RoutingWorkload::Interactive => (0.45, 0.40, 0.15),
44            RoutingWorkload::Batch => (0.60, 0.15, 0.25),
45            RoutingWorkload::Background => (0.65, 0.05, 0.30),
46            // Quality-aware (closer to Interactive) but tolerant of
47            // some latency hit since the caller chose local. Cost
48            // weight matches Batch.
49            RoutingWorkload::LocalPreferred => (0.55, 0.20, 0.25),
50        }
51    }
52
53    pub fn local_bonus(self) -> f64 {
54        match self {
55            RoutingWorkload::Interactive => 0.0,
56            RoutingWorkload::Batch => 0.08,
57            RoutingWorkload::Background => 0.15,
58            // Stronger push than Background — "prefer local" should
59            // win ties decisively, otherwise the hint is ineffective.
60            RoutingWorkload::LocalPreferred => 0.20,
61        }
62    }
63}
64
65/// Qwen3 hybrid thinking control. Qwen3 models were trained with both a
66/// "thinking" (chain-of-thought inside `<think>...</think>`) and a
67/// non-thinking mode. Upstream defaults thinking ON; `/no_think` and
68/// `/think` are the documented per-turn overrides in the chat template.
69///
70/// Scope: applies to the *single-turn* local Qwen3 path driven by
71/// [`apply_chat_template`]. The multi-turn `messages: Vec<Message>`
72/// field on [`GenerateRequest`] is consumed by remote protocol
73/// handlers (OpenAI/Anthropic/Google) which pass through user-supplied
74/// system messages verbatim; this flag is not injected there. If you
75/// need Qwen3 thinking control over a remote API, include `/think` or
76/// `/no_think` explicitly in your own system message.
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
78#[serde(rename_all = "snake_case")]
79pub enum ThinkingMode {
80    /// Let the model decide. No explicit `/think` or `/no_think` directive
81    /// is injected into the system prompt, so Qwen3's trained default
82    /// (thinking ON) applies. `<think>...</think>` output is stripped
83    /// from the returned text.
84    #[default]
85    Auto,
86    /// Inject `/think` into the system prompt to explicitly request the
87    /// thinking phase. Useful when callers want to force reasoning even
88    /// on short prompts the model would normally answer directly.
89    On,
90    /// Inject `/no_think` into the system prompt to suppress the
91    /// thinking phase for faster, more direct responses. This was the
92    /// prior hard-coded behavior; callers now opt into it explicitly.
93    Off,
94}
95
96impl ThinkingMode {
97    /// Return the directive marker to append to the system prompt, or
98    /// `None` when `Auto` (don't inject anything — trust model default).
99    pub fn directive(self) -> Option<&'static str> {
100        match self {
101            ThinkingMode::Auto => None,
102            ThinkingMode::On => Some("/think"),
103            ThinkingMode::Off => Some("/no_think"),
104        }
105    }
106}
107
108/// Parameters controlling generation behavior.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct GenerateParams {
111    /// Sampling temperature (0.0 = greedy, 1.0 = full distribution).
112    #[serde(default = "default_temperature")]
113    pub temperature: f64,
114    /// Top-p (nucleus) sampling threshold.
115    #[serde(default = "default_top_p")]
116    pub top_p: f64,
117    /// Top-k sampling (0 = disabled).
118    #[serde(default)]
119    pub top_k: usize,
120    /// Maximum tokens to generate.
121    #[serde(default = "default_max_tokens")]
122    pub max_tokens: usize,
123    /// Stop sequences — generation halts when any is produced.
124    #[serde(default)]
125    pub stop: Vec<String>,
126    /// Extended thinking budget (tokens). When > 0, enables the model's
127    /// internal reasoning/planning phase before responding. Only supported
128    /// by models with the ExtendedThinking capability (e.g., Claude).
129    #[serde(default)]
130    pub budget_tokens: usize,
131    /// Routing workload class. Interactive requests bias toward lower latency,
132    /// while batch/background work can tolerate slower high-quality local models.
133    #[serde(default)]
134    pub workload: RoutingWorkload,
135    /// Tool choice mode: "auto" (default when tools present), "required" (must use a tool),
136    /// "none" (disable tools). When "required", the model must respond with a tool call,
137    /// eliminating mixed text+JSON responses.
138    #[serde(default)]
139    pub tool_choice: Option<String>,
140    /// OpenAI-compatible parallel tool call control.
141    #[serde(default)]
142    pub parallel_tool_calls: Option<bool>,
143    /// Qwen3 hybrid thinking mode control. `Auto` (default) leaves the
144    /// model at its trained default (thinking on). `On`/`Off` inject
145    /// the documented `/think` or `/no_think` directive into the chat
146    /// template. Ignored by non-Qwen3 models.
147    #[serde(default)]
148    pub thinking: ThinkingMode,
149}
150
151fn default_temperature() -> f64 {
152    0.7
153}
154fn default_top_p() -> f64 {
155    0.9
156}
157fn default_max_tokens() -> usize {
158    4096
159}
160
161impl Default for GenerateParams {
162    fn default() -> Self {
163        Self {
164            temperature: default_temperature(),
165            top_p: default_top_p(),
166            top_k: 0,
167            max_tokens: default_max_tokens(),
168            stop: Vec::new(),
169            budget_tokens: 0,
170            workload: RoutingWorkload::Interactive,
171            tool_choice: None,
172            parallel_tool_calls: None,
173            thinking: ThinkingMode::default(),
174        }
175    }
176}
177
178/// A tool call returned by the model.
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct ToolCall {
181    /// Provider-assigned tool call ID (e.g. OpenAI `call_abc123`, Anthropic `toolu_abc123`).
182    /// When present, protocol handlers use this for round-trip correlation instead of
183    /// synthesizing positional IDs like `call_0`.
184    #[serde(default, skip_serializing_if = "Option::is_none")]
185    pub id: Option<String>,
186    /// Tool/function name.
187    pub name: String,
188    /// Arguments as key-value pairs.
189    pub arguments: std::collections::HashMap<String, serde_json::Value>,
190}
191
192/// A content block in a multimodal message.
193///
194/// The image variants (`ImageBase64`, `ImageUrl`) are fully wired on
195/// the native Qwen2.5-VL backend. The video variants
196/// (`VideoPath`, `VideoUrl`, `VideoBase64`) are defined on the public
197/// request surface so higher-level tooling can express Qwen2.5-VL
198/// video-understanding payloads, but the native backend returns
199/// [`crate::InferenceError::UnsupportedMode`] for them until the
200/// video-tokenization path lands. Remote multimodal providers
201/// (Anthropic, Google Vertex) accept them through the protocol
202/// handlers today.
203#[derive(Debug, Clone, Serialize, Deserialize)]
204#[serde(tag = "type", rename_all = "snake_case")]
205pub enum ContentBlock {
206    /// Plain text content.
207    Text { text: String },
208    /// Base64-encoded image.
209    ImageBase64 {
210        /// Base64-encoded image data.
211        data: String,
212        /// MIME type (e.g., "image/png", "image/jpeg").
213        media_type: String,
214    },
215    /// Image from URL.
216    ImageUrl {
217        /// URL of the image.
218        url: String,
219        /// Detail level for image processing ("auto", "low", "high").
220        #[serde(default = "default_detail")]
221        detail: String,
222    },
223    /// Video loaded from a local filesystem path. Qwen2.5-VL samples
224    /// the clip at `fps` frames/sec (default: backend-chosen) and
225    /// caps at `max_frames` to respect context budgets.
226    VideoPath {
227        path: String,
228        #[serde(default, skip_serializing_if = "Option::is_none")]
229        fps: Option<f32>,
230        #[serde(default, skip_serializing_if = "Option::is_none")]
231        max_frames: Option<u32>,
232    },
233    /// Video accessible over HTTP(S). Semantics as [`ContentBlock::VideoPath`].
234    VideoUrl {
235        url: String,
236        #[serde(default, skip_serializing_if = "Option::is_none")]
237        fps: Option<f32>,
238        #[serde(default, skip_serializing_if = "Option::is_none")]
239        max_frames: Option<u32>,
240    },
241    /// Base64-encoded video bytes. Prefer `VideoPath` when possible;
242    /// inline base64 is expensive to round-trip.
243    VideoBase64 {
244        data: String,
245        media_type: String,
246        #[serde(default, skip_serializing_if = "Option::is_none")]
247        fps: Option<f32>,
248        #[serde(default, skip_serializing_if = "Option::is_none")]
249        max_frames: Option<u32>,
250    },
251    /// Audio loaded from a local filesystem path. Used for
252    /// audio-understanding models (Gemma 4 small variants, Gemini).
253    AudioPath {
254        path: String,
255        /// Optional explicit sample-rate hint. Most backends will
256        /// resample internally; this is a best-effort declaration.
257        #[serde(default, skip_serializing_if = "Option::is_none")]
258        sample_rate: Option<u32>,
259    },
260    /// Audio accessible over HTTP(S).
261    AudioUrl {
262        url: String,
263        #[serde(default, skip_serializing_if = "Option::is_none")]
264        sample_rate: Option<u32>,
265    },
266    /// Base64-encoded audio bytes.
267    AudioBase64 {
268        data: String,
269        media_type: String,
270        #[serde(default, skip_serializing_if = "Option::is_none")]
271        sample_rate: Option<u32>,
272    },
273}
274
275impl ContentBlock {
276    /// Return true if this block carries video data (any encoding).
277    /// Used by backends that need to refuse video inputs until the
278    /// tokenization path is wired.
279    pub fn is_video(&self) -> bool {
280        matches!(
281            self,
282            ContentBlock::VideoPath { .. }
283                | ContentBlock::VideoUrl { .. }
284                | ContentBlock::VideoBase64 { .. }
285        )
286    }
287
288    /// Return true if this block carries audio data (any encoding).
289    /// Used by backends that need to refuse audio inputs until the
290    /// tokenization path is wired. Gemma 4 small variants and Gemini
291    /// accept audio; everything else in CAR rejects with
292    /// `UnsupportedMode`.
293    pub fn is_audio(&self) -> bool {
294        matches!(
295            self,
296            ContentBlock::AudioPath { .. }
297                | ContentBlock::AudioUrl { .. }
298                | ContentBlock::AudioBase64 { .. }
299        )
300    }
301}
302
303fn default_detail() -> String {
304    "auto".to_string()
305}
306
307/// A message in a multi-turn conversation.
308///
309/// The `System` variant exists so callers can express a first-class
310/// system prompt inside `messages: Vec<Message>` without threading it
311/// through the legacy `context: Option<String>` field on
312/// [`GenerateRequest`]. Protocol handlers and local chat templates
313/// that have a native system-role slot (OpenAI, Anthropic, Gemini,
314/// Gemma 4, Qwen) emit it in the right place; ones that don't can
315/// fold it into the first user turn.
316#[derive(Debug, Clone, Serialize, Deserialize)]
317#[serde(tag = "role", rename_all = "snake_case")]
318pub enum Message {
319    /// A system prompt. Appears once, at the start of the conversation.
320    System { content: String },
321    /// A user message (text only).
322    User { content: String },
323    /// A user message with multimodal content (text + images + video + audio).
324    UserMultimodal { content: Vec<ContentBlock> },
325    /// An assistant response, possibly with tool calls.
326    Assistant {
327        #[serde(default)]
328        content: String,
329        #[serde(default)]
330        tool_calls: Vec<ToolCall>,
331    },
332    /// The result of executing a tool call.
333    ToolResult {
334        tool_use_id: String,
335        content: String,
336    },
337}
338
339/// Constraint on the model's response shape. Distinct from `tools` —
340/// tools are a side-channel for action invocation; `response_format`
341/// constrains the *primary* text output to be parseable JSON, optionally
342/// against a caller-supplied schema.
343///
344/// Provider mapping (handled in `protocol.rs`):
345/// * **OpenAI / Azure / OpenAI-compatible**: `response_format: {type: "json_schema", json_schema: {schema, strict, name}}`
346///   (or `{type: "json_object"}` for the looser variant). Strict mode rejects
347///   any deviation from the schema.
348/// * **Google (Gemini)**: `response_mime_type: "application/json"` plus
349///   optional `response_schema`.
350/// * **Anthropic**: no native field as of early 2026 — the schema is
351///   logged at `warn` level and dropped. Callers needing schema-validated
352///   output on Claude should fall back to the `tools` + `tool_choice="required"`
353///   coercion idiom (which is what worked before this field landed).
354///
355/// `JsonObject` is the looser variant — tells the provider "emit valid
356/// JSON, no schema check required". Use when the schema is too dynamic
357/// to spell out but the parse contract still matters.
358#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
359#[serde(tag = "type", rename_all = "snake_case")]
360pub enum ResponseFormat {
361    /// JSON output validated against the provided schema. `strict: true`
362    /// asks the provider to reject any deviation; `false` makes the
363    /// schema a best-effort hint. `name` is OpenAI-specific (the
364    /// `json_schema.name` field, max 64 chars, alphanumerics + `-_`);
365    /// other providers ignore it.
366    JsonSchema {
367        schema: serde_json::Value,
368        #[serde(default)]
369        strict: bool,
370        #[serde(default, skip_serializing_if = "Option::is_none")]
371        name: Option<String>,
372    },
373    /// Plain JSON-mode output — the provider emits valid JSON without
374    /// schema enforcement.
375    JsonObject,
376}
377
378/// A text generation request.
379///
380/// `Default` is derived so call sites can mutate just the fields
381/// they care about: `GenerateRequest { prompt: "...".into(), ..Default::default() }`.
382/// The default `prompt = ""` is useless on its own — callers always
383/// override it — but the `..Default::default()` shorthand stops the
384/// per-call-site mechanical churn every time a new optional field
385/// lands (closes #109).
386#[derive(Debug, Clone, Default, Serialize, Deserialize)]
387pub struct GenerateRequest {
388    /// The prompt to complete (first user message for single-turn).
389    pub prompt: String,
390    /// Optional model override.
391    pub model: Option<String>,
392    /// Generation parameters.
393    #[serde(default)]
394    pub params: GenerateParams,
395    /// Optional memory context to prepend to the prompt.
396    /// When provided, this is injected as a system-level context block
397    /// before the user prompt, grounding the model's response.
398    #[serde(default)]
399    pub context: Option<String>,
400    /// Optional tool definitions for structured tool_use.
401    /// When provided, the model may return tool_calls instead of text.
402    /// Each tool is a JSON object with: name, description, parameters (JSON Schema).
403    #[serde(default, skip_serializing_if = "Option::is_none")]
404    pub tools: Option<Vec<serde_json::Value>>,
405    /// Optional images for vision models.
406    /// When provided with a single-turn prompt, these are included as image content blocks
407    /// in the user message. For multi-turn with `messages`, use `UserMultimodal` variants instead.
408    #[serde(default, skip_serializing_if = "Option::is_none")]
409    pub images: Option<Vec<ContentBlock>>,
410    /// Optional multi-turn conversation history.
411    /// When provided, the backend builds a proper multi-turn message array
412    /// instead of a single user message. The `prompt` field is ignored when
413    /// messages are present.
414    #[serde(default, skip_serializing_if = "Option::is_none")]
415    pub messages: Option<Vec<Message>>,
416    /// Enable prompt caching for Anthropic API.
417    /// When true, system prompt and tools are marked with cache_control breakpoints,
418    /// enabling cache reuse across parent/child agent calls sharing the same prefix.
419    #[serde(default)]
420    pub cache_control: bool,
421    /// Constrain output to JSON (optionally schema-validated). See
422    /// [`ResponseFormat`] for the per-provider mapping. Defaults to
423    /// `None` — free-form text.
424    #[serde(default, skip_serializing_if = "Option::is_none")]
425    pub response_format: Option<ResponseFormat>,
426    /// Caller-supplied routing intent. None preserves the existing
427    /// adaptive vs. pinned-model behavior. When `Some`, the adaptive
428    /// router uses the hint to filter candidates (hard `require`),
429    /// override task selection, and bias the score profile
430    /// (`prefer_local`). See [`crate::intent::IntentHint`].
431    #[serde(default, skip_serializing_if = "Option::is_none")]
432    pub intent: Option<crate::intent::IntentHint>,
433}
434
435/// Wrap a raw prompt in Qwen3 chat format if it's not already formatted.
436///
437/// Thinking behavior follows the caller-supplied [`ThinkingMode`]:
438/// * `Auto` — no directive injected; Qwen3's trained default (thinking
439///   on) applies.
440/// * `On` — the documented `/think` directive is appended to the system
441///   message on its own line, and the model is allowed to emit a full
442///   `<think>...</think>` block before its answer.
443/// * `Off` — `/no_think` is appended to the system message *and* an
444///   empty `<think>\n\n</think>` block is pre-filled after the assistant
445///   marker, matching upstream Qwen3's `enable_thinking=False` jinja
446///   template. The pre-filled closed tags structurally prevent the
447///   model from emitting a thinking block even if the directive is
448///   contradicted later in the prompt.
449///
450/// When context is provided it is injected into the system message to
451/// ground the model's response with memory. The directive always
452/// appears *after* the context blob so user-supplied memory cannot
453/// nudge the directive's parse position.
454pub fn apply_chat_template(
455    prompt: &str,
456    context: Option<&str>,
457    thinking: ThinkingMode,
458) -> String {
459    if prompt.contains("<|im_start|>") {
460        return prompt.to_string();
461    }
462    // Directive goes on its own line at the end of the system message
463    // (never concatenated onto prose) so Qwen3's chat template parser
464    // sees `/think`/`/no_think` as a standalone token, not as part of
465    // "assistant. /no_think".
466    let directive_line = match thinking.directive() {
467        Some(d) => format!("\n{d}"),
468        None => String::new(),
469    };
470    // For Off, pre-fill a closed empty thinking block after the
471    // assistant marker. This mirrors upstream Qwen3's jinja behavior
472    // when `enable_thinking=False` and is the hard-switch (structural)
473    // form of the mode, whereas `/no_think` alone is a soft directive.
474    let thinking_prefill = match thinking {
475        ThinkingMode::Off => "<think>\n\n</think>\n\n",
476        _ => "",
477    };
478    match context {
479        Some(ctx) => format!(
480            "<|im_start|>system\nYou are a helpful assistant. Use the following context to inform your response.\n\n{ctx}{directive_line}<|im_end|>\n\
481             <|im_start|>user\n{prompt}<|im_end|>\n\
482             <|im_start|>assistant\n{thinking_prefill}"
483        ),
484        None => format!(
485            "<|im_start|>system\nYou are a helpful assistant.{directive_line}<|im_end|>\n\
486             <|im_start|>user\n{prompt}<|im_end|>\n\
487             <|im_start|>assistant\n{thinking_prefill}"
488        ),
489    }
490}
491
492/// Strip Qwen3 `<think>...</think>` blocks from model output, honoring
493/// the caller's requested [`ThinkingMode`]:
494///
495/// * `On` — the caller explicitly asked for reasoning; return the raw
496///   text verbatim so `<think>...</think>` is visible.
497/// * `Auto` / `Off` — strip the thinking block and return only the
498///   post-thinking answer. If the output contains an opening `<think>`
499///   without a closing tag (truncation or stop before the model
500///   finished thinking) return an empty string rather than leaking a
501///   dangling tag to the caller.
502pub fn strip_thinking(text: &str, thinking: ThinkingMode) -> String {
503    if matches!(thinking, ThinkingMode::On) {
504        return text.to_string();
505    }
506    strip_thinking_block(text)
507}
508
509/// Remove a leading `<think>...</think>` block unconditionally.
510/// Returns "" if `<think>` opens but never closes (incomplete output).
511fn strip_thinking_block(text: &str) -> String {
512    if let Some(end) = text.find("</think>") {
513        text[end + 8..].trim_start().to_string()
514    } else if text.contains("<think>") {
515        String::new()
516    } else {
517        text.to_string()
518    }
519}
520
521/// Callback for FLARE-style re-retrieval during generation.
522/// Called with partial generation text, returns additional context or None.
523pub type RetrievalCallback = Box<dyn Fn(&str) -> Option<String> + Send>;
524
525#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
526/// Generate text from a prompt using the loaded model.
527///
528/// Returns `(text, time_to_first_token_ms)`. TTFT is measured from
529/// function entry through prefill to the moment the first generated
530/// token has been sampled — the user-visible "did anything happen yet"
531/// gate. `None` only when the prompt encodes to zero tokens (degenerate
532/// input).
533pub async fn generate(
534    backend: &mut CandleBackend,
535    req: GenerateRequest,
536) -> Result<(String, Option<u64>), InferenceError> {
537    let start = std::time::Instant::now();
538
539    // Reset KV cache so each generation starts fresh (prevents cross-call state bleed)
540    backend.clear_kv_cache();
541
542    let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
543    let tokens = backend.encode(&formatted)?;
544    let eos = backend.eos_token_id();
545    let eos_alt = backend.token_id("<|im_end|>");
546    let params = &req.params;
547
548    if tokens.is_empty() {
549        return Ok((String::new(), None));
550    }
551
552    // Truncate to model's max context length minus generation headroom.
553    // This prevents KV cache overflow on long prompts.
554    let max_ctx = backend.context_length().unwrap_or(32768);
555    let headroom = params.max_tokens.min(max_ctx / 4);
556    let max_prompt = max_ctx.saturating_sub(headroom);
557    let tokens = if tokens.len() > max_prompt {
558        eprintln!(
559            "[car-inference] truncating prompt from {} to {} tokens (context_length={})",
560            tokens.len(),
561            max_prompt,
562            max_ctx
563        );
564        tokens[tokens.len() - max_prompt..].to_vec()
565    } else {
566        tokens
567    };
568
569    let mut generated = Vec::new();
570
571    // Prefill: process all prompt tokens, sample first generated token from prefill logits
572    let logits = backend.forward(&tokens, 0)?;
573    let mut next_token = sample_token(&logits, params)?;
574    let ttft_ms = Some(start.elapsed().as_millis() as u64);
575
576    for _i in 0..params.max_tokens {
577        // Check EOS
578        if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
579        {
580            break;
581        }
582
583        generated.push(next_token);
584
585        // Check stop sequences
586        if !params.stop.is_empty() {
587            let text_so_far = backend.decode(&generated)?;
588            if params.stop.iter().any(|s| text_so_far.contains(s)) {
589                break;
590            }
591        }
592
593        // Generate next token
594        let pos = tokens.len() + generated.len() - 1;
595        let logits = backend.forward(&[next_token], pos)?;
596        next_token = sample_token(&logits, params)?;
597    }
598
599    let text = backend.decode(&generated)?;
600    Ok((strip_thinking(&text, params.thinking), ttft_ms))
601}
602
603#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
604/// Generate with FLARE-style confidence-triggered re-retrieval.
605///
606/// Monitors token logit confidence during generation. When a window of
607/// low-confidence tokens is detected, pauses, re-queries memory with the
608/// partial generation, and resumes with augmented context.
609pub async fn generate_with_retrieval(
610    backend: &mut CandleBackend,
611    mut req: GenerateRequest,
612    retrieval_cb: RetrievalCallback,
613) -> Result<String, InferenceError> {
614    // First pass: generate normally
615    backend.clear_kv_cache();
616    let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
617    let tokens = backend.encode(&formatted)?;
618    let eos = backend.eos_token_id();
619    let eos_alt = backend.token_id("<|im_end|>");
620    let params = req.params.clone();
621
622    if tokens.is_empty() {
623        return Ok(String::new());
624    }
625
626    let mut generated = Vec::new();
627    let mut low_confidence_count = 0u32;
628    let mut retrieval_attempts = 0u32;
629    let max_retrievals = 2;
630    let confidence_threshold = 0.4f32;
631    let low_confidence_window = 3u32;
632
633    let logits = backend.forward(&tokens, 0)?;
634    let mut next_token = sample_token(&logits, &params)?;
635
636    for _i in 0..params.max_tokens {
637        if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
638        {
639            break;
640        }
641
642        generated.push(next_token);
643
644        // Generate next token and check confidence
645        let pos = tokens.len() + generated.len() - 1;
646        let logits = backend.forward(&[next_token], pos)?;
647
648        // Check max logit probability for confidence
649        let logits_f32: Vec<f32> = logits
650            .squeeze(0)
651            .unwrap_or(logits.clone())
652            .to_dtype(candle_core::DType::F32)
653            .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
654            .to_vec1()
655            .unwrap_or_default();
656
657        if !logits_f32.is_empty() {
658            // Compute softmax max probability
659            let max_logit = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
660            let exp_sum: f32 = logits_f32.iter().map(|&v| (v - max_logit).exp()).sum();
661            let max_prob = 1.0 / exp_sum; // probability of the top token
662
663            if max_prob < confidence_threshold {
664                low_confidence_count += 1;
665            } else {
666                low_confidence_count = 0;
667            }
668
669            // Trigger re-retrieval after sustained low confidence
670            if low_confidence_count >= low_confidence_window && retrieval_attempts < max_retrievals
671            {
672                retrieval_attempts += 1;
673                low_confidence_count = 0;
674
675                // Use partial generation as re-retrieval query
676                let partial = backend.decode(&generated)?;
677                if let Some(new_context) = retrieval_cb(&partial) {
678                    // Restart generation with augmented context
679                    let combined_context = match req.context.take() {
680                        Some(old) => format!("{}\n\n{}", old, new_context),
681                        None => new_context,
682                    };
683                    req.context = Some(combined_context);
684
685                    // Re-encode and restart
686                    backend.clear_kv_cache();
687                    let new_formatted =
688                        apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
689                    let new_tokens = backend.encode(&new_formatted)?;
690                    generated.clear();
691
692                    let logits = backend.forward(&new_tokens, 0)?;
693                    next_token = sample_token(&logits, &params)?;
694                    continue;
695                }
696            }
697        }
698
699        next_token = sample_token(&logits, &params)?;
700    }
701
702    let text = backend.decode(&generated)?;
703    Ok(strip_thinking(&text, params.thinking))
704}
705
706#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
707/// Sample a token, suppressing specific token IDs (set to -inf before sampling).
708pub fn sample_token_suppress(
709    logits: &Tensor,
710    params: &GenerateParams,
711    suppress: &[u32],
712) -> Result<u32, InferenceError> {
713    if suppress.is_empty() {
714        return sample_token(logits, params);
715    }
716    // Clone logits and set suppressed tokens to -inf
717    let mut logits_vec: Vec<f32> = logits
718        .squeeze(0)
719        .unwrap_or(logits.clone())
720        .to_dtype(candle_core::DType::F32)
721        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
722        .to_vec1()
723        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
724    // Handle 2D logits (take last row)
725    let dims = logits.dims();
726    if dims.len() == 2 {
727        let vocab = dims[dims.len() - 1];
728        let start = logits_vec.len() - vocab;
729        logits_vec = logits_vec[start..].to_vec();
730    }
731    for &id in suppress {
732        if (id as usize) < logits_vec.len() {
733            logits_vec[id as usize] = f32::NEG_INFINITY;
734        }
735    }
736    let modified = Tensor::from_vec(
737        logits_vec,
738        logits.squeeze(0).unwrap_or(logits.clone()).shape(),
739        logits.device(),
740    )
741    .map_err(|e| InferenceError::InferenceFailed(format!("from_vec: {e}")))?;
742    sample_token(&modified, params)
743}
744
745#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
746/// Sample a token from logits using temperature + top-p + top-k.
747pub fn sample_token(logits: &Tensor, params: &GenerateParams) -> Result<u32, InferenceError> {
748    let logits = logits
749        .squeeze(0)
750        .map_err(|e| InferenceError::InferenceFailed(format!("squeeze: {e}")))?;
751    let logits = logits
752        .to_dtype(candle_core::DType::F32)
753        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?;
754
755    // Get last position's logits
756    let dim = logits.dims();
757    let logits = if dim.len() == 2 {
758        logits
759            .get(dim[0] - 1)
760            .map_err(|e| InferenceError::InferenceFailed(format!("get last: {e}")))?
761    } else {
762        logits
763    };
764
765    // Greedy decoding
766    if params.temperature <= 0.0 {
767        let token = logits
768            .argmax(0)
769            .map_err(|e| InferenceError::InferenceFailed(format!("argmax: {e}")))?
770            .to_scalar::<u32>()
771            .map_err(|e| InferenceError::InferenceFailed(format!("scalar: {e}")))?;
772        return Ok(token);
773    }
774
775    // Temperature scaling
776    let logits = (&logits / params.temperature)
777        .map_err(|e| InferenceError::InferenceFailed(format!("temp scale: {e}")))?;
778
779    let mut logits_vec: Vec<f32> = logits
780        .to_vec1()
781        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
782
783    // Top-k filtering
784    if params.top_k > 0 && params.top_k < logits_vec.len() {
785        let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
786        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
787        let threshold = indexed[params.top_k].1;
788        for v in &mut logits_vec {
789            if *v < threshold {
790                *v = f32::NEG_INFINITY;
791            }
792        }
793    }
794
795    // Softmax
796    let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
797    let exp: Vec<f32> = logits_vec.iter().map(|&v| (v - max_logit).exp()).collect();
798    let sum: f32 = exp.iter().sum();
799    let mut probs: Vec<f32> = exp.iter().map(|&v| v / sum).collect();
800
801    // Top-p (nucleus) filtering
802    if params.top_p < 1.0 {
803        let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
804        sorted_indices.sort_by(|&a, &b| {
805            probs[b]
806                .partial_cmp(&probs[a])
807                .unwrap_or(std::cmp::Ordering::Equal)
808        });
809
810        let mut cumsum = 0.0f32;
811        let mut cutoff_idx = sorted_indices.len();
812        for (i, &idx) in sorted_indices.iter().enumerate() {
813            cumsum += probs[idx];
814            if cumsum > params.top_p as f32 {
815                cutoff_idx = i + 1;
816                break;
817            }
818        }
819
820        let keep: std::collections::HashSet<usize> =
821            sorted_indices[..cutoff_idx].iter().copied().collect();
822        for (i, p) in probs.iter_mut().enumerate() {
823            if !keep.contains(&i) {
824                *p = 0.0;
825            }
826        }
827
828        // Renormalize
829        let sum: f32 = probs.iter().sum();
830        if sum > 0.0 {
831            for p in &mut probs {
832                *p /= sum;
833            }
834        }
835    }
836
837    // Categorical sample
838    let r: f32 = rand_f32();
839    let mut cumsum = 0.0f32;
840    for (i, &p) in probs.iter().enumerate() {
841        cumsum += p;
842        if cumsum >= r {
843            return Ok(i as u32);
844        }
845    }
846
847    // Fallback: return highest prob token
848    Ok(probs
849        .iter()
850        .enumerate()
851        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
852        .map(|(i, _)| i as u32)
853        .unwrap_or(0))
854}
855
856#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
857/// Random float in [0, 1) using the rand crate.
858fn rand_f32() -> f32 {
859    rand::random::<f32>()
860}
861
862#[cfg(test)]
863mod thinking_tests {
864    use super::*;
865
866    #[test]
867    fn auto_injects_no_directive_and_no_prefill() {
868        let out = apply_chat_template("hi", None, ThinkingMode::Auto);
869        assert!(!out.contains("/no_think"));
870        assert!(!out.contains("/think"));
871        assert!(!out.contains("<think>"));
872        assert!(out.contains("<|im_start|>user\nhi<|im_end|>"));
873    }
874
875    #[test]
876    fn off_injects_no_think_on_own_line_and_prefills_empty_think() {
877        let out = apply_chat_template("hi", None, ThinkingMode::Off);
878        // Directive on its own line, not concatenated onto prose.
879        assert!(out.contains("\n/no_think<|im_end|>"));
880        assert!(!out.contains(" /no_think"));
881        // Closed empty thinking block pre-filled after assistant marker
882        // — the upstream jinja hard-switch for enable_thinking=False.
883        assert!(out.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
884    }
885
886    #[test]
887    fn on_injects_think_and_no_prefill() {
888        let out = apply_chat_template("hi", None, ThinkingMode::On);
889        assert!(out.contains("\n/think<|im_end|>"));
890        assert!(!out.contains("/no_think"));
891        assert!(!out.contains("<think>"));
892    }
893
894    #[test]
895    fn pre_formatted_prompt_is_untouched() {
896        let pre = "<|im_start|>system\ncustom<|im_end|>\n<|im_start|>user\nhi<|im_end|>";
897        let out = apply_chat_template(pre, None, ThinkingMode::Off);
898        assert_eq!(out, pre);
899    }
900
901    #[test]
902    fn directive_appears_after_context_not_before() {
903        let out = apply_chat_template("q?", Some("some memory"), ThinkingMode::Off);
904        let ctx_idx = out.find("some memory").unwrap();
905        let directive_idx = out.find("/no_think").unwrap();
906        assert!(
907            directive_idx > ctx_idx,
908            "directive must appear after context so user memory cannot nudge the parse"
909        );
910    }
911
912    #[test]
913    fn default_params_is_auto() {
914        assert_eq!(GenerateParams::default().thinking, ThinkingMode::Auto);
915    }
916
917    #[test]
918    fn thinking_mode_serde_snake_case() {
919        let json = serde_json::to_string(&ThinkingMode::Off).unwrap();
920        assert_eq!(json, "\"off\"");
921        let parsed: ThinkingMode = serde_json::from_str("\"on\"").unwrap();
922        assert_eq!(parsed, ThinkingMode::On);
923    }
924
925    #[test]
926    fn strip_preserves_thinking_when_on() {
927        let text = "<think>reasoning here</think>the answer";
928        let out = strip_thinking(text, ThinkingMode::On);
929        assert_eq!(out, text, "On mode must return raw text with <think> visible");
930    }
931
932    #[test]
933    fn strip_removes_thinking_when_auto_or_off() {
934        let text = "<think>reasoning</think>the answer";
935        assert_eq!(strip_thinking(text, ThinkingMode::Auto), "the answer");
936        assert_eq!(strip_thinking(text, ThinkingMode::Off), "the answer");
937    }
938
939    #[test]
940    fn strip_returns_empty_on_unterminated_think() {
941        // Output was cut off mid-thinking — don't leak the dangling tag.
942        let text = "<think>mid-reasoning, never closed";
943        assert_eq!(strip_thinking(text, ThinkingMode::Auto), "");
944        assert_eq!(strip_thinking(text, ThinkingMode::Off), "");
945        // On mode still returns the raw text — caller asked for it.
946        assert_eq!(strip_thinking(text, ThinkingMode::On), text);
947    }
948
949    #[test]
950    fn strip_is_noop_when_no_think_tag() {
951        let text = "just a plain answer";
952        assert_eq!(strip_thinking(text, ThinkingMode::Auto), text);
953        assert_eq!(strip_thinking(text, ThinkingMode::Off), text);
954        assert_eq!(strip_thinking(text, ThinkingMode::On), text);
955    }
956}