ragit_api/
request.rs

1use async_std::task;
2use chrono::Local;
3use crate::{ApiProvider, Error};
4use crate::message::message_to_json;
5use crate::model::{Model, ModelRaw};
6use crate::record::{
7    RecordAt,
8    dump_pdl,
9    record_api_usage,
10};
11use crate::response::Response;
12use ragit_fs::{
13    WriteMode,
14    create_dir_all,
15    exists,
16    join,
17    write_log,
18    write_string,
19};
20use ragit_pdl::{Message, Role, Schema};
21use serde::de::DeserializeOwned;
22use serde_json::{Map, Value};
23use std::time::{Duration, Instant};
24
25#[derive(Clone, Debug)]
26pub struct Request {
27    pub messages: Vec<Message>,
28    pub model: Model,
29    pub temperature: Option<f64>,
30    pub frequency_penalty: Option<f64>,
31    pub max_tokens: Option<usize>,
32
33    /// milliseconds
34    pub timeout: Option<u64>,
35
36    /// It tries 1 + max_retry times.
37    pub max_retry: usize,
38
39    /// milliseconds
40    pub sleep_between_retries: u64,
41    pub record_api_usage_at: Option<RecordAt>,
42
43    /// It dumps the AI conversation in pdl format. See <https://crates.io/crates/ragit-pdl> to read about pdl.
44    pub dump_pdl_at: Option<String>,
45
46    /// It's a directory, not a file. If given, it dumps `dir/request-<timestamp>.json` and `dir/response-<timestamp>.json`.
47    pub dump_json_at: Option<String>,
48
49    /// It can force LLMs to create a json output with a given schema.
50    /// You have to call `send_and_validate` instead of `send` if you want
51    /// to force the schema.
52    pub schema: Option<Schema>,
53
54    /// If LLMs fail to generate a valid schema `schema_max_try` times,
55    /// it returns a default value. If it's 0, it wouldn't call LLM at all!
56    pub schema_max_try: usize,
57}
58
59impl Request {
60    pub fn is_valid(&self) -> bool {
61        self.messages.len() > 1
62        && self.messages.len() & 1 == 0  // the last message must be user's
63        && self.messages[0].is_valid_system_prompt()  // I'm not sure whether all the models require the first message to be a system prompt. but it would be safer to guarantee that
64        && {
65            let mut flag = true;
66
67            for (index, message) in self.messages[1..].iter().enumerate() {
68                if index & 1 == 0 && !message.is_user_prompt() {
69                    flag = false;
70                    break;
71                }
72
73                else if index & 1 == 1 && !message.is_assistant_prompt() {
74                    flag = false;
75                    break;
76                }
77            }
78
79            flag
80        }
81    }
82
83    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
84    pub fn build_json_body(&self) -> Value {
85        match &self.model.api_provider {
86            ApiProvider::OpenAi { .. } | ApiProvider::Cohere => {
87                let mut result = Map::new();
88                result.insert(String::from("model"), self.model.api_name.clone().into());
89                let mut messages = vec![];
90
91                for message in self.messages.iter() {
92                    messages.push(message_to_json(message, &self.model.api_provider));
93                }
94
95                result.insert(String::from("messages"), messages.into());
96
97                if let Some(temperature) = self.temperature {
98                    result.insert(String::from("temperature"), temperature.into());
99                }
100
101                if let Some(frequency_penalty) = self.frequency_penalty {
102                    result.insert(String::from("frequency_penalty"), frequency_penalty.into());
103                }
104
105                if let Some(max_tokens) = self.max_tokens {
106                    result.insert(String::from("max_tokens"), max_tokens.into());
107                }
108
109                result.into()
110            },
111            ApiProvider::Anthropic => {
112                let mut result = Map::new();
113                result.insert(String::from("model"), self.model.api_name.clone().into());
114                let mut messages = vec![];
115                let mut system_prompt = vec![];
116
117                for message in self.messages.iter() {
118                    if message.role == Role::System {
119                        system_prompt.push(message.content[0].unwrap_str().to_string());
120                    }
121
122                    else {
123                        messages.push(message_to_json(message, &ApiProvider::Anthropic));
124                    }
125                }
126
127                let system_prompt = system_prompt.concat();
128
129                if !system_prompt.is_empty() {
130                    result.insert(String::from("system"), system_prompt.into());
131                }
132
133                result.insert(String::from("messages"), messages.into());
134
135                if let Some(temperature) = self.temperature {
136                    result.insert(String::from("temperature"), temperature.into()).unwrap();
137                }
138
139                if let Some(frequency_penalty) = self.frequency_penalty {
140                    result.insert(String::from("frequency_penalty"), frequency_penalty.into());
141                }
142
143                // it's a required field
144                result.insert(String::from("max_tokens"), self.max_tokens.unwrap_or(2048).into());
145
146                result.into()
147            },
148            ApiProvider::Test(_) => Value::Null,
149        }
150    }
151
152    /// It panics if `schema` field is missing.
153    /// It doesn't tell you whether the default value is used or not.
154    pub async fn send_and_validate<T: DeserializeOwned>(&self, default: T) -> Result<T, Error> {
155        let mut state = self.clone();
156        let mut messages = self.messages.clone();
157
158        for _ in 0..state.schema_max_try {
159            state.messages = messages.clone();
160            let response = state.send().await?;
161            let response = response.get_message(0).unwrap();
162
163            match state.schema.as_ref().unwrap().validate(&response) {
164                Ok(v) => {
165                    return Ok(serde_json::from_value::<T>(v)?);
166                },
167                Err(error_message) => {
168                    messages.push(Message::simple_message(Role::Assistant, response.to_string()));
169                    messages.push(Message::simple_message(Role::User, error_message));
170                },
171            }
172        }
173
174        Ok(default)
175    }
176
177    /// NOTE: this function dies ocassionally, for no reason.
178    ///
179    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
180    pub fn blocking_send(&self) -> Result<Response, Error> {
181        futures::executor::block_on(self.send())
182    }
183
184    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
185    pub async fn send(&self) -> Result<Response, Error> {
186        let started_at = Instant::now();
187        let client = reqwest::Client::new();
188        let mut curr_error = Error::NoTry;
189
190        let post_url = self.model.get_api_url();
191        let body = self.build_json_body();
192
193        if let Err(e) = self.dump_json(&body, "request") {
194            write_log(
195                "dump_json",
196                &format!("dump_json(\"request\", ..) failed with {e:?}"),
197            );
198        }
199
200        if let ApiProvider::Test(test_model) = &self.model.api_provider {
201            let response = test_model.get_dummy_response(&self.messages);
202
203            if let Some(key) = &self.record_api_usage_at {
204                if let Err(e) = record_api_usage(
205                    key,
206                    0,
207                    0,
208                    self.model.dollars_per_1b_input_tokens,
209                    self.model.dollars_per_1b_output_tokens,
210                    false,
211                ) {
212                    write_log(
213                        "record_api_usage",
214                        &format!("record_api_usage({key:?}, ..) failed with {e:?}"),
215                    );
216                }
217            }
218
219            if let Some(path) = &self.dump_pdl_at {
220                if let Err(e) = dump_pdl(
221                    &self.messages,
222                    &response,
223                    &None,
224                    path,
225                    String::from("model: dummy, input_tokens: 0, output_tokens: 0, took: 0ms"),
226                ) {
227                    write_log(
228                        "dump_pdl",
229                        &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
230                    );
231
232                    // TODO: should it return an error?
233                    //       the api call was successful
234                }
235            }
236
237            return Ok(Response::dummy(response));
238        }
239
240        let body = serde_json::to_string(&body)?;
241        let api_key = self.model.get_api_key()?;
242        write_log(
243            "chat_request::send",
244            &format!("entered chat_request::send() with {} bytes, model: {}", body.len(), self.model.name),
245        );
246
247        for _ in 0..(self.max_retry + 1) {
248            let mut request = client.post(post_url)
249                .header(reqwest::header::CONTENT_TYPE, "application/json")
250                .body(body.clone());
251
252            if let ApiProvider::Anthropic = &self.model.api_provider {
253                request = request.header("x-api-key", api_key.clone())
254                    .header("anthropic-version", "2023-06-01");
255            }
256
257            else if !api_key.is_empty() {
258                request = request.bearer_auth(api_key.clone());
259            }
260
261            if let Some(t) = self.timeout {
262                request = request.timeout(Duration::from_millis(t));
263            }
264
265            write_log(
266                "chat_request::send",
267                "a request sent",
268            );
269            let response = request.send().await;
270            write_log(
271                "chat_request::send",
272                "got a response from a request",
273            );
274
275            match response {
276                Ok(response) => match response.status().as_u16() {
277                    200 => match response.text().await {
278                        Ok(text) => {
279                            match serde_json::from_str::<Value>(&text) {
280                                Ok(v) => match self.dump_json(&v, "response") {
281                                    Err(e) => {
282                                        write_log(
283                                            "dump_json",
284                                            &format!("dump_json(\"response\", ..) failed with {e:?}"),
285                                        );
286                                    },
287                                    Ok(_) => {},
288                                },
289                                Err(e) => {
290                                    write_log(
291                                        "dump_json",
292                                        &format!("dump_json(\"response\", ..) failed with {e:?}"),
293                                    );
294                                },
295                            }
296
297                            match Response::from_str(&text, &self.model.api_provider) {
298                                Ok(result) => {
299                                    if let Some(key) = &self.record_api_usage_at {
300                                        if let Err(e) = record_api_usage(
301                                            key,
302                                            result.get_prompt_token_count() as u64,
303                                            result.get_output_token_count() as u64,
304                                            self.model.dollars_per_1b_input_tokens,
305                                            self.model.dollars_per_1b_output_tokens,
306                                            false,
307                                        ) {
308                                            write_log(
309                                                "record_api_usage",
310                                                &format!("record_api_usage({key:?}, ..) failed with {e:?}"),
311                                            );
312                                        }
313                                    }
314
315                                    if let Some(path) = &self.dump_pdl_at {
316                                        if let Err(e) = dump_pdl(
317                                            &self.messages,
318                                            &result.get_message(0).map(|m| m.to_string()).unwrap_or(String::new()),
319                                            &result.get_reasoning(0).map(|m| m.to_string()),
320                                            path,
321                                            format!(
322                                                "model: {}, input_tokens: {}, output_tokens: {}, took: {}ms",
323                                                self.model.name,
324                                                result.get_prompt_token_count(),
325                                                result.get_output_token_count(),
326                                                Instant::now().duration_since(started_at.clone()).as_millis(),
327                                            ),
328                                        ) {
329                                            write_log(
330                                                "dump_pdl",
331                                                &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
332                                            );
333
334                                            // TODO: should it return an error?
335                                            //       the api call was successful
336                                        }
337                                    }
338
339                                    return Ok(result);
340                                },
341                                Err(e) => {
342                                    write_log(
343                                        "Response::from_str",
344                                        &format!("Response::from_str(..) failed with {e:?}"),
345                                    );
346                                    curr_error = e;
347                                },
348                            }
349                        },
350                        Err(e) => {
351                            write_log(
352                                "response.text()",
353                                &format!("response.text() failed with {e:?}"),
354                            );
355                            curr_error = Error::ReqwestError(e);
356                        },
357                    },
358                    status_code => {
359                        curr_error = Error::ServerError {
360                            status_code,
361                            body: response.text().await,
362                        };
363
364                        if let Some(path) = &self.dump_pdl_at {
365                            if let Err(e) = dump_pdl(
366                                &self.messages,
367                                "",
368                                &None,
369                                path,
370                                format!("{}# error: {curr_error:?} #{}", '{', '}'),
371                            ) {
372                                write_log(
373                                    "dump_pdl",
374                                    &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
375                                );
376                            }
377                        }
378                    },
379                },
380                Err(e) => {
381                    write_log(
382                        "request.send().await",
383                        &format!("request.send().await failed with {e:?}"),
384                    );
385                    curr_error = Error::ReqwestError(e);
386                },
387            }
388
389            task::sleep(Duration::from_millis(self.sleep_between_retries)).await
390        }
391
392        Err(curr_error)
393    }
394
395    fn dump_json(&self, j: &Value, header: &str) -> Result<(), Error> {
396        if let Some(dir) = &self.dump_json_at {
397            if !exists(dir) {
398                create_dir_all(dir)?;
399            }
400
401            let path = join(
402                &dir,
403                &format!("{header}-{}.json", Local::now().to_rfc3339()),
404            )?;
405            write_string(&path, &serde_json::to_string_pretty(j)?, WriteMode::AlwaysCreate)?;
406        }
407
408        Ok(())
409    }
410}
411
412impl Default for Request {
413    fn default() -> Self {
414        Request {
415            messages: vec![],
416            model: (&ModelRaw::llama_70b()).try_into().unwrap(),
417            temperature: None,
418            frequency_penalty: None,
419            max_tokens: None,
420            timeout: Some(20_000),
421            max_retry: 2,
422            sleep_between_retries: 6_000,
423            record_api_usage_at: None,
424            dump_pdl_at: None,
425            dump_json_at: None,
426            schema: None,
427            schema_max_try: 3,
428        }
429    }
430}