Skip to main content

codetether_agent/cognition/
thinker.rs

1//! Local "thinking" models for edge inference.
2//!
3//! Supports Candle (Qwen/Llama), OpenAI-compatible HTTP, and Bedrock as
4//! backends for generating fast local reasoning completions.
5//!
6//! Start with [`ThinkerClient::new`] to create a client.
7
8use anyhow::{Context, Result, anyhow};
9use candle_core::quantized::gguf_file;
10use candle_core::{DType, Device, Tensor};
11use candle_transformers::generation::LogitsProcessor;
12
13#[cfg(feature = "functiongemma")]
14use candle_transformers::models::quantized_gemma3;
15use candle_transformers::models::{
16    quantized_llama, quantized_qwen2, quantized_qwen3, quantized_qwen3_moe,
17};
18use candle_transformers::utils::apply_repeat_penalty;
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use std::collections::HashSet;
22use std::fs::File;
23use std::io::BufReader;
24use std::sync::{Arc, Mutex};
25use std::time::{Duration, Instant};
26use tokenizers::Tokenizer;
27
28use crate::provider::bedrock::{AwsCredentials, BedrockProvider};
29
30/// Available backends for the local thinker.
31///
32/// # Examples
33///
34/// ```rust
35/// use codetether_agent::cognition::thinker::ThinkerBackend;
36/// assert_eq!(ThinkerBackend::from_env("candle"), ThinkerBackend::Candle);
37/// ```
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ThinkerBackend {
40    OpenAICompat,
41    Candle,
42    Bedrock,
43}
44
45impl ThinkerBackend {
46    /// Parse a backend from an env-var string.
47    ///
48    /// # Examples
49    ///
50    /// ```rust
51    /// use codetether_agent::cognition::thinker::ThinkerBackend;
52    /// assert_eq!(ThinkerBackend::from_env("bedrock"), ThinkerBackend::Bedrock);
53    /// ```
54    pub fn from_env(value: &str) -> Self {
55        match value.trim().to_ascii_lowercase().as_str() {
56            "candle" => Self::Candle,
57            "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
58            "bedrock" | "aws" | "aws_bedrock" => Self::Bedrock,
59            _ => Self::OpenAICompat,
60        }
61    }
62}
63
64/// Device preference for Candle inference.
65///
66/// # Examples
67///
68/// ```rust
69/// use codetether_agent::cognition::thinker::CandleDevicePreference;
70/// assert_eq!(CandleDevicePreference::from_env("cpu"), CandleDevicePreference::Cpu);
71/// ```
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum CandleDevicePreference {
74    Auto,
75    Cpu,
76    Cuda,
77}
78
79impl CandleDevicePreference {
80    /// Parse a device preference from an env-var string.
81    ///
82    /// # Examples
83    ///
84    /// ```rust
85    /// use codetether_agent::cognition::thinker::CandleDevicePreference;
86    /// assert_eq!(CandleDevicePreference::from_env("cuda"), CandleDevicePreference::Cuda);
87    /// ```
88    pub fn from_env(value: &str) -> Self {
89        match value.trim().to_ascii_lowercase().as_str() {
90            "cpu" => Self::Cpu,
91            "cuda" | "gpu" => Self::Cuda,
92            _ => Self::Auto,
93        }
94    }
95}
96
97/// Configuration for [`ThinkerClient`].
98///
99/// # Examples
100///
101/// ```rust
102/// use codetether_agent::cognition::thinker::ThinkerConfig;
103/// let cfg = ThinkerConfig::default();
104/// assert!(!cfg.enabled);
105/// ```
106#[derive(Debug, Clone)]
107pub struct ThinkerConfig {
108    pub enabled: bool,
109    pub backend: ThinkerBackend,
110    pub endpoint: String,
111    pub model: String,
112    pub api_key: Option<String>,
113    pub temperature: f32,
114    pub top_p: Option<f32>,
115    pub max_tokens: usize,
116    pub timeout_ms: u64,
117    pub candle_model_path: Option<String>,
118    pub candle_tokenizer_path: Option<String>,
119    pub candle_arch: Option<String>,
120    pub candle_device: CandleDevicePreference,
121    pub candle_cuda_ordinal: usize,
122    pub candle_repeat_penalty: f32,
123    pub candle_repeat_last_n: usize,
124    pub candle_seed: u64,
125    pub bedrock_region: String,
126    pub bedrock_service_tier: Option<String>,
127}
128
129impl Default for ThinkerConfig {
130    fn default() -> Self {
131        Self {
132            enabled: false,
133            backend: ThinkerBackend::OpenAICompat,
134            endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
135            model: "qwen3.5-9b".to_string(),
136            api_key: None,
137            temperature: 0.2,
138            top_p: None,
139            max_tokens: 256,
140            timeout_ms: 30_000,
141            candle_model_path: None,
142            candle_tokenizer_path: None,
143            candle_arch: None,
144            candle_device: CandleDevicePreference::Auto,
145            candle_cuda_ordinal: 0,
146            candle_repeat_penalty: 1.1,
147            candle_repeat_last_n: 64,
148            candle_seed: 42,
149            bedrock_region: "us-west-2".to_string(),
150            bedrock_service_tier: None,
151        }
152    }
153}
154
155/// Output from a thinker completion.
156///
157/// # Examples
158///
159/// ```rust
160/// use codetether_agent::cognition::thinker::ThinkerOutput;
161/// let out = ThinkerOutput {
162///     model: "qwen3.5-9b".into(),
163///     finish_reason: Some("stop".into()),
164///     text: "hello".into(),
165///     prompt_tokens: Some(5),
166///     completion_tokens: Some(1),
167///     total_tokens: Some(6),
168///     cache_read_tokens: None,
169///     cache_write_tokens: None,
170/// };
171/// assert_eq!(out.model, "qwen3.5-9b");
172/// ```
173#[derive(Debug, Clone)]
174pub struct ThinkerOutput {
175    pub model: String,
176    pub finish_reason: Option<String>,
177    pub text: String,
178    pub prompt_tokens: Option<u32>,
179    pub completion_tokens: Option<u32>,
180    pub total_tokens: Option<u32>,
181    #[cfg_attr(not(feature = "candle-cuda"), allow(dead_code))]
182    pub cache_read_tokens: Option<u32>,
183    #[cfg_attr(not(feature = "candle-cuda"), allow(dead_code))]
184    pub cache_write_tokens: Option<u32>,
185}
186
187/// Client for local thinker inference across multiple backends.
188///
189/// # Examples
190///
191/// ```rust,no_run
192/// use codetether_agent::cognition::thinker::{ThinkerClient, ThinkerConfig};
193/// let cfg = ThinkerConfig::default();
194/// let client = ThinkerClient::new(cfg).unwrap();
195/// ```
196#[derive(Clone)]
197pub struct ThinkerClient {
198    config: ThinkerConfig,
199    backend: ThinkerClientBackend,
200}
201
202impl std::fmt::Debug for ThinkerClient {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.debug_struct("ThinkerClient")
205            .field("backend", &self.config.backend)
206            .field("model", &self.config.model)
207            .finish()
208    }
209}
210
211#[derive(Clone)]
212enum ThinkerClientBackend {
213    OpenAICompat { http: Client },
214    Candle { runtime: Arc<Mutex<CandleThinker>> },
215    Bedrock { provider: Arc<BedrockProvider> },
216}
217
218impl ThinkerClient {
219    /// Create a new thinker client from the given config.
220    ///
221    /// # Examples
222    ///
223    /// ```rust,no_run
224    /// use codetether_agent::cognition::thinker::{ThinkerClient, ThinkerConfig};
225    /// let client = ThinkerClient::new(ThinkerConfig::default()).unwrap();
226    /// ```
227    pub fn new(config: ThinkerConfig) -> Result<Self> {
228        let backend = match config.backend {
229            ThinkerBackend::OpenAICompat => {
230                let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
231                let http = Client::builder()
232                    .timeout(timeout)
233                    .build()
234                    .context("failed to build thinker HTTP client")?;
235                ThinkerClientBackend::OpenAICompat { http }
236            }
237            ThinkerBackend::Candle => {
238                let runtime = CandleThinker::new(&config)?;
239                ThinkerClientBackend::Candle {
240                    runtime: Arc::new(Mutex::new(runtime)),
241                }
242            }
243            ThinkerBackend::Bedrock => {
244                let creds = AwsCredentials::from_environment()
245                    .ok_or_else(|| anyhow!("Bedrock thinker requires AWS credentials (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY or ~/.aws/credentials)"))?;
246                let provider =
247                    BedrockProvider::with_credentials(creds, config.bedrock_region.clone())?;
248                ThinkerClientBackend::Bedrock {
249                    provider: Arc::new(provider),
250                }
251            }
252        };
253
254        Ok(Self { config, backend })
255    }
256
257    /// Return a reference to the active config.
258    ///
259    /// # Examples
260    ///
261    /// ```rust,no_run
262    /// # use codetether_agent::cognition::thinker::{ThinkerClient, ThinkerConfig};
263    /// let client = ThinkerClient::new(ThinkerConfig::default()).unwrap();
264    /// let cfg = client.config();
265    /// assert!(!cfg.enabled);
266    /// ```
267    pub fn config(&self) -> &ThinkerConfig {
268        &self.config
269    }
270
271    /// Generate a thinking completion using the configured backend.
272    ///
273    /// # Examples
274    ///
275    /// ```rust,no_run
276    /// # use codetether_agent::cognition::thinker::{ThinkerClient, ThinkerConfig};
277    /// # async fn demo() {
278    /// let client = ThinkerClient::new(ThinkerConfig::default()).unwrap();
279    /// let output = client.think("You are helpful.", "Explain Rust").await.unwrap();
280    /// assert!(!output.text.is_empty());
281    /// # }
282    /// ```
283    pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
284        match &self.backend {
285            ThinkerClientBackend::OpenAICompat { http } => {
286                self.think_openai_compat(http, system_prompt, user_prompt)
287                    .await
288            }
289            ThinkerClientBackend::Bedrock { provider } => {
290                self.think_bedrock(provider, system_prompt, user_prompt)
291                    .await
292            }
293            ThinkerClientBackend::Candle { runtime } => {
294                let runtime = Arc::clone(runtime);
295                let system_prompt = system_prompt.to_string();
296                let user_prompt = user_prompt.to_string();
297                tokio::task::spawn_blocking(move || {
298                    let mut guard = match runtime.try_lock() {
299                        Ok(g) => g,
300                        Err(std::sync::TryLockError::WouldBlock) => {
301                            return Err(anyhow!("candle thinker is busy"));
302                        }
303                        Err(std::sync::TryLockError::Poisoned(_)) => {
304                            return Err(anyhow!("candle thinker mutex poisoned"));
305                        }
306                    };
307                    guard.think(&system_prompt, &user_prompt)
308                })
309                .await
310                .context("candle thinker task join failed")?
311            }
312        }
313    }
314
315    async fn think_bedrock(
316        &self,
317        provider: &BedrockProvider,
318        system_prompt: &str,
319        user_prompt: &str,
320    ) -> Result<ThinkerOutput> {
321        let started_at = Instant::now();
322        let model_id = &self.config.model;
323
324        // Build Bedrock Converse request body
325        let mut body = serde_json::json!({
326            "system": [{"text": system_prompt}],
327            "messages": [{
328                "role": "user",
329                "content": [{"text": user_prompt}]
330            }],
331            "inferenceConfig": {
332                "maxTokens": self.config.max_tokens,
333                "temperature": self.config.temperature
334            }
335        });
336
337        if let Some(service_tier) = self.config.bedrock_service_tier.as_ref() {
338            body["additionalModelRequestFields"] = serde_json::json!({
339                "service_tier": service_tier
340            });
341        }
342
343        let body_bytes = serde_json::to_vec(&body)?;
344        // Do NOT percent-encode `:` here — reqwest encodes URL paths natively,
345        // and pre-encoding to `%3A` triggers double-encoding to `%253A`,
346        // breaking SigV4.
347        let url = format!(
348            "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
349            self.config.bedrock_region, model_id
350        );
351
352        let response = provider
353            .send_converse_request(&url, &body_bytes)
354            .await
355            .context("Bedrock thinker converse request failed")?;
356
357        let status = response.status();
358        let text = response
359            .text()
360            .await
361            .context("Failed to read Bedrock thinker response")?;
362
363        if !status.is_success() {
364            return Err(anyhow!(
365                "Bedrock thinker error ({}): {}",
366                status,
367                crate::util::truncate_bytes_safe(&text, 500)
368            ));
369        }
370
371        let parsed: serde_json::Value =
372            serde_json::from_str(&text).context("Failed to parse Bedrock thinker response")?;
373
374        let output_text = parsed["output"]["message"]["content"]
375            .as_array()
376            .and_then(|arr| arr.first())
377            .and_then(|c| c["text"].as_str())
378            .unwrap_or_default()
379            .to_string();
380
381        let usage = &parsed["usage"];
382        let prompt_tokens = usage["inputTokens"].as_u64().map(|v| v as u32);
383        let completion_tokens = usage["outputTokens"].as_u64().map(|v| v as u32);
384
385        tracing::debug!(
386            model = model_id,
387            latency_ms = started_at.elapsed().as_millis(),
388            prompt_tokens = ?prompt_tokens,
389            completion_tokens = ?completion_tokens,
390            "bedrock thinker generated thought"
391        );
392
393        Ok(ThinkerOutput {
394            model: model_id.clone(),
395            finish_reason: parsed["stopReason"].as_str().map(|s| s.to_string()),
396            text: output_text,
397            prompt_tokens,
398            completion_tokens,
399            total_tokens: prompt_tokens.zip(completion_tokens).map(|(p, c)| p + c),
400            cache_read_tokens: None,
401            cache_write_tokens: None,
402        })
403    }
404
405    async fn think_openai_compat(
406        &self,
407        http: &Client,
408        system_prompt: &str,
409        user_prompt: &str,
410    ) -> Result<ThinkerOutput> {
411        let started_at = Instant::now();
412        let body = OpenAIChatRequest {
413            model: self.config.model.clone(),
414            messages: vec![
415                OpenAIMessage {
416                    role: "system".to_string(),
417                    content: system_prompt.to_string(),
418                },
419                OpenAIMessage {
420                    role: "user".to_string(),
421                    content: user_prompt.to_string(),
422                },
423            ],
424            temperature: self.config.temperature,
425            top_p: self.config.top_p,
426            max_tokens: self.config.max_tokens,
427            stream: false,
428        };
429
430        // Retry once on transient failures (connection errors, 429, 502-504).
431        let max_attempts: u32 = 2;
432        let mut last_err: Option<anyhow::Error> = None;
433
434        for attempt in 0..max_attempts {
435            if attempt > 0 {
436                tokio::time::sleep(Duration::from_millis(500 * attempt as u64)).await;
437                tracing::debug!(attempt, "retrying thinker HTTP request");
438            }
439
440            let mut request = http.post(&self.config.endpoint).json(&body);
441            if let Some(key) = self.config.api_key.as_ref() {
442                request = request.bearer_auth(key);
443            }
444
445            let response = match request.send().await {
446                Ok(resp) => resp,
447                Err(e) => {
448                    if is_transient_reqwest_error(&e) {
449                        tracing::warn!(attempt, error = %e, "thinker HTTP request failed (transient)");
450                        last_err =
451                            Some(anyhow::Error::from(e).context("transient thinker send error"));
452                        continue;
453                    }
454                    return Err(anyhow::Error::from(e).context("non-transient thinker send error"));
455                }
456            };
457
458            let status = response.status();
459            if is_transient_http_error(status.as_u16()) {
460                let body_text = response.text().await.unwrap_or_default();
461                tracing::warn!(attempt, status = %status, "thinker received transient HTTP error");
462                last_err = Some(anyhow!(
463                    "thinker request failed with status {}: {}",
464                    status,
465                    body_text
466                ));
467                continue;
468            }
469
470            if !status.is_success() {
471                let body_text = response
472                    .text()
473                    .await
474                    .unwrap_or_else(|_| "<empty>".to_string());
475                return Err(anyhow!(
476                    "thinker request failed with status {}: {}",
477                    status,
478                    body_text
479                ));
480            }
481
482            let payload: OpenAIChatResponse = response
483                .json()
484                .await
485                .context("failed to decode thinker response")?;
486            let choice = payload
487                .choices
488                .first()
489                .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
490            let text = choice.message.extract_text();
491            let usage = payload.usage.unwrap_or_default();
492
493            let output = ThinkerOutput {
494                model: payload.model.unwrap_or_else(|| self.config.model.clone()),
495                finish_reason: choice.finish_reason.clone(),
496                text,
497                prompt_tokens: usage.prompt_tokens,
498                completion_tokens: usage.completion_tokens,
499                total_tokens: usage.total_tokens,
500                cache_read_tokens: None,
501                cache_write_tokens: None,
502            };
503
504            tracing::debug!(
505                model = %output.model,
506                latency_ms = started_at.elapsed().as_millis(),
507                prompt_tokens = ?output.prompt_tokens,
508                completion_tokens = ?output.completion_tokens,
509                attempt,
510                "openai-compat thinker generated thought"
511            );
512
513            return Ok(output);
514        }
515
516        Err(last_err.unwrap_or_else(|| {
517            anyhow!("thinker HTTP request failed after {max_attempts} attempts")
518        }))
519    }
520}
521
522pub(crate) struct CandleThinker {
523    model: CandleModel,
524    tokenizer: Tokenizer,
525    device: Device,
526    model_label: String,
527    architecture: String,
528    context_window: usize,
529    temperature: f32,
530    top_p: Option<f32>,
531    max_tokens: usize,
532    repeat_penalty: f32,
533    repeat_last_n: usize,
534    seed: u64,
535    request_index: u64,
536    eos_token_ids: HashSet<u32>,
537    cached_tokens: Vec<u32>,
538}
539
540enum CandleModel {
541    Llama(quantized_llama::ModelWeights),
542    Qwen2(quantized_qwen2::ModelWeights),
543    Qwen3(quantized_qwen3::ModelWeights),
544    Qwen3Moe(quantized_qwen3_moe::GGUFQWenMoE),
545
546    #[cfg(feature = "functiongemma")]
547    Gemma3(quantized_gemma3::ModelWeights),
548}
549
550impl CandleModel {
551    fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
552        match self {
553            Self::Llama(model) => Ok(model.forward(x, index_pos)?),
554            Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
555            Self::Qwen3(model) => Ok(model.forward(x, index_pos)?),
556            Self::Qwen3Moe(model) => Ok(model.forward(x, index_pos)?),
557
558            #[cfg(feature = "functiongemma")]
559            Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
560        }
561    }
562
563    fn reset_kv_cache_for_new_request(&mut self) -> Result<()> {
564        match self {
565            // quantized_qwen3 uses ConcatKvCache and requires explicit reset for unrelated prompts.
566            Self::Qwen3(model) => {
567                model.clear_kv_cache();
568                Ok(())
569            }
570            // quantized_qwen3_moe in candle-transformers 0.9.2 does not expose KV reset.
571            Self::Qwen3Moe(_) => Err(anyhow!(
572                "qwen3_moe runtime cannot reset KV cache in this build; restart local runtime or use qwen3"
573            )),
574            Self::Llama(_) | Self::Qwen2(_) => Ok(()),
575
576            #[cfg(feature = "functiongemma")]
577            Self::Gemma3(_) => Ok(()),
578        }
579    }
580
581    fn can_extend_cached_prefix(&self) -> bool {
582        true
583    }
584}
585
586impl CandleThinker {
587    pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
588        let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
589            anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
590        })?;
591        let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
592            anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
593        })?;
594
595        let (device, device_label) = select_candle_device(config)?;
596        let mut reader = BufReader::new(
597            File::open(model_path)
598                .with_context(|| format!("failed to open candle model file at {}", model_path))?,
599        );
600        let content = gguf_file::Content::read(&mut reader)
601            .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
602
603        let architecture = config
604            .candle_arch
605            .clone()
606            .or_else(|| {
607                content
608                    .metadata
609                    .get("general.architecture")
610                    .and_then(|v| v.to_string().ok())
611                    .cloned()
612            })
613            .unwrap_or_else(|| "llama".to_string())
614            .to_ascii_lowercase();
615
616        let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
617        let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
618
619        let tokenizer = Tokenizer::from_file(tokenizer_path)
620            .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
621
622        // Extract EOS metadata from content before it is moved into from_gguf.
623        let gguf_eos_ids = extract_gguf_eos_ids(&content);
624
625        let model = match architecture.as_str() {
626            "llama" => CandleModel::Llama(
627                quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
628                    .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
629            ),
630            "qwen2" => CandleModel::Qwen2(
631                quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
632                    .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
633            ),
634            "qwen3" => CandleModel::Qwen3(
635                quantized_qwen3::ModelWeights::from_gguf(content, &mut reader, &device)
636                    .with_context(|| format!("failed to load qwen3 gguf from {}", model_path))?,
637            ),
638            "qwen3moe" | "qwen3_moe" => CandleModel::Qwen3Moe(
639                quantized_qwen3_moe::GGUFQWenMoE::from_gguf(
640                    content,
641                    &mut reader,
642                    &device,
643                    DType::F16,
644                )
645                .with_context(|| format!("failed to load qwen3_moe gguf from {}", model_path))?,
646            ),
647
648            #[cfg(feature = "functiongemma")]
649            "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
650                quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
651                    .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
652            ),
653            other => {
654                #[cfg(not(feature = "functiongemma"))]
655                if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
656                    return Err(anyhow!(
657                        "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
658                        other
659                    ));
660                }
661                return Err(anyhow!(
662                    "unsupported candle architecture '{}' (supported: llama, qwen2, qwen3, qwen3_moe{})",
663                    other,
664                    if cfg!(feature = "functiongemma") {
665                        ", gemma/gemma2/gemma3"
666                    } else {
667                        ""
668                    }
669                ));
670            }
671        };
672
673        let eos_token_ids: HashSet<u32> = collect_eos_token_ids(&tokenizer, &gguf_eos_ids);
674        if eos_token_ids.is_empty() {
675            tracing::warn!(
676                "No EOS tokens found in tokenizer; generation will stop on max token limit"
677            );
678        }
679
680        Ok(Self {
681            model,
682            tokenizer,
683            device,
684            model_label,
685            architecture,
686            context_window,
687            temperature: config.temperature,
688            top_p: config.top_p,
689            max_tokens: config.max_tokens.max(1),
690            repeat_penalty: config.candle_repeat_penalty.max(1.0),
691            repeat_last_n: config.candle_repeat_last_n.max(1),
692            seed: config.candle_seed,
693            request_index: 0,
694            eos_token_ids,
695            cached_tokens: Vec::new(),
696        })
697    }
698
699    /// Run inference using an already-formatted prompt string (no chat template wrapping).
700    #[cfg(feature = "functiongemma")]
701    pub(crate) fn think_raw(&mut self, raw_prompt: &str) -> Result<ThinkerOutput> {
702        self.think_inner(raw_prompt)
703    }
704
705    pub(crate) fn think(
706        &mut self,
707        system_prompt: &str,
708        user_prompt: &str,
709    ) -> Result<ThinkerOutput> {
710        let prompt = format_chat_prompt(&self.architecture, system_prompt, user_prompt);
711        self.think_inner(&prompt)
712    }
713
714    fn think_inner(&mut self, prompt: &str) -> Result<ThinkerOutput> {
715        let started_at = Instant::now();
716        let encoding = self
717            .tokenizer
718            .encode(prompt, true)
719            .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
720        let mut tokens = encoding.get_ids().to_vec();
721        if tokens.is_empty() {
722            return Err(anyhow!("tokenizer produced an empty prompt token set"));
723        }
724
725        // Truncate to context window if needed — keep the tail of the prompt.
726        if self.context_window > 8 && tokens.len() >= self.context_window {
727            let budget = self.context_window.saturating_sub(8);
728            tokens = tokens[tokens.len().saturating_sub(budget)..].to_vec();
729        }
730        let prompt_token_count = tokens.len() as u32;
731
732        let request_seed = self.seed.wrapping_add(self.request_index);
733        self.request_index = self.request_index.wrapping_add(1);
734        let mut logits_processor = LogitsProcessor::new(
735            request_seed,
736            Some(self.temperature as f64),
737            self.top_p.map(|v| v as f64),
738        );
739
740        let mut cache_read_tokens = 0u32;
741        let mut cache_write_tokens = 0u32;
742
743        let can_extend_prefix = self.model.can_extend_cached_prefix()
744            && !self.cached_tokens.is_empty()
745            && tokens.len() > self.cached_tokens.len()
746            && tokens.starts_with(&self.cached_tokens);
747
748        let mut index_pos = if can_extend_prefix {
749            cache_read_tokens = self.cached_tokens.len() as u32;
750            self.cached_tokens.len()
751        } else {
752            if !self.cached_tokens.is_empty() {
753                self.model.reset_kv_cache_for_new_request()?;
754            }
755            0
756        };
757
758        let prefill = if index_pos == 0 {
759            tokens.as_slice()
760        } else {
761            &tokens[index_pos..]
762        };
763        if prefill.is_empty() {
764            // Exact-token prompt replay has no extra tokens to prefill, so do a fresh prefill.
765            self.model.reset_kv_cache_for_new_request()?;
766            index_pos = 0;
767        }
768
769        let prefill = if index_pos == 0 {
770            tokens.as_slice()
771        } else {
772            &tokens[index_pos..]
773        };
774        cache_write_tokens = cache_write_tokens.saturating_add(prefill.len() as u32);
775
776        let input = Tensor::new(prefill, &self.device)?
777            .unsqueeze(0)
778            .context("failed to create candle input tensor")?;
779        let mut logits = self
780            .model
781            .forward(&input, index_pos)
782            .context("candle model forward failed")?;
783        index_pos += prefill.len();
784        logits = logits
785            .squeeze(0)
786            .context("failed to squeeze logits batch dimension")?;
787
788        let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
789        let mut finish_reason = "length".to_string();
790
791        for _ in 0..self.max_tokens {
792            let sampling_logits = if self.repeat_penalty > 1.0 {
793                let start_at = tokens.len().saturating_sub(self.repeat_last_n);
794                apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
795                    .context("failed to apply repeat penalty")?
796            } else {
797                logits.clone()
798            };
799
800            let next_token =
801                sample_next_token_with_fallback(&mut logits_processor, &sampling_logits)?;
802            if self.eos_token_ids.contains(&next_token) {
803                finish_reason = "stop".to_string();
804                break;
805            }
806
807            tokens.push(next_token);
808            generated.push(next_token);
809            cache_write_tokens = cache_write_tokens.saturating_add(1);
810
811            if tokens.len() + 1 >= self.context_window {
812                finish_reason = "length".to_string();
813                break;
814            }
815
816            let input = Tensor::new(&tokens[tokens.len() - 1..], &self.device)?
817                .unsqueeze(0)
818                .context("failed to create candle input tensor")?;
819            logits = self
820                .model
821                .forward(&input, index_pos)
822                .context("candle model forward failed")?;
823            index_pos += 1;
824            logits = logits
825                .squeeze(0)
826                .context("failed to squeeze logits batch dimension")?;
827        }
828
829        let text = self
830            .tokenizer
831            .decode(&generated, true)
832            .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
833        let completion_tokens = generated.len() as u32;
834        self.cached_tokens = tokens;
835
836        tracing::debug!(
837            model = %self.model_label,
838            latency_ms = started_at.elapsed().as_millis(),
839            prompt_tokens = prompt_token_count,
840            completion_tokens = completion_tokens,
841            cache_read_tokens = cache_read_tokens,
842            cache_write_tokens = cache_write_tokens,
843            "candle thinker generated thought"
844        );
845
846        Ok(ThinkerOutput {
847            model: self.model_label.clone(),
848            finish_reason: Some(finish_reason),
849            text,
850            prompt_tokens: Some(prompt_token_count),
851            completion_tokens: Some(completion_tokens),
852            total_tokens: Some(prompt_token_count + completion_tokens),
853            cache_read_tokens: Some(cache_read_tokens),
854            cache_write_tokens: Some(cache_write_tokens),
855        })
856    }
857}
858
859/// Build a chat prompt using the proper template for each model architecture.
860fn format_chat_prompt(architecture: &str, system_prompt: &str, user_prompt: &str) -> String {
861    match architecture {
862        // ChatML template (Qwen2, Yi, etc.)
863        "qwen2" | "qwen3" | "qwen3moe" | "qwen3_moe" => format!(
864            "<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
865            system = system_prompt,
866            user = user_prompt,
867        ),
868        // Llama 3 instruct template
869        "llama" => format!(
870            "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
871            system = system_prompt,
872            user = user_prompt,
873        ),
874        // Gemma instruct template
875        "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => format!(
876            "<start_of_turn>user\n{system}\n\n{user}<end_of_turn>\n<start_of_turn>model\n",
877            system = system_prompt,
878            user = user_prompt,
879        ),
880        // Fallback for unknown architectures
881        _ => format!(
882            "System:\n{system}\n\nUser:\n{user}\n\nAssistant:\n",
883            system = system_prompt,
884            user = user_prompt,
885        ),
886    }
887}
888
889fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
890    match config.candle_device {
891        CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
892        CandleDevicePreference::Cuda => {
893            let device = try_cuda_device(config.candle_cuda_ordinal)?;
894            Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
895        }
896        CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
897            Ok(device) => {
898                tracing::info!(
899                    ordinal = config.candle_cuda_ordinal,
900                    "Candle thinker selected CUDA device"
901                );
902                Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
903            }
904            Err(error) => {
905                tracing::warn!(
906                    %error,
907                    "CUDA unavailable for Candle thinker, falling back to CPU"
908                );
909                Ok((Device::Cpu, "cpu".to_string()))
910            }
911        },
912    }
913}
914
915#[cfg(feature = "candle-cuda")]
916fn try_cuda_device(ordinal: usize) -> Result<Device> {
917    Device::new_cuda(ordinal)
918        .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
919}
920
921#[cfg(not(feature = "candle-cuda"))]
922fn try_cuda_device(_ordinal: usize) -> Result<Device> {
923    Err(anyhow!(
924        "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
925    ))
926}
927
928fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
929    let key = match architecture {
930        "qwen2" => "qwen2.context_length",
931        "qwen3" | "qwen3moe" | "qwen3_moe" => "qwen3.context_length",
932        "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
933            // Try gemma3 first, then fall back to gemma2, gemma
934            for prefix in ["gemma3", "gemma2", "gemma"] {
935                let k = format!("{prefix}.context_length");
936                if let Some(v) = content.metadata.get(&k) {
937                    return v.to_u32().ok().map(|v| v as usize);
938                }
939            }
940            return None;
941        }
942        _ => "llama.context_length",
943    };
944    content
945        .metadata
946        .get(key)
947        .and_then(|v| v.to_u32().ok())
948        .map(|v| v as usize)
949}
950
951/// Extract EOS token IDs from GGUF metadata before the content is consumed.
952fn extract_gguf_eos_ids(content: &gguf_file::Content) -> Vec<u32> {
953    let mut ids = Vec::new();
954    for key in ["tokenizer.ggml.eos_token_id", "tokenizer.ggml.eot_token_id"] {
955        if let Some(v) = content.metadata.get(key)
956            && let Ok(id) = v.to_u32()
957            && !ids.contains(&id)
958        {
959            ids.push(id);
960        }
961    }
962    ids
963}
964
965fn collect_eos_token_ids(tokenizer: &Tokenizer, gguf_eos_ids: &[u32]) -> HashSet<u32> {
966    let mut ids: HashSet<u32> = gguf_eos_ids.iter().copied().collect();
967
968    // Also check well-known special token strings as fallback.
969    let candidates = [
970        "<|im_end|>",
971        "<|eot_id|>",
972        "<|endoftext|>",
973        "</s>",
974        "<|end|>",
975        "<end_of_turn>",
976    ];
977    for token in candidates {
978        if let Some(id) = tokenizer.token_to_id(token) {
979            ids.insert(id);
980        }
981    }
982    ids
983}
984
985fn sample_next_token_with_fallback(
986    logits_processor: &mut LogitsProcessor,
987    logits: &Tensor,
988) -> Result<u32> {
989    match logits_processor.sample(logits) {
990        Ok(token) => Ok(token),
991        Err(sample_error) => {
992            let logits_vec = logits
993                .to_vec1::<f32>()
994                .context("token sampling failed and fallback logits extraction failed")?;
995            let mut best_token = None;
996            let mut best_logit = f32::NEG_INFINITY;
997
998            for (idx, logit) in logits_vec.into_iter().enumerate() {
999                if !logit.is_finite() {
1000                    continue;
1001                }
1002                if best_token.is_none() || logit > best_logit {
1003                    best_token = Some(idx as u32);
1004                    best_logit = logit;
1005                }
1006            }
1007
1008            if let Some(token) = best_token {
1009                tracing::warn!(
1010                    error = %sample_error,
1011                    token,
1012                    "Token sampling produced invalid weights; using greedy argmax fallback"
1013                );
1014                Ok(token)
1015            } else {
1016                Err(anyhow!(
1017                    "token sampling failed and fallback argmax found no finite logits: {}",
1018                    sample_error
1019                ))
1020            }
1021        }
1022    }
1023}
1024
1025/// Returns true for HTTP status codes that are worth retrying.
1026fn is_transient_http_error(status: u16) -> bool {
1027    matches!(status, 429 | 502 | 503 | 504)
1028}
1029
1030/// Returns true for reqwest errors that are worth retrying (timeouts, connection resets).
1031fn is_transient_reqwest_error(e: &reqwest::Error) -> bool {
1032    e.is_timeout() || e.is_connect() || e.is_request()
1033}
1034
1035#[derive(Debug, Serialize)]
1036struct OpenAIChatRequest {
1037    model: String,
1038    messages: Vec<OpenAIMessage>,
1039    temperature: f32,
1040    #[serde(skip_serializing_if = "Option::is_none")]
1041    top_p: Option<f32>,
1042    max_tokens: usize,
1043    stream: bool,
1044}
1045
1046#[derive(Debug, Serialize)]
1047struct OpenAIMessage {
1048    role: String,
1049    content: String,
1050}
1051
1052#[derive(Debug, Deserialize)]
1053struct OpenAIChatResponse {
1054    model: Option<String>,
1055    choices: Vec<OpenAIChatChoice>,
1056    #[serde(default)]
1057    usage: Option<OpenAIUsage>,
1058}
1059
1060#[derive(Debug, Deserialize)]
1061struct OpenAIChatChoice {
1062    message: OpenAIChatChoiceMessage,
1063    #[serde(default)]
1064    finish_reason: Option<String>,
1065}
1066
1067#[derive(Debug, Deserialize)]
1068struct OpenAIChatChoiceMessage {
1069    #[serde(default)]
1070    content: Option<OpenAIChatContent>,
1071    #[serde(default)]
1072    reasoning: Option<String>,
1073    #[serde(default)]
1074    reasoning_content: Option<String>,
1075}
1076
1077#[derive(Debug, Default, Deserialize)]
1078struct OpenAIUsage {
1079    prompt_tokens: Option<u32>,
1080    completion_tokens: Option<u32>,
1081    total_tokens: Option<u32>,
1082}
1083
1084#[derive(Debug, Deserialize)]
1085#[serde(untagged)]
1086enum OpenAIChatContent {
1087    Text(String),
1088    Parts(Vec<OpenAIChatContentPart>),
1089    Part(OpenAIChatContentPart),
1090}
1091
1092#[derive(Debug, Deserialize)]
1093struct OpenAIChatContentPart {
1094    #[serde(rename = "type")]
1095    kind: Option<String>,
1096    #[serde(default)]
1097    text: Option<String>,
1098    #[serde(default)]
1099    content: Option<String>,
1100}
1101
1102impl OpenAIChatChoiceMessage {
1103    fn extract_text(&self) -> String {
1104        let content_text = self
1105            .content
1106            .as_ref()
1107            .map(OpenAIChatContent::to_text)
1108            .unwrap_or_default();
1109        if !content_text.trim().is_empty() {
1110            return content_text;
1111        }
1112
1113        if let Some(reasoning) = self
1114            .reasoning
1115            .as_deref()
1116            .filter(|text| !text.trim().is_empty())
1117        {
1118            return reasoning.to_string();
1119        }
1120
1121        self.reasoning_content
1122            .as_deref()
1123            .filter(|text| !text.trim().is_empty())
1124            .unwrap_or_default()
1125            .to_string()
1126    }
1127}
1128
1129impl OpenAIChatContent {
1130    fn to_text(&self) -> String {
1131        match self {
1132            Self::Text(text) => text.clone(),
1133            Self::Parts(parts) => parts
1134                .iter()
1135                .filter_map(OpenAIChatContentPart::text_fragment)
1136                .collect::<Vec<_>>()
1137                .join("\n"),
1138            Self::Part(part) => part.text_fragment().unwrap_or_default(),
1139        }
1140    }
1141}
1142
1143impl OpenAIChatContentPart {
1144    fn text_fragment(&self) -> Option<String> {
1145        if let Some(kind) = self.kind.as_deref()
1146            && !kind.eq_ignore_ascii_case("text")
1147            && !kind.eq_ignore_ascii_case("output_text")
1148        {
1149            return None;
1150        }
1151
1152        self.text
1153            .as_deref()
1154            .or(self.content.as_deref())
1155            .map(str::trim)
1156            .filter(|text| !text.is_empty())
1157            .map(ToString::to_string)
1158    }
1159}