ragit_api/
request.rs

1use async_std::task;
2use chrono::Local;
3use crate::{ApiProvider, Error};
4use crate::audit::{
5    AuditRecordAt,
6    dump_api_usage,
7    dump_pdl,
8};
9use crate::message::{message_contents_to_json_array, message_to_json};
10use crate::model::{Model, ModelRaw};
11use crate::response::Response;
12use ragit_fs::{
13    WriteMode,
14    create_dir_all,
15    exists,
16    join,
17    write_log,
18    write_string,
19};
20use ragit_pdl::{Message, Role, Schema};
21use serde::de::DeserializeOwned;
22use serde_json::{Map, Value};
23use std::time::{Duration, Instant};
24
25#[derive(Clone, Debug)]
26pub struct Request {
27    pub messages: Vec<Message>,
28    pub model: Model,
29    pub temperature: Option<f64>,
30    pub frequency_penalty: Option<f64>,
31    pub max_tokens: Option<usize>,
32
33    /// milliseconds
34    pub timeout: Option<u64>,
35
36    /// It tries 1 + max_retry times.
37    pub max_retry: usize,
38
39    /// milliseconds
40    pub sleep_between_retries: u64,
41    pub dump_api_usage_at: Option<AuditRecordAt>,
42
43    /// It dumps the AI conversation in pdl format. See <https://crates.io/crates/ragit-pdl> to read about pdl.
44    pub dump_pdl_at: Option<String>,
45
46    /// It's a directory, not a file. If given, it dumps `dir/request-<timestamp>.json` and `dir/response-<timestamp>.json`.
47    pub dump_json_at: Option<String>,
48
49    /// It can force LLMs to create a json output with a given schema.
50    /// You have to call `send_and_validate` instead of `send` if you want
51    /// to force the schema.
52    pub schema: Option<Schema>,
53
54    /// If LLMs fail to generate a valid schema `schema_max_try` times,
55    /// it returns a default value. If it's 0, it wouldn't call LLM at all!
56    pub schema_max_try: usize,
57}
58
59impl Request {
60    pub fn is_valid(&self) -> bool {
61        self.messages.len() > 1
62        && self.messages.len() & 1 == 0  // the last message must be user's
63        && self.messages[0].is_valid_system_prompt()  // I'm not sure whether all the models require the first message to be a system prompt. but it would be safer to guarantee that
64        && {
65            let mut flag = true;
66
67            for (index, message) in self.messages[1..].iter().enumerate() {
68                if index & 1 == 0 && !message.is_user_prompt() {
69                    flag = false;
70                    break;
71                }
72
73                else if index & 1 == 1 && !message.is_assistant_prompt() {
74                    flag = false;
75                    break;
76                }
77            }
78
79            flag
80        }
81    }
82
83    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
84    pub fn build_json_body(&self) -> Value {
85        match &self.model.api_provider {
86            ApiProvider::Google => {
87                let mut result = Map::new();
88                let mut contents = vec![];
89                let mut system_prompt = vec![];
90
91                for message in self.messages.iter() {
92                    if message.role == Role::System {
93                        match message_contents_to_json_array(&message.content, &ApiProvider::Google) {
94                            Value::Array(parts) => {
95                                system_prompt.push(parts);
96                            },
97                            _ => unreachable!(),
98                        }
99                    }
100
101                    else {
102                        contents.push(message_to_json(message, &self.model.api_provider));
103                    }
104                }
105
106                if !system_prompt.is_empty() {
107                    let parts = system_prompt.concat();
108                    let mut system_prompt = Map::new();
109                    system_prompt.insert(String::from("parts"), parts.into());
110                    result.insert(String::from("system_instruction"), system_prompt.into());
111                }
112
113                // TODO: temperature
114
115                result.insert(String::from("contents"), contents.into());
116                result.into()
117            },
118            ApiProvider::OpenAi { .. } | ApiProvider::Cohere => {
119                let mut result = Map::new();
120                result.insert(String::from("model"), self.model.api_name.clone().into());
121                let mut messages = vec![];
122
123                for message in self.messages.iter() {
124                    messages.push(message_to_json(message, &self.model.api_provider));
125                }
126
127                result.insert(String::from("messages"), messages.into());
128
129                if let Some(temperature) = self.temperature {
130                    result.insert(String::from("temperature"), temperature.into());
131                }
132
133                if let Some(frequency_penalty) = self.frequency_penalty {
134                    result.insert(String::from("frequency_penalty"), frequency_penalty.into());
135                }
136
137                if let Some(max_tokens) = self.max_tokens {
138                    result.insert(String::from("max_tokens"), max_tokens.into());
139                }
140
141                result.into()
142            },
143            ApiProvider::Anthropic => {
144                let mut result = Map::new();
145                result.insert(String::from("model"), self.model.api_name.clone().into());
146                let mut messages = vec![];
147                let mut system_prompt = vec![];
148
149                for message in self.messages.iter() {
150                    if message.role == Role::System {
151                        system_prompt.push(message.content[0].unwrap_str().to_string());
152                    }
153
154                    else {
155                        messages.push(message_to_json(message, &ApiProvider::Anthropic));
156                    }
157                }
158
159                let system_prompt = system_prompt.concat();
160
161                if !system_prompt.is_empty() {
162                    result.insert(String::from("system"), system_prompt.into());
163                }
164
165                result.insert(String::from("messages"), messages.into());
166
167                if let Some(temperature) = self.temperature {
168                    result.insert(String::from("temperature"), temperature.into());
169                }
170
171                if let Some(frequency_penalty) = self.frequency_penalty {
172                    result.insert(String::from("frequency_penalty"), frequency_penalty.into());
173                }
174
175                // it's a required field
176                result.insert(String::from("max_tokens"), self.max_tokens.unwrap_or(2048).into());
177
178                result.into()
179            },
180            ApiProvider::Test(_) => Value::Null,
181        }
182    }
183
184    /// It panics if `schema` field is missing.
185    /// It doesn't tell you whether the default value is used or not.
186    pub async fn send_and_validate<T: DeserializeOwned>(&self, default: T) -> Result<T, Error> {
187        let mut state = self.clone();
188        let mut messages = self.messages.clone();
189
190        for _ in 0..state.schema_max_try {
191            state.messages = messages.clone();
192            let response = state.send().await?;
193            let response = response.get_message(0).unwrap();
194
195            match state.schema.as_ref().unwrap().validate(&response) {
196                Ok(v) => {
197                    return Ok(serde_json::from_value::<T>(v)?);
198                },
199                Err(error_message) => {
200                    messages.push(Message::simple_message(Role::Assistant, response.to_string()));
201                    messages.push(Message::simple_message(Role::User, error_message));
202                },
203            }
204        }
205
206        Ok(default)
207    }
208
209    /// NOTE: this function dies ocassionally, for no reason.
210    ///
211    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
212    pub fn blocking_send(&self) -> Result<Response, Error> {
213        futures::executor::block_on(self.send())
214    }
215
216    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
217    pub async fn send(&self) -> Result<Response, Error> {
218        let started_at = Instant::now();
219        let client = reqwest::Client::new();
220        let mut curr_error = Error::NoTry;
221
222        let post_url = self.model.get_api_url()?;
223        let body = self.build_json_body();
224
225        if let Err(e) = self.dump_json(&body, "request") {
226            write_log(
227                "dump_json",
228                &format!("dump_json(\"request\", ..) failed with {e:?}"),
229            );
230        }
231
232        if let ApiProvider::Test(test_model) = &self.model.api_provider {
233            let response = test_model.get_dummy_response(&self.messages)?;
234
235            if let Some(key) = &self.dump_api_usage_at {
236                if let Err(e) = dump_api_usage(
237                    key,
238                    0,
239                    0,
240                    self.model.dollars_per_1b_input_tokens,
241                    self.model.dollars_per_1b_output_tokens,
242                    false,
243                ) {
244                    write_log(
245                        "dump_api_usage",
246                        &format!("dump_api_usage({key:?}, ..) failed with {e:?}"),
247                    );
248                }
249            }
250
251            if let Some(path) = &self.dump_pdl_at {
252                if let Err(e) = dump_pdl(
253                    &self.messages,
254                    &response,
255                    &None,
256                    path,
257                    String::from("model: dummy, input_tokens: 0, output_tokens: 0, took: 0ms"),
258                ) {
259                    write_log(
260                        "dump_pdl",
261                        &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
262                    );
263
264                    // TODO: should it return an error?
265                    //       the api call was successful
266                }
267            }
268
269            return Ok(Response::dummy(response));
270        }
271
272        let body = serde_json::to_string(&body)?;
273        let api_key = self.model.get_api_key()?;
274        write_log(
275            "chat_request::send",
276            &format!("entered chat_request::send() with {} bytes, model: {}", body.len(), self.model.name),
277        );
278
279        for _ in 0..(self.max_retry + 1) {
280            let mut request = client.post(&post_url)
281                .header(reqwest::header::CONTENT_TYPE, "application/json")
282                .body(body.clone());
283
284            match &self.model.api_provider {
285                ApiProvider::Anthropic => {
286                    request = request.header("x-api-key", api_key.clone())
287                        .header("anthropic-version", "2023-06-01");
288                },
289                ApiProvider::Google => {},
290                _ if !api_key.is_empty() => {
291                    request = request.bearer_auth(api_key.clone());
292                },
293                _ => {},
294            }
295
296            if let Some(t) = self.timeout {
297                request = request.timeout(Duration::from_millis(t));
298            }
299
300            write_log(
301                "chat_request::send",
302                "a request sent",
303            );
304            let response = request.send().await;
305            write_log(
306                "chat_request::send",
307                "got a response from a request",
308            );
309
310            match response {
311                Ok(response) => match response.status().as_u16() {
312                    200 => match response.text().await {
313                        Ok(text) => {
314                            match serde_json::from_str::<Value>(&text) {
315                                Ok(v) => match self.dump_json(&v, "response") {
316                                    Err(e) => {
317                                        write_log(
318                                            "dump_json",
319                                            &format!("dump_json(\"response\", ..) failed with {e:?}"),
320                                        );
321                                    },
322                                    Ok(_) => {},
323                                },
324                                Err(e) => {
325                                    write_log(
326                                        "dump_json",
327                                        &format!("dump_json(\"response\", ..) failed with {e:?}"),
328                                    );
329                                },
330                            }
331
332                            match Response::from_str(&text, &self.model.api_provider) {
333                                Ok(result) => {
334                                    if let Some(key) = &self.dump_api_usage_at {
335                                        if let Err(e) = dump_api_usage(
336                                            key,
337                                            result.get_prompt_token_count() as u64,
338                                            result.get_output_token_count() as u64,
339                                            self.model.dollars_per_1b_input_tokens,
340                                            self.model.dollars_per_1b_output_tokens,
341                                            false,
342                                        ) {
343                                            write_log(
344                                                "dump_api_usage",
345                                                &format!("dump_api_usage({key:?}, ..) failed with {e:?}"),
346                                            );
347                                        }
348                                    }
349
350                                    if let Some(path) = &self.dump_pdl_at {
351                                        if let Err(e) = dump_pdl(
352                                            &self.messages,
353                                            &result.get_message(0).map(|m| m.to_string()).unwrap_or(String::new()),
354                                            &result.get_reasoning(0).map(|m| m.to_string()),
355                                            path,
356                                            format!(
357                                                "model: {}, input_tokens: {}, output_tokens: {}, took: {}ms",
358                                                self.model.name,
359                                                result.get_prompt_token_count(),
360                                                result.get_output_token_count(),
361                                                Instant::now().duration_since(started_at.clone()).as_millis(),
362                                            ),
363                                        ) {
364                                            write_log(
365                                                "dump_pdl",
366                                                &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
367                                            );
368
369                                            // TODO: should it return an error?
370                                            //       the api call was successful
371                                        }
372                                    }
373
374                                    return Ok(result);
375                                },
376                                Err(e) => {
377                                    write_log(
378                                        "Response::from_str",
379                                        &format!("Response::from_str(..) failed with {e:?}"),
380                                    );
381                                    curr_error = e;
382                                },
383                            }
384                        },
385                        Err(e) => {
386                            write_log(
387                                "response.text()",
388                                &format!("response.text() failed with {e:?}"),
389                            );
390                            curr_error = Error::ReqwestError(e);
391                        },
392                    },
393                    status_code => {
394                        curr_error = Error::ServerError {
395                            status_code,
396                            body: response.text().await,
397                        };
398
399                        if let Some(path) = &self.dump_pdl_at {
400                            if let Err(e) = dump_pdl(
401                                &self.messages,
402                                "",
403                                &None,
404                                path,
405                                format!("{}# error: {curr_error:?} #{}", '{', '}'),
406                            ) {
407                                write_log(
408                                    "dump_pdl",
409                                    &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
410                                );
411                            }
412                        }
413
414                        // There are 2 cases.
415                        // 1. `self.model.can_read_images` is false, but it can actually read images.
416                        //   - Maybe `self.model` is outdated.
417                        //   - That's why it tries once even though there is an image.
418                        // 2. `self.model.can_read_images` is false, and it cannot read images.
419                        //   - There's no point in retrying, so it just escapes immediately with a better error.
420                        if !self.model.can_read_images && self.messages.iter().any(|message| message.has_image()) {
421                            return Err(Error::CannotReadImage(self.model.name.clone()));
422                        }
423                    },
424                },
425                Err(e) => {
426                    write_log(
427                        "request.send().await",
428                        &format!("request.send().await failed with {e:?}"),
429                    );
430                    curr_error = Error::ReqwestError(e);
431                },
432            }
433
434            task::sleep(Duration::from_millis(self.sleep_between_retries)).await
435        }
436
437        Err(curr_error)
438    }
439
440    fn dump_json(&self, j: &Value, header: &str) -> Result<(), Error> {
441        if let Some(dir) = &self.dump_json_at {
442            if !exists(dir) {
443                create_dir_all(dir)?;
444            }
445
446            let path = join(
447                &dir,
448                &format!("{header}-{}.json", Local::now().to_rfc3339()),
449            )?;
450            write_string(&path, &serde_json::to_string_pretty(j)?, WriteMode::AlwaysCreate)?;
451        }
452
453        Ok(())
454    }
455}
456
457impl Default for Request {
458    fn default() -> Self {
459        Request {
460            messages: vec![],
461            model: (&ModelRaw::llama_70b()).try_into().unwrap(),
462            temperature: None,
463            frequency_penalty: None,
464            max_tokens: None,
465            timeout: Some(20_000),
466            max_retry: 2,
467            sleep_between_retries: 6_000,
468            dump_api_usage_at: None,
469            dump_pdl_at: None,
470            dump_json_at: None,
471            schema: None,
472            schema_max_try: 3,
473        }
474    }
475}