1use anyhow::{Context, Result, anyhow};
2use candle_core::quantized::gguf_file;
3use candle_core::{Device, Tensor};
4use candle_transformers::generation::LogitsProcessor;
5
6use candle_transformers::models::quantized_gemma3;
7use candle_transformers::models::{quantized_llama, quantized_qwen2};
8use candle_transformers::utils::apply_repeat_penalty;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use std::fs::File;
13use std::io::BufReader;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use tokenizers::Tokenizer;
17
18use crate::provider::bedrock::{AwsCredentials, BedrockProvider};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ThinkerBackend {
22 OpenAICompat,
23 Candle,
24 Bedrock,
25}
26
27impl ThinkerBackend {
28 pub fn from_env(value: &str) -> Self {
29 match value.trim().to_ascii_lowercase().as_str() {
30 "candle" => Self::Candle,
31 "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
32 "bedrock" | "aws" | "aws_bedrock" => Self::Bedrock,
33 _ => Self::OpenAICompat,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum CandleDevicePreference {
40 Auto,
41 Cpu,
42 Cuda,
43}
44
45impl CandleDevicePreference {
46 pub fn from_env(value: &str) -> Self {
47 match value.trim().to_ascii_lowercase().as_str() {
48 "cpu" => Self::Cpu,
49 "cuda" | "gpu" => Self::Cuda,
50 _ => Self::Auto,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
56pub struct ThinkerConfig {
57 pub enabled: bool,
58 pub backend: ThinkerBackend,
59 pub endpoint: String,
60 pub model: String,
61 pub api_key: Option<String>,
62 pub temperature: f32,
63 pub top_p: Option<f32>,
64 pub max_tokens: usize,
65 pub timeout_ms: u64,
66 pub candle_model_path: Option<String>,
67 pub candle_tokenizer_path: Option<String>,
68 pub candle_arch: Option<String>,
69 pub candle_device: CandleDevicePreference,
70 pub candle_cuda_ordinal: usize,
71 pub candle_repeat_penalty: f32,
72 pub candle_repeat_last_n: usize,
73 pub candle_seed: u64,
74 pub bedrock_region: String,
75}
76
77impl Default for ThinkerConfig {
78 fn default() -> Self {
79 Self {
80 enabled: false,
81 backend: ThinkerBackend::OpenAICompat,
82 endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
83 model: "qwen2.5:3b-instruct".to_string(),
84 api_key: None,
85 temperature: 0.2,
86 top_p: None,
87 max_tokens: 256,
88 timeout_ms: 30_000,
89 candle_model_path: None,
90 candle_tokenizer_path: None,
91 candle_arch: None,
92 candle_device: CandleDevicePreference::Auto,
93 candle_cuda_ordinal: 0,
94 candle_repeat_penalty: 1.1,
95 candle_repeat_last_n: 64,
96 candle_seed: 42,
97 bedrock_region: "us-west-2".to_string(),
98 }
99 }
100}
101
102#[derive(Debug, Clone)]
103pub struct ThinkerOutput {
104 pub model: String,
105 pub finish_reason: Option<String>,
106 pub text: String,
107 pub prompt_tokens: Option<u32>,
108 pub completion_tokens: Option<u32>,
109 pub total_tokens: Option<u32>,
110}
111
112#[derive(Clone)]
113pub struct ThinkerClient {
114 config: ThinkerConfig,
115 backend: ThinkerClientBackend,
116}
117
118impl std::fmt::Debug for ThinkerClient {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 f.debug_struct("ThinkerClient")
121 .field("backend", &self.config.backend)
122 .field("model", &self.config.model)
123 .finish()
124 }
125}
126
127#[derive(Clone)]
128enum ThinkerClientBackend {
129 OpenAICompat { http: Client },
130 Candle { runtime: Arc<Mutex<CandleThinker>> },
131 Bedrock { provider: Arc<BedrockProvider> },
132}
133
134impl ThinkerClient {
135 pub fn new(config: ThinkerConfig) -> Result<Self> {
136 let backend = match config.backend {
137 ThinkerBackend::OpenAICompat => {
138 let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
139 let http = Client::builder()
140 .timeout(timeout)
141 .build()
142 .context("failed to build thinker HTTP client")?;
143 ThinkerClientBackend::OpenAICompat { http }
144 }
145 ThinkerBackend::Candle => {
146 let runtime = CandleThinker::new(&config)?;
147 ThinkerClientBackend::Candle {
148 runtime: Arc::new(Mutex::new(runtime)),
149 }
150 }
151 ThinkerBackend::Bedrock => {
152 let creds = AwsCredentials::from_environment()
153 .ok_or_else(|| anyhow!("Bedrock thinker requires AWS credentials (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY or ~/.aws/credentials)"))?;
154 let provider =
155 BedrockProvider::with_credentials(creds, config.bedrock_region.clone())?;
156 ThinkerClientBackend::Bedrock {
157 provider: Arc::new(provider),
158 }
159 }
160 };
161
162 Ok(Self { config, backend })
163 }
164
165 pub fn config(&self) -> &ThinkerConfig {
166 &self.config
167 }
168
169 pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
170 match &self.backend {
171 ThinkerClientBackend::OpenAICompat { http } => {
172 self.think_openai_compat(http, system_prompt, user_prompt)
173 .await
174 }
175 ThinkerClientBackend::Bedrock { provider } => {
176 self.think_bedrock(provider, system_prompt, user_prompt)
177 .await
178 }
179 ThinkerClientBackend::Candle { runtime } => {
180 let runtime = Arc::clone(runtime);
181 let system_prompt = system_prompt.to_string();
182 let user_prompt = user_prompt.to_string();
183 tokio::task::spawn_blocking(move || {
184 let mut guard = match runtime.try_lock() {
185 Ok(g) => g,
186 Err(std::sync::TryLockError::WouldBlock) => {
187 return Err(anyhow!("candle thinker is busy"));
188 }
189 Err(std::sync::TryLockError::Poisoned(_)) => {
190 return Err(anyhow!("candle thinker mutex poisoned"));
191 }
192 };
193 guard.think(&system_prompt, &user_prompt)
194 })
195 .await
196 .context("candle thinker task join failed")?
197 }
198 }
199 }
200
201 async fn think_bedrock(
202 &self,
203 provider: &BedrockProvider,
204 system_prompt: &str,
205 user_prompt: &str,
206 ) -> Result<ThinkerOutput> {
207 let started_at = Instant::now();
208 let model_id = &self.config.model;
209
210 let body = serde_json::json!({
212 "system": [{"text": system_prompt}],
213 "messages": [{
214 "role": "user",
215 "content": [{"text": user_prompt}]
216 }],
217 "inferenceConfig": {
218 "maxTokens": self.config.max_tokens,
219 "temperature": self.config.temperature
220 }
221 });
222
223 let body_bytes = serde_json::to_vec(&body)?;
224 let encoded_model_id = model_id.replace(':', "%3A");
225 let url = format!(
226 "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
227 self.config.bedrock_region, encoded_model_id
228 );
229
230 let response = provider
231 .send_converse_request(&url, &body_bytes)
232 .await
233 .context("Bedrock thinker converse request failed")?;
234
235 let status = response.status();
236 let text = response
237 .text()
238 .await
239 .context("Failed to read Bedrock thinker response")?;
240
241 if !status.is_success() {
242 return Err(anyhow!(
243 "Bedrock thinker error ({}): {}",
244 status,
245 &text[..text.len().min(500)]
246 ));
247 }
248
249 let parsed: serde_json::Value =
250 serde_json::from_str(&text).context("Failed to parse Bedrock thinker response")?;
251
252 let output_text = parsed["output"]["message"]["content"]
253 .as_array()
254 .and_then(|arr| arr.first())
255 .and_then(|c| c["text"].as_str())
256 .unwrap_or_default()
257 .to_string();
258
259 let usage = &parsed["usage"];
260 let prompt_tokens = usage["inputTokens"].as_u64().map(|v| v as u32);
261 let completion_tokens = usage["outputTokens"].as_u64().map(|v| v as u32);
262
263 tracing::debug!(
264 model = model_id,
265 latency_ms = started_at.elapsed().as_millis(),
266 prompt_tokens = ?prompt_tokens,
267 completion_tokens = ?completion_tokens,
268 "bedrock thinker generated thought"
269 );
270
271 Ok(ThinkerOutput {
272 model: model_id.clone(),
273 finish_reason: parsed["stopReason"].as_str().map(|s| s.to_string()),
274 text: output_text,
275 prompt_tokens,
276 completion_tokens,
277 total_tokens: prompt_tokens.zip(completion_tokens).map(|(p, c)| p + c),
278 })
279 }
280
281 async fn think_openai_compat(
282 &self,
283 http: &Client,
284 system_prompt: &str,
285 user_prompt: &str,
286 ) -> Result<ThinkerOutput> {
287 let started_at = Instant::now();
288 let body = OpenAIChatRequest {
289 model: self.config.model.clone(),
290 messages: vec![
291 OpenAIMessage {
292 role: "system".to_string(),
293 content: system_prompt.to_string(),
294 },
295 OpenAIMessage {
296 role: "user".to_string(),
297 content: user_prompt.to_string(),
298 },
299 ],
300 temperature: self.config.temperature,
301 top_p: self.config.top_p,
302 max_tokens: self.config.max_tokens,
303 stream: false,
304 };
305
306 let max_attempts: u32 = 2;
308 let mut last_err: Option<anyhow::Error> = None;
309
310 for attempt in 0..max_attempts {
311 if attempt > 0 {
312 tokio::time::sleep(Duration::from_millis(500 * attempt as u64)).await;
313 tracing::debug!(attempt, "retrying thinker HTTP request");
314 }
315
316 let mut request = http.post(&self.config.endpoint).json(&body);
317 if let Some(key) = self.config.api_key.as_ref() {
318 request = request.bearer_auth(key);
319 }
320
321 let response = match request.send().await {
322 Ok(resp) => resp,
323 Err(e) => {
324 if is_transient_reqwest_error(&e) {
325 tracing::warn!(attempt, error = %e, "thinker HTTP request failed (transient)");
326 last_err =
327 Some(anyhow::Error::from(e).context("transient thinker send error"));
328 continue;
329 }
330 return Err(anyhow::Error::from(e).context("non-transient thinker send error"));
331 }
332 };
333
334 let status = response.status();
335 if is_transient_http_error(status.as_u16()) {
336 let body_text = response.text().await.unwrap_or_default();
337 tracing::warn!(attempt, status = %status, "thinker received transient HTTP error");
338 last_err = Some(anyhow!(
339 "thinker request failed with status {}: {}",
340 status,
341 body_text
342 ));
343 continue;
344 }
345
346 if !status.is_success() {
347 let body_text = response
348 .text()
349 .await
350 .unwrap_or_else(|_| "<empty>".to_string());
351 return Err(anyhow!(
352 "thinker request failed with status {}: {}",
353 status,
354 body_text
355 ));
356 }
357
358 let payload: OpenAIChatResponse = response
359 .json()
360 .await
361 .context("failed to decode thinker response")?;
362 let choice = payload
363 .choices
364 .first()
365 .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
366 let text = choice.message.extract_text();
367 let usage = payload.usage.unwrap_or_default();
368
369 let output = ThinkerOutput {
370 model: payload.model.unwrap_or_else(|| self.config.model.clone()),
371 finish_reason: choice.finish_reason.clone(),
372 text,
373 prompt_tokens: usage.prompt_tokens,
374 completion_tokens: usage.completion_tokens,
375 total_tokens: usage.total_tokens,
376 };
377
378 tracing::debug!(
379 model = %output.model,
380 latency_ms = started_at.elapsed().as_millis(),
381 prompt_tokens = ?output.prompt_tokens,
382 completion_tokens = ?output.completion_tokens,
383 attempt,
384 "openai-compat thinker generated thought"
385 );
386
387 return Ok(output);
388 }
389
390 Err(last_err.unwrap_or_else(|| {
391 anyhow!("thinker HTTP request failed after {max_attempts} attempts")
392 }))
393 }
394}
395
396pub(crate) struct CandleThinker {
397 model: CandleModel,
398 tokenizer: Tokenizer,
399 device: Device,
400 model_label: String,
401 architecture: String,
402 context_window: usize,
403 temperature: f32,
404 top_p: Option<f32>,
405 max_tokens: usize,
406 repeat_penalty: f32,
407 repeat_last_n: usize,
408 seed: u64,
409 request_index: u64,
410 eos_token_ids: HashSet<u32>,
411 cached_tokens: Vec<u32>,
412}
413
414enum CandleModel {
415 Llama(quantized_llama::ModelWeights),
416 Qwen2(quantized_qwen2::ModelWeights),
417
418 Gemma3(quantized_gemma3::ModelWeights),
419}
420
421impl CandleModel {
422 fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
423 match self {
424 Self::Llama(model) => Ok(model.forward(x, index_pos)?),
425 Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
426
427 Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
428 }
429 }
430}
431
432impl CandleThinker {
433 pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
434 let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
435 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
436 })?;
437 let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
438 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
439 })?;
440
441 let (device, device_label) = select_candle_device(config)?;
442 let mut reader = BufReader::new(
443 File::open(model_path)
444 .with_context(|| format!("failed to open candle model file at {}", model_path))?,
445 );
446 let content = gguf_file::Content::read(&mut reader)
447 .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
448
449 let architecture = config
450 .candle_arch
451 .clone()
452 .or_else(|| {
453 content
454 .metadata
455 .get("general.architecture")
456 .and_then(|v| v.to_string().ok())
457 .cloned()
458 })
459 .unwrap_or_else(|| "llama".to_string())
460 .to_ascii_lowercase();
461
462 let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
463 let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
464
465 let tokenizer = Tokenizer::from_file(tokenizer_path)
466 .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
467
468 let gguf_eos_ids = extract_gguf_eos_ids(&content);
470
471 let model = match architecture.as_str() {
472 "llama" => CandleModel::Llama(
473 quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
474 .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
475 ),
476 "qwen2" => CandleModel::Qwen2(
477 quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
478 .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
479 ),
480
481 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
482 quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
483 .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
484 ),
485 other => {
486 #[cfg(not(feature = "functiongemma"))]
487 if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
488 return Err(anyhow!(
489 "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
490 other
491 ));
492 }
493 return Err(anyhow!(
494 "unsupported candle architecture '{}' (supported: llama, qwen2{})",
495 other,
496 if cfg!(feature = "functiongemma") {
497 ", gemma/gemma2/gemma3"
498 } else {
499 ""
500 }
501 ));
502 }
503 };
504
505 let eos_token_ids: HashSet<u32> = collect_eos_token_ids(&tokenizer, &gguf_eos_ids);
506 if eos_token_ids.is_empty() {
507 tracing::warn!(
508 "No EOS tokens found in tokenizer; generation will stop on max token limit"
509 );
510 }
511
512 Ok(Self {
513 model,
514 tokenizer,
515 device,
516 model_label,
517 architecture,
518 context_window,
519 temperature: config.temperature,
520 top_p: config.top_p,
521 max_tokens: config.max_tokens.max(1),
522 repeat_penalty: config.candle_repeat_penalty.max(1.0),
523 repeat_last_n: config.candle_repeat_last_n.max(1),
524 seed: config.candle_seed,
525 request_index: 0,
526 eos_token_ids,
527 cached_tokens: Vec::new(),
528 })
529 }
530
531 pub(crate) fn think(
532 &mut self,
533 system_prompt: &str,
534 user_prompt: &str,
535 ) -> Result<ThinkerOutput> {
536 let started_at = Instant::now();
537 let prompt = format_chat_prompt(&self.architecture, system_prompt, user_prompt);
538 let encoding = self
539 .tokenizer
540 .encode(prompt.as_str(), true)
541 .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
542 let mut tokens = encoding.get_ids().to_vec();
543 if tokens.is_empty() {
544 return Err(anyhow!("tokenizer produced an empty prompt token set"));
545 }
546
547 if self.context_window > 8 && tokens.len() >= self.context_window {
549 let system_only = format_chat_prompt(&self.architecture, system_prompt, "");
550 let sys_encoding = self
551 .tokenizer
552 .encode(system_only.as_str(), true)
553 .map_err(|e| anyhow!("tokenizer encode failed (system): {}", e))?;
554 let sys_len = sys_encoding.get_ids().len();
555 let budget = self.context_window.saturating_sub(8);
556 if sys_len < budget {
557 let tail_budget = budget.saturating_sub(sys_len);
559 let tail_start = tokens.len().saturating_sub(tail_budget);
560 let mut truncated = sys_encoding.get_ids().to_vec();
561 truncated.extend_from_slice(&tokens[tail_start..]);
562 tokens = truncated;
563 } else {
564 let keep = budget;
566 tokens = tokens[tokens.len().saturating_sub(keep)..].to_vec();
567 }
568 }
569 let prompt_token_count = tokens.len() as u32;
570
571 let request_seed = self.seed.wrapping_add(self.request_index);
572 self.request_index = self.request_index.wrapping_add(1);
573 let mut logits_processor = LogitsProcessor::new(
574 request_seed,
575 Some(self.temperature as f64),
576 self.top_p.map(|v| v as f64),
577 );
578
579 let mut index_pos = 0usize;
580 let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
581 let mut finish_reason = "length".to_string();
582
583 for _ in 0..self.max_tokens {
584 let ctxt: &[u32] = if index_pos == 0 {
585 tokens.as_slice()
586 } else {
587 &tokens[tokens.len() - 1..]
588 };
589
590 let input = Tensor::new(ctxt, &self.device)?
591 .unsqueeze(0)
592 .context("failed to create candle input tensor")?;
593 let mut logits = self
594 .model
595 .forward(&input, index_pos)
596 .context("candle model forward failed")?;
597 index_pos += ctxt.len();
598 logits = logits
599 .squeeze(0)
600 .context("failed to squeeze logits batch dimension")?;
601
602 let logits = if self.repeat_penalty > 1.0 {
603 let start_at = tokens.len().saturating_sub(self.repeat_last_n);
604 apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
605 .context("failed to apply repeat penalty")?
606 } else {
607 logits
608 };
609
610 let next_token = logits_processor
611 .sample(&logits)
612 .context("token sampling failed")?;
613 if self.eos_token_ids.contains(&next_token) {
614 finish_reason = "stop".to_string();
615 break;
616 }
617
618 tokens.push(next_token);
619 generated.push(next_token);
620
621 if tokens.len() + 1 >= self.context_window {
622 finish_reason = "length".to_string();
623 break;
624 }
625 }
626
627 let text = self
628 .tokenizer
629 .decode(&generated, true)
630 .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
631 let completion_tokens = generated.len() as u32;
632
633 tracing::debug!(
634 model = %self.model_label,
635 latency_ms = started_at.elapsed().as_millis(),
636 prompt_tokens = prompt_token_count,
637 completion_tokens = completion_tokens,
638 "candle thinker generated thought"
639 );
640
641 Ok(ThinkerOutput {
642 model: self.model_label.clone(),
643 finish_reason: Some(finish_reason),
644 text,
645 prompt_tokens: Some(prompt_token_count),
646 completion_tokens: Some(completion_tokens),
647 total_tokens: Some(prompt_token_count + completion_tokens),
648 })
649 }
650}
651
652fn format_chat_prompt(architecture: &str, system_prompt: &str, user_prompt: &str) -> String {
654 match architecture {
655 "qwen2" => format!(
657 "<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
658 system = system_prompt,
659 user = user_prompt,
660 ),
661 "llama" => format!(
663 "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
664 system = system_prompt,
665 user = user_prompt,
666 ),
667 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => format!(
669 "<start_of_turn>user\n{system}\n\n{user}<end_of_turn>\n<start_of_turn>model\n",
670 system = system_prompt,
671 user = user_prompt,
672 ),
673 _ => format!(
675 "System:\n{system}\n\nUser:\n{user}\n\nAssistant:\n",
676 system = system_prompt,
677 user = user_prompt,
678 ),
679 }
680}
681
682fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
683 match config.candle_device {
684 CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
685 CandleDevicePreference::Cuda => {
686 let device = try_cuda_device(config.candle_cuda_ordinal)?;
687 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
688 }
689 CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
690 Ok(device) => {
691 tracing::info!(
692 ordinal = config.candle_cuda_ordinal,
693 "Candle thinker selected CUDA device"
694 );
695 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
696 }
697 Err(error) => {
698 tracing::warn!(
699 %error,
700 "CUDA unavailable for Candle thinker, falling back to CPU"
701 );
702 Ok((Device::Cpu, "cpu".to_string()))
703 }
704 },
705 }
706}
707
708#[cfg(feature = "candle-cuda")]
709fn try_cuda_device(ordinal: usize) -> Result<Device> {
710 Device::new_cuda(ordinal)
711 .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
712}
713
714#[cfg(not(feature = "candle-cuda"))]
715fn try_cuda_device(_ordinal: usize) -> Result<Device> {
716 Err(anyhow!(
717 "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
718 ))
719}
720
721fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
722 let key = match architecture {
723 "qwen2" => "qwen2.context_length",
724 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
725 for prefix in ["gemma3", "gemma2", "gemma"] {
727 let k = format!("{prefix}.context_length");
728 if let Some(v) = content.metadata.get(&k) {
729 return v.to_u32().ok().map(|v| v as usize);
730 }
731 }
732 return None;
733 }
734 _ => "llama.context_length",
735 };
736 content
737 .metadata
738 .get(key)
739 .and_then(|v| v.to_u32().ok())
740 .map(|v| v as usize)
741}
742
743fn extract_gguf_eos_ids(content: &gguf_file::Content) -> Vec<u32> {
745 let mut ids = Vec::new();
746 for key in ["tokenizer.ggml.eos_token_id", "tokenizer.ggml.eot_token_id"] {
747 if let Some(v) = content.metadata.get(key) {
748 if let Ok(id) = v.to_u32() {
749 if !ids.contains(&id) {
750 ids.push(id);
751 }
752 }
753 }
754 }
755 ids
756}
757
758fn collect_eos_token_ids(tokenizer: &Tokenizer, gguf_eos_ids: &[u32]) -> HashSet<u32> {
759 let mut ids: HashSet<u32> = gguf_eos_ids.iter().copied().collect();
760
761 let candidates = [
763 "<|im_end|>",
764 "<|eot_id|>",
765 "<|endoftext|>",
766 "</s>",
767 "<|end|>",
768 "<end_of_turn>",
769 ];
770 for token in candidates {
771 if let Some(id) = tokenizer.token_to_id(token) {
772 ids.insert(id);
773 }
774 }
775 ids
776}
777
778fn is_transient_http_error(status: u16) -> bool {
780 matches!(status, 429 | 502 | 503 | 504)
781}
782
783fn is_transient_reqwest_error(e: &reqwest::Error) -> bool {
785 e.is_timeout() || e.is_connect() || e.is_request()
786}
787
788#[derive(Debug, Serialize)]
789struct OpenAIChatRequest {
790 model: String,
791 messages: Vec<OpenAIMessage>,
792 temperature: f32,
793 #[serde(skip_serializing_if = "Option::is_none")]
794 top_p: Option<f32>,
795 max_tokens: usize,
796 stream: bool,
797}
798
799#[derive(Debug, Serialize)]
800struct OpenAIMessage {
801 role: String,
802 content: String,
803}
804
805#[derive(Debug, Deserialize)]
806struct OpenAIChatResponse {
807 model: Option<String>,
808 choices: Vec<OpenAIChatChoice>,
809 #[serde(default)]
810 usage: Option<OpenAIUsage>,
811}
812
813#[derive(Debug, Deserialize)]
814struct OpenAIChatChoice {
815 message: OpenAIChatChoiceMessage,
816 #[serde(default)]
817 finish_reason: Option<String>,
818}
819
820#[derive(Debug, Deserialize)]
821struct OpenAIChatChoiceMessage {
822 #[serde(default)]
823 content: Option<OpenAIChatContent>,
824 #[serde(default)]
825 reasoning: Option<String>,
826 #[serde(default)]
827 reasoning_content: Option<String>,
828}
829
830#[derive(Debug, Default, Deserialize)]
831struct OpenAIUsage {
832 prompt_tokens: Option<u32>,
833 completion_tokens: Option<u32>,
834 total_tokens: Option<u32>,
835}
836
837#[derive(Debug, Deserialize)]
838#[serde(untagged)]
839enum OpenAIChatContent {
840 Text(String),
841 Parts(Vec<OpenAIChatContentPart>),
842 Part(OpenAIChatContentPart),
843}
844
845#[derive(Debug, Deserialize)]
846struct OpenAIChatContentPart {
847 #[serde(rename = "type")]
848 kind: Option<String>,
849 #[serde(default)]
850 text: Option<String>,
851 #[serde(default)]
852 content: Option<String>,
853}
854
855impl OpenAIChatChoiceMessage {
856 fn extract_text(&self) -> String {
857 let content_text = self
858 .content
859 .as_ref()
860 .map(OpenAIChatContent::to_text)
861 .unwrap_or_default();
862 if !content_text.trim().is_empty() {
863 return content_text;
864 }
865
866 if let Some(reasoning) = self
867 .reasoning
868 .as_deref()
869 .filter(|text| !text.trim().is_empty())
870 {
871 return reasoning.to_string();
872 }
873
874 self.reasoning_content
875 .as_deref()
876 .filter(|text| !text.trim().is_empty())
877 .unwrap_or_default()
878 .to_string()
879 }
880}
881
882impl OpenAIChatContent {
883 fn to_text(&self) -> String {
884 match self {
885 Self::Text(text) => text.clone(),
886 Self::Parts(parts) => parts
887 .iter()
888 .filter_map(OpenAIChatContentPart::text_fragment)
889 .collect::<Vec<_>>()
890 .join("\n"),
891 Self::Part(part) => part.text_fragment().unwrap_or_default(),
892 }
893 }
894}
895
896impl OpenAIChatContentPart {
897 fn text_fragment(&self) -> Option<String> {
898 if let Some(kind) = self.kind.as_deref()
899 && !kind.eq_ignore_ascii_case("text")
900 && !kind.eq_ignore_ascii_case("output_text")
901 {
902 return None;
903 }
904
905 self.text
906 .as_deref()
907 .or(self.content.as_deref())
908 .map(str::trim)
909 .filter(|text| !text.is_empty())
910 .map(ToString::to_string)
911 }
912}