openai_mock/utils/
choices.rs1use crate::models::completion::Choice;
2use crate::utils::token_counting::TokenCounter;
3use std::collections::HashMap;
4use rand::{thread_rng, Rng};
5use crate::models::completion::Logprobs;
6
7
8impl Choice {
9 pub fn new(index: i32, text: String, echo: bool, prompt: &str) -> Self {
10 let final_text = if echo {
11 format!("{}{}", prompt, text)
12 } else {
13 text
14 };
15
16 Choice {
17 text: final_text,
18 index,
19 logprobs: None,
20 finish_reason: None,
21 }
22 }
23
24 fn generate_mock_logprobs(&self, text: &str, logprobs_n: u32) -> Logprobs {
25 let mut rng = thread_rng();
26
27 let tokens: Vec<String> = text
29 .split_whitespace()
30 .map(|s| s.to_string())
31 .collect();
32
33 let mut current_offset = 0;
34 let mut text_offset: Vec<usize> = Vec::new();
35
36 for token in &tokens {
38 text_offset.push(current_offset);
39 current_offset += token.len() + 1; }
41
42 let token_logprobs: Vec<f32> = (0..tokens.len())
44 .map(|_| -rng.gen_range(0.0..5.0))
45 .collect();
46
47 let top_logprobs: Vec<HashMap<String, f32>> = tokens
49 .iter()
50 .map(|_| {
51 let mut map = HashMap::new();
52 for _ in 0..logprobs_n {
53 let mock_token = format!("token_{}", rng.gen_range(0..100));
54 let mock_logprob = -rng.gen_range(0.0..10.0);
55 map.insert(mock_token, mock_logprob);
56 }
57 map
58 })
59 .collect();
60
61 Logprobs {
62 tokens,
63 token_logprobs,
64 text_offset,
65 top_logprobs,
66 }
67 }
68
69 pub fn generate_text(
70 &mut self,
71 prompt: &str,
72 stop_sequences: &[String],
73 max_tokens: u32,
74 echo: bool,
75 logprobs_n: Option<u32>,
76 model: &str
77 ) {
78 let mut generated = if echo {
79 prompt.to_string()
80 } else {
81 String::new()
82 };
83
84 for stop_seq in stop_sequences {
86 if generated.contains(stop_seq) {
87 self.finish_reason = Some("stop".to_string());
88 generated = generated.split(stop_seq).next().unwrap_or("").to_string();
89 self.text = generated;
90 return;
91 }
92 }
93 let token_counter = TokenCounter::new(&model);
94 match token_counter {
95 Ok(token_counter) => {
96 let estimated_tokens = token_counter.count_tokens(&generated);
98 if estimated_tokens >= max_tokens {
99 self.finish_reason = Some("length".to_string());
100 self.text = token_counter.truncate_to_tokens(&generated, max_tokens);
101 return;
102 }
103 },
104 Err(e) => {
105 eprintln!("Error creating token counter: {}", e);
106 }
107 }
108
109 self.text = generated;
110 self.finish_reason = Some("content".to_string());
111
112 if let Some(n) = logprobs_n {
114 self.logprobs = Some(self.generate_mock_logprobs(&self.text, n));
115 }
116 }
117}
118
119pub fn create_choices(
120 n: i32,
121 prompt: &str,
122 stop_sequences: &[String],
123 max_tokens: u32,
124 echo: bool,
125 logprobs: Option<u32>,
126 model: &str
127) -> Vec<Choice> {
128 let mut choices = Vec::with_capacity(n as usize);
129
130 for i in 0..n {
131 let mut choice = Choice::new(i, String::new(), echo, prompt);
132 choice.generate_text(prompt, stop_sequences, max_tokens, echo, logprobs, model);
133 choices.push(choice);
134 }
135
136 choices
137}