1use crate::core::{ExecutionContext, ExecutionResult, Node, NodeId};
7use crate::state::{GraphState, StateValue};
8use crate::{RGraphError, RGraphResult};
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub struct RagRetrievalConfig {
20 pub query_key: String,
21 pub context_key: String,
22 pub top_k: usize,
23 pub similarity_threshold: Option<f32>,
24 pub metadata_filters: Vec<(String, String)>,
25}
26
27impl Default for RagRetrievalConfig {
28 fn default() -> Self {
29 Self {
30 query_key: "user_query".to_string(),
31 context_key: "retrieval_context".to_string(),
32 top_k: 5,
33 similarity_threshold: Some(0.7),
34 metadata_filters: Vec::new(),
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
42pub struct RagGenerationConfig {
43 pub query_key: String,
44 pub context_key: String,
45 pub response_key: String,
46 pub system_prompt: Option<String>,
47 pub max_tokens: Option<usize>,
48 pub temperature: Option<f32>,
49}
50
51impl Default for RagGenerationConfig {
52 fn default() -> Self {
53 Self {
54 query_key: "user_query".to_string(),
55 context_key: "retrieval_context".to_string(),
56 response_key: "rag_response".to_string(),
57 system_prompt: Some(
58 "You are a helpful AI assistant. Use the provided context to answer the user's question accurately and comprehensively.".to_string()
59 ),
60 max_tokens: Some(512),
61 temperature: Some(0.7),
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69pub struct ContextEvaluationConfig {
70 pub context_key: String,
71 pub query_key: String,
72 pub relevance_key: String,
73 pub min_relevance_score: f32,
74}
75
76impl Default for ContextEvaluationConfig {
77 fn default() -> Self {
78 Self {
79 context_key: "retrieval_context".to_string(),
80 query_key: "user_query".to_string(),
81 relevance_key: "context_relevance".to_string(),
82 min_relevance_score: 0.6,
83 }
84 }
85}
86
87pub struct RagRetrievalNode {
89 id: NodeId,
90 name: String,
91 config: RagRetrievalConfig,
92}
93
94impl RagRetrievalNode {
95 pub fn new(id: impl Into<NodeId>, name: impl Into<String>, config: RagRetrievalConfig) -> Self {
96 Self {
97 id: id.into(),
98 name: name.into(),
99 config,
100 }
101 }
102}
103
104#[async_trait]
105impl Node for RagRetrievalNode {
106 async fn execute(
107 &self,
108 state: &mut GraphState,
109 context: &ExecutionContext,
110 ) -> RGraphResult<ExecutionResult> {
111 let query = state.get(&self.config.query_key)?;
113 let query_text = query
114 .as_string()
115 .ok_or_else(|| RGraphError::node(self.id.as_str(), "Query must be a string"))?;
116
117 let mock_documents = vec![
119 create_mock_document("Machine learning is a method of data analysis that automates analytical model building.", 0.85),
120 create_mock_document("It is a branch of artificial intelligence based on the idea that systems can learn from data.", 0.78),
121 create_mock_document("ML algorithms build a model based on training data to make predictions or decisions.", 0.72),
122 ];
123
124 let filtered_docs: Vec<StateValue> =
126 if let Some(threshold) = self.config.similarity_threshold {
127 mock_documents
128 .into_iter()
129 .filter(|doc| {
130 if let Some(obj) = doc.as_object() {
131 if let Some(score_val) = obj.get("score") {
132 if let Some(score) = score_val.as_float() {
133 return score >= threshold as f64;
134 }
135 }
136 }
137 false
138 })
139 .take(self.config.top_k)
140 .collect()
141 } else {
142 mock_documents.into_iter().take(self.config.top_k).collect()
143 };
144
145 state.set_with_context(
147 context.current_node.as_str(),
148 &self.config.context_key,
149 StateValue::Array(filtered_docs.clone()),
150 );
151
152 state.set_with_context(
154 context.current_node.as_str(),
155 "retrieval_metadata",
156 StateValue::from(serde_json::json!({
157 "total_results": filtered_docs.len(),
158 "retrieved_count": filtered_docs.len(),
159 "query": query_text
160 })),
161 );
162
163 Ok(ExecutionResult::Continue)
164 }
165
166 fn id(&self) -> &NodeId {
167 &self.id
168 }
169
170 fn name(&self) -> &str {
171 &self.name
172 }
173
174 fn input_keys(&self) -> Vec<&str> {
175 vec![&self.config.query_key]
176 }
177
178 fn output_keys(&self) -> Vec<&str> {
179 vec![&self.config.context_key, "retrieval_metadata"]
180 }
181}
182
183pub struct RagGenerationNode {
185 id: NodeId,
186 name: String,
187 config: RagGenerationConfig,
188}
189
190impl RagGenerationNode {
191 pub fn new(
192 id: impl Into<NodeId>,
193 name: impl Into<String>,
194 config: RagGenerationConfig,
195 ) -> Self {
196 Self {
197 id: id.into(),
198 name: name.into(),
199 config,
200 }
201 }
202}
203
204#[async_trait]
205impl Node for RagGenerationNode {
206 async fn execute(
207 &self,
208 state: &mut GraphState,
209 context: &ExecutionContext,
210 ) -> RGraphResult<ExecutionResult> {
211 let query = state.get(&self.config.query_key)?;
213 let query_text = query
214 .as_string()
215 .ok_or_else(|| RGraphError::node(self.id.as_str(), "Query must be a string"))?;
216
217 let context_value = state.get(&self.config.context_key)?;
218 let context_docs = if let Some(array) = context_value.as_array() {
219 array
220 .iter()
221 .filter_map(|v| {
222 if let Some(obj) = v.as_object() {
223 if let Some(content) = obj.get("content") {
224 content.as_string()
225 } else {
226 None
227 }
228 } else {
229 None
230 }
231 })
232 .collect::<Vec<&str>>()
233 .join("\n\n")
234 } else {
235 return Err(RGraphError::node(
236 self.id.as_str(),
237 "Context must be an array of documents",
238 ));
239 };
240
241 let response = format!(
243 "Based on the provided context, here's what I can tell you about {}: {}",
244 query_text,
245 if context_docs.is_empty() {
246 "I don't have specific information available, but I can provide a general response."
247 } else {
248 "The context provides relevant information that I can use to answer your question comprehensively."
249 }
250 );
251
252 let token_estimate = response.len() / 4;
254
255 state.set_with_context(
257 context.current_node.as_str(),
258 &self.config.response_key,
259 StateValue::String(response),
260 );
261
262 state.set_with_context(
264 context.current_node.as_str(),
265 "generation_metadata",
266 StateValue::from(serde_json::json!({
267 "tokens_used": token_estimate,
268 "model": "mock-model",
269 "finish_reason": "complete"
270 })),
271 );
272
273 Ok(ExecutionResult::Continue)
274 }
275
276 fn id(&self) -> &NodeId {
277 &self.id
278 }
279
280 fn name(&self) -> &str {
281 &self.name
282 }
283
284 fn input_keys(&self) -> Vec<&str> {
285 vec![&self.config.query_key, &self.config.context_key]
286 }
287
288 fn output_keys(&self) -> Vec<&str> {
289 vec![&self.config.response_key, "generation_metadata"]
290 }
291}
292
293pub struct ContextEvaluationNode {
295 id: NodeId,
296 name: String,
297 config: ContextEvaluationConfig,
298}
299
300impl ContextEvaluationNode {
301 pub fn new(
302 id: impl Into<NodeId>,
303 name: impl Into<String>,
304 config: ContextEvaluationConfig,
305 ) -> Self {
306 Self {
307 id: id.into(),
308 name: name.into(),
309 config,
310 }
311 }
312}
313
314#[async_trait]
315impl Node for ContextEvaluationNode {
316 async fn execute(
317 &self,
318 state: &mut GraphState,
319 context: &ExecutionContext,
320 ) -> RGraphResult<ExecutionResult> {
321 let context_value = state.get(&self.config.context_key)?;
323 let query_value = state.get(&self.config.query_key)?;
324
325 let query_text = query_value
326 .as_string()
327 .ok_or_else(|| RGraphError::node(self.id.as_str(), "Query must be a string"))?;
328
329 let context_docs = if let Some(array) = context_value.as_array() {
330 array
331 } else {
332 return Err(RGraphError::node(
333 self.id.as_str(),
334 "Context must be an array of documents",
335 ));
336 };
337
338 let mut relevant_docs = Vec::new();
340 let mut total_score = 0.0;
341 let mut evaluated_count = 0;
342
343 for doc in context_docs {
344 if let Some(obj) = doc.as_object() {
345 if let Some(content_val) = obj.get("content") {
346 if let Some(content) = content_val.as_string() {
347 let relevance_score = self.calculate_relevance_score(query_text, content);
349
350 if relevance_score >= self.config.min_relevance_score {
351 let mut relevant_doc_map = obj.clone();
352 relevant_doc_map.insert(
353 "relevance_score".to_string(),
354 StateValue::Float(relevance_score as f64),
355 );
356 let relevant_doc = StateValue::Object(relevant_doc_map);
357 relevant_docs.push(relevant_doc);
358 }
359
360 total_score += relevance_score;
361 evaluated_count += 1;
362 }
363 }
364 }
365 }
366
367 let average_relevance = if evaluated_count > 0 {
368 total_score / evaluated_count as f32
369 } else {
370 0.0
371 };
372
373 state.set_with_context(
375 context.current_node.as_str(),
376 "filtered_context",
377 StateValue::Array(relevant_docs.clone()),
378 );
379
380 state.set_with_context(
382 context.current_node.as_str(),
383 &self.config.relevance_key,
384 StateValue::from(serde_json::json!({
385 "average_score": average_relevance,
386 "relevant_docs_count": relevant_docs.len(),
387 "total_docs_evaluated": evaluated_count,
388 "min_threshold": self.config.min_relevance_score
389 })),
390 );
391
392 Ok(ExecutionResult::Continue)
393 }
394
395 fn id(&self) -> &NodeId {
396 &self.id
397 }
398
399 fn name(&self) -> &str {
400 &self.name
401 }
402
403 fn input_keys(&self) -> Vec<&str> {
404 vec![&self.config.context_key, &self.config.query_key]
405 }
406
407 fn output_keys(&self) -> Vec<&str> {
408 vec!["filtered_context", &self.config.relevance_key]
409 }
410}
411
412impl ContextEvaluationNode {
413 fn calculate_relevance_score(&self, query: &str, content: &str) -> f32 {
415 let query_words: std::collections::HashSet<String> = query
416 .to_lowercase()
417 .split_whitespace()
418 .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
419 .filter(|w| !w.is_empty())
420 .map(|w| w.to_string())
421 .collect();
422
423 let content_words: std::collections::HashSet<String> = content
424 .to_lowercase()
425 .split_whitespace()
426 .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
427 .filter(|w| !w.is_empty())
428 .map(|w| w.to_string())
429 .collect();
430
431 if query_words.is_empty() || content_words.is_empty() {
432 return 0.0;
433 }
434
435 let intersection_count = query_words.intersection(&content_words).count();
436 let union_count = query_words.union(&content_words).count();
437
438 if union_count == 0 {
439 0.0
440 } else {
441 intersection_count as f32 / union_count as f32
442 }
443 }
444}
445
446pub struct RagWorkflowBuilder;
448
449impl RagWorkflowBuilder {
450 pub fn new() -> Self {
451 Self
452 }
453
454 pub fn build_retrieval_node(
456 &self,
457 id: impl Into<NodeId>,
458 name: impl Into<String>,
459 config: RagRetrievalConfig,
460 ) -> RGraphResult<Arc<RagRetrievalNode>> {
461 Ok(Arc::new(RagRetrievalNode::new(id, name, config)))
462 }
463
464 pub fn build_generation_node(
466 &self,
467 id: impl Into<NodeId>,
468 name: impl Into<String>,
469 config: RagGenerationConfig,
470 ) -> RGraphResult<Arc<RagGenerationNode>> {
471 Ok(Arc::new(RagGenerationNode::new(id, name, config)))
472 }
473
474 pub fn build_evaluation_node(
476 &self,
477 id: impl Into<NodeId>,
478 name: impl Into<String>,
479 config: ContextEvaluationConfig,
480 ) -> RGraphResult<Arc<ContextEvaluationNode>> {
481 Ok(Arc::new(ContextEvaluationNode::new(id, name, config)))
482 }
483}
484
485impl Default for RagWorkflowBuilder {
486 fn default() -> Self {
487 Self::new()
488 }
489}
490
491fn create_mock_document(content: &str, score: f64) -> StateValue {
493 let mut doc = HashMap::new();
494 doc.insert(
495 "content".to_string(),
496 StateValue::String(content.to_string()),
497 );
498 doc.insert("score".to_string(), StateValue::Float(score));
499 doc.insert("metadata".to_string(), StateValue::Object(HashMap::new()));
500 StateValue::Object(doc)
501}