langchain_rust/language_models/
llm.rs

1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5
6use crate::schemas::{Message, StreamData};
7
8use super::{options::CallOptions, GenerateResult, LLMError};
9
10#[async_trait]
11pub trait LLM: Sync + Send + LLMClone {
12    async fn generate(&self, messages: &[Message]) -> Result<GenerateResult, LLMError>;
13    async fn invoke(&self, prompt: &str) -> Result<String, LLMError> {
14        self.generate(&[Message::new_human_message(prompt)])
15            .await
16            .map(|res| res.generation)
17    }
18    async fn stream(
19        &self,
20        _messages: &[Message],
21    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, LLMError>> + Send>>, LLMError>;
22
23    /// This is usefull when you want to create a chain and override
24    /// LLM options
25    fn add_options(&mut self, _options: CallOptions) {
26        // No action taken
27    }
28    //This is usefull when using non chat models
29    fn messages_to_string(&self, messages: &[Message]) -> String {
30        messages
31            .iter()
32            .map(|m| format!("{:?}: {}", m.message_type, m.content))
33            .collect::<Vec<String>>()
34            .join("\n")
35    }
36}
37
38pub trait LLMClone {
39    fn clone_box(&self) -> Box<dyn LLM>;
40}
41
42impl<T> LLMClone for T
43where
44    T: 'static + LLM + Clone,
45{
46    fn clone_box(&self) -> Box<dyn LLM> {
47        Box::new(self.clone())
48    }
49}
50
51impl<L> From<L> for Box<dyn LLM>
52where
53    L: 'static + LLM,
54{
55    fn from(llm: L) -> Self {
56        Box::new(llm)
57    }
58}