1use crate::api_error::ApiError;
2use crate::api_error::ApiErrorType;
3use crate::api_result::ApiResult;
4use crate::fine_tune::FineTune;
5use crate::json::AudioTranscriptionResponse;
6use crate::json::ChatRequestInfo;
7use crate::json::CompletionRequestInfo;
8use crate::json::FileDeletedResponse;
9use crate::json::FileInfoResponse;
10use crate::json::FileUploadResponse;
11use crate::json::Files;
12use crate::json::ImageRequestInfo;
15use crate::json::Message;
16use crate::json::ModelReturned;
17use crate::json::Usage;
18use chrono::{NaiveDateTime, TimeZone, Utc};
19use curl::easy::Easy;
20use curl::easy::List;
21use reqwest::blocking::multipart;
22use reqwest::blocking::Client;
23use reqwest::blocking::ClientBuilder;
24use reqwest::blocking::RequestBuilder;
25use reqwest::header::HeaderMap;
26use reqwest::header::{HeaderValue, AUTHORIZATION, CONTENT_TYPE};
27use reqwest::StatusCode;
28use serde_json::json;
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt;
32use std::fmt::Display;
33use std::io::Read;
34use std::path::Path;
35use std::result::Result;
36use std::time::Instant;
37
38const API_URL: &str = "https://api.openai.com/v1";
63
64#[derive(Debug)]
65pub struct ApiInterface<'a> {
66 client: Client,
68
69 api_key: &'a str,
71
72 pub tokens: u32,
74
75 pub temperature: f32,
77
78 pub context: Vec<String>,
80
81 pub system_prompt: String,
83}
84
85impl Display for ApiInterface<'_> {
86 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 write!(
88 f,
89 "Temperature: {}\n\
90 Tokens: {}\n\
91 Context length: {}\n\
92 System prompt: {}",
93 self.temperature,
94 self.tokens,
95 self.context.len(),
96 self.system_prompt,
97 )
98 }
99}
100
101impl<'a> ApiInterface<'_> {
102 pub fn new(api_key: &'a str, tokens: u32, temperature: f32) -> ApiInterface<'a> {
103 ApiInterface {
104 client: ClientBuilder::new()
105 .timeout(std::time::Duration::from_secs(1200))
106 .pool_idle_timeout(None)
107 .connection_verbose(false)
108 .build()
109 .unwrap(),
110 api_key,
111 tokens,
112 temperature,
113 context: vec![],
115 system_prompt: String::new(),
116 }
117 }
118
119 pub fn file_info(&self, file_id: String) -> Result<ApiResult<String>, Box<dyn Error>> {
121 let uri = format!("{API_URL}/files/{file_id}");
123 let response = self
124 .client
125 .get(uri.as_str())
126 .header("Content-Type", "application/json")
127 .header(AUTHORIZATION, format!("Bearer {}", self.api_key))
128 .send()?;
129 let headers = Self::header_map_to_hash_map(response.headers());
130 if response.status() != StatusCode::OK {
131 let reason = response
132 .status()
133 .canonical_reason()
134 .unwrap_or("Unknown Reason");
135 Err(Box::new(ApiError::new(
136 ApiErrorType::Status(response.status(), reason.to_string()),
137 headers,
138 )))
139 } else {
140 let fir: FileInfoResponse = response.json()?;
141 let datetime = NaiveDateTime::from_timestamp_opt(fir.created_at as i64, 0).unwrap();
142 let datetime_utc = Utc.from_utc_datetime(&datetime);
143
144 let datetime_string = datetime_utc.format("%Y-%m-%d %H:%M:%S").to_string();
145 Ok(ApiResult {
146 headers,
147 body: format!(
148 "Size: {} Name: {} Created: {}",
149 fir.bytes, fir.filename, datetime_string
150 ), })
152 }
153
154 }
158
159 pub fn file_contents(&self, file_id: String) -> Result<ApiResult<String>, Box<dyn Error>> {
161 let uri = format!("{API_URL}/files/{file_id}/content");
163 let response = self
164 .client
165 .get(uri.as_str())
166 .header("Content-Type", "application/json")
167 .header("Authorization", format!("Bearer {}", self.api_key))
168 .send()?;
169 let headers = Self::header_map_to_hash_map(response.headers());
170 if response.status() != StatusCode::OK {
171 let reason = response
172 .status()
173 .canonical_reason()
174 .unwrap_or("Unknown Reason");
175 Err(Box::new(ApiError::new(
176 ApiErrorType::Status(response.status(), reason.to_string()),
177 headers,
178 )))
179 } else {
180 let content = response.text()?;
181 Ok(ApiResult::new(content, headers))
182 }
183 }
184 pub fn files_delete(&self, file_id: String) -> Result<ApiResult<()>, Box<dyn Error>> {
186 let uri = format!("{API_URL}/files/{file_id}");
188 let response = self
189 .client
190 .delete(uri.as_str())
191 .header("Content-Type", "application/json")
192 .header("Authorization", format!("Bearer {}", self.api_key))
193 .send()?;
194 let headers = Self::header_map_to_hash_map(response.headers());
195 if response.status() != StatusCode::OK {
196 let reason = response
197 .status()
198 .canonical_reason()
199 .unwrap_or("Unknown Reason");
200 Err(Box::new(ApiError::new(
201 ApiErrorType::Status(response.status(), reason.to_string()),
202 headers,
203 )))
204 } else {
205 let fdr: FileDeletedResponse = response.json()?;
206 if !fdr.deleted || fdr.object != *"file" || fdr.id != file_id {
207 Err(Box::new(ApiError::new(
208 ApiErrorType::Error(format!(
209 "File delete response:{:?} file_id: {file_id}",
210 fdr
211 )),
212 headers,
213 )))
214 } else {
215 Ok(ApiResult::new_e(HashMap::new()))
216 }
217 }
218 }
219 pub fn files_list(&self) -> Result<ApiResult<Vec<(String, String)>>, Box<dyn Error>> {
221 let uri = format!("{}/files", API_URL);
223 let response = self
224 .client
225 .get(uri)
226 .header("Authorization", format!("Bearer {}", self.api_key))
227 .send()?;
228
229 let headers = Self::header_map_to_hash_map(response.headers());
230 let response_strings: Vec<(String, String)> = if response.status() != StatusCode::OK {
231 let reason = response
232 .status()
233 .canonical_reason()
234 .unwrap_or("Unknown Reason");
235 return Err(Box::new(ApiError::new(
236 ApiErrorType::Status(response.status(), reason.to_string()),
237 headers,
238 )));
239 } else {
240 response
241 .json::<Files>()?
242 .data
243 .iter()
244 .map(|x| (x.filename.clone(), x.id.clone()))
245 .collect()
246 };
247 Ok(ApiResult::new_v(response_strings, headers))
248 }
249
250 pub fn files_upload_fine_tuning(
252 &self,
253 file: &Path,
254 ) -> Result<ApiResult<String>, Box<dyn Error>> {
255 let uri = format!("{}/files", API_URL);
272
273 let file_field = multipart::Part::file(file)?;
274 let purpose_field = multipart::Part::text("fine-tune");
275 let form = multipart::Form::new()
276 .part("file", file_field)
277 .part("purpose", purpose_field);
278 let response = self
279 .client
280 .post(uri)
281 .header("Authorization", format!("Bearer {}", self.api_key))
282 .multipart(form)
283 .send()?;
284 let headers = Self::header_map_to_hash_map(response.headers());
285 let response_text: String = if response.status() != StatusCode::OK {
286 let reason = response
287 .status()
288 .canonical_reason()
289 .unwrap_or("Unknown Reason");
290 return Err(Box::new(ApiError::new(
291 ApiErrorType::Status(response.status(), reason.to_string()),
292 headers,
293 )));
294 } else {
295 response.json::<FileUploadResponse>()?.id
296 };
297
298 Ok(ApiResult::new(response_text, headers))
299 }
300 pub fn audio_transcription(
304 &mut self,
305 audio_file: &Path,
306 prompt: Option<&str>,
307 ) -> Result<ApiResult<String>, Box<dyn Error>> {
308 let uri = format!("{}/audio/transcriptions", API_URL);
321
322 let file_field = multipart::Part::file(audio_file)?;
323 let model_field = multipart::Part::text("whisper-1");
324 let mut form = multipart::Form::new()
325 .part("file", file_field)
326 .part("model", model_field);
327 if let Some(prompt) = prompt {
328 let p: String = prompt.to_string();
329 let prompt_field = multipart::Part::text(p);
330 form = form.part("prompt", prompt_field);
331 }
332
333 let response = self
335 .client
336 .post(uri)
337 .header("Authorization", format!("Bearer {}", self.api_key))
338 .multipart(form)
339 .send()?;
340
341 let headers = Self::header_map_to_hash_map(response.headers());
342 let response_text: String = if response.status() != StatusCode::OK {
343 format!(
344 "Failed: Status: {}.\nResponse.path({})",
345 response
346 .status()
347 .canonical_reason()
348 .unwrap_or("Unknown Reason"),
349 response.url().path(),
350 )
351 } else {
352 response.json::<AudioTranscriptionResponse>()?.text
353 };
354
355 Ok(ApiResult::new(response_text, headers))
356 }
357 pub fn fine_tune_create(
385 &self,
386 training_file_id: String,
387 ) -> Result<ApiResult<FineTune>, Box<dyn Error>> {
388 let uri = format!("{API_URL}/fine-tunes");
389 let request_body = json!({
390 "training_file": training_file_id.as_str()
391 });
392
393 let mut response = self
394 .client
395 .post(uri)
396 .header("Authorization", format!("Bearer {}", self.api_key))
397 .header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
398 .json(&request_body)
399 .send()?;
400 let mut s = String::new();
401 _ = response.read_to_string(&mut s)?;
402 let headers = Self::header_map_to_hash_map(response.headers());
403 let st = s.as_str();
404 let fine_tune: FineTune = serde_json::from_str(st)?;
405
406 Ok(ApiResult {
420 headers,
421 body: fine_tune,
422 })
423 }
424 pub fn chat(&mut self, prompt: &str, model: &str) -> Result<ApiResult<String>, Box<dyn Error>> {
485 let uri = format!("{}/chat/completions", API_URL);
489
490 let mut messages: Vec<Message> = vec![]; if self.context.is_empty() {
499 messages.push(Message {
501 role: "system".to_string(),
502 content: self.system_prompt.clone(),
503 });
504 } else {
505 for i in 0..self.context.len() {
506 messages.push(Message {
507 role: "user".to_string(),
508 content: self.context[i].clone(),
509 });
510 }
511 }
512
513 let role = "user".to_string();
515 let content = prompt.to_string();
516 messages.push(Message { role, content });
517
518 let data = json!({
520 "messages": messages,
521 "model": model,
522 });
523
524 let (headers, response_string) = self.send_curl(&data, uri.as_str())?;
527 let json: ChatRequestInfo = serde_json::from_str(response_string.as_str())?;
528 let mut headers_ret = Self::usage_headers(json.usage.clone());
529 let cost: f64 = Self::cost(json.usage, model);
530 headers_ret.insert("Cost".to_string(), format!("{}", cost)); headers_ret.extend(headers);
544
545 let content = json.choices[0].message.content.clone();
546 self.context.push(prompt.to_string());
547 self.context.push(content.clone());
548
549 Ok(ApiResult::new(content, headers_ret))
550 }
551
552 pub fn get_context(&self) -> Result<Vec<String>, Box<dyn Error>> {
554 Ok(self.context.clone())
555 }
556
557 pub fn set_context(&mut self, context: Vec<String>) {
559 self.context = context;
560 }
561 pub fn completion(
565 &mut self,
566 prompt: &str,
567 model: &str,
568 ) -> Result<ApiResult<String>, Box<dyn Error>> {
569 let uri: String = format!("{}/completions", API_URL);
570
571 let payload = CompletionRequestInfo::new(prompt, model, self.temperature, self.tokens);
572
573 let response = self
574 .client
575 .post(uri)
576 .header("Authorization", format!("Bearer {}", self.api_key))
577 .header("Content-Type", "application/json")
578 .json(&payload)
579 .send()?;
580
581 let mut headers = Self::header_map_to_hash_map(response.headers());
582 let response_text: String = if response.status() != StatusCode::OK {
583 format!(
586 "Failed: Status: {}.\nResponse.path({})",
587 response
588 .status()
589 .canonical_reason()
590 .unwrap_or("Unknown Reason"),
591 response.url().path(),
592 )
593 } else {
594 let response_debug = format!("{:?}", &response);
596 let json: CompletionRequestInfo = match response.json() {
597 Ok(json) => json,
598 Err(err) => {
599 panic!("Failed to get json. {err}\n{response_debug}")
600 }
601 };
602
603 let finish_reason = json.choices[0].finish_reason.as_str();
606 if finish_reason != "stop" {
607 headers.insert("finsh reason".to_string(), finish_reason.to_string());
608 }
609
610 if json.choices[0].text.is_empty() {
611 panic!("Empty json.choices[0]. {:?}", &json);
612 } else {
613 json.choices[0].text.clone()
614 }
615 };
616 Ok(ApiResult::new(response_text, headers))
617 }
618
619 pub fn image(&mut self, prompt: &str) -> Result<ApiResult<String>, Box<dyn Error>> {
621 let uri: String = format!("{}/images/generations", API_URL);
623
624 let data = json!({
626 "prompt": prompt,
627 "size": "1024x1024",
628 });
629
630 let res = Client::new()
632 .post(uri)
633 .header("Authorization", format!("Bearer {}", self.api_key).as_str())
634 .header("Content-Type", "application/json")
635 .json(&data);
636
637 let response = match res.send() {
639 Ok(r) => r,
640 Err(err) => {
641 return Ok(ApiResult::new(
642 format!("Image: Response::send() failed: '{err}'"),
643 HashMap::new(),
644 ));
645 }
646 };
647
648 let headers = Self::header_map_to_hash_map(&response.headers().clone());
650 if !response.status().is_success() {
651 let reason = response
652 .status()
653 .canonical_reason()
654 .unwrap_or("Unknown Reason");
655 return Err(Box::new(ApiError::new(
656 ApiErrorType::Status(response.status(), reason.to_string()),
657 headers,
658 )));
659 }
661
662 let json: ImageRequestInfo = match response.json() {
664 Ok(json) => json,
665 Err(err) => {
666 return Err(Box::new(ApiError::new(
667 ApiErrorType::BadJson(format!("{err}")),
668 headers,
669 )))
670 }
671 };
672
673 Ok(ApiResult::new(json.data[0].url.clone(), headers))
675 }
676
677 pub fn image_edit(
682 &mut self,
683 prompt: &str,
684 image: &Path,
685 mask: &Path,
686 ) -> Result<ApiResult<String>, Box<dyn Error>> {
687 let uri = format!("{}/images/edits", API_URL);
689
690 let start = Instant::now();
692
693 let form = multipart::Form::new();
702 let form = match form.file("image", image) {
703 Ok(f) => match f.file("mask", mask) {
704 Ok(s) => s
705 .text("prompt", prompt.to_string())
706 .text("size", "1024x1024"),
707 Err(err) => {
708 return Err(Box::new(ApiError::new(
709 ApiErrorType::Error(format!("{err}")),
710 HashMap::new(),
711 )))
712 }
713 },
714 Err(err) => {
716 return Err(Box::new(ApiError::new(
717 ApiErrorType::Error(format!("{err}")),
718 HashMap::new(),
719 )))
720 }
721 };
722
723 let req_build: RequestBuilder = Client::new()
725 .post(uri.as_str())
726 .timeout(std::time::Duration::from_secs(1200))
727 .header("Authorization", format!("Bearer {}", self.api_key).as_str())
728 .multipart(form);
729
730 let response = match req_build.send() {
732 Ok(r) => r,
733 Err(err) => {
734 println!("Failed url: {uri} Err: {err}");
735 return Err(Box::new(err));
736 }
737 };
738
739 let headers = Self::header_map_to_hash_map(&response.headers().clone());
740 println!("Sent message: {:?}", start.elapsed());
741 if !response.status().is_success() {
742 let reason = response
743 .status()
744 .canonical_reason()
745 .unwrap_or("Unknown Reason");
746 return Err(Box::new(ApiError::new(
747 ApiErrorType::Status(response.status(), reason.to_string()),
748 headers,
749 )));
750 }
751 let response_dbg = format!("{:?}", response);
752 let json: ImageRequestInfo = match response.json() {
755 Ok(json) => json,
756 Err(err) => {
757 eprintln!("Failed to get json. {err} Response: {response_dbg}");
758 return Err(Box::new(err));
759 }
760 };
761
762 Ok(ApiResult::new(json.data[0].url.clone(), headers))
763 }
764
765 pub fn model_list(&self) -> Result<Vec<String>, Box<dyn Error>> {
768 let uri: String = format!("{}/models", API_URL);
769 let response = self
770 .client
771 .get(uri.as_str())
772 .header("Content-Type", "application/json")
773 .header("Authorization", format!("Bearer {}", self.api_key))
774 .send()?;
775 if !response.status().is_success() {
776 panic!("Failed call to get model list. {:?}", response);
779 }
780 let model_returned: ModelReturned = response.json().unwrap();
781 println!("{:?}", model_returned);
782 Ok(model_returned.data.iter().map(|x| x.root.clone()).collect())
783 }
785
786 fn cost(usage: Usage, model: &str) -> f64 {
788 if model.starts_with("gpt-4") {
790 usage.completion_tokens as f64 / 1000.0 * 12.0
791 + usage.prompt_tokens as f64 / 1000.0 * 0.06
792 } else if model.starts_with("gpt-3") {
793 usage.total_tokens as f64 / 1000.0 * 0.2
794 } else {
795 panic!("{model}");
796 }
797 }
798 fn usage_headers(usage: Usage) -> HashMap<String, String> {
799 let prompt_tokens = usage.prompt_tokens.to_string();
800 let completion_tokens = usage.completion_tokens.to_string();
801 let total_tokens = usage.total_tokens.to_string();
802 let mut result = HashMap::new();
803 result.insert("Tokens prompt".to_string(), prompt_tokens);
804 result.insert("Tokens completion".to_string(), completion_tokens);
805 result.insert("Tokens total".to_string(), total_tokens);
806 result
807 }
808
809 fn header_map_to_hash_map(header_map: &HeaderMap) -> HashMap<String, String> {
811 let mut hash_map = HashMap::new();
812 for (header_name, header_value) in header_map.iter() {
813 if let (Ok(name), Ok(value)) = (
814 header_name.to_string().as_str().trim().parse::<String>(),
815 header_value.to_str().map(str::to_owned),
816 ) {
817 hash_map.insert(name, value);
818 }
819 }
820 hash_map
821 }
822
823 pub fn clear_context(&mut self) {
825 self.context.clear();
826 }
827
828 fn send_curl(
831 &mut self,
832 data: &serde_json::Value,
833 uri: &str,
834 ) -> Result<(HashMap<String, String>, String), Box<dyn Error>> {
835 let body = format!("{data}");
836
837 let mut body = body.as_bytes();
838 let mut curl_easy = Easy::new();
839 curl_easy.url(uri)?;
840
841 let mut list = List::new();
843 list.append(format!("Authorization: Bearer {}", self.api_key).as_str())?;
844 list.append("Content-Type: application/json")?;
845 curl_easy.http_headers(list)?;
846
847 curl_easy.post_field_size(body.len() as u64)?;
849
850 let mut output_buffer = Vec::new();
852
853 let mut header_buffer = Vec::new();
855
856 let start = Instant::now();
858
859 {
860 let mut transfer = curl_easy.transfer();
863 transfer.header_function(|data| {
864 header_buffer.push(String::from_utf8(data.to_vec()).unwrap());
865 true
866 })?;
867 transfer.read_function(|buf| Ok(body.read(buf).unwrap_or(0)))?;
868 transfer.write_function(|data| {
869 output_buffer.extend_from_slice(data);
870 Ok(data.len())
871 })?;
872 transfer.perform()?;
873 }
874
875 let _duration = start.elapsed();
877
878 let result = String::from_utf8(output_buffer)?; let headers_hm: HashMap<String, String> = header_buffer
881 .into_iter()
882 .filter_map(|item| {
883 let mut parts = item.splitn(2, ':');
884 if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
885 Some((key.to_string(), value.trim().to_string()))
886 } else {
887 None
888 }
889 })
890 .collect();
891 Ok((headers_hm, result))
892 }
893}