openai_chat/
completion.rs1use crate::error;
4use openai_api_client::{
5 ClientError, CompletionsParams, CompletionsResponse, ErrorResponse, Request,
6};
7use reqwest::{header, ClientBuilder};
8use std::time::Duration;
9
10static TEXT_DAVINCI_MODEL: &str = "text-davinci-003";
11pub async fn completions(prompt: &str, api_key: &str) -> std::result::Result<String, error::Errpr> {
17 let model = TEXT_DAVINCI_MODEL;
18 let max_tokens: u32 = 3000;
19 let res = completions_inner(prompt, model, max_tokens, api_key).await?;
20 Ok(res)
21}
22pub async fn completions_inner(
24 prompt: &str,
25 model: &str,
26 max_tokens: u32,
27 api_key: &str,
28) -> std::result::Result<String, ClientError> {
29 let params = CompletionsParams {
30 model: model.to_string(),
31 temperature: 0,
32 max_tokens,
33 top_p: 1.0,
34 frequency_penalty: 0.0,
35 presence_penalty: 0.0,
36 stop: None,
37 suffix: None,
38 n: 1,
39 stream: false,
40 logprobs: None,
41 echo: false,
42 best_of: 1,
43 logit_bias: None,
44 user: None,
45 };
46
47 let request = Request {
48 model: params.model.clone(),
49 prompt: prompt.to_string(),
50 temperature: params.temperature,
51 max_tokens: params.max_tokens,
52 top_p: params.top_p,
53 frequency_penalty: params.frequency_penalty,
54 presence_penalty: params.presence_penalty,
55 stop: params.stop.clone(),
56 suffix: params.suffix.clone(),
57 logprobs: params.logprobs,
58 echo: params.echo,
59 best_of: params.best_of,
60 n: params.n,
61 stream: params.stream,
62 logit_bias: params.logit_bias.clone(),
63 user: params.user.clone(),
64 };
65 let request = serde_json::to_string(&request).unwrap();
66 let mut header = header::HeaderMap::new();
67 header.insert("Content-Type", "application/json".parse().unwrap());
68 header.insert(
69 "Authorization",
70 format!("Bearer {api_key}").parse().unwrap(),
71 );
72 let client = ClientBuilder::new()
73 .default_headers(header)
74 .build()
75 .unwrap();
76 let response = client
77 .post("https://api.openai.com/v1/completions")
78 .timeout(Duration::from_secs(60))
79 .body(request)
80 .send()
81 .await
82 .map_err(|e| ClientError::NetworkError(format!("{e:?}")))?
83 .bytes()
84 .await
85 .map_err(|e| ClientError::NetworkError(format!("{e:?}")))?;
86
87 let response_str =
88 std::str::from_utf8(&response).map_err(|e| ClientError::OtherError(format!("{e:?}")))?;
89
90 let completions_response: CompletionsResponse = match serde_json::from_str(response_str) {
91 Ok(response) => response,
92 Err(e1) => {
93 let error_response: ErrorResponse = match serde_json::from_str(response_str) {
94 Ok(response) => response,
95 Err(e2) => {
96 return Err(ClientError::OtherError(format!("{e2:?} {e1:?}")));
97 }
98 };
99 return Err(ClientError::APIError(error_response.error.message));
100 }
101 };
102 Ok(completions_response.choices[0].text.clone())
103}