Skip to main content

codetether_agent/cognition/
thinker.rs

1use anyhow::{Context, Result, anyhow};
2use candle_core::quantized::gguf_file;
3use candle_core::{Device, Tensor};
4use candle_transformers::generation::LogitsProcessor;
5#[cfg(feature = "functiongemma")]
6use candle_transformers::models::quantized_gemma3;
7use candle_transformers::models::{quantized_llama, quantized_qwen2};
8use candle_transformers::utils::apply_repeat_penalty;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use std::fs::File;
13use std::io::BufReader;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use tokenizers::Tokenizer;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ThinkerBackend {
20    OpenAICompat,
21    Candle,
22}
23
24impl ThinkerBackend {
25    pub fn from_env(value: &str) -> Self {
26        match value.trim().to_ascii_lowercase().as_str() {
27            "candle" => Self::Candle,
28            "openai" | "openai_compat" | "openai-compatible" | "http" => Self::OpenAICompat,
29            _ => Self::OpenAICompat,
30        }
31    }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum CandleDevicePreference {
36    Auto,
37    Cpu,
38    Cuda,
39}
40
41impl CandleDevicePreference {
42    pub fn from_env(value: &str) -> Self {
43        match value.trim().to_ascii_lowercase().as_str() {
44            "cpu" => Self::Cpu,
45            "cuda" | "gpu" => Self::Cuda,
46            _ => Self::Auto,
47        }
48    }
49}
50
51#[derive(Debug, Clone)]
52pub struct ThinkerConfig {
53    pub enabled: bool,
54    pub backend: ThinkerBackend,
55    pub endpoint: String,
56    pub model: String,
57    pub api_key: Option<String>,
58    pub temperature: f32,
59    pub top_p: Option<f32>,
60    pub max_tokens: usize,
61    pub timeout_ms: u64,
62    pub candle_model_path: Option<String>,
63    pub candle_tokenizer_path: Option<String>,
64    pub candle_arch: Option<String>,
65    pub candle_device: CandleDevicePreference,
66    pub candle_cuda_ordinal: usize,
67    pub candle_repeat_penalty: f32,
68    pub candle_repeat_last_n: usize,
69    pub candle_seed: u64,
70}
71
72impl Default for ThinkerConfig {
73    fn default() -> Self {
74        Self {
75            enabled: false,
76            backend: ThinkerBackend::OpenAICompat,
77            endpoint: "http://127.0.0.1:11434/v1/chat/completions".to_string(),
78            model: "qwen2.5:3b-instruct".to_string(),
79            api_key: None,
80            temperature: 0.2,
81            top_p: None,
82            max_tokens: 256,
83            timeout_ms: 30_000,
84            candle_model_path: None,
85            candle_tokenizer_path: None,
86            candle_arch: None,
87            candle_device: CandleDevicePreference::Auto,
88            candle_cuda_ordinal: 0,
89            candle_repeat_penalty: 1.1,
90            candle_repeat_last_n: 64,
91            candle_seed: 42,
92        }
93    }
94}
95
96#[derive(Debug, Clone)]
97pub struct ThinkerOutput {
98    pub model: String,
99    pub finish_reason: Option<String>,
100    pub text: String,
101    pub prompt_tokens: Option<u32>,
102    pub completion_tokens: Option<u32>,
103    pub total_tokens: Option<u32>,
104}
105
106#[derive(Clone)]
107pub struct ThinkerClient {
108    config: ThinkerConfig,
109    backend: ThinkerClientBackend,
110}
111
112impl std::fmt::Debug for ThinkerClient {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("ThinkerClient")
115            .field("backend", &self.config.backend)
116            .field("model", &self.config.model)
117            .finish()
118    }
119}
120
121#[derive(Clone)]
122enum ThinkerClientBackend {
123    OpenAICompat { http: Client },
124    Candle { runtime: Arc<Mutex<CandleThinker>> },
125}
126
127impl ThinkerClient {
128    pub fn new(config: ThinkerConfig) -> Result<Self> {
129        let backend = match config.backend {
130            ThinkerBackend::OpenAICompat => {
131                let timeout = Duration::from_millis(config.timeout_ms.max(1_000));
132                let http = Client::builder()
133                    .timeout(timeout)
134                    .build()
135                    .context("failed to build thinker HTTP client")?;
136                ThinkerClientBackend::OpenAICompat { http }
137            }
138            ThinkerBackend::Candle => {
139                let runtime = CandleThinker::new(&config)?;
140                ThinkerClientBackend::Candle {
141                    runtime: Arc::new(Mutex::new(runtime)),
142                }
143            }
144        };
145
146        Ok(Self { config, backend })
147    }
148
149    pub fn config(&self) -> &ThinkerConfig {
150        &self.config
151    }
152
153    pub async fn think(&self, system_prompt: &str, user_prompt: &str) -> Result<ThinkerOutput> {
154        match &self.backend {
155            ThinkerClientBackend::OpenAICompat { http } => {
156                self.think_openai_compat(http, system_prompt, user_prompt)
157                    .await
158            }
159            ThinkerClientBackend::Candle { runtime } => {
160                let runtime = Arc::clone(runtime);
161                let system_prompt = system_prompt.to_string();
162                let user_prompt = user_prompt.to_string();
163                tokio::task::spawn_blocking(move || {
164                    let mut guard = match runtime.try_lock() {
165                        Ok(g) => g,
166                        Err(std::sync::TryLockError::WouldBlock) => {
167                            return Err(anyhow!("candle thinker is busy"));
168                        }
169                        Err(std::sync::TryLockError::Poisoned(_)) => {
170                            return Err(anyhow!("candle thinker mutex poisoned"));
171                        }
172                    };
173                    guard.think(&system_prompt, &user_prompt)
174                })
175                .await
176                .context("candle thinker task join failed")?
177            }
178        }
179    }
180
181    async fn think_openai_compat(
182        &self,
183        http: &Client,
184        system_prompt: &str,
185        user_prompt: &str,
186    ) -> Result<ThinkerOutput> {
187        let started_at = Instant::now();
188        let body = OpenAIChatRequest {
189            model: self.config.model.clone(),
190            messages: vec![
191                OpenAIMessage {
192                    role: "system".to_string(),
193                    content: system_prompt.to_string(),
194                },
195                OpenAIMessage {
196                    role: "user".to_string(),
197                    content: user_prompt.to_string(),
198                },
199            ],
200            temperature: self.config.temperature,
201            top_p: self.config.top_p,
202            max_tokens: self.config.max_tokens,
203            stream: false,
204        };
205
206        // Retry once on transient failures (connection errors, 429, 502-504).
207        let max_attempts: u32 = 2;
208        let mut last_err: Option<anyhow::Error> = None;
209
210        for attempt in 0..max_attempts {
211            if attempt > 0 {
212                tokio::time::sleep(Duration::from_millis(500 * attempt as u64)).await;
213                tracing::debug!(attempt, "retrying thinker HTTP request");
214            }
215
216            let mut request = http.post(&self.config.endpoint).json(&body);
217            if let Some(key) = self.config.api_key.as_ref() {
218                request = request.bearer_auth(key);
219            }
220
221            let response = match request.send().await {
222                Ok(resp) => resp,
223                Err(e) => {
224                    if is_transient_reqwest_error(&e) {
225                        tracing::warn!(attempt, error = %e, "thinker HTTP request failed (transient)");
226                        last_err =
227                            Some(anyhow::Error::from(e).context("transient thinker send error"));
228                        continue;
229                    }
230                    return Err(anyhow::Error::from(e).context("non-transient thinker send error"));
231                }
232            };
233
234            let status = response.status();
235            if is_transient_http_error(status.as_u16()) {
236                let body_text = response.text().await.unwrap_or_default();
237                tracing::warn!(attempt, status = %status, "thinker received transient HTTP error");
238                last_err = Some(anyhow!(
239                    "thinker request failed with status {}: {}",
240                    status,
241                    body_text
242                ));
243                continue;
244            }
245
246            if !status.is_success() {
247                let body_text = response
248                    .text()
249                    .await
250                    .unwrap_or_else(|_| "<empty>".to_string());
251                return Err(anyhow!(
252                    "thinker request failed with status {}: {}",
253                    status,
254                    body_text
255                ));
256            }
257
258            let payload: OpenAIChatResponse = response
259                .json()
260                .await
261                .context("failed to decode thinker response")?;
262            let choice = payload
263                .choices
264                .first()
265                .ok_or_else(|| anyhow!("thinker response did not include choices"))?;
266            let text = choice.message.extract_text();
267            let usage = payload.usage.unwrap_or_default();
268
269            let output = ThinkerOutput {
270                model: payload.model.unwrap_or_else(|| self.config.model.clone()),
271                finish_reason: choice.finish_reason.clone(),
272                text,
273                prompt_tokens: usage.prompt_tokens,
274                completion_tokens: usage.completion_tokens,
275                total_tokens: usage.total_tokens,
276            };
277
278            tracing::debug!(
279                model = %output.model,
280                latency_ms = started_at.elapsed().as_millis(),
281                prompt_tokens = ?output.prompt_tokens,
282                completion_tokens = ?output.completion_tokens,
283                attempt,
284                "openai-compat thinker generated thought"
285            );
286
287            return Ok(output);
288        }
289
290        Err(last_err.unwrap_or_else(|| {
291            anyhow!("thinker HTTP request failed after {max_attempts} attempts")
292        }))
293    }
294}
295
296pub(crate) struct CandleThinker {
297    model: CandleModel,
298    tokenizer: Tokenizer,
299    device: Device,
300    model_label: String,
301    architecture: String,
302    context_window: usize,
303    temperature: f32,
304    top_p: Option<f32>,
305    max_tokens: usize,
306    repeat_penalty: f32,
307    repeat_last_n: usize,
308    seed: u64,
309    request_index: u64,
310    eos_token_ids: HashSet<u32>,
311}
312
313enum CandleModel {
314    Llama(quantized_llama::ModelWeights),
315    Qwen2(quantized_qwen2::ModelWeights),
316    #[cfg(feature = "functiongemma")]
317    Gemma3(quantized_gemma3::ModelWeights),
318}
319
320impl CandleModel {
321    fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
322        match self {
323            Self::Llama(model) => Ok(model.forward(x, index_pos)?),
324            Self::Qwen2(model) => Ok(model.forward(x, index_pos)?),
325            #[cfg(feature = "functiongemma")]
326            Self::Gemma3(model) => Ok(model.forward(x, index_pos)?),
327        }
328    }
329}
330
331impl CandleThinker {
332    pub(crate) fn new(config: &ThinkerConfig) -> Result<Self> {
333        let model_path = config.candle_model_path.as_ref().ok_or_else(|| {
334            anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_MODEL_PATH")
335        })?;
336        let tokenizer_path = config.candle_tokenizer_path.as_ref().ok_or_else(|| {
337            anyhow!("candle backend requires CODETETHER_COGNITION_THINKER_CANDLE_TOKENIZER_PATH")
338        })?;
339
340        let (device, device_label) = select_candle_device(config)?;
341        let mut reader = BufReader::new(
342            File::open(model_path)
343                .with_context(|| format!("failed to open candle model file at {}", model_path))?,
344        );
345        let content = gguf_file::Content::read(&mut reader)
346            .with_context(|| format!("failed to parse gguf model metadata from {}", model_path))?;
347
348        let architecture = config
349            .candle_arch
350            .clone()
351            .or_else(|| {
352                content
353                    .metadata
354                    .get("general.architecture")
355                    .and_then(|v| v.to_string().ok())
356                    .cloned()
357            })
358            .unwrap_or_else(|| "llama".to_string())
359            .to_ascii_lowercase();
360
361        let context_window = detect_context_window(&content, &architecture).unwrap_or(4096);
362        let model_label = format!("candle:{}:{}@{}", architecture, device_label, model_path);
363
364        let tokenizer = Tokenizer::from_file(tokenizer_path)
365            .map_err(|e| anyhow!("failed to load tokenizer from {}: {}", tokenizer_path, e))?;
366
367        // Extract EOS metadata from content before it is moved into from_gguf.
368        let gguf_eos_ids = extract_gguf_eos_ids(&content);
369
370        let model = match architecture.as_str() {
371            "llama" => CandleModel::Llama(
372                quantized_llama::ModelWeights::from_gguf(content, &mut reader, &device)
373                    .with_context(|| format!("failed to load llama gguf from {}", model_path))?,
374            ),
375            "qwen2" => CandleModel::Qwen2(
376                quantized_qwen2::ModelWeights::from_gguf(content, &mut reader, &device)
377                    .with_context(|| format!("failed to load qwen2 gguf from {}", model_path))?,
378            ),
379            #[cfg(feature = "functiongemma")]
380            "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => CandleModel::Gemma3(
381                quantized_gemma3::ModelWeights::from_gguf(content, &mut reader, &device)
382                    .with_context(|| format!("failed to load gemma3 gguf from {}", model_path))?,
383            ),
384            other => {
385                #[cfg(not(feature = "functiongemma"))]
386                if matches!(other, "gemma" | "gemma2" | "gemma3" | "gemma-embedding") {
387                    return Err(anyhow!(
388                        "gemma architecture '{}' requires the 'functiongemma' feature; rebuild with --features functiongemma",
389                        other
390                    ));
391                }
392                return Err(anyhow!(
393                    "unsupported candle architecture '{}' (supported: llama, qwen2{})",
394                    other,
395                    if cfg!(feature = "functiongemma") {
396                        ", gemma/gemma2/gemma3"
397                    } else {
398                        ""
399                    }
400                ));
401            }
402        };
403
404        let eos_token_ids: HashSet<u32> = collect_eos_token_ids(&tokenizer, &gguf_eos_ids);
405        if eos_token_ids.is_empty() {
406            tracing::warn!(
407                "No EOS tokens found in tokenizer; generation will stop on max token limit"
408            );
409        }
410
411        Ok(Self {
412            model,
413            tokenizer,
414            device,
415            model_label,
416            architecture,
417            context_window,
418            temperature: config.temperature,
419            top_p: config.top_p,
420            max_tokens: config.max_tokens.max(1),
421            repeat_penalty: config.candle_repeat_penalty.max(1.0),
422            repeat_last_n: config.candle_repeat_last_n.max(1),
423            seed: config.candle_seed,
424            request_index: 0,
425            eos_token_ids,
426        })
427    }
428
429    pub(crate) fn think(
430        &mut self,
431        system_prompt: &str,
432        user_prompt: &str,
433    ) -> Result<ThinkerOutput> {
434        let started_at = Instant::now();
435        let prompt = format_chat_prompt(&self.architecture, system_prompt, user_prompt);
436        let encoding = self
437            .tokenizer
438            .encode(prompt.as_str(), true)
439            .map_err(|e| anyhow!("tokenizer encode failed: {}", e))?;
440        let mut tokens = encoding.get_ids().to_vec();
441        if tokens.is_empty() {
442            return Err(anyhow!("tokenizer produced an empty prompt token set"));
443        }
444
445        // Truncate user content while preserving the system prompt prefix.
446        if self.context_window > 8 && tokens.len() >= self.context_window {
447            let system_only = format_chat_prompt(&self.architecture, system_prompt, "");
448            let sys_encoding = self
449                .tokenizer
450                .encode(system_only.as_str(), true)
451                .map_err(|e| anyhow!("tokenizer encode failed (system): {}", e))?;
452            let sys_len = sys_encoding.get_ids().len();
453            let budget = self.context_window.saturating_sub(8);
454            if sys_len < budget {
455                // Keep system prefix + tail of user content that fits
456                let tail_budget = budget.saturating_sub(sys_len);
457                let tail_start = tokens.len().saturating_sub(tail_budget);
458                let mut truncated = sys_encoding.get_ids().to_vec();
459                truncated.extend_from_slice(&tokens[tail_start..]);
460                tokens = truncated;
461            } else {
462                // System alone exceeds budget; keep only the tail
463                let keep = budget;
464                tokens = tokens[tokens.len().saturating_sub(keep)..].to_vec();
465            }
466        }
467        let prompt_token_count = tokens.len() as u32;
468
469        let request_seed = self.seed.wrapping_add(self.request_index);
470        self.request_index = self.request_index.wrapping_add(1);
471        let mut logits_processor = LogitsProcessor::new(
472            request_seed,
473            Some(self.temperature as f64),
474            self.top_p.map(|v| v as f64),
475        );
476
477        let mut index_pos = 0usize;
478        let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
479        let mut finish_reason = "length".to_string();
480
481        for _ in 0..self.max_tokens {
482            let ctxt: &[u32] = if index_pos == 0 {
483                tokens.as_slice()
484            } else {
485                &tokens[tokens.len() - 1..]
486            };
487
488            let input = Tensor::new(ctxt, &self.device)?
489                .unsqueeze(0)
490                .context("failed to create candle input tensor")?;
491            let mut logits = self
492                .model
493                .forward(&input, index_pos)
494                .context("candle model forward failed")?;
495            index_pos += ctxt.len();
496            logits = logits
497                .squeeze(0)
498                .context("failed to squeeze logits batch dimension")?;
499
500            let logits = if self.repeat_penalty > 1.0 {
501                let start_at = tokens.len().saturating_sub(self.repeat_last_n);
502                apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
503                    .context("failed to apply repeat penalty")?
504            } else {
505                logits
506            };
507
508            let next_token = logits_processor
509                .sample(&logits)
510                .context("token sampling failed")?;
511            if self.eos_token_ids.contains(&next_token) {
512                finish_reason = "stop".to_string();
513                break;
514            }
515
516            tokens.push(next_token);
517            generated.push(next_token);
518
519            if tokens.len() + 1 >= self.context_window {
520                finish_reason = "length".to_string();
521                break;
522            }
523        }
524
525        let text = self
526            .tokenizer
527            .decode(&generated, true)
528            .map_err(|e| anyhow!("tokenizer decode failed: {}", e))?;
529        let completion_tokens = generated.len() as u32;
530
531        tracing::debug!(
532            model = %self.model_label,
533            latency_ms = started_at.elapsed().as_millis(),
534            prompt_tokens = prompt_token_count,
535            completion_tokens = completion_tokens,
536            "candle thinker generated thought"
537        );
538
539        Ok(ThinkerOutput {
540            model: self.model_label.clone(),
541            finish_reason: Some(finish_reason),
542            text,
543            prompt_tokens: Some(prompt_token_count),
544            completion_tokens: Some(completion_tokens),
545            total_tokens: Some(prompt_token_count + completion_tokens),
546        })
547    }
548}
549
550/// Build a chat prompt using the proper template for each model architecture.
551fn format_chat_prompt(architecture: &str, system_prompt: &str, user_prompt: &str) -> String {
552    match architecture {
553        // ChatML template (Qwen2, Yi, etc.)
554        "qwen2" => format!(
555            "<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
556            system = system_prompt,
557            user = user_prompt,
558        ),
559        // Llama 3 instruct template
560        "llama" => format!(
561            "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
562            system = system_prompt,
563            user = user_prompt,
564        ),
565        // Gemma instruct template
566        "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => format!(
567            "<start_of_turn>user\n{system}\n\n{user}<end_of_turn>\n<start_of_turn>model\n",
568            system = system_prompt,
569            user = user_prompt,
570        ),
571        // Fallback for unknown architectures
572        _ => format!(
573            "System:\n{system}\n\nUser:\n{user}\n\nAssistant:\n",
574            system = system_prompt,
575            user = user_prompt,
576        ),
577    }
578}
579
580fn select_candle_device(config: &ThinkerConfig) -> Result<(Device, String)> {
581    match config.candle_device {
582        CandleDevicePreference::Cpu => Ok((Device::Cpu, "cpu".to_string())),
583        CandleDevicePreference::Cuda => {
584            let device = try_cuda_device(config.candle_cuda_ordinal)?;
585            Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
586        }
587        CandleDevicePreference::Auto => match try_cuda_device(config.candle_cuda_ordinal) {
588            Ok(device) => {
589                tracing::info!(
590                    ordinal = config.candle_cuda_ordinal,
591                    "Candle thinker selected CUDA device"
592                );
593                Ok((device, format!("cuda:{}", config.candle_cuda_ordinal)))
594            }
595            Err(error) => {
596                tracing::warn!(
597                    %error,
598                    "CUDA unavailable for Candle thinker, falling back to CPU"
599                );
600                Ok((Device::Cpu, "cpu".to_string()))
601            }
602        },
603    }
604}
605
606#[cfg(feature = "candle-cuda")]
607fn try_cuda_device(ordinal: usize) -> Result<Device> {
608    Device::new_cuda(ordinal)
609        .with_context(|| format!("failed to initialize CUDA device ordinal {}", ordinal))
610}
611
612#[cfg(not(feature = "candle-cuda"))]
613fn try_cuda_device(_ordinal: usize) -> Result<Device> {
614    Err(anyhow!(
615        "candle-cuda feature is not enabled in this build; rebuild with --features candle-cuda"
616    ))
617}
618
619fn detect_context_window(content: &gguf_file::Content, architecture: &str) -> Option<usize> {
620    let key = match architecture {
621        "qwen2" => "qwen2.context_length",
622        "gemma" | "gemma2" | "gemma3" | "gemma-embedding" => {
623            // Try gemma3 first, then fall back to gemma2, gemma
624            for prefix in ["gemma3", "gemma2", "gemma"] {
625                let k = format!("{prefix}.context_length");
626                if let Some(v) = content.metadata.get(&k) {
627                    return v.to_u32().ok().map(|v| v as usize);
628                }
629            }
630            return None;
631        }
632        _ => "llama.context_length",
633    };
634    content
635        .metadata
636        .get(key)
637        .and_then(|v| v.to_u32().ok())
638        .map(|v| v as usize)
639}
640
641/// Extract EOS token IDs from GGUF metadata before the content is consumed.
642fn extract_gguf_eos_ids(content: &gguf_file::Content) -> Vec<u32> {
643    let mut ids = Vec::new();
644    for key in ["tokenizer.ggml.eos_token_id", "tokenizer.ggml.eot_token_id"] {
645        if let Some(v) = content.metadata.get(key) {
646            if let Ok(id) = v.to_u32() {
647                if !ids.contains(&id) {
648                    ids.push(id);
649                }
650            }
651        }
652    }
653    ids
654}
655
656fn collect_eos_token_ids(tokenizer: &Tokenizer, gguf_eos_ids: &[u32]) -> HashSet<u32> {
657    let mut ids: HashSet<u32> = gguf_eos_ids.iter().copied().collect();
658
659    // Also check well-known special token strings as fallback.
660    let candidates = [
661        "<|im_end|>",
662        "<|eot_id|>",
663        "<|endoftext|>",
664        "</s>",
665        "<|end|>",
666        "<end_of_turn>",
667    ];
668    for token in candidates {
669        if let Some(id) = tokenizer.token_to_id(token) {
670            ids.insert(id);
671        }
672    }
673    ids
674}
675
676/// Returns true for HTTP status codes that are worth retrying.
677fn is_transient_http_error(status: u16) -> bool {
678    matches!(status, 429 | 502 | 503 | 504)
679}
680
681/// Returns true for reqwest errors that are worth retrying (timeouts, connection resets).
682fn is_transient_reqwest_error(e: &reqwest::Error) -> bool {
683    e.is_timeout() || e.is_connect() || e.is_request()
684}
685
686#[derive(Debug, Serialize)]
687struct OpenAIChatRequest {
688    model: String,
689    messages: Vec<OpenAIMessage>,
690    temperature: f32,
691    #[serde(skip_serializing_if = "Option::is_none")]
692    top_p: Option<f32>,
693    max_tokens: usize,
694    stream: bool,
695}
696
697#[derive(Debug, Serialize)]
698struct OpenAIMessage {
699    role: String,
700    content: String,
701}
702
703#[derive(Debug, Deserialize)]
704struct OpenAIChatResponse {
705    model: Option<String>,
706    choices: Vec<OpenAIChatChoice>,
707    #[serde(default)]
708    usage: Option<OpenAIUsage>,
709}
710
711#[derive(Debug, Deserialize)]
712struct OpenAIChatChoice {
713    message: OpenAIChatChoiceMessage,
714    #[serde(default)]
715    finish_reason: Option<String>,
716}
717
718#[derive(Debug, Deserialize)]
719struct OpenAIChatChoiceMessage {
720    #[serde(default)]
721    content: Option<OpenAIChatContent>,
722    #[serde(default)]
723    reasoning: Option<String>,
724    #[serde(default)]
725    reasoning_content: Option<String>,
726}
727
728#[derive(Debug, Default, Deserialize)]
729struct OpenAIUsage {
730    prompt_tokens: Option<u32>,
731    completion_tokens: Option<u32>,
732    total_tokens: Option<u32>,
733}
734
735#[derive(Debug, Deserialize)]
736#[serde(untagged)]
737enum OpenAIChatContent {
738    Text(String),
739    Parts(Vec<OpenAIChatContentPart>),
740    Part(OpenAIChatContentPart),
741}
742
743#[derive(Debug, Deserialize)]
744struct OpenAIChatContentPart {
745    #[serde(rename = "type")]
746    kind: Option<String>,
747    #[serde(default)]
748    text: Option<String>,
749    #[serde(default)]
750    content: Option<String>,
751}
752
753impl OpenAIChatChoiceMessage {
754    fn extract_text(&self) -> String {
755        let content_text = self
756            .content
757            .as_ref()
758            .map(OpenAIChatContent::to_text)
759            .unwrap_or_default();
760        if !content_text.trim().is_empty() {
761            return content_text;
762        }
763
764        if let Some(reasoning) = self
765            .reasoning
766            .as_deref()
767            .filter(|text| !text.trim().is_empty())
768        {
769            return reasoning.to_string();
770        }
771
772        self.reasoning_content
773            .as_deref()
774            .filter(|text| !text.trim().is_empty())
775            .unwrap_or_default()
776            .to_string()
777    }
778}
779
780impl OpenAIChatContent {
781    fn to_text(&self) -> String {
782        match self {
783            Self::Text(text) => text.clone(),
784            Self::Parts(parts) => parts
785                .iter()
786                .filter_map(OpenAIChatContentPart::text_fragment)
787                .collect::<Vec<_>>()
788                .join("\n"),
789            Self::Part(part) => part.text_fragment().unwrap_or_default(),
790        }
791    }
792}
793
794impl OpenAIChatContentPart {
795    fn text_fragment(&self) -> Option<String> {
796        if let Some(kind) = self.kind.as_deref()
797            && !kind.eq_ignore_ascii_case("text")
798            && !kind.eq_ignore_ascii_case("output_text")
799        {
800            return None;
801        }
802
803        self.text
804            .as_deref()
805            .or(self.content.as_deref())
806            .map(str::trim)
807            .filter(|text| !text.is_empty())
808            .map(ToString::to_string)
809    }
810}