elikoga_textsynth/completions.rs
1//! Provides completion api
2
3pub mod logprob;
4
5use std::{collections::HashMap, fmt, marker::PhantomData};
6
7use bytes::{Buf, BytesMut};
8use futures::{stream, Stream, StreamExt};
9use serde::{de, Deserialize, Deserializer, Serialize};
10use serde_with::skip_serializing_none;
11use thiserror::Error;
12
13use crate::{IsEngine, TextSynthClient};
14
15/// Enum for the different completion engines available for TextSynth
16#[derive(strum::Display)]
17pub enum Engine {
18 /// GPT-J is a language model with 6 billion parameters trained on the Pile
19 /// (825 GB of text data) published by EleutherAI. Its main language is
20 /// English but it is also fluent in several other languages. It is also
21 /// trained on several computer languages.
22 #[strum(serialize = "gptj_6B")]
23 GPTJ6B,
24 /// Boris is a fine tuned version of GPT-J for the French language. Use this
25 /// model is you want the best performance with the French language.
26 #[strum(serialize = "boris_6B")]
27 Boris6B,
28 /// Fairseq GPT 13B is an English language model with 13 billion parameters.
29 /// Its training corpus is less diverse than GPT-J but it has better
30 /// performance at least on pure English language tasks.
31 #[strum(serialize = "fairseq_gpt_13B")]
32 FairseqGPT13B,
33 /// GPT-NeoX-20B is the largest publically available English language model
34 /// with 20 billion parameters. It was trained on the same corpus as GPT-J.
35 #[strum(serialize = "gptneox_20B")]
36 GPTNeoX20B,
37}
38
39impl IsEngine for Engine {
40 fn is_completion(&self) -> bool {
41 true
42 }
43}
44
45/// Struct for a completion request
46#[skip_serializing_none]
47#[derive(Serialize, Builder)]
48#[builder(setter(into))]
49#[builder(build_fn(validate = "Self::validate"))]
50pub struct Request {
51 /// The input text to complete.
52 ///
53 /// NOTE: The prompt is not included in the output.
54 prompt: String,
55 /// Maximum number of tokens to generate. A token represents about 4
56 /// characters for English texts. The total number of tokens (prompt +
57 /// generated text) cannot exceed the model's maximum context length. It
58 /// is of 2048 for GPT-J and 1024 for the other models.
59 #[builder(setter(strip_option))]
60 #[builder(default)]
61 max_tokens: Option<u32>,
62 /// If true, the output is streamed so that it is possible to display the
63 /// result before the complete output is generated. Several JSON answers
64 /// are output. Each answer is followed by two line feed characters.
65 #[builder(setter(strip_option))]
66 #[builder(default)]
67 stream: Option<bool>,
68 /// Stop the generation when the string(s) are encountered. The generated
69 /// text does not contain the string. The length of the array is at most 5.
70 #[builder(setter(strip_option))]
71 #[builder(default)]
72 stop: Option<Vec<String>>,
73 /// Generate n completions from a single prompt.
74 #[builder(setter(strip_option))]
75 #[builder(default)]
76 n: Option<u32>,
77 /// Sampling temperature. A higher temperature means the model will select
78 /// less common tokens leading to a larger diversity but potentially less
79 /// relevant output. It is usually better to tune top_p or top_k.
80 #[builder(setter(strip_option))]
81 #[builder(default)]
82 temperature: Option<f64>,
83 /// Select the next output token among the top_k most likely ones. A higher
84 /// top_k gives more diversity but a potentially less relevant output.
85 #[builder(setter(strip_option))]
86 #[builder(default)]
87 top_k: Option<u32>,
88 /// Select the next output token among the most probable ones so that their
89 /// cumulative probability is larger than top_p. A higher top_p gives more
90 /// diversity but a potentially less relevant output. top_p and top_k are
91 /// combined, meaning that at most top_k tokens are selected. A value of 1
92 /// disables this sampling.
93 #[builder(setter(strip_option))]
94 #[builder(default)]
95 top_p: Option<f64>,
96 // More advanced sampling parameters are available:
97 /// Modify the likelihood of the specified tokens in the completion.
98 /// The specified object is a map between the token indexes and the
99 /// corresponding logit bias. A negative bias reduces the likelihood of the
100 /// corresponding token. The bias must be between -100 and 100. Note that
101 /// the token indexes are specific to the selected model. You can use the
102 /// tokenize API endpoint to retrieve the token indexes of a given model.
103 /// Example: if you want to ban the " unicorn" token for GPT-J, you can use:
104 /// logit_bias: { "44986": -100 }
105 #[builder(setter(strip_option))]
106 #[builder(default)]
107 logit_bias: Option<HashMap<String, f64>>,
108 /// A positive value penalizes tokens which already appeared in the
109 /// generated text. Hence it forces the model to have a more diverse output.
110 #[builder(setter(strip_option))]
111 #[builder(default)]
112 presence_penalty: Option<f64>,
113 /// A positive value penalizes tokens which already appeared in the
114 /// generated text proportionaly to their frequency. Hence it forces the
115 /// model to have a more diverse output.
116 #[builder(setter(strip_option))]
117 #[builder(default)]
118 frequency_penalty: Option<f64>,
119 /// Divide by repetition_penalty the logits corresponding to tokens which
120 /// already appeared in the generated text. A value of 1 effectively
121 /// disables it.
122 #[builder(setter(strip_option))]
123 #[builder(default)]
124 repetition_penalty: Option<f64>,
125 /// Alternative to top_p sampling: instead of selecting the tokens starting
126 /// from the most probable one, start from the ones whose log likelihood is
127 /// the closest to the symbol entropy. This is useful for models with a
128 /// low top_p value.
129 /// The value of 1 disables this sampling.
130 #[builder(setter(strip_option))]
131 #[builder(default)]
132 typical_p: Option<f64>,
133}
134
135impl RequestBuilder {
136 fn validate(&self) -> Result<(), String> {
137 // n must be between 1 and 16
138 match self.n {
139 Some(Some(n)) if !(1..=16).contains(&n) => {
140 return Err("n must be between 1 and 16".to_string());
141 }
142 _ => {}
143 };
144 // top_k must be between 1 and 1000
145 match self.top_k {
146 Some(Some(top_k)) if !(1..=1000).contains(&top_k) => {
147 return Err("top_k must be between 1 and 1000".to_string());
148 }
149 _ => {}
150 };
151 // top_p must be between 0.0 and 1.0
152 match self.top_p {
153 Some(Some(top_p)) if !(0.0..=1.0).contains(&top_p) => {
154 return Err("top_p must be between 0.0 and 1.0".to_string());
155 }
156 _ => {}
157 };
158 // presence_penalty must be between -2.0 and 2.0
159 match self.presence_penalty {
160 Some(Some(presence_penalty)) if !(-2.0..=2.0).contains(&presence_penalty) => {
161 return Err("presence_penalty must be between -2.0 and 2.0".to_string());
162 }
163 _ => {}
164 };
165 // frequency_penalty must be between -2.0 and 2.0
166 match self.frequency_penalty {
167 Some(Some(frequency_penalty)) if !(-2.0..=2.0).contains(&frequency_penalty) => {
168 return Err("frequency_penalty must be between -2.0 and 2.0".to_string());
169 }
170 _ => {}
171 };
172 // typical_p: must be > 0 and <= 1\
173 match self.typical_p {
174 Some(Some(typical_p)) if !(typical_p > 0.0 && typical_p <= 1.0) => {
175 return Err("typical_p must be between 0.0 and 1.0".to_string());
176 }
177 _ => {}
178 };
179 Ok(())
180 }
181}
182
183fn string_or_seq_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
184where
185 D: Deserializer<'de>,
186{
187 struct StringOrVec(PhantomData<Vec<String>>);
188
189 impl<'de> de::Visitor<'de> for StringOrVec {
190 type Value = Vec<String>;
191
192 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
193 formatter.write_str("string or list of strings")
194 }
195
196 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
197 where
198 E: de::Error,
199 {
200 Ok(vec![value.to_owned()])
201 }
202
203 fn visit_seq<S>(self, visitor: S) -> Result<Self::Value, S::Error>
204 where
205 S: de::SeqAccess<'de>,
206 {
207 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(visitor))
208 }
209 }
210
211 deserializer.deserialize_any(StringOrVec(PhantomData))
212}
213
214/// Struct for a completion answer
215#[derive(Deserialize, Debug)]
216pub struct ResponseChunk {
217 /// The completed text.
218 #[serde(deserialize_with = "string_or_seq_string")]
219 pub text: Vec<String>,
220 /// If true, indicate that it is the last answer.
221 pub reached_end: bool,
222 /// If true, indicate that the prompt was truncated because it was too large
223 pub truncated_prompt: Option<bool>,
224 /// Indicate the number of input tokens.
225 pub input_tokens: Option<u32>,
226 /// Indicate the total number of generated tokens.
227 pub output_tokens: Option<u32>,
228}
229
230#[derive(Error, Debug)]
231/// Error for a completion answer
232pub enum Error {
233 /// Serde error
234 #[error("Serde error: {0}")]
235 SerdeError(#[from] serde_json::Error),
236 /// Error from Reqwest
237 #[error("Reqwest error: {0}")]
238 RequestError(#[from] reqwest::Error),
239 /// Couldn't parse the response to completion
240 #[error("Couldn't parse the response to completion")]
241 ParseError(bytes::Bytes),
242}
243
244impl TextSynthClient {
245 /// Perform a completion request
246 pub async fn completions(
247 &self,
248 engine: &Engine,
249 request: &Request,
250 ) -> Result<impl Stream<Item = Result<ResponseChunk, Error>>, Error> {
251 let request_json = serde_json::to_string(&request)?;
252 let url = format!("{}/engines/{}/completions", self.base_url, engine);
253 let response = self.client.post(&url).body(request_json).send().await?;
254
255 struct StreamState<S> {
256 inner: S,
257 chunks: BytesMut,
258 }
259 let state = StreamState {
260 inner: response.bytes_stream(),
261 chunks: BytesMut::new(),
262 };
263 let response_stream = stream::unfold(state, |mut state| async move {
264 loop {
265 if let Some(chunk) = state.inner.next().await {
266 let chunk = match chunk {
267 Ok(chunk) => chunk,
268 Err(err) => break Some((Err(err.into()), state)),
269 };
270 state.chunks.extend_from_slice(&chunk);
271 // stream parse
272 let mut stream = serde_json::Deserializer::from_slice(&state.chunks)
273 .into_iter::<ResponseChunk>();
274 // get next chunk
275 let next = Iterator::next(&mut stream);
276 // println!("Next: {:?}", next);
277 if let Some(Ok(chunk)) = next {
278 // remove parsed chunk from buffer
279 state.chunks.advance(stream.byte_offset());
280 // remove leading whitespace from buffer
281 let mut i = 0;
282 while i < state.chunks.len() {
283 if state.chunks[i].is_ascii_whitespace() {
284 i += 1;
285 } else {
286 break;
287 }
288 }
289 state.chunks.advance(i);
290 break Some((Ok(chunk), state));
291 }
292 } else {
293 // end of stream
294 // if there is some data in the buffer (that isn't whitespace), return error
295 if state.chunks.is_empty() {
296 break None;
297 } else {
298 // return error
299 break Some((
300 Err(Error::ParseError(state.chunks.freeze())),
301 StreamState {
302 chunks: BytesMut::new(),
303 ..state
304 },
305 ));
306 }
307 }
308 }
309 });
310 Ok(Box::pin(response_stream))
311 }
312}