car-inference 0.15.0

Local model inference for CAR — Candle backend with Qwen3 models
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
//! Text generation with sampling.

#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use candle_core::Tensor;
use serde::{Deserialize, Serialize};

#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use crate::backend::CandleBackend;
use crate::InferenceError;

/// How latency-sensitive this generation request is.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum RoutingWorkload {
    /// User-facing, interactive request where latency matters.
    #[default]
    Interactive,
    /// Batch job where latency matters somewhat, but quality/cost matter more.
    Batch,
    /// Background or offline work where latency is a weak concern.
    Background,
    /// Caller explicitly prefers on-device models. Distinct from
    /// `Background` (which is "this is a background job, latency
    /// barely matters"). The caller may be doing latency-sensitive
    /// interactive work but wants the privacy / cost / offline
    /// properties of local inference. Same `local_bonus` as
    /// `Background` plus a slightly more quality-aware weight profile
    /// — the caller chose local for a reason, not because the work is
    /// throwaway.
    LocalPreferred,
    /// Aggressive latency bias for time-to-first-token. Voice turns
    /// (specifically the fast track in the two-track sidecar pattern)
    /// pick this. Quality and cost are heavily downweighted; the
    /// router prefers whichever model produces a first token soonest.
    /// On macOS 26+ this typically resolves to `apple/foundation:default`
    /// via the Foundation Models system-LLM bonus. Reached via the
    /// `IntentHint::prefer_fast` flag (or `RoutingWorkload::Fastest`
    /// directly when callers know they want it).
    Fastest,
}

impl RoutingWorkload {
    pub fn is_latency_sensitive(self) -> bool {
        matches!(
            self,
            RoutingWorkload::Interactive
                | RoutingWorkload::LocalPreferred
                | RoutingWorkload::Fastest,
        )
    }

    pub fn weights(self) -> (f64, f64, f64) {
        // Tuple is `(quality, latency, cost)` — destructured at the
        // single use site `adaptive_router.rs:892`.
        match self {
            RoutingWorkload::Interactive => (0.45, 0.40, 0.15),
            RoutingWorkload::Batch => (0.60, 0.15, 0.25),
            RoutingWorkload::Background => (0.65, 0.05, 0.30),
            // Quality-aware (closer to Interactive) but tolerant of
            // some latency hit since the caller chose local. Cost
            // weight matches Batch.
            RoutingWorkload::LocalPreferred => (0.55, 0.20, 0.25),
            // Voice fast track: latency is everything. Quality and
            // cost are deliberately near-floor — first audio in
            // <500ms beats any quality gain that takes another
            // round-trip. Sums to 1.0 like every other variant.
            RoutingWorkload::Fastest => (0.10, 0.85, 0.05),
        }
    }

    pub fn local_bonus(self) -> f64 {
        match self {
            RoutingWorkload::Interactive => 0.0,
            RoutingWorkload::Batch => 0.08,
            RoutingWorkload::Background => 0.15,
            // Stronger push than Background — "prefer local" should
            // win ties decisively, otherwise the hint is ineffective.
            RoutingWorkload::LocalPreferred => 0.20,
            // Local inference avoids network round-trips entirely —
            // the single biggest latency win available. On macOS the
            // Foundation Models `system_llm_bonus` stacks on top of
            // this for `apple/foundation:default`. Match
            // LocalPreferred's bonus so cloud-streamed fast paths
            // (e.g. gpt-4o-mini) can still win when locally there's
            // no model loaded; the weight profile already strongly
            // favours latency.
            RoutingWorkload::Fastest => 0.20,
        }
    }
}

/// Qwen3 hybrid thinking control. Qwen3 models were trained with both a
/// "thinking" (chain-of-thought inside `<think>...</think>`) and a
/// non-thinking mode. Upstream defaults thinking ON; `/no_think` and
/// `/think` are the documented per-turn overrides in the chat template.
///
/// Scope: applies to the *single-turn* local Qwen3 path driven by
/// [`apply_chat_template`]. The multi-turn `messages: Vec<Message>`
/// field on [`GenerateRequest`] is consumed by remote protocol
/// handlers (OpenAI/Anthropic/Google) which pass through user-supplied
/// system messages verbatim; this flag is not injected there. If you
/// need Qwen3 thinking control over a remote API, include `/think` or
/// `/no_think` explicitly in your own system message.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ThinkingMode {
    /// Let the model decide. No explicit `/think` or `/no_think` directive
    /// is injected into the system prompt, so Qwen3's trained default
    /// (thinking ON) applies. `<think>...</think>` output is stripped
    /// from the returned text.
    #[default]
    Auto,
    /// Inject `/think` into the system prompt to explicitly request the
    /// thinking phase. Useful when callers want to force reasoning even
    /// on short prompts the model would normally answer directly.
    On,
    /// Inject `/no_think` into the system prompt to suppress the
    /// thinking phase for faster, more direct responses. This was the
    /// prior hard-coded behavior; callers now opt into it explicitly.
    Off,
}

impl ThinkingMode {
    /// Return the directive marker to append to the system prompt, or
    /// `None` when `Auto` (don't inject anything — trust model default).
    pub fn directive(self) -> Option<&'static str> {
        match self {
            ThinkingMode::Auto => None,
            ThinkingMode::On => Some("/think"),
            ThinkingMode::Off => Some("/no_think"),
        }
    }
}

/// Parameters controlling generation behavior.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateParams {
    /// Sampling temperature (0.0 = greedy, 1.0 = full distribution).
    #[serde(default = "default_temperature")]
    pub temperature: f64,
    /// Top-p (nucleus) sampling threshold.
    #[serde(default = "default_top_p")]
    pub top_p: f64,
    /// Top-k sampling (0 = disabled).
    #[serde(default)]
    pub top_k: usize,
    /// Maximum tokens to generate.
    #[serde(default = "default_max_tokens")]
    pub max_tokens: usize,
    /// Stop sequences — generation halts when any is produced.
    #[serde(default)]
    pub stop: Vec<String>,
    /// Extended thinking budget (tokens). When > 0, enables the model's
    /// internal reasoning/planning phase before responding. Only supported
    /// by models with the ExtendedThinking capability (e.g., Claude).
    #[serde(default)]
    pub budget_tokens: usize,
    /// Routing workload class. Interactive requests bias toward lower latency,
    /// while batch/background work can tolerate slower high-quality local models.
    #[serde(default)]
    pub workload: RoutingWorkload,
    /// Tool choice mode: "auto" (default when tools present), "required" (must use a tool),
    /// "none" (disable tools). When "required", the model must respond with a tool call,
    /// eliminating mixed text+JSON responses.
    #[serde(default)]
    pub tool_choice: Option<String>,
    /// OpenAI-compatible parallel tool call control.
    #[serde(default)]
    pub parallel_tool_calls: Option<bool>,
    /// Qwen3 hybrid thinking mode control. `Auto` (default) leaves the
    /// model at its trained default (thinking on). `On`/`Off` inject
    /// the documented `/think` or `/no_think` directive into the chat
    /// template. Ignored by non-Qwen3 models.
    #[serde(default)]
    pub thinking: ThinkingMode,
}

fn default_temperature() -> f64 {
    0.7
}
fn default_top_p() -> f64 {
    0.9
}
fn default_max_tokens() -> usize {
    4096
}

impl Default for GenerateParams {
    fn default() -> Self {
        Self {
            temperature: default_temperature(),
            top_p: default_top_p(),
            top_k: 0,
            max_tokens: default_max_tokens(),
            stop: Vec::new(),
            budget_tokens: 0,
            workload: RoutingWorkload::Interactive,
            tool_choice: None,
            parallel_tool_calls: None,
            thinking: ThinkingMode::default(),
        }
    }
}

/// A tool call returned by the model.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
    /// Provider-assigned tool call ID (e.g. OpenAI `call_abc123`, Anthropic `toolu_abc123`).
    /// When present, protocol handlers use this for round-trip correlation instead of
    /// synthesizing positional IDs like `call_0`.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub id: Option<String>,
    /// Tool/function name.
    pub name: String,
    /// Arguments as key-value pairs.
    pub arguments: std::collections::HashMap<String, serde_json::Value>,
}

/// A content block in a multimodal message.
///
/// The image variants (`ImageBase64`, `ImageUrl`) are fully wired on
/// the native Qwen2.5-VL backend. The video variants
/// (`VideoPath`, `VideoUrl`, `VideoBase64`) are defined on the public
/// request surface so higher-level tooling can express Qwen2.5-VL
/// video-understanding payloads, but the native backend returns
/// [`crate::InferenceError::UnsupportedMode`] for them until the
/// video-tokenization path lands. Remote multimodal providers
/// (Anthropic, Google Vertex) accept them through the protocol
/// handlers today.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
    /// Plain text content.
    Text { text: String },
    /// Base64-encoded image.
    ImageBase64 {
        /// Base64-encoded image data.
        data: String,
        /// MIME type (e.g., "image/png", "image/jpeg").
        media_type: String,
    },
    /// Image from URL.
    ImageUrl {
        /// URL of the image.
        url: String,
        /// Detail level for image processing ("auto", "low", "high").
        #[serde(default = "default_detail")]
        detail: String,
    },
    /// Video loaded from a local filesystem path. Qwen2.5-VL samples
    /// the clip at `fps` frames/sec (default: backend-chosen) and
    /// caps at `max_frames` to respect context budgets.
    VideoPath {
        path: String,
        #[serde(default, skip_serializing_if = "Option::is_none")]
        fps: Option<f32>,
        #[serde(default, skip_serializing_if = "Option::is_none")]
        max_frames: Option<u32>,
    },
    /// Video accessible over HTTP(S). Semantics as [`ContentBlock::VideoPath`].
    VideoUrl {
        url: String,
        #[serde(default, skip_serializing_if = "Option::is_none")]
        fps: Option<f32>,
        #[serde(default, skip_serializing_if = "Option::is_none")]
        max_frames: Option<u32>,
    },
    /// Base64-encoded video bytes. Prefer `VideoPath` when possible;
    /// inline base64 is expensive to round-trip.
    VideoBase64 {
        data: String,
        media_type: String,
        #[serde(default, skip_serializing_if = "Option::is_none")]
        fps: Option<f32>,
        #[serde(default, skip_serializing_if = "Option::is_none")]
        max_frames: Option<u32>,
    },
    /// Audio loaded from a local filesystem path. Used for
    /// audio-understanding models (Gemma 4 small variants, Gemini).
    AudioPath {
        path: String,
        /// Optional explicit sample-rate hint. Most backends will
        /// resample internally; this is a best-effort declaration.
        #[serde(default, skip_serializing_if = "Option::is_none")]
        sample_rate: Option<u32>,
    },
    /// Audio accessible over HTTP(S).
    AudioUrl {
        url: String,
        #[serde(default, skip_serializing_if = "Option::is_none")]
        sample_rate: Option<u32>,
    },
    /// Base64-encoded audio bytes.
    AudioBase64 {
        data: String,
        media_type: String,
        #[serde(default, skip_serializing_if = "Option::is_none")]
        sample_rate: Option<u32>,
    },
}

impl ContentBlock {
    /// Return true if this block carries video data (any encoding).
    /// Used by backends that need to refuse video inputs until the
    /// tokenization path is wired.
    pub fn is_video(&self) -> bool {
        matches!(
            self,
            ContentBlock::VideoPath { .. }
                | ContentBlock::VideoUrl { .. }
                | ContentBlock::VideoBase64 { .. }
        )
    }

    /// Return true if this block carries audio data (any encoding).
    /// Used by backends that need to refuse audio inputs until the
    /// tokenization path is wired. Gemma 4 small variants and Gemini
    /// accept audio; everything else in CAR rejects with
    /// `UnsupportedMode`.
    pub fn is_audio(&self) -> bool {
        matches!(
            self,
            ContentBlock::AudioPath { .. }
                | ContentBlock::AudioUrl { .. }
                | ContentBlock::AudioBase64 { .. }
        )
    }
}

fn default_detail() -> String {
    "auto".to_string()
}

/// A message in a multi-turn conversation.
///
/// The `System` variant exists so callers can express a first-class
/// system prompt inside `messages: Vec<Message>` without threading it
/// through the legacy `context: Option<String>` field on
/// [`GenerateRequest`]. Protocol handlers and local chat templates
/// that have a native system-role slot (OpenAI, Anthropic, Gemini,
/// Gemma 4, Qwen) emit it in the right place; ones that don't can
/// fold it into the first user turn.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "role", rename_all = "snake_case")]
pub enum Message {
    /// A system prompt. Appears once, at the start of the conversation.
    System { content: String },
    /// A user message (text only).
    User { content: String },
    /// A user message with multimodal content (text + images + video + audio).
    UserMultimodal { content: Vec<ContentBlock> },
    /// An assistant response, possibly with tool calls.
    Assistant {
        #[serde(default)]
        content: String,
        #[serde(default)]
        tool_calls: Vec<ToolCall>,
    },
    /// The result of executing a tool call.
    ToolResult {
        tool_use_id: String,
        content: String,
    },
    /// Provider-specific output items that need to round-trip
    /// verbatim across turns. The OpenAI Responses API returns
    /// reasoning blobs, encrypted_content, web-search results, etc.
    /// as opaque structured items; the next request must include
    /// them in the same form to preserve provider-side state.
    ///
    /// `protocol` identifies the provider format that produced the
    /// items (currently `"openai-responses"`). Builder paths that
    /// don't recognize the protocol drop the variant — there is no
    /// portable rendering across providers.
    ProviderOutputItems {
        protocol: String,
        items: Vec<serde_json::Value>,
    },
}

/// Constraint on the model's response shape. Distinct from `tools` —
/// tools are a side-channel for action invocation; `response_format`
/// constrains the *primary* text output to be parseable JSON, optionally
/// against a caller-supplied schema.
///
/// Provider mapping (handled in `protocol.rs`):
/// * **OpenAI / Azure / OpenAI-compatible**: `response_format: {type: "json_schema", json_schema: {schema, strict, name}}`
///   (or `{type: "json_object"}` for the looser variant). Strict mode rejects
///   any deviation from the schema.
/// * **Google (Gemini)**: `response_mime_type: "application/json"` plus
///   optional `response_schema`.
/// * **Anthropic**: no native field as of early 2026 — the schema is
///   logged at `warn` level and dropped. Callers needing schema-validated
///   output on Claude should fall back to the `tools` + `tool_choice="required"`
///   coercion idiom (which is what worked before this field landed).
///
/// `JsonObject` is the looser variant — tells the provider "emit valid
/// JSON, no schema check required". Use when the schema is too dynamic
/// to spell out but the parse contract still matters.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseFormat {
    /// JSON output validated against the provided schema. `strict: true`
    /// asks the provider to reject any deviation; `false` makes the
    /// schema a best-effort hint. `name` is OpenAI-specific (the
    /// `json_schema.name` field, max 64 chars, alphanumerics + `-_`);
    /// other providers ignore it.
    JsonSchema {
        schema: serde_json::Value,
        #[serde(default)]
        strict: bool,
        #[serde(default, skip_serializing_if = "Option::is_none")]
        name: Option<String>,
    },
    /// Plain JSON-mode output — the provider emits valid JSON without
    /// schema enforcement.
    JsonObject,
}

/// A text generation request.
///
/// `Default` is derived so call sites can mutate just the fields
/// they care about: `GenerateRequest { prompt: "...".into(), ..Default::default() }`.
/// The default `prompt = ""` is useless on its own — callers always
/// override it — but the `..Default::default()` shorthand stops the
/// per-call-site mechanical churn every time a new optional field
/// lands (closes #109).
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GenerateRequest {
    /// The prompt to complete (first user message for single-turn).
    pub prompt: String,
    /// Optional model override.
    pub model: Option<String>,
    /// Generation parameters.
    #[serde(default)]
    pub params: GenerateParams,
    /// Optional memory context to prepend to the prompt.
    /// When provided, this is injected as a system-level context block
    /// before the user prompt, grounding the model's response.
    #[serde(default)]
    pub context: Option<String>,
    /// Optional tool definitions for structured tool_use.
    /// When provided, the model may return tool_calls instead of text.
    /// Each tool is a JSON object with: name, description, parameters (JSON Schema).
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub tools: Option<Vec<serde_json::Value>>,
    /// Optional images for vision models.
    /// When provided with a single-turn prompt, these are included as image content blocks
    /// in the user message. For multi-turn with `messages`, use `UserMultimodal` variants instead.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub images: Option<Vec<ContentBlock>>,
    /// Optional multi-turn conversation history.
    /// When provided, the backend builds a proper multi-turn message array
    /// instead of a single user message. The `prompt` field is ignored when
    /// messages are present.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub messages: Option<Vec<Message>>,
    /// Enable prompt caching for Anthropic API.
    /// When true, system prompt and tools are marked with cache_control breakpoints,
    /// enabling cache reuse across parent/child agent calls sharing the same prefix.
    #[serde(default)]
    pub cache_control: bool,
    /// Constrain output to JSON (optionally schema-validated). See
    /// [`ResponseFormat`] for the per-provider mapping. Defaults to
    /// `None` — free-form text.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub response_format: Option<ResponseFormat>,
    /// Caller-supplied routing intent. None preserves the existing
    /// adaptive vs. pinned-model behavior. When `Some`, the adaptive
    /// router uses the hint to filter candidates (hard `require`),
    /// override task selection, and bias the score profile
    /// (`prefer_local`). See [`crate::intent::IntentHint`].
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub intent: Option<crate::intent::IntentHint>,
}

/// Wrap a raw prompt in Qwen3 chat format if it's not already formatted.
///
/// Thinking behavior follows the caller-supplied [`ThinkingMode`]:
/// * `Auto` — no directive injected; Qwen3's trained default (thinking
///   on) applies.
/// * `On` — the documented `/think` directive is appended to the system
///   message on its own line, and the model is allowed to emit a full
///   `<think>...</think>` block before its answer.
/// * `Off` — `/no_think` is appended to the system message *and* an
///   empty `<think>\n\n</think>` block is pre-filled after the assistant
///   marker, matching upstream Qwen3's `enable_thinking=False` jinja
///   template. The pre-filled closed tags structurally prevent the
///   model from emitting a thinking block even if the directive is
///   contradicted later in the prompt.
///
/// When context is provided it is injected into the system message to
/// ground the model's response with memory. The directive always
/// appears *after* the context blob so user-supplied memory cannot
/// nudge the directive's parse position.
pub fn apply_chat_template(prompt: &str, context: Option<&str>, thinking: ThinkingMode) -> String {
    if prompt.contains("<|im_start|>") {
        return prompt.to_string();
    }
    // Directive goes on its own line at the end of the system message
    // (never concatenated onto prose) so Qwen3's chat template parser
    // sees `/think`/`/no_think` as a standalone token, not as part of
    // "assistant. /no_think".
    let directive_line = match thinking.directive() {
        Some(d) => format!("\n{d}"),
        None => String::new(),
    };
    // For Off, pre-fill a closed empty thinking block after the
    // assistant marker. This mirrors upstream Qwen3's jinja behavior
    // when `enable_thinking=False` and is the hard-switch (structural)
    // form of the mode, whereas `/no_think` alone is a soft directive.
    let thinking_prefill = match thinking {
        ThinkingMode::Off => "<think>\n\n</think>\n\n",
        _ => "",
    };
    match context {
        Some(ctx) => format!(
            "<|im_start|>system\nYou are a helpful assistant. Use the following context to inform your response.\n\n{ctx}{directive_line}<|im_end|>\n\
             <|im_start|>user\n{prompt}<|im_end|>\n\
             <|im_start|>assistant\n{thinking_prefill}"
        ),
        None => format!(
            "<|im_start|>system\nYou are a helpful assistant.{directive_line}<|im_end|>\n\
             <|im_start|>user\n{prompt}<|im_end|>\n\
             <|im_start|>assistant\n{thinking_prefill}"
        ),
    }
}

/// Strip Qwen3 `<think>...</think>` blocks from model output, honoring
/// the caller's requested [`ThinkingMode`]:
///
/// * `On` — the caller explicitly asked for reasoning; return the raw
///   text verbatim so `<think>...</think>` is visible.
/// * `Auto` / `Off` — strip the thinking block and return only the
///   post-thinking answer. If the output contains an opening `<think>`
///   without a closing tag (truncation or stop before the model
///   finished thinking) return an empty string rather than leaking a
///   dangling tag to the caller.
pub fn strip_thinking(text: &str, thinking: ThinkingMode) -> String {
    if matches!(thinking, ThinkingMode::On) {
        return text.to_string();
    }
    strip_thinking_block(text)
}

/// Remove a leading `<think>...</think>` block unconditionally.
/// Returns "" if `<think>` opens but never closes (incomplete output).
///
/// When that "opened but never closed" branch fires, log a warn line
/// — the caller is about to receive an empty string for what was
/// almost certainly a budget-truncation. Surfaces issue #168's root
/// cause without changing the return contract: callers (e.g. car-cli)
/// that look at stderr can tell users to either bump
/// `--max-tokens` or pass `--thinking off`. The decision lives in
/// the strip helper because every text-completion path funnels
/// through it; logging at the call sites would be a lot of
/// duplication.
fn strip_thinking_block(text: &str) -> String {
    if let Some(end) = text.find("</think>") {
        text[end + 8..].trim_start().to_string()
    } else if text.contains("<think>") {
        tracing::warn!(
            target: "car_inference::tasks::generate",
            raw_len = text.len(),
            "model output opened <think> but never closed it — \
             likely truncated by max_tokens; returning empty text. \
             Increase max_tokens, or set thinking=off to suppress \
             the reasoning phase."
        );
        String::new()
    } else {
        text.to_string()
    }
}

/// Callback for FLARE-style re-retrieval during generation.
/// Called with partial generation text, returns additional context or None.
pub type RetrievalCallback = Box<dyn Fn(&str) -> Option<String> + Send>;

#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
/// Generate text from a prompt using the loaded model.
///
/// Returns `(text, time_to_first_token_ms)`. TTFT is measured from
/// function entry through prefill to the moment the first generated
/// token has been sampled — the user-visible "did anything happen yet"
/// gate. `None` only when the prompt encodes to zero tokens (degenerate
/// input).
pub async fn generate(
    backend: &mut CandleBackend,
    req: GenerateRequest,
) -> Result<(String, Option<u64>), InferenceError> {
    let start = std::time::Instant::now();

    // Reset KV cache so each generation starts fresh (prevents cross-call state bleed)
    backend.clear_kv_cache();

    let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
    let tokens = backend.encode(&formatted)?;
    let eos = backend.eos_token_id();
    let eos_alt = backend.token_id("<|im_end|>");
    let params = &req.params;

    if tokens.is_empty() {
        return Ok((String::new(), None));
    }

    // Truncate to model's max context length minus generation headroom.
    // This prevents KV cache overflow on long prompts.
    let max_ctx = backend.context_length().unwrap_or(32768);
    let headroom = params.max_tokens.min(max_ctx / 4);
    let max_prompt = max_ctx.saturating_sub(headroom);
    let tokens = if tokens.len() > max_prompt {
        eprintln!(
            "[car-inference] truncating prompt from {} to {} tokens (context_length={})",
            tokens.len(),
            max_prompt,
            max_ctx
        );
        tokens[tokens.len() - max_prompt..].to_vec()
    } else {
        tokens
    };

    let mut generated = Vec::new();

    // Prefill: process all prompt tokens, sample first generated token from prefill logits
    let logits = backend.forward(&tokens, 0)?;
    let mut next_token = sample_token(&logits, params)?;
    let ttft_ms = Some(start.elapsed().as_millis() as u64);

    for _i in 0..params.max_tokens {
        // Check EOS
        if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
        {
            break;
        }

        generated.push(next_token);

        // Check stop sequences
        if !params.stop.is_empty() {
            let text_so_far = backend.decode(&generated)?;
            if params.stop.iter().any(|s| text_so_far.contains(s)) {
                break;
            }
        }

        // Generate next token
        let pos = tokens.len() + generated.len() - 1;
        let logits = backend.forward(&[next_token], pos)?;
        next_token = sample_token(&logits, params)?;
    }

    let text = backend.decode(&generated)?;
    Ok((strip_thinking(&text, params.thinking), ttft_ms))
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
/// Generate with FLARE-style confidence-triggered re-retrieval.
///
/// Monitors token logit confidence during generation. When a window of
/// low-confidence tokens is detected, pauses, re-queries memory with the
/// partial generation, and resumes with augmented context.
pub async fn generate_with_retrieval(
    backend: &mut CandleBackend,
    mut req: GenerateRequest,
    retrieval_cb: RetrievalCallback,
) -> Result<String, InferenceError> {
    // First pass: generate normally
    backend.clear_kv_cache();
    let formatted = apply_chat_template(&req.prompt, req.context.as_deref(), req.params.thinking);
    let tokens = backend.encode(&formatted)?;
    let eos = backend.eos_token_id();
    let eos_alt = backend.token_id("<|im_end|>");
    let params = req.params.clone();

    if tokens.is_empty() {
        return Ok(String::new());
    }

    let mut generated = Vec::new();
    let mut low_confidence_count = 0u32;
    let mut retrieval_attempts = 0u32;
    let max_retrievals = 2;
    let confidence_threshold = 0.4f32;
    let low_confidence_window = 3u32;

    let logits = backend.forward(&tokens, 0)?;
    let mut next_token = sample_token(&logits, &params)?;

    for _i in 0..params.max_tokens {
        if eos.map_or(false, |id| next_token == id) || eos_alt.map_or(false, |id| next_token == id)
        {
            break;
        }

        generated.push(next_token);

        // Generate next token and check confidence
        let pos = tokens.len() + generated.len() - 1;
        let logits = backend.forward(&[next_token], pos)?;

        // Check max logit probability for confidence
        let logits_f32: Vec<f32> = logits
            .squeeze(0)
            .unwrap_or(logits.clone())
            .to_dtype(candle_core::DType::F32)
            .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
            .to_vec1()
            .unwrap_or_default();

        if !logits_f32.is_empty() {
            // Compute softmax max probability
            let max_logit = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
            let exp_sum: f32 = logits_f32.iter().map(|&v| (v - max_logit).exp()).sum();
            let max_prob = 1.0 / exp_sum; // probability of the top token

            if max_prob < confidence_threshold {
                low_confidence_count += 1;
            } else {
                low_confidence_count = 0;
            }

            // Trigger re-retrieval after sustained low confidence
            if low_confidence_count >= low_confidence_window && retrieval_attempts < max_retrievals
            {
                retrieval_attempts += 1;
                low_confidence_count = 0;

                // Use partial generation as re-retrieval query
                let partial = backend.decode(&generated)?;
                if let Some(new_context) = retrieval_cb(&partial) {
                    // Restart generation with augmented context
                    let combined_context = match req.context.take() {
                        Some(old) => format!("{}\n\n{}", old, new_context),
                        None => new_context,
                    };
                    req.context = Some(combined_context);

                    // Re-encode and restart
                    backend.clear_kv_cache();
                    let new_formatted = apply_chat_template(
                        &req.prompt,
                        req.context.as_deref(),
                        req.params.thinking,
                    );
                    let new_tokens = backend.encode(&new_formatted)?;
                    generated.clear();

                    let logits = backend.forward(&new_tokens, 0)?;
                    next_token = sample_token(&logits, &params)?;
                    continue;
                }
            }
        }

        next_token = sample_token(&logits, &params)?;
    }

    let text = backend.decode(&generated)?;
    Ok(strip_thinking(&text, params.thinking))
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
/// Sample a token, suppressing specific token IDs (set to -inf before sampling).
pub fn sample_token_suppress(
    logits: &Tensor,
    params: &GenerateParams,
    suppress: &[u32],
) -> Result<u32, InferenceError> {
    if suppress.is_empty() {
        return sample_token(logits, params);
    }
    // Clone logits and set suppressed tokens to -inf
    let mut logits_vec: Vec<f32> = logits
        .squeeze(0)
        .unwrap_or(logits.clone())
        .to_dtype(candle_core::DType::F32)
        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
        .to_vec1()
        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
    // Handle 2D logits (take last row)
    let dims = logits.dims();
    if dims.len() == 2 {
        let vocab = dims[dims.len() - 1];
        let start = logits_vec.len() - vocab;
        logits_vec = logits_vec[start..].to_vec();
    }
    for &id in suppress {
        if (id as usize) < logits_vec.len() {
            logits_vec[id as usize] = f32::NEG_INFINITY;
        }
    }
    let modified = Tensor::from_vec(
        logits_vec,
        logits.squeeze(0).unwrap_or(logits.clone()).shape(),
        logits.device(),
    )
    .map_err(|e| InferenceError::InferenceFailed(format!("from_vec: {e}")))?;
    sample_token(&modified, params)
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
/// Sample a token from logits using temperature + top-p + top-k.
pub fn sample_token(logits: &Tensor, params: &GenerateParams) -> Result<u32, InferenceError> {
    let logits = logits
        .squeeze(0)
        .map_err(|e| InferenceError::InferenceFailed(format!("squeeze: {e}")))?;
    let logits = logits
        .to_dtype(candle_core::DType::F32)
        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?;

    // Get last position's logits
    let dim = logits.dims();
    let logits = if dim.len() == 2 {
        logits
            .get(dim[0] - 1)
            .map_err(|e| InferenceError::InferenceFailed(format!("get last: {e}")))?
    } else {
        logits
    };

    // Greedy decoding
    if params.temperature <= 0.0 {
        let token = logits
            .argmax(0)
            .map_err(|e| InferenceError::InferenceFailed(format!("argmax: {e}")))?
            .to_scalar::<u32>()
            .map_err(|e| InferenceError::InferenceFailed(format!("scalar: {e}")))?;
        return Ok(token);
    }

    // Temperature scaling
    let logits = (&logits / params.temperature)
        .map_err(|e| InferenceError::InferenceFailed(format!("temp scale: {e}")))?;

    let mut logits_vec: Vec<f32> = logits
        .to_vec1()
        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;

    // Top-k filtering
    if params.top_k > 0 && params.top_k < logits_vec.len() {
        let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
        let threshold = indexed[params.top_k].1;
        for v in &mut logits_vec {
            if *v < threshold {
                *v = f32::NEG_INFINITY;
            }
        }
    }

    // Softmax
    let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
    let exp: Vec<f32> = logits_vec.iter().map(|&v| (v - max_logit).exp()).collect();
    let sum: f32 = exp.iter().sum();
    let mut probs: Vec<f32> = exp.iter().map(|&v| v / sum).collect();

    // Top-p (nucleus) filtering
    if params.top_p < 1.0 {
        let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
        sorted_indices.sort_by(|&a, &b| {
            probs[b]
                .partial_cmp(&probs[a])
                .unwrap_or(std::cmp::Ordering::Equal)
        });

        let mut cumsum = 0.0f32;
        let mut cutoff_idx = sorted_indices.len();
        for (i, &idx) in sorted_indices.iter().enumerate() {
            cumsum += probs[idx];
            if cumsum > params.top_p as f32 {
                cutoff_idx = i + 1;
                break;
            }
        }

        let keep: std::collections::HashSet<usize> =
            sorted_indices[..cutoff_idx].iter().copied().collect();
        for (i, p) in probs.iter_mut().enumerate() {
            if !keep.contains(&i) {
                *p = 0.0;
            }
        }

        // Renormalize
        let sum: f32 = probs.iter().sum();
        if sum > 0.0 {
            for p in &mut probs {
                *p /= sum;
            }
        }
    }

    // Categorical sample
    let r: f32 = rand_f32();
    let mut cumsum = 0.0f32;
    for (i, &p) in probs.iter().enumerate() {
        cumsum += p;
        if cumsum >= r {
            return Ok(i as u32);
        }
    }

    // Fallback: return highest prob token
    Ok(probs
        .iter()
        .enumerate()
        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
        .map(|(i, _)| i as u32)
        .unwrap_or(0))
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
/// Random float in [0, 1) using the rand crate.
fn rand_f32() -> f32 {
    rand::random::<f32>()
}

#[cfg(test)]
mod thinking_tests {
    use super::*;

    #[test]
    fn auto_injects_no_directive_and_no_prefill() {
        let out = apply_chat_template("hi", None, ThinkingMode::Auto);
        assert!(!out.contains("/no_think"));
        assert!(!out.contains("/think"));
        assert!(!out.contains("<think>"));
        assert!(out.contains("<|im_start|>user\nhi<|im_end|>"));
    }

    #[test]
    fn off_injects_no_think_on_own_line_and_prefills_empty_think() {
        let out = apply_chat_template("hi", None, ThinkingMode::Off);
        // Directive on its own line, not concatenated onto prose.
        assert!(out.contains("\n/no_think<|im_end|>"));
        assert!(!out.contains(" /no_think"));
        // Closed empty thinking block pre-filled after assistant marker
        // — the upstream jinja hard-switch for enable_thinking=False.
        assert!(out.contains("<|im_start|>assistant\n<think>\n\n</think>\n\n"));
    }

    #[test]
    fn on_injects_think_and_no_prefill() {
        let out = apply_chat_template("hi", None, ThinkingMode::On);
        assert!(out.contains("\n/think<|im_end|>"));
        assert!(!out.contains("/no_think"));
        assert!(!out.contains("<think>"));
    }

    #[test]
    fn pre_formatted_prompt_is_untouched() {
        let pre = "<|im_start|>system\ncustom<|im_end|>\n<|im_start|>user\nhi<|im_end|>";
        let out = apply_chat_template(pre, None, ThinkingMode::Off);
        assert_eq!(out, pre);
    }

    #[test]
    fn directive_appears_after_context_not_before() {
        let out = apply_chat_template("q?", Some("some memory"), ThinkingMode::Off);
        let ctx_idx = out.find("some memory").unwrap();
        let directive_idx = out.find("/no_think").unwrap();
        assert!(
            directive_idx > ctx_idx,
            "directive must appear after context so user memory cannot nudge the parse"
        );
    }

    #[test]
    fn default_params_is_auto() {
        assert_eq!(GenerateParams::default().thinking, ThinkingMode::Auto);
    }

    #[test]
    fn thinking_mode_serde_snake_case() {
        let json = serde_json::to_string(&ThinkingMode::Off).unwrap();
        assert_eq!(json, "\"off\"");
        let parsed: ThinkingMode = serde_json::from_str("\"on\"").unwrap();
        assert_eq!(parsed, ThinkingMode::On);
    }

    #[test]
    fn strip_preserves_thinking_when_on() {
        let text = "<think>reasoning here</think>the answer";
        let out = strip_thinking(text, ThinkingMode::On);
        assert_eq!(
            out, text,
            "On mode must return raw text with <think> visible"
        );
    }

    #[test]
    fn strip_removes_thinking_when_auto_or_off() {
        let text = "<think>reasoning</think>the answer";
        assert_eq!(strip_thinking(text, ThinkingMode::Auto), "the answer");
        assert_eq!(strip_thinking(text, ThinkingMode::Off), "the answer");
    }

    #[test]
    fn strip_returns_empty_on_unterminated_think() {
        // Output was cut off mid-thinking — don't leak the dangling tag.
        let text = "<think>mid-reasoning, never closed";
        assert_eq!(strip_thinking(text, ThinkingMode::Auto), "");
        assert_eq!(strip_thinking(text, ThinkingMode::Off), "");
        // On mode still returns the raw text — caller asked for it.
        assert_eq!(strip_thinking(text, ThinkingMode::On), text);
    }

    #[test]
    fn strip_is_noop_when_no_think_tag() {
        let text = "just a plain answer";
        assert_eq!(strip_thinking(text, ThinkingMode::Auto), text);
        assert_eq!(strip_thinking(text, ThinkingMode::Off), text);
        assert_eq!(strip_thinking(text, ThinkingMode::On), text);
    }
}

#[cfg(test)]
mod workload_tests {
    use super::*;

    #[test]
    fn all_workload_weights_sum_to_one() {
        for w in [
            RoutingWorkload::Interactive,
            RoutingWorkload::Batch,
            RoutingWorkload::Background,
            RoutingWorkload::LocalPreferred,
            RoutingWorkload::Fastest,
        ] {
            let (q, l, c) = w.weights();
            let sum = q + l + c;
            assert!(
                (sum - 1.0).abs() < 1e-6,
                "weights for {w:?} sum to {sum}, expected 1.0"
            );
        }
    }

    #[test]
    fn fastest_weights_dominate_on_latency() {
        let (q, l, c) = RoutingWorkload::Fastest.weights();
        // Latency should be the largest by a wide margin — that's the
        // whole point of this workload class.
        assert!(l > q && l > c);
        assert!(l >= 0.7, "latency weight too small: {l}");
    }

    #[test]
    fn fastest_is_latency_sensitive() {
        assert!(RoutingWorkload::Fastest.is_latency_sensitive());
    }
}