1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
use anyhow::{anyhow, Result};
use log::{error, info, warn};
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::domain::AllmsError;
use crate::llm_models::{LLMModel, LLMTools};
use crate::utils::{get_tokenizer, get_type_schema};
/// Completions APIs take a list of messages as input and return a model-generated message as output.
/// Although the Completions format is designed to make multi-turn conversations easy,
/// it’s just as useful for single-turn tasks without any conversation.
pub struct Completions<T: LLMModel> {
model: T,
//For prompt & response
max_tokens: usize,
temperature: f32,
input_json: Option<String>,
debug: bool,
function_call: bool,
api_key: String,
version: Option<String>,
tools: Option<Vec<LLMTools>>,
thinking_level: Option<ThinkingLevel>,
}
impl<T: LLMModel> Completions<T> {
/// Constructor for the Completions API
pub fn new(
model: T,
api_key: &str,
max_tokens: Option<usize>,
temperature: Option<u32>,
) -> Self {
let temperature = temperature
.map(|temp| model.get_normalized_temperature(temp))
.unwrap_or(model.get_default_temperature());
Completions {
//If no max tokens limit is provided we default to max allowed for the model
max_tokens: max_tokens.unwrap_or_else(|| model.default_max_tokens()),
function_call: model.function_call_default(),
model,
temperature,
input_json: None,
debug: false,
api_key: api_key.to_string(),
version: None,
tools: None,
thinking_level: None,
}
}
///
/// This function turns on debug mode which will info! the prompt to log when executing it.
///
pub fn debug(mut self) -> Self {
self.debug = true;
self
}
///
/// This function turns on/off function calling mode when interacting with OpenAI API.
///
pub fn function_calling(mut self, function_call: bool) -> Self {
self.function_call = function_call;
self
}
///
/// This method can be used to define the model temperature used by the Assistant
/// This method accepts % target of the acceptable range for the model
///
pub fn temperature(mut self, temp_target: u32) -> Self {
self.temperature = self.model.get_normalized_temperature(temp_target);
self
}
///
/// This method can be used to define the model temperature used by the Assistant
/// Using this method the temperature can be set directly without any validation of the range accepted by the model
/// For a range-safe implementation please consider using `OpenAIAssistant::temperature` method
///
pub fn temperature_unchecked(mut self, temp: f32) -> Self {
self.temperature = temp;
self
}
///
/// This method can be used to set the version of Completions API to be used
/// This is currently used for OpenAI models which can be run on OpenAI API or Azure API
///
pub fn version(mut self, version: &str) -> Self {
// TODO: We should use the model trait to check which versions are allowed
self.version = Some(version.to_string());
self
}
///
/// This method can be used to inform the model to use a tool.
/// Different models support different tool implementations.
///
pub fn add_tool(mut self, tool: LLMTools) -> Self {
self.tools = Some(match self.tools {
Some(mut tools) => {
tools.push(tool);
tools
}
None => vec![tool],
});
self
}
///
/// This method can be used to set the thinking level for the model
/// This is currently used for Gemini 3 models
///
pub fn thinking_level(mut self, thinking_level: ThinkingLevel) -> Self {
self.thinking_level = Some(thinking_level);
self
}
///
/// This method can be used to provide values that will be used as context for the prompt.
/// Using this function you can provide multiple input values by calling it multiple times. New values will be appended with the category name
/// It accepts any instance that implements the Serialize trait.
///
pub fn set_context<U: Serialize>(mut self, input_name: &str, input_data: &U) -> Result<Self> {
let input_json = if let Ok(json) = serde_json::to_string(&input_data) {
json
} else {
return Err(anyhow!("Unable serialize provided input data."));
};
let line_break = match self.input_json {
Some(_) => "\n\n".to_string(),
None => "".to_string(),
};
let new_json = format!(
"{}{}<{}>{}</{}>",
self.input_json.unwrap_or_default(),
line_break,
input_name,
input_json,
input_name,
);
self.input_json = Some(new_json);
Ok(self)
}
///
/// This method is used to check how many tokens would most likely remain for the response
/// This is accomplished by estimating number of tokens needed for system/base instructions, user prompt, and function components including schema definition.
///
pub fn check_prompt_tokens<U: JsonSchema + DeserializeOwned>(
&self,
instructions: &str,
) -> Result<usize> {
//Output schema is extracted from the type parameter
let schema = get_type_schema::<U>()?;
let context_text = self
.input_json
.as_ref()
.map(|context| format!("\n\n{}", &context))
.unwrap_or_default();
let prompt = format!(
"Instructions:
{instructions}{context_text}
Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.",
);
let full_prompt = format!(
"{}{}{}",
//Base (system) instructions
self.model.get_base_instructions(Some(self.function_call)),
//Instructions & context data
prompt,
//Output schema
schema
);
//Check how many tokens are required for prompt
let bpe = get_tokenizer(&self.model)?;
let prompt_tokens = bpe.encode_with_special_tokens(&full_prompt).len();
//Assuming another 5% overhead for json formatting
Ok((prompt_tokens as f64 * 1.05) as usize)
}
///
/// This method is used to submit a prompt to OpenAI and process the response.
/// When calling the function you need to specify the type parameter as the response will match the schema of that type.
/// The prompt in this function is written in a way to instruct OpenAI to behave like a computer function that calculates an output based on provided input and its language model.
///
pub async fn get_answer<U: JsonSchema + DeserializeOwned>(
self,
instructions: &str,
) -> Result<U> {
//Output schema is extracted from the type parameter
let schema = get_type_schema::<U>()?;
let json_schema = serde_json::from_str(&schema)?;
let context_text = self
.input_json
.as_ref()
.map(|context| format!("\n\n{}", &context))
.unwrap_or_default();
let prompt = format!("{instructions}{context_text}");
//Validate how many tokens remain for the response (and how many are used for prompt)
let prompt_tokens = self
.check_prompt_tokens::<U>(instructions)
.unwrap_or_default();
if prompt_tokens >= self.max_tokens {
return Err(anyhow!(
"The provided prompt requires more tokens than allocated."
));
}
let response_tokens = self.max_tokens - prompt_tokens;
//Throw a warning if after processing the prompt there might be not enough tokens for response
//This assumes response will be similar size as input. Because this is not always correct this is a warning and not an error
if prompt_tokens * 2 >= self.max_tokens {
warn!(
"{} tokens remaining for response: {} allocated, {} used for prompt",
response_tokens, self.max_tokens, prompt_tokens,
);
};
//Build the API body depending on the used model
let model_body = self.model.get_version_body(
&prompt,
&json_schema,
self.function_call,
&response_tokens,
&self.temperature,
self.version.clone(),
self.tools.as_deref(),
self.thinking_level.as_ref(),
);
//Display debug info if requested
if self.debug {
info!("[debug] Model body: {:#?}", model_body);
info!(
"[debug] Prompt accounts for approx {} tokens, leaving {} tokens for answer.",
prompt_tokens, response_tokens,
);
}
let response_text = self
.model
.call_api(
&self.api_key,
self.version.clone(),
&model_body,
self.debug,
self.tools.as_deref(),
)
.await?;
//Extract data from the returned response text based on the used model
let response_string = self
.model
.get_version_data(&response_text, self.function_call, self.version)
.map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: format!("assistants::completions::{}", self.model.as_str()),
error_message: format!(
"Completions API response serialization error: {}",
error
),
error_detail: response_text.to_string(),
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
if self.debug {
info!("[debug] Completions response data: {}", response_string);
}
//Deserialize the string response into the expected output type
serde_json::from_str(&response_string).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: format!("assistants::completions::{}", self.model.as_str()),
error_message: format!("Completions API response serialization error: {}", error),
error_detail: response_string,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})
}
}
#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ThinkingLevel {
Low,
#[default]
High,
}
impl ThinkingLevel {
pub fn as_str(&self) -> &str {
match self {
ThinkingLevel::Low => "low",
ThinkingLevel::High => "high",
}
}
pub fn try_from_str(s: &str) -> Option<Self> {
match s {
"low" => Some(ThinkingLevel::Low),
"high" => Some(ThinkingLevel::High),
_ => None,
}
}
}