use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use crate::schemas::{Message, StreamData};
use super::{invocation_config::InvocationConfig, options::CallOptions, GenerateResult, LLMError};
#[async_trait]
pub trait LLM: Sync + Send + LLMClone {
async fn generate(&self, messages: &[Message]) -> Result<GenerateResult, LLMError>;
async fn invoke(&self, prompt: &str) -> Result<String, LLMError> {
self.generate(&[Message::new_human_message(prompt)])
.await
.map(|res| res.generation)
}
async fn stream(
&self,
_messages: &[Message],
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, LLMError>> + Send>>, LLMError>;
fn add_options(&mut self, _options: CallOptions) {
}
fn messages_to_string(&self, messages: &[Message]) -> String {
messages
.iter()
.map(|m| format!("{:?}: {}", m.message_type, m.content))
.collect::<Vec<String>>()
.join("\n")
}
async fn invoke_with_config(
&self,
prompt: &str,
_config: Option<&InvocationConfig>,
) -> Result<String, LLMError> {
self.invoke(prompt).await
}
async fn generate_with_config(
&self,
messages: &[Message],
_config: Option<&InvocationConfig>,
) -> Result<GenerateResult, LLMError> {
self.generate(messages).await
}
async fn batch(&self, prompts: &[&str]) -> Result<Vec<Result<String, LLMError>>, LLMError> {
let mut results = Vec::new();
for prompt in prompts {
let result = self.invoke(prompt).await;
results.push(result);
}
Ok(results)
}
async fn batch_generate(
&self,
message_sets: &[&[Message]],
) -> Result<Vec<Result<GenerateResult, LLMError>>, LLMError> {
let mut results = Vec::new();
for messages in message_sets {
let result = self.generate(messages).await;
results.push(result);
}
Ok(results)
}
}
pub trait LLMClone {
fn clone_box(&self) -> Box<dyn LLM>;
}
impl<T> LLMClone for T
where
T: 'static + LLM + Clone,
{
fn clone_box(&self) -> Box<dyn LLM> {
Box::new(self.clone())
}
}
impl<L> From<L> for Box<dyn LLM>
where
L: 'static + LLM,
{
fn from(llm: L) -> Self {
Box::new(llm)
}
}