1use async_std::task;
2use chrono::Local;
3use crate::{ApiProvider, Error};
4use crate::message::message_to_json;
5use crate::model::{Model, ModelRaw};
6use crate::record::{
7 RecordAt,
8 dump_pdl,
9 record_api_usage,
10};
11use crate::response::Response;
12use ragit_fs::{
13 WriteMode,
14 create_dir_all,
15 exists,
16 join,
17 write_log,
18 write_string,
19};
20use ragit_pdl::{Message, Role, Schema};
21use serde::de::DeserializeOwned;
22use serde_json::{Map, Value};
23use std::time::{Duration, Instant};
24
25#[derive(Clone, Debug)]
26pub struct Request {
27 pub messages: Vec<Message>,
28 pub model: Model,
29 pub temperature: Option<f64>,
30 pub frequency_penalty: Option<f64>,
31 pub max_tokens: Option<usize>,
32
33 pub timeout: Option<u64>,
35
36 pub max_retry: usize,
38
39 pub sleep_between_retries: u64,
41 pub record_api_usage_at: Option<RecordAt>,
42
43 pub dump_pdl_at: Option<String>,
45
46 pub dump_json_at: Option<String>,
48
49 pub schema: Option<Schema>,
53
54 pub schema_max_try: usize,
57}
58
59impl Request {
60 pub fn is_valid(&self) -> bool {
61 self.messages.len() > 1
62 && self.messages.len() & 1 == 0 && self.messages[0].is_valid_system_prompt() && {
65 let mut flag = true;
66
67 for (index, message) in self.messages[1..].iter().enumerate() {
68 if index & 1 == 0 && !message.is_user_prompt() {
69 flag = false;
70 break;
71 }
72
73 else if index & 1 == 1 && !message.is_assistant_prompt() {
74 flag = false;
75 break;
76 }
77 }
78
79 flag
80 }
81 }
82
83 pub fn build_json_body(&self) -> Value {
85 match &self.model.api_provider {
86 ApiProvider::OpenAi { .. } | ApiProvider::Cohere => {
87 let mut result = Map::new();
88 result.insert(String::from("model"), self.model.api_name.clone().into());
89 let mut messages = vec![];
90
91 for message in self.messages.iter() {
92 messages.push(message_to_json(message, &self.model.api_provider));
93 }
94
95 result.insert(String::from("messages"), messages.into());
96
97 if let Some(temperature) = self.temperature {
98 result.insert(String::from("temperature"), temperature.into());
99 }
100
101 if let Some(frequency_penalty) = self.frequency_penalty {
102 result.insert(String::from("frequency_penalty"), frequency_penalty.into());
103 }
104
105 if let Some(max_tokens) = self.max_tokens {
106 result.insert(String::from("max_tokens"), max_tokens.into());
107 }
108
109 result.into()
110 },
111 ApiProvider::Anthropic => {
112 let mut result = Map::new();
113 result.insert(String::from("model"), self.model.api_name.clone().into());
114 let mut messages = vec![];
115 let mut system_prompt = vec![];
116
117 for message in self.messages.iter() {
118 if message.role == Role::System {
119 system_prompt.push(message.content[0].unwrap_str().to_string());
120 }
121
122 else {
123 messages.push(message_to_json(message, &ApiProvider::Anthropic));
124 }
125 }
126
127 let system_prompt = system_prompt.concat();
128
129 if !system_prompt.is_empty() {
130 result.insert(String::from("system"), system_prompt.into());
131 }
132
133 result.insert(String::from("messages"), messages.into());
134
135 if let Some(temperature) = self.temperature {
136 result.insert(String::from("temperature"), temperature.into()).unwrap();
137 }
138
139 if let Some(frequency_penalty) = self.frequency_penalty {
140 result.insert(String::from("frequency_penalty"), frequency_penalty.into());
141 }
142
143 result.insert(String::from("max_tokens"), self.max_tokens.unwrap_or(2048).into());
145
146 result.into()
147 },
148 ApiProvider::Test(_) => Value::Null,
149 }
150 }
151
152 pub async fn send_and_validate<T: DeserializeOwned>(&self, default: T) -> Result<T, Error> {
155 let mut state = self.clone();
156 let mut messages = self.messages.clone();
157
158 for _ in 0..state.schema_max_try {
159 state.messages = messages.clone();
160 let response = state.send().await?;
161 let response = response.get_message(0).unwrap();
162
163 match state.schema.as_ref().unwrap().validate(&response) {
164 Ok(v) => {
165 return Ok(serde_json::from_value::<T>(v)?);
166 },
167 Err(error_message) => {
168 messages.push(Message::simple_message(Role::Assistant, response.to_string()));
169 messages.push(Message::simple_message(Role::User, error_message));
170 },
171 }
172 }
173
174 Ok(default)
175 }
176
177 pub fn blocking_send(&self) -> Result<Response, Error> {
181 futures::executor::block_on(self.send())
182 }
183
184 pub async fn send(&self) -> Result<Response, Error> {
186 let started_at = Instant::now();
187 let client = reqwest::Client::new();
188 let mut curr_error = Error::NoTry;
189
190 let post_url = self.model.get_api_url();
191 let body = self.build_json_body();
192
193 if let Err(e) = self.dump_json(&body, "request") {
194 write_log(
195 "dump_json",
196 &format!("dump_json(\"request\", ..) failed with {e:?}"),
197 );
198 }
199
200 if let ApiProvider::Test(test_model) = &self.model.api_provider {
201 let response = test_model.get_dummy_response(&self.messages);
202
203 if let Some(key) = &self.record_api_usage_at {
204 if let Err(e) = record_api_usage(
205 key,
206 0,
207 0,
208 self.model.dollars_per_1b_input_tokens,
209 self.model.dollars_per_1b_output_tokens,
210 false,
211 ) {
212 write_log(
213 "record_api_usage",
214 &format!("record_api_usage({key:?}, ..) failed with {e:?}"),
215 );
216 }
217 }
218
219 if let Some(path) = &self.dump_pdl_at {
220 if let Err(e) = dump_pdl(
221 &self.messages,
222 &response,
223 &None,
224 path,
225 String::from("model: dummy, input_tokens: 0, output_tokens: 0, took: 0ms"),
226 ) {
227 write_log(
228 "dump_pdl",
229 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
230 );
231
232 }
235 }
236
237 return Ok(Response::dummy(response));
238 }
239
240 let body = serde_json::to_string(&body)?;
241 let api_key = self.model.get_api_key()?;
242 write_log(
243 "chat_request::send",
244 &format!("entered chat_request::send() with {} bytes, model: {}", body.len(), self.model.name),
245 );
246
247 for _ in 0..(self.max_retry + 1) {
248 let mut request = client.post(post_url)
249 .header(reqwest::header::CONTENT_TYPE, "application/json")
250 .body(body.clone());
251
252 if let ApiProvider::Anthropic = &self.model.api_provider {
253 request = request.header("x-api-key", api_key.clone())
254 .header("anthropic-version", "2023-06-01");
255 }
256
257 else if !api_key.is_empty() {
258 request = request.bearer_auth(api_key.clone());
259 }
260
261 if let Some(t) = self.timeout {
262 request = request.timeout(Duration::from_millis(t));
263 }
264
265 write_log(
266 "chat_request::send",
267 "a request sent",
268 );
269 let response = request.send().await;
270 write_log(
271 "chat_request::send",
272 "got a response from a request",
273 );
274
275 match response {
276 Ok(response) => match response.status().as_u16() {
277 200 => match response.text().await {
278 Ok(text) => {
279 match serde_json::from_str::<Value>(&text) {
280 Ok(v) => match self.dump_json(&v, "response") {
281 Err(e) => {
282 write_log(
283 "dump_json",
284 &format!("dump_json(\"response\", ..) failed with {e:?}"),
285 );
286 },
287 Ok(_) => {},
288 },
289 Err(e) => {
290 write_log(
291 "dump_json",
292 &format!("dump_json(\"response\", ..) failed with {e:?}"),
293 );
294 },
295 }
296
297 match Response::from_str(&text, &self.model.api_provider) {
298 Ok(result) => {
299 if let Some(key) = &self.record_api_usage_at {
300 if let Err(e) = record_api_usage(
301 key,
302 result.get_prompt_token_count() as u64,
303 result.get_output_token_count() as u64,
304 self.model.dollars_per_1b_input_tokens,
305 self.model.dollars_per_1b_output_tokens,
306 false,
307 ) {
308 write_log(
309 "record_api_usage",
310 &format!("record_api_usage({key:?}, ..) failed with {e:?}"),
311 );
312 }
313 }
314
315 if let Some(path) = &self.dump_pdl_at {
316 if let Err(e) = dump_pdl(
317 &self.messages,
318 &result.get_message(0).map(|m| m.to_string()).unwrap_or(String::new()),
319 &result.get_reasoning(0).map(|m| m.to_string()),
320 path,
321 format!(
322 "model: {}, input_tokens: {}, output_tokens: {}, took: {}ms",
323 self.model.name,
324 result.get_prompt_token_count(),
325 result.get_output_token_count(),
326 Instant::now().duration_since(started_at.clone()).as_millis(),
327 ),
328 ) {
329 write_log(
330 "dump_pdl",
331 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
332 );
333
334 }
337 }
338
339 return Ok(result);
340 },
341 Err(e) => {
342 write_log(
343 "Response::from_str",
344 &format!("Response::from_str(..) failed with {e:?}"),
345 );
346 curr_error = e;
347 },
348 }
349 },
350 Err(e) => {
351 write_log(
352 "response.text()",
353 &format!("response.text() failed with {e:?}"),
354 );
355 curr_error = Error::ReqwestError(e);
356 },
357 },
358 status_code => {
359 curr_error = Error::ServerError {
360 status_code,
361 body: response.text().await,
362 };
363
364 if let Some(path) = &self.dump_pdl_at {
365 if let Err(e) = dump_pdl(
366 &self.messages,
367 "",
368 &None,
369 path,
370 format!("{}# error: {curr_error:?} #{}", '{', '}'),
371 ) {
372 write_log(
373 "dump_pdl",
374 &format!("dump_pdl({path:?}, ..) failed with {e:?}"),
375 );
376 }
377 }
378 },
379 },
380 Err(e) => {
381 write_log(
382 "request.send().await",
383 &format!("request.send().await failed with {e:?}"),
384 );
385 curr_error = Error::ReqwestError(e);
386 },
387 }
388
389 task::sleep(Duration::from_millis(self.sleep_between_retries)).await
390 }
391
392 Err(curr_error)
393 }
394
395 fn dump_json(&self, j: &Value, header: &str) -> Result<(), Error> {
396 if let Some(dir) = &self.dump_json_at {
397 if !exists(dir) {
398 create_dir_all(dir)?;
399 }
400
401 let path = join(
402 &dir,
403 &format!("{header}-{}.json", Local::now().to_rfc3339()),
404 )?;
405 write_string(&path, &serde_json::to_string_pretty(j)?, WriteMode::AlwaysCreate)?;
406 }
407
408 Ok(())
409 }
410}
411
412impl Default for Request {
413 fn default() -> Self {
414 Request {
415 messages: vec![],
416 model: (&ModelRaw::llama_70b()).try_into().unwrap(),
417 temperature: None,
418 frequency_penalty: None,
419 max_tokens: None,
420 timeout: Some(20_000),
421 max_retry: 2,
422 sleep_between_retries: 6_000,
423 record_api_usage_at: None,
424 dump_pdl_at: None,
425 dump_json_at: None,
426 schema: None,
427 schema_max_try: 3,
428 }
429 }
430}