1use anyhow::anyhow;
2use prost::Message;
3use serde::{Deserialize, Serialize};
4
5use crate::proto as dsl;
6use crate::proto::{ItemCore, Query};
7use crate::proto::prompt_graph_node_loader::LoadFrom;
8
9fn map_string_to_vector_database(encoding: &str) -> anyhow::Result<dsl::SupportedVectorDatabase> {
11 match encoding {
12 "IN_MEMORY" => Ok(dsl::SupportedVectorDatabase::InMemory),
13 "CHROMA" => Ok(dsl::SupportedVectorDatabase::Chroma),
14 "PINECONEDB" => Ok(dsl::SupportedVectorDatabase::Pineconedb),
15 "QDRANT" => Ok(dsl::SupportedVectorDatabase::Qdrant),
16 _ => {
17 Err(anyhow!("Unknown vector database: {}", encoding))
18 },
19 }
20}
21
22fn map_string_to_embedding_model(encoding: &str) -> anyhow::Result<dsl::SupportedEmebddingModel> {
24 match encoding {
25 "TEXT_EMBEDDING_ADA_002" => Ok(dsl::SupportedEmebddingModel::TextEmbeddingAda002),
26 "TEXT_SEARCH_ADA_DOC_001" => Ok(dsl::SupportedEmebddingModel::TextSearchAdaDoc001),
27 _ => {
28 Err(anyhow!("Unknown embedding model: {}", encoding))
29 },
30 }
31}
32
33fn map_string_to_chat_model(encoding: &str) -> anyhow::Result<dsl::SupportedChatModel> {
35 match encoding {
36 "GPT_4" => Ok(dsl::SupportedChatModel::Gpt4),
37 "GPT_4_0314" => Ok(dsl::SupportedChatModel::Gpt40314),
38 "GPT_4_32K" => Ok(dsl::SupportedChatModel::Gpt432k),
39 "GPT_4_32K_0314" => Ok(dsl::SupportedChatModel::Gpt432k0314),
40 "GPT_3_5_TURBO" => Ok(dsl::SupportedChatModel::Gpt35Turbo),
41 "GPT_3_5_TURBO_0301" => Ok(dsl::SupportedChatModel::Gpt35Turbo0301),
42 _ => {
43 Err(anyhow!("Unknown chat model: {}", encoding))
44 },
45 }
46}
47
48fn map_string_to_supported_source_langauge(encoding: &str) -> anyhow::Result<dsl::SupportedSourceCodeLanguages> {
50 match encoding {
51 "DENO" => Ok(dsl::SupportedSourceCodeLanguages::Deno),
52 "STARLARK" => Ok(dsl::SupportedSourceCodeLanguages::Starlark),
53 _ => {
54 Err(anyhow!("Unknown source language: {}", encoding))
55 },
56 }
57}
58
59fn create_query(query_def: Option<String>) -> dsl::Query {
61 dsl::Query {
62 query: query_def.map(|d|d),
63 }
64}
65
66fn create_output(output_def: &str) -> Option<dsl::OutputType> {
68 Some(dsl::OutputType {
69 output: output_def.to_string(),
70 })
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74pub enum SourceNodeType {
75 Code(String, String, bool),
76 S3(String),
77 Zipfile(Vec<u8>),
78}
79
80#[derive(Debug, Serialize, Deserialize)]
81pub struct DefinitionGraph {
82 internal: dsl::File,
83}
84
85
86impl DefinitionGraph {
90
91 pub fn get_file(&self) -> &dsl::File {
93 &self.internal
94 }
95
96 pub fn zero() -> Self {
98 Self {
99 internal: dsl::File::default()
100 }
101 }
102
103 pub fn from_file(file: dsl::File) -> Self {
105 Self {
106 internal: file
107 }
108 }
109
110 pub fn new(bytes: &[u8]) -> Self {
113 Self {
114 internal: dsl::File::decode(bytes).unwrap()
115 }
116 }
117
118 pub(crate) fn get_nodes(&self) -> &Vec<dsl::Item> {
120 &self.internal.nodes
121 }
122
123 pub(crate) fn get_nodes_mut(&mut self) -> &Vec<dsl::Item> {
125 &self.internal.nodes
126 }
127
128 pub(crate) fn serialize(&self) -> Vec<u8> {
130 let mut buffer = Vec::new();
131 self.internal.encode(&mut buffer).unwrap();
132 buffer
133 }
134
135 pub fn register_node(&mut self, item: dsl::Item) {
137 self.internal.nodes.push(item);
138 }
139
140 pub fn register_node_bytes(&mut self, item: &[u8]) {
142 let item = dsl::Item::decode(item).unwrap();
143 self.internal.nodes.push(item);
144 }
145}
146
147
148#[deprecated(since="0.1.0", note="do not use")]
149pub fn create_entrypoint_query(
150 query_def: Option<String>
151) -> dsl::Item {
152 let query_element = dsl::Query {
153 query: query_def.map(|x| x.to_string()),
154 };
155 let _node = dsl::PromptGraphNodeCode {
156 source: None,
157 };
158 dsl::Item {
159 core: Some(ItemCore {
160 name: "RegistrationCodeNode".to_string(),
161 triggers: vec![query_element],
162 output: Default::default(),
163 output_tables: vec![],
164 }),
165 item: None,
166 }
167}
168
169pub fn create_node_parameter(
171 name: String,
172 output_def: String
173) -> dsl::Item {
174 dsl::Item {
175 core: Some(ItemCore {
176 name: name.to_string(),
177 output: create_output(&output_def),
178 triggers: vec![Query { query: None }],
179 output_tables: vec![],
180 }),
181 item: Some(dsl::item::Item::NodeParameter(dsl::PromptGraphParameterNode {
182 })),
183 }
184}
185
186pub fn create_op_map(
188 name: String,
189 query_defs: Vec<Option<String>>,
190 path: String,
191 output_tables: Vec<String>
192) -> dsl::Item {
193 dsl::Item {
194 core: Some(ItemCore {
195 name: name.to_string(),
196 triggers: query_defs.into_iter().map(create_query).collect(),
197 output: create_output(r#"
199 {
200 result: String
201 }
202 "#),
203 output_tables,
204 }),
205 item: Some(dsl::item::Item::Map(dsl::PromptGraphMap {
206 path: path.to_string(),
207 })),
208 }
209}
210
211pub fn create_code_node(
219 name: String,
220 query_defs: Vec<Option<String>>,
221 output_def: String,
222 source_type: SourceNodeType,
223 output_tables: Vec<String>,
224) -> dsl::Item {
225 let source = match source_type {
226 SourceNodeType::Code(language, code, template) => {
227 dsl::prompt_graph_node_code::Source::SourceCode( dsl::PromptGraphNodeCodeSourceCode{
230 template,
231 language: map_string_to_supported_source_langauge(&language).unwrap() as i32,
232 source_code: code.to_string(),
233 })
234 }
235 SourceNodeType::S3(path) => {
236 dsl::prompt_graph_node_code::Source::S3Path(path)
237 }
238 SourceNodeType::Zipfile(file) => {
239 dsl::prompt_graph_node_code::Source::Zipfile(file)
240 }
241 };
242
243 dsl::Item {
244 core: Some(ItemCore {
245 name: name.to_string(),
246 triggers: query_defs.into_iter().map(create_query).collect(),
247 output: create_output(&output_def),
248 output_tables
249 }),
250 item: Some(dsl::item::Item::NodeCode(dsl::PromptGraphNodeCode{
251 source: Some(source),
252 })),
253 }
254}
255
256
257
258pub fn create_custom_node(
266 name: String,
267 query_defs: Vec<Option<String>>,
268 output_def: String,
269 type_name: String,
270 output_tables: Vec<String>
271) -> dsl::Item {
272 dsl::Item {
273 core: Some(ItemCore {
274 name: name.to_string(),
275 triggers: query_defs.into_iter().map(create_query).collect(),
276 output: create_output(&output_def),
277 output_tables
278 }),
279 item: Some(dsl::item::Item::NodeCustom(dsl::PromptGraphNodeCustom{
280 type_name,
281 })),
282 }
283}
284
285pub fn create_observation_node(
287 name: String,
288 query_defs: Vec<Option<String>>,
289 output_def: String,
290 output_tables: Vec<String>
291) -> dsl::Item {
292 dsl::Item {
293 core: Some(ItemCore {
294 name: name.to_string(),
295 triggers: query_defs.into_iter().map(create_query).collect(),
296 output: create_output(&output_def),
297 output_tables
298 }),
299 item: Some(dsl::item::Item::NodeObservation(dsl::PromptGraphNodeObservation{
300 integration: "".to_string(),
301 })),
302 }
303}
304
305pub fn create_vector_memory_node(
309 name: String,
310 query_defs: Vec<Option<String>>,
311 output_def: String,
312 action: String,
313 embedding_model: String,
314 template: String,
315 db_vendor: String,
316 collection_name: String,
317 output_tables: Vec<String>
318) -> anyhow::Result<dsl::Item> {
319 let model = dsl::prompt_graph_node_memory::EmbeddingModel::Model(map_string_to_embedding_model(&embedding_model)? as i32);
320 let vector_db = dsl::prompt_graph_node_memory::VectorDbProvider::Db(map_string_to_vector_database(&db_vendor)? as i32);
321
322 let action = match action.as_str() {
323 "READ" => {
324 dsl::MemoryAction::Read as i32
325 },
326 "WRITE" => {
327 dsl::MemoryAction::Write as i32
328 },
329 "DELETE" => {
330 dsl::MemoryAction::Delete as i32
331 }
332 _ => { unreachable!("Invalid action") }
333 };
334
335 Ok(dsl::Item {
336 core: Some(ItemCore {
337 name: name.to_string(),
338 triggers: query_defs.into_iter().map(create_query).collect(),
339 output: create_output(&output_def),
340 output_tables
341 }),
342 item: Some(dsl::item::Item::NodeMemory(dsl::PromptGraphNodeMemory{
343 collection_name: collection_name,
344 action,
345 embedding_model: Some(model),
346 template: template,
347 vector_db_provider: Some(vector_db),
348 })),
349 })
350}
351
352pub fn create_component_node(
359 name: String,
360 query_defs: Vec<Option<String>>,
361 output_def: String,
362 output_tables: Vec<String>,
363) -> dsl::Item {
364 dsl::Item {
365 core: Some(ItemCore {
366 name: name.to_string(),
367 triggers: query_defs.into_iter().map(create_query).collect(),
368 output: create_output(&output_def),
369 output_tables
370 }),
371 item: Some(dsl::item::Item::NodeComponent(dsl::PromptGraphNodeComponent {
372 transclusion: None,
373 })),
374 }
375}
376
377pub fn create_loader_node(
379 name: String,
380 query_defs: Vec<Option<String>>,
381 output_def: String,
382 load_from: LoadFrom,
383 output_tables: Vec<String>,
384) -> dsl::Item {
385 dsl::Item {
386 core: Some(ItemCore {
387 name: name.to_string(),
388 triggers: query_defs.into_iter().map(create_query).collect(),
389 output: create_output(&output_def),
390 output_tables
391 }),
392 item: Some(dsl::item::Item::NodeLoader(dsl::PromptGraphNodeLoader {
393 load_from: Some(load_from),
394 })),
395 }
396}
397
398pub fn create_prompt_node(
402 name: String,
403 query_defs: Vec<Option<String>>,
404 template: String,
405 model: String,
406 output_tables: Vec<String>,
407) -> anyhow::Result<dsl::Item> {
408 let chat_model = map_string_to_chat_model(&model)?;
409 let model = dsl::prompt_graph_node_prompt::Model::ChatModel(chat_model as i32);
410 Ok(dsl::Item {
415 core: Some(ItemCore {
416 name: name.to_string(),
417 triggers: query_defs.into_iter().map(create_query).collect(),
418 output: create_output(r#"
419 {
420 promptResult: String
421 }
422 "#),
423 output_tables
424 }),
425 item: Some(dsl::item::Item::NodePrompt(dsl::PromptGraphNodePrompt{
426 template: template.to_string(),
427 model: Some(model),
428 temperature: 1.0,
430 top_p: 1.0,
431 max_tokens: 100,
432 presence_penalty: 0.0,
433 frequency_penalty: 0.0,
434 stop: vec![],
435 })),
436 })
437}