1use anyhow::{Context, Result, anyhow};
9use candle_core::quantized::gguf_file;
10use candle_core::{DType, Device, Tensor};
11use candle_transformers::generation::LogitsProcessor;
12
13#[cfg(feature = "functiongemma")]
14use candle_transformers::models::quantized_gemma3;
15use candle_transformers::models::{
16 quantized_llama, quantized_qwen2, quantized_qwen3, quantized_qwen3_moe,
17};
18use candle_transformers::utils::apply_repeat_penalty;
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use std::collections::HashSet;
22use std::fs::File;
23use std::io::BufReader;
24use std::sync::{Arc, Mutex};
25use std::time::{Duration, Instant};
26use tokenizers::Tokenizer;
27
28use crate::provider::bedrock::{AwsCredentials, BedrockProvider};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ThinkerBackend {
40 OpenAICompat,
41 Candle,
42 Bedrock,
43}
44
45impl ThinkerBackend {
46 pub fn from_env(value: &str) -> Self {
55 match value.trim().to_ascii_lowercase().as_str() {
56 "candle" => Self::Candle,
57 "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
58 "bedrock" | "aws" | "aws_bedrock" => Self::Bedrock,
59 _ => Self::OpenAICompat,
60 }
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum CandleDevicePreference {
74 Auto,
75 Cpu,
76 Cuda,
77}
78
79impl CandleDevicePreference {
80 pub fn from_env(value: &str) -> Self {
89 match value.trim().to_ascii_lowercase().as_str() {
90 "cpu" => Self::Cpu,
91 "cuda" | "gpu" => Self::Cuda,
92 _ => Self::Auto,
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
107pub struct ThinkerConfig {
108 pub enabled: bool,
109 pub backend: ThinkerBackend,
110 pub endpoint: String,
111 pub model: String,
112 pub api_key: Option<String>,
113 pub temperature: f32,
114 pub top_p: Option<f32>,
115 pub max_tokens: usize,
116 pub timeout_ms: u64,
117 pub candle_model_path: Option<String>,
118 pub candle_tokenizer_path: Option<String>,
119 pub candle_arch: Option<String>,
120 pub candle_device: CandleDevicePreference,
121 pub candle_cuda_ordinal: usize,
122 pub candle_repeat_penalty: f32,
123 pub candle_repeat_last_n: usize,
124 pub candle_seed: u64,
125 pub bedrock_region: String,
126 pub bedrock_service_tier: Option<String>,
127}
128
129impl Default for ThinkerConfig {
130 fn default() -> Self {
131 Self {
132 enabled: false,
133 backend: ThinkerBackend::OpenAICompat,
134 endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
135 model: "qwen3.5-9b".to_string(),
136 api_key: None,
137 temperature: 0.2,
138 top_p: None,
139 max_tokens: 256,
140 timeout_ms: 30_000,
141 candle_model_path: None,
142 candle_tokenizer_path: None,
143 candle_arch: None,
144 candle_device: CandleDevicePreference::Auto,
145 candle_cuda_ordinal: 0,
146 candle_repeat_penalty: 1.1,
147 candle_repeat_last_n: 64,
148 candle_seed: 42,
149 bedrock_region: "us-west-2".to_string(),
150 bedrock_service_tier: None,
151 }
152 }
153}
154
155#[derive(Debug, Clone)]
174pub struct ThinkerOutput {
175 pub model: String,
176 pub finish_reason: Option<String>,
177 pub text: String,
178 pub prompt_tokens: Option<u32>,
179 pub completion_tokens: Option<u32>,
180 pub total_tokens: Option<u32>,
181 #[cfg_attr(not(feature = "candle-cuda"), allow(dead_code))]
182 pub cache_read_tokens: Option<u32>,
183 #[cfg_attr(not(feature = "candle-cuda"), allow(dead_code))]
184 pub cache_write_tokens: Option<u32>,
185}
186
187#[derive(Clone)]
197pub struct ThinkerClient {
198 config: ThinkerConfig,
199 backend: ThinkerClientBackend,
200}
201
202impl std::fmt::Debug for ThinkerClient {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("ThinkerClient")
205 .field("backend", &self.config.backend)
206 .field("model", &self.config.model)
207 .finish()
208 }
209}
210
211#[derive(Clone)]
212enum ThinkerClientBackend {
213 OpenAICompat { http: Client },
214 Candle { runtime: Arc<Mutex<CandleThinker>> },
215 Bedrock { provider: Arc<BedrockProvider> },
216}
217
218impl ThinkerClient {
219 pub fn new(config: ThinkerConfig) -> Result<Self> {
228 let backend = match config.backend {
229 ThinkerBackend::OpenAICompat => {
230 let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
231 let http = Client::builder()
232 .timeout(timeout)
233 .build()
234 .context("failed to build thinker HTTP client")?;
235 ThinkerClientBackend::OpenAICompat { http }
236 }
237 ThinkerBackend::Candle => {
238 let runtime = CandleThinker::new(&config)?;
239 ThinkerClientBackend::Candle {
240 runtime: Arc::new(Mutex::new(runtime)),
241 }
242 }
243 ThinkerBackend::Bedrock => {
244 let creds = AwsCredentials::from_environment()
245 .ok_or_else(|| anyhow!("Bedrock thinker requires AWS credentials (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY or ~/.aws/credentials)"))?;
246 let provider =
247 BedrockProvider::with_credentials(creds, config.bedrock_region.clone())?;
248 ThinkerClientBackend::Bedrock {
249 provider: Arc::new(provider),
250 }
251 }
252 };
253
254 Ok(Self { config, backend })
255 }
256
257 pub fn config(&self) -> &ThinkerConfig {
268 &self.config
269 }
270
271 pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
284 match &self.backend {
285 ThinkerClientBackend::OpenAICompat { http } => {
286 self.think_openai_compat(http, system_prompt, user_prompt)
287 .await
288 }
289 ThinkerClientBackend::Bedrock { provider } => {
290 self.think_bedrock(provider, system_prompt, user_prompt)
291 .await
292 }
293 ThinkerClientBackend::Candle { runtime } => {
294 let runtime = Arc::clone(runtime);
295 let system_prompt = system_prompt.to_string();
296 let user_prompt = user_prompt.to_string();
297 tokio::task::spawn_blocking(move || {
298 let mut guard = match runtime.try_lock() {
299 Ok(g) => g,
300 Err(std::sync::TryLockError::WouldBlock) => {
301 return Err(anyhow!("candle thinker is busy"));
302 }
303 Err(std::sync::TryLockError::Poisoned(_)) => {
304 return Err(anyhow!("candle thinker mutex poisoned"));
305 }
306 };
307 guard.think(&system_prompt, &user_prompt)
308 })
309 .await
310 .context("candle thinker task join failed")?
311 }
312 }
313 }
314
315 async fn think_bedrock(
316 &self,
317 provider: &BedrockProvider,
318 system_prompt: &str,
319 user_prompt: &str,
320 ) -> Result<ThinkerOutput> {
321 let started_at = Instant::now();
322 let model_id = &self.config.model;
323
324 let mut body = serde_json::json!({
326 "system": [{"text": system_prompt}],
327 "messages": [{
328 "role": "user",
329 "content": [{"text": user_prompt}]
330 }],
331 "inferenceConfig": {
332 "maxTokens": self.config.max_tokens,
333 "temperature": self.config.temperature
334 }
335 });
336
337 if let Some(service_tier) = self.config.bedrock_service_tier.as_ref() {
338 body["additionalModelRequestFields"] = serde_json::json!({
339 "service_tier": service_tier
340 });
341 }
342
343 let body_bytes = serde_json::to_vec(&body)?;
344 let url = format!(
348 "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
349 self.config.bedrock_region, model_id
350 );
351
352 let response = provider
353 .send_converse_request(&url, &body_bytes)
354 .await
355 .context("Bedrock thinker converse request failed")?;
356
357 let status = response.status();
358 let text = response
359 .text()
360 .await
361 .context("Failed to read Bedrock thinker response")?;
362
363 if !status.is_success() {
364 return Err(anyhow!(
365 "Bedrock thinker error ({}): {}",
366 status,
367 crate::util::truncate_bytes_safe(&text, 500)
368 ));
369 }
370
371 let parsed: serde_json::Value =
372 serde_json::from_str(&text).context("Failed to parse Bedrock thinker response")?;
373
374 let output_text = parsed["output"]["message"]["content"]
375 .as_array()
376 .and_then(|arr| arr.first())
377 .and_then(|c| c["text"].as_str())
378 .unwrap_or_default()
379 .to_string();
380
381 let usage = &parsed["usage"];
382 let prompt_tokens = usage["inputTokens"].as_u64().map(|v| v as u32);
383 let completion_tokens = usage["outputTokens"].as_u64().map(|v| v as u32);
384
385 tracing::debug!(
386 model = model_id,
387 latency_ms = started_at.elapsed().as_millis(),
388 prompt_tokens = ?prompt_tokens,
389 completion_tokens = ?completion_tokens,
390 "bedrock thinker generated thought"
391 );
392
393 Ok(ThinkerOutput {
394 model: model_id.clone(),
395 finish_reason: parsed["stopReason"].as_str().map(|s| s.to_string()),
396 text: output_text,
397 prompt_tokens,
398 completion_tokens,
399 total_tokens: prompt_tokens.zip(completion_tokens).map(|(p, c)| p + c),
400 cache_read_tokens: None,
401 cache_write_tokens: None,
402 })
403 }
404
405 async fn think_openai_compat(
406 &self,
407 http: &Client,
408 system_prompt: &str,
409 user_prompt: &str,
410 ) -> Result<ThinkerOutput> {
411 let started_at = Instant::now();
412 let body = OpenAIChatRequest {
413 model: self.config.model.clone(),
414 messages: vec![
415 OpenAIMessage {
416 role: "system".to_string(),
417 content: system_prompt.to_string(),
418 },
419 OpenAIMessage {
420 role: "user".to_string(),
421 content: user_prompt.to_string(),
422 },
423 ],
424 temperature: self.config.temperature,
425 top_p: self.config.top_p,
426 max_tokens: self.config.max_tokens,
427 stream: false,
428 };
429
430 let max_attempts: u32 = 2;
432 let mut last_err: Option<anyhow::Error> = None;
433
434 for attempt in 0..max_attempts {
435 if attempt > 0 {
436 tokio::time::sleep(Duration::from_millis(500 * attempt as u64)).await;
437 tracing::debug!(attempt, "retrying thinker HTTP request");
438 }
439
440 let mut request = http.post(&self.config.endpoint).json(&body);
441 if let Some(key) = self.config.api_key.as_ref() {
442 request = request.bearer_auth(key);
443 }
444
445 let response = match request.send().await {
446 Ok(resp) => resp,
447 Err(e) => {
448 if is_transient_reqwest_error(&e) {
449 tracing::warn!(attempt, error = %e, "thinker HTTP request failed (transient)");
450 last_err =
451 Some(anyhow::Error::from(e).context("transient thinker send error"));
452 continue;
453 }
454 return Err(anyhow::Error::from(e).context("non-transient thinker send error"));
455 }
456 };
457
458 let status = response.status();
459 if is_transient_http_error(status.as_u16()) {
460 let body_text = response.text().await.unwrap_or_default();
461 tracing::warn!(attempt, status = %status, "thinker received transient HTTP error");
462 last_err = Some(anyhow!(
463 "thinker request failed with status {}: {}",
464 status,
465 body_text
466 ));
467 continue;
468 }
469
470 if !status.is_success() {
471 let body_text = response
472 .text()
473 .await
474 .unwrap_or_else(|_| "<empty>".to_string());
475 return Err(anyhow!(
476 "thinker request failed with status {}: {}",
477 status,
478 body_text
479 ));
480 }
481
482 let payload: OpenAIChatResponse = response
483 .json()
484 .await
485 .context("failed to decode thinker response")?;
486 let choice = payload
487 .choices
488 .first()
489 .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
490 let text = choice.message.extract_text();
491 let usage = payload.usage.unwrap_or_default();
492
493 let output = ThinkerOutput {
494 model: payload.model.unwrap_or_else(|| self.config.model.clone()),
495 finish_reason: choice.finish_reason.clone(),
496 text,
497 prompt_tokens: usage.prompt_tokens,
498 completion_tokens: usage.completion_tokens,
499 total_tokens: usage.total_tokens,
500 cache_read_tokens: None,
501 cache_write_tokens: None,
502 };
503
504 tracing::debug!(
505 model = %output.model,
506 latency_ms = started_at.elapsed().as_millis(),
507 prompt_tokens = ?output.prompt_tokens,
508 completion_tokens = ?output.completion_tokens,
509 attempt,
510 "openai-compat thinker generated thought"
511 );
512
513 return Ok(output);
514 }
515
516 Err(last_err.unwrap_or_else(|| {
517 anyhow!("thinker HTTP request failed after {max_attempts} attempts")
518 }))
519 }
520}
521
522pub(crate) struct CandleThinker {
523 model: CandleModel,
524 tokenizer: Tokenizer,
525 device: Device,
526 model_label: String,
527 architecture: String,
528 context_window: usize,
529 temperature: f32,
530 top_p: Option<f32>,
531 max_tokens: usize,
532 repeat_penalty: f32,
533 repeat_last_n: usize,
534 seed: u64,
535 request_index: u64,
536 eos_token_ids: HashSet<u32>,
537 cached_tokens: Vec<u32>,
538}
539
540enum CandleModel {
541 Llama(quantized_llama::ModelWeights),
542 Qwen2(quantized_qwen2::ModelWeights),
543 Qwen3(quantized_qwen3::ModelWeights),
544 Qwen3Moe(quantized_qwen3_moe::GGUFQWenMoE),
545
546 #[cfg(feature = "functiongemma")]
547 Gemma3(quantized_gemma3::ModelWeights),
548}
549
550impl CandleModel {
551 fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
552 match self {
553 Self::Llama(model) => Ok(model.forward(x, index_pos)?),
554 Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
555 Self::Qwen3(model) => Ok(model.forward(x, index_pos)?),
556 Self::Qwen3Moe(model) => Ok(model.forward(x, index_pos)?),
557
558 #[cfg(feature = "functiongemma")]
559 Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
560 }
561 }
562
563 fn reset_kv_cache_for_new_request(&mut self) -> Result<()> {
564 match self {
565 Self::Qwen3(model) => {
567 model.clear_kv_cache();
568 Ok(())
569 }
570 Self::Qwen3Moe(_) => Err(anyhow!(
572 "qwen3_moe runtime cannot reset KV cache in this build; restart local runtime or use qwen3"
573 )),
574 Self::Llama(_) | Self::Qwen2(_) => Ok(()),
575
576 #[cfg(feature = "functiongemma")]
577 Self::Gemma3(_) => Ok(()),
578 }
579 }
580
581 fn can_extend_cached_prefix(&self) -> bool {
582 true
583 }
584}
585
586impl CandleThinker {
587 pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
588 let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
589 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
590 })?;
591 let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
592 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
593 })?;
594
595 let (device, device_label) = select_candle_device(config)?;
596 let mut reader = BufReader::new(
597 File::open(model_path)
598 .with_context(|| format!("failed to open candle model file at {}", model_path))?,
599 );
600 let content = gguf_file::Content::read(&mut reader)
601 .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
602
603 let architecture = config
604 .candle_arch
605 .clone()
606 .or_else(|| {
607 content
608 .metadata
609 .get("general.architecture")
610 .and_then(|v| v.to_string().ok())
611 .cloned()
612 })
613 .unwrap_or_else(|| "llama".to_string())
614 .to_ascii_lowercase();
615
616 let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
617 let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
618
619 let tokenizer = Tokenizer::from_file(tokenizer_path)
620 .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
621
622 let gguf_eos_ids = extract_gguf_eos_ids(&content);
624
625 let model = match architecture.as_str() {
626 "llama" => CandleModel::Llama(
627 quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
628 .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
629 ),
630 "qwen2" => CandleModel::Qwen2(
631 quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
632 .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
633 ),
634 "qwen3" => CandleModel::Qwen3(
635 quantized_qwen3::ModelWeights::from_gguf(content, &mut reader, &device)
636 .with_context(|| format!("failed to load qwen3 gguf from {}", model_path))?,
637 ),
638 "qwen3moe" | "qwen3_moe" => CandleModel::Qwen3Moe(
639 quantized_qwen3_moe::GGUFQWenMoE::from_gguf(
640 content,
641 &mut reader,
642 &device,
643 DType::F16,
644 )
645 .with_context(|| format!("failed to load qwen3_moe gguf from {}", model_path))?,
646 ),
647
648 #[cfg(feature = "functiongemma")]
649 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
650 quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
651 .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
652 ),
653 other => {
654 #[cfg(not(feature = "functiongemma"))]
655 if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
656 return Err(anyhow!(
657 "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
658 other
659 ));
660 }
661 return Err(anyhow!(
662 "unsupported candle architecture '{}' (supported: llama, qwen2, qwen3, qwen3_moe{})",
663 other,
664 if cfg!(feature = "functiongemma") {
665 ", gemma/gemma2/gemma3"
666 } else {
667 ""
668 }
669 ));
670 }
671 };
672
673 let eos_token_ids: HashSet<u32> = collect_eos_token_ids(&tokenizer, &gguf_eos_ids);
674 if eos_token_ids.is_empty() {
675 tracing::warn!(
676 "No EOS tokens found in tokenizer; generation will stop on max token limit"
677 );
678 }
679
680 Ok(Self {
681 model,
682 tokenizer,
683 device,
684 model_label,
685 architecture,
686 context_window,
687 temperature: config.temperature,
688 top_p: config.top_p,
689 max_tokens: config.max_tokens.max(1),
690 repeat_penalty: config.candle_repeat_penalty.max(1.0),
691 repeat_last_n: config.candle_repeat_last_n.max(1),
692 seed: config.candle_seed,
693 request_index: 0,
694 eos_token_ids,
695 cached_tokens: Vec::new(),
696 })
697 }
698
699 #[cfg(feature = "functiongemma")]
701 pub(crate) fn think_raw(&mut self, raw_prompt: &str) -> Result<ThinkerOutput> {
702 self.think_inner(raw_prompt)
703 }
704
705 pub(crate) fn think(
706 &mut self,
707 system_prompt: &str,
708 user_prompt: &str,
709 ) -> Result<ThinkerOutput> {
710 let prompt = format_chat_prompt(&self.architecture, system_prompt, user_prompt);
711 self.think_inner(&prompt)
712 }
713
714 fn think_inner(&mut self, prompt: &str) -> Result<ThinkerOutput> {
715 let started_at = Instant::now();
716 let encoding = self
717 .tokenizer
718 .encode(prompt, true)
719 .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
720 let mut tokens = encoding.get_ids().to_vec();
721 if tokens.is_empty() {
722 return Err(anyhow!("tokenizer produced an empty prompt token set"));
723 }
724
725 if self.context_window > 8 && tokens.len() >= self.context_window {
727 let budget = self.context_window.saturating_sub(8);
728 tokens = tokens[tokens.len().saturating_sub(budget)..].to_vec();
729 }
730 let prompt_token_count = tokens.len() as u32;
731
732 let request_seed = self.seed.wrapping_add(self.request_index);
733 self.request_index = self.request_index.wrapping_add(1);
734 let mut logits_processor = LogitsProcessor::new(
735 request_seed,
736 Some(self.temperature as f64),
737 self.top_p.map(|v| v as f64),
738 );
739
740 let mut cache_read_tokens = 0u32;
741 let mut cache_write_tokens = 0u32;
742
743 let can_extend_prefix = self.model.can_extend_cached_prefix()
744 && !self.cached_tokens.is_empty()
745 && tokens.len() > self.cached_tokens.len()
746 && tokens.starts_with(&self.cached_tokens);
747
748 let mut index_pos = if can_extend_prefix {
749 cache_read_tokens = self.cached_tokens.len() as u32;
750 self.cached_tokens.len()
751 } else {
752 if !self.cached_tokens.is_empty() {
753 self.model.reset_kv_cache_for_new_request()?;
754 }
755 0
756 };
757
758 let prefill = if index_pos == 0 {
759 tokens.as_slice()
760 } else {
761 &tokens[index_pos..]
762 };
763 if prefill.is_empty() {
764 self.model.reset_kv_cache_for_new_request()?;
766 index_pos = 0;
767 }
768
769 let prefill = if index_pos == 0 {
770 tokens.as_slice()
771 } else {
772 &tokens[index_pos..]
773 };
774 cache_write_tokens = cache_write_tokens.saturating_add(prefill.len() as u32);
775
776 let input = Tensor::new(prefill, &self.device)?
777 .unsqueeze(0)
778 .context("failed to create candle input tensor")?;
779 let mut logits = self
780 .model
781 .forward(&input, index_pos)
782 .context("candle model forward failed")?;
783 index_pos += prefill.len();
784 logits = logits
785 .squeeze(0)
786 .context("failed to squeeze logits batch dimension")?;
787
788 let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
789 let mut finish_reason = "length".to_string();
790
791 for _ in 0..self.max_tokens {
792 let sampling_logits = if self.repeat_penalty > 1.0 {
793 let start_at = tokens.len().saturating_sub(self.repeat_last_n);
794 apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
795 .context("failed to apply repeat penalty")?
796 } else {
797 logits.clone()
798 };
799
800 let next_token =
801 sample_next_token_with_fallback(&mut logits_processor, &sampling_logits)?;
802 if self.eos_token_ids.contains(&next_token) {
803 finish_reason = "stop".to_string();
804 break;
805 }
806
807 tokens.push(next_token);
808 generated.push(next_token);
809 cache_write_tokens = cache_write_tokens.saturating_add(1);
810
811 if tokens.len() + 1 >= self.context_window {
812 finish_reason = "length".to_string();
813 break;
814 }
815
816 let input = Tensor::new(&tokens[tokens.len() - 1..], &self.device)?
817 .unsqueeze(0)
818 .context("failed to create candle input tensor")?;
819 logits = self
820 .model
821 .forward(&input, index_pos)
822 .context("candle model forward failed")?;
823 index_pos += 1;
824 logits = logits
825 .squeeze(0)
826 .context("failed to squeeze logits batch dimension")?;
827 }
828
829 let text = self
830 .tokenizer
831 .decode(&generated, true)
832 .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
833 let completion_tokens = generated.len() as u32;
834 self.cached_tokens = tokens;
835
836 tracing::debug!(
837 model = %self.model_label,
838 latency_ms = started_at.elapsed().as_millis(),
839 prompt_tokens = prompt_token_count,
840 completion_tokens = completion_tokens,
841 cache_read_tokens = cache_read_tokens,
842 cache_write_tokens = cache_write_tokens,
843 "candle thinker generated thought"
844 );
845
846 Ok(ThinkerOutput {
847 model: self.model_label.clone(),
848 finish_reason: Some(finish_reason),
849 text,
850 prompt_tokens: Some(prompt_token_count),
851 completion_tokens: Some(completion_tokens),
852 total_tokens: Some(prompt_token_count + completion_tokens),
853 cache_read_tokens: Some(cache_read_tokens),
854 cache_write_tokens: Some(cache_write_tokens),
855 })
856 }
857}
858
859fn format_chat_prompt(architecture: &str, system_prompt: &str, user_prompt: &str) -> String {
861 match architecture {
862 "qwen2" | "qwen3" | "qwen3moe" | "qwen3_moe" => format!(
864 "<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
865 system = system_prompt,
866 user = user_prompt,
867 ),
868 "llama" => format!(
870 "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
871 system = system_prompt,
872 user = user_prompt,
873 ),
874 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => format!(
876 "<start_of_turn>user\n{system}\n\n{user}<end_of_turn>\n<start_of_turn>model\n",
877 system = system_prompt,
878 user = user_prompt,
879 ),
880 _ => format!(
882 "System:\n{system}\n\nUser:\n{user}\n\nAssistant:\n",
883 system = system_prompt,
884 user = user_prompt,
885 ),
886 }
887}
888
889fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
890 match config.candle_device {
891 CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
892 CandleDevicePreference::Cuda => {
893 let device = try_cuda_device(config.candle_cuda_ordinal)?;
894 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
895 }
896 CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
897 Ok(device) => {
898 tracing::info!(
899 ordinal = config.candle_cuda_ordinal,
900 "Candle thinker selected CUDA device"
901 );
902 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
903 }
904 Err(error) => {
905 tracing::warn!(
906 %error,
907 "CUDA unavailable for Candle thinker, falling back to CPU"
908 );
909 Ok((Device::Cpu, "cpu".to_string()))
910 }
911 },
912 }
913}
914
915#[cfg(feature = "candle-cuda")]
916fn try_cuda_device(ordinal: usize) -> Result<Device> {
917 Device::new_cuda(ordinal)
918 .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
919}
920
921#[cfg(not(feature = "candle-cuda"))]
922fn try_cuda_device(_ordinal: usize) -> Result<Device> {
923 Err(anyhow!(
924 "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
925 ))
926}
927
928fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
929 let key = match architecture {
930 "qwen2" => "qwen2.context_length",
931 "qwen3" | "qwen3moe" | "qwen3_moe" => "qwen3.context_length",
932 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
933 for prefix in ["gemma3", "gemma2", "gemma"] {
935 let k = format!("{prefix}.context_length");
936 if let Some(v) = content.metadata.get(&k) {
937 return v.to_u32().ok().map(|v| v as usize);
938 }
939 }
940 return None;
941 }
942 _ => "llama.context_length",
943 };
944 content
945 .metadata
946 .get(key)
947 .and_then(|v| v.to_u32().ok())
948 .map(|v| v as usize)
949}
950
951fn extract_gguf_eos_ids(content: &gguf_file::Content) -> Vec<u32> {
953 let mut ids = Vec::new();
954 for key in ["tokenizer.ggml.eos_token_id", "tokenizer.ggml.eot_token_id"] {
955 if let Some(v) = content.metadata.get(key)
956 && let Ok(id) = v.to_u32()
957 && !ids.contains(&id)
958 {
959 ids.push(id);
960 }
961 }
962 ids
963}
964
965fn collect_eos_token_ids(tokenizer: &Tokenizer, gguf_eos_ids: &[u32]) -> HashSet<u32> {
966 let mut ids: HashSet<u32> = gguf_eos_ids.iter().copied().collect();
967
968 let candidates = [
970 "<|im_end|>",
971 "<|eot_id|>",
972 "<|endoftext|>",
973 "</s>",
974 "<|end|>",
975 "<end_of_turn>",
976 ];
977 for token in candidates {
978 if let Some(id) = tokenizer.token_to_id(token) {
979 ids.insert(id);
980 }
981 }
982 ids
983}
984
985fn sample_next_token_with_fallback(
986 logits_processor: &mut LogitsProcessor,
987 logits: &Tensor,
988) -> Result<u32> {
989 match logits_processor.sample(logits) {
990 Ok(token) => Ok(token),
991 Err(sample_error) => {
992 let logits_vec = logits
993 .to_vec1::<f32>()
994 .context("token sampling failed and fallback logits extraction failed")?;
995 let mut best_token = None;
996 let mut best_logit = f32::NEG_INFINITY;
997
998 for (idx, logit) in logits_vec.into_iter().enumerate() {
999 if !logit.is_finite() {
1000 continue;
1001 }
1002 if best_token.is_none() || logit > best_logit {
1003 best_token = Some(idx as u32);
1004 best_logit = logit;
1005 }
1006 }
1007
1008 if let Some(token) = best_token {
1009 tracing::warn!(
1010 error = %sample_error,
1011 token,
1012 "Token sampling produced invalid weights; using greedy argmax fallback"
1013 );
1014 Ok(token)
1015 } else {
1016 Err(anyhow!(
1017 "token sampling failed and fallback argmax found no finite logits: {}",
1018 sample_error
1019 ))
1020 }
1021 }
1022 }
1023}
1024
1025fn is_transient_http_error(status: u16) -> bool {
1027 matches!(status, 429 | 502 | 503 | 504)
1028}
1029
1030fn is_transient_reqwest_error(e: &reqwest::Error) -> bool {
1032 e.is_timeout() || e.is_connect() || e.is_request()
1033}
1034
1035#[derive(Debug, Serialize)]
1036struct OpenAIChatRequest {
1037 model: String,
1038 messages: Vec<OpenAIMessage>,
1039 temperature: f32,
1040 #[serde(skip_serializing_if = "Option::is_none")]
1041 top_p: Option<f32>,
1042 max_tokens: usize,
1043 stream: bool,
1044}
1045
1046#[derive(Debug, Serialize)]
1047struct OpenAIMessage {
1048 role: String,
1049 content: String,
1050}
1051
1052#[derive(Debug, Deserialize)]
1053struct OpenAIChatResponse {
1054 model: Option<String>,
1055 choices: Vec<OpenAIChatChoice>,
1056 #[serde(default)]
1057 usage: Option<OpenAIUsage>,
1058}
1059
1060#[derive(Debug, Deserialize)]
1061struct OpenAIChatChoice {
1062 message: OpenAIChatChoiceMessage,
1063 #[serde(default)]
1064 finish_reason: Option<String>,
1065}
1066
1067#[derive(Debug, Deserialize)]
1068struct OpenAIChatChoiceMessage {
1069 #[serde(default)]
1070 content: Option<OpenAIChatContent>,
1071 #[serde(default)]
1072 reasoning: Option<String>,
1073 #[serde(default)]
1074 reasoning_content: Option<String>,
1075}
1076
1077#[derive(Debug, Default, Deserialize)]
1078struct OpenAIUsage {
1079 prompt_tokens: Option<u32>,
1080 completion_tokens: Option<u32>,
1081 total_tokens: Option<u32>,
1082}
1083
1084#[derive(Debug, Deserialize)]
1085#[serde(untagged)]
1086enum OpenAIChatContent {
1087 Text(String),
1088 Parts(Vec<OpenAIChatContentPart>),
1089 Part(OpenAIChatContentPart),
1090}
1091
1092#[derive(Debug, Deserialize)]
1093struct OpenAIChatContentPart {
1094 #[serde(rename = "type")]
1095 kind: Option<String>,
1096 #[serde(default)]
1097 text: Option<String>,
1098 #[serde(default)]
1099 content: Option<String>,
1100}
1101
1102impl OpenAIChatChoiceMessage {
1103 fn extract_text(&self) -> String {
1104 let content_text = self
1105 .content
1106 .as_ref()
1107 .map(OpenAIChatContent::to_text)
1108 .unwrap_or_default();
1109 if !content_text.trim().is_empty() {
1110 return content_text;
1111 }
1112
1113 if let Some(reasoning) = self
1114 .reasoning
1115 .as_deref()
1116 .filter(|text| !text.trim().is_empty())
1117 {
1118 return reasoning.to_string();
1119 }
1120
1121 self.reasoning_content
1122 .as_deref()
1123 .filter(|text| !text.trim().is_empty())
1124 .unwrap_or_default()
1125 .to_string()
1126 }
1127}
1128
1129impl OpenAIChatContent {
1130 fn to_text(&self) -> String {
1131 match self {
1132 Self::Text(text) => text.clone(),
1133 Self::Parts(parts) => parts
1134 .iter()
1135 .filter_map(OpenAIChatContentPart::text_fragment)
1136 .collect::<Vec<_>>()
1137 .join("\n"),
1138 Self::Part(part) => part.text_fragment().unwrap_or_default(),
1139 }
1140 }
1141}
1142
1143impl OpenAIChatContentPart {
1144 fn text_fragment(&self) -> Option<String> {
1145 if let Some(kind) = self.kind.as_deref()
1146 && !kind.eq_ignore_ascii_case("text")
1147 && !kind.eq_ignore_ascii_case("output_text")
1148 {
1149 return None;
1150 }
1151
1152 self.text
1153 .as_deref()
1154 .or(self.content.as_deref())
1155 .map(str::trim)
1156 .filter(|text| !text.is_empty())
1157 .map(ToString::to_string)
1158 }
1159}