prompt_graph_core/
graph_definition.rs

1use anyhow::anyhow;
2use prost::Message;
3use serde::{Deserialize, Serialize};
4
5use crate::proto as dsl;
6use crate::proto::{ItemCore, Query};
7use crate::proto::prompt_graph_node_loader::LoadFrom;
8
9/// Maps a string to a supported vector database type
10fn map_string_to_vector_database(encoding: &str) -> anyhow::Result<dsl::SupportedVectorDatabase> {
11    match encoding {
12        "IN_MEMORY" => Ok(dsl::SupportedVectorDatabase::InMemory),
13        "CHROMA" => Ok(dsl::SupportedVectorDatabase::Chroma),
14        "PINECONEDB" => Ok(dsl::SupportedVectorDatabase::Pineconedb),
15        "QDRANT" => Ok(dsl::SupportedVectorDatabase::Qdrant),
16        _ => {
17            Err(anyhow!("Unknown vector database: {}", encoding))
18        },
19    }
20}
21
22/// Maps a string to a supported embedding model type
23fn map_string_to_embedding_model(encoding: &str) -> anyhow::Result<dsl::SupportedEmebddingModel> {
24    match encoding {
25        "TEXT_EMBEDDING_ADA_002" => Ok(dsl::SupportedEmebddingModel::TextEmbeddingAda002),
26        "TEXT_SEARCH_ADA_DOC_001" => Ok(dsl::SupportedEmebddingModel::TextSearchAdaDoc001),
27        _ => {
28            Err(anyhow!("Unknown embedding model: {}", encoding))
29        },
30    }
31}
32
33/// Maps a string to a supported chat model type
34fn map_string_to_chat_model(encoding: &str) -> anyhow::Result<dsl::SupportedChatModel> {
35    match encoding {
36        "GPT_4" => Ok(dsl::SupportedChatModel::Gpt4),
37        "GPT_4_0314" => Ok(dsl::SupportedChatModel::Gpt40314),
38        "GPT_4_32K" => Ok(dsl::SupportedChatModel::Gpt432k),
39        "GPT_4_32K_0314" => Ok(dsl::SupportedChatModel::Gpt432k0314),
40        "GPT_3_5_TURBO" => Ok(dsl::SupportedChatModel::Gpt35Turbo),
41        "GPT_3_5_TURBO_0301" => Ok(dsl::SupportedChatModel::Gpt35Turbo0301),
42        _ => {
43            Err(anyhow!("Unknown chat model: {}", encoding))
44        },
45    }
46}
47
48/// Maps a string to a supported source language type
49fn map_string_to_supported_source_langauge(encoding: &str) -> anyhow::Result<dsl::SupportedSourceCodeLanguages> {
50    match encoding {
51        "DENO" => Ok(dsl::SupportedSourceCodeLanguages::Deno),
52        "STARLARK" => Ok(dsl::SupportedSourceCodeLanguages::Starlark),
53        _ => {
54            Err(anyhow!("Unknown source language: {}", encoding))
55        },
56    }
57}
58
59/// Converts a string representing a query definition to a Query type
60fn create_query(query_def: Option<String>) -> dsl::Query {
61     dsl::Query {
62        query: query_def.map(|d|d),
63    }
64}
65
66/// Converts a string representing an output definition to an OutputType type
67fn create_output(output_def: &str) -> Option<dsl::OutputType> {
68    Some(dsl::OutputType {
69        output: output_def.to_string(),
70    })
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74pub enum SourceNodeType {
75    Code(String, String, bool),
76    S3(String),
77    Zipfile(Vec<u8>),
78}
79
80#[derive(Debug, Serialize, Deserialize)]
81pub struct DefinitionGraph {
82    internal: dsl::File,
83}
84
85
86/// A graph definition or DefinitionGraph defines a graph of executable nodes connected by edges or 'triggers'.
87/// The graph is defined in a DSL (domain specific language) that is compiled into a binary formatted File that can be
88/// executed by the prompt-graph-core runtime.
89impl DefinitionGraph {
90
91    /// Returns the File object representing this graph definition
92    pub fn get_file(&self) -> &dsl::File {
93        &self.internal
94    }
95
96    /// Returns an empty graph definition
97    pub fn zero() -> Self {
98        Self {
99            internal: dsl::File::default()
100        }
101    }
102
103    /// Sets this graph definition to read from & write to the given File object
104    pub fn from_file(file: dsl::File) -> Self {
105        Self {
106            internal: file
107        }
108    }
109
110    /// Store the given bytes (representing protobuf graph definition) as a
111    /// new File object and associate this graph definition with it
112    pub fn new(bytes: &[u8]) -> Self {
113        Self {
114            internal: dsl::File::decode(bytes).unwrap()
115        }
116    }
117
118    /// Read and return the nodes from internal File object
119    pub(crate) fn get_nodes(&self) -> &Vec<dsl::Item> {
120        &self.internal.nodes
121    }
122
123    /// Read and return a mutable collection of nodes from internal File object
124    pub(crate) fn get_nodes_mut(&mut self) -> &Vec<dsl::Item> {
125        &self.internal.nodes
126    }
127
128    /// Serialize the internal File object to bytes and return them
129    pub(crate) fn serialize(&self) -> Vec<u8> {
130        let mut buffer = Vec::new();
131        self.internal.encode(&mut buffer).unwrap();
132        buffer
133    }
134
135    /// Push a given node (defined as Item type) to the internal graph definition
136    pub fn register_node(&mut self, item: dsl::Item) {
137        self.internal.nodes.push(item);
138    }
139
140    /// Push a given node (defined as bytes) to the internal graph definition
141    pub fn register_node_bytes(&mut self, item: &[u8]) {
142        let item = dsl::Item::decode(item).unwrap();
143        self.internal.nodes.push(item);
144    }
145}
146
147
148#[deprecated(since="0.1.0", note="do not use")]
149pub fn create_entrypoint_query(
150    query_def: Option<String>
151) -> dsl::Item {
152    let query_element = dsl::Query {
153        query: query_def.map(|x| x.to_string()),
154    };
155    let _node = dsl::PromptGraphNodeCode {
156        source: None,
157    };
158    dsl::Item {
159        core: Some(ItemCore {
160            name: "RegistrationCodeNode".to_string(),
161            triggers: vec![query_element],
162            output: Default::default(),
163            output_tables: vec![],
164        }),
165        item: None,
166    }
167}
168
169/// Takes in common node parameters and returns a fulfilled node type (a dsl::Item type)
170pub fn create_node_parameter(
171    name: String,
172    output_def: String
173) -> dsl::Item {
174    dsl::Item {
175        core: Some(ItemCore {
176            name: name.to_string(),
177            output: create_output(&output_def),
178            triggers: vec![Query { query: None }],
179            output_tables: vec![],
180        }),
181        item: Some(dsl::item::Item::NodeParameter(dsl::PromptGraphParameterNode {
182        })),
183    }
184}
185
186/// Returns a Map type node, which maps a Path (key) to a given String (value)
187pub fn create_op_map(
188    name: String,
189    query_defs: Vec<Option<String>>,
190    path: String,
191    output_tables: Vec<String>
192) -> dsl::Item {
193    dsl::Item {
194        core: Some(ItemCore {
195            name: name.to_string(),
196            triggers: query_defs.into_iter().map(create_query).collect(),
197            // TODO: needs to have the type of the input
198            output: create_output(r#"
199                {
200                    result: String
201                }
202            "#),
203            output_tables,
204        }),
205        item: Some(dsl::item::Item::Map(dsl::PromptGraphMap {
206            path: path.to_string(),
207        })),
208    }
209}
210
211// TODO: automatically wire these into prompt nodes that support function calling
212// TODO: https://platform.openai.com/docs/guides/gpt/function-calling
213/// Takes in executable code and returns a node that executes said code when triggered
214/// This executable code can take the format of:
215/// - a raw string of code in a supported language
216/// - a path to an S3 bucket containing code in a supported language
217/// - a zip file containing code in a supported language
218pub fn create_code_node(
219    name: String,
220    query_defs: Vec<Option<String>>,
221    output_def: String,
222    source_type: SourceNodeType,
223    output_tables: Vec<String>,
224) -> dsl::Item {
225    let source = match source_type {
226        SourceNodeType::Code(language, code, template) => {
227            // https://github.com/denoland/deno/discussions/17345
228            // https://github.com/a-poor/js-in-rs/blob/main/src/main.rs
229            dsl::prompt_graph_node_code::Source::SourceCode( dsl::PromptGraphNodeCodeSourceCode{
230                template,
231                language: map_string_to_supported_source_langauge(&language).unwrap() as i32,
232                source_code: code.to_string(),
233            })
234        }
235        SourceNodeType::S3(path) => {
236            dsl::prompt_graph_node_code::Source::S3Path(path)
237        }
238        SourceNodeType::Zipfile(file) => {
239            dsl::prompt_graph_node_code::Source::Zipfile(file)
240        }
241    };
242
243    dsl::Item {
244        core: Some(ItemCore {
245            name: name.to_string(),
246            triggers: query_defs.into_iter().map(create_query).collect(),
247            output: create_output(&output_def),
248            output_tables
249        }),
250        item: Some(dsl::item::Item::NodeCode(dsl::PromptGraphNodeCode{
251            source: Some(source),
252        })),
253    }
254}
255
256
257
258// TODO: automatically wire these into prompt nodes that support function calling
259// TODO: https://platform.openai.com/docs/guides/gpt/function-calling
260/// Returns a custom node that executes a given function
261/// When registering a custom node in the SDK, you provide an in-language function and
262/// tell chidori to register that function under the given "type_name".
263/// This function executed is then executed in the graph
264/// when referenced by this "type_name" parameter
265pub fn create_custom_node(
266    name: String,
267    query_defs: Vec<Option<String>>,
268    output_def: String,
269    type_name: String,
270    output_tables: Vec<String>
271) -> dsl::Item {
272    dsl::Item {
273        core: Some(ItemCore {
274            name: name.to_string(),
275            triggers: query_defs.into_iter().map(create_query).collect(),
276            output: create_output(&output_def),
277            output_tables
278        }),
279        item: Some(dsl::item::Item::NodeCustom(dsl::PromptGraphNodeCustom{
280            type_name,
281        })),
282    }
283}
284
285/// Returns a node that, when triggered, echoes back its input for easier querying
286pub fn create_observation_node(
287    name: String,
288    query_defs: Vec<Option<String>>,
289    output_def: String,
290    output_tables: Vec<String>
291) -> dsl::Item {
292    dsl::Item {
293        core: Some(ItemCore {
294            name: name.to_string(),
295            triggers: query_defs.into_iter().map(create_query).collect(),
296            output: create_output(&output_def),
297            output_tables
298        }),
299        item: Some(dsl::item::Item::NodeObservation(dsl::PromptGraphNodeObservation{
300            integration: "".to_string(),
301        })),
302    }
303}
304
305/// Returns a node that can perform some READ/WRITE/DELETE operation on
306/// a specified Vector database, using the specified configuration options
307/// (options like the embedding_model to use and collection_name namespace to query within)
308pub fn create_vector_memory_node(
309    name: String,
310    query_defs: Vec<Option<String>>,
311    output_def: String,
312    action: String,
313    embedding_model: String,
314    template: String,
315    db_vendor: String,
316    collection_name: String,
317    output_tables: Vec<String>
318) -> anyhow::Result<dsl::Item> {
319    let model = dsl::prompt_graph_node_memory::EmbeddingModel::Model(map_string_to_embedding_model(&embedding_model)? as i32);
320    let vector_db = dsl::prompt_graph_node_memory::VectorDbProvider::Db(map_string_to_vector_database(&db_vendor)? as i32);
321
322    let action = match action.as_str() {
323        "READ" => {
324            dsl::MemoryAction::Read as i32
325        },
326        "WRITE" => {
327            dsl::MemoryAction::Write as i32
328        },
329        "DELETE" => {
330            dsl::MemoryAction::Delete as i32
331        }
332        _ => { unreachable!("Invalid action") }
333    };
334
335    Ok(dsl::Item {
336        core: Some(ItemCore {
337            name: name.to_string(),
338            triggers: query_defs.into_iter().map(create_query).collect(),
339            output: create_output(&output_def),
340            output_tables
341        }),
342        item: Some(dsl::item::Item::NodeMemory(dsl::PromptGraphNodeMemory{
343            collection_name: collection_name,
344            action,
345            embedding_model: Some(model),
346            template: template,
347            vector_db_provider: Some(vector_db),
348        })),
349    })
350}
351
352/// Returns a node that can implement logic from another graph definition
353/// This is useful for reusing logic across multiple graphs
354/// The graph definition to transclude is specified by either
355/// - a path to an S3 bucket containing a graph definition
356/// - raw bytes of a graph definition
357/// - a File object containing a graph definition
358pub fn create_component_node(
359    name: String,
360    query_defs: Vec<Option<String>>,
361    output_def: String,
362    output_tables: Vec<String>,
363) -> dsl::Item {
364    dsl::Item {
365        core: Some(ItemCore {
366            name: name.to_string(),
367            triggers: query_defs.into_iter().map(create_query).collect(),
368            output: create_output(&output_def),
369            output_tables
370        }),
371        item: Some(dsl::item::Item::NodeComponent(dsl::PromptGraphNodeComponent {
372            transclusion: None,
373        })),
374    }
375}
376
377/// Returns a node that can read bytes from a given source
378pub fn create_loader_node(
379    name: String,
380    query_defs: Vec<Option<String>>,
381    output_def: String,
382    load_from: LoadFrom,
383    output_tables: Vec<String>,
384) -> dsl::Item {
385    dsl::Item {
386        core: Some(ItemCore {
387            name: name.to_string(),
388            triggers: query_defs.into_iter().map(create_query).collect(),
389            output: create_output(&output_def),
390            output_tables
391        }),
392        item: Some(dsl::item::Item::NodeLoader(dsl::PromptGraphNodeLoader {
393            load_from: Some(load_from),
394        })),
395    }
396}
397
398/// Returns a node that, when triggered, performs an API call to a given language model endpoint,
399/// using the template parameter as the prompt input to the language model, and returns the result
400/// to the graph as a String type labeled "promptResult"
401pub fn create_prompt_node(
402    name: String,
403    query_defs: Vec<Option<String>>,
404    template: String,
405    model: String,
406    output_tables: Vec<String>,
407) -> anyhow::Result<dsl::Item> {
408    let chat_model = map_string_to_chat_model(&model)?;
409    let model = dsl::prompt_graph_node_prompt::Model::ChatModel(chat_model as i32);
410    // TODO: use handlebars Template object in order to inspect the contents of and validate the template against the query
411    // https://github.com/sunng87/handlebars-rust/blob/23ca8d76bee783bf72f627b4c4995d1d11008d17/src/template.rs#L963
412    // self.handlebars.register_template_string(name, template).unwrap();
413    // println!("{:?}", Template::compile(&template).unwrap());
414    Ok(dsl::Item {
415        core: Some(ItemCore {
416            name: name.to_string(),
417            triggers: query_defs.into_iter().map(create_query).collect(),
418            output: create_output(r#"
419              {
420                  promptResult: String
421              }
422            "#),
423            output_tables
424        }),
425        item: Some(dsl::item::Item::NodePrompt(dsl::PromptGraphNodePrompt{
426            template: template.to_string(),
427            model: Some(model),
428            // TODO: add output but set it to some sane defaults
429            temperature: 1.0,
430            top_p: 1.0,
431            max_tokens: 100,
432            presence_penalty: 0.0,
433            frequency_penalty: 0.0,
434            stop: vec![],
435        })),
436    })
437}