langchain_rust/language_models/
llm.rs1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5
6use crate::schemas::{Message, StreamData};
7
8use super::{options::CallOptions, GenerateResult, LLMError};
9
10#[async_trait]
11pub trait LLM: Sync + Send + LLMClone {
12 async fn generate(&self, messages: &[Message]) -> Result<GenerateResult, LLMError>;
13 async fn invoke(&self, prompt: &str) -> Result<String, LLMError> {
14 self.generate(&[Message::new_human_message(prompt)])
15 .await
16 .map(|res| res.generation)
17 }
18 async fn stream(
19 &self,
20 _messages: &[Message],
21 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, LLMError>> + Send>>, LLMError>;
22
23 fn add_options(&mut self, _options: CallOptions) {
26 }
28 fn messages_to_string(&self, messages: &[Message]) -> String {
30 messages
31 .iter()
32 .map(|m| format!("{:?}: {}", m.message_type, m.content))
33 .collect::<Vec<String>>()
34 .join("\n")
35 }
36}
37
38pub trait LLMClone {
39 fn clone_box(&self) -> Box<dyn LLM>;
40}
41
42impl<T> LLMClone for T
43where
44 T: 'static + LLM + Clone,
45{
46 fn clone_box(&self) -> Box<dyn LLM> {
47 Box::new(self.clone())
48 }
49}
50
51impl<L> From<L> for Box<dyn LLM>
52where
53 L: 'static + LLM,
54{
55 fn from(llm: L) -> Self {
56 Box::new(llm)
57 }
58}