openai_chat/
completion.rs

1//! mainly use model text-davinci-003
2//! api: POST https://api.openai.com/v1/completions
3use crate::error;
4use openai_api_client::{
5    ClientError, CompletionsParams, CompletionsResponse, ErrorResponse, Request,
6};
7use reqwest::{header, ClientBuilder};
8use std::time::Duration;
9
10static TEXT_DAVINCI_MODEL: &str = "text-davinci-003";
11/// haedcode max-token and model inside
12///
13/// model: text-davinci-003
14///
15/// max_tokens = 3000(exclude prompt token)
16pub async fn completions(prompt: &str, api_key: &str) -> std::result::Result<String, error::Errpr> {
17    let model = TEXT_DAVINCI_MODEL;
18    let max_tokens: u32 = 3000;
19    let res = completions_inner(prompt, model, max_tokens, api_key).await?;
20    Ok(res)
21}
22/// use reqwest crate to make http request
23pub async fn completions_inner(
24    prompt: &str,
25    model: &str,
26    max_tokens: u32,
27    api_key: &str,
28) -> std::result::Result<String, ClientError> {
29    let params = CompletionsParams {
30        model: model.to_string(),
31        temperature: 0,
32        max_tokens,
33        top_p: 1.0,
34        frequency_penalty: 0.0,
35        presence_penalty: 0.0,
36        stop: None,
37        suffix: None,
38        n: 1,
39        stream: false,
40        logprobs: None,
41        echo: false,
42        best_of: 1,
43        logit_bias: None,
44        user: None,
45    };
46
47    let request = Request {
48        model: params.model.clone(),
49        prompt: prompt.to_string(),
50        temperature: params.temperature,
51        max_tokens: params.max_tokens,
52        top_p: params.top_p,
53        frequency_penalty: params.frequency_penalty,
54        presence_penalty: params.presence_penalty,
55        stop: params.stop.clone(),
56        suffix: params.suffix.clone(),
57        logprobs: params.logprobs,
58        echo: params.echo,
59        best_of: params.best_of,
60        n: params.n,
61        stream: params.stream,
62        logit_bias: params.logit_bias.clone(),
63        user: params.user.clone(),
64    };
65    let request = serde_json::to_string(&request).unwrap();
66    let mut header = header::HeaderMap::new();
67    header.insert("Content-Type", "application/json".parse().unwrap());
68    header.insert(
69        "Authorization",
70        format!("Bearer {api_key}").parse().unwrap(),
71    );
72    let client = ClientBuilder::new()
73        .default_headers(header)
74        .build()
75        .unwrap();
76    let response = client
77        .post("https://api.openai.com/v1/completions")
78        .timeout(Duration::from_secs(60))
79        .body(request)
80        .send()
81        .await
82        .map_err(|e| ClientError::NetworkError(format!("{e:?}")))?
83        .bytes()
84        .await
85        .map_err(|e| ClientError::NetworkError(format!("{e:?}")))?;
86
87    let response_str =
88        std::str::from_utf8(&response).map_err(|e| ClientError::OtherError(format!("{e:?}")))?;
89
90    let completions_response: CompletionsResponse = match serde_json::from_str(response_str) {
91        Ok(response) => response,
92        Err(e1) => {
93            let error_response: ErrorResponse = match serde_json::from_str(response_str) {
94                Ok(response) => response,
95                Err(e2) => {
96                    return Err(ClientError::OtherError(format!("{e2:?} {e1:?}")));
97                }
98            };
99            return Err(ClientError::APIError(error_response.error.message));
100        }
101    };
102    Ok(completions_response.choices[0].text.clone())
103}