mod factory;
use std::{ops::Deref, sync::Arc};
use crate::{
provider::LlmBackend,
types::chat::{ChatCompletionRequest, ChatMessage},
};
pub use factory::BackendFactory;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ChatClientOptions {
model: String,
system_prompt: Option<String>,
}
impl ChatClientOptions {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
system_prompt: None,
}
}
pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
self.system_prompt = Some(system_prompt.into());
self
}
pub fn model(&self) -> &str {
&self.model
}
pub fn system_prompt(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
}
#[derive(Clone)]
pub struct ChatClient {
model: String,
system_prompt: Option<String>,
backend: Arc<dyn LlmBackend>,
}
impl ChatClient {
pub fn new(backend: Arc<dyn LlmBackend>, options: ChatClientOptions) -> Self {
Self {
model: options.model,
system_prompt: options.system_prompt,
backend,
}
}
pub fn model(&self) -> &str {
&self.model
}
pub fn system_prompt(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
self.system_prompt = Some(system_prompt.into());
self
}
pub fn clear_system_prompt(mut self) -> Self {
self.system_prompt = None;
self
}
#[must_use = "the returned request must be passed to a backend method to have any effect"]
pub fn create_request(&self, messages: Vec<ChatMessage>) -> ChatCompletionRequest {
let mut request = ChatCompletionRequest::new(self.model.clone(), messages);
if let Some(system_prompt) = &self.system_prompt {
request = request.with_system_prompt(system_prompt.clone());
}
request
}
}
impl Deref for ChatClient {
type Target = dyn crate::LlmBackend;
fn deref(&self) -> &Self::Target {
&*self.backend
}
}
impl std::fmt::Debug for ChatClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChatClient")
.field("family", &self.backend.family())
.field("model", &self.model)
.field("has_system_prompt", &self.system_prompt.is_some())
.finish()
}
}