1use anyhow::{Context, Result, anyhow};
2use candle_core::quantized::gguf_file;
3use candle_core::{Device, Tensor};
4use candle_transformers::generation::LogitsProcessor;
5#[cfg(feature = "functiongemma")]
6use candle_transformers::models::quantized_gemma3;
7use candle_transformers::models::{quantized_llama, quantized_qwen2};
8use candle_transformers::utils::apply_repeat_penalty;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use std::fs::File;
13use std::io::BufReader;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use tokenizers::Tokenizer;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ThinkerBackend {
20 OpenAICompat,
21 Candle,
22}
23
24impl ThinkerBackend {
25 pub fn from_env(value: &str) -> Self {
26 match value.trim().to_ascii_lowercase().as_str() {
27 "candle" => Self::Candle,
28 "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
29 _ => Self::OpenAICompat,
30 }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum CandleDevicePreference {
36 Auto,
37 Cpu,
38 Cuda,
39}
40
41impl CandleDevicePreference {
42 pub fn from_env(value: &str) -> Self {
43 match value.trim().to_ascii_lowercase().as_str() {
44 "cpu" => Self::Cpu,
45 "cuda" | "gpu" => Self::Cuda,
46 _ => Self::Auto,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
52pub struct ThinkerConfig {
53 pub enabled: bool,
54 pub backend: ThinkerBackend,
55 pub endpoint: String,
56 pub model: String,
57 pub api_key: Option<String>,
58 pub temperature: f32,
59 pub top_p: Option<f32>,
60 pub max_tokens: usize,
61 pub timeout_ms: u64,
62 pub candle_model_path: Option<String>,
63 pub candle_tokenizer_path: Option<String>,
64 pub candle_arch: Option<String>,
65 pub candle_device: CandleDevicePreference,
66 pub candle_cuda_ordinal: usize,
67 pub candle_repeat_penalty: f32,
68 pub candle_repeat_last_n: usize,
69 pub candle_seed: u64,
70}
71
72impl Default for ThinkerConfig {
73 fn default() -> Self {
74 Self {
75 enabled: false,
76 backend: ThinkerBackend::OpenAICompat,
77 endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
78 model: "qwen2.5:3b-instruct".to_string(),
79 api_key: None,
80 temperature: 0.2,
81 top_p: None,
82 max_tokens: 256,
83 timeout_ms: 30_000,
84 candle_model_path: None,
85 candle_tokenizer_path: None,
86 candle_arch: None,
87 candle_device: CandleDevicePreference::Auto,
88 candle_cuda_ordinal: 0,
89 candle_repeat_penalty: 1.1,
90 candle_repeat_last_n: 64,
91 candle_seed: 42,
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
97pub struct ThinkerOutput {
98 pub model: String,
99 pub finish_reason: Option<String>,
100 pub text: String,
101 pub prompt_tokens: Option<u32>,
102 pub completion_tokens: Option<u32>,
103 pub total_tokens: Option<u32>,
104}
105
106#[derive(Clone)]
107pub struct ThinkerClient {
108 config: ThinkerConfig,
109 backend: ThinkerClientBackend,
110}
111
112impl std::fmt::Debug for ThinkerClient {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("ThinkerClient")
115 .field("backend", &self.config.backend)
116 .field("model", &self.config.model)
117 .finish()
118 }
119}
120
121#[derive(Clone)]
122enum ThinkerClientBackend {
123 OpenAICompat { http: Client },
124 Candle { runtime: Arc<Mutex<CandleThinker>> },
125}
126
127impl ThinkerClient {
128 pub fn new(config: ThinkerConfig) -> Result<Self> {
129 let backend = match config.backend {
130 ThinkerBackend::OpenAICompat => {
131 let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
132 let http = Client::builder()
133 .timeout(timeout)
134 .build()
135 .context("failed to build thinker HTTP client")?;
136 ThinkerClientBackend::OpenAICompat { http }
137 }
138 ThinkerBackend::Candle => {
139 let runtime = CandleThinker::new(&config)?;
140 ThinkerClientBackend::Candle {
141 runtime: Arc::new(Mutex::new(runtime)),
142 }
143 }
144 };
145
146 Ok(Self { config, backend })
147 }
148
149 pub fn config(&self) -> &ThinkerConfig {
150 &self.config
151 }
152
153 pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
154 match &self.backend {
155 ThinkerClientBackend::OpenAICompat { http } => {
156 self.think_openai_compat(http, system_prompt, user_prompt)
157 .await
158 }
159 ThinkerClientBackend::Candle { runtime } => {
160 let runtime = Arc::clone(runtime);
161 let system_prompt = system_prompt.to_string();
162 let user_prompt = user_prompt.to_string();
163 tokio::task::spawn_blocking(move || {
164 let mut guard = match runtime.try_lock() {
165 Ok(g) => g,
166 Err(std::sync::TryLockError::WouldBlock) => {
167 return Err(anyhow!("candle thinker is busy"));
168 }
169 Err(std::sync::TryLockError::Poisoned(_)) => {
170 return Err(anyhow!("candle thinker mutex poisoned"));
171 }
172 };
173 guard.think(&system_prompt, &user_prompt)
174 })
175 .await
176 .context("candle thinker task join failed")?
177 }
178 }
179 }
180
181 async fn think_openai_compat(
182 &self,
183 http: &Client,
184 system_prompt: &str,
185 user_prompt: &str,
186 ) -> Result<ThinkerOutput> {
187 let started_at = Instant::now();
188 let body = OpenAIChatRequest {
189 model: self.config.model.clone(),
190 messages: vec![
191 OpenAIMessage {
192 role: "system".to_string(),
193 content: system_prompt.to_string(),
194 },
195 OpenAIMessage {
196 role: "user".to_string(),
197 content: user_prompt.to_string(),
198 },
199 ],
200 temperature: self.config.temperature,
201 top_p: self.config.top_p,
202 max_tokens: self.config.max_tokens,
203 stream: false,
204 };
205
206 let max_attempts: u32 = 2;
208 let mut last_err: Option<anyhow::Error> = None;
209
210 for attempt in 0..max_attempts {
211 if attempt > 0 {
212 tokio::time::sleep(Duration::from_millis(500 * attempt as u64)).await;
213 tracing::debug!(attempt, "retrying thinker HTTP request");
214 }
215
216 let mut request = http.post(&self.config.endpoint).json(&body);
217 if let Some(key) = self.config.api_key.as_ref() {
218 request = request.bearer_auth(key);
219 }
220
221 let response = match request.send().await {
222 Ok(resp) => resp,
223 Err(e) => {
224 if is_transient_reqwest_error(&e) {
225 tracing::warn!(attempt, error = %e, "thinker HTTP request failed (transient)");
226 last_err =
227 Some(anyhow::Error::from(e).context("transient thinker send error"));
228 continue;
229 }
230 return Err(anyhow::Error::from(e).context("non-transient thinker send error"));
231 }
232 };
233
234 let status = response.status();
235 if is_transient_http_error(status.as_u16()) {
236 let body_text = response.text().await.unwrap_or_default();
237 tracing::warn!(attempt, status = %status, "thinker received transient HTTP error");
238 last_err = Some(anyhow!(
239 "thinker request failed with status {}: {}",
240 status,
241 body_text
242 ));
243 continue;
244 }
245
246 if !status.is_success() {
247 let body_text = response
248 .text()
249 .await
250 .unwrap_or_else(|_| "<empty>".to_string());
251 return Err(anyhow!(
252 "thinker request failed with status {}: {}",
253 status,
254 body_text
255 ));
256 }
257
258 let payload: OpenAIChatResponse = response
259 .json()
260 .await
261 .context("failed to decode thinker response")?;
262 let choice = payload
263 .choices
264 .first()
265 .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
266 let text = choice.message.extract_text();
267 let usage = payload.usage.unwrap_or_default();
268
269 let output = ThinkerOutput {
270 model: payload.model.unwrap_or_else(|| self.config.model.clone()),
271 finish_reason: choice.finish_reason.clone(),
272 text,
273 prompt_tokens: usage.prompt_tokens,
274 completion_tokens: usage.completion_tokens,
275 total_tokens: usage.total_tokens,
276 };
277
278 tracing::debug!(
279 model = %output.model,
280 latency_ms = started_at.elapsed().as_millis(),
281 prompt_tokens = ?output.prompt_tokens,
282 completion_tokens = ?output.completion_tokens,
283 attempt,
284 "openai-compat thinker generated thought"
285 );
286
287 return Ok(output);
288 }
289
290 Err(last_err.unwrap_or_else(|| {
291 anyhow!("thinker HTTP request failed after {max_attempts} attempts")
292 }))
293 }
294}
295
296pub(crate) struct CandleThinker {
297 model: CandleModel,
298 tokenizer: Tokenizer,
299 device: Device,
300 model_label: String,
301 architecture: String,
302 context_window: usize,
303 temperature: f32,
304 top_p: Option<f32>,
305 max_tokens: usize,
306 repeat_penalty: f32,
307 repeat_last_n: usize,
308 seed: u64,
309 request_index: u64,
310 eos_token_ids: HashSet<u32>,
311}
312
313enum CandleModel {
314 Llama(quantized_llama::ModelWeights),
315 Qwen2(quantized_qwen2::ModelWeights),
316 #[cfg(feature = "functiongemma")]
317 Gemma3(quantized_gemma3::ModelWeights),
318}
319
320impl CandleModel {
321 fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
322 match self {
323 Self::Llama(model) => Ok(model.forward(x, index_pos)?),
324 Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
325 #[cfg(feature = "functiongemma")]
326 Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
327 }
328 }
329}
330
331impl CandleThinker {
332 pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
333 let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
334 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
335 })?;
336 let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
337 anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
338 })?;
339
340 let (device, device_label) = select_candle_device(config)?;
341 let mut reader = BufReader::new(
342 File::open(model_path)
343 .with_context(|| format!("failed to open candle model file at {}", model_path))?,
344 );
345 let content = gguf_file::Content::read(&mut reader)
346 .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
347
348 let architecture = config
349 .candle_arch
350 .clone()
351 .or_else(|| {
352 content
353 .metadata
354 .get("general.architecture")
355 .and_then(|v| v.to_string().ok())
356 .cloned()
357 })
358 .unwrap_or_else(|| "llama".to_string())
359 .to_ascii_lowercase();
360
361 let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
362 let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
363
364 let tokenizer = Tokenizer::from_file(tokenizer_path)
365 .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
366
367 let gguf_eos_ids = extract_gguf_eos_ids(&content);
369
370 let model = match architecture.as_str() {
371 "llama" => CandleModel::Llama(
372 quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
373 .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
374 ),
375 "qwen2" => CandleModel::Qwen2(
376 quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
377 .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
378 ),
379 #[cfg(feature = "functiongemma")]
380 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
381 quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
382 .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
383 ),
384 other => {
385 #[cfg(not(feature = "functiongemma"))]
386 if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
387 return Err(anyhow!(
388 "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
389 other
390 ));
391 }
392 return Err(anyhow!(
393 "unsupported candle architecture '{}' (supported: llama, qwen2{})",
394 other,
395 if cfg!(feature = "functiongemma") {
396 ", gemma/gemma2/gemma3"
397 } else {
398 ""
399 }
400 ));
401 }
402 };
403
404 let eos_token_ids: HashSet<u32> = collect_eos_token_ids(&tokenizer, &gguf_eos_ids);
405 if eos_token_ids.is_empty() {
406 tracing::warn!(
407 "No EOS tokens found in tokenizer; generation will stop on max token limit"
408 );
409 }
410
411 Ok(Self {
412 model,
413 tokenizer,
414 device,
415 model_label,
416 architecture,
417 context_window,
418 temperature: config.temperature,
419 top_p: config.top_p,
420 max_tokens: config.max_tokens.max(1),
421 repeat_penalty: config.candle_repeat_penalty.max(1.0),
422 repeat_last_n: config.candle_repeat_last_n.max(1),
423 seed: config.candle_seed,
424 request_index: 0,
425 eos_token_ids,
426 })
427 }
428
429 pub(crate) fn think(
430 &mut self,
431 system_prompt: &str,
432 user_prompt: &str,
433 ) -> Result<ThinkerOutput> {
434 let started_at = Instant::now();
435 let prompt = format_chat_prompt(&self.architecture, system_prompt, user_prompt);
436 let encoding = self
437 .tokenizer
438 .encode(prompt.as_str(), true)
439 .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
440 let mut tokens = encoding.get_ids().to_vec();
441 if tokens.is_empty() {
442 return Err(anyhow!("tokenizer produced an empty prompt token set"));
443 }
444
445 if self.context_window > 8 && tokens.len() >= self.context_window {
447 let system_only = format_chat_prompt(&self.architecture, system_prompt, "");
448 let sys_encoding = self
449 .tokenizer
450 .encode(system_only.as_str(), true)
451 .map_err(|e| anyhow!("tokenizer encode failed (system): {}", e))?;
452 let sys_len = sys_encoding.get_ids().len();
453 let budget = self.context_window.saturating_sub(8);
454 if sys_len < budget {
455 let tail_budget = budget.saturating_sub(sys_len);
457 let tail_start = tokens.len().saturating_sub(tail_budget);
458 let mut truncated = sys_encoding.get_ids().to_vec();
459 truncated.extend_from_slice(&tokens[tail_start..]);
460 tokens = truncated;
461 } else {
462 let keep = budget;
464 tokens = tokens[tokens.len().saturating_sub(keep)..].to_vec();
465 }
466 }
467 let prompt_token_count = tokens.len() as u32;
468
469 let request_seed = self.seed.wrapping_add(self.request_index);
470 self.request_index = self.request_index.wrapping_add(1);
471 let mut logits_processor = LogitsProcessor::new(
472 request_seed,
473 Some(self.temperature as f64),
474 self.top_p.map(|v| v as f64),
475 );
476
477 let mut index_pos = 0usize;
478 let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
479 let mut finish_reason = "length".to_string();
480
481 for _ in 0..self.max_tokens {
482 let ctxt: &[u32] = if index_pos == 0 {
483 tokens.as_slice()
484 } else {
485 &tokens[tokens.len() - 1..]
486 };
487
488 let input = Tensor::new(ctxt, &self.device)?
489 .unsqueeze(0)
490 .context("failed to create candle input tensor")?;
491 let mut logits = self
492 .model
493 .forward(&input, index_pos)
494 .context("candle model forward failed")?;
495 index_pos += ctxt.len();
496 logits = logits
497 .squeeze(0)
498 .context("failed to squeeze logits batch dimension")?;
499
500 let logits = if self.repeat_penalty > 1.0 {
501 let start_at = tokens.len().saturating_sub(self.repeat_last_n);
502 apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
503 .context("failed to apply repeat penalty")?
504 } else {
505 logits
506 };
507
508 let next_token = logits_processor
509 .sample(&logits)
510 .context("token sampling failed")?;
511 if self.eos_token_ids.contains(&next_token) {
512 finish_reason = "stop".to_string();
513 break;
514 }
515
516 tokens.push(next_token);
517 generated.push(next_token);
518
519 if tokens.len() + 1 >= self.context_window {
520 finish_reason = "length".to_string();
521 break;
522 }
523 }
524
525 let text = self
526 .tokenizer
527 .decode(&generated, true)
528 .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
529 let completion_tokens = generated.len() as u32;
530
531 tracing::debug!(
532 model = %self.model_label,
533 latency_ms = started_at.elapsed().as_millis(),
534 prompt_tokens = prompt_token_count,
535 completion_tokens = completion_tokens,
536 "candle thinker generated thought"
537 );
538
539 Ok(ThinkerOutput {
540 model: self.model_label.clone(),
541 finish_reason: Some(finish_reason),
542 text,
543 prompt_tokens: Some(prompt_token_count),
544 completion_tokens: Some(completion_tokens),
545 total_tokens: Some(prompt_token_count + completion_tokens),
546 })
547 }
548}
549
550fn format_chat_prompt(architecture: &str, system_prompt: &str, user_prompt: &str) -> String {
552 match architecture {
553 "qwen2" => format!(
555 "<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
556 system = system_prompt,
557 user = user_prompt,
558 ),
559 "llama" => format!(
561 "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
562 system = system_prompt,
563 user = user_prompt,
564 ),
565 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => format!(
567 "<start_of_turn>user\n{system}\n\n{user}<end_of_turn>\n<start_of_turn>model\n",
568 system = system_prompt,
569 user = user_prompt,
570 ),
571 _ => format!(
573 "System:\n{system}\n\nUser:\n{user}\n\nAssistant:\n",
574 system = system_prompt,
575 user = user_prompt,
576 ),
577 }
578}
579
580fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
581 match config.candle_device {
582 CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
583 CandleDevicePreference::Cuda => {
584 let device = try_cuda_device(config.candle_cuda_ordinal)?;
585 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
586 }
587 CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
588 Ok(device) => {
589 tracing::info!(
590 ordinal = config.candle_cuda_ordinal,
591 "Candle thinker selected CUDA device"
592 );
593 Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
594 }
595 Err(error) => {
596 tracing::warn!(
597 %error,
598 "CUDA unavailable for Candle thinker, falling back to CPU"
599 );
600 Ok((Device::Cpu, "cpu".to_string()))
601 }
602 },
603 }
604}
605
606#[cfg(feature = "candle-cuda")]
607fn try_cuda_device(ordinal: usize) -> Result<Device> {
608 Device::new_cuda(ordinal)
609 .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
610}
611
612#[cfg(not(feature = "candle-cuda"))]
613fn try_cuda_device(_ordinal: usize) -> Result<Device> {
614 Err(anyhow!(
615 "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
616 ))
617}
618
619fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
620 let key = match architecture {
621 "qwen2" => "qwen2.context_length",
622 "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
623 for prefix in ["gemma3", "gemma2", "gemma"] {
625 let k = format!("{prefix}.context_length");
626 if let Some(v) = content.metadata.get(&k) {
627 return v.to_u32().ok().map(|v| v as usize);
628 }
629 }
630 return None;
631 }
632 _ => "llama.context_length",
633 };
634 content
635 .metadata
636 .get(key)
637 .and_then(|v| v.to_u32().ok())
638 .map(|v| v as usize)
639}
640
641fn extract_gguf_eos_ids(content: &gguf_file::Content) -> Vec<u32> {
643 let mut ids = Vec::new();
644 for key in ["tokenizer.ggml.eos_token_id", "tokenizer.ggml.eot_token_id"] {
645 if let Some(v) = content.metadata.get(key) {
646 if let Ok(id) = v.to_u32() {
647 if !ids.contains(&id) {
648 ids.push(id);
649 }
650 }
651 }
652 }
653 ids
654}
655
656fn collect_eos_token_ids(tokenizer: &Tokenizer, gguf_eos_ids: &[u32]) -> HashSet<u32> {
657 let mut ids: HashSet<u32> = gguf_eos_ids.iter().copied().collect();
658
659 let candidates = [
661 "<|im_end|>",
662 "<|eot_id|>",
663 "<|endoftext|>",
664 "</s>",
665 "<|end|>",
666 "<end_of_turn>",
667 ];
668 for token in candidates {
669 if let Some(id) = tokenizer.token_to_id(token) {
670 ids.insert(id);
671 }
672 }
673 ids
674}
675
676fn is_transient_http_error(status: u16) -> bool {
678 matches!(status, 429 | 502 | 503 | 504)
679}
680
681fn is_transient_reqwest_error(e: &reqwest::Error) -> bool {
683 e.is_timeout() || e.is_connect() || e.is_request()
684}
685
686#[derive(Debug, Serialize)]
687struct OpenAIChatRequest {
688 model: String,
689 messages: Vec<OpenAIMessage>,
690 temperature: f32,
691 #[serde(skip_serializing_if = "Option::is_none")]
692 top_p: Option<f32>,
693 max_tokens: usize,
694 stream: bool,
695}
696
697#[derive(Debug, Serialize)]
698struct OpenAIMessage {
699 role: String,
700 content: String,
701}
702
703#[derive(Debug, Deserialize)]
704struct OpenAIChatResponse {
705 model: Option<String>,
706 choices: Vec<OpenAIChatChoice>,
707 #[serde(default)]
708 usage: Option<OpenAIUsage>,
709}
710
711#[derive(Debug, Deserialize)]
712struct OpenAIChatChoice {
713 message: OpenAIChatChoiceMessage,
714 #[serde(default)]
715 finish_reason: Option<String>,
716}
717
718#[derive(Debug, Deserialize)]
719struct OpenAIChatChoiceMessage {
720 #[serde(default)]
721 content: Option<OpenAIChatContent>,
722 #[serde(default)]
723 reasoning: Option<String>,
724 #[serde(default)]
725 reasoning_content: Option<String>,
726}
727
728#[derive(Debug, Default, Deserialize)]
729struct OpenAIUsage {
730 prompt_tokens: Option<u32>,
731 completion_tokens: Option<u32>,
732 total_tokens: Option<u32>,
733}
734
735#[derive(Debug, Deserialize)]
736#[serde(untagged)]
737enum OpenAIChatContent {
738 Text(String),
739 Parts(Vec<OpenAIChatContentPart>),
740 Part(OpenAIChatContentPart),
741}
742
743#[derive(Debug, Deserialize)]
744struct OpenAIChatContentPart {
745 #[serde(rename = "type")]
746 kind: Option<String>,
747 #[serde(default)]
748 text: Option<String>,
749 #[serde(default)]
750 content: Option<String>,
751}
752
753impl OpenAIChatChoiceMessage {
754 fn extract_text(&self) -> String {
755 let content_text = self
756 .content
757 .as_ref()
758 .map(OpenAIChatContent::to_text)
759 .unwrap_or_default();
760 if !content_text.trim().is_empty() {
761 return content_text;
762 }
763
764 if let Some(reasoning) = self
765 .reasoning
766 .as_deref()
767 .filter(|text| !text.trim().is_empty())
768 {
769 return reasoning.to_string();
770 }
771
772 self.reasoning_content
773 .as_deref()
774 .filter(|text| !text.trim().is_empty())
775 .unwrap_or_default()
776 .to_string()
777 }
778}
779
780impl OpenAIChatContent {
781 fn to_text(&self) -> String {
782 match self {
783 Self::Text(text) => text.clone(),
784 Self::Parts(parts) => parts
785 .iter()
786 .filter_map(OpenAIChatContentPart::text_fragment)
787 .collect::<Vec<_>>()
788 .join("\n"),
789 Self::Part(part) => part.text_fragment().unwrap_or_default(),
790 }
791 }
792}
793
794impl OpenAIChatContentPart {
795 fn text_fragment(&self) -> Option<String> {
796 if let Some(kind) = self.kind.as_deref()
797 && !kind.eq_ignore_ascii_case("text")
798 && !kind.eq_ignore_ascii_case("output_text")
799 {
800 return None;
801 }
802
803 self.text
804 .as_deref()
805 .or(self.content.as_deref())
806 .map(str::trim)
807 .filter(|text| !text.is_empty())
808 .map(ToString::to_string)
809 }
810}