lemon_llm/
lib.rs

1//! LLM nodes for [lemon-graph](https://github.com/unavi-xyz/lemon/tree/main/crates/lemon-graph).
2//!
3//! # Usage
4//!
5//! ```
6//! use std::sync::Arc;
7//!
8//! use lemon_graph::{Graph, Executor, nodes::{NodeWrapper, LogNode}};
9//! use lemon_llm::{ollama::{OllamaBackend, OllamaModel}, LlmBackend, LlmNode, LlmWeight};
10//!
11//! #[tokio::main]
12//! async fn main() {
13//!    let mut graph = Graph::new();
14//!
15//!    // Create a new Ollama backend.
16//!    let backend = Arc::new(OllamaBackend {
17//!        model: OllamaModel::Mistral7B,
18//!        ..Default::default()
19//!    });
20//!
21//!    // Create an llm node, using our Ollama backend.
22//!    let llm = LlmNode::new(&mut graph, LlmWeight::new(backend.clone()));
23//!
24//!    // Set the prompt manually.
25//!    let prompt = llm.prompt(&graph).unwrap();
26//!    prompt.set_value(
27//!         &mut graph,
28//!         "Tell me your favorite lemon fact.".to_string().into(),
29//!    );
30//!
31//!    // Connect the reponse to a log node.
32//!    let response = llm.response(&graph).unwrap();
33//!
34//!    let log = LogNode::new(&mut graph);
35//!    log.run_after(&mut graph, llm.0);
36//!
37//!    let message = log.message(&graph).unwrap();
38//!    message.set_input(&mut graph, Some(response));
39//!
40//!    // Execute the graph.
41//!    // Executor::execute(&mut graph, llm.0).await.unwrap();
42//! }
43//! ```
44
45use std::{future::Future, sync::Arc};
46
47use lemon_graph::{
48    nodes::{AsyncNode, GetStoreError, NodeError, NodeWrapper, StoreWrapper},
49    Graph, GraphEdge, GraphNode, Value,
50};
51use petgraph::graph::NodeIndex;
52use thiserror::Error;
53use tracing::error;
54
55#[cfg(feature = "ollama")]
56pub mod ollama;
57#[cfg(feature = "replicate")]
58pub mod replicate;
59
60#[derive(Debug, Clone, Copy)]
61pub struct LlmNode(pub NodeIndex);
62
63impl From<LlmNode> for NodeIndex {
64    fn from(value: LlmNode) -> Self {
65        value.0
66    }
67}
68
69impl NodeWrapper for LlmNode {}
70
71impl LlmNode {
72    pub fn new<T: LlmBackend>(graph: &mut Graph, weight: LlmWeight<T>) -> Self {
73        let index = graph.add_node(GraphNode::AsyncNode(Box::new(weight)));
74
75        let input = graph.add_node(GraphNode::Store(Value::String(Default::default())));
76        graph.add_edge(input, index, GraphEdge::DataMap(0));
77
78        let output = graph.add_node(GraphNode::Store(Value::String(Default::default())));
79        graph.add_edge(index, output, GraphEdge::DataMap(0));
80
81        Self(index)
82    }
83
84    pub fn prompt(&self, graph: &Graph) -> Result<StoreWrapper, GetStoreError> {
85        self.input_stores(graph)
86            .next()
87            .ok_or(GetStoreError::NoStore)
88    }
89
90    pub fn response(&self, graph: &Graph) -> Result<StoreWrapper, GetStoreError> {
91        self.output_stores(graph)
92            .next()
93            .ok_or(GetStoreError::NoStore)
94    }
95}
96
97#[derive(Debug, Error)]
98pub enum GenerateError {
99    #[error("Backend error: {0}")]
100    BackendError(String),
101}
102
103pub trait LlmBackend {
104    fn generate(&self, prompt: &str) -> impl Future<Output = Result<String, GenerateError>>;
105}
106
107pub struct LlmWeight<T: LlmBackend + 'static> {
108    pub backend: Arc<T>,
109}
110
111impl<T: LlmBackend> LlmWeight<T> {
112    pub fn new(backend: Arc<T>) -> Self {
113        Self { backend }
114    }
115}
116
117impl<T: LlmBackend> AsyncNode for LlmWeight<T> {
118    fn run(
119        &self,
120        inputs: Vec<Value>,
121    ) -> Box<dyn Future<Output = Result<Vec<Value>, NodeError>> + Unpin> {
122        let backend = self.backend.clone();
123
124        Box::new(Box::pin(async move {
125            let prompt = match inputs.first() {
126                Some(Value::String(prompt)) => prompt.clone(),
127                Some(v) => return Err(NodeError::ConversionError(v.clone())),
128                None => return Err(NodeError::MissingInput(0)),
129            };
130
131            let response = backend
132                .generate(&prompt)
133                .await
134                .map_err(|e| NodeError::InternalError(format!("Failed to generate: {}", e)))?;
135
136            Ok(vec![Value::String(response)])
137        }))
138    }
139}