1use async_std::task;
2use chrono::Local;
3use crate::{ApiProvider, Error};
4use crate::message::message_to_json;
5use crate::model::{Model, ModelRaw};
6use crate::record::{
7 RecordAt,
8 dump_pdl,
9 record_api_usage,
10};
11use crate::response::Response;
12use ragit_fs::{WriteMode, join, write_log, write_string};
13use ragit_pdl::{Message, Role, Schema};
14use serde::de::DeserializeOwned;
15use serde_json::{Map, Value};
16use std::time::{Duration, Instant};
17
18#[derive(Clone, Debug)]
19pub struct Request {
20 pub messages: Vec<Message>,
21 pub model: Model,
22 pub temperature: Option<f64>,
23 pub frequency_penalty: Option<f64>,
24 pub max_tokens: Option<usize>,
25
26 pub timeout: Option<u64>,
28
29 pub max_retry: usize,
31
32 pub sleep_between_retries: u64,
34 pub record_api_usage_at: Option<RecordAt>,
35
36 pub dump_pdl_at: Option<String>,
38
39 pub dump_json_at: Option<String>,
41
42 pub schema: Option<Schema>,
46
47 pub schema_max_try: usize,
50}
51
52impl Request {
53 pub fn is_valid(&self) -> bool {
54 self.messages.len() > 1
55 && self.messages.len() & 1 == 0 && self.messages[0].is_valid_system_prompt() && {
58 let mut flag = true;
59
60 for (index, message) in self.messages[1..].iter().enumerate() {
61 if index & 1 == 0 && !message.is_user_prompt() {
62 flag = false;
63 break;
64 }
65
66 else if index & 1 == 1 && !message.is_assistant_prompt() {
67 flag = false;
68 break;
69 }
70 }
71
72 flag
73 }
74 }
75
76 pub fn build_json_body(&self) -> Value {
78 match &self.model.api_provider {
79 ApiProvider::OpenAi { .. } | ApiProvider::Cohere => {
80 let mut result = Map::new();
81 result.insert(String::from("model"), self.model.api_name.clone().into());
82 let mut messages = vec![];
83
84 for message in self.messages.iter() {
85 messages.push(message_to_json(message, &self.model.api_provider));
86 }
87
88 result.insert(String::from("messages"), messages.into());
89
90 if let Some(temperature) = self.temperature {
91 result.insert(String::from("temperature"), temperature.into());
92 }
93
94 if let Some(frequency_penalty) = self.frequency_penalty {
95 result.insert(String::from("frequency_penalty"), frequency_penalty.into());
96 }
97
98 if let Some(max_tokens) = self.max_tokens {
99 result.insert(String::from("max_tokens"), max_tokens.into());
100 }
101
102 result.into()
103 },
104 ApiProvider::Anthropic => {
105 let mut result = Map::new();
106 result.insert(String::from("model"), self.model.api_name.clone().into());
107 let mut messages = vec![];
108 let mut system_prompt = vec![];
109
110 for message in self.messages.iter() {
111 if message.role == Role::System {
112 system_prompt.push(message.content[0].unwrap_str().to_string());
113 }
114
115 else {
116 messages.push(message_to_json(message, &ApiProvider::Anthropic));
117 }
118 }
119
120 let system_prompt = system_prompt.concat();
121
122 if !system_prompt.is_empty() {
123 result.insert(String::from("system"), system_prompt.into());
124 }
125
126 result.insert(String::from("messages"), messages.into());
127
128 if let Some(temperature) = self.temperature {
129 result.insert(String::from("temperature"), temperature.into()).unwrap();
130 }
131
132 if let Some(frequency_penalty) = self.frequency_penalty {
133 result.insert(String::from("frequency_penalty"), frequency_penalty.into());
134 }
135
136 result.insert(String::from("max_tokens"), self.max_tokens.unwrap_or(2048).into());
138
139 result.into()
140 },
141 ApiProvider::Test(_) => Value::Null,
142 }
143 }
144
145 pub async fn send_and_validate<T: DeserializeOwned>(&self, default: T) -> Result<T, Error> {
148 let mut state = self.clone();
149 let mut messages = self.messages.clone();
150
151 for _ in 0..state.schema_max_try {
152 state.messages = messages.clone();
153 let response = state.send().await?;
154 let response = response.get_message(0).unwrap();
155
156 match state.schema.as_ref().unwrap().validate(&response) {
157 Ok(v) => {
158 return Ok(serde_json::from_value::<T>(v)?);
159 },
160 Err(error_message) => {
161 messages.push(Message::simple_message(Role::Assistant, response.to_string()));
162 messages.push(Message::simple_message(Role::User, error_message));
163 },
164 }
165 }
166
167 Ok(default)
168 }
169
170 pub fn blocking_send(&self) -> Result<Response, Error> {
174 futures::executor::block_on(self.send())
175 }
176
177 pub async fn send(&self) -> Result<Response, Error> {
179 let started_at = Instant::now();
180 let client = reqwest::Client::new();
181 let mut curr_error = Error::NoTry;
182
183 let post_url = self.model.get_api_url();
184 let body = self.build_json_body();
185
186 if let Err(e) = self.dump_json(&body, "request") {
187 write_log(
188 "dump_json",
189 &format!("dump_json(\"request\", ..) failed with {e:?}"),
190 );
191 }
192
193 if let ApiProvider::Test(test_model) = &self.model.api_provider {
194 let response = test_model.get_dummy_response(&self.messages);
195
196 if let Some(key) = &self.record_api_usage_at {
197 if let Err(e) = record_api_usage(
198 key,
199 0,
200 0,
201 self.model.dollars_per_1b_input_tokens,
202 self.model.dollars_per_1b_output_tokens,
203 false,
204 ) {
205 write_log(
206 "record_api_usage",
207 &format!("record_api_usage({key:?}, ..) failed with {e:?}"),
208 );
209 }
210 }
211
212 if let Some(path) = &self.dump_pdl_at {
213 if let Err(e) = dump_pdl(
214 &self.messages,
215 &response,
216 &None,
217 path,
218 String::from("model: dummy, input_tokens: 0, output_tokens: 0, took: 0ms"),
219 ) {
220 write_log(
221 "dump_pdl",
222 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
223 );
224
225 }
228 }
229
230 return Ok(Response::dummy(response));
231 }
232
233 let body = serde_json::to_string(&body)?;
234 let api_key = self.model.get_api_key()?;
235 write_log(
236 "chat_request::send",
237 &format!("entered chat_request::send() with {} bytes, model: {}", body.len(), self.model.name),
238 );
239
240 for _ in 0..(self.max_retry + 1) {
241 let mut request = client.post(post_url)
242 .header(reqwest::header::CONTENT_TYPE, "application/json")
243 .body(body.clone());
244
245 if let ApiProvider::Anthropic = &self.model.api_provider {
246 request = request.header("x-api-key", api_key.clone())
247 .header("anthropic-version", "2023-06-01");
248 }
249
250 else if !api_key.is_empty() {
251 request = request.bearer_auth(api_key.clone());
252 }
253
254 if let Some(t) = self.timeout {
255 request = request.timeout(Duration::from_millis(t));
256 }
257
258 write_log(
259 "chat_request::send",
260 "a request sent",
261 );
262 let response = request.send().await;
263 write_log(
264 "chat_request::send",
265 "got a response from a request",
266 );
267
268 match response {
269 Ok(response) => match response.status().as_u16() {
270 200 => match response.text().await {
271 Ok(text) => {
272 match serde_json::from_str::<Value>(&text) {
273 Ok(v) => match self.dump_json(&v, "response") {
274 Err(e) => {
275 write_log(
276 "dump_json",
277 &format!("dump_json(\"response\", ..) failed with {e:?}"),
278 );
279 },
280 Ok(_) => {},
281 },
282 Err(e) => {
283 write_log(
284 "dump_json",
285 &format!("dump_json(\"response\", ..) failed with {e:?}"),
286 );
287 },
288 }
289
290 match Response::from_str(&text, &self.model.api_provider) {
291 Ok(result) => {
292 if let Some(key) = &self.record_api_usage_at {
293 if let Err(e) = record_api_usage(
294 key,
295 result.get_prompt_token_count() as u64,
296 result.get_output_token_count() as u64,
297 self.model.dollars_per_1b_input_tokens,
298 self.model.dollars_per_1b_output_tokens,
299 false,
300 ) {
301 write_log(
302 "record_api_usage",
303 &format!("record_api_usage({key:?}, ..) failed with {e:?}"),
304 );
305 }
306 }
307
308 if let Some(path) = &self.dump_pdl_at {
309 if let Err(e) = dump_pdl(
310 &self.messages,
311 &result.get_message(0).map(|m| m.to_string()).unwrap_or(String::new()),
312 &result.get_reasoning(0).map(|m| m.to_string()),
313 path,
314 format!(
315 "model: {}, input_tokens: {}, output_tokens: {}, took: {}ms",
316 self.model.name,
317 result.get_prompt_token_count(),
318 result.get_output_token_count(),
319 Instant::now().duration_since(started_at.clone()).as_millis(),
320 ),
321 ) {
322 write_log(
323 "dump_pdl",
324 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
325 );
326
327 }
330 }
331
332 return Ok(result);
333 },
334 Err(e) => {
335 write_log(
336 "Response::from_str",
337 &format!("Response::from_str(..) failed with {e:?}"),
338 );
339 curr_error = e;
340 },
341 }
342 },
343 Err(e) => {
344 write_log(
345 "response.text()",
346 &format!("response.text() failed with {e:?}"),
347 );
348 curr_error = Error::ReqwestError(e);
349 },
350 },
351 status_code => {
352 curr_error = Error::ServerError {
353 status_code,
354 body: response.text().await,
355 };
356
357 if let Some(path) = &self.dump_pdl_at {
358 if let Err(e) = dump_pdl(
359 &self.messages,
360 "",
361 &None,
362 path,
363 format!("{}# error: {curr_error:?} #{}", '{', '}'),
364 ) {
365 write_log(
366 "dump_pdl",
367 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
368 );
369 }
370 }
371 },
372 },
373 Err(e) => {
374 write_log(
375 "request.send().await",
376 &format!("request.send().await failed with {e:?}"),
377 );
378 curr_error = Error::ReqwestError(e);
379 },
380 }
381
382 task::sleep(Duration::from_millis(self.sleep_between_retries)).await
383 }
384
385 Err(curr_error)
386 }
387
388 fn dump_json(&self, j: &Value, header: &str) -> Result<(), Error> {
389 if let Some(dir) = &self.dump_json_at {
390 let path = join(
391 &dir,
392 &format!("{header}-{}.json", Local::now().to_rfc3339()),
393 )?;
394 write_string(&path, &serde_json::to_string_pretty(j)?, WriteMode::AlwaysCreate)?;
395 }
396
397 Ok(())
398 }
399}
400
401impl Default for Request {
402 fn default() -> Self {
403 Request {
404 messages: vec![],
405 model: (&ModelRaw::llama_70b()).try_into().unwrap(),
406 temperature: None,
407 frequency_penalty: None,
408 max_tokens: None,
409 timeout: Some(20_000),
410 max_retry: 2,
411 sleep_between_retries: 6_000,
412 record_api_usage_at: None,
413 dump_pdl_at: None,
414 dump_json_at: None,
415 schema: None,
416 schema_max_try: 3,
417 }
418 }
419}