1use async_std::task;
2use chrono::Local;
3use crate::{ApiProvider, Error};
4use crate::audit::{
5 AuditRecordAt,
6 dump_api_usage,
7 dump_pdl,
8};
9use crate::message::{message_contents_to_json_array, message_to_json};
10use crate::model::{Model, ModelRaw};
11use crate::response::Response;
12use ragit_fs::{
13 WriteMode,
14 create_dir_all,
15 exists,
16 join,
17 write_log,
18 write_string,
19};
20use ragit_pdl::{Message, Role, Schema};
21use serde::de::DeserializeOwned;
22use serde_json::{Map, Value};
23use std::time::{Duration, Instant};
24
25#[derive(Clone, Debug)]
26pub struct Request {
27 pub messages: Vec<Message>,
28 pub model: Model,
29 pub temperature: Option<f64>,
30 pub frequency_penalty: Option<f64>,
31 pub max_tokens: Option<usize>,
32
33 pub timeout: Option<u64>,
35
36 pub max_retry: usize,
38
39 pub sleep_between_retries: u64,
41 pub dump_api_usage_at: Option<AuditRecordAt>,
42
43 pub dump_pdl_at: Option<String>,
45
46 pub dump_json_at: Option<String>,
48
49 pub schema: Option<Schema>,
53
54 pub schema_max_try: usize,
57}
58
59impl Request {
60 pub fn is_valid(&self) -> bool {
61 self.messages.len() > 1
62 && self.messages.len() & 1 == 0 && self.messages[0].is_valid_system_prompt() && {
65 let mut flag = true;
66
67 for (index, message) in self.messages[1..].iter().enumerate() {
68 if index & 1 == 0 && !message.is_user_prompt() {
69 flag = false;
70 break;
71 }
72
73 else if index & 1 == 1 && !message.is_assistant_prompt() {
74 flag = false;
75 break;
76 }
77 }
78
79 flag
80 }
81 }
82
83 pub fn build_json_body(&self) -> Value {
85 match &self.model.api_provider {
86 ApiProvider::Google => {
87 let mut result = Map::new();
88 let mut contents = vec![];
89 let mut system_prompt = vec![];
90
91 for message in self.messages.iter() {
92 if message.role == Role::System {
93 match message_contents_to_json_array(&message.content, &ApiProvider::Google) {
94 Value::Array(parts) => {
95 system_prompt.push(parts);
96 },
97 _ => unreachable!(),
98 }
99 }
100
101 else {
102 contents.push(message_to_json(message, &self.model.api_provider));
103 }
104 }
105
106 if !system_prompt.is_empty() {
107 let parts = system_prompt.concat();
108 let mut system_prompt = Map::new();
109 system_prompt.insert(String::from("parts"), parts.into());
110 result.insert(String::from("system_instruction"), system_prompt.into());
111 }
112
113 result.insert(String::from("contents"), contents.into());
116 result.into()
117 },
118 ApiProvider::OpenAi { .. } | ApiProvider::Cohere => {
119 let mut result = Map::new();
120 result.insert(String::from("model"), self.model.api_name.clone().into());
121 let mut messages = vec![];
122
123 for message in self.messages.iter() {
124 messages.push(message_to_json(message, &self.model.api_provider));
125 }
126
127 result.insert(String::from("messages"), messages.into());
128
129 if let Some(temperature) = self.temperature {
130 result.insert(String::from("temperature"), temperature.into());
131 }
132
133 if let Some(frequency_penalty) = self.frequency_penalty {
134 result.insert(String::from("frequency_penalty"), frequency_penalty.into());
135 }
136
137 if let Some(max_tokens) = self.max_tokens {
138 result.insert(String::from("max_tokens"), max_tokens.into());
139 }
140
141 result.into()
142 },
143 ApiProvider::Anthropic => {
144 let mut result = Map::new();
145 result.insert(String::from("model"), self.model.api_name.clone().into());
146 let mut messages = vec![];
147 let mut system_prompt = vec![];
148
149 for message in self.messages.iter() {
150 if message.role == Role::System {
151 system_prompt.push(message.content[0].unwrap_str().to_string());
152 }
153
154 else {
155 messages.push(message_to_json(message, &ApiProvider::Anthropic));
156 }
157 }
158
159 let system_prompt = system_prompt.concat();
160
161 if !system_prompt.is_empty() {
162 result.insert(String::from("system"), system_prompt.into());
163 }
164
165 result.insert(String::from("messages"), messages.into());
166
167 if let Some(temperature) = self.temperature {
168 result.insert(String::from("temperature"), temperature.into());
169 }
170
171 if let Some(frequency_penalty) = self.frequency_penalty {
172 result.insert(String::from("frequency_penalty"), frequency_penalty.into());
173 }
174
175 result.insert(String::from("max_tokens"), self.max_tokens.unwrap_or(2048).into());
177
178 result.into()
179 },
180 ApiProvider::Test(_) => Value::Null,
181 }
182 }
183
184 pub async fn send_and_validate<T: DeserializeOwned>(&self, default: T) -> Result<T, Error> {
187 let mut state = self.clone();
188 let mut messages = self.messages.clone();
189
190 for _ in 0..state.schema_max_try {
191 state.messages = messages.clone();
192 let response = state.send().await?;
193 let response = response.get_message(0).unwrap();
194
195 match state.schema.as_ref().unwrap().validate(&response) {
196 Ok(v) => {
197 return Ok(serde_json::from_value::<T>(v)?);
198 },
199 Err(error_message) => {
200 messages.push(Message::simple_message(Role::Assistant, response.to_string()));
201 messages.push(Message::simple_message(Role::User, error_message));
202 },
203 }
204 }
205
206 Ok(default)
207 }
208
209 pub fn blocking_send(&self) -> Result<Response, Error> {
213 futures::executor::block_on(self.send())
214 }
215
216 pub async fn send(&self) -> Result<Response, Error> {
218 let started_at = Instant::now();
219 let client = reqwest::Client::new();
220 let mut curr_error = Error::NoTry;
221
222 let post_url = self.model.get_api_url()?;
223 let body = self.build_json_body();
224
225 if let Err(e) = self.dump_json(&body, "request") {
226 write_log(
227 "dump_json",
228 &format!("dump_json(\"request\", ..) failed with {e:?}"),
229 );
230 }
231
232 if let ApiProvider::Test(test_model) = &self.model.api_provider {
233 let response = test_model.get_dummy_response(&self.messages)?;
234
235 if let Some(key) = &self.dump_api_usage_at {
236 if let Err(e) = dump_api_usage(
237 key,
238 0,
239 0,
240 self.model.dollars_per_1b_input_tokens,
241 self.model.dollars_per_1b_output_tokens,
242 false,
243 ) {
244 write_log(
245 "dump_api_usage",
246 &format!("dump_api_usage({key:?}, ..) failed with {e:?}"),
247 );
248 }
249 }
250
251 if let Some(path) = &self.dump_pdl_at {
252 if let Err(e) = dump_pdl(
253 &self.messages,
254 &response,
255 &None,
256 path,
257 String::from("model: dummy, input_tokens: 0, output_tokens: 0, took: 0ms"),
258 ) {
259 write_log(
260 "dump_pdl",
261 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
262 );
263
264 }
267 }
268
269 return Ok(Response::dummy(response));
270 }
271
272 let body = serde_json::to_string(&body)?;
273 let api_key = self.model.get_api_key()?;
274 write_log(
275 "chat_request::send",
276 &format!("entered chat_request::send() with {} bytes, model: {}", body.len(), self.model.name),
277 );
278
279 for _ in 0..(self.max_retry + 1) {
280 let mut request = client.post(&post_url)
281 .header(reqwest::header::CONTENT_TYPE, "application/json")
282 .body(body.clone());
283
284 match &self.model.api_provider {
285 ApiProvider::Anthropic => {
286 request = request.header("x-api-key", api_key.clone())
287 .header("anthropic-version", "2023-06-01");
288 },
289 ApiProvider::Google => {},
290 _ if !api_key.is_empty() => {
291 request = request.bearer_auth(api_key.clone());
292 },
293 _ => {},
294 }
295
296 if let Some(t) = self.timeout {
297 request = request.timeout(Duration::from_millis(t));
298 }
299
300 write_log(
301 "chat_request::send",
302 "a request sent",
303 );
304 let response = request.send().await;
305 write_log(
306 "chat_request::send",
307 "got a response from a request",
308 );
309
310 match response {
311 Ok(response) => match response.status().as_u16() {
312 200 => match response.text().await {
313 Ok(text) => {
314 match serde_json::from_str::<Value>(&text) {
315 Ok(v) => match self.dump_json(&v, "response") {
316 Err(e) => {
317 write_log(
318 "dump_json",
319 &format!("dump_json(\"response\", ..) failed with {e:?}"),
320 );
321 },
322 Ok(_) => {},
323 },
324 Err(e) => {
325 write_log(
326 "dump_json",
327 &format!("dump_json(\"response\", ..) failed with {e:?}"),
328 );
329 },
330 }
331
332 match Response::from_str(&text, &self.model.api_provider) {
333 Ok(result) => {
334 if let Some(key) = &self.dump_api_usage_at {
335 if let Err(e) = dump_api_usage(
336 key,
337 result.get_prompt_token_count() as u64,
338 result.get_output_token_count() as u64,
339 self.model.dollars_per_1b_input_tokens,
340 self.model.dollars_per_1b_output_tokens,
341 false,
342 ) {
343 write_log(
344 "dump_api_usage",
345 &format!("dump_api_usage({key:?}, ..) failed with {e:?}"),
346 );
347 }
348 }
349
350 if let Some(path) = &self.dump_pdl_at {
351 if let Err(e) = dump_pdl(
352 &self.messages,
353 &result.get_message(0).map(|m| m.to_string()).unwrap_or(String::new()),
354 &result.get_reasoning(0).map(|m| m.to_string()),
355 path,
356 format!(
357 "model: {}, input_tokens: {}, output_tokens: {}, took: {}ms",
358 self.model.name,
359 result.get_prompt_token_count(),
360 result.get_output_token_count(),
361 Instant::now().duration_since(started_at.clone()).as_millis(),
362 ),
363 ) {
364 write_log(
365 "dump_pdl",
366 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
367 );
368
369 }
372 }
373
374 return Ok(result);
375 },
376 Err(e) => {
377 write_log(
378 "Response::from_str",
379 &format!("Response::from_str(..) failed with {e:?}"),
380 );
381 curr_error = e;
382 },
383 }
384 },
385 Err(e) => {
386 write_log(
387 "response.text()",
388 &format!("response.text() failed with {e:?}"),
389 );
390 curr_error = Error::ReqwestError(e);
391 },
392 },
393 status_code => {
394 curr_error = Error::ServerError {
395 status_code,
396 body: response.text().await,
397 };
398
399 if let Some(path) = &self.dump_pdl_at {
400 if let Err(e) = dump_pdl(
401 &self.messages,
402 "",
403 &None,
404 path,
405 format!("{}# error: {curr_error:?} #{}", '{', '}'),
406 ) {
407 write_log(
408 "dump_pdl",
409 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
410 );
411 }
412 }
413
414 if !self.model.can_read_images && self.messages.iter().any(|message| message.has_image()) {
421 return Err(Error::CannotReadImage(self.model.name.clone()));
422 }
423 },
424 },
425 Err(e) => {
426 write_log(
427 "request.send().await",
428 &format!("request.send().await failed with {e:?}"),
429 );
430 curr_error = Error::ReqwestError(e);
431 },
432 }
433
434 task::sleep(Duration::from_millis(self.sleep_between_retries)).await
435 }
436
437 Err(curr_error)
438 }
439
440 fn dump_json(&self, j: &Value, header: &str) -> Result<(), Error> {
441 if let Some(dir) = &self.dump_json_at {
442 if !exists(dir) {
443 create_dir_all(dir)?;
444 }
445
446 let path = join(
447 &dir,
448 &format!("{header}-{}.json", Local::now().to_rfc3339()),
449 )?;
450 write_string(&path, &serde_json::to_string_pretty(j)?, WriteMode::AlwaysCreate)?;
451 }
452
453 Ok(())
454 }
455}
456
457impl Default for Request {
458 fn default() -> Self {
459 Request {
460 messages: vec![],
461 model: (&ModelRaw::llama_70b()).try_into().unwrap(),
462 temperature: None,
463 frequency_penalty: None,
464 max_tokens: None,
465 timeout: Some(20_000),
466 max_retry: 2,
467 sleep_between_retries: 6_000,
468 dump_api_usage_at: None,
469 dump_pdl_at: None,
470 dump_json_at: None,
471 schema: None,
472 schema_max_try: 3,
473 }
474 }
475}