1use anyhow::{Context, Result, anyhow};
9use candle_core::quantized::gguf_file;
10use candle_core::{DType, Device, Tensor};
11use candle_transformers::generation::LogitsProcessor;
12
13#[cfg(feature = "functiongemma")]
14use candle_transformers::models::quantized_gemma3;
15use candle_transformers::models::{
16 quantized_llama, quantized_qwen2, quantized_qwen3, quantized_qwen3_moe,
17};
18use candle_transformers::utils::apply_repeat_penalty;
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use std::collections::HashSet;
22use std::fs::File;
23use std::io::BufReader;
24use std::sync::{Arc, Mutex};
25use std::time::{Duration, Instant};
26use tokenizers::Tokenizer;
27
28use crate::provider::bedrock::{AwsCredentials, BedrockProvider};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ThinkerBackend {
40 OpenAICompat,
41 Candle,
42 Bedrock,
43}
44
45impl ThinkerBackend {
46 pub fn from_env(value: &str) -> Self {
55 match value.trim().to_ascii_lowercase().as_str() {
56 "candle" => Self::Candle,
57 "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
58 "bedrock" | "aws" | "aws_bedrock" => Self::Bedrock,
59 _ => Self::OpenAICompat,
60 }
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum CandleDevicePreference {
74 Auto,
75 Cpu,
76 Cuda,
77}
78
79impl CandleDevicePreference {
80 pub fn from_env(value: &str) -> Self {
89 match value.trim().to_ascii_lowercase().as_str() {
90 "cpu" => Self::Cpu,
91 "cuda" | "gpu" => Self::Cuda,
92 _ => Self::Auto,
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
107pub struct ThinkerConfig {
108 pub enabled: bool,
109 pub backend: ThinkerBackend,
110 pub endpoint: String,
111 pub model: String,
112 pub api_key: Option<String>,
113 pub temperature: f32,
114 pub top_p: Option<f32>,
115 pub max_tokens: usize,
116 pub timeout_ms: u64,
117 pub candle_model_path: Option<String>,
118 pub candle_tokenizer_path: Option<String>,
119 pub candle_arch: Option<String>,
120 pub candle_device: CandleDevicePreference,
121 pub candle_cuda_ordinal: usize,
122 pub candle_repeat_penalty: f32,
123 pub candle_repeat_last_n: usize,
124 pub candle_seed: u64,
125 pub bedrock_region: String,
126 pub bedrock_service_tier: Option<String>,
127}
128
129impl Default for ThinkerConfig {
130 fn default() -> Self {
131 Self {
132 enabled: false,
133 backend: ThinkerBackend::OpenAICompat,
134 endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
135 model: "qwen3.5-9b".to_string(),
136 api_key: None,
137 temperature: 0.2,
138 top_p: None,
139 max_tokens: 256,
140 timeout_ms: 30_000,
141 candle_model_path: None,
142 candle_tokenizer_path: None,
143 candle_arch: None,
144 candle_device: CandleDevicePreference::Auto,
145 candle_cuda_ordinal: 0,
146 candle_repeat_penalty: 1.1,
147 candle_repeat_last_n: 64,
148 candle_seed: 42,
149 bedrock_region: "us-west-2".to_string(),
150 bedrock_service_tier: None,
151 }
152 }
153}
154
155#[derive(Debug, Clone)]
174pub struct ThinkerOutput {
175 pub model: String,
176 pub finish_reason: Option<String>,
177 pub text: String,
178 pub prompt_tokens: Option<u32>,
179 pub completion_tokens: Option<u32>,
180 pub total_tokens: Option<u32>,
181 #[cfg_attr(not(feature = "candle-cuda"), allow(dead_code))]
182 pub cache_read_tokens: Option<u32>,
183 #[cfg_attr(not(feature = "candle-cuda"), allow(dead_code))]
184 pub cache_write_tokens: Option<u32>,
185}
186
187#[derive(Clone)]
197pub struct ThinkerClient {
198 config: ThinkerConfig,
199 backend: ThinkerClientBackend,
200}
201
202impl std::fmt::Debug for ThinkerClient {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("ThinkerClient")
205 .field("backend", &self.config.backend)
206 .field("model", &self.config.model)
207 .finish()
208 }
209}
210
211#[derive(Clone)]
212enum ThinkerClientBackend {
213 OpenAICompat { http: Client },
214 Candle { runtime: Arc<Mutex<CandleThinker>> },
215 Bedrock { provider: Arc<BedrockProvider> },
216}
217
218impl ThinkerClient {
219 pub fn new(config: ThinkerConfig) -> Result<Self> {
228 let backend = match config.backend {
229 ThinkerBackend::OpenAICompat => {
230 let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
231 let http = Client::builder()
232 .timeout(timeout)
233 .build()
234 .context("failed to build thinker HTTP client")?;
235 ThinkerClientBackend::OpenAICompat { http }
236 }
237 ThinkerBackend::Candle => {
238 let runtime = CandleThinker::new(&config)?;
239 ThinkerClientBackend::Candle {
240 runtime: Arc::new(Mutex::new(runtime)),
241 }
242 }
243 ThinkerBackend::Bedrock => {
244 let creds = AwsCredentials::from_environment()
245 .ok_or_else(|| anyhow!("Bedrock thinker requires AWS credentials (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY or ~/.aws/credentials)"))?;
246 let provider =
247 BedrockProvider::with_credentials(creds, config.bedrock_region.clone())?;
248 ThinkerClientBackend::Bedrock {
249 provider: Arc::new(provider),
250 }
251 }
252 };
253
254 Ok(Self { config, backend })
255 }
256
257 pub fn config(&self) -> &ThinkerConfig {
268 &self.config
269 }
270
271 pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
284 match &self.backend {
285 ThinkerClientBackend::OpenAICompat { http } => {
286 self.think_openai_compat(http, system_prompt, user_prompt)
287 .await
288 }
289 ThinkerClientBackend::Bedrock { provider } => {
290 self.think_bedrock(provider, system_prompt, user_prompt)
291 .await
292 }
293 ThinkerClientBackend::Candle { runtime } => {
294 let runtime = Arc::clone(runtime);
295 let system_prompt = system_prompt.to_string();
296 let user_prompt = user_prompt.to_string();
297 tokio::task::spawn_blocking(move || {
298 let mut guard = match runtime.try_lock() {
299 Ok(g) => g,
300 Err(std::sync::TryLockError::WouldBlock) => {
301 return Err(anyhow!("candle thinker is busy"));
302 }
303 Err(std::sync::TryLockError::Poisoned(_)) => {
304 return Err(anyhow!("candle thinker mutex poisoned"));
305 }
306 };
307 guard.think(&system_prompt, &user_prompt)
308 })
309 .await
310 .context("candle thinker task join failed")?
311 }
312 }
313 }
314
315 async fn think_bedrock(
316 &self,
317 provider: &BedrockProvider,
318 system_prompt: &str,
319 user_prompt: &str,
320 ) -> Result<ThinkerOutput> {
321 let started_at = Instant::now();
322 let model_id = &self.config.model;
323
324 let mut body = serde_json::json!({
326 "system": [{"text": system_prompt}],
327 "messages": [{
328 "role": "user",
329 "content": [{"text": user_prompt}]
330 }],
331 "inferenceConfig": {
332 "maxTokens": self.config.max_tokens,
333 "temperature": self.config.temperature
334 }
335 });
336
337 if let Some(service_tier) = self.config.bedrock_service_tier.as_ref() {
338 body["additionalModelRequestFields"] = serde_json::json!({
339 "service_tier": service_tier
340 });
341 }
342
343 let body_bytes = serde_json::to_vec(&body)?;
344 let url = format!(
348 "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
349 self.config.bedrock_region, model_id
350 );
351
352 let response = provider
353 .send_converse_request(&url, &body_bytes)
354 .await
355 .context("Bedrock thinker converse request failed")?;
356
357 let status = response.status();
358 let text = response
359 .text()
360 .await
361 .context("Failed to read Bedrock thinker response")?;
362
363 if !status.is_success() {
364 return Err(anyhow!(
365 "Bedrock thinker error ({}): {}",
366 status,
367 crate::util::truncate_bytes_safe(&text, 500)
368 ));
369 }
370
371 let parsed: serde_json::Value =
372 serde_json::from_str(&text).context("Failed to parse Bedrock thinker response")?;
373
374 let output_text = parsed["output"]["message"]["content"]
375 .as_array()
376 .and_then(|arr| arr.first())
377 .and_then(|c| c["text"].as_str())
378 .unwrap_or_default()
379 .to_string();
380
381 let usage = &parsed["usage"];
382 let prompt_tokens = usage["inputTokens"].as_u64().map(|v| v as u32);
383 let completion_tokens = usage["outputTokens"].as_u64().map(|v| v as u32);
384
385 tracing::debug!(
386 model = model_id,
387 latency_ms = started_at.elapsed().as_millis(),
388 prompt_tokens = ?prompt_tokens,
389 completion_tokens = ?completion_tokens,
390 "bedrock thinker generated thought"
391 );
392
393 Ok(ThinkerOutput {
394 model: model_id.clone(),
395 finish_reason: parsed["stopReason"].as_str().map(|s| s.to_string()),
396 text: output_text,
397 prompt_tokens,
398 completion_tokens,
399 total_tokens: prompt_tokens.zip(completion_tokens).map(|(p, c)| p + c),
400 cache_read_tokens: None,
401 cache_write_tokens: None,
402 })
403 }
404
405 async fn think_openai_compat(
406 &self,
407 http: &Client,
408 system_prompt: &str,
409 user_prompt: &str,
410 ) -> Result<ThinkerOutput> {
411 let started_at = Instant::now();
412 let body = OpenAIChatRequest {
413 model: self.config.model.clone(),
414 messages: vec![
415 OpenAIMessage {
416 role: "system".to_string(),
417 content: system_prompt.to_string(),
418 },
419 OpenAIMessage {
420 role: "user".to_string(),
421 content: user_prompt.to_string(),
422 },
423 ],
424 temperature: self.config.temperature,
425 top_p: self.config.top_p,
426 max_tokens: self.config.max_tokens,
427 stream: false,
428 };
429
430 let max_attempts: u32 = 2;
432 let mut last_err: Option<anyhow::Error> = None;
433
434 for attempt in 0..max_attempts {
435 if attempt > 0 {
436 tokio::time::sleep(Duration::from_millis(500 * attempt as u64)).await;
437 tracing::debug!(attempt, "retrying thinker HTTP request");
438 }
439
440 let mut request = http.post(&self.config.endpoint).json(&body);
441 if let Some(key) = self.config.api_key.as_ref() {
442 request = request.bearer_auth(key);
443 }
444
445 let response = match request.send().await {
446 Ok(resp) => resp,
447 Err(e) => {
448 if is_transient_reqwest_error(&e) {
449 tracing::warn!(attempt, error = %e, "thinker HTTP request failed (transient)");
450 last_err =
451 Some(anyhow::Error::from(e).context("transient thinker send error"));
452 continue;
453 }
454 return Err(anyhow::Error::from(e).context("non-transient thinker send error"));
455 }
456 };
457
458 let status = response.status();
459 if is_transient_http_error(status.as_u16()) {
460 let body_text = response.text().await.unwrap_or_default();
461 tracing::warn!(attempt, status = %status, "thinker received transient HTTP error");
462 last_err = Some(anyhow!(
463 "thinker request failed with status {}: {}",
464 status,
465 body_text
466 ));
467 continue;
468 }
469
470 if !status.is_success() {
471 let body_text = response
472 .text()
473 .await
474 .unwrap_or_else(|_| "<empty>".to_string());
475 return Err(anyhow!(
476 "thinker request failed with status {}: {}",
477 status,
478 body_text
479 ));
480 }
481
482 let payload: OpenAIChatResponse = response
483 .json()
484 .await
485 .context("failed to decode thinker response")?;
486 let choice = payload
487 .choices
488 .first()
489 .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
490 let text = choice.message.extract_text();
491 let usage = payload.usage.unwrap_or_default();
492
493 let output = ThinkerOutput {
494 model: payload.model.unwrap_or_else(|| self.config.model.clone()),
495 finish_reason: choice.finish_reason.clone(),
496 text,
497 prompt_tokens: usage.prompt_tokens,
498 completion_tokens: usage.completion_tokens,
499 total_tokens: usage.total_tokens,
500 cache_read_tokens: None,
501 cache_write_tokens: None,
502 };
503
504 tracing::debug!(
505 model = %output.model,
506 latency_ms = started_at.elapsed().as_millis(),
507 prompt_tokens = ?output.prompt_tokens,
508 completion_tokens = ?output.completion_tokens,
509 attempt,
510 "openai-compat thinker generated thought"
511 );
512
513 return Ok(output);
514 }
515
516 Err(last_err.unwrap_or_else(|| {
517 anyhow!("thinker HTTP request failed after {max_attempts} attempts")
518 }))
519 }
520}
521
522pub(crate) struct CandleThinker {
523 model: CandleModel,
524 tokenizer: Tokenizer,
525 device: Device,
526 model_label: String,
527 architecture: String,
528 context_window: usize,
529 temperature: f32,
530 top_p: Option<f32>,
531 max_tokens: usize,
532 repeat_penalty: f32,
533 repeat_last_n: usize,
534 seed: u64,
535 request_index: u64,
536 eos_token_ids: HashSet<u32>,
537 cached_tokens: Vec<u32>,
538}
539
540enum CandleModel {
541 Llama(quantized_llama::ModelWeights),
542 Qwen2(quantized_qwen2::ModelWeights),
543 Qwen3(quantized_qwen3::ModelWeights),
544 Qwen3Moe(quantized_qwen3_moe::GGUFQWenMoE),
545
546 #[cfg(feature = "functiongemma")]
547 Gemma3(quantized_gemma3::ModelWeights),
548}
549
550impl CandleModel {
551 fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
552 match self {
553 Self::Llama(model) => Ok(model.forward(x, index_pos)?),
554 Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
555 Self::Qwen3(model) => Ok(model.forward(x, index_pos)?),
556 Self::Qwen3Moe(model) => Ok(model.forward(x, index_pos)?),
557
558 #[cfg(feature = "functiongemma")]
559 Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
560 }
561 }
562
563 fn reset_kv_cache_for_new_request(&mut self) -> Result<()> {
564 match self {
565 Self::Qwen3(model) => {
567 model.clear_kv_cache();
568 Ok(())
569 }
570 Self::Qwen3Moe(_) => Err(anyhow!(
572 "qwen3_moe runtime cannot reset KV cache in this build; restart local runtime or use qwen3"
573 )),
574 Self::Llama(_) | Self::Qwen2(_) => Ok(()),
575
576 #[cfg(feature = "functiongemma")]
577 Self::Gemma3(_) => Ok(()),
578 }
579 }
580
581 fn can_extend_cached_prefix(&self) -> bool {
582 true
583 }
584}
585
586impl CandleThinker {
587 pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
588 let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
589 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
590 })?;
591 let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
592 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
593 })?;
594
595 let (device, device_label) = select_candle_device(config)?;
596 let mut reader = BufReader::new(
597 File::open(model_path)
598 .with_context(|| format!("failed to open candle model file at {}", model_path))?,
599 );
600 let content = gguf_file::Content::read(&mut reader)
601 .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
602
603 let architecture = config
604 .candle_arch
605 .clone()
606 .or_else(|| {
607 content
608 .metadata
609 .get("general.architecture")
610 .and_then(|v| v.to_string().ok())
611 .cloned()
612 })
613 .unwrap_or_else(|| "llama".to_string())
614 .to_ascii_lowercase();
615
616 let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
617 let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
618
619 let tokenizer = Tokenizer::from_file(tokenizer_path)
620 .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
621
622 let gguf_eos_ids = extract_gguf_eos_ids(&content);
624
625 let model = match architecture.as_str() {
626 "llama" => CandleModel::Llama(
627 quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
628 .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
629 ),
630 "qwen2" => CandleModel::Qwen2(
631 quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
632 .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
633 ),
634 "qwen3" => CandleModel::Qwen3(
635 quantized_qwen3::ModelWeights::from_gguf(content, &mut reader, &device)
636 .with_context(|| format!("failed to load qwen3 gguf from {}", model_path))?,
637 ),
638 "qwen3moe" | "qwen3_moe" => CandleModel::Qwen3Moe(
639 quantized_qwen3_moe::GGUFQWenMoE::from_gguf(
640 content,
641 &mut reader,
642 &device,
643 DType::F16,
644 )
645 .with_context(|| format!("failed to load qwen3_moe gguf from {}", model_path))?,
646 ),
647
648 #[cfg(feature = "functiongemma")]
649 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
650 quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
651 .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
652 ),
653 other => {
654 #[cfg(not(feature = "functiongemma"))]
655 if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
656 return Err(anyhow!(
657 "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
658 other
659 ));
660 }
661 return Err(anyhow!(
662 "unsupported candle architecture '{}' (supported: llama, qwen2, qwen3, qwen3_moe{})",
663 other,
664 if cfg!(feature = "functiongemma") {
665 ", gemma/gemma2/gemma3"
666 } else {
667 ""
668 }
669 ));
670 }
671 };
672
673 let eos_token_ids: HashSet<u32> = collect_eos_token_ids(&tokenizer, &gguf_eos_ids);
674 if eos_token_ids.is_empty() {
675 tracing::warn!(
676 "No EOS tokens found in tokenizer; generation will stop on max token limit"
677 );
678 }
679
680 Ok(Self {
681 model,
682 tokenizer,
683 device,
684 model_label,
685 architecture,
686 context_window,
687 temperature: config.temperature,
688 top_p: config.top_p,
689 max_tokens: config.max_tokens.max(1),
690 repeat_penalty: config.candle_repeat_penalty.max(1.0),
691 repeat_last_n: config.candle_repeat_last_n.max(1),
692 seed: config.candle_seed,
693 request_index: 0,
694 eos_token_ids,
695 cached_tokens: Vec::new(),
696 })
697 }
698
699 #[cfg(feature = "functiongemma")]
701 pub(crate) fn think_raw(&mut self, raw_prompt: &str) -> Result<ThinkerOutput> {
702 self.think_inner(raw_prompt)
703 }
704
705 pub(crate) fn think(
706 &mut self,
707 system_prompt: &str,
708 user_prompt: &str,
709 ) -> Result<ThinkerOutput> {
710 let prompt = format_chat_prompt(&self.architecture, system_prompt, user_prompt);
711 self.think_inner(&prompt)
712 }
713
714 fn think_inner(&mut self, prompt: &str) -> Result<ThinkerOutput> {
715 let started_at = Instant::now();
716 let encoding = self
717 .tokenizer
718 .encode(prompt, true)
719 .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
720 let mut tokens = encoding.get_ids().to_vec();
721 if tokens.is_empty() {
722 return Err(anyhow!("tokenizer produced an empty prompt token set"));
723 }
724
725 if self.context_window > 8 && tokens.len() >= self.context_window {
727 let budget = self.context_window.saturating_sub(8);
728 tokens = tokens[tokens.len().saturating_sub(budget)..].to_vec();
729 }
730 let prompt_token_count = tokens.len() as u32;
731
732 let request_seed = self.seed.wrapping_add(self.request_index);
733 self.request_index = self.request_index.wrapping_add(1);
734 let mut logits_processor = LogitsProcessor::new(
735 request_seed,
736 Some(self.temperature as f64),
737 self.top_p.map(|v| v as f64),
738 );
739
740 let mut cache_read_tokens = 0u32;
741 let mut cache_write_tokens = 0u32;
742
743 let can_extend_prefix = self.model.can_extend_cached_prefix()
744 && !self.cached_tokens.is_empty()
745 && tokens.len() > self.cached_tokens.len()
746 && tokens.starts_with(&self.cached_tokens);
747
748 let mut index_pos = if can_extend_prefix {
749 cache_read_tokens = self.cached_tokens.len() as u32;
750 self.cached_tokens.len()
751 } else {
752 if !self.cached_tokens.is_empty() {
753 self.model.reset_kv_cache_for_new_request()?;
754 }
755 0
756 };
757
758 let prefill = if index_pos == 0 {
759 tokens.as_slice()
760 } else {
761 &tokens[index_pos..]
762 };
763 if prefill.is_empty() {
764 self.model.reset_kv_cache_for_new_request()?;
766 index_pos = 0;
767 }
768
769 let prefill = if index_pos == 0 {
770 tokens.as_slice()
771 } else {
772 &tokens[index_pos..]
773 };
774 cache_write_tokens = cache_write_tokens.saturating_add(prefill.len() as u32);
775
776 let input = Tensor::new(prefill, &self.device)?
777 .unsqueeze(0)
778 .context("failed to create candle input tensor")?;
779 let mut logits = self
780 .model
781 .forward(&input, index_pos)
782 .context("candle model forward failed")?;
783 index_pos += prefill.len();
784 logits = logits
785 .squeeze(0)
786 .context("failed to squeeze logits batch dimension")?;
787
788 let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens.min(64 * 1024));
791 let mut finish_reason = "length".to_string();
792
793 for _ in 0..self.max_tokens {
794 let sampling_logits = if self.repeat_penalty > 1.0 {
795 let start_at = tokens.len().saturating_sub(self.repeat_last_n);
796 apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
797 .context("failed to apply repeat penalty")?
798 } else {
799 logits.clone()
800 };
801
802 let next_token =
803 sample_next_token_with_fallback(&mut logits_processor, &sampling_logits)?;
804 if self.eos_token_ids.contains(&next_token) {
805 finish_reason = "stop".to_string();
806 break;
807 }
808
809 tokens.push(next_token);
810 generated.push(next_token);
811 cache_write_tokens = cache_write_tokens.saturating_add(1);
812
813 if tokens.len() + 1 >= self.context_window {
814 finish_reason = "length".to_string();
815 break;
816 }
817
818 let input = Tensor::new(&tokens[tokens.len() - 1..], &self.device)?
819 .unsqueeze(0)
820 .context("failed to create candle input tensor")?;
821 logits = self
822 .model
823 .forward(&input, index_pos)
824 .context("candle model forward failed")?;
825 index_pos += 1;
826 logits = logits
827 .squeeze(0)
828 .context("failed to squeeze logits batch dimension")?;
829 }
830
831 let text = self
832 .tokenizer
833 .decode(&generated, true)
834 .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
835 let completion_tokens = generated.len() as u32;
836 self.cached_tokens = tokens;
837
838 tracing::debug!(
839 model = %self.model_label,
840 latency_ms = started_at.elapsed().as_millis(),
841 prompt_tokens = prompt_token_count,
842 completion_tokens = completion_tokens,
843 cache_read_tokens = cache_read_tokens,
844 cache_write_tokens = cache_write_tokens,
845 "candle thinker generated thought"
846 );
847
848 Ok(ThinkerOutput {
849 model: self.model_label.clone(),
850 finish_reason: Some(finish_reason),
851 text,
852 prompt_tokens: Some(prompt_token_count),
853 completion_tokens: Some(completion_tokens),
854 total_tokens: Some(prompt_token_count + completion_tokens),
855 cache_read_tokens: Some(cache_read_tokens),
856 cache_write_tokens: Some(cache_write_tokens),
857 })
858 }
859}
860
861fn format_chat_prompt(architecture: &str, system_prompt: &str, user_prompt: &str) -> String {
863 match architecture {
864 "qwen2" | "qwen3" | "qwen3moe" | "qwen3_moe" => format!(
866 "<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
867 system = system_prompt,
868 user = user_prompt,
869 ),
870 "llama" => format!(
872 "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
873 system = system_prompt,
874 user = user_prompt,
875 ),
876 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => format!(
878 "<start_of_turn>user\n{system}\n\n{user}<end_of_turn>\n<start_of_turn>model\n",
879 system = system_prompt,
880 user = user_prompt,
881 ),
882 _ => format!(
884 "System:\n{system}\n\nUser:\n{user}\n\nAssistant:\n",
885 system = system_prompt,
886 user = user_prompt,
887 ),
888 }
889}
890
891fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
892 match config.candle_device {
893 CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
894 CandleDevicePreference::Cuda => {
895 let device = try_cuda_device(config.candle_cuda_ordinal)?;
896 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
897 }
898 CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
899 Ok(device) => {
900 tracing::info!(
901 ordinal = config.candle_cuda_ordinal,
902 "Candle thinker selected CUDA device"
903 );
904 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
905 }
906 Err(error) => {
907 tracing::warn!(
908 %error,
909 "CUDA unavailable for Candle thinker, falling back to CPU"
910 );
911 Ok((Device::Cpu, "cpu".to_string()))
912 }
913 },
914 }
915}
916
917#[cfg(feature = "candle-cuda")]
918fn try_cuda_device(ordinal: usize) -> Result<Device> {
919 Device::new_cuda(ordinal)
920 .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
921}
922
923#[cfg(not(feature = "candle-cuda"))]
924fn try_cuda_device(_ordinal: usize) -> Result<Device> {
925 Err(anyhow!(
926 "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
927 ))
928}
929
930fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
931 let key = match architecture {
932 "qwen2" => "qwen2.context_length",
933 "qwen3" | "qwen3moe" | "qwen3_moe" => "qwen3.context_length",
934 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
935 for prefix in ["gemma3", "gemma2", "gemma"] {
937 let k = format!("{prefix}.context_length");
938 if let Some(v) = content.metadata.get(&k) {
939 return v.to_u32().ok().map(|v| v as usize);
940 }
941 }
942 return None;
943 }
944 _ => "llama.context_length",
945 };
946 content
947 .metadata
948 .get(key)
949 .and_then(|v| v.to_u32().ok())
950 .map(|v| v as usize)
951}
952
953fn extract_gguf_eos_ids(content: &gguf_file::Content) -> Vec<u32> {
955 let mut ids = Vec::new();
956 for key in ["tokenizer.ggml.eos_token_id", "tokenizer.ggml.eot_token_id"] {
957 if let Some(v) = content.metadata.get(key)
958 && let Ok(id) = v.to_u32()
959 && !ids.contains(&id)
960 {
961 ids.push(id);
962 }
963 }
964 ids
965}
966
967fn collect_eos_token_ids(tokenizer: &Tokenizer, gguf_eos_ids: &[u32]) -> HashSet<u32> {
968 let mut ids: HashSet<u32> = gguf_eos_ids.iter().copied().collect();
969
970 let candidates = [
972 "<|im_end|>",
973 "<|eot_id|>",
974 "<|endoftext|>",
975 "</s>",
976 "<|end|>",
977 "<end_of_turn>",
978 ];
979 for token in candidates {
980 if let Some(id) = tokenizer.token_to_id(token) {
981 ids.insert(id);
982 }
983 }
984 ids
985}
986
987fn sample_next_token_with_fallback(
988 logits_processor: &mut LogitsProcessor,
989 logits: &Tensor,
990) -> Result<u32> {
991 match logits_processor.sample(logits) {
992 Ok(token) => Ok(token),
993 Err(sample_error) => {
994 let logits_vec = logits
995 .to_vec1::<f32>()
996 .context("token sampling failed and fallback logits extraction failed")?;
997 let mut best_token = None;
998 let mut best_logit = f32::NEG_INFINITY;
999
1000 for (idx, logit) in logits_vec.into_iter().enumerate() {
1001 if !logit.is_finite() {
1002 continue;
1003 }
1004 if best_token.is_none() || logit > best_logit {
1005 best_token = Some(idx as u32);
1006 best_logit = logit;
1007 }
1008 }
1009
1010 if let Some(token) = best_token {
1011 tracing::warn!(
1012 error = %sample_error,
1013 token,
1014 "Token sampling produced invalid weights; using greedy argmax fallback"
1015 );
1016 Ok(token)
1017 } else {
1018 Err(anyhow!(
1019 "token sampling failed and fallback argmax found no finite logits: {}",
1020 sample_error
1021 ))
1022 }
1023 }
1024 }
1025}
1026
1027fn is_transient_http_error(status: u16) -> bool {
1029 matches!(status, 429 | 502 | 503 | 504)
1030}
1031
1032fn is_transient_reqwest_error(e: &reqwest::Error) -> bool {
1034 e.is_timeout() || e.is_connect() || e.is_request()
1035}
1036
1037#[derive(Debug, Serialize)]
1038struct OpenAIChatRequest {
1039 model: String,
1040 messages: Vec<OpenAIMessage>,
1041 temperature: f32,
1042 #[serde(skip_serializing_if = "Option::is_none")]
1043 top_p: Option<f32>,
1044 max_tokens: usize,
1045 stream: bool,
1046}
1047
1048#[derive(Debug, Serialize)]
1049struct OpenAIMessage {
1050 role: String,
1051 content: String,
1052}
1053
1054#[derive(Debug, Deserialize)]
1055struct OpenAIChatResponse {
1056 model: Option<String>,
1057 choices: Vec<OpenAIChatChoice>,
1058 #[serde(default)]
1059 usage: Option<OpenAIUsage>,
1060}
1061
1062#[derive(Debug, Deserialize)]
1063struct OpenAIChatChoice {
1064 message: OpenAIChatChoiceMessage,
1065 #[serde(default)]
1066 finish_reason: Option<String>,
1067}
1068
1069#[derive(Debug, Deserialize)]
1070struct OpenAIChatChoiceMessage {
1071 #[serde(default)]
1072 content: Option<OpenAIChatContent>,
1073 #[serde(default)]
1074 reasoning: Option<String>,
1075 #[serde(default)]
1076 reasoning_content: Option<String>,
1077}
1078
1079#[derive(Debug, Default, Deserialize)]
1080struct OpenAIUsage {
1081 prompt_tokens: Option<u32>,
1082 completion_tokens: Option<u32>,
1083 total_tokens: Option<u32>,
1084}
1085
1086#[derive(Debug, Deserialize)]
1087#[serde(untagged)]
1088enum OpenAIChatContent {
1089 Text(String),
1090 Parts(Vec<OpenAIChatContentPart>),
1091 Part(OpenAIChatContentPart),
1092}
1093
1094#[derive(Debug, Deserialize)]
1095struct OpenAIChatContentPart {
1096 #[serde(rename = "type")]
1097 kind: Option<String>,
1098 #[serde(default)]
1099 text: Option<String>,
1100 #[serde(default)]
1101 content: Option<String>,
1102}
1103
1104impl OpenAIChatChoiceMessage {
1105 fn extract_text(&self) -> String {
1106 let content_text = self
1107 .content
1108 .as_ref()
1109 .map(OpenAIChatContent::to_text)
1110 .unwrap_or_default();
1111 if !content_text.trim().is_empty() {
1112 return content_text;
1113 }
1114
1115 if let Some(reasoning) = self
1116 .reasoning
1117 .as_deref()
1118 .filter(|text| !text.trim().is_empty())
1119 {
1120 return reasoning.to_string();
1121 }
1122
1123 self.reasoning_content
1124 .as_deref()
1125 .filter(|text| !text.trim().is_empty())
1126 .unwrap_or_default()
1127 .to_string()
1128 }
1129}
1130
1131impl OpenAIChatContent {
1132 fn to_text(&self) -> String {
1133 match self {
1134 Self::Text(text) => text.clone(),
1135 Self::Parts(parts) => parts
1136 .iter()
1137 .filter_map(OpenAIChatContentPart::text_fragment)
1138 .collect::<Vec<_>>()
1139 .join("\n"),
1140 Self::Part(part) => part.text_fragment().unwrap_or_default(),
1141 }
1142 }
1143}
1144
1145impl OpenAIChatContentPart {
1146 fn text_fragment(&self) -> Option<String> {
1147 if let Some(kind) = self.kind.as_deref()
1148 && !kind.eq_ignore_ascii_case("text")
1149 && !kind.eq_ignore_ascii_case("output_text")
1150 {
1151 return None;
1152 }
1153
1154 self.text
1155 .as_deref()
1156 .or(self.content.as_deref())
1157 .map(str::trim)
1158 .filter(|text| !text.is_empty())
1159 .map(ToString::to_string)
1160 }
1161}