Skip to main content

codetether_agent/cognition/
thinker.rs

1use anyhow::{Context, Result, anyhow};
2use candle_core::quantized::gguf_file;
3use candle_core::{Device, Tensor};
4use candle_transformers::generation::LogitsProcessor;
5#[cfg(feature = "functiongemma")]
6use candle_transformers::models::quantized_gemma3;
7use candle_transformers::models::{quantized_llama, quantized_qwen2};
8use candle_transformers::utils::apply_repeat_penalty;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use std::fs::File;
13use std::io::BufReader;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use tokenizers::Tokenizer;
17
18use crate::provider::bedrock::{AwsCredentials, BedrockProvider};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ThinkerBackend {
22    OpenAICompat,
23    Candle,
24    Bedrock,
25}
26
27impl ThinkerBackend {
28    pub fn from_env(value: &str) -> Self {
29        match value.trim().to_ascii_lowercase().as_str() {
30            "candle" => Self::Candle,
31            "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
32            "bedrock" | "aws" | "aws_bedrock" => Self::Bedrock,
33            _ => Self::OpenAICompat,
34        }
35    }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum CandleDevicePreference {
40    Auto,
41    Cpu,
42    Cuda,
43}
44
45impl CandleDevicePreference {
46    pub fn from_env(value: &str) -> Self {
47        match value.trim().to_ascii_lowercase().as_str() {
48            "cpu" => Self::Cpu,
49            "cuda" | "gpu" => Self::Cuda,
50            _ => Self::Auto,
51        }
52    }
53}
54
55#[derive(Debug, Clone)]
56pub struct ThinkerConfig {
57    pub enabled: bool,
58    pub backend: ThinkerBackend,
59    pub endpoint: String,
60    pub model: String,
61    pub api_key: Option<String>,
62    pub temperature: f32,
63    pub top_p: Option<f32>,
64    pub max_tokens: usize,
65    pub timeout_ms: u64,
66    pub candle_model_path: Option<String>,
67    pub candle_tokenizer_path: Option<String>,
68    pub candle_arch: Option<String>,
69    pub candle_device: CandleDevicePreference,
70    pub candle_cuda_ordinal: usize,
71    pub candle_repeat_penalty: f32,
72    pub candle_repeat_last_n: usize,
73    pub candle_seed: u64,
74    pub bedrock_region: String,
75}
76
77impl Default for ThinkerConfig {
78    fn default() -> Self {
79        Self {
80            enabled: false,
81            backend: ThinkerBackend::OpenAICompat,
82            endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
83            model: "qwen2.5:3b-instruct".to_string(),
84            api_key: None,
85            temperature: 0.2,
86            top_p: None,
87            max_tokens: 256,
88            timeout_ms: 30_000,
89            candle_model_path: None,
90            candle_tokenizer_path: None,
91            candle_arch: None,
92            candle_device: CandleDevicePreference::Auto,
93            candle_cuda_ordinal: 0,
94            candle_repeat_penalty: 1.1,
95            candle_repeat_last_n: 64,
96            candle_seed: 42,
97            bedrock_region: "us-west-2".to_string(),
98        }
99    }
100}
101
102#[derive(Debug, Clone)]
103pub struct ThinkerOutput {
104    pub model: String,
105    pub finish_reason: Option<String>,
106    pub text: String,
107    pub prompt_tokens: Option<u32>,
108    pub completion_tokens: Option<u32>,
109    pub total_tokens: Option<u32>,
110}
111
112#[derive(Clone)]
113pub struct ThinkerClient {
114    config: ThinkerConfig,
115    backend: ThinkerClientBackend,
116}
117
118impl std::fmt::Debug for ThinkerClient {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("ThinkerClient")
121            .field("backend", &self.config.backend)
122            .field("model", &self.config.model)
123            .finish()
124    }
125}
126
127#[derive(Clone)]
128enum ThinkerClientBackend {
129    OpenAICompat { http: Client },
130    Candle { runtime: Arc<Mutex<CandleThinker>> },
131    Bedrock { provider: Arc<BedrockProvider> },
132}
133
134impl ThinkerClient {
135    pub fn new(config: ThinkerConfig) -> Result<Self> {
136        let backend = match config.backend {
137            ThinkerBackend::OpenAICompat => {
138                let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
139                let http = Client::builder()
140                    .timeout(timeout)
141                    .build()
142                    .context("failed to build thinker HTTP client")?;
143                ThinkerClientBackend::OpenAICompat { http }
144            }
145            ThinkerBackend::Candle => {
146                let runtime = CandleThinker::new(&config)?;
147                ThinkerClientBackend::Candle {
148                    runtime: Arc::new(Mutex::new(runtime)),
149                }
150            }
151            ThinkerBackend::Bedrock => {
152                let creds = AwsCredentials::from_environment()
153                    .ok_or_else(|| anyhow!("Bedrock thinker requires AWS credentials (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY or ~/.aws/credentials)"))?;
154                let provider =
155                    BedrockProvider::with_credentials(creds, config.bedrock_region.clone())?;
156                ThinkerClientBackend::Bedrock {
157                    provider: Arc::new(provider),
158                }
159            }
160        };
161
162        Ok(Self { config, backend })
163    }
164
165    pub fn config(&self) -> &ThinkerConfig {
166        &self.config
167    }
168
169    pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
170        match &self.backend {
171            ThinkerClientBackend::OpenAICompat { http } => {
172                self.think_openai_compat(http, system_prompt, user_prompt)
173                    .await
174            }
175            ThinkerClientBackend::Bedrock { provider } => {
176                self.think_bedrock(provider, system_prompt, user_prompt)
177                    .await
178            }
179            ThinkerClientBackend::Candle { runtime } => {
180                let runtime = Arc::clone(runtime);
181                let system_prompt = system_prompt.to_string();
182                let user_prompt = user_prompt.to_string();
183                tokio::task::spawn_blocking(move || {
184                    let mut guard = match runtime.try_lock() {
185                        Ok(g) => g,
186                        Err(std::sync::TryLockError::WouldBlock) => {
187                            return Err(anyhow!("candle thinker is busy"));
188                        }
189                        Err(std::sync::TryLockError::Poisoned(_)) => {
190                            return Err(anyhow!("candle thinker mutex poisoned"));
191                        }
192                    };
193                    guard.think(&system_prompt, &user_prompt)
194                })
195                .await
196                .context("candle thinker task join failed")?
197            }
198        }
199    }
200
201    async fn think_bedrock(
202        &self,
203        provider: &BedrockProvider,
204        system_prompt: &str,
205        user_prompt: &str,
206    ) -> Result<ThinkerOutput> {
207        let started_at = Instant::now();
208        let model_id = &self.config.model;
209
210        // Build Bedrock Converse request body
211        let body = serde_json::json!({
212            "system": [{"text": system_prompt}],
213            "messages": [{
214                "role": "user",
215                "content": [{"text": user_prompt}]
216            }],
217            "inferenceConfig": {
218                "maxTokens": self.config.max_tokens,
219                "temperature": self.config.temperature
220            }
221        });
222
223        let body_bytes = serde_json::to_vec(&body)?;
224        let encoded_model_id = model_id.replace(':', "%3A");
225        let url = format!(
226            "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
227            self.config.bedrock_region, encoded_model_id
228        );
229
230        let response = provider
231            .send_converse_request(&url, &body_bytes)
232            .await
233            .context("Bedrock thinker converse request failed")?;
234
235        let status = response.status();
236        let text = response
237            .text()
238            .await
239            .context("Failed to read Bedrock thinker response")?;
240
241        if !status.is_success() {
242            return Err(anyhow!(
243                "Bedrock thinker error ({}): {}",
244                status,
245                &text[..text.len().min(500)]
246            ));
247        }
248
249        let parsed: serde_json::Value =
250            serde_json::from_str(&text).context("Failed to parse Bedrock thinker response")?;
251
252        let output_text = parsed["output"]["message"]["content"]
253            .as_array()
254            .and_then(|arr| arr.first())
255            .and_then(|c| c["text"].as_str())
256            .unwrap_or_default()
257            .to_string();
258
259        let usage = &parsed["usage"];
260        let prompt_tokens = usage["inputTokens"].as_u64().map(|v| v as u32);
261        let completion_tokens = usage["outputTokens"].as_u64().map(|v| v as u32);
262
263        tracing::debug!(
264            model = model_id,
265            latency_ms = started_at.elapsed().as_millis(),
266            prompt_tokens = ?prompt_tokens,
267            completion_tokens = ?completion_tokens,
268            "bedrock thinker generated thought"
269        );
270
271        Ok(ThinkerOutput {
272            model: model_id.clone(),
273            finish_reason: parsed["stopReason"].as_str().map(|s| s.to_string()),
274            text: output_text,
275            prompt_tokens,
276            completion_tokens,
277            total_tokens: prompt_tokens.zip(completion_tokens).map(|(p, c)| p + c),
278        })
279    }
280
281    async fn think_openai_compat(
282        &self,
283        http: &Client,
284        system_prompt: &str,
285        user_prompt: &str,
286    ) -> Result<ThinkerOutput> {
287        let started_at = Instant::now();
288        let body = OpenAIChatRequest {
289            model: self.config.model.clone(),
290            messages: vec![
291                OpenAIMessage {
292                    role: "system".to_string(),
293                    content: system_prompt.to_string(),
294                },
295                OpenAIMessage {
296                    role: "user".to_string(),
297                    content: user_prompt.to_string(),
298                },
299            ],
300            temperature: self.config.temperature,
301            top_p: self.config.top_p,
302            max_tokens: self.config.max_tokens,
303            stream: false,
304        };
305
306        // Retry once on transient failures (connection errors, 429, 502-504).
307        let max_attempts: u32 = 2;
308        let mut last_err: Option<anyhow::Error> = None;
309
310        for attempt in 0..max_attempts {
311            if attempt > 0 {
312                tokio::time::sleep(Duration::from_millis(500 * attempt as u64)).await;
313                tracing::debug!(attempt, "retrying thinker HTTP request");
314            }
315
316            let mut request = http.post(&self.config.endpoint).json(&body);
317            if let Some(key) = self.config.api_key.as_ref() {
318                request = request.bearer_auth(key);
319            }
320
321            let response = match request.send().await {
322                Ok(resp) => resp,
323                Err(e) => {
324                    if is_transient_reqwest_error(&e) {
325                        tracing::warn!(attempt, error = %e, "thinker HTTP request failed (transient)");
326                        last_err =
327                            Some(anyhow::Error::from(e).context("transient thinker send error"));
328                        continue;
329                    }
330                    return Err(anyhow::Error::from(e).context("non-transient thinker send error"));
331                }
332            };
333
334            let status = response.status();
335            if is_transient_http_error(status.as_u16()) {
336                let body_text = response.text().await.unwrap_or_default();
337                tracing::warn!(attempt, status = %status, "thinker received transient HTTP error");
338                last_err = Some(anyhow!(
339                    "thinker request failed with status {}: {}",
340                    status,
341                    body_text
342                ));
343                continue;
344            }
345
346            if !status.is_success() {
347                let body_text = response
348                    .text()
349                    .await
350                    .unwrap_or_else(|_| "<empty>".to_string());
351                return Err(anyhow!(
352                    "thinker request failed with status {}: {}",
353                    status,
354                    body_text
355                ));
356            }
357
358            let payload: OpenAIChatResponse = response
359                .json()
360                .await
361                .context("failed to decode thinker response")?;
362            let choice = payload
363                .choices
364                .first()
365                .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
366            let text = choice.message.extract_text();
367            let usage = payload.usage.unwrap_or_default();
368
369            let output = ThinkerOutput {
370                model: payload.model.unwrap_or_else(|| self.config.model.clone()),
371                finish_reason: choice.finish_reason.clone(),
372                text,
373                prompt_tokens: usage.prompt_tokens,
374                completion_tokens: usage.completion_tokens,
375                total_tokens: usage.total_tokens,
376            };
377
378            tracing::debug!(
379                model = %output.model,
380                latency_ms = started_at.elapsed().as_millis(),
381                prompt_tokens = ?output.prompt_tokens,
382                completion_tokens = ?output.completion_tokens,
383                attempt,
384                "openai-compat thinker generated thought"
385            );
386
387            return Ok(output);
388        }
389
390        Err(last_err.unwrap_or_else(|| {
391            anyhow!("thinker HTTP request failed after {max_attempts} attempts")
392        }))
393    }
394}
395
396pub(crate) struct CandleThinker {
397    model: CandleModel,
398    tokenizer: Tokenizer,
399    device: Device,
400    model_label: String,
401    architecture: String,
402    context_window: usize,
403    temperature: f32,
404    top_p: Option<f32>,
405    max_tokens: usize,
406    repeat_penalty: f32,
407    repeat_last_n: usize,
408    seed: u64,
409    request_index: u64,
410    eos_token_ids: HashSet<u32>,
411}
412
413enum CandleModel {
414    Llama(quantized_llama::ModelWeights),
415    Qwen2(quantized_qwen2::ModelWeights),
416    #[cfg(feature = "functiongemma")]
417    Gemma3(quantized_gemma3::ModelWeights),
418}
419
420impl CandleModel {
421    fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
422        match self {
423            Self::Llama(model) => Ok(model.forward(x, index_pos)?),
424            Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
425            #[cfg(feature = "functiongemma")]
426            Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
427        }
428    }
429}
430
431impl CandleThinker {
432    pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
433        let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
434            anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
435        })?;
436        let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
437            anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
438        })?;
439
440        let (device, device_label) = select_candle_device(config)?;
441        let mut reader = BufReader::new(
442            File::open(model_path)
443                .with_context(|| format!("failed to open candle model file at {}", model_path))?,
444        );
445        let content = gguf_file::Content::read(&mut reader)
446            .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
447
448        let architecture = config
449            .candle_arch
450            .clone()
451            .or_else(|| {
452                content
453                    .metadata
454                    .get("general.architecture")
455                    .and_then(|v| v.to_string().ok())
456                    .cloned()
457            })
458            .unwrap_or_else(|| "llama".to_string())
459            .to_ascii_lowercase();
460
461        let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
462        let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
463
464        let tokenizer = Tokenizer::from_file(tokenizer_path)
465            .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
466
467        // Extract EOS metadata from content before it is moved into from_gguf.
468        let gguf_eos_ids = extract_gguf_eos_ids(&content);
469
470        let model = match architecture.as_str() {
471            "llama" => CandleModel::Llama(
472                quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
473                    .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
474            ),
475            "qwen2" => CandleModel::Qwen2(
476                quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
477                    .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
478            ),
479            #[cfg(feature = "functiongemma")]
480            "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
481                quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
482                    .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
483            ),
484            other => {
485                #[cfg(not(feature = "functiongemma"))]
486                if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
487                    return Err(anyhow!(
488                        "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
489                        other
490                    ));
491                }
492                return Err(anyhow!(
493                    "unsupported candle architecture '{}' (supported: llama, qwen2{})",
494                    other,
495                    if cfg!(feature = "functiongemma") {
496                        ", gemma/gemma2/gemma3"
497                    } else {
498                        ""
499                    }
500                ));
501            }
502        };
503
504        let eos_token_ids: HashSet<u32> = collect_eos_token_ids(&tokenizer, &gguf_eos_ids);
505        if eos_token_ids.is_empty() {
506            tracing::warn!(
507                "No EOS tokens found in tokenizer; generation will stop on max token limit"
508            );
509        }
510
511        Ok(Self {
512            model,
513            tokenizer,
514            device,
515            model_label,
516            architecture,
517            context_window,
518            temperature: config.temperature,
519            top_p: config.top_p,
520            max_tokens: config.max_tokens.max(1),
521            repeat_penalty: config.candle_repeat_penalty.max(1.0),
522            repeat_last_n: config.candle_repeat_last_n.max(1),
523            seed: config.candle_seed,
524            request_index: 0,
525            eos_token_ids,
526        })
527    }
528
529    pub(crate) fn think(
530        &mut self,
531        system_prompt: &str,
532        user_prompt: &str,
533    ) -> Result<ThinkerOutput> {
534        let started_at = Instant::now();
535        let prompt = format_chat_prompt(&self.architecture, system_prompt, user_prompt);
536        let encoding = self
537            .tokenizer
538            .encode(prompt.as_str(), true)
539            .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
540        let mut tokens = encoding.get_ids().to_vec();
541        if tokens.is_empty() {
542            return Err(anyhow!("tokenizer produced an empty prompt token set"));
543        }
544
545        // Truncate user content while preserving the system prompt prefix.
546        if self.context_window > 8 && tokens.len() >= self.context_window {
547            let system_only = format_chat_prompt(&self.architecture, system_prompt, "");
548            let sys_encoding = self
549                .tokenizer
550                .encode(system_only.as_str(), true)
551                .map_err(|e| anyhow!("tokenizer encode failed (system): {}", e))?;
552            let sys_len = sys_encoding.get_ids().len();
553            let budget = self.context_window.saturating_sub(8);
554            if sys_len < budget {
555                // Keep system prefix + tail of user content that fits
556                let tail_budget = budget.saturating_sub(sys_len);
557                let tail_start = tokens.len().saturating_sub(tail_budget);
558                let mut truncated = sys_encoding.get_ids().to_vec();
559                truncated.extend_from_slice(&tokens[tail_start..]);
560                tokens = truncated;
561            } else {
562                // System alone exceeds budget; keep only the tail
563                let keep = budget;
564                tokens = tokens[tokens.len().saturating_sub(keep)..].to_vec();
565            }
566        }
567        let prompt_token_count = tokens.len() as u32;
568
569        let request_seed = self.seed.wrapping_add(self.request_index);
570        self.request_index = self.request_index.wrapping_add(1);
571        let mut logits_processor = LogitsProcessor::new(
572            request_seed,
573            Some(self.temperature as f64),
574            self.top_p.map(|v| v as f64),
575        );
576
577        let mut index_pos = 0usize;
578        let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
579        let mut finish_reason = "length".to_string();
580
581        for _ in 0..self.max_tokens {
582            let ctxt: &[u32] = if index_pos == 0 {
583                tokens.as_slice()
584            } else {
585                &tokens[tokens.len() - 1..]
586            };
587
588            let input = Tensor::new(ctxt, &self.device)?
589                .unsqueeze(0)
590                .context("failed to create candle input tensor")?;
591            let mut logits = self
592                .model
593                .forward(&input, index_pos)
594                .context("candle model forward failed")?;
595            index_pos += ctxt.len();
596            logits = logits
597                .squeeze(0)
598                .context("failed to squeeze logits batch dimension")?;
599
600            let logits = if self.repeat_penalty > 1.0 {
601                let start_at = tokens.len().saturating_sub(self.repeat_last_n);
602                apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
603                    .context("failed to apply repeat penalty")?
604            } else {
605                logits
606            };
607
608            let next_token = logits_processor
609                .sample(&logits)
610                .context("token sampling failed")?;
611            if self.eos_token_ids.contains(&next_token) {
612                finish_reason = "stop".to_string();
613                break;
614            }
615
616            tokens.push(next_token);
617            generated.push(next_token);
618
619            if tokens.len() + 1 >= self.context_window {
620                finish_reason = "length".to_string();
621                break;
622            }
623        }
624
625        let text = self
626            .tokenizer
627            .decode(&generated, true)
628            .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
629        let completion_tokens = generated.len() as u32;
630
631        tracing::debug!(
632            model = %self.model_label,
633            latency_ms = started_at.elapsed().as_millis(),
634            prompt_tokens = prompt_token_count,
635            completion_tokens = completion_tokens,
636            "candle thinker generated thought"
637        );
638
639        Ok(ThinkerOutput {
640            model: self.model_label.clone(),
641            finish_reason: Some(finish_reason),
642            text,
643            prompt_tokens: Some(prompt_token_count),
644            completion_tokens: Some(completion_tokens),
645            total_tokens: Some(prompt_token_count + completion_tokens),
646        })
647    }
648}
649
650/// Build a chat prompt using the proper template for each model architecture.
651fn format_chat_prompt(architecture: &str, system_prompt: &str, user_prompt: &str) -> String {
652    match architecture {
653        // ChatML template (Qwen2, Yi, etc.)
654        "qwen2" => format!(
655            "<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
656            system = system_prompt,
657            user = user_prompt,
658        ),
659        // Llama 3 instruct template
660        "llama" => format!(
661            "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
662            system = system_prompt,
663            user = user_prompt,
664        ),
665        // Gemma instruct template
666        "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => format!(
667            "<start_of_turn>user\n{system}\n\n{user}<end_of_turn>\n<start_of_turn>model\n",
668            system = system_prompt,
669            user = user_prompt,
670        ),
671        // Fallback for unknown architectures
672        _ => format!(
673            "System:\n{system}\n\nUser:\n{user}\n\nAssistant:\n",
674            system = system_prompt,
675            user = user_prompt,
676        ),
677    }
678}
679
680fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
681    match config.candle_device {
682        CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
683        CandleDevicePreference::Cuda => {
684            let device = try_cuda_device(config.candle_cuda_ordinal)?;
685            Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
686        }
687        CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
688            Ok(device) => {
689                tracing::info!(
690                    ordinal = config.candle_cuda_ordinal,
691                    "Candle thinker selected CUDA device"
692                );
693                Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
694            }
695            Err(error) => {
696                tracing::warn!(
697                    %error,
698                    "CUDA unavailable for Candle thinker, falling back to CPU"
699                );
700                Ok((Device::Cpu, "cpu".to_string()))
701            }
702        },
703    }
704}
705
706#[cfg(feature = "candle-cuda")]
707fn try_cuda_device(ordinal: usize) -> Result<Device> {
708    Device::new_cuda(ordinal)
709        .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
710}
711
712#[cfg(not(feature = "candle-cuda"))]
713fn try_cuda_device(_ordinal: usize) -> Result<Device> {
714    Err(anyhow!(
715        "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
716    ))
717}
718
719fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
720    let key = match architecture {
721        "qwen2" => "qwen2.context_length",
722        "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
723            // Try gemma3 first, then fall back to gemma2, gemma
724            for prefix in ["gemma3", "gemma2", "gemma"] {
725                let k = format!("{prefix}.context_length");
726                if let Some(v) = content.metadata.get(&k) {
727                    return v.to_u32().ok().map(|v| v as usize);
728                }
729            }
730            return None;
731        }
732        _ => "llama.context_length",
733    };
734    content
735        .metadata
736        .get(key)
737        .and_then(|v| v.to_u32().ok())
738        .map(|v| v as usize)
739}
740
741/// Extract EOS token IDs from GGUF metadata before the content is consumed.
742fn extract_gguf_eos_ids(content: &gguf_file::Content) -> Vec<u32> {
743    let mut ids = Vec::new();
744    for key in ["tokenizer.ggml.eos_token_id", "tokenizer.ggml.eot_token_id"] {
745        if let Some(v) = content.metadata.get(key) {
746            if let Ok(id) = v.to_u32() {
747                if !ids.contains(&id) {
748                    ids.push(id);
749                }
750            }
751        }
752    }
753    ids
754}
755
756fn collect_eos_token_ids(tokenizer: &Tokenizer, gguf_eos_ids: &[u32]) -> HashSet<u32> {
757    let mut ids: HashSet<u32> = gguf_eos_ids.iter().copied().collect();
758
759    // Also check well-known special token strings as fallback.
760    let candidates = [
761        "<|im_end|>",
762        "<|eot_id|>",
763        "<|endoftext|>",
764        "</s>",
765        "<|end|>",
766        "<end_of_turn>",
767    ];
768    for token in candidates {
769        if let Some(id) = tokenizer.token_to_id(token) {
770            ids.insert(id);
771        }
772    }
773    ids
774}
775
776/// Returns true for HTTP status codes that are worth retrying.
777fn is_transient_http_error(status: u16) -> bool {
778    matches!(status, 429 | 502 | 503 | 504)
779}
780
781/// Returns true for reqwest errors that are worth retrying (timeouts, connection resets).
782fn is_transient_reqwest_error(e: &reqwest::Error) -> bool {
783    e.is_timeout() || e.is_connect() || e.is_request()
784}
785
786#[derive(Debug, Serialize)]
787struct OpenAIChatRequest {
788    model: String,
789    messages: Vec<OpenAIMessage>,
790    temperature: f32,
791    #[serde(skip_serializing_if = "Option::is_none")]
792    top_p: Option<f32>,
793    max_tokens: usize,
794    stream: bool,
795}
796
797#[derive(Debug, Serialize)]
798struct OpenAIMessage {
799    role: String,
800    content: String,
801}
802
803#[derive(Debug, Deserialize)]
804struct OpenAIChatResponse {
805    model: Option<String>,
806    choices: Vec<OpenAIChatChoice>,
807    #[serde(default)]
808    usage: Option<OpenAIUsage>,
809}
810
811#[derive(Debug, Deserialize)]
812struct OpenAIChatChoice {
813    message: OpenAIChatChoiceMessage,
814    #[serde(default)]
815    finish_reason: Option<String>,
816}
817
818#[derive(Debug, Deserialize)]
819struct OpenAIChatChoiceMessage {
820    #[serde(default)]
821    content: Option<OpenAIChatContent>,
822    #[serde(default)]
823    reasoning: Option<String>,
824    #[serde(default)]
825    reasoning_content: Option<String>,
826}
827
828#[derive(Debug, Default, Deserialize)]
829struct OpenAIUsage {
830    prompt_tokens: Option<u32>,
831    completion_tokens: Option<u32>,
832    total_tokens: Option<u32>,
833}
834
835#[derive(Debug, Deserialize)]
836#[serde(untagged)]
837enum OpenAIChatContent {
838    Text(String),
839    Parts(Vec<OpenAIChatContentPart>),
840    Part(OpenAIChatContentPart),
841}
842
843#[derive(Debug, Deserialize)]
844struct OpenAIChatContentPart {
845    #[serde(rename = "type")]
846    kind: Option<String>,
847    #[serde(default)]
848    text: Option<String>,
849    #[serde(default)]
850    content: Option<String>,
851}
852
853impl OpenAIChatChoiceMessage {
854    fn extract_text(&self) -> String {
855        let content_text = self
856            .content
857            .as_ref()
858            .map(OpenAIChatContent::to_text)
859            .unwrap_or_default();
860        if !content_text.trim().is_empty() {
861            return content_text;
862        }
863
864        if let Some(reasoning) = self
865            .reasoning
866            .as_deref()
867            .filter(|text| !text.trim().is_empty())
868        {
869            return reasoning.to_string();
870        }
871
872        self.reasoning_content
873            .as_deref()
874            .filter(|text| !text.trim().is_empty())
875            .unwrap_or_default()
876            .to_string()
877    }
878}
879
880impl OpenAIChatContent {
881    fn to_text(&self) -> String {
882        match self {
883            Self::Text(text) => text.clone(),
884            Self::Parts(parts) => parts
885                .iter()
886                .filter_map(OpenAIChatContentPart::text_fragment)
887                .collect::<Vec<_>>()
888                .join("\n"),
889            Self::Part(part) => part.text_fragment().unwrap_or_default(),
890        }
891    }
892}
893
894impl OpenAIChatContentPart {
895    fn text_fragment(&self) -> Option<String> {
896        if let Some(kind) = self.kind.as_deref()
897            && !kind.eq_ignore_ascii_case("text")
898            && !kind.eq_ignore_ascii_case("output_text")
899        {
900            return None;
901        }
902
903        self.text
904            .as_deref()
905            .or(self.content.as_deref())
906            .map(str::trim)
907            .filter(|text| !text.is_empty())
908            .map(ToString::to_string)
909    }
910}