ragit_api/
request.rs

1use async_std::task;
2use chrono::Local;
3use crate::{ApiProvider, Error};
4use crate::audit::{
5    AuditRecordAt,
6    dump_api_usage,
7    dump_pdl,
8};
9use crate::message::{message_contents_to_json_array, message_to_json};
10use crate::model::{Model, ModelRaw};
11use crate::response::Response;
12use ragit_fs::{
13    WriteMode,
14    create_dir_all,
15    exists,
16    join,
17    write_log,
18    write_string,
19};
20use ragit_pdl::{Message, Role, Schema};
21use serde::de::DeserializeOwned;
22use serde_json::{Map, Value};
23use std::time::{Duration, Instant};
24
25#[derive(Clone, Debug)]
26pub struct Request {
27    pub messages: Vec<Message>,
28    pub model: Model,
29    pub temperature: Option<f64>,
30    pub frequency_penalty: Option<f64>,
31    pub max_tokens: Option<usize>,
32
33    /// milliseconds
34    pub timeout: Option<u64>,
35
36    /// It tries 1 + max_retry times.
37    pub max_retry: usize,
38
39    /// milliseconds
40    pub sleep_between_retries: u64,
41    pub dump_api_usage_at: Option<AuditRecordAt>,
42
43    /// It dumps the AI conversation in pdl format. See <https://crates.io/crates/ragit-pdl> to read about pdl.
44    pub dump_pdl_at: Option<String>,
45
46    /// It's a directory, not a file. If given, it dumps `dir/request-<timestamp>.json` and `dir/response-<timestamp>.json`.
47    pub dump_json_at: Option<String>,
48
49    /// It can force LLMs to create a json output with a given schema.
50    /// You have to call `send_and_validate` instead of `send` if you want
51    /// to force the schema.
52    pub schema: Option<Schema>,
53
54    /// If LLMs fail to generate a valid schema `schema_max_try` times,
55    /// it returns a default value. If it's 0, it wouldn't call LLM at all!
56    pub schema_max_try: usize,
57}
58
59impl Request {
60    pub fn is_valid(&self) -> bool {
61        self.messages.len() > 1
62        && self.messages.len() & 1 == 0  // the last message must be user's
63        && self.messages[0].is_valid_system_prompt()  // I'm not sure whether all the models require the first message to be a system prompt. but it would be safer to guarantee that
64        && {
65            let mut flag = true;
66
67            for (index, message) in self.messages[1..].iter().enumerate() {
68                if index & 1 == 0 && !message.is_user_prompt() {
69                    flag = false;
70                    break;
71                }
72
73                else if index & 1 == 1 && !message.is_assistant_prompt() {
74                    flag = false;
75                    break;
76                }
77            }
78
79            flag
80        }
81    }
82
83    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
84    pub fn build_json_body(&self) -> Value {
85        match &self.model.api_provider {
86            ApiProvider::Google => {
87                let mut result = Map::new();
88                let mut contents = vec![];
89                let mut system_prompt = vec![];
90
91                for message in self.messages.iter() {
92                    if message.role == Role::System {
93                        match message_contents_to_json_array(&message.content, &ApiProvider::Google) {
94                            Value::Array(parts) => {
95                                system_prompt.push(parts);
96                            },
97                            _ => unreachable!(),
98                        }
99                    }
100
101                    else {
102                        contents.push(message_to_json(message, &self.model.api_provider));
103                    }
104                }
105
106                if !system_prompt.is_empty() {
107                    let parts = system_prompt.concat();
108                    let mut system_prompt = Map::new();
109                    system_prompt.insert(String::from("parts"), parts.into());
110                    result.insert(String::from("system_instruction"), system_prompt.into());
111                }
112
113                // TODO: temperature
114
115                result.insert(String::from("contents"), contents.into());
116                result.into()
117            },
118            ApiProvider::OpenAi { .. } | ApiProvider::Cohere => {
119                let mut result = Map::new();
120                result.insert(String::from("model"), self.model.api_name.clone().into());
121                let mut messages = vec![];
122
123                for message in self.messages.iter() {
124                    messages.push(message_to_json(message, &self.model.api_provider));
125                }
126
127                result.insert(String::from("messages"), messages.into());
128
129                if let Some(temperature) = self.temperature {
130                    result.insert(String::from("temperature"), temperature.into());
131                }
132
133                if let Some(frequency_penalty) = self.frequency_penalty {
134                    result.insert(String::from("frequency_penalty"), frequency_penalty.into());
135                }
136
137                if let Some(max_tokens) = self.max_tokens {
138                    result.insert(String::from("max_tokens"), max_tokens.into());
139                }
140
141                // NOTE: It's a temporary fix. Read `tests/unnecessary_reasoning.py` for more details.
142                if self.model.api_name.contains("gpt-5") {
143                    result.insert(String::from("reasoning_effort"), "low".into());
144                }
145
146                result.into()
147            },
148            ApiProvider::Anthropic => {
149                let mut result = Map::new();
150                result.insert(String::from("model"), self.model.api_name.clone().into());
151                let mut messages = vec![];
152                let mut system_prompt = vec![];
153
154                for message in self.messages.iter() {
155                    if message.role == Role::System {
156                        system_prompt.push(message.content[0].unwrap_str().to_string());
157                    }
158
159                    else {
160                        messages.push(message_to_json(message, &ApiProvider::Anthropic));
161                    }
162                }
163
164                let system_prompt = system_prompt.concat();
165
166                if !system_prompt.is_empty() {
167                    result.insert(String::from("system"), system_prompt.into());
168                }
169
170                result.insert(String::from("messages"), messages.into());
171
172                if let Some(temperature) = self.temperature {
173                    result.insert(String::from("temperature"), temperature.into());
174                }
175
176                if let Some(frequency_penalty) = self.frequency_penalty {
177                    result.insert(String::from("frequency_penalty"), frequency_penalty.into());
178                }
179
180                // it's a required field
181                result.insert(String::from("max_tokens"), self.max_tokens.unwrap_or(16384).into());
182
183                // TODO: make it configurable
184                // let mut thinking = Map::new();
185                // thinking.insert(String::from("type"), "disabled".into());
186                // result.insert(String::from("thinking"), thinking.into());
187
188                result.into()
189            },
190            ApiProvider::Test(_) => Value::Null,
191        }
192    }
193
194    /// It panics if `schema` field is missing.
195    /// It doesn't tell you whether the default value is used or not.
196    pub async fn send_and_validate<T: DeserializeOwned>(&self, default: T) -> Result<T, Error> {
197        let mut state = self.clone();
198        let mut messages = self.messages.clone();
199
200        for _ in 0..state.schema_max_try {
201            state.messages = messages.clone();
202            let response = state.send().await?;
203            let response = response.get_message(0).unwrap();
204
205            match state.schema.as_ref().unwrap().validate(&response) {
206                Ok(v) => {
207                    return Ok(serde_json::from_value::<T>(v)?);
208                },
209                Err(error_message) => {
210                    messages.push(Message::simple_message(Role::Assistant, response.to_string()));
211                    messages.push(Message::simple_message(Role::User, error_message));
212                },
213            }
214        }
215
216        Ok(default)
217    }
218
219    /// NOTE: this function dies ocassionally, for no reason.
220    ///
221    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
222    pub fn blocking_send(&self) -> Result<Response, Error> {
223        futures::executor::block_on(self.send())
224    }
225
226    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
227    pub async fn send(&self) -> Result<Response, Error> {
228        let started_at = Instant::now();
229        let client = reqwest::Client::new();
230        let mut curr_error = Error::NoTry;
231
232        let post_url = self.model.get_api_url()?;
233        let body = self.build_json_body();
234
235        if let Err(e) = self.dump_json(&body, "request") {
236            write_log(
237                "dump_json",
238                &format!("dump_json(\"request\", ..) failed with {e:?}"),
239            );
240        }
241
242        if let ApiProvider::Test(test_model) = &self.model.api_provider {
243            let response = test_model.get_dummy_response(&self.messages)?;
244
245            if let Some(key) = &self.dump_api_usage_at {
246                if let Err(e) = dump_api_usage(
247                    key,
248                    0,
249                    0,
250                    self.model.dollars_per_1b_input_tokens,
251                    self.model.dollars_per_1b_output_tokens,
252                    false,
253                ) {
254                    write_log(
255                        "dump_api_usage",
256                        &format!("dump_api_usage({key:?}, ..) failed with {e:?}"),
257                    );
258                }
259            }
260
261            if let Some(path) = &self.dump_pdl_at {
262                if let Err(e) = dump_pdl(
263                    &self.messages,
264                    &response,
265                    &None,
266                    path,
267                    String::from("model: dummy, input_tokens: 0, output_tokens: 0, took: 0ms"),
268                ) {
269                    write_log(
270                        "dump_pdl",
271                        &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
272                    );
273
274                    // TODO: should it return an error?
275                    //       the api call was successful
276                }
277            }
278
279            return Ok(Response::dummy(response));
280        }
281
282        let body = serde_json::to_string(&body)?;
283        let api_key = self.model.get_api_key()?;
284        write_log(
285            "chat_request::send",
286            &format!("entered chat_request::send() with {} bytes, model: {}", body.len(), self.model.name),
287        );
288
289        for _ in 0..(self.max_retry + 1) {
290            let mut request = client.post(&post_url)
291                .header(reqwest::header::CONTENT_TYPE, "application/json")
292                .body(body.clone());
293
294            match &self.model.api_provider {
295                ApiProvider::Anthropic => {
296                    request = request.header("x-api-key", api_key.clone())
297                        .header("anthropic-version", "2023-06-01");
298                },
299                ApiProvider::Google => {},
300                _ if !api_key.is_empty() => {
301                    request = request.bearer_auth(api_key.clone());
302                },
303                _ => {},
304            }
305
306            if let Some(t) = self.timeout {
307                request = request.timeout(Duration::from_millis(t));
308            }
309
310            write_log(
311                "chat_request::send",
312                "a request sent",
313            );
314            let response = request.send().await;
315            write_log(
316                "chat_request::send",
317                "got a response from a request",
318            );
319
320            match response {
321                Ok(response) => match response.status().as_u16() {
322                    200 => match response.text().await {
323                        Ok(text) => {
324                            match serde_json::from_str::<Value>(&text) {
325                                Ok(v) => match self.dump_json(&v, "response") {
326                                    Err(e) => {
327                                        write_log(
328                                            "dump_json",
329                                            &format!("dump_json(\"response\", ..) failed with {e:?}"),
330                                        );
331                                    },
332                                    Ok(_) => {},
333                                },
334                                Err(e) => {
335                                    write_log(
336                                        "dump_json",
337                                        &format!("dump_json(\"response\", ..) failed with {e:?}"),
338                                    );
339                                },
340                            }
341
342                            match Response::from_str(&text, &self.model.api_provider) {
343                                Ok(result) => {
344                                    if let Some(key) = &self.dump_api_usage_at {
345                                        if let Err(e) = dump_api_usage(
346                                            key,
347                                            result.get_prompt_token_count() as u64,
348                                            result.get_output_token_count() as u64,
349                                            self.model.dollars_per_1b_input_tokens,
350                                            self.model.dollars_per_1b_output_tokens,
351                                            false,
352                                        ) {
353                                            write_log(
354                                                "dump_api_usage",
355                                                &format!("dump_api_usage({key:?}, ..) failed with {e:?}"),
356                                            );
357                                        }
358                                    }
359
360                                    if let Some(path) = &self.dump_pdl_at {
361                                        if let Err(e) = dump_pdl(
362                                            &self.messages,
363                                            &result.get_message(0).map(|m| m.to_string()).unwrap_or(String::new()),
364                                            &result.get_reasoning(0).map(|m| m.to_string()),
365                                            path,
366                                            format!(
367                                                "model: {}, input_tokens: {}, output_tokens: {}, took: {}ms",
368                                                self.model.name,
369                                                result.get_prompt_token_count(),
370                                                result.get_output_token_count(),
371                                                Instant::now().duration_since(started_at.clone()).as_millis(),
372                                            ),
373                                        ) {
374                                            write_log(
375                                                "dump_pdl",
376                                                &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
377                                            );
378
379                                            // TODO: should it return an error?
380                                            //       the api call was successful
381                                        }
382                                    }
383
384                                    return Ok(result);
385                                },
386                                Err(e) => {
387                                    write_log(
388                                        "Response::from_str",
389                                        &format!("Response::from_str(..) failed with {e:?}"),
390                                    );
391                                    curr_error = e;
392                                },
393                            }
394                        },
395                        Err(e) => {
396                            write_log(
397                                "response.text()",
398                                &format!("response.text() failed with {e:?}"),
399                            );
400                            curr_error = Error::ReqwestError(e);
401                        },
402                    },
403                    status_code => {
404                        curr_error = Error::ServerError {
405                            status_code,
406                            body: response.text().await,
407                        };
408
409                        if let Some(path) = &self.dump_pdl_at {
410                            if let Err(e) = dump_pdl(
411                                &self.messages,
412                                "",
413                                &None,
414                                path,
415                                format!("{}# error: {curr_error:?} #{}", '{', '}'),
416                            ) {
417                                write_log(
418                                    "dump_pdl",
419                                    &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
420                                );
421                            }
422                        }
423
424                        // There are 2 cases.
425                        // 1. `self.model.can_read_images` is false, but it can actually read images.
426                        //   - Maybe `self.model` is outdated.
427                        //   - That's why it tries once even though there is an image.
428                        // 2. `self.model.can_read_images` is false, and it cannot read images.
429                        //   - There's no point in retrying, so it just escapes immediately with a better error.
430                        if !self.model.can_read_images && self.messages.iter().any(|message| message.has_image()) {
431                            return Err(Error::CannotReadImage(self.model.name.clone()));
432                        }
433
434                        // Assumption: if the input is invalid, there's no point in retrying over and over.
435                        if status_code == 400 {
436                            return Err(curr_error);
437                        }
438                    },
439                },
440                Err(e) => {
441                    write_log(
442                        "request.send().await",
443                        &format!("request.send().await failed with {e:?}"),
444                    );
445                    curr_error = Error::ReqwestError(e);
446                },
447            }
448
449            task::sleep(Duration::from_millis(self.sleep_between_retries)).await
450        }
451
452        Err(curr_error)
453    }
454
455    fn dump_json(&self, j: &Value, header: &str) -> Result<(), Error> {
456        if let Some(dir) = &self.dump_json_at {
457            if !exists(dir) {
458                create_dir_all(dir)?;
459            }
460
461            let path = join(
462                &dir,
463                &format!("{header}-{}.json", Local::now().to_rfc3339()),
464            )?;
465            write_string(&path, &serde_json::to_string_pretty(j)?, WriteMode::AlwaysCreate)?;
466        }
467
468        Ok(())
469    }
470}
471
472impl Default for Request {
473    fn default() -> Self {
474        Request {
475            messages: vec![],
476            model: (&ModelRaw::llama_70b()).try_into().unwrap(),
477            temperature: None,
478            frequency_penalty: None,
479            max_tokens: None,
480            timeout: Some(20_000),
481            max_retry: 2,
482            sleep_between_retries: 6_000,
483            dump_api_usage_at: None,
484            dump_pdl_at: None,
485            dump_json_at: None,
486            schema: None,
487            schema_max_try: 3,
488        }
489    }
490}