1use anyhow::{Context, Result, anyhow};
2use candle_core::quantized::gguf_file;
3use candle_core::{Device, Tensor};
4use candle_transformers::generation::LogitsProcessor;
5#[cfg(feature = "functiongemma")]
6use candle_transformers::models::quantized_gemma3;
7use candle_transformers::models::{quantized_llama, quantized_qwen2};
8use candle_transformers::utils::apply_repeat_penalty;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use std::fs::File;
13use std::io::BufReader;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use tokenizers::Tokenizer;
17
18use crate::provider::bedrock::{AwsCredentials, BedrockProvider};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ThinkerBackend {
22 OpenAICompat,
23 Candle,
24 Bedrock,
25}
26
27impl ThinkerBackend {
28 pub fn from_env(value: &str) -> Self {
29 match value.trim().to_ascii_lowercase().as_str() {
30 "candle" => Self::Candle,
31 "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
32 "bedrock" | "aws" | "aws_bedrock" => Self::Bedrock,
33 _ => Self::OpenAICompat,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum CandleDevicePreference {
40 Auto,
41 Cpu,
42 Cuda,
43}
44
45impl CandleDevicePreference {
46 pub fn from_env(value: &str) -> Self {
47 match value.trim().to_ascii_lowercase().as_str() {
48 "cpu" => Self::Cpu,
49 "cuda" | "gpu" => Self::Cuda,
50 _ => Self::Auto,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
56pub struct ThinkerConfig {
57 pub enabled: bool,
58 pub backend: ThinkerBackend,
59 pub endpoint: String,
60 pub model: String,
61 pub api_key: Option<String>,
62 pub temperature: f32,
63 pub top_p: Option<f32>,
64 pub max_tokens: usize,
65 pub timeout_ms: u64,
66 pub candle_model_path: Option<String>,
67 pub candle_tokenizer_path: Option<String>,
68 pub candle_arch: Option<String>,
69 pub candle_device: CandleDevicePreference,
70 pub candle_cuda_ordinal: usize,
71 pub candle_repeat_penalty: f32,
72 pub candle_repeat_last_n: usize,
73 pub candle_seed: u64,
74 pub bedrock_region: String,
75}
76
77impl Default for ThinkerConfig {
78 fn default() -> Self {
79 Self {
80 enabled: false,
81 backend: ThinkerBackend::OpenAICompat,
82 endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
83 model: "qwen2.5:3b-instruct".to_string(),
84 api_key: None,
85 temperature: 0.2,
86 top_p: None,
87 max_tokens: 256,
88 timeout_ms: 30_000,
89 candle_model_path: None,
90 candle_tokenizer_path: None,
91 candle_arch: None,
92 candle_device: CandleDevicePreference::Auto,
93 candle_cuda_ordinal: 0,
94 candle_repeat_penalty: 1.1,
95 candle_repeat_last_n: 64,
96 candle_seed: 42,
97 bedrock_region: "us-west-2".to_string(),
98 }
99 }
100}
101
102#[derive(Debug, Clone)]
103pub struct ThinkerOutput {
104 pub model: String,
105 pub finish_reason: Option<String>,
106 pub text: String,
107 pub prompt_tokens: Option<u32>,
108 pub completion_tokens: Option<u32>,
109 pub total_tokens: Option<u32>,
110}
111
112#[derive(Clone)]
113pub struct ThinkerClient {
114 config: ThinkerConfig,
115 backend: ThinkerClientBackend,
116}
117
118impl std::fmt::Debug for ThinkerClient {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 f.debug_struct("ThinkerClient")
121 .field("backend", &self.config.backend)
122 .field("model", &self.config.model)
123 .finish()
124 }
125}
126
127#[derive(Clone)]
128enum ThinkerClientBackend {
129 OpenAICompat { http: Client },
130 Candle { runtime: Arc<Mutex<CandleThinker>> },
131 Bedrock { provider: Arc<BedrockProvider> },
132}
133
134impl ThinkerClient {
135 pub fn new(config: ThinkerConfig) -> Result<Self> {
136 let backend = match config.backend {
137 ThinkerBackend::OpenAICompat => {
138 let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
139 let http = Client::builder()
140 .timeout(timeout)
141 .build()
142 .context("failed to build thinker HTTP client")?;
143 ThinkerClientBackend::OpenAICompat { http }
144 }
145 ThinkerBackend::Candle => {
146 let runtime = CandleThinker::new(&config)?;
147 ThinkerClientBackend::Candle {
148 runtime: Arc::new(Mutex::new(runtime)),
149 }
150 }
151 ThinkerBackend::Bedrock => {
152 let creds = AwsCredentials::from_environment()
153 .ok_or_else(|| anyhow!("Bedrock thinker requires AWS credentials (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY or ~/.aws/credentials)"))?;
154 let provider = BedrockProvider::with_credentials(creds, config.bedrock_region.clone())?;
155 ThinkerClientBackend::Bedrock { provider: Arc::new(provider) }
156 }
157 };
158
159 Ok(Self { config, backend })
160 }
161
162 pub fn config(&self) -> &ThinkerConfig {
163 &self.config
164 }
165
166 pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
167 match &self.backend {
168 ThinkerClientBackend::OpenAICompat { http } => {
169 self.think_openai_compat(http, system_prompt, user_prompt)
170 .await
171 }
172 ThinkerClientBackend::Bedrock { provider } => {
173 self.think_bedrock(provider, system_prompt, user_prompt).await
174 }
175 ThinkerClientBackend::Candle { runtime } => {
176 let runtime = Arc::clone(runtime);
177 let system_prompt = system_prompt.to_string();
178 let user_prompt = user_prompt.to_string();
179 tokio::task::spawn_blocking(move || {
180 let mut guard = match runtime.try_lock() {
181 Ok(g) => g,
182 Err(std::sync::TryLockError::WouldBlock) => {
183 return Err(anyhow!("candle thinker is busy"));
184 }
185 Err(std::sync::TryLockError::Poisoned(_)) => {
186 return Err(anyhow!("candle thinker mutex poisoned"));
187 }
188 };
189 guard.think(&system_prompt, &user_prompt)
190 })
191 .await
192 .context("candle thinker task join failed")?
193 }
194 }
195 }
196
197 async fn think_bedrock(
198 &self,
199 provider: &BedrockProvider,
200 system_prompt: &str,
201 user_prompt: &str,
202 ) -> Result<ThinkerOutput> {
203 let started_at = Instant::now();
204 let model_id = &self.config.model;
205
206 let body = serde_json::json!({
208 "system": [{"text": system_prompt}],
209 "messages": [{
210 "role": "user",
211 "content": [{"text": user_prompt}]
212 }],
213 "inferenceConfig": {
214 "maxTokens": self.config.max_tokens,
215 "temperature": self.config.temperature
216 }
217 });
218
219 let body_bytes = serde_json::to_vec(&body)?;
220 let encoded_model_id = model_id.replace(':', "%3A");
221 let url = format!(
222 "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
223 self.config.bedrock_region, encoded_model_id
224 );
225
226 let response = provider
227 .send_converse_request(&url, &body_bytes)
228 .await
229 .context("Bedrock thinker converse request failed")?;
230
231 let status = response.status();
232 let text = response.text().await.context("Failed to read Bedrock thinker response")?;
233
234 if !status.is_success() {
235 return Err(anyhow!("Bedrock thinker error ({}): {}", status, &text[..text.len().min(500)]));
236 }
237
238 let parsed: serde_json::Value = serde_json::from_str(&text)
239 .context("Failed to parse Bedrock thinker response")?;
240
241 let output_text = parsed["output"]["message"]["content"]
242 .as_array()
243 .and_then(|arr| arr.first())
244 .and_then(|c| c["text"].as_str())
245 .unwrap_or_default()
246 .to_string();
247
248 let usage = &parsed["usage"];
249 let prompt_tokens = usage["inputTokens"].as_u64().map(|v| v as u32);
250 let completion_tokens = usage["outputTokens"].as_u64().map(|v| v as u32);
251
252 tracing::debug!(
253 model = model_id,
254 latency_ms = started_at.elapsed().as_millis(),
255 prompt_tokens = ?prompt_tokens,
256 completion_tokens = ?completion_tokens,
257 "bedrock thinker generated thought"
258 );
259
260 Ok(ThinkerOutput {
261 model: model_id.clone(),
262 finish_reason: parsed["stopReason"].as_str().map(|s| s.to_string()),
263 text: output_text,
264 prompt_tokens,
265 completion_tokens,
266 total_tokens: prompt_tokens.zip(completion_tokens).map(|(p, c)| p + c),
267 })
268 }
269
270 async fn think_openai_compat(
271 &self,
272 http: &Client,
273 system_prompt: &str,
274 user_prompt: &str,
275 ) -> Result<ThinkerOutput> {
276 let started_at = Instant::now();
277 let body = OpenAIChatRequest {
278 model: self.config.model.clone(),
279 messages: vec![
280 OpenAIMessage {
281 role: "system".to_string(),
282 content: system_prompt.to_string(),
283 },
284 OpenAIMessage {
285 role: "user".to_string(),
286 content: user_prompt.to_string(),
287 },
288 ],
289 temperature: self.config.temperature,
290 top_p: self.config.top_p,
291 max_tokens: self.config.max_tokens,
292 stream: false,
293 };
294
295 let max_attempts: u32 = 2;
297 let mut last_err: Option<anyhow::Error> = None;
298
299 for attempt in 0..max_attempts {
300 if attempt > 0 {
301 tokio::time::sleep(Duration::from_millis(500 * attempt as u64)).await;
302 tracing::debug!(attempt, "retrying thinker HTTP request");
303 }
304
305 let mut request = http.post(&self.config.endpoint).json(&body);
306 if let Some(key) = self.config.api_key.as_ref() {
307 request = request.bearer_auth(key);
308 }
309
310 let response = match request.send().await {
311 Ok(resp) => resp,
312 Err(e) => {
313 if is_transient_reqwest_error(&e) {
314 tracing::warn!(attempt, error = %e, "thinker HTTP request failed (transient)");
315 last_err =
316 Some(anyhow::Error::from(e).context("transient thinker send error"));
317 continue;
318 }
319 return Err(anyhow::Error::from(e).context("non-transient thinker send error"));
320 }
321 };
322
323 let status = response.status();
324 if is_transient_http_error(status.as_u16()) {
325 let body_text = response.text().await.unwrap_or_default();
326 tracing::warn!(attempt, status = %status, "thinker received transient HTTP error");
327 last_err = Some(anyhow!(
328 "thinker request failed with status {}: {}",
329 status,
330 body_text
331 ));
332 continue;
333 }
334
335 if !status.is_success() {
336 let body_text = response
337 .text()
338 .await
339 .unwrap_or_else(|_| "<empty>".to_string());
340 return Err(anyhow!(
341 "thinker request failed with status {}: {}",
342 status,
343 body_text
344 ));
345 }
346
347 let payload: OpenAIChatResponse = response
348 .json()
349 .await
350 .context("failed to decode thinker response")?;
351 let choice = payload
352 .choices
353 .first()
354 .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
355 let text = choice.message.extract_text();
356 let usage = payload.usage.unwrap_or_default();
357
358 let output = ThinkerOutput {
359 model: payload.model.unwrap_or_else(|| self.config.model.clone()),
360 finish_reason: choice.finish_reason.clone(),
361 text,
362 prompt_tokens: usage.prompt_tokens,
363 completion_tokens: usage.completion_tokens,
364 total_tokens: usage.total_tokens,
365 };
366
367 tracing::debug!(
368 model = %output.model,
369 latency_ms = started_at.elapsed().as_millis(),
370 prompt_tokens = ?output.prompt_tokens,
371 completion_tokens = ?output.completion_tokens,
372 attempt,
373 "openai-compat thinker generated thought"
374 );
375
376 return Ok(output);
377 }
378
379 Err(last_err.unwrap_or_else(|| {
380 anyhow!("thinker HTTP request failed after {max_attempts} attempts")
381 }))
382 }
383}
384
385pub(crate) struct CandleThinker {
386 model: CandleModel,
387 tokenizer: Tokenizer,
388 device: Device,
389 model_label: String,
390 architecture: String,
391 context_window: usize,
392 temperature: f32,
393 top_p: Option<f32>,
394 max_tokens: usize,
395 repeat_penalty: f32,
396 repeat_last_n: usize,
397 seed: u64,
398 request_index: u64,
399 eos_token_ids: HashSet<u32>,
400}
401
402enum CandleModel {
403 Llama(quantized_llama::ModelWeights),
404 Qwen2(quantized_qwen2::ModelWeights),
405 #[cfg(feature = "functiongemma")]
406 Gemma3(quantized_gemma3::ModelWeights),
407}
408
409impl CandleModel {
410 fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
411 match self {
412 Self::Llama(model) => Ok(model.forward(x, index_pos)?),
413 Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
414 #[cfg(feature = "functiongemma")]
415 Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
416 }
417 }
418}
419
420impl CandleThinker {
421 pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
422 let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
423 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
424 })?;
425 let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
426 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
427 })?;
428
429 let (device, device_label) = select_candle_device(config)?;
430 let mut reader = BufReader::new(
431 File::open(model_path)
432 .with_context(|| format!("failed to open candle model file at {}", model_path))?,
433 );
434 let content = gguf_file::Content::read(&mut reader)
435 .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
436
437 let architecture = config
438 .candle_arch
439 .clone()
440 .or_else(|| {
441 content
442 .metadata
443 .get("general.architecture")
444 .and_then(|v| v.to_string().ok())
445 .cloned()
446 })
447 .unwrap_or_else(|| "llama".to_string())
448 .to_ascii_lowercase();
449
450 let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
451 let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
452
453 let tokenizer = Tokenizer::from_file(tokenizer_path)
454 .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
455
456 let gguf_eos_ids = extract_gguf_eos_ids(&content);
458
459 let model = match architecture.as_str() {
460 "llama" => CandleModel::Llama(
461 quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
462 .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
463 ),
464 "qwen2" => CandleModel::Qwen2(
465 quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
466 .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
467 ),
468 #[cfg(feature = "functiongemma")]
469 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
470 quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
471 .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
472 ),
473 other => {
474 #[cfg(not(feature = "functiongemma"))]
475 if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
476 return Err(anyhow!(
477 "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
478 other
479 ));
480 }
481 return Err(anyhow!(
482 "unsupported candle architecture '{}' (supported: llama, qwen2{})",
483 other,
484 if cfg!(feature = "functiongemma") {
485 ", gemma/gemma2/gemma3"
486 } else {
487 ""
488 }
489 ));
490 }
491 };
492
493 let eos_token_ids: HashSet<u32> = collect_eos_token_ids(&tokenizer, &gguf_eos_ids);
494 if eos_token_ids.is_empty() {
495 tracing::warn!(
496 "No EOS tokens found in tokenizer; generation will stop on max token limit"
497 );
498 }
499
500 Ok(Self {
501 model,
502 tokenizer,
503 device,
504 model_label,
505 architecture,
506 context_window,
507 temperature: config.temperature,
508 top_p: config.top_p,
509 max_tokens: config.max_tokens.max(1),
510 repeat_penalty: config.candle_repeat_penalty.max(1.0),
511 repeat_last_n: config.candle_repeat_last_n.max(1),
512 seed: config.candle_seed,
513 request_index: 0,
514 eos_token_ids,
515 })
516 }
517
518 pub(crate) fn think(
519 &mut self,
520 system_prompt: &str,
521 user_prompt: &str,
522 ) -> Result<ThinkerOutput> {
523 let started_at = Instant::now();
524 let prompt = format_chat_prompt(&self.architecture, system_prompt, user_prompt);
525 let encoding = self
526 .tokenizer
527 .encode(prompt.as_str(), true)
528 .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
529 let mut tokens = encoding.get_ids().to_vec();
530 if tokens.is_empty() {
531 return Err(anyhow!("tokenizer produced an empty prompt token set"));
532 }
533
534 if self.context_window > 8 && tokens.len() >= self.context_window {
536 let system_only = format_chat_prompt(&self.architecture, system_prompt, "");
537 let sys_encoding = self
538 .tokenizer
539 .encode(system_only.as_str(), true)
540 .map_err(|e| anyhow!("tokenizer encode failed (system): {}", e))?;
541 let sys_len = sys_encoding.get_ids().len();
542 let budget = self.context_window.saturating_sub(8);
543 if sys_len < budget {
544 let tail_budget = budget.saturating_sub(sys_len);
546 let tail_start = tokens.len().saturating_sub(tail_budget);
547 let mut truncated = sys_encoding.get_ids().to_vec();
548 truncated.extend_from_slice(&tokens[tail_start..]);
549 tokens = truncated;
550 } else {
551 let keep = budget;
553 tokens = tokens[tokens.len().saturating_sub(keep)..].to_vec();
554 }
555 }
556 let prompt_token_count = tokens.len() as u32;
557
558 let request_seed = self.seed.wrapping_add(self.request_index);
559 self.request_index = self.request_index.wrapping_add(1);
560 let mut logits_processor = LogitsProcessor::new(
561 request_seed,
562 Some(self.temperature as f64),
563 self.top_p.map(|v| v as f64),
564 );
565
566 let mut index_pos = 0usize;
567 let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
568 let mut finish_reason = "length".to_string();
569
570 for _ in 0..self.max_tokens {
571 let ctxt: &[u32] = if index_pos == 0 {
572 tokens.as_slice()
573 } else {
574 &tokens[tokens.len() - 1..]
575 };
576
577 let input = Tensor::new(ctxt, &self.device)?
578 .unsqueeze(0)
579 .context("failed to create candle input tensor")?;
580 let mut logits = self
581 .model
582 .forward(&input, index_pos)
583 .context("candle model forward failed")?;
584 index_pos += ctxt.len();
585 logits = logits
586 .squeeze(0)
587 .context("failed to squeeze logits batch dimension")?;
588
589 let logits = if self.repeat_penalty > 1.0 {
590 let start_at = tokens.len().saturating_sub(self.repeat_last_n);
591 apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
592 .context("failed to apply repeat penalty")?
593 } else {
594 logits
595 };
596
597 let next_token = logits_processor
598 .sample(&logits)
599 .context("token sampling failed")?;
600 if self.eos_token_ids.contains(&next_token) {
601 finish_reason = "stop".to_string();
602 break;
603 }
604
605 tokens.push(next_token);
606 generated.push(next_token);
607
608 if tokens.len() + 1 >= self.context_window {
609 finish_reason = "length".to_string();
610 break;
611 }
612 }
613
614 let text = self
615 .tokenizer
616 .decode(&generated, true)
617 .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
618 let completion_tokens = generated.len() as u32;
619
620 tracing::debug!(
621 model = %self.model_label,
622 latency_ms = started_at.elapsed().as_millis(),
623 prompt_tokens = prompt_token_count,
624 completion_tokens = completion_tokens,
625 "candle thinker generated thought"
626 );
627
628 Ok(ThinkerOutput {
629 model: self.model_label.clone(),
630 finish_reason: Some(finish_reason),
631 text,
632 prompt_tokens: Some(prompt_token_count),
633 completion_tokens: Some(completion_tokens),
634 total_tokens: Some(prompt_token_count + completion_tokens),
635 })
636 }
637}
638
639fn format_chat_prompt(architecture: &str, system_prompt: &str, user_prompt: &str) -> String {
641 match architecture {
642 "qwen2" => format!(
644 "<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
645 system = system_prompt,
646 user = user_prompt,
647 ),
648 "llama" => format!(
650 "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
651 system = system_prompt,
652 user = user_prompt,
653 ),
654 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => format!(
656 "<start_of_turn>user\n{system}\n\n{user}<end_of_turn>\n<start_of_turn>model\n",
657 system = system_prompt,
658 user = user_prompt,
659 ),
660 _ => format!(
662 "System:\n{system}\n\nUser:\n{user}\n\nAssistant:\n",
663 system = system_prompt,
664 user = user_prompt,
665 ),
666 }
667}
668
669fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
670 match config.candle_device {
671 CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
672 CandleDevicePreference::Cuda => {
673 let device = try_cuda_device(config.candle_cuda_ordinal)?;
674 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
675 }
676 CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
677 Ok(device) => {
678 tracing::info!(
679 ordinal = config.candle_cuda_ordinal,
680 "Candle thinker selected CUDA device"
681 );
682 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
683 }
684 Err(error) => {
685 tracing::warn!(
686 %error,
687 "CUDA unavailable for Candle thinker, falling back to CPU"
688 );
689 Ok((Device::Cpu, "cpu".to_string()))
690 }
691 },
692 }
693}
694
695#[cfg(feature = "candle-cuda")]
696fn try_cuda_device(ordinal: usize) -> Result<Device> {
697 Device::new_cuda(ordinal)
698 .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
699}
700
701#[cfg(not(feature = "candle-cuda"))]
702fn try_cuda_device(_ordinal: usize) -> Result<Device> {
703 Err(anyhow!(
704 "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
705 ))
706}
707
708fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
709 let key = match architecture {
710 "qwen2" => "qwen2.context_length",
711 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
712 for prefix in ["gemma3", "gemma2", "gemma"] {
714 let k = format!("{prefix}.context_length");
715 if let Some(v) = content.metadata.get(&k) {
716 return v.to_u32().ok().map(|v| v as usize);
717 }
718 }
719 return None;
720 }
721 _ => "llama.context_length",
722 };
723 content
724 .metadata
725 .get(key)
726 .and_then(|v| v.to_u32().ok())
727 .map(|v| v as usize)
728}
729
730fn extract_gguf_eos_ids(content: &gguf_file::Content) -> Vec<u32> {
732 let mut ids = Vec::new();
733 for key in ["tokenizer.ggml.eos_token_id", "tokenizer.ggml.eot_token_id"] {
734 if let Some(v) = content.metadata.get(key) {
735 if let Ok(id) = v.to_u32() {
736 if !ids.contains(&id) {
737 ids.push(id);
738 }
739 }
740 }
741 }
742 ids
743}
744
745fn collect_eos_token_ids(tokenizer: &Tokenizer, gguf_eos_ids: &[u32]) -> HashSet<u32> {
746 let mut ids: HashSet<u32> = gguf_eos_ids.iter().copied().collect();
747
748 let candidates = [
750 "<|im_end|>",
751 "<|eot_id|>",
752 "<|endoftext|>",
753 "</s>",
754 "<|end|>",
755 "<end_of_turn>",
756 ];
757 for token in candidates {
758 if let Some(id) = tokenizer.token_to_id(token) {
759 ids.insert(id);
760 }
761 }
762 ids
763}
764
765fn is_transient_http_error(status: u16) -> bool {
767 matches!(status, 429 | 502 | 503 | 504)
768}
769
770fn is_transient_reqwest_error(e: &reqwest::Error) -> bool {
772 e.is_timeout() || e.is_connect() || e.is_request()
773}
774
775#[derive(Debug, Serialize)]
776struct OpenAIChatRequest {
777 model: String,
778 messages: Vec<OpenAIMessage>,
779 temperature: f32,
780 #[serde(skip_serializing_if = "Option::is_none")]
781 top_p: Option<f32>,
782 max_tokens: usize,
783 stream: bool,
784}
785
786#[derive(Debug, Serialize)]
787struct OpenAIMessage {
788 role: String,
789 content: String,
790}
791
792#[derive(Debug, Deserialize)]
793struct OpenAIChatResponse {
794 model: Option<String>,
795 choices: Vec<OpenAIChatChoice>,
796 #[serde(default)]
797 usage: Option<OpenAIUsage>,
798}
799
800#[derive(Debug, Deserialize)]
801struct OpenAIChatChoice {
802 message: OpenAIChatChoiceMessage,
803 #[serde(default)]
804 finish_reason: Option<String>,
805}
806
807#[derive(Debug, Deserialize)]
808struct OpenAIChatChoiceMessage {
809 #[serde(default)]
810 content: Option<OpenAIChatContent>,
811 #[serde(default)]
812 reasoning: Option<String>,
813 #[serde(default)]
814 reasoning_content: Option<String>,
815}
816
817#[derive(Debug, Default, Deserialize)]
818struct OpenAIUsage {
819 prompt_tokens: Option<u32>,
820 completion_tokens: Option<u32>,
821 total_tokens: Option<u32>,
822}
823
824#[derive(Debug, Deserialize)]
825#[serde(untagged)]
826enum OpenAIChatContent {
827 Text(String),
828 Parts(Vec<OpenAIChatContentPart>),
829 Part(OpenAIChatContentPart),
830}
831
832#[derive(Debug, Deserialize)]
833struct OpenAIChatContentPart {
834 #[serde(rename = "type")]
835 kind: Option<String>,
836 #[serde(default)]
837 text: Option<String>,
838 #[serde(default)]
839 content: Option<String>,
840}
841
842impl OpenAIChatChoiceMessage {
843 fn extract_text(&self) -> String {
844 let content_text = self
845 .content
846 .as_ref()
847 .map(OpenAIChatContent::to_text)
848 .unwrap_or_default();
849 if !content_text.trim().is_empty() {
850 return content_text;
851 }
852
853 if let Some(reasoning) = self
854 .reasoning
855 .as_deref()
856 .filter(|text| !text.trim().is_empty())
857 {
858 return reasoning.to_string();
859 }
860
861 self.reasoning_content
862 .as_deref()
863 .filter(|text| !text.trim().is_empty())
864 .unwrap_or_default()
865 .to_string()
866 }
867}
868
869impl OpenAIChatContent {
870 fn to_text(&self) -> String {
871 match self {
872 Self::Text(text) => text.clone(),
873 Self::Parts(parts) => parts
874 .iter()
875 .filter_map(OpenAIChatContentPart::text_fragment)
876 .collect::<Vec<_>>()
877 .join("\n"),
878 Self::Part(part) => part.text_fragment().unwrap_or_default(),
879 }
880 }
881}
882
883impl OpenAIChatContentPart {
884 fn text_fragment(&self) -> Option<String> {
885 if let Some(kind) = self.kind.as_deref()
886 && !kind.eq_ignore_ascii_case("text")
887 && !kind.eq_ignore_ascii_case("output_text")
888 {
889 return None;
890 }
891
892 self.text
893 .as_deref()
894 .or(self.content.as_deref())
895 .map(str::trim)
896 .filter(|text| !text.is_empty())
897 .map(ToString::to_string)
898 }
899}