aleph_alpha_client/chat.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
use std::borrow::Cow;
use serde::{Deserialize, Serialize};
use crate::Task;
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message<'a> {
pub role: Cow<'a, str>,
pub content: Cow<'a, str>,
}
impl<'a> Message<'a> {
pub fn new(role: impl Into<Cow<'a, str>>, content: impl Into<Cow<'a, str>>) -> Self {
Self {
role: role.into(),
content: content.into(),
}
}
pub fn user(content: impl Into<Cow<'a, str>>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<Cow<'a, str>>) -> Self {
Self::new("assistant", content)
}
pub fn system(content: impl Into<Cow<'a, str>>) -> Self {
Self::new("system", content)
}
}
pub struct TaskChat<'a> {
/// The list of messages comprising the conversation so far.
pub messages: Vec<Message<'a>>,
/// The maximum number of tokens to be generated. Completion will terminate after the maximum
/// number of tokens is reached. Increase this value to allow for longer outputs. A text is split
/// into tokens. Usually there are more tokens than words. The total number of tokens of prompt
/// and maximum_tokens depends on the model.
/// If maximum tokens is set to None, no outside limit is opposed on the number of maximum tokens.
/// The model will generate tokens until it generates one of the specified stop_sequences or it
/// reaches its technical limit, which usually is its context window.
pub maximum_tokens: Option<u32>,
/// A temperature encourages the model to produce less probable outputs ("be more creative").
/// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
/// response.
pub temperature: Option<f64>,
/// Introduces random sampling for generated tokens by randomly selecting the next token from
/// the smallest possible set of tokens whose cumulative probability exceeds the probability
/// top_p. Set to 0 to get the same behaviour as `None`.
pub top_p: Option<f64>,
}
impl<'a> TaskChat<'a> {
/// Creates a new TaskChat containing one message with the given role and content.
/// All optional TaskChat attributes are left unset.
pub fn with_message(message: Message<'a>) -> Self {
TaskChat {
messages: vec![message],
maximum_tokens: None,
temperature: None,
top_p: None,
}
}
/// Creates a new TaskChat containing the given messages.
/// All optional TaskChat attributes are left unset.
pub fn with_messages(messages: Vec<Message<'a>>) -> Self {
TaskChat {
messages,
maximum_tokens: None,
temperature: None,
top_p: None,
}
}
/// Pushes a new Message to this TaskChat.
pub fn push_message(mut self, message: Message<'a>) -> Self {
self.messages.push(message);
self
}
/// Sets the maximum token attribute of this TaskChat.
pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
self.maximum_tokens = Some(maximum_tokens);
self
}
/// Sets the temperature attribute of this TaskChat.
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
/// Sets the top_p attribute of this TaskChat.
pub fn with_top_p(mut self, top_p: f64) -> Self {
self.top_p = Some(top_p);
self
}
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ChatOutput {
pub message: Message<'static>,
pub finish_reason: String,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ResponseChat {
pub choices: Vec<ChatOutput>,
}
#[derive(Serialize)]
struct ChatBody<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// The list of messages comprising the conversation so far.
messages: &'a [Message<'a>],
/// Limits the number of tokens, which are generated for the completion.
#[serde(skip_serializing_if = "Option::is_none")]
pub maximum_tokens: Option<u32>,
/// Controls the randomness of the model. Lower values will make the model more deterministic and higher values will make it more random.
/// Mathematically, the temperature is used to divide the logits before sampling. A temperature of 0 will always return the most likely token.
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
/// "nucleus" parameter to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out.
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
}
impl<'a> ChatBody<'a> {
pub fn new(model: &'a str, task: &'a TaskChat) -> Self {
Self {
model,
messages: &task.messages,
maximum_tokens: task.maximum_tokens,
temperature: task.temperature,
top_p: task.top_p,
}
}
}
impl<'a> Task for TaskChat<'a> {
type Output = ChatOutput;
type ResponseBody = ResponseChat;
fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = ChatBody::new(model, self);
client.post(format!("{base}/chat/completions")).json(&body)
}
fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
response.choices.pop().unwrap()
}
}