llm_rs/
openai_interface.rs

1use crate::api_error::ApiError;
2use crate::api_error::ApiErrorType;
3use crate::api_result::ApiResult;
4use crate::fine_tune::FineTune;
5use crate::json::AudioTranscriptionResponse;
6use crate::json::ChatRequestInfo;
7use crate::json::CompletionRequestInfo;
8use crate::json::FileDeletedResponse;
9use crate::json::FileInfoResponse;
10use crate::json::FileUploadResponse;
11use crate::json::Files;
12// use crate::json::FineTuneCreateResponse;
13// use crate::json::FtRoot;
14use crate::json::ImageRequestInfo;
15use crate::json::Message;
16use crate::json::ModelReturned;
17use crate::json::Usage;
18use chrono::{NaiveDateTime, TimeZone, Utc};
19use curl::easy::Easy;
20use curl::easy::List;
21use reqwest::blocking::multipart;
22use reqwest::blocking::Client;
23use reqwest::blocking::ClientBuilder;
24use reqwest::blocking::RequestBuilder;
25use reqwest::header::HeaderMap;
26use reqwest::header::{HeaderValue, AUTHORIZATION, CONTENT_TYPE};
27use reqwest::StatusCode;
28use serde_json::json;
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt;
32use std::fmt::Display;
33use std::io::Read;
34use std::path::Path;
35use std::result::Result;
36use std::time::Instant;
37
38// URLS:
39// * => implemented
40// * Completions: POST https://api.openai.com/v1/completions
41// * Chat: POST https://api.openai.com/v1/completions
42// Edits: POST https://api.openai.com/v1/chat/completions
43// * Images, create: POST https://api.openai.com/v1/images/generations
44// * Images, edit: POST https://api.openai.com/v1/images/edits
45// Images, variations: POST https://api.openai.com/v1/images/variations
46// * Audio, transcription: POST https://api.openai.com/v1/audio/transcriptions
47// Audio, translation: POST https://api.openai.com/v1/audio/translations
48// * Files, list: GET https://api.openai.com/v1/files
49// * Files, upload: POST https://api.openai.com/v1/files
50// * Files, delete: DELETE https://api.openai.com/v1/files/{file_id}
51// * Files, retrieve: GET https://api.openai.com/v1/files/{file_id}
52// * Files, retrieve content: GET https://api.openai.com/v1/files/{file_id}/content
53// Fine tune, create: POST https://api.openai.com/v1/fine-tunes
54// Fine tune, list: GET https://api.openai.com/v1/fine-tunes
55// Fine tune, retrieve: GET https://api.openai.com/v1/fine-tunes/{fine_tune_id}
56// Fine tune, cancel: POST https://api.openai.com/v1/fine-tunes/{fine_tune_id}/cancel
57// Fine tune, events: GET https://api.openai.com/v1/fine-tunes/{fine_tune_id}/events
58// Fine tune, delete: DELETE https://api.openai.com/v1/models/{model}
59// Moderations: POST https://api.openai.com/v1/moderations
60
61/// Bas URI for requests
62const API_URL: &str = "https://api.openai.com/v1";
63
64#[derive(Debug)]
65pub struct ApiInterface<'a> {
66    /// Handles the communications with OpenAI
67    client: Client,
68
69    /// The secret key from OpenAI
70    api_key: &'a str,
71
72    /// Restricts the amount of text returned
73    pub tokens: u32,
74
75    /// Influences the predictability/repeatability of the model
76    pub temperature: f32,
77
78    /// Chat keeps its state here.
79    pub context: Vec<String>,
80
81    /// The chat model system prompt
82    pub system_prompt: String,
83}
84
85impl Display for ApiInterface<'_> {
86    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87        write!(
88            f,
89            "Temperature: {}\n\
90		     Tokens: {}\n\
91		     Context length: {}\n\
92		     System prompt: {}",
93            self.temperature,
94            self.tokens,
95            self.context.len(),
96            self.system_prompt,
97        )
98    }
99}
100
101impl<'a> ApiInterface<'_> {
102    pub fn new(api_key: &'a str, tokens: u32, temperature: f32) -> ApiInterface<'a> {
103        ApiInterface {
104            client: ClientBuilder::new()
105                .timeout(std::time::Duration::from_secs(1200))
106                .pool_idle_timeout(None)
107                .connection_verbose(false)
108                .build()
109                .unwrap(),
110            api_key,
111            tokens,
112            temperature,
113            // model: model.to_string(),
114            context: vec![],
115            system_prompt: String::new(),
116        }
117    }
118
119    /// Get information about a file
120    pub fn file_info(&self, file_id: String) -> Result<ApiResult<String>, Box<dyn Error>> {
121        // GET https://api.openai.com/v1/files/{file_id}
122        let uri = format!("{API_URL}/files/{file_id}");
123        let response = self
124            .client
125            .get(uri.as_str())
126            .header("Content-Type", "application/json")
127            .header(AUTHORIZATION, format!("Bearer {}", self.api_key))
128            .send()?;
129        let headers = Self::header_map_to_hash_map(response.headers());
130        if response.status() != StatusCode::OK {
131            let reason = response
132                .status()
133                .canonical_reason()
134                .unwrap_or("Unknown Reason");
135            Err(Box::new(ApiError::new(
136                ApiErrorType::Status(response.status(), reason.to_string()),
137                headers,
138            )))
139        } else {
140            let fir: FileInfoResponse = response.json()?;
141            let datetime = NaiveDateTime::from_timestamp_opt(fir.created_at as i64, 0).unwrap();
142            let datetime_utc = Utc.from_utc_datetime(&datetime);
143
144            let datetime_string = datetime_utc.format("%Y-%m-%d %H:%M:%S").to_string();
145            Ok(ApiResult {
146                headers,
147                body: format!(
148                    "Size: {} Name: {} Created: {}",
149                    fir.bytes, fir.filename, datetime_string
150                ), //fir.to_string(),
151            })
152        }
153
154        // //let result = ;
155
156        // Ok(ApiResult::new("".to_string(), HashMap::new())) // { headers: (), body: ()
157    }
158
159    /// Get file cotents
160    pub fn file_contents(&self, file_id: String) -> Result<ApiResult<String>, Box<dyn Error>> {
161        // GET https://api.openai.com/v1/files/{file_id}/content
162        let uri = format!("{API_URL}/files/{file_id}/content");
163        let response = self
164            .client
165            .get(uri.as_str())
166            .header("Content-Type", "application/json")
167            .header("Authorization", format!("Bearer {}", self.api_key))
168            .send()?;
169        let headers = Self::header_map_to_hash_map(response.headers());
170        if response.status() != StatusCode::OK {
171            let reason = response
172                .status()
173                .canonical_reason()
174                .unwrap_or("Unknown Reason");
175            Err(Box::new(ApiError::new(
176                ApiErrorType::Status(response.status(), reason.to_string()),
177                headers,
178            )))
179        } else {
180            let content = response.text()?;
181            Ok(ApiResult::new(content, headers))
182        }
183    }
184    /// Delete a file
185    pub fn files_delete(&self, file_id: String) -> Result<ApiResult<()>, Box<dyn Error>> {
186        // DELETE https://api.openai.com/v1/files/{file_id}
187        let uri = format!("{API_URL}/files/{file_id}");
188        let response = self
189            .client
190            .delete(uri.as_str())
191            .header("Content-Type", "application/json")
192            .header("Authorization", format!("Bearer {}", self.api_key))
193            .send()?;
194        let headers = Self::header_map_to_hash_map(response.headers());
195        if response.status() != StatusCode::OK {
196            let reason = response
197                .status()
198                .canonical_reason()
199                .unwrap_or("Unknown Reason");
200            Err(Box::new(ApiError::new(
201                ApiErrorType::Status(response.status(), reason.to_string()),
202                headers,
203            )))
204        } else {
205            let fdr: FileDeletedResponse = response.json()?;
206            if !fdr.deleted || fdr.object != *"file" || fdr.id != file_id {
207                Err(Box::new(ApiError::new(
208                    ApiErrorType::Error(format!(
209                        "File delete response:{:?}  file_id: {file_id}",
210                        fdr
211                    )),
212                    headers,
213                )))
214            } else {
215                Ok(ApiResult::new_e(HashMap::new()))
216            }
217        }
218    }
219    /// Get a list of all files stored on OpenAI
220    pub fn files_list(&self) -> Result<ApiResult<Vec<(String, String)>>, Box<dyn Error>> {
221        // GET https://api.openai.com/v1/files
222        let uri = format!("{}/files", API_URL);
223        let response = self
224            .client
225            .get(uri)
226            .header("Authorization", format!("Bearer {}", self.api_key))
227            .send()?;
228
229        let headers = Self::header_map_to_hash_map(response.headers());
230        let response_strings: Vec<(String, String)> = if response.status() != StatusCode::OK {
231            let reason = response
232                .status()
233                .canonical_reason()
234                .unwrap_or("Unknown Reason");
235            return Err(Box::new(ApiError::new(
236                ApiErrorType::Status(response.status(), reason.to_string()),
237                headers,
238            )));
239        } else {
240            response
241                .json::<Files>()?
242                .data
243                .iter()
244                .map(|x| (x.filename.clone(), x.id.clone()))
245                .collect()
246        };
247        Ok(ApiResult::new_v(response_strings, headers))
248    }
249
250    /// Upload a file for fine-tuning.
251    pub fn files_upload_fine_tuning(
252        &self,
253        file: &Path,
254    ) -> Result<ApiResult<String>, Box<dyn Error>> {
255        // Request
256        // curl https://api.openai.com/v1/files \
257        // -H "Authorization: Bearer $OPENAI_API_KEY" \
258        // -F purpose="fine-tune" \
259        // -F file="@mydata.jsonl"
260
261        // Response
262        // {
263        //   "id": "file-XjGxS3KTG0uNmNOK362iJua3",
264        //   "object": "file",
265        //   "bytes": 140,
266        //   "created_at": 1613779121,
267        //   "filename": "mydata.jsonl",
268        //   "purpose": "fine-tune"
269        // }
270
271        let uri = format!("{}/files", API_URL);
272
273        let file_field = multipart::Part::file(file)?;
274        let purpose_field = multipart::Part::text("fine-tune");
275        let form = multipart::Form::new()
276            .part("file", file_field)
277            .part("purpose", purpose_field);
278        let response = self
279            .client
280            .post(uri)
281            .header("Authorization", format!("Bearer {}", self.api_key))
282            .multipart(form)
283            .send()?;
284        let headers = Self::header_map_to_hash_map(response.headers());
285        let response_text: String = if response.status() != StatusCode::OK {
286            let reason = response
287                .status()
288                .canonical_reason()
289                .unwrap_or("Unknown Reason");
290            return Err(Box::new(ApiError::new(
291                ApiErrorType::Status(response.status(), reason.to_string()),
292                headers,
293            )));
294        } else {
295            response.json::<FileUploadResponse>()?.id
296        };
297
298        Ok(ApiResult::new(response_text, headers))
299    }
300    /// The audio file `audio_file` is tracscribed.  No `Usage` data
301    /// returned from this endpoint
302    /// Get an audio transcription
303    pub fn audio_transcription(
304        &mut self,
305        audio_file: &Path,
306        prompt: Option<&str>,
307    ) -> Result<ApiResult<String>, Box<dyn Error>> {
308        // Request
309        // curl https://api.openai.com/v1/audio/transcriptions \
310        //   -H "Authorization: Bearer $OPENAI_API_KEY" \
311        //   -H "Content-Type: multipart/form-data" \
312        //   -F file="@/path/to/file/audio.mp3" \
313        //   -F model="whisper-1"
314
315        // Respponse
316        // {
317        //   "text": "Imagine the....that."
318        // }
319
320        let uri = format!("{}/audio/transcriptions", API_URL);
321
322        let file_field = multipart::Part::file(audio_file)?;
323        let model_field = multipart::Part::text("whisper-1");
324        let mut form = multipart::Form::new()
325            .part("file", file_field)
326            .part("model", model_field);
327        if let Some(prompt) = prompt {
328            let p: String = prompt.to_string();
329            let prompt_field = multipart::Part::text(p);
330            form = form.part("prompt", prompt_field);
331        }
332
333        // let client = reqwest::blocking::Client::new();
334        let response = self
335            .client
336            .post(uri)
337            .header("Authorization", format!("Bearer {}", self.api_key))
338            .multipart(form)
339            .send()?;
340
341        let headers = Self::header_map_to_hash_map(response.headers());
342        let response_text: String = if response.status() != StatusCode::OK {
343            format!(
344                "Failed: Status: {}.\nResponse.path({})",
345                response
346                    .status()
347                    .canonical_reason()
348                    .unwrap_or("Unknown Reason"),
349                response.url().path(),
350            )
351        } else {
352            response.json::<AudioTranscriptionResponse>()?.text
353        };
354
355        Ok(ApiResult::new(response_text, headers))
356    }
357    // pub fn fine_tune_retrieve(
358    //     &self,
359    //     id: String,
360    // ) -> Result<ApiResult<FineTuneCreateResponse>, Box<dyn Error>> {
361    //     let uri = format!("{API_URL}/fine-tunes/{id}");
362    //     let response = self
363    //         .client
364    //         .get(uri)
365    //         .header("Authorization", format!("Bearer {}", self.api_key))
366    //         .send()?;
367
368    //     let headers = Self::header_map_to_hash_map(response.headers());
369    //     let body: FineTuneCreateResponse = if response.status() != StatusCode::OK {
370    //         let reason = response
371    //             .status()
372    //             .canonical_reason()
373    //             .unwrap_or("Unknown Reason");
374    //         return Err(Box::new(ApiError::new(
375    //             ApiErrorType::Status(response.status(), reason.to_string()),
376    //             headers,
377    //         )));
378    //     } else {
379    //         response.json::<FineTuneCreateResponse>()?
380    //     };
381
382    //     Ok(ApiResult { headers, body })
383    // }
384    pub fn fine_tune_create(
385        &self,
386        training_file_id: String,
387    ) -> Result<ApiResult<FineTune>, Box<dyn Error>> {
388        let uri = format!("{API_URL}/fine-tunes");
389        let request_body = json!({
390                "training_file": training_file_id.as_str()
391        });
392
393        let mut response = self
394            .client
395            .post(uri)
396            .header("Authorization", format!("Bearer {}", self.api_key))
397            .header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
398            .json(&request_body)
399            .send()?;
400        let mut s = String::new();
401        _ = response.read_to_string(&mut s)?;
402        let headers = Self::header_map_to_hash_map(response.headers());
403        let st = s.as_str();
404        let fine_tune: FineTune = serde_json::from_str(st)?;
405
406        // let body: FineTune = if response.status() != StatusCode::OK {
407        //     let reason = response
408        //         .status()
409        //         .canonical_reason()
410        //         .unwrap_or("Unknown Reason");
411        //     return Err(Box::new(ApiError::new(
412        //         ApiErrorType::Status(response.status(), reason.to_string()),
413        //         headers,
414        //     )));
415        // } else {
416        //     response.json::<FineTune>()?
417        // };
418
419        Ok(ApiResult {
420            headers,
421            body: fine_tune,
422        })
423    }
424    // pub fn fine_tune_info(
425    //     &self,
426    //     id: String,
427    // ) -> Result<ApiResult<FineTuneCreateResponse>, Box<dyn Error>> {
428    //     let uri = format!("{API_URL}/fine-tunes/{id}");
429    //     let response = self
430    //         .client
431    //         .get(uri)
432    //         .header("Authorization", format!("Bearer {}", self.api_key))
433    //         .send()?;
434
435    //     let headers = Self::header_map_to_hash_map(response.headers());
436    //     let body: FineTuneCreateResponse = if response.status() != StatusCode::OK {
437    //         let reason = response
438    //             .status()
439    //             .canonical_reason()
440    //             .unwrap_or("Unknown Reason");
441    //         return Err(Box::new(ApiError::new(
442    //             ApiErrorType::Status(response.status(), reason.to_string()),
443    //             headers,
444    //         )));
445    //     } else {
446    //         response.json::<FineTuneCreateResponse>()?
447    //     };
448
449    //     Ok(ApiResult { headers, body })
450    // }
451    // pub fn fine_tunes_list(&self) -> Result<ApiResult<FtRoot>, Box<dyn Error>> {
452    //     // endpoint
453    //     let uri = format!("{API_URL}/fine-tunes");
454    //     let response = self
455    //         .client
456    //         .get(uri)
457    //         .header("Authorization", format!("Bearer {}", self.api_key))
458    //         .send()?;
459
460    //     let headers = Self::header_map_to_hash_map(response.headers());
461    //     let response: FtRoot = if response.status() != StatusCode::OK {
462    //         let reason = response
463    //             .status()
464    //             .canonical_reason()
465    //             .unwrap_or("Unknown Reason");
466    //         return Err(Box::new(ApiError::new(
467    //             ApiErrorType::Status(response.status(), reason.to_string()),
468    //             headers,
469    //         )));
470    //     } else {
471    //         response.json::<FtRoot>()?
472    //     };
473
474    //     Ok(ApiResult {
475    //         headers, // HashMap::new(),
476    //         body: response,
477    //     })
478    // }
479    /// Finetune
480    /// Workflow:
481    ///
482
483    /// Documented [here](https://platform.openai.com/docs/api-reference/chat)
484    pub fn chat(&mut self, prompt: &str, model: &str) -> Result<ApiResult<String>, Box<dyn Error>> {
485        // An ongoing conversation with the LLM
486
487        // endpoint
488        let uri = format!("{}/chat/completions", API_URL);
489
490        // Model can be any of: gpt-4, gpt-4-0314, gpt-4-32k,
491        // gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
492        // https://platform.openai.com/docs/models/model-endpoint-compatibility
493
494        // Put the conversation so far in here
495        let mut messages: Vec<Message> = vec![]; // = [Message { role, content }];
496
497        // If here is any context, supply it
498        if self.context.is_empty() {
499            // Conversation starting.  Append system prompt to context
500            messages.push(Message {
501                role: "system".to_string(),
502                content: self.system_prompt.clone(),
503            });
504        } else {
505            for i in 0..self.context.len() {
506                messages.push(Message {
507                    role: "user".to_string(),
508                    content: self.context[i].clone(),
509                });
510            }
511        }
512
513        // Add in the latest installment, the prompt for this function
514        let role = "user".to_string();
515        let content = prompt.to_string();
516        messages.push(Message { role, content });
517
518        // The payload
519        let data = json!({
520            "messages": messages,
521            "model": model,
522        });
523
524        // Send the request and get the Json data as a String, convert
525        // into ``ChatRequestInfo`
526        let (headers, response_string) = self.send_curl(&data, uri.as_str())?;
527        let json: ChatRequestInfo = serde_json::from_str(response_string.as_str())?;
528        let mut headers_ret = Self::usage_headers(json.usage.clone());
529        let cost: f64 = Self::cost(json.usage, model);
530        // Define a function that returns a closure
531        // fn update_spent(cost: f64) -> impl FnMut(SharedState) -> SharedState {
532        //     move |mut ss| {
533        //         ss.spent += cost;
534        //         ss
535        //     }
536        // }
537
538        // // Call read_write_atomic with the closure
539        // let ss: SharedState = SharedState::read_write_atomic(update_spent(cost))?
540
541        headers_ret.insert("Cost".to_string(), format!("{}", cost)); //ss.spent));
542
543        headers_ret.extend(headers);
544
545        let content = json.choices[0].message.content.clone();
546        self.context.push(prompt.to_string());
547        self.context.push(content.clone());
548
549        Ok(ApiResult::new(content, headers_ret))
550    }
551
552    /// Read the record of the conversation
553    pub fn get_context(&self) -> Result<Vec<String>, Box<dyn Error>> {
554        Ok(self.context.clone())
555    }
556
557    /// Restore a record of a conversation
558    pub fn set_context(&mut self, context: Vec<String>) {
559        self.context = context;
560    }
561    /// [Documented](https://platform.openai.com/docs/api-reference/completions)
562    /// Takes the `prompt` and sends it to the LLM with no context.
563    /// The interface has to manage no state
564    pub fn completion(
565        &mut self,
566        prompt: &str,
567        model: &str,
568    ) -> Result<ApiResult<String>, Box<dyn Error>> {
569        let uri: String = format!("{}/completions", API_URL);
570
571        let payload = CompletionRequestInfo::new(prompt, model, self.temperature, self.tokens);
572
573        let response = self
574            .client
575            .post(uri)
576            .header("Authorization", format!("Bearer {}", self.api_key))
577            .header("Content-Type", "application/json")
578            .json(&payload)
579            .send()?;
580
581        let mut headers = Self::header_map_to_hash_map(response.headers());
582        let response_text: String = if response.status() != StatusCode::OK {
583            // There was some sort of failure.  Probably a network
584            // failure
585            format!(
586                "Failed: Status: {}.\nResponse.path({})",
587                response
588                    .status()
589                    .canonical_reason()
590                    .unwrap_or("Unknown Reason"),
591                response.url().path(),
592            )
593        } else {
594            // Got a good response from the LLM
595            let response_debug = format!("{:?}", &response);
596            let json: CompletionRequestInfo = match response.json() {
597                Ok(json) => json,
598                Err(err) => {
599                    panic!("Failed to get json.  {err}\n{response_debug}")
600                }
601            };
602
603            // The data about the query
604            // let choice_count = json.choices.len();
605            let finish_reason = json.choices[0].finish_reason.as_str();
606            if finish_reason != "stop" {
607                headers.insert("finsh reason".to_string(), finish_reason.to_string());
608            }
609
610            if json.choices[0].text.is_empty() {
611                panic!("Empty json.choices[0].  {:?}", &json);
612            } else {
613                json.choices[0].text.clone()
614            }
615        };
616        Ok(ApiResult::new(response_text, headers))
617    }
618
619    /// Handle image mode prompts
620    pub fn image(&mut self, prompt: &str) -> Result<ApiResult<String>, Box<dyn Error>> {
621        // Endpoint
622        let uri: String = format!("{}/images/generations", API_URL);
623
624        // Payload
625        let data = json!({
626                  "prompt":  prompt,
627                  "size": "1024x1024",
628        });
629
630        // Set up network comms
631        let res = Client::new()
632            .post(uri)
633            .header("Authorization", format!("Bearer {}", self.api_key).as_str())
634            .header("Content-Type", "application/json")
635            .json(&data);
636
637        // Send network request
638        let response = match res.send() {
639            Ok(r) => r,
640            Err(err) => {
641                return Ok(ApiResult::new(
642                    format!("Image: Response::send() failed: '{err}'"),
643                    HashMap::new(),
644                ));
645            }
646        };
647
648        // Prepare diagnostic data
649        let headers = Self::header_map_to_hash_map(&response.headers().clone());
650        if !response.status().is_success() {
651            let reason = response
652                .status()
653                .canonical_reason()
654                .unwrap_or("Unknown Reason");
655            return Err(Box::new(ApiError::new(
656                ApiErrorType::Status(response.status(), reason.to_string()),
657                headers,
658            )));
659            //return Ok(ApiResult::new("Request failed".to_string(), headers));
660        }
661
662        // Have a normal result.  Process it
663        let json: ImageRequestInfo = match response.json() {
664            Ok(json) => json,
665            Err(err) => {
666                return Err(Box::new(ApiError::new(
667                    ApiErrorType::BadJson(format!("{err}")),
668                    headers,
669                )))
670            }
671        };
672
673        // Success.
674        Ok(ApiResult::new(json.data[0].url.clone(), headers))
675    }
676
677    // Editing an image.  The mask defines the region to edit
678    // according to the prompt.  ??The prompt describes the whole
679    // image??
680    // https://platform.openai.com/docs/api-reference/images/create-edit
681    pub fn image_edit(
682        &mut self,
683        prompt: &str,
684        image: &Path,
685        mask: &Path,
686    ) -> Result<ApiResult<String>, Box<dyn Error>> {
687        // Endpoint
688        let uri = format!("{}/images/edits", API_URL);
689
690        // Some timeing.  TODO: Why here, in this function, and not everywhere?
691        let start = Instant::now();
692
693        // Need an image to edit.  If there is an image in `self.image`
694        // prefer that.  Failing that use `self.focus_image_url` In the
695        // second case the image refered to in the url is downloaded and
696        // put into `self.image`
697
698        // let mask_path = mask_file.path().to_owned();
699
700        // Prepare the payload to send to OpenAI
701        let form = multipart::Form::new();
702        let form = match form.file("image", image) {
703            Ok(f) => match f.file("mask", mask) {
704                Ok(s) => s
705                    .text("prompt", prompt.to_string())
706                    .text("size", "1024x1024"),
707                Err(err) => {
708                    return Err(Box::new(ApiError::new(
709                        ApiErrorType::Error(format!("{err}")),
710                        HashMap::new(),
711                    )))
712                }
713            },
714            // Err(err) => return Err(Box::new(ApiErrorType::Error("Err path: {err}".to_string()))),
715            Err(err) => {
716                return Err(Box::new(ApiError::new(
717                    ApiErrorType::Error(format!("{err}")),
718                    HashMap::new(),
719                )))
720            }
721        };
722
723        // Set up network comms
724        let req_build: RequestBuilder = Client::new()
725            .post(uri.as_str())
726            .timeout(std::time::Duration::from_secs(1200))
727            .header("Authorization", format!("Bearer {}", self.api_key).as_str())
728            .multipart(form);
729
730        // Send request
731        let response = match req_build.send() {
732            Ok(r) => r,
733            Err(err) => {
734                println!("Failed url: {uri} Err: {err}");
735                return Err(Box::new(err));
736            }
737        };
738
739        let headers = Self::header_map_to_hash_map(&response.headers().clone());
740        println!("Sent message: {:?}", start.elapsed());
741        if !response.status().is_success() {
742            let reason = response
743                .status()
744                .canonical_reason()
745                .unwrap_or("Unknown Reason");
746            return Err(Box::new(ApiError::new(
747                ApiErrorType::Status(response.status(), reason.to_string()),
748                headers,
749            )));
750        }
751        let response_dbg = format!("{:?}", response);
752        // let response_text = response.text()?;
753        // Ok(response_text)
754        let json: ImageRequestInfo = match response.json() {
755            Ok(json) => json,
756            Err(err) => {
757                eprintln!("Failed to get json. {err} Response: {response_dbg}");
758                return Err(Box::new(err));
759            }
760        };
761
762        Ok(ApiResult::new(json.data[0].url.clone(), headers))
763    }
764
765    /// Handle the response if the user queries what models there are
766    /// ("! md" prompt in cli).  
767    pub fn model_list(&self) -> Result<Vec<String>, Box<dyn Error>> {
768        let uri: String = format!("{}/models", API_URL);
769        let response = self
770            .client
771            .get(uri.as_str())
772            .header("Content-Type", "application/json")
773            .header("Authorization", format!("Bearer {}", self.api_key))
774            .send()?;
775        if !response.status().is_success() {
776            // If it were not a success the previous cal will have failed
777            // This will not happen
778            panic!("Failed call to get model list. {:?}", response);
779        }
780        let model_returned: ModelReturned = response.json().unwrap();
781        println!("{:?}", model_returned);
782        Ok(model_returned.data.iter().map(|x| x.root.clone()).collect())
783        // Ok(vec![])
784    }
785
786    /// Convert the usege into a price.
787    fn cost(usage: Usage, model: &str) -> f64 {
788        // GPT-4is more expensive
789        if model.starts_with("gpt-4") {
790            usage.completion_tokens as f64 / 1000.0 * 12.0
791                + usage.prompt_tokens as f64 / 1000.0 * 0.06
792        } else if model.starts_with("gpt-3") {
793            usage.total_tokens as f64 / 1000.0 * 0.2
794        } else {
795            panic!("{model}");
796        }
797    }
798    fn usage_headers(usage: Usage) -> HashMap<String, String> {
799        let prompt_tokens = usage.prompt_tokens.to_string();
800        let completion_tokens = usage.completion_tokens.to_string();
801        let total_tokens = usage.total_tokens.to_string();
802        let mut result = HashMap::new();
803        result.insert("Tokens prompt".to_string(), prompt_tokens);
804        result.insert("Tokens completion".to_string(), completion_tokens);
805        result.insert("Tokens total".to_string(), total_tokens);
806        result
807    }
808
809    /// Used to adapt headers reported from Reqwest
810    fn header_map_to_hash_map(header_map: &HeaderMap) -> HashMap<String, String> {
811        let mut hash_map = HashMap::new();
812        for (header_name, header_value) in header_map.iter() {
813            if let (Ok(name), Ok(value)) = (
814                header_name.to_string().as_str().trim().parse::<String>(),
815                header_value.to_str().map(str::to_owned),
816            ) {
817                hash_map.insert(name, value);
818            }
819        }
820        hash_map
821    }
822
823    /// Clear the context used to maintain chat history
824    pub fn clear_context(&mut self) {
825        self.context.clear();
826    }
827
828    /// Send a request, the body of which is coded in `data`, to `uri`.
829    /// Return the Json data as a String
830    fn send_curl(
831        &mut self,
832        data: &serde_json::Value,
833        uri: &str,
834    ) -> Result<(HashMap<String, String>, String), Box<dyn Error>> {
835        let body = format!("{data}");
836
837        let mut body = body.as_bytes();
838        let mut curl_easy = Easy::new();
839        curl_easy.url(uri)?;
840
841        // Prepare the headers
842        let mut list = List::new();
843        list.append(format!("Authorization: Bearer {}", self.api_key).as_str())?;
844        list.append("Content-Type: application/json")?;
845        curl_easy.http_headers(list)?;
846
847        // I am unsure why I have to do this magick incantation
848        curl_easy.post_field_size(body.len() as u64)?;
849
850        // To get the normal output of the server
851        let mut output_buffer = Vec::new();
852
853        // To get the headers
854        let mut header_buffer = Vec::new();
855
856        // Time the process.
857        let start = Instant::now();
858
859        {
860            // Start a block so `transfer` is destroyed and releases the
861            // borrow it has on `header_buffer` and `output_buffer`
862            let mut transfer = curl_easy.transfer();
863            transfer.header_function(|data| {
864                header_buffer.push(String::from_utf8(data.to_vec()).unwrap());
865                true
866            })?;
867            transfer.read_function(|buf| Ok(body.read(buf).unwrap_or(0)))?;
868            transfer.write_function(|data| {
869                output_buffer.extend_from_slice(data);
870                Ok(data.len())
871            })?;
872            transfer.perform()?;
873        }
874
875        // Made the call, got the output,  Close the timer
876        let _duration = start.elapsed();
877
878        let result = String::from_utf8(output_buffer)?; // Process the output
879
880        let headers_hm: HashMap<String, String> = header_buffer
881            .into_iter()
882            .filter_map(|item| {
883                let mut parts = item.splitn(2, ':');
884                if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
885                    Some((key.to_string(), value.trim().to_string()))
886                } else {
887                    None
888                }
889            })
890            .collect();
891        Ok((headers_hm, result))
892    }
893}