use crate::error::Result;
use bon::{Builder, bon, builder};
use reqwest::Client;
use serde::Serialize;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::time::Duration;
use tap::Pipe;
#[derive(Debug, Clone, Serialize, PartialEq)]
pub enum MessageRole {
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
}
#[derive(Debug, Clone, Serialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
impl Message {
pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(MessageRole::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(MessageRole::Assistant, content)
}
pub fn is_user(&self) -> bool {
self.role == MessageRole::User
}
pub fn is_assistant(&self) -> bool {
self.role == MessageRole::Assistant
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub schema: Option<Value>,
}
impl ResponseFormat {
pub fn json_object() -> Self {
Self {
format_type: "json_object".to_string(),
schema: None,
}
}
pub fn json_schema(schema: Value) -> Self {
Self {
format_type: "json_object".to_string(),
schema: Some(schema),
}
}
}
#[derive(Debug, Clone, Default, Builder)]
pub struct GenerationConfig {
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub min_p: Option<f32>,
pub typical_p: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub repeat_penalty: Option<f32>,
pub repeat_last_n: Option<u32>,
pub tfs_z: Option<f32>,
pub mirostat_mode: Option<u32>,
pub mirostat_tau: Option<f32>,
pub mirostat_eta: Option<f32>,
pub samplers: Option<Vec<String>>,
pub seed: Option<i32>,
pub stop: Option<Vec<String>>,
pub stream: Option<bool>,
pub echo: Option<bool>,
pub logprobs: Option<u32>,
pub top_logprobs: Option<u32>,
pub logit_bias: Option<HashMap<String, f32>>,
pub response_format: Option<ResponseFormat>,
pub grammar: Option<String>,
pub n: Option<u32>,
pub penalize_nl: Option<bool>,
pub ignore_eos: Option<bool>,
}
impl GenerationConfig {
pub fn merge(&self, override_config: &GenerationConfig) -> Self {
Self {
max_tokens: override_config.max_tokens.or(self.max_tokens),
temperature: override_config.temperature.or(self.temperature),
top_p: override_config.top_p.or(self.top_p),
top_k: override_config.top_k.or(self.top_k),
min_p: override_config.min_p.or(self.min_p),
typical_p: override_config.typical_p.or(self.typical_p),
frequency_penalty: override_config.frequency_penalty.or(self.frequency_penalty),
presence_penalty: override_config.presence_penalty.or(self.presence_penalty),
repeat_penalty: override_config.repeat_penalty.or(self.repeat_penalty),
repeat_last_n: override_config.repeat_last_n.or(self.repeat_last_n),
tfs_z: override_config.tfs_z.or(self.tfs_z),
mirostat_mode: override_config.mirostat_mode.or(self.mirostat_mode),
mirostat_tau: override_config.mirostat_tau.or(self.mirostat_tau),
mirostat_eta: override_config.mirostat_eta.or(self.mirostat_eta),
samplers: override_config
.samplers
.clone()
.or_else(|| self.samplers.clone()),
seed: override_config.seed.or(self.seed),
stop: override_config.stop.clone().or_else(|| self.stop.clone()),
stream: override_config.stream.or(self.stream),
echo: override_config.echo.or(self.echo),
logprobs: override_config.logprobs.or(self.logprobs),
top_logprobs: override_config.top_logprobs.or(self.top_logprobs),
logit_bias: override_config
.logit_bias
.clone()
.or_else(|| self.logit_bias.clone()),
response_format: override_config
.response_format
.clone()
.or_else(|| self.response_format.clone()),
grammar: override_config
.grammar
.clone()
.or_else(|| self.grammar.clone()),
n: override_config.n.or(self.n),
penalize_nl: override_config.penalize_nl.or(self.penalize_nl),
ignore_eos: override_config.ignore_eos.or(self.ignore_eos),
}
}
pub fn to_json(&self) -> Value {
let mut obj = serde_json::Map::new();
macro_rules! add_if_some {
($field:ident) => {
if let Some(val) = &self.$field {
obj.insert(stringify!($field).to_string(), json!(val));
}
};
}
add_if_some!(max_tokens);
add_if_some!(temperature);
add_if_some!(top_p);
add_if_some!(top_k);
add_if_some!(min_p);
add_if_some!(typical_p);
add_if_some!(frequency_penalty);
add_if_some!(presence_penalty);
add_if_some!(repeat_penalty);
add_if_some!(repeat_last_n);
add_if_some!(tfs_z);
add_if_some!(mirostat_mode);
add_if_some!(mirostat_tau);
add_if_some!(mirostat_eta);
add_if_some!(samplers);
add_if_some!(seed);
add_if_some!(stop);
add_if_some!(stream);
add_if_some!(echo);
add_if_some!(logprobs);
add_if_some!(top_logprobs);
add_if_some!(logit_bias);
add_if_some!(response_format);
add_if_some!(grammar);
add_if_some!(n);
add_if_some!(penalize_nl);
add_if_some!(ignore_eos);
Value::Object(obj)
}
}
#[derive(Debug, Clone, Builder)]
pub struct Chat {
pub system_prompt: Option<String>,
#[builder(default)]
pub messages: Vec<Message>,
}
impl Chat {
pub fn new() -> Self {
Self {
system_prompt: None,
messages: Vec::new(),
}
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
pub fn to_json(&self) -> Vec<Value> {
let mut json_messages = Vec::new();
if let Some(system) = &self.system_prompt {
json_messages.push(json!({
"role": "system",
"content": system
}));
}
json_messages.extend(self.messages.iter().map(|message| {
json!({
"role": message.role,
"content": message.content
})
}));
json_messages
}
pub fn len_with_system_prompt(&self) -> usize {
self.messages.len() + self.system_prompt.is_some() as usize
}
}
#[derive(Debug, Clone, Builder)]
#[builder(on(String, into))]
pub struct LlmClient {
client: Client,
api_url: String,
api_key: Option<String>,
#[builder(default = GenerationConfig {
max_tokens: Some(512),
temperature: Some(0.7),
frequency_penalty: Some(0.0),
..Default::default()
})]
default_config: GenerationConfig,
}
#[bon]
impl LlmClient {
pub fn builder_with_default_client() -> Result<LlmClientBuilder<llm_client_builder::SetClient>>
{
Ok(LlmClient::builder().client(
reqwest::ClientBuilder::new()
.connect_timeout(Duration::from_secs(10))
.read_timeout(Duration::from_secs(360))
.build()?,
))
}
#[builder]
pub async fn generate_response_and_add_to_chat(
&self,
chat: &mut Chat,
config: Option<GenerationConfig>,
api_url: Option<&str>,
api_key: Option<&str>,
) -> Result<String> {
let messages = chat.to_json();
let response = self.call_api(messages, config, api_url, api_key).await?;
chat.add_message(Message::new(MessageRole::Assistant, &response));
Ok(response)
}
#[builder]
pub async fn generate_response(
&self,
chat: &Chat,
config: Option<GenerationConfig>,
api_url: Option<&str>,
api_key: Option<&str>,
) -> Result<String> {
let messages = chat.to_json();
self.call_api(messages, config, api_url, api_key).await
}
#[cfg(feature = "mock_llm_api")]
async fn call_api(
&self,
messages: Vec<Value>,
_config: Option<GenerationConfig>,
_api_url: Option<&str>,
_api_key: Option<&str>,
) -> Result<String> {
Ok(messages
.iter()
.map(|message| message.to_string())
.collect::<Vec<String>>()
.join("\n"))
}
#[cfg(not(feature = "mock_llm_api"))]
async fn call_api(
&self,
messages: Vec<Value>,
config: Option<GenerationConfig>,
api_url: Option<&str>,
api_key: Option<&str>,
) -> Result<String> {
use crate::error::Error;
let api_url = api_url.unwrap_or(&self.api_url);
let api_key = api_key.or(self.api_key.as_deref());
let final_config = if let Some(override_config) = config {
self.default_config.merge(&override_config)
} else {
self.default_config.clone()
};
let mut body = json!({
"messages": messages,
});
if let Value::Object(config_map) = final_config.to_json() {
if let Value::Object(body_map) = &mut body {
body_map.extend(config_map);
}
}
let response = self
.client
.post(api_url)
.pipe(|req| {
if let Some(key) = api_key {
req.header("Authorization", format!("Bearer {}", key))
} else {
req
}
})
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?
.json::<Value>()
.await?;
let content = response
.pointer("/choices/0/message/content")
.ok_or(Error::IssueWithLlmApiReturnedJson)?
.as_str()
.ok_or(Error::FailedToExtractResponseContent)?
.to_string();
Ok(content)
}
}