1use std::{future::Future, sync::Arc};
46
47use lemon_graph::{
48 nodes::{AsyncNode, GetStoreError, NodeError, NodeWrapper, StoreWrapper},
49 Graph, GraphEdge, GraphNode, Value,
50};
51use petgraph::graph::NodeIndex;
52use thiserror::Error;
53use tracing::error;
54
55#[cfg(feature = "ollama")]
56pub mod ollama;
57#[cfg(feature = "replicate")]
58pub mod replicate;
59
60#[derive(Debug, Clone, Copy)]
61pub struct LlmNode(pub NodeIndex);
62
63impl From<LlmNode> for NodeIndex {
64 fn from(value: LlmNode) -> Self {
65 value.0
66 }
67}
68
69impl NodeWrapper for LlmNode {}
70
71impl LlmNode {
72 pub fn new<T: LlmBackend>(graph: &mut Graph, weight: LlmWeight<T>) -> Self {
73 let index = graph.add_node(GraphNode::AsyncNode(Box::new(weight)));
74
75 let input = graph.add_node(GraphNode::Store(Value::String(Default::default())));
76 graph.add_edge(input, index, GraphEdge::DataMap(0));
77
78 let output = graph.add_node(GraphNode::Store(Value::String(Default::default())));
79 graph.add_edge(index, output, GraphEdge::DataMap(0));
80
81 Self(index)
82 }
83
84 pub fn prompt(&self, graph: &Graph) -> Result<StoreWrapper, GetStoreError> {
85 self.input_stores(graph)
86 .next()
87 .ok_or(GetStoreError::NoStore)
88 }
89
90 pub fn response(&self, graph: &Graph) -> Result<StoreWrapper, GetStoreError> {
91 self.output_stores(graph)
92 .next()
93 .ok_or(GetStoreError::NoStore)
94 }
95}
96
97#[derive(Debug, Error)]
98pub enum GenerateError {
99 #[error("Backend error: {0}")]
100 BackendError(String),
101}
102
103pub trait LlmBackend {
104 fn generate(&self, prompt: &str) -> impl Future<Output = Result<String, GenerateError>>;
105}
106
107pub struct LlmWeight<T: LlmBackend + 'static> {
108 pub backend: Arc<T>,
109}
110
111impl<T: LlmBackend> LlmWeight<T> {
112 pub fn new(backend: Arc<T>) -> Self {
113 Self { backend }
114 }
115}
116
117impl<T: LlmBackend> AsyncNode for LlmWeight<T> {
118 fn run(
119 &self,
120 inputs: Vec<Value>,
121 ) -> Box<dyn Future<Output = Result<Vec<Value>, NodeError>> + Unpin> {
122 let backend = self.backend.clone();
123
124 Box::new(Box::pin(async move {
125 let prompt = match inputs.first() {
126 Some(Value::String(prompt)) => prompt.clone(),
127 Some(v) => return Err(NodeError::ConversionError(v.clone())),
128 None => return Err(NodeError::MissingInput(0)),
129 };
130
131 let response = backend
132 .generate(&prompt)
133 .await
134 .map_err(|e| NodeError::InternalError(format!("Failed to generate: {}", e)))?;
135
136 Ok(vec![Value::String(response)])
137 }))
138 }
139}