elikoga_textsynth/
completions.rs

1//! Provides completion api
2
3pub mod logprob;
4
5use std::{collections::HashMap, fmt, marker::PhantomData};
6
7use bytes::{Buf, BytesMut};
8use futures::{stream, Stream, StreamExt};
9use serde::{de, Deserialize, Deserializer, Serialize};
10use serde_with::skip_serializing_none;
11use thiserror::Error;
12
13use crate::{IsEngine, TextSynthClient};
14
15/// Enum for the different completion engines available for TextSynth
16#[derive(strum::Display)]
17pub enum Engine {
18    /// GPT-J is a language model with 6 billion parameters trained on the Pile
19    /// (825 GB of text data) published by EleutherAI. Its main language is
20    /// English but it is also fluent in several other languages. It is also
21    /// trained on several computer languages.
22    #[strum(serialize = "gptj_6B")]
23    GPTJ6B,
24    /// Boris is a fine tuned version of GPT-J for the French language. Use this
25    /// model is you want the best performance with the French language.
26    #[strum(serialize = "boris_6B")]
27    Boris6B,
28    /// Fairseq GPT 13B is an English language model with 13 billion parameters.
29    /// Its training corpus is less diverse than GPT-J but it has better
30    /// performance at least on pure English language tasks.
31    #[strum(serialize = "fairseq_gpt_13B")]
32    FairseqGPT13B,
33    /// GPT-NeoX-20B is the largest publically available English language model
34    /// with 20 billion parameters. It was trained on the same corpus as GPT-J.
35    #[strum(serialize = "gptneox_20B")]
36    GPTNeoX20B,
37}
38
39impl IsEngine for Engine {
40    fn is_completion(&self) -> bool {
41        true
42    }
43}
44
45/// Struct for a completion request
46#[skip_serializing_none]
47#[derive(Serialize, Builder)]
48#[builder(setter(into))]
49#[builder(build_fn(validate = "Self::validate"))]
50pub struct Request {
51    /// The input text to complete.
52    ///
53    /// NOTE: The prompt is not included in the output.
54    prompt: String,
55    /// Maximum number of tokens to generate. A token represents about 4
56    /// characters for English texts. The total number of tokens (prompt +
57    /// generated text) cannot exceed the model's maximum context length. It
58    /// is of 2048 for GPT-J and 1024 for the other models.
59    #[builder(setter(strip_option))]
60    #[builder(default)]
61    max_tokens: Option<u32>,
62    /// If true, the output is streamed so that it is possible to display the
63    /// result before the complete output is generated. Several JSON answers
64    /// are output. Each answer is followed by two line feed characters.
65    #[builder(setter(strip_option))]
66    #[builder(default)]
67    stream: Option<bool>,
68    /// Stop the generation when the string(s) are encountered. The generated
69    /// text does not contain the string. The length of the array is at most 5.
70    #[builder(setter(strip_option))]
71    #[builder(default)]
72    stop: Option<Vec<String>>,
73    /// Generate n completions from a single prompt.
74    #[builder(setter(strip_option))]
75    #[builder(default)]
76    n: Option<u32>,
77    /// Sampling temperature. A higher temperature means the model will select
78    /// less common tokens leading to a larger diversity but potentially less
79    /// relevant output. It is usually better to tune top_p or top_k.
80    #[builder(setter(strip_option))]
81    #[builder(default)]
82    temperature: Option<f64>,
83    /// Select the next output token among the top_k most likely ones. A higher
84    /// top_k gives more diversity but a potentially less relevant output.
85    #[builder(setter(strip_option))]
86    #[builder(default)]
87    top_k: Option<u32>,
88    /// Select the next output token among the most probable ones so that their
89    /// cumulative probability is larger than top_p. A higher top_p gives more
90    /// diversity but a potentially less relevant output. top_p and top_k are
91    /// combined, meaning that at most top_k tokens are selected. A value of 1
92    /// disables this sampling.
93    #[builder(setter(strip_option))]
94    #[builder(default)]
95    top_p: Option<f64>,
96    // More advanced sampling parameters are available:
97    /// Modify the likelihood of the specified tokens in the completion.
98    /// The specified object is a map between the token indexes and the
99    /// corresponding logit bias. A negative bias reduces the likelihood of the
100    /// corresponding token. The bias must be between -100 and 100. Note that
101    /// the token indexes are specific to the selected model. You can use the
102    /// tokenize API endpoint to retrieve the token indexes of a given model.
103    /// Example: if you want to ban the " unicorn" token for GPT-J, you can use:
104    /// logit_bias: { "44986": -100 }
105    #[builder(setter(strip_option))]
106    #[builder(default)]
107    logit_bias: Option<HashMap<String, f64>>,
108    /// A positive value penalizes tokens which already appeared in the
109    /// generated text. Hence it forces the model to have a more diverse output.
110    #[builder(setter(strip_option))]
111    #[builder(default)]
112    presence_penalty: Option<f64>,
113    /// A positive value penalizes tokens which already appeared in the
114    /// generated text proportionaly to their frequency. Hence it forces the
115    /// model to have a more diverse output.
116    #[builder(setter(strip_option))]
117    #[builder(default)]
118    frequency_penalty: Option<f64>,
119    /// Divide by repetition_penalty the logits corresponding to tokens which
120    /// already appeared in the generated text. A value of 1 effectively
121    /// disables it.
122    #[builder(setter(strip_option))]
123    #[builder(default)]
124    repetition_penalty: Option<f64>,
125    /// Alternative to top_p sampling: instead of selecting the tokens starting
126    /// from the most probable one, start from the ones whose log likelihood is
127    /// the closest to the symbol entropy. This is useful for models with a
128    /// low top_p value.
129    /// The value of 1 disables this sampling.
130    #[builder(setter(strip_option))]
131    #[builder(default)]
132    typical_p: Option<f64>,
133}
134
135impl RequestBuilder {
136    fn validate(&self) -> Result<(), String> {
137        // n must be between 1 and 16
138        match self.n {
139            Some(Some(n)) if !(1..=16).contains(&n) => {
140                return Err("n must be between 1 and 16".to_string());
141            }
142            _ => {}
143        };
144        // top_k must be between 1 and 1000
145        match self.top_k {
146            Some(Some(top_k)) if !(1..=1000).contains(&top_k) => {
147                return Err("top_k must be between 1 and 1000".to_string());
148            }
149            _ => {}
150        };
151        // top_p must be between 0.0 and 1.0
152        match self.top_p {
153            Some(Some(top_p)) if !(0.0..=1.0).contains(&top_p) => {
154                return Err("top_p must be between 0.0 and 1.0".to_string());
155            }
156            _ => {}
157        };
158        // presence_penalty must be between -2.0 and 2.0
159        match self.presence_penalty {
160            Some(Some(presence_penalty)) if !(-2.0..=2.0).contains(&presence_penalty) => {
161                return Err("presence_penalty must be between -2.0 and 2.0".to_string());
162            }
163            _ => {}
164        };
165        // frequency_penalty must be between -2.0 and 2.0
166        match self.frequency_penalty {
167            Some(Some(frequency_penalty)) if !(-2.0..=2.0).contains(&frequency_penalty) => {
168                return Err("frequency_penalty must be between -2.0 and 2.0".to_string());
169            }
170            _ => {}
171        };
172        // typical_p: must be > 0 and <= 1\
173        match self.typical_p {
174            Some(Some(typical_p)) if !(typical_p > 0.0 && typical_p <= 1.0) => {
175                return Err("typical_p must be between 0.0 and 1.0".to_string());
176            }
177            _ => {}
178        };
179        Ok(())
180    }
181}
182
183fn string_or_seq_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
184where
185    D: Deserializer<'de>,
186{
187    struct StringOrVec(PhantomData<Vec<String>>);
188
189    impl<'de> de::Visitor<'de> for StringOrVec {
190        type Value = Vec<String>;
191
192        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
193            formatter.write_str("string or list of strings")
194        }
195
196        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
197        where
198            E: de::Error,
199        {
200            Ok(vec![value.to_owned()])
201        }
202
203        fn visit_seq<S>(self, visitor: S) -> Result<Self::Value, S::Error>
204        where
205            S: de::SeqAccess<'de>,
206        {
207            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(visitor))
208        }
209    }
210
211    deserializer.deserialize_any(StringOrVec(PhantomData))
212}
213
214/// Struct for a completion answer
215#[derive(Deserialize, Debug)]
216pub struct ResponseChunk {
217    /// The completed text.
218    #[serde(deserialize_with = "string_or_seq_string")]
219    pub text: Vec<String>,
220    /// If true, indicate that it is the last answer.
221    pub reached_end: bool,
222    /// If true, indicate that the prompt was truncated because it was too large
223    pub truncated_prompt: Option<bool>,
224    /// Indicate the number of input tokens.
225    pub input_tokens: Option<u32>,
226    /// Indicate the total number of generated tokens.
227    pub output_tokens: Option<u32>,
228}
229
230#[derive(Error, Debug)]
231/// Error for a completion answer
232pub enum Error {
233    /// Serde error
234    #[error("Serde error: {0}")]
235    SerdeError(#[from] serde_json::Error),
236    /// Error from Reqwest
237    #[error("Reqwest error: {0}")]
238    RequestError(#[from] reqwest::Error),
239    /// Couldn't parse the response to completion
240    #[error("Couldn't parse the response to completion")]
241    ParseError(bytes::Bytes),
242}
243
244impl TextSynthClient {
245    /// Perform a completion request
246    pub async fn completions(
247        &self,
248        engine: &Engine,
249        request: &Request,
250    ) -> Result<impl Stream<Item = Result<ResponseChunk, Error>>, Error> {
251        let request_json = serde_json::to_string(&request)?;
252        let url = format!("{}/engines/{}/completions", self.base_url, engine);
253        let response = self.client.post(&url).body(request_json).send().await?;
254
255        struct StreamState<S> {
256            inner: S,
257            chunks: BytesMut,
258        }
259        let state = StreamState {
260            inner: response.bytes_stream(),
261            chunks: BytesMut::new(),
262        };
263        let response_stream = stream::unfold(state, |mut state| async move {
264            loop {
265                if let Some(chunk) = state.inner.next().await {
266                    let chunk = match chunk {
267                        Ok(chunk) => chunk,
268                        Err(err) => break Some((Err(err.into()), state)),
269                    };
270                    state.chunks.extend_from_slice(&chunk);
271                    // stream parse
272                    let mut stream = serde_json::Deserializer::from_slice(&state.chunks)
273                        .into_iter::<ResponseChunk>();
274                    // get next chunk
275                    let next = Iterator::next(&mut stream);
276                    // println!("Next: {:?}", next);
277                    if let Some(Ok(chunk)) = next {
278                        // remove parsed chunk from buffer
279                        state.chunks.advance(stream.byte_offset());
280                        // remove leading whitespace from buffer
281                        let mut i = 0;
282                        while i < state.chunks.len() {
283                            if state.chunks[i].is_ascii_whitespace() {
284                                i += 1;
285                            } else {
286                                break;
287                            }
288                        }
289                        state.chunks.advance(i);
290                        break Some((Ok(chunk), state));
291                    }
292                } else {
293                    // end of stream
294                    // if there is some data in the buffer (that isn't whitespace), return error
295                    if state.chunks.is_empty() {
296                        break None;
297                    } else {
298                        // return error
299                        break Some((
300                            Err(Error::ParseError(state.chunks.freeze())),
301                            StreamState {
302                                chunks: BytesMut::new(),
303                                ..state
304                            },
305                        ));
306                    }
307                }
308            }
309        });
310        Ok(Box::pin(response_stream))
311    }
312}