use crate::engine::definition::EngineDefinition;
use crate::engine::Engine;
use arrayvec::ArrayVec;
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use tap::Pipe;
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize)]
pub struct MaxTokens(usize);
impl MaxTokens {
pub fn new(max_tokens: usize, engine_definition: &EngineDefinition) -> Option<Self> {
if max_tokens <= engine_definition.max_tokens() {
Some(Self(max_tokens))
} else {
None
}
}
pub fn inner(&self) -> usize {
self.0
}
}
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Serialize)]
pub struct TopP(f64);
impl TopP {
pub fn new(top_p: f64) -> Option<Self> {
if (0.0..=1.0).contains(&top_p) {
Some(Self(top_p))
} else {
None
}
}
}
pub type TopK = bounded_integer::BoundedU16<1, 1000>;
pub type Stop = ArrayVec<String, 5>;
#[derive(Serialize, Default)]
struct TextCompletionRequest {
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<MaxTokens>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<TopK>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<TopP>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Stop>,
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Deserialize)]
pub struct TextCompletion {
text: String,
reached_end: bool,
truncated_prompt: Option<bool>,
total_tokens: Option<usize>,
}
impl TextCompletion {
pub fn text(&self) -> &str {
&self.text
}
pub fn reached_end(&self) -> bool {
self.reached_end
}
pub fn truncated_prompt(&self) -> bool {
self.truncated_prompt.unwrap_or(false)
}
pub fn total_tokens(&self) -> Option<usize> {
self.total_tokens
}
}
pub type TextCompletionStreamResult =
reqwest::Result<serde_json::Result<crate::Result<TextCompletion>>>;
pub trait TextCompletionStream: Stream<Item = TextCompletionStreamResult> {}
impl<T: Stream<Item = TextCompletionStreamResult>> TextCompletionStream for T {}
#[derive(Clone)]
pub struct TextCompletionBuilder<'ts, 'e> {
pub engine: &'e Engine<'ts>,
pub prompt: String,
pub max_tokens: Option<MaxTokens>,
pub temperature: Option<f64>,
pub top_k: Option<TopK>,
pub top_p: Option<TopP>,
}
impl<'ts, 'e> TextCompletionBuilder<'ts, 'e> {
pub const fn new(engine: &'e Engine<'ts>, prompt: String) -> Self {
Self {
engine,
prompt,
max_tokens: None,
temperature: None,
top_k: None,
top_p: None,
}
}
pub fn max_tokens(mut self, max_tokens: MaxTokens) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn top_k(mut self, top_k: TopK) -> Self {
self.top_k = Some(top_k);
self
}
pub fn top_p(mut self, top_p: TopP) -> Self {
self.top_p = Some(top_p);
self
}
fn url(&self) -> String {
let engine_id = self.engine.definition.id();
format!("https://api.textsynth.com/v1/engines/{engine_id}/completions")
}
async fn now_impl(self, stop: Option<Stop>) -> reqwest::Result<crate::Result<TextCompletion>> {
let url = self.url();
let request = TextCompletionRequest {
prompt: self.prompt,
max_tokens: self.max_tokens,
temperature: self.temperature,
top_k: self.top_k,
top_p: self.top_p,
stream: None,
stop,
};
self.engine
.text_synth
.post(url)
.json(&request)
.send()
.await?
.json::<crate::UntaggedResult<_>>()
.await
.map(Into::into)
}
pub async fn now(self) -> reqwest::Result<crate::Result<TextCompletion>> {
self.now_impl(None).await
}
pub async fn now_until(self, stop: Stop) -> reqwest::Result<crate::Result<TextCompletion>> {
self.now_impl(Some(stop)).await
}
pub async fn stream(self) -> reqwest::Result<impl TextCompletionStream> {
let url = self.url();
let request = TextCompletionRequest {
prompt: self.prompt,
max_tokens: self.max_tokens,
temperature: self.temperature,
top_k: self.top_k,
top_p: self.top_p,
stream: Some(true),
stop: None,
};
self.engine
.text_synth
.post(url)
.json(&request)
.send()
.await?
.bytes_stream()
.map(|bytes| {
bytes
.map(|bytes| bytes.slice(..bytes.len() - 2))
.map(|bytes| serde_json::from_slice::<crate::UntaggedResult<_>>(&bytes))
.map(|result| result.map(Into::into))
})
.pipe(Ok)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::CustomEngineDefinition;
use crate::test_utils;
use once_cell::sync::Lazy;
use test_utils::text_synth;
static YOU_SHOULD_CLONE_THIS_BUILDER: Lazy<TextCompletionBuilder> =
Lazy::new(|| text_synth::engine().text_completion("fn main() {".into()));
static BUILDER: Lazy<TextCompletionBuilder> = Lazy::new(|| {
YOU_SHOULD_CLONE_THIS_BUILDER
.clone()
.max_tokens(MaxTokens::new(128, &text_synth::ENGINE_DEFINITION).unwrap())
.temperature(0.5)
.top_k(TopK::new(128).unwrap())
.top_p(TopP::new(0.5).unwrap())
});
static ENGINE_DEFINITION: EngineDefinition =
EngineDefinition::Custom(CustomEngineDefinition::r#static("custom", 1024));
#[test]
fn test_max_tokens_new() {
assert!(MaxTokens::new(1, &ENGINE_DEFINITION).is_some());
assert!(MaxTokens::new(1024, &ENGINE_DEFINITION).is_some());
assert!(MaxTokens::new(1025, &ENGINE_DEFINITION).is_none());
}
#[test]
fn test_max_tokens_inner() {
let max_tokens = MaxTokens::new(1, &ENGINE_DEFINITION).unwrap();
assert_eq!(max_tokens.inner(), 1);
}
#[test]
fn test_text_completion_builder_new() {
let _ = TextCompletionBuilder::new(text_synth::engine(), "fn main() {".into());
}
#[test]
fn test_text_completion_max_tokens() {
let max_tokens = MaxTokens::new(128, &text_synth::ENGINE_DEFINITION).unwrap();
let _ = YOU_SHOULD_CLONE_THIS_BUILDER.clone().max_tokens(max_tokens);
}
#[test]
fn test_text_completion_temperature() {
let _ = YOU_SHOULD_CLONE_THIS_BUILDER.clone().temperature(0.5);
}
#[test]
fn test_text_completion_top_k() {
let top_k = TopK::new(128).unwrap();
let _ = YOU_SHOULD_CLONE_THIS_BUILDER.clone().top_k(top_k);
}
#[test]
fn test_text_completion_top_p() {
let top_p = TopP::new(0.5).unwrap();
let _ = YOU_SHOULD_CLONE_THIS_BUILDER.clone().top_p(top_p);
}
#[tokio::test]
async fn test_text_completion_now_and_friends() {
let text_completion = BUILDER
.clone()
.now()
.await
.expect("network error")
.expect("api error");
assert!(
text_completion.total_tokens().is_some(),
"expected total tokens of immediate text completion to exist since it is not streamed",
);
let _ = text_completion.text();
let _ = text_completion.truncated_prompt();
let _ = text_completion.reached_end();
}
#[tokio::test]
async fn test_text_completion_truncated_prompt_if_prompt_too_long() {
let mut builder = BUILDER.clone();
builder.prompt = format!(
"fn main() {{\n{}}}",
"println('Hello World')\n".repeat(2048)
);
let text_completion = builder
.now()
.await
.expect("network error")
.expect("api error");
assert!(text_completion.truncated_prompt())
}
#[tokio::test]
async fn test_text_completion_now_until() {
let _ = BUILDER
.clone()
.now_until(Stop::try_from(&["RwLock".into()][..]).unwrap())
.await
.expect("network error")
.expect("api error");
}
#[tokio::test]
async fn test_text_completion_stream() {
fn unwrap_text_completion(
text_completion: Option<&TextCompletionStreamResult>,
) -> &TextCompletion {
text_completion
.expect("at least one text completion")
.as_ref()
.expect("network error")
.as_ref()
.expect("json error")
.as_ref()
.expect("api error")
}
let stream: Vec<TextCompletionStreamResult> = BUILDER
.clone()
.stream()
.await
.expect("network error")
.collect()
.await;
let first_text_completion = stream.first().pipe(unwrap_text_completion);
assert!(first_text_completion.total_tokens().is_none());
let last_text_completion = stream.last().pipe(unwrap_text_completion);
assert!(last_text_completion.total_tokens().is_some());
}
}