ragit_api/
request.rs

1use async_std::task;
2use chrono::Local;
3use crate::{ApiProvider, Error};
4use crate::message::message_to_json;
5use crate::model::{Model, ModelRaw};
6use crate::record::{
7    RecordAt,
8    dump_pdl,
9    record_api_usage,
10};
11use crate::response::Response;
12use ragit_fs::{WriteMode, join, write_log, write_string};
13use ragit_pdl::{Message, Role, Schema};
14use serde::de::DeserializeOwned;
15use serde_json::{Map, Value};
16use std::time::{Duration, Instant};
17
18#[derive(Clone, Debug)]
19pub struct Request {
20    pub messages: Vec<Message>,
21    pub model: Model,
22    pub temperature: Option<f64>,
23    pub frequency_penalty: Option<f64>,
24    pub max_tokens: Option<usize>,
25
26    /// milliseconds
27    pub timeout: Option<u64>,
28
29    /// It tries 1 + max_retry times.
30    pub max_retry: usize,
31
32    /// milliseconds
33    pub sleep_between_retries: u64,
34    pub record_api_usage_at: Option<RecordAt>,
35
36    /// It dumps the AI conversation in pdl format. See <https://crates.io/crates/ragit-pdl> to read about pdl.
37    pub dump_pdl_at: Option<String>,
38
39    /// It's a directory, not a file. If given, it dumps `dir/request-<timestamp>.json` and `dir/response-<timestamp>.json`.
40    pub dump_json_at: Option<String>,
41
42    /// It can force LLMs to create a json output with a given schema.
43    /// You have to call `send_and_validate` instead of `send` if you want
44    /// to force the schema.
45    pub schema: Option<Schema>,
46
47    /// If LLMs fail to generate a valid schema `schema_max_try` times,
48    /// it returns a default value. If it's 0, it wouldn't call LLM at all!
49    pub schema_max_try: usize,
50}
51
52impl Request {
53    pub fn is_valid(&self) -> bool {
54        self.messages.len() > 1
55        && self.messages.len() & 1 == 0  // the last message must be user's
56        && self.messages[0].is_valid_system_prompt()  // I'm not sure whether all the models require the first message to be a system prompt. but it would be safer to guarantee that
57        && {
58            let mut flag = true;
59
60            for (index, message) in self.messages[1..].iter().enumerate() {
61                if index & 1 == 0 && !message.is_user_prompt() {
62                    flag = false;
63                    break;
64                }
65
66                else if index & 1 == 1 && !message.is_assistant_prompt() {
67                    flag = false;
68                    break;
69                }
70            }
71
72            flag
73        }
74    }
75
76    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
77    pub fn build_json_body(&self) -> Value {
78        match &self.model.api_provider {
79            ApiProvider::OpenAi { .. } | ApiProvider::Cohere => {
80                let mut result = Map::new();
81                result.insert(String::from("model"), self.model.api_name.clone().into());
82                let mut messages = vec![];
83
84                for message in self.messages.iter() {
85                    messages.push(message_to_json(message, &self.model.api_provider));
86                }
87
88                result.insert(String::from("messages"), messages.into());
89
90                if let Some(temperature) = self.temperature {
91                    result.insert(String::from("temperature"), temperature.into());
92                }
93
94                if let Some(frequency_penalty) = self.frequency_penalty {
95                    result.insert(String::from("frequency_penalty"), frequency_penalty.into());
96                }
97
98                if let Some(max_tokens) = self.max_tokens {
99                    result.insert(String::from("max_tokens"), max_tokens.into());
100                }
101
102                result.into()
103            },
104            ApiProvider::Anthropic => {
105                let mut result = Map::new();
106                result.insert(String::from("model"), self.model.api_name.clone().into());
107                let mut messages = vec![];
108                let mut system_prompt = vec![];
109
110                for message in self.messages.iter() {
111                    if message.role == Role::System {
112                        system_prompt.push(message.content[0].unwrap_str().to_string());
113                    }
114
115                    else {
116                        messages.push(message_to_json(message, &ApiProvider::Anthropic));
117                    }
118                }
119
120                let system_prompt = system_prompt.concat();
121
122                if !system_prompt.is_empty() {
123                    result.insert(String::from("system"), system_prompt.into());
124                }
125
126                result.insert(String::from("messages"), messages.into());
127
128                if let Some(temperature) = self.temperature {
129                    result.insert(String::from("temperature"), temperature.into()).unwrap();
130                }
131
132                if let Some(frequency_penalty) = self.frequency_penalty {
133                    result.insert(String::from("frequency_penalty"), frequency_penalty.into());
134                }
135
136                // it's a required field
137                result.insert(String::from("max_tokens"), self.max_tokens.unwrap_or(2048).into());
138
139                result.into()
140            },
141            ApiProvider::Test(_) => Value::Null,
142        }
143    }
144
145    /// It panics if `schema` field is missing.
146    /// It doesn't tell you whether the default value is used or not.
147    pub async fn send_and_validate<T: DeserializeOwned>(&self, default: T) -> Result<T, Error> {
148        let mut state = self.clone();
149        let mut messages = self.messages.clone();
150
151        for _ in 0..state.schema_max_try {
152            state.messages = messages.clone();
153            let response = state.send().await?;
154            let response = response.get_message(0).unwrap();
155
156            match state.schema.as_ref().unwrap().validate(&response) {
157                Ok(v) => {
158                    return Ok(serde_json::from_value::<T>(v)?);
159                },
160                Err(error_message) => {
161                    messages.push(Message::simple_message(Role::Assistant, response.to_string()));
162                    messages.push(Message::simple_message(Role::User, error_message));
163                },
164            }
165        }
166
167        Ok(default)
168    }
169
170    /// NOTE: this function dies ocassionally, for no reason.
171    ///
172    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
173    pub fn blocking_send(&self) -> Result<Response, Error> {
174        futures::executor::block_on(self.send())
175    }
176
177    /// It panics if its fields are not complete. If you're not sure, run `self.is_valid()` before sending a request.
178    pub async fn send(&self) -> Result<Response, Error> {
179        let started_at = Instant::now();
180        let client = reqwest::Client::new();
181        let mut curr_error = Error::NoTry;
182
183        let post_url = self.model.get_api_url();
184        let body = self.build_json_body();
185
186        if let Err(e) = self.dump_json(&body, "request") {
187            write_log(
188                "dump_json",
189                &format!("dump_json(\"request\", ..) failed with {e:?}"),
190            );
191        }
192
193        if let ApiProvider::Test(test_model) = &self.model.api_provider {
194            let response = test_model.get_dummy_response(&self.messages);
195
196            if let Some(key) = &self.record_api_usage_at {
197                if let Err(e) = record_api_usage(
198                    key,
199                    0,
200                    0,
201                    self.model.dollars_per_1b_input_tokens,
202                    self.model.dollars_per_1b_output_tokens,
203                    false,
204                ) {
205                    write_log(
206                        "record_api_usage",
207                        &format!("record_api_usage({key:?}, ..) failed with {e:?}"),
208                    );
209                }
210            }
211
212            if let Some(path) = &self.dump_pdl_at {
213                if let Err(e) = dump_pdl(
214                    &self.messages,
215                    &response,
216                    &None,
217                    path,
218                    String::from("model: dummy, input_tokens: 0, output_tokens: 0, took: 0ms"),
219                ) {
220                    write_log(
221                        "dump_pdl",
222                        &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
223                    );
224
225                    // TODO: should it return an error?
226                    //       the api call was successful
227                }
228            }
229
230            return Ok(Response::dummy(response));
231        }
232
233        let body = serde_json::to_string(&body)?;
234        let api_key = self.model.get_api_key()?;
235        write_log(
236            "chat_request::send",
237            &format!("entered chat_request::send() with {} bytes, model: {}", body.len(), self.model.name),
238        );
239
240        for _ in 0..(self.max_retry + 1) {
241            let mut request = client.post(post_url)
242                .header(reqwest::header::CONTENT_TYPE, "application/json")
243                .body(body.clone());
244
245            if let ApiProvider::Anthropic = &self.model.api_provider {
246                request = request.header("x-api-key", api_key.clone())
247                    .header("anthropic-version", "2023-06-01");
248            }
249
250            else if !api_key.is_empty() {
251                request = request.bearer_auth(api_key.clone());
252            }
253
254            if let Some(t) = self.timeout {
255                request = request.timeout(Duration::from_millis(t));
256            }
257
258            write_log(
259                "chat_request::send",
260                "a request sent",
261            );
262            let response = request.send().await;
263            write_log(
264                "chat_request::send",
265                "got a response from a request",
266            );
267
268            match response {
269                Ok(response) => match response.status().as_u16() {
270                    200 => match response.text().await {
271                        Ok(text) => {
272                            match serde_json::from_str::<Value>(&text) {
273                                Ok(v) => match self.dump_json(&v, "response") {
274                                    Err(e) => {
275                                        write_log(
276                                            "dump_json",
277                                            &format!("dump_json(\"response\", ..) failed with {e:?}"),
278                                        );
279                                    },
280                                    Ok(_) => {},
281                                },
282                                Err(e) => {
283                                    write_log(
284                                        "dump_json",
285                                        &format!("dump_json(\"response\", ..) failed with {e:?}"),
286                                    );
287                                },
288                            }
289
290                            match Response::from_str(&text, &self.model.api_provider) {
291                                Ok(result) => {
292                                    if let Some(key) = &self.record_api_usage_at {
293                                        if let Err(e) = record_api_usage(
294                                            key,
295                                            result.get_prompt_token_count() as u64,
296                                            result.get_output_token_count() as u64,
297                                            self.model.dollars_per_1b_input_tokens,
298                                            self.model.dollars_per_1b_output_tokens,
299                                            false,
300                                        ) {
301                                            write_log(
302                                                "record_api_usage",
303                                                &format!("record_api_usage({key:?}, ..) failed with {e:?}"),
304                                            );
305                                        }
306                                    }
307
308                                    if let Some(path) = &self.dump_pdl_at {
309                                        if let Err(e) = dump_pdl(
310                                            &self.messages,
311                                            &result.get_message(0).map(|m| m.to_string()).unwrap_or(String::new()),
312                                            &result.get_reasoning(0).map(|m| m.to_string()),
313                                            path,
314                                            format!(
315                                                "model: {}, input_tokens: {}, output_tokens: {}, took: {}ms",
316                                                self.model.name,
317                                                result.get_prompt_token_count(),
318                                                result.get_output_token_count(),
319                                                Instant::now().duration_since(started_at.clone()).as_millis(),
320                                            ),
321                                        ) {
322                                            write_log(
323                                                "dump_pdl",
324                                                &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
325                                            );
326
327                                            // TODO: should it return an error?
328                                            //       the api call was successful
329                                        }
330                                    }
331
332                                    return Ok(result);
333                                },
334                                Err(e) => {
335                                    write_log(
336                                        "Response::from_str",
337                                        &format!("Response::from_str(..) failed with {e:?}"),
338                                    );
339                                    curr_error = e;
340                                },
341                            }
342                        },
343                        Err(e) => {
344                            write_log(
345                                "response.text()",
346                                &format!("response.text() failed with {e:?}"),
347                            );
348                            curr_error = Error::ReqwestError(e);
349                        },
350                    },
351                    status_code => {
352                        curr_error = Error::ServerError {
353                            status_code,
354                            body: response.text().await,
355                        };
356
357                        if let Some(path) = &self.dump_pdl_at {
358                            if let Err(e) = dump_pdl(
359                                &self.messages,
360                                "",
361                                &None,
362                                path,
363                                format!("{}# error: {curr_error:?} #{}", '{', '}'),
364                            ) {
365                                write_log(
366                                    "dump_pdl",
367                                    &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
368                                );
369                            }
370                        }
371                    },
372                },
373                Err(e) => {
374                    write_log(
375                        "request.send().await",
376                        &format!("request.send().await failed with {e:?}"),
377                    );
378                    curr_error = Error::ReqwestError(e);
379                },
380            }
381
382            task::sleep(Duration::from_millis(self.sleep_between_retries)).await
383        }
384
385        Err(curr_error)
386    }
387
388    fn dump_json(&self, j: &Value, header: &str) -> Result<(), Error> {
389        if let Some(dir) = &self.dump_json_at {
390            let path = join(
391                &dir,
392                &format!("{header}-{}.json", Local::now().to_rfc3339()),
393            )?;
394            write_string(&path, &serde_json::to_string_pretty(j)?, WriteMode::AlwaysCreate)?;
395        }
396
397        Ok(())
398    }
399}
400
401impl Default for Request {
402    fn default() -> Self {
403        Request {
404            messages: vec![],
405            model: (&ModelRaw::llama_70b()).try_into().unwrap(),
406            temperature: None,
407            frequency_penalty: None,
408            max_tokens: None,
409            timeout: Some(20_000),
410            max_retry: 2,
411            sleep_between_retries: 6_000,
412            record_api_usage_at: None,
413            dump_pdl_at: None,
414            dump_json_at: None,
415            schema: None,
416            schema_max_try: 3,
417        }
418    }
419}