Skip to main content

codetether_agent/cognition/
thinker.rs

1use anyhow::{Context, Result, anyhow};
2use candle_core::quantized::gguf_file;
3use candle_core::{Device, Tensor};
4use candle_transformers::generation::LogitsProcessor;
5#[cfg(feature = "functiongemma")]
6use candle_transformers::models::quantized_gemma3;
7use candle_transformers::models::{quantized_llama, quantized_qwen2};
8use candle_transformers::utils::apply_repeat_penalty;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::fs::File;
12use std::io::BufReader;
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15use tokenizers::Tokenizer;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ThinkerBackend {
19    OpenAICompat,
20    Candle,
21}
22
23impl ThinkerBackend {
24    pub fn from_env(value: &str) -> Self {
25        match value.trim().to_ascii_lowercase().as_str() {
26            "candle" => Self::Candle,
27            "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
28            _ => Self::OpenAICompat,
29        }
30    }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum CandleDevicePreference {
35    Auto,
36    Cpu,
37    Cuda,
38}
39
40impl CandleDevicePreference {
41    pub fn from_env(value: &str) -> Self {
42        match value.trim().to_ascii_lowercase().as_str() {
43            "cpu" => Self::Cpu,
44            "cuda" | "gpu" => Self::Cuda,
45            _ => Self::Auto,
46        }
47    }
48}
49
50#[derive(Debug, Clone)]
51pub struct ThinkerConfig {
52    pub enabled: bool,
53    pub backend: ThinkerBackend,
54    pub endpoint: String,
55    pub model: String,
56    pub api_key: Option<String>,
57    pub temperature: f32,
58    pub top_p: Option<f32>,
59    pub max_tokens: usize,
60    pub timeout_ms: u64,
61    pub candle_model_path: Option<String>,
62    pub candle_tokenizer_path: Option<String>,
63    pub candle_arch: Option<String>,
64    pub candle_device: CandleDevicePreference,
65    pub candle_cuda_ordinal: usize,
66    pub candle_repeat_penalty: f32,
67    pub candle_repeat_last_n: usize,
68    pub candle_seed: u64,
69}
70
71impl Default for ThinkerConfig {
72    fn default() -> Self {
73        Self {
74            enabled: false,
75            backend: ThinkerBackend::OpenAICompat,
76            endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
77            model: "qwen2.5:3b-instruct".to_string(),
78            api_key: None,
79            temperature: 0.2,
80            top_p: None,
81            max_tokens: 256,
82            timeout_ms: 12_000,
83            candle_model_path: None,
84            candle_tokenizer_path: None,
85            candle_arch: None,
86            candle_device: CandleDevicePreference::Auto,
87            candle_cuda_ordinal: 0,
88            candle_repeat_penalty: 1.1,
89            candle_repeat_last_n: 64,
90            candle_seed: 42,
91        }
92    }
93}
94
95#[derive(Debug, Clone)]
96pub struct ThinkerOutput {
97    pub model: String,
98    pub finish_reason: Option<String>,
99    pub text: String,
100    pub prompt_tokens: Option<u32>,
101    pub completion_tokens: Option<u32>,
102    pub total_tokens: Option<u32>,
103}
104
105#[derive(Clone)]
106pub struct ThinkerClient {
107    config: ThinkerConfig,
108    backend: ThinkerClientBackend,
109}
110
111impl std::fmt::Debug for ThinkerClient {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("ThinkerClient")
114            .field("backend", &self.config.backend)
115            .field("model", &self.config.model)
116            .finish()
117    }
118}
119
120#[derive(Clone)]
121enum ThinkerClientBackend {
122    OpenAICompat { http: Client },
123    Candle { runtime: Arc<Mutex<CandleThinker>> },
124}
125
126impl ThinkerClient {
127    pub fn new(config: ThinkerConfig) -> Result<Self> {
128        let backend = match config.backend {
129            ThinkerBackend::OpenAICompat => {
130                let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
131                let http = Client::builder()
132                    .timeout(timeout)
133                    .build()
134                    .context("failed to build thinker HTTP client")?;
135                ThinkerClientBackend::OpenAICompat { http }
136            }
137            ThinkerBackend::Candle => {
138                let runtime = CandleThinker::new(&config)?;
139                ThinkerClientBackend::Candle {
140                    runtime: Arc::new(Mutex::new(runtime)),
141                }
142            }
143        };
144
145        Ok(Self { config, backend })
146    }
147
148    pub fn config(&self) -> &ThinkerConfig {
149        &self.config
150    }
151
152    pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
153        match &self.backend {
154            ThinkerClientBackend::OpenAICompat { http } => {
155                self.think_openai_compat(http, system_prompt, user_prompt)
156                    .await
157            }
158            ThinkerClientBackend::Candle { runtime } => {
159                let runtime = Arc::clone(runtime);
160                let system_prompt = system_prompt.to_string();
161                let user_prompt = user_prompt.to_string();
162                tokio::task::spawn_blocking(move || {
163                    let mut guard = runtime
164                        .lock()
165                        .map_err(|_| anyhow!("candle thinker mutex poisoned"))?;
166                    guard.think(&system_prompt, &user_prompt)
167                })
168                .await
169                .context("candle thinker task join failed")?
170            }
171        }
172    }
173
174    async fn think_openai_compat(
175        &self,
176        http: &Client,
177        system_prompt: &str,
178        user_prompt: &str,
179    ) -> Result<ThinkerOutput> {
180        let body = OpenAIChatRequest {
181            model: self.config.model.clone(),
182            messages: vec![
183                OpenAIMessage {
184                    role: "system".to_string(),
185                    content: system_prompt.to_string(),
186                },
187                OpenAIMessage {
188                    role: "user".to_string(),
189                    content: user_prompt.to_string(),
190                },
191            ],
192            temperature: self.config.temperature,
193            top_p: self.config.top_p,
194            max_tokens: self.config.max_tokens,
195            stream: false,
196        };
197
198        let mut request = http.post(&self.config.endpoint).json(&body);
199        if let Some(key) = self.config.api_key.as_ref() {
200            request = request.bearer_auth(key);
201        }
202
203        let response = request
204            .send()
205            .await
206            .context("thinker request failed to send")?;
207        if !response.status().is_success() {
208            let status = response.status();
209            let body_text = response
210                .text()
211                .await
212                .unwrap_or_else(|_| "<empty>".to_string());
213            return Err(anyhow!(
214                "thinker request failed with status {}: {}",
215                status,
216                body_text
217            ));
218        }
219
220        let payload: OpenAIChatResponse = response
221            .json()
222            .await
223            .context("failed to decode thinker response")?;
224        let choice = payload
225            .choices
226            .first()
227            .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
228        let text = choice.message.extract_text();
229        let usage = payload.usage.unwrap_or_default();
230
231        Ok(ThinkerOutput {
232            model: payload.model.unwrap_or_else(|| self.config.model.clone()),
233            finish_reason: choice.finish_reason.clone(),
234            text,
235            prompt_tokens: usage.prompt_tokens,
236            completion_tokens: usage.completion_tokens,
237            total_tokens: usage.total_tokens,
238        })
239    }
240}
241
242pub(crate) struct CandleThinker {
243    model: CandleModel,
244    tokenizer: Tokenizer,
245    device: Device,
246    model_label: String,
247    context_window: usize,
248    temperature: f32,
249    top_p: Option<f32>,
250    max_tokens: usize,
251    repeat_penalty: f32,
252    repeat_last_n: usize,
253    seed: u64,
254    request_index: u64,
255    eos_token_ids: Vec<u32>,
256}
257
258enum CandleModel {
259    Llama(quantized_llama::ModelWeights),
260    Qwen2(quantized_qwen2::ModelWeights),
261    #[cfg(feature = "functiongemma")]
262    Gemma3(quantized_gemma3::ModelWeights),
263}
264
265impl CandleModel {
266    fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
267        match self {
268            Self::Llama(model) => Ok(model.forward(x, index_pos)?),
269            Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
270            #[cfg(feature = "functiongemma")]
271            Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
272        }
273    }
274}
275
276impl CandleThinker {
277    pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
278        let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
279            anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
280        })?;
281        let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
282            anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
283        })?;
284
285        let (device, device_label) = select_candle_device(config)?;
286        let mut reader = BufReader::new(
287            File::open(model_path)
288                .with_context(|| format!("failed to open candle model file at {}", model_path))?,
289        );
290        let content = gguf_file::Content::read(&mut reader)
291            .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
292
293        let architecture = config
294            .candle_arch
295            .clone()
296            .or_else(|| {
297                content
298                    .metadata
299                    .get("general.architecture")
300                    .and_then(|v| v.to_string().ok())
301                    .cloned()
302            })
303            .unwrap_or_else(|| "llama".to_string())
304            .to_ascii_lowercase();
305
306        let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
307        let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
308
309        let model = match architecture.as_str() {
310            "llama" => CandleModel::Llama(
311                quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
312                    .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
313            ),
314            "qwen2" => CandleModel::Qwen2(
315                quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
316                    .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
317            ),
318            #[cfg(feature = "functiongemma")]
319            "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
320                quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
321                    .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
322            ),
323            other => {
324                #[cfg(not(feature = "functiongemma"))]
325                if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
326                    return Err(anyhow!(
327                        "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
328                        other
329                    ));
330                }
331                return Err(anyhow!(
332                    "unsupported candle architecture '{}' (supported: llama, qwen2{})",
333                    other,
334                    if cfg!(feature = "functiongemma") {
335                        ", gemma/gemma2/gemma3"
336                    } else {
337                        ""
338                    }
339                ));
340            }
341        };
342
343        let tokenizer = Tokenizer::from_file(tokenizer_path)
344            .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
345
346        let eos_token_ids = collect_eos_token_ids(&tokenizer);
347        if eos_token_ids.is_empty() {
348            tracing::warn!(
349                "No EOS tokens found in tokenizer; generation will stop on max token limit"
350            );
351        }
352
353        Ok(Self {
354            model,
355            tokenizer,
356            device,
357            model_label,
358            context_window,
359            temperature: config.temperature,
360            top_p: config.top_p,
361            max_tokens: config.max_tokens.max(1),
362            repeat_penalty: config.candle_repeat_penalty.max(1.0),
363            repeat_last_n: config.candle_repeat_last_n.max(1),
364            seed: config.candle_seed,
365            request_index: 0,
366            eos_token_ids,
367        })
368    }
369
370    pub(crate) fn think(
371        &mut self,
372        system_prompt: &str,
373        user_prompt: &str,
374    ) -> Result<ThinkerOutput> {
375        let started_at = Instant::now();
376        let prompt = format!(
377            "System:\n{}\n\nUser:\n{}\n\nAssistant:\n",
378            system_prompt, user_prompt
379        );
380        let encoding = self
381            .tokenizer
382            .encode(prompt.as_str(), true)
383            .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
384        let mut tokens = encoding.get_ids().to_vec();
385        if tokens.is_empty() {
386            return Err(anyhow!("tokenizer produced an empty prompt token set"));
387        }
388
389        if self.context_window > 8 && tokens.len() >= self.context_window {
390            let keep = self.context_window.saturating_sub(8);
391            tokens = tokens[tokens.len().saturating_sub(keep)..].to_vec();
392        }
393        let prompt_token_count = tokens.len() as u32;
394
395        let request_seed = self.seed.wrapping_add(self.request_index);
396        self.request_index = self.request_index.wrapping_add(1);
397        let mut logits_processor = LogitsProcessor::new(
398            request_seed,
399            Some(self.temperature as f64),
400            self.top_p.map(|v| v as f64),
401        );
402
403        let mut index_pos = 0usize;
404        let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
405        let mut finish_reason = "length".to_string();
406
407        for _ in 0..self.max_tokens {
408            let ctxt: &[u32] = if index_pos == 0 {
409                tokens.as_slice()
410            } else {
411                &tokens[tokens.len() - 1..]
412            };
413
414            let input = Tensor::new(ctxt, &self.device)?
415                .unsqueeze(0)
416                .context("failed to create candle input tensor")?;
417            let mut logits = self
418                .model
419                .forward(&input, index_pos)
420                .context("candle model forward failed")?;
421            index_pos += ctxt.len();
422            logits = logits
423                .squeeze(0)
424                .context("failed to squeeze logits batch dimension")?;
425
426            let logits = if self.repeat_penalty > 1.0 {
427                let start_at = tokens.len().saturating_sub(self.repeat_last_n);
428                apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
429                    .context("failed to apply repeat penalty")?
430            } else {
431                logits
432            };
433
434            let next_token = logits_processor
435                .sample(&logits)
436                .context("token sampling failed")?;
437            if self.eos_token_ids.contains(&next_token) {
438                finish_reason = "stop".to_string();
439                break;
440            }
441
442            tokens.push(next_token);
443            generated.push(next_token);
444
445            if tokens.len() + 1 >= self.context_window {
446                finish_reason = "length".to_string();
447                break;
448            }
449        }
450
451        let text = self
452            .tokenizer
453            .decode(&generated, true)
454            .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
455        let completion_tokens = generated.len() as u32;
456
457        tracing::debug!(
458            model = %self.model_label,
459            latency_ms = started_at.elapsed().as_millis(),
460            prompt_tokens = prompt_token_count,
461            completion_tokens = completion_tokens,
462            "candle thinker generated thought"
463        );
464
465        Ok(ThinkerOutput {
466            model: self.model_label.clone(),
467            finish_reason: Some(finish_reason),
468            text,
469            prompt_tokens: Some(prompt_token_count),
470            completion_tokens: Some(completion_tokens),
471            total_tokens: Some(prompt_token_count + completion_tokens),
472        })
473    }
474}
475
476fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
477    match config.candle_device {
478        CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
479        CandleDevicePreference::Cuda => {
480            let device = try_cuda_device(config.candle_cuda_ordinal)?;
481            Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
482        }
483        CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
484            Ok(device) => {
485                tracing::info!(
486                    ordinal = config.candle_cuda_ordinal,
487                    "Candle thinker selected CUDA device"
488                );
489                Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
490            }
491            Err(error) => {
492                tracing::warn!(
493                    %error,
494                    "CUDA unavailable for Candle thinker, falling back to CPU"
495                );
496                Ok((Device::Cpu, "cpu".to_string()))
497            }
498        },
499    }
500}
501
502#[cfg(feature = "candle-cuda")]
503fn try_cuda_device(ordinal: usize) -> Result<Device> {
504    Device::new_cuda(ordinal)
505        .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
506}
507
508#[cfg(not(feature = "candle-cuda"))]
509fn try_cuda_device(_ordinal: usize) -> Result<Device> {
510    Err(anyhow!(
511        "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
512    ))
513}
514
515fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
516    let key = match architecture {
517        "qwen2" => "qwen2.context_length",
518        "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
519            // Try gemma3 first, then fall back to gemma2, gemma
520            for prefix in ["gemma3", "gemma2", "gemma"] {
521                let k = format!("{prefix}.context_length");
522                if let Some(v) = content.metadata.get(&k) {
523                    return v.to_u32().ok().map(|v| v as usize);
524                }
525            }
526            return None;
527        }
528        _ => "llama.context_length",
529    };
530    content
531        .metadata
532        .get(key)
533        .and_then(|v| v.to_u32().ok())
534        .map(|v| v as usize)
535}
536
537fn collect_eos_token_ids(tokenizer: &Tokenizer) -> Vec<u32> {
538    let candidates = [
539        "<|im_end|>",
540        "<|eot_id|>",
541        "<|endoftext|>",
542        "</s>",
543        "<|end|>",
544        "<end_of_turn>",
545    ];
546    let mut ids = Vec::new();
547    for token in candidates {
548        if let Some(id) = tokenizer.token_to_id(token) {
549            if !ids.contains(&id) {
550                ids.push(id);
551            }
552        }
553    }
554    ids
555}
556
557#[derive(Debug, Serialize)]
558struct OpenAIChatRequest {
559    model: String,
560    messages: Vec<OpenAIMessage>,
561    temperature: f32,
562    #[serde(skip_serializing_if = "Option::is_none")]
563    top_p: Option<f32>,
564    max_tokens: usize,
565    stream: bool,
566}
567
568#[derive(Debug, Serialize)]
569struct OpenAIMessage {
570    role: String,
571    content: String,
572}
573
574#[derive(Debug, Deserialize)]
575struct OpenAIChatResponse {
576    model: Option<String>,
577    choices: Vec<OpenAIChatChoice>,
578    #[serde(default)]
579    usage: Option<OpenAIUsage>,
580}
581
582#[derive(Debug, Deserialize)]
583struct OpenAIChatChoice {
584    message: OpenAIChatChoiceMessage,
585    #[serde(default)]
586    finish_reason: Option<String>,
587}
588
589#[derive(Debug, Deserialize)]
590struct OpenAIChatChoiceMessage {
591    #[serde(default)]
592    content: Option<OpenAIChatContent>,
593    #[serde(default)]
594    reasoning: Option<String>,
595    #[serde(default)]
596    reasoning_content: Option<String>,
597}
598
599#[derive(Debug, Default, Deserialize)]
600struct OpenAIUsage {
601    prompt_tokens: Option<u32>,
602    completion_tokens: Option<u32>,
603    total_tokens: Option<u32>,
604}
605
606#[derive(Debug, Deserialize)]
607#[serde(untagged)]
608enum OpenAIChatContent {
609    Text(String),
610    Parts(Vec<OpenAIChatContentPart>),
611    Part(OpenAIChatContentPart),
612}
613
614#[derive(Debug, Deserialize)]
615struct OpenAIChatContentPart {
616    #[serde(rename = "type")]
617    kind: Option<String>,
618    #[serde(default)]
619    text: Option<String>,
620    #[serde(default)]
621    content: Option<String>,
622}
623
624impl OpenAIChatChoiceMessage {
625    fn extract_text(&self) -> String {
626        let content_text = self
627            .content
628            .as_ref()
629            .map(OpenAIChatContent::to_text)
630            .unwrap_or_default();
631        if !content_text.trim().is_empty() {
632            return content_text;
633        }
634
635        if let Some(reasoning) = self
636            .reasoning
637            .as_deref()
638            .filter(|text| !text.trim().is_empty())
639        {
640            return reasoning.to_string();
641        }
642
643        self.reasoning_content
644            .as_deref()
645            .filter(|text| !text.trim().is_empty())
646            .unwrap_or_default()
647            .to_string()
648    }
649}
650
651impl OpenAIChatContent {
652    fn to_text(&self) -> String {
653        match self {
654            Self::Text(text) => text.clone(),
655            Self::Parts(parts) => parts
656                .iter()
657                .filter_map(OpenAIChatContentPart::text_fragment)
658                .collect::<Vec<_>>()
659                .join("\n"),
660            Self::Part(part) => part.text_fragment().unwrap_or_default(),
661        }
662    }
663}
664
665impl OpenAIChatContentPart {
666    fn text_fragment(&self) -> Option<String> {
667        if let Some(kind) = self.kind.as_deref()
668            && !kind.eq_ignore_ascii_case("text")
669            && !kind.eq_ignore_ascii_case("output_text")
670        {
671            return None;
672        }
673
674        self.text
675            .as_deref()
676            .or(self.content.as_deref())
677            .map(str::trim)
678            .filter(|text| !text.is_empty())
679            .map(ToString::to_string)
680    }
681}