rrag_graph/
rrag_integration.rs

1//! # RRAG Integration
2//!
3//! Integration layer between RGraph and RRAG for RAG-powered agent workflows.
4//! This module provides nodes that can leverage RRAG's retrieval and generation capabilities.
5
6use crate::core::{ExecutionContext, ExecutionResult, Node, NodeId};
7use crate::state::{GraphState, StateValue};
8use crate::{RGraphError, RGraphResult};
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16/// Configuration for RAG retrieval nodes
17#[derive(Debug, Clone)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub struct RagRetrievalConfig {
20    pub query_key: String,
21    pub context_key: String,
22    pub top_k: usize,
23    pub similarity_threshold: Option<f32>,
24    pub metadata_filters: Vec<(String, String)>,
25}
26
27impl Default for RagRetrievalConfig {
28    fn default() -> Self {
29        Self {
30            query_key: "user_query".to_string(),
31            context_key: "retrieval_context".to_string(),
32            top_k: 5,
33            similarity_threshold: Some(0.7),
34            metadata_filters: Vec::new(),
35        }
36    }
37}
38
39/// Configuration for RAG generation nodes
40#[derive(Debug, Clone)]
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
42pub struct RagGenerationConfig {
43    pub query_key: String,
44    pub context_key: String,
45    pub response_key: String,
46    pub system_prompt: Option<String>,
47    pub max_tokens: Option<usize>,
48    pub temperature: Option<f32>,
49}
50
51impl Default for RagGenerationConfig {
52    fn default() -> Self {
53        Self {
54            query_key: "user_query".to_string(),
55            context_key: "retrieval_context".to_string(),
56            response_key: "rag_response".to_string(),
57            system_prompt: Some(
58                "You are a helpful AI assistant. Use the provided context to answer the user's question accurately and comprehensively.".to_string()
59            ),
60            max_tokens: Some(512),
61            temperature: Some(0.7),
62        }
63    }
64}
65
66/// Configuration for context evaluation nodes
67#[derive(Debug, Clone)]
68#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69pub struct ContextEvaluationConfig {
70    pub context_key: String,
71    pub query_key: String,
72    pub relevance_key: String,
73    pub min_relevance_score: f32,
74}
75
76impl Default for ContextEvaluationConfig {
77    fn default() -> Self {
78        Self {
79            context_key: "retrieval_context".to_string(),
80            query_key: "user_query".to_string(),
81            relevance_key: "context_relevance".to_string(),
82            min_relevance_score: 0.6,
83        }
84    }
85}
86
87/// A node that performs RAG retrieval (simplified mock implementation)
88pub struct RagRetrievalNode {
89    id: NodeId,
90    name: String,
91    config: RagRetrievalConfig,
92}
93
94impl RagRetrievalNode {
95    pub fn new(id: impl Into<NodeId>, name: impl Into<String>, config: RagRetrievalConfig) -> Self {
96        Self {
97            id: id.into(),
98            name: name.into(),
99            config,
100        }
101    }
102}
103
104#[async_trait]
105impl Node for RagRetrievalNode {
106    async fn execute(
107        &self,
108        state: &mut GraphState,
109        context: &ExecutionContext,
110    ) -> RGraphResult<ExecutionResult> {
111        // Get the query from state
112        let query = state.get(&self.config.query_key)?;
113        let query_text = query
114            .as_string()
115            .ok_or_else(|| RGraphError::node(self.id.as_str(), "Query must be a string"))?;
116
117        // Simulate document retrieval (in real implementation, this would use RRAG)
118        let mock_documents = vec![
119            create_mock_document("Machine learning is a method of data analysis that automates analytical model building.", 0.85),
120            create_mock_document("It is a branch of artificial intelligence based on the idea that systems can learn from data.", 0.78),
121            create_mock_document("ML algorithms build a model based on training data to make predictions or decisions.", 0.72),
122        ];
123
124        // Filter based on similarity threshold if provided
125        let filtered_docs: Vec<StateValue> =
126            if let Some(threshold) = self.config.similarity_threshold {
127                mock_documents
128                    .into_iter()
129                    .filter(|doc| {
130                        if let Some(obj) = doc.as_object() {
131                            if let Some(score_val) = obj.get("score") {
132                                if let Some(score) = score_val.as_float() {
133                                    return score >= threshold as f64;
134                                }
135                            }
136                        }
137                        false
138                    })
139                    .take(self.config.top_k)
140                    .collect()
141            } else {
142                mock_documents.into_iter().take(self.config.top_k).collect()
143            };
144
145        // Store retrieval context
146        state.set_with_context(
147            context.current_node.as_str(),
148            &self.config.context_key,
149            StateValue::Array(filtered_docs.clone()),
150        );
151
152        // Store retrieval metadata
153        state.set_with_context(
154            context.current_node.as_str(),
155            "retrieval_metadata",
156            StateValue::from(serde_json::json!({
157                "total_results": filtered_docs.len(),
158                "retrieved_count": filtered_docs.len(),
159                "query": query_text
160            })),
161        );
162
163        Ok(ExecutionResult::Continue)
164    }
165
166    fn id(&self) -> &NodeId {
167        &self.id
168    }
169
170    fn name(&self) -> &str {
171        &self.name
172    }
173
174    fn input_keys(&self) -> Vec<&str> {
175        vec![&self.config.query_key]
176    }
177
178    fn output_keys(&self) -> Vec<&str> {
179        vec![&self.config.context_key, "retrieval_metadata"]
180    }
181}
182
183/// A node that performs RAG generation (simplified mock implementation)
184pub struct RagGenerationNode {
185    id: NodeId,
186    name: String,
187    config: RagGenerationConfig,
188}
189
190impl RagGenerationNode {
191    pub fn new(
192        id: impl Into<NodeId>,
193        name: impl Into<String>,
194        config: RagGenerationConfig,
195    ) -> Self {
196        Self {
197            id: id.into(),
198            name: name.into(),
199            config,
200        }
201    }
202}
203
204#[async_trait]
205impl Node for RagGenerationNode {
206    async fn execute(
207        &self,
208        state: &mut GraphState,
209        context: &ExecutionContext,
210    ) -> RGraphResult<ExecutionResult> {
211        // Get query and context from state
212        let query = state.get(&self.config.query_key)?;
213        let query_text = query
214            .as_string()
215            .ok_or_else(|| RGraphError::node(self.id.as_str(), "Query must be a string"))?;
216
217        let context_value = state.get(&self.config.context_key)?;
218        let context_docs = if let Some(array) = context_value.as_array() {
219            array
220                .iter()
221                .filter_map(|v| {
222                    if let Some(obj) = v.as_object() {
223                        if let Some(content) = obj.get("content") {
224                            content.as_string()
225                        } else {
226                            None
227                        }
228                    } else {
229                        None
230                    }
231                })
232                .collect::<Vec<&str>>()
233                .join("\n\n")
234        } else {
235            return Err(RGraphError::node(
236                self.id.as_str(),
237                "Context must be an array of documents",
238            ));
239        };
240
241        // Simulate response generation (in real implementation, this would use RRAG's generation engine)
242        let response = format!(
243            "Based on the provided context, here's what I can tell you about {}: {}",
244            query_text,
245            if context_docs.is_empty() {
246                "I don't have specific information available, but I can provide a general response."
247            } else {
248                "The context provides relevant information that I can use to answer your question comprehensively."
249            }
250        );
251
252        // Calculate token estimate before moving response
253        let token_estimate = response.len() / 4;
254
255        // Store the generated response
256        state.set_with_context(
257            context.current_node.as_str(),
258            &self.config.response_key,
259            StateValue::String(response),
260        );
261
262        // Store generation metadata
263        state.set_with_context(
264            context.current_node.as_str(),
265            "generation_metadata",
266            StateValue::from(serde_json::json!({
267                "tokens_used": token_estimate,
268                "model": "mock-model",
269                "finish_reason": "complete"
270            })),
271        );
272
273        Ok(ExecutionResult::Continue)
274    }
275
276    fn id(&self) -> &NodeId {
277        &self.id
278    }
279
280    fn name(&self) -> &str {
281        &self.name
282    }
283
284    fn input_keys(&self) -> Vec<&str> {
285        vec![&self.config.query_key, &self.config.context_key]
286    }
287
288    fn output_keys(&self) -> Vec<&str> {
289        vec![&self.config.response_key, "generation_metadata"]
290    }
291}
292
293/// A node that evaluates context relevance
294pub struct ContextEvaluationNode {
295    id: NodeId,
296    name: String,
297    config: ContextEvaluationConfig,
298}
299
300impl ContextEvaluationNode {
301    pub fn new(
302        id: impl Into<NodeId>,
303        name: impl Into<String>,
304        config: ContextEvaluationConfig,
305    ) -> Self {
306        Self {
307            id: id.into(),
308            name: name.into(),
309            config,
310        }
311    }
312}
313
314#[async_trait]
315impl Node for ContextEvaluationNode {
316    async fn execute(
317        &self,
318        state: &mut GraphState,
319        context: &ExecutionContext,
320    ) -> RGraphResult<ExecutionResult> {
321        // Get context and query from state
322        let context_value = state.get(&self.config.context_key)?;
323        let query_value = state.get(&self.config.query_key)?;
324
325        let query_text = query_value
326            .as_string()
327            .ok_or_else(|| RGraphError::node(self.id.as_str(), "Query must be a string"))?;
328
329        let context_docs = if let Some(array) = context_value.as_array() {
330            array
331        } else {
332            return Err(RGraphError::node(
333                self.id.as_str(),
334                "Context must be an array of documents",
335            ));
336        };
337
338        // Evaluate relevance for each document
339        let mut relevant_docs = Vec::new();
340        let mut total_score = 0.0;
341        let mut evaluated_count = 0;
342
343        for doc in context_docs {
344            if let Some(obj) = doc.as_object() {
345                if let Some(content_val) = obj.get("content") {
346                    if let Some(content) = content_val.as_string() {
347                        // Simple relevance scoring based on keyword overlap
348                        let relevance_score = self.calculate_relevance_score(query_text, content);
349
350                        if relevance_score >= self.config.min_relevance_score {
351                            let mut relevant_doc_map = obj.clone();
352                            relevant_doc_map.insert(
353                                "relevance_score".to_string(),
354                                StateValue::Float(relevance_score as f64),
355                            );
356                            let relevant_doc = StateValue::Object(relevant_doc_map);
357                            relevant_docs.push(relevant_doc);
358                        }
359
360                        total_score += relevance_score;
361                        evaluated_count += 1;
362                    }
363                }
364            }
365        }
366
367        let average_relevance = if evaluated_count > 0 {
368            total_score / evaluated_count as f32
369        } else {
370            0.0
371        };
372
373        // Store filtered relevant context
374        state.set_with_context(
375            context.current_node.as_str(),
376            "filtered_context",
377            StateValue::Array(relevant_docs.clone()),
378        );
379
380        // Store relevance metrics
381        state.set_with_context(
382            context.current_node.as_str(),
383            &self.config.relevance_key,
384            StateValue::from(serde_json::json!({
385                "average_score": average_relevance,
386                "relevant_docs_count": relevant_docs.len(),
387                "total_docs_evaluated": evaluated_count,
388                "min_threshold": self.config.min_relevance_score
389            })),
390        );
391
392        Ok(ExecutionResult::Continue)
393    }
394
395    fn id(&self) -> &NodeId {
396        &self.id
397    }
398
399    fn name(&self) -> &str {
400        &self.name
401    }
402
403    fn input_keys(&self) -> Vec<&str> {
404        vec![&self.config.context_key, &self.config.query_key]
405    }
406
407    fn output_keys(&self) -> Vec<&str> {
408        vec!["filtered_context", &self.config.relevance_key]
409    }
410}
411
412impl ContextEvaluationNode {
413    /// Simple relevance scoring based on keyword overlap
414    fn calculate_relevance_score(&self, query: &str, content: &str) -> f32 {
415        let query_words: std::collections::HashSet<String> = query
416            .to_lowercase()
417            .split_whitespace()
418            .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
419            .filter(|w| !w.is_empty())
420            .map(|w| w.to_string())
421            .collect();
422
423        let content_words: std::collections::HashSet<String> = content
424            .to_lowercase()
425            .split_whitespace()
426            .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
427            .filter(|w| !w.is_empty())
428            .map(|w| w.to_string())
429            .collect();
430
431        if query_words.is_empty() || content_words.is_empty() {
432            return 0.0;
433        }
434
435        let intersection_count = query_words.intersection(&content_words).count();
436        let union_count = query_words.union(&content_words).count();
437
438        if union_count == 0 {
439            0.0
440        } else {
441            intersection_count as f32 / union_count as f32
442        }
443    }
444}
445
446/// Builder for RAG-powered workflows
447pub struct RagWorkflowBuilder;
448
449impl RagWorkflowBuilder {
450    pub fn new() -> Self {
451        Self
452    }
453
454    /// Create a RAG retrieval node
455    pub fn build_retrieval_node(
456        &self,
457        id: impl Into<NodeId>,
458        name: impl Into<String>,
459        config: RagRetrievalConfig,
460    ) -> RGraphResult<Arc<RagRetrievalNode>> {
461        Ok(Arc::new(RagRetrievalNode::new(id, name, config)))
462    }
463
464    /// Create a RAG generation node
465    pub fn build_generation_node(
466        &self,
467        id: impl Into<NodeId>,
468        name: impl Into<String>,
469        config: RagGenerationConfig,
470    ) -> RGraphResult<Arc<RagGenerationNode>> {
471        Ok(Arc::new(RagGenerationNode::new(id, name, config)))
472    }
473
474    /// Create a context evaluation node
475    pub fn build_evaluation_node(
476        &self,
477        id: impl Into<NodeId>,
478        name: impl Into<String>,
479        config: ContextEvaluationConfig,
480    ) -> RGraphResult<Arc<ContextEvaluationNode>> {
481        Ok(Arc::new(ContextEvaluationNode::new(id, name, config)))
482    }
483}
484
485impl Default for RagWorkflowBuilder {
486    fn default() -> Self {
487        Self::new()
488    }
489}
490
491/// Helper function to create mock documents for demonstration
492fn create_mock_document(content: &str, score: f64) -> StateValue {
493    let mut doc = HashMap::new();
494    doc.insert(
495        "content".to_string(),
496        StateValue::String(content.to_string()),
497    );
498    doc.insert("score".to_string(), StateValue::Float(score));
499    doc.insert("metadata".to_string(), StateValue::Object(HashMap::new()));
500    StateValue::Object(doc)
501}