1use async_std::task;
2use chrono::Local;
3use crate::{ApiProvider, Error};
4use crate::audit::{
5 AuditRecordAt,
6 dump_api_usage,
7 dump_pdl,
8};
9use crate::message::{message_contents_to_json_array, message_to_json};
10use crate::model::{Model, ModelRaw};
11use crate::response::Response;
12use ragit_fs::{
13 WriteMode,
14 create_dir_all,
15 exists,
16 join,
17 write_log,
18 write_string,
19};
20use ragit_pdl::{Message, Role, Schema};
21use serde::de::DeserializeOwned;
22use serde_json::{Map, Value};
23use std::time::{Duration, Instant};
24
25#[derive(Clone, Debug)]
26pub struct Request {
27 pub messages: Vec<Message>,
28 pub model: Model,
29 pub temperature: Option<f64>,
30 pub frequency_penalty: Option<f64>,
31 pub max_tokens: Option<usize>,
32
33 pub timeout: Option<u64>,
35
36 pub max_retry: usize,
38
39 pub sleep_between_retries: u64,
41 pub dump_api_usage_at: Option<AuditRecordAt>,
42
43 pub dump_pdl_at: Option<String>,
45
46 pub dump_json_at: Option<String>,
48
49 pub schema: Option<Schema>,
53
54 pub schema_max_try: usize,
57}
58
59impl Request {
60 pub fn is_valid(&self) -> bool {
61 self.messages.len() > 1
62 && self.messages.len() & 1 == 0 && self.messages[0].is_valid_system_prompt() && {
65 let mut flag = true;
66
67 for (index, message) in self.messages[1..].iter().enumerate() {
68 if index & 1 == 0 && !message.is_user_prompt() {
69 flag = false;
70 break;
71 }
72
73 else if index & 1 == 1 && !message.is_assistant_prompt() {
74 flag = false;
75 break;
76 }
77 }
78
79 flag
80 }
81 }
82
83 pub fn build_json_body(&self) -> Value {
85 match &self.model.api_provider {
86 ApiProvider::Google => {
87 let mut result = Map::new();
88 let mut contents = vec![];
89 let mut system_prompt = vec![];
90
91 for message in self.messages.iter() {
92 if message.role == Role::System {
93 match message_contents_to_json_array(&message.content, &ApiProvider::Google) {
94 Value::Array(parts) => {
95 system_prompt.push(parts);
96 },
97 _ => unreachable!(),
98 }
99 }
100
101 else {
102 contents.push(message_to_json(message, &self.model.api_provider));
103 }
104 }
105
106 if !system_prompt.is_empty() {
107 let parts = system_prompt.concat();
108 let mut system_prompt = Map::new();
109 system_prompt.insert(String::from("parts"), parts.into());
110 result.insert(String::from("system_instruction"), system_prompt.into());
111 }
112
113 result.insert(String::from("contents"), contents.into());
116 result.into()
117 },
118 ApiProvider::OpenAi { .. } | ApiProvider::Cohere => {
119 let mut result = Map::new();
120 result.insert(String::from("model"), self.model.api_name.clone().into());
121 let mut messages = vec![];
122
123 for message in self.messages.iter() {
124 messages.push(message_to_json(message, &self.model.api_provider));
125 }
126
127 result.insert(String::from("messages"), messages.into());
128
129 if let Some(temperature) = self.temperature {
130 result.insert(String::from("temperature"), temperature.into());
131 }
132
133 if let Some(frequency_penalty) = self.frequency_penalty {
134 result.insert(String::from("frequency_penalty"), frequency_penalty.into());
135 }
136
137 if let Some(max_tokens) = self.max_tokens {
138 result.insert(String::from("max_tokens"), max_tokens.into());
139 }
140
141 if self.model.api_name.contains("gpt-5") {
143 result.insert(String::from("reasoning_effort"), "low".into());
144 }
145
146 result.into()
147 },
148 ApiProvider::Anthropic => {
149 let mut result = Map::new();
150 result.insert(String::from("model"), self.model.api_name.clone().into());
151 let mut messages = vec![];
152 let mut system_prompt = vec![];
153
154 for message in self.messages.iter() {
155 if message.role == Role::System {
156 system_prompt.push(message.content[0].unwrap_str().to_string());
157 }
158
159 else {
160 messages.push(message_to_json(message, &ApiProvider::Anthropic));
161 }
162 }
163
164 let system_prompt = system_prompt.concat();
165
166 if !system_prompt.is_empty() {
167 result.insert(String::from("system"), system_prompt.into());
168 }
169
170 result.insert(String::from("messages"), messages.into());
171
172 if let Some(temperature) = self.temperature {
173 result.insert(String::from("temperature"), temperature.into());
174 }
175
176 if let Some(frequency_penalty) = self.frequency_penalty {
177 result.insert(String::from("frequency_penalty"), frequency_penalty.into());
178 }
179
180 result.insert(String::from("max_tokens"), self.max_tokens.unwrap_or(16384).into());
182
183 result.into()
189 },
190 ApiProvider::Test(_) => Value::Null,
191 }
192 }
193
194 pub async fn send_and_validate<T: DeserializeOwned>(&self, default: T) -> Result<T, Error> {
197 let mut state = self.clone();
198 let mut messages = self.messages.clone();
199
200 for _ in 0..state.schema_max_try {
201 state.messages = messages.clone();
202 let response = state.send().await?;
203 let response = response.get_message(0).unwrap();
204
205 match state.schema.as_ref().unwrap().validate(&response) {
206 Ok(v) => {
207 return Ok(serde_json::from_value::<T>(v)?);
208 },
209 Err(error_message) => {
210 messages.push(Message::simple_message(Role::Assistant, response.to_string()));
211 messages.push(Message::simple_message(Role::User, error_message));
212 },
213 }
214 }
215
216 Ok(default)
217 }
218
219 pub fn blocking_send(&self) -> Result<Response, Error> {
223 futures::executor::block_on(self.send())
224 }
225
226 pub async fn send(&self) -> Result<Response, Error> {
228 let started_at = Instant::now();
229 let client = reqwest::Client::new();
230 let mut curr_error = Error::NoTry;
231
232 let post_url = self.model.get_api_url()?;
233 let body = self.build_json_body();
234
235 if let Err(e) = self.dump_json(&body, "request") {
236 write_log(
237 "dump_json",
238 &format!("dump_json(\"request\", ..) failed with {e:?}"),
239 );
240 }
241
242 if let ApiProvider::Test(test_model) = &self.model.api_provider {
243 let response = test_model.get_dummy_response(&self.messages)?;
244
245 if let Some(key) = &self.dump_api_usage_at {
246 if let Err(e) = dump_api_usage(
247 key,
248 0,
249 0,
250 self.model.dollars_per_1b_input_tokens,
251 self.model.dollars_per_1b_output_tokens,
252 false,
253 ) {
254 write_log(
255 "dump_api_usage",
256 &format!("dump_api_usage({key:?}, ..) failed with {e:?}"),
257 );
258 }
259 }
260
261 if let Some(path) = &self.dump_pdl_at {
262 if let Err(e) = dump_pdl(
263 &self.messages,
264 &response,
265 &None,
266 path,
267 String::from("model: dummy, input_tokens: 0, output_tokens: 0, took: 0ms"),
268 ) {
269 write_log(
270 "dump_pdl",
271 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
272 );
273
274 }
277 }
278
279 return Ok(Response::dummy(response));
280 }
281
282 let body = serde_json::to_string(&body)?;
283 let api_key = self.model.get_api_key()?;
284 write_log(
285 "chat_request::send",
286 &format!("entered chat_request::send() with {} bytes, model: {}", body.len(), self.model.name),
287 );
288
289 for _ in 0..(self.max_retry + 1) {
290 let mut request = client.post(&post_url)
291 .header(reqwest::header::CONTENT_TYPE, "application/json")
292 .body(body.clone());
293
294 match &self.model.api_provider {
295 ApiProvider::Anthropic => {
296 request = request.header("x-api-key", api_key.clone())
297 .header("anthropic-version", "2023-06-01");
298 },
299 ApiProvider::Google => {},
300 _ if !api_key.is_empty() => {
301 request = request.bearer_auth(api_key.clone());
302 },
303 _ => {},
304 }
305
306 if let Some(t) = self.timeout {
307 request = request.timeout(Duration::from_millis(t));
308 }
309
310 write_log(
311 "chat_request::send",
312 "a request sent",
313 );
314 let response = request.send().await;
315 write_log(
316 "chat_request::send",
317 "got a response from a request",
318 );
319
320 match response {
321 Ok(response) => match response.status().as_u16() {
322 200 => match response.text().await {
323 Ok(text) => {
324 match serde_json::from_str::<Value>(&text) {
325 Ok(v) => match self.dump_json(&v, "response") {
326 Err(e) => {
327 write_log(
328 "dump_json",
329 &format!("dump_json(\"response\", ..) failed with {e:?}"),
330 );
331 },
332 Ok(_) => {},
333 },
334 Err(e) => {
335 write_log(
336 "dump_json",
337 &format!("dump_json(\"response\", ..) failed with {e:?}"),
338 );
339 },
340 }
341
342 match Response::from_str(&text, &self.model.api_provider) {
343 Ok(result) => {
344 if let Some(key) = &self.dump_api_usage_at {
345 if let Err(e) = dump_api_usage(
346 key,
347 result.get_prompt_token_count() as u64,
348 result.get_output_token_count() as u64,
349 self.model.dollars_per_1b_input_tokens,
350 self.model.dollars_per_1b_output_tokens,
351 false,
352 ) {
353 write_log(
354 "dump_api_usage",
355 &format!("dump_api_usage({key:?}, ..) failed with {e:?}"),
356 );
357 }
358 }
359
360 if let Some(path) = &self.dump_pdl_at {
361 if let Err(e) = dump_pdl(
362 &self.messages,
363 &result.get_message(0).map(|m| m.to_string()).unwrap_or(String::new()),
364 &result.get_reasoning(0).map(|m| m.to_string()),
365 path,
366 format!(
367 "model: {}, input_tokens: {}, output_tokens: {}, took: {}ms",
368 self.model.name,
369 result.get_prompt_token_count(),
370 result.get_output_token_count(),
371 Instant::now().duration_since(started_at.clone()).as_millis(),
372 ),
373 ) {
374 write_log(
375 "dump_pdl",
376 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
377 );
378
379 }
382 }
383
384 return Ok(result);
385 },
386 Err(e) => {
387 write_log(
388 "Response::from_str",
389 &format!("Response::from_str(..) failed with {e:?}"),
390 );
391 curr_error = e;
392 },
393 }
394 },
395 Err(e) => {
396 write_log(
397 "response.text()",
398 &format!("response.text() failed with {e:?}"),
399 );
400 curr_error = Error::ReqwestError(e);
401 },
402 },
403 status_code => {
404 curr_error = Error::ServerError {
405 status_code,
406 body: response.text().await,
407 };
408
409 if let Some(path) = &self.dump_pdl_at {
410 if let Err(e) = dump_pdl(
411 &self.messages,
412 "",
413 &None,
414 path,
415 format!("{}# error: {curr_error:?} #{}", '{', '}'),
416 ) {
417 write_log(
418 "dump_pdl",
419 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
420 );
421 }
422 }
423
424 if !self.model.can_read_images && self.messages.iter().any(|message| message.has_image()) {
431 return Err(Error::CannotReadImage(self.model.name.clone()));
432 }
433
434 if status_code == 400 {
436 return Err(curr_error);
437 }
438 },
439 },
440 Err(e) => {
441 write_log(
442 "request.send().await",
443 &format!("request.send().await failed with {e:?}"),
444 );
445 curr_error = Error::ReqwestError(e);
446 },
447 }
448
449 task::sleep(Duration::from_millis(self.sleep_between_retries)).await
450 }
451
452 Err(curr_error)
453 }
454
455 fn dump_json(&self, j: &Value, header: &str) -> Result<(), Error> {
456 if let Some(dir) = &self.dump_json_at {
457 if !exists(dir) {
458 create_dir_all(dir)?;
459 }
460
461 let path = join(
462 &dir,
463 &format!("{header}-{}.json", Local::now().to_rfc3339()),
464 )?;
465 write_string(&path, &serde_json::to_string_pretty(j)?, WriteMode::AlwaysCreate)?;
466 }
467
468 Ok(())
469 }
470}
471
472impl Default for Request {
473 fn default() -> Self {
474 Request {
475 messages: vec![],
476 model: (&ModelRaw::llama_70b()).try_into().unwrap(),
477 temperature: None,
478 frequency_penalty: None,
479 max_tokens: None,
480 timeout: Some(20_000),
481 max_retry: 2,
482 sleep_between_retries: 6_000,
483 dump_api_usage_at: None,
484 dump_pdl_at: None,
485 dump_json_at: None,
486 schema: None,
487 schema_max_try: 3,
488 }
489 }
490}