openai_mock/utils/
choices.rs

1use crate::models::completion::Choice;
2use crate::utils::token_counting::TokenCounter;
3use std::collections::HashMap;
4use rand::{thread_rng, Rng};
5use crate::models::completion::Logprobs;
6
7
8impl Choice {
9    pub fn new(index: i32, text: String, echo: bool, prompt: &str) -> Self {
10        let final_text = if echo {
11            format!("{}{}", prompt, text)
12        } else {
13            text
14        };
15
16        Choice {
17            text: final_text,
18            index,
19            logprobs: None,
20            finish_reason: None,
21        }
22    }
23
24    fn generate_mock_logprobs(&self, text: &str, logprobs_n: u32) -> Logprobs {
25        let mut rng = thread_rng();
26
27        // Split text into mock tokens (simple word-based splitting for mock data)
28        let tokens: Vec<String> = text
29            .split_whitespace()
30            .map(|s| s.to_string())
31            .collect();
32
33        let mut current_offset = 0;
34        let mut text_offset: Vec<usize> = Vec::new();
35
36        // Generate text offsets
37        for token in &tokens {
38            text_offset.push(current_offset);
39            current_offset += token.len() + 1; // +1 for space
40        }
41
42        // Generate mock token logprobs
43        let token_logprobs: Vec<f32> = (0..tokens.len())
44            .map(|_| -rng.gen_range(0.0..5.0))
45            .collect();
46
47        // Generate top logprobs for each token
48        let top_logprobs: Vec<HashMap<String, f32>> = tokens
49            .iter()
50            .map(|_| {
51                let mut map = HashMap::new();
52                for _ in 0..logprobs_n {
53                    let mock_token = format!("token_{}", rng.gen_range(0..100));
54                    let mock_logprob = -rng.gen_range(0.0..10.0);
55                    map.insert(mock_token, mock_logprob);
56                }
57                map
58            })
59            .collect();
60
61        Logprobs {
62            tokens,
63            token_logprobs,
64            text_offset,
65            top_logprobs,
66        }
67    }
68
69    pub fn generate_text(
70        &mut self,
71        prompt: &str,
72        stop_sequences: &[String],
73        max_tokens: u32,
74        echo: bool,
75        logprobs_n: Option<u32>,
76        model: &str
77    ) {
78        let mut generated = if echo {
79            prompt.to_string()
80        } else {
81            String::new()
82        };
83
84        // Check for stop sequences
85        for stop_seq in stop_sequences {
86            if generated.contains(stop_seq) {
87                self.finish_reason = Some("stop".to_string());
88                generated = generated.split(stop_seq).next().unwrap_or("").to_string();
89                self.text = generated;
90                return;
91            }
92        }
93        let token_counter = TokenCounter::new(&model);
94        match token_counter {
95            Ok(token_counter) => {
96                // More robust token count estimation
97                let estimated_tokens = token_counter.count_tokens(&generated);
98                if estimated_tokens >= max_tokens {
99            self.finish_reason = Some("length".to_string());
100            self.text = token_counter.truncate_to_tokens(&generated, max_tokens);
101                    return;
102                }
103            },
104            Err(e) => {
105                eprintln!("Error creating token counter: {}", e);
106            }
107        }
108
109        self.text = generated;
110        self.finish_reason = Some("content".to_string());
111
112        // Generate logprobs if requested
113        if let Some(n) = logprobs_n {
114            self.logprobs = Some(self.generate_mock_logprobs(&self.text, n));
115        }
116    }
117}
118
119pub fn create_choices(
120    n: i32,
121    prompt: &str,
122    stop_sequences: &[String],
123    max_tokens: u32,
124    echo: bool,
125    logprobs: Option<u32>,
126    model: &str
127) -> Vec<Choice> {
128    let mut choices = Vec::with_capacity(n as usize);
129
130    for i in 0..n {
131        let mut choice = Choice::new(i, String::new(), echo, prompt);
132        choice.generate_text(prompt, stop_sequences, max_tokens, echo, logprobs, model);
133        choices.push(choice);
134    }
135
136    choices
137}