openai_req/completion/
mod.rs

1use async_trait::async_trait;
2use crate::{Input, JsonRequest, Usage};
3use std::collections::HashMap;
4use serde::{Serialize,Deserialize};
5
6///text completion request
7///detailed description of params at https://platform.openai.com/docs/api-reference/completions
8///
9/// # Usage example
10///```
11///    use openai_req::completion::CompletionRequest;
12///    use openai_req::JsonRequest;
13///
14///    let completion_request =
15///         CompletionRequest::new("long long time ago".into());
16///     let response = completion_request.run(&client).await?;
17/// ```
18#[derive(Clone, Serialize, Deserialize, Debug)]
19pub struct CompletionRequest {
20    model: String,
21    prompt: Input,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    suffix: Option<String>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    max_tokens: Option<u32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    temperature: Option<f32>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    top_p: Option<f32>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    n: Option<u16>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    stream: Option<bool>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    logprobs: Option<u32>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    echo: Option<bool>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    stop: Option<Vec<String>>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    presence_penalty: Option<f32>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    frequency_penalty: Option<f32>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    best_of: Option<u16>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    logit_bias: Option<HashMap<String, f32>>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    user: Option<String>,
50}
51
52
53#[async_trait(?Send)]
54impl JsonRequest<CompletionSuccess> for CompletionRequest {
55    const ENDPOINT: &'static str = "/completions";
56}
57
58impl CompletionRequest {
59    pub fn new(prompt: Input) -> CompletionRequest {
60        CompletionRequest {
61            model: "text-davinci-003".to_string(),
62            prompt,
63            suffix: None,
64            max_tokens: None,
65            temperature: None,
66            top_p: None,
67            n: None,
68            stream: None,
69            logprobs: None,
70            echo: None,
71            stop: None,
72            presence_penalty: None,
73            frequency_penalty: None,
74            best_of: None,
75            logit_bias: None,
76            user: None,
77        }
78    }
79    pub fn with_model(model: &str, prompt: Input) -> CompletionRequest {
80        CompletionRequest {
81            model: model.to_string(),
82            prompt,
83            suffix: None,
84            max_tokens: None,
85            temperature: None,
86            top_p: None,
87            n: None,
88            stream: None,
89            logprobs: None,
90            echo: None,
91            stop: None,
92            presence_penalty: None,
93            frequency_penalty: None,
94            best_of: None,
95            logit_bias: None,
96            user: None,
97        }
98    }
99
100    pub fn set_suffix(&mut self, suffix: &str) -> &mut Self {
101        self.suffix = Some(suffix.to_string());
102        self
103    }
104
105    pub fn set_max_tokens(&mut self, max_tokens: u32) -> &mut Self {
106        self.max_tokens = Some(max_tokens);
107        self
108    }
109
110    pub fn set_temperature(&mut self, temperature: f32) -> &mut Self {
111        self.temperature = Some(temperature);
112        self
113    }
114
115    pub fn set_top_p(&mut self, top_p: f32) -> &mut Self {
116        self.top_p = Some(top_p);
117        self
118    }
119
120    pub fn set_n(&mut self, n: u16) -> &mut Self {
121        self.n = Some(n);
122        self
123    }
124
125    pub fn set_stream(&mut self, stream: bool) -> &mut Self {
126        self.stream = Some(stream);
127        self
128    }
129
130    pub fn set_logprobs(&mut self, logprobs: u32) -> &mut Self {
131        self.logprobs = Some(logprobs);
132        self
133    }
134
135    pub fn set_echo(&mut self, echo: bool) -> &mut Self {
136        self.echo = Some(echo);
137        self
138    }
139
140    pub fn set_stop(mut self, stop: impl Into<Vec<String>>) -> Self {
141        self.stop = Some(stop.into());
142        self
143    }
144
145    pub fn set_presence_penalty(mut self, presence_penalty: f32) -> Self {
146        self.presence_penalty = Some(presence_penalty);
147        self
148    }
149
150    pub fn set_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
151        self.frequency_penalty = Some(frequency_penalty);
152        self
153    }
154
155    pub fn set_best_of(mut self, best_of: u16) -> Self {
156        self.best_of = Some(best_of);
157        self
158    }
159
160    pub fn set_logit_bias(mut self, logit_bias: HashMap<String, f32>) -> Self {
161        self.logit_bias = Some(logit_bias);
162        self
163    }
164
165    pub fn set_user(mut self, user: String) -> Self {
166        self.user = Some(user);
167        self
168    }
169}
170
171#[derive(Clone, Serialize, Deserialize, Debug)]
172pub  struct CompletionChoice {
173    pub text: String,
174    pub index: i64,
175    pub logprobs: Option<u32>,
176    pub finish_reason: String,
177}
178
179#[derive(Clone, Serialize, Deserialize, Debug)]
180pub  struct CompletionSuccess {
181    pub id: String,
182    pub object: String,
183    pub created: i64,
184    pub model: String,
185    pub choices: Vec<CompletionChoice>,
186    pub usage: Usage,
187}