1use anyhow::{Context, Result, anyhow};
2use candle_core::quantized::gguf_file;
3use candle_core::{Device, Tensor};
4use candle_transformers::generation::LogitsProcessor;
5#[cfg(feature = "functiongemma")]
6use candle_transformers::models::quantized_gemma3;
7use candle_transformers::models::{quantized_llama, quantized_qwen2};
8use candle_transformers::utils::apply_repeat_penalty;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::fs::File;
12use std::io::BufReader;
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15use tokenizers::Tokenizer;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ThinkerBackend {
19 OpenAICompat,
20 Candle,
21}
22
23impl ThinkerBackend {
24 pub fn from_env(value: &str) -> Self {
25 match value.trim().to_ascii_lowercase().as_str() {
26 "candle" => Self::Candle,
27 "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
28 _ => Self::OpenAICompat,
29 }
30 }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum CandleDevicePreference {
35 Auto,
36 Cpu,
37 Cuda,
38}
39
40impl CandleDevicePreference {
41 pub fn from_env(value: &str) -> Self {
42 match value.trim().to_ascii_lowercase().as_str() {
43 "cpu" => Self::Cpu,
44 "cuda" | "gpu" => Self::Cuda,
45 _ => Self::Auto,
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
51pub struct ThinkerConfig {
52 pub enabled: bool,
53 pub backend: ThinkerBackend,
54 pub endpoint: String,
55 pub model: String,
56 pub api_key: Option<String>,
57 pub temperature: f32,
58 pub top_p: Option<f32>,
59 pub max_tokens: usize,
60 pub timeout_ms: u64,
61 pub candle_model_path: Option<String>,
62 pub candle_tokenizer_path: Option<String>,
63 pub candle_arch: Option<String>,
64 pub candle_device: CandleDevicePreference,
65 pub candle_cuda_ordinal: usize,
66 pub candle_repeat_penalty: f32,
67 pub candle_repeat_last_n: usize,
68 pub candle_seed: u64,
69}
70
71impl Default for ThinkerConfig {
72 fn default() -> Self {
73 Self {
74 enabled: false,
75 backend: ThinkerBackend::OpenAICompat,
76 endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
77 model: "qwen2.5:3b-instruct".to_string(),
78 api_key: None,
79 temperature: 0.2,
80 top_p: None,
81 max_tokens: 256,
82 timeout_ms: 12_000,
83 candle_model_path: None,
84 candle_tokenizer_path: None,
85 candle_arch: None,
86 candle_device: CandleDevicePreference::Auto,
87 candle_cuda_ordinal: 0,
88 candle_repeat_penalty: 1.1,
89 candle_repeat_last_n: 64,
90 candle_seed: 42,
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
96pub struct ThinkerOutput {
97 pub model: String,
98 pub finish_reason: Option<String>,
99 pub text: String,
100 pub prompt_tokens: Option<u32>,
101 pub completion_tokens: Option<u32>,
102 pub total_tokens: Option<u32>,
103}
104
105#[derive(Clone)]
106pub struct ThinkerClient {
107 config: ThinkerConfig,
108 backend: ThinkerClientBackend,
109}
110
111impl std::fmt::Debug for ThinkerClient {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("ThinkerClient")
114 .field("backend", &self.config.backend)
115 .field("model", &self.config.model)
116 .finish()
117 }
118}
119
120#[derive(Clone)]
121enum ThinkerClientBackend {
122 OpenAICompat { http: Client },
123 Candle { runtime: Arc<Mutex<CandleThinker>> },
124}
125
126impl ThinkerClient {
127 pub fn new(config: ThinkerConfig) -> Result<Self> {
128 let backend = match config.backend {
129 ThinkerBackend::OpenAICompat => {
130 let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
131 let http = Client::builder()
132 .timeout(timeout)
133 .build()
134 .context("failed to build thinker HTTP client")?;
135 ThinkerClientBackend::OpenAICompat { http }
136 }
137 ThinkerBackend::Candle => {
138 let runtime = CandleThinker::new(&config)?;
139 ThinkerClientBackend::Candle {
140 runtime: Arc::new(Mutex::new(runtime)),
141 }
142 }
143 };
144
145 Ok(Self { config, backend })
146 }
147
148 pub fn config(&self) -> &ThinkerConfig {
149 &self.config
150 }
151
152 pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
153 match &self.backend {
154 ThinkerClientBackend::OpenAICompat { http } => {
155 self.think_openai_compat(http, system_prompt, user_prompt)
156 .await
157 }
158 ThinkerClientBackend::Candle { runtime } => {
159 let runtime = Arc::clone(runtime);
160 let system_prompt = system_prompt.to_string();
161 let user_prompt = user_prompt.to_string();
162 tokio::task::spawn_blocking(move || {
163 let mut guard = runtime
164 .lock()
165 .map_err(|_| anyhow!("candle thinker mutex poisoned"))?;
166 guard.think(&system_prompt, &user_prompt)
167 })
168 .await
169 .context("candle thinker task join failed")?
170 }
171 }
172 }
173
174 async fn think_openai_compat(
175 &self,
176 http: &Client,
177 system_prompt: &str,
178 user_prompt: &str,
179 ) -> Result<ThinkerOutput> {
180 let body = OpenAIChatRequest {
181 model: self.config.model.clone(),
182 messages: vec![
183 OpenAIMessage {
184 role: "system".to_string(),
185 content: system_prompt.to_string(),
186 },
187 OpenAIMessage {
188 role: "user".to_string(),
189 content: user_prompt.to_string(),
190 },
191 ],
192 temperature: self.config.temperature,
193 top_p: self.config.top_p,
194 max_tokens: self.config.max_tokens,
195 stream: false,
196 };
197
198 let mut request = http.post(&self.config.endpoint).json(&body);
199 if let Some(key) = self.config.api_key.as_ref() {
200 request = request.bearer_auth(key);
201 }
202
203 let response = request
204 .send()
205 .await
206 .context("thinker request failed to send")?;
207 if !response.status().is_success() {
208 let status = response.status();
209 let body_text = response
210 .text()
211 .await
212 .unwrap_or_else(|_| "<empty>".to_string());
213 return Err(anyhow!(
214 "thinker request failed with status {}: {}",
215 status,
216 body_text
217 ));
218 }
219
220 let payload: OpenAIChatResponse = response
221 .json()
222 .await
223 .context("failed to decode thinker response")?;
224 let choice = payload
225 .choices
226 .first()
227 .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
228 let text = choice.message.extract_text();
229 let usage = payload.usage.unwrap_or_default();
230
231 Ok(ThinkerOutput {
232 model: payload.model.unwrap_or_else(|| self.config.model.clone()),
233 finish_reason: choice.finish_reason.clone(),
234 text,
235 prompt_tokens: usage.prompt_tokens,
236 completion_tokens: usage.completion_tokens,
237 total_tokens: usage.total_tokens,
238 })
239 }
240}
241
242pub(crate) struct CandleThinker {
243 model: CandleModel,
244 tokenizer: Tokenizer,
245 device: Device,
246 model_label: String,
247 context_window: usize,
248 temperature: f32,
249 top_p: Option<f32>,
250 max_tokens: usize,
251 repeat_penalty: f32,
252 repeat_last_n: usize,
253 seed: u64,
254 request_index: u64,
255 eos_token_ids: Vec<u32>,
256}
257
258enum CandleModel {
259 Llama(quantized_llama::ModelWeights),
260 Qwen2(quantized_qwen2::ModelWeights),
261 #[cfg(feature = "functiongemma")]
262 Gemma3(quantized_gemma3::ModelWeights),
263}
264
265impl CandleModel {
266 fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
267 match self {
268 Self::Llama(model) => Ok(model.forward(x, index_pos)?),
269 Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
270 #[cfg(feature = "functiongemma")]
271 Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
272 }
273 }
274}
275
276impl CandleThinker {
277 pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
278 let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
279 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
280 })?;
281 let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
282 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
283 })?;
284
285 let (device, device_label) = select_candle_device(config)?;
286 let mut reader = BufReader::new(
287 File::open(model_path)
288 .with_context(|| format!("failed to open candle model file at {}", model_path))?,
289 );
290 let content = gguf_file::Content::read(&mut reader)
291 .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
292
293 let architecture = config
294 .candle_arch
295 .clone()
296 .or_else(|| {
297 content
298 .metadata
299 .get("general.architecture")
300 .and_then(|v| v.to_string().ok())
301 .cloned()
302 })
303 .unwrap_or_else(|| "llama".to_string())
304 .to_ascii_lowercase();
305
306 let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
307 let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
308
309 let model = match architecture.as_str() {
310 "llama" => CandleModel::Llama(
311 quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
312 .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
313 ),
314 "qwen2" => CandleModel::Qwen2(
315 quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
316 .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
317 ),
318 #[cfg(feature = "functiongemma")]
319 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
320 quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
321 .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
322 ),
323 other => {
324 #[cfg(not(feature = "functiongemma"))]
325 if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
326 return Err(anyhow!(
327 "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
328 other
329 ));
330 }
331 return Err(anyhow!(
332 "unsupported candle architecture '{}' (supported: llama, qwen2{})",
333 other,
334 if cfg!(feature = "functiongemma") {
335 ", gemma/gemma2/gemma3"
336 } else {
337 ""
338 }
339 ));
340 }
341 };
342
343 let tokenizer = Tokenizer::from_file(tokenizer_path)
344 .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
345
346 let eos_token_ids = collect_eos_token_ids(&tokenizer);
347 if eos_token_ids.is_empty() {
348 tracing::warn!(
349 "No EOS tokens found in tokenizer; generation will stop on max token limit"
350 );
351 }
352
353 Ok(Self {
354 model,
355 tokenizer,
356 device,
357 model_label,
358 context_window,
359 temperature: config.temperature,
360 top_p: config.top_p,
361 max_tokens: config.max_tokens.max(1),
362 repeat_penalty: config.candle_repeat_penalty.max(1.0),
363 repeat_last_n: config.candle_repeat_last_n.max(1),
364 seed: config.candle_seed,
365 request_index: 0,
366 eos_token_ids,
367 })
368 }
369
370 pub(crate) fn think(
371 &mut self,
372 system_prompt: &str,
373 user_prompt: &str,
374 ) -> Result<ThinkerOutput> {
375 let started_at = Instant::now();
376 let prompt = format!(
377 "System:\n{}\n\nUser:\n{}\n\nAssistant:\n",
378 system_prompt, user_prompt
379 );
380 let encoding = self
381 .tokenizer
382 .encode(prompt.as_str(), true)
383 .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
384 let mut tokens = encoding.get_ids().to_vec();
385 if tokens.is_empty() {
386 return Err(anyhow!("tokenizer produced an empty prompt token set"));
387 }
388
389 if self.context_window > 8 && tokens.len() >= self.context_window {
390 let keep = self.context_window.saturating_sub(8);
391 tokens = tokens[tokens.len().saturating_sub(keep)..].to_vec();
392 }
393 let prompt_token_count = tokens.len() as u32;
394
395 let request_seed = self.seed.wrapping_add(self.request_index);
396 self.request_index = self.request_index.wrapping_add(1);
397 let mut logits_processor = LogitsProcessor::new(
398 request_seed,
399 Some(self.temperature as f64),
400 self.top_p.map(|v| v as f64),
401 );
402
403 let mut index_pos = 0usize;
404 let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
405 let mut finish_reason = "length".to_string();
406
407 for _ in 0..self.max_tokens {
408 let ctxt: &[u32] = if index_pos == 0 {
409 tokens.as_slice()
410 } else {
411 &tokens[tokens.len() - 1..]
412 };
413
414 let input = Tensor::new(ctxt, &self.device)?
415 .unsqueeze(0)
416 .context("failed to create candle input tensor")?;
417 let mut logits = self
418 .model
419 .forward(&input, index_pos)
420 .context("candle model forward failed")?;
421 index_pos += ctxt.len();
422 logits = logits
423 .squeeze(0)
424 .context("failed to squeeze logits batch dimension")?;
425
426 let logits = if self.repeat_penalty > 1.0 {
427 let start_at = tokens.len().saturating_sub(self.repeat_last_n);
428 apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
429 .context("failed to apply repeat penalty")?
430 } else {
431 logits
432 };
433
434 let next_token = logits_processor
435 .sample(&logits)
436 .context("token sampling failed")?;
437 if self.eos_token_ids.contains(&next_token) {
438 finish_reason = "stop".to_string();
439 break;
440 }
441
442 tokens.push(next_token);
443 generated.push(next_token);
444
445 if tokens.len() + 1 >= self.context_window {
446 finish_reason = "length".to_string();
447 break;
448 }
449 }
450
451 let text = self
452 .tokenizer
453 .decode(&generated, true)
454 .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
455 let completion_tokens = generated.len() as u32;
456
457 tracing::debug!(
458 model = %self.model_label,
459 latency_ms = started_at.elapsed().as_millis(),
460 prompt_tokens = prompt_token_count,
461 completion_tokens = completion_tokens,
462 "candle thinker generated thought"
463 );
464
465 Ok(ThinkerOutput {
466 model: self.model_label.clone(),
467 finish_reason: Some(finish_reason),
468 text,
469 prompt_tokens: Some(prompt_token_count),
470 completion_tokens: Some(completion_tokens),
471 total_tokens: Some(prompt_token_count + completion_tokens),
472 })
473 }
474}
475
476fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
477 match config.candle_device {
478 CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
479 CandleDevicePreference::Cuda => {
480 let device = try_cuda_device(config.candle_cuda_ordinal)?;
481 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
482 }
483 CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
484 Ok(device) => {
485 tracing::info!(
486 ordinal = config.candle_cuda_ordinal,
487 "Candle thinker selected CUDA device"
488 );
489 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
490 }
491 Err(error) => {
492 tracing::warn!(
493 %error,
494 "CUDA unavailable for Candle thinker, falling back to CPU"
495 );
496 Ok((Device::Cpu, "cpu".to_string()))
497 }
498 },
499 }
500}
501
502#[cfg(feature = "candle-cuda")]
503fn try_cuda_device(ordinal: usize) -> Result<Device> {
504 Device::new_cuda(ordinal)
505 .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
506}
507
508#[cfg(not(feature = "candle-cuda"))]
509fn try_cuda_device(_ordinal: usize) -> Result<Device> {
510 Err(anyhow!(
511 "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
512 ))
513}
514
515fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
516 let key = match architecture {
517 "qwen2" => "qwen2.context_length",
518 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
519 for prefix in ["gemma3", "gemma2", "gemma"] {
521 let k = format!("{prefix}.context_length");
522 if let Some(v) = content.metadata.get(&k) {
523 return v.to_u32().ok().map(|v| v as usize);
524 }
525 }
526 return None;
527 }
528 _ => "llama.context_length",
529 };
530 content
531 .metadata
532 .get(key)
533 .and_then(|v| v.to_u32().ok())
534 .map(|v| v as usize)
535}
536
537fn collect_eos_token_ids(tokenizer: &Tokenizer) -> Vec<u32> {
538 let candidates = [
539 "<|im_end|>",
540 "<|eot_id|>",
541 "<|endoftext|>",
542 "</s>",
543 "<|end|>",
544 "<end_of_turn>",
545 ];
546 let mut ids = Vec::new();
547 for token in candidates {
548 if let Some(id) = tokenizer.token_to_id(token) {
549 if !ids.contains(&id) {
550 ids.push(id);
551 }
552 }
553 }
554 ids
555}
556
557#[derive(Debug, Serialize)]
558struct OpenAIChatRequest {
559 model: String,
560 messages: Vec<OpenAIMessage>,
561 temperature: f32,
562 #[serde(skip_serializing_if = "Option::is_none")]
563 top_p: Option<f32>,
564 max_tokens: usize,
565 stream: bool,
566}
567
568#[derive(Debug, Serialize)]
569struct OpenAIMessage {
570 role: String,
571 content: String,
572}
573
574#[derive(Debug, Deserialize)]
575struct OpenAIChatResponse {
576 model: Option<String>,
577 choices: Vec<OpenAIChatChoice>,
578 #[serde(default)]
579 usage: Option<OpenAIUsage>,
580}
581
582#[derive(Debug, Deserialize)]
583struct OpenAIChatChoice {
584 message: OpenAIChatChoiceMessage,
585 #[serde(default)]
586 finish_reason: Option<String>,
587}
588
589#[derive(Debug, Deserialize)]
590struct OpenAIChatChoiceMessage {
591 #[serde(default)]
592 content: Option<OpenAIChatContent>,
593 #[serde(default)]
594 reasoning: Option<String>,
595 #[serde(default)]
596 reasoning_content: Option<String>,
597}
598
599#[derive(Debug, Default, Deserialize)]
600struct OpenAIUsage {
601 prompt_tokens: Option<u32>,
602 completion_tokens: Option<u32>,
603 total_tokens: Option<u32>,
604}
605
606#[derive(Debug, Deserialize)]
607#[serde(untagged)]
608enum OpenAIChatContent {
609 Text(String),
610 Parts(Vec<OpenAIChatContentPart>),
611 Part(OpenAIChatContentPart),
612}
613
614#[derive(Debug, Deserialize)]
615struct OpenAIChatContentPart {
616 #[serde(rename = "type")]
617 kind: Option<String>,
618 #[serde(default)]
619 text: Option<String>,
620 #[serde(default)]
621 content: Option<String>,
622}
623
624impl OpenAIChatChoiceMessage {
625 fn extract_text(&self) -> String {
626 let content_text = self
627 .content
628 .as_ref()
629 .map(OpenAIChatContent::to_text)
630 .unwrap_or_default();
631 if !content_text.trim().is_empty() {
632 return content_text;
633 }
634
635 if let Some(reasoning) = self
636 .reasoning
637 .as_deref()
638 .filter(|text| !text.trim().is_empty())
639 {
640 return reasoning.to_string();
641 }
642
643 self.reasoning_content
644 .as_deref()
645 .filter(|text| !text.trim().is_empty())
646 .unwrap_or_default()
647 .to_string()
648 }
649}
650
651impl OpenAIChatContent {
652 fn to_text(&self) -> String {
653 match self {
654 Self::Text(text) => text.clone(),
655 Self::Parts(parts) => parts
656 .iter()
657 .filter_map(OpenAIChatContentPart::text_fragment)
658 .collect::<Vec<_>>()
659 .join("\n"),
660 Self::Part(part) => part.text_fragment().unwrap_or_default(),
661 }
662 }
663}
664
665impl OpenAIChatContentPart {
666 fn text_fragment(&self) -> Option<String> {
667 if let Some(kind) = self.kind.as_deref()
668 && !kind.eq_ignore_ascii_case("text")
669 && !kind.eq_ignore_ascii_case("output_text")
670 {
671 return None;
672 }
673
674 self.text
675 .as_deref()
676 .or(self.content.as_deref())
677 .map(str::trim)
678 .filter(|text| !text.is_empty())
679 .map(ToString::to_string)
680 }
681}