oxirs_graphrag/generation/
context_builder.rs

1//! Context building for LLM generation
2
3use crate::{CommunitySummary, GraphRAGResult, Triple};
4use serde::{Deserialize, Serialize};
5
6/// Context builder configuration
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ContextConfig {
9    /// Maximum context length in characters
10    pub max_length: usize,
11    /// Include community summaries
12    pub include_communities: bool,
13    /// Include raw triples
14    pub include_triples: bool,
15    /// Triple format
16    pub triple_format: TripleFormat,
17    /// Prioritize triples by score
18    pub score_weighted: bool,
19}
20
21impl Default for ContextConfig {
22    fn default() -> Self {
23        Self {
24            max_length: 8000,
25            include_communities: true,
26            include_triples: true,
27            triple_format: TripleFormat::NaturalLanguage,
28            score_weighted: true,
29        }
30    }
31}
32
33/// Triple formatting options
34#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
35pub enum TripleFormat {
36    /// Natural language: "Entity A is related to Entity B"
37    NaturalLanguage,
38    /// Structured: "subject → predicate → object"
39    Structured,
40    /// Turtle-like: `<subject> <predicate> <object> .`
41    Turtle,
42    /// JSON-LD style
43    JsonLd,
44}
45
46/// Context builder for LLM input
47pub struct ContextBuilder {
48    config: ContextConfig,
49}
50
51impl Default for ContextBuilder {
52    fn default() -> Self {
53        Self::new(ContextConfig::default())
54    }
55}
56
57impl ContextBuilder {
58    pub fn new(config: ContextConfig) -> Self {
59        Self { config }
60    }
61
62    /// Build context string from subgraph and communities
63    pub fn build(
64        &self,
65        query: &str,
66        triples: &[Triple],
67        communities: &[CommunitySummary],
68    ) -> GraphRAGResult<String> {
69        let mut context = String::new();
70        let mut remaining_length = self.config.max_length;
71
72        // Add query context
73        let query_section = format!("## Query\n{}\n\n", query);
74        if query_section.len() < remaining_length {
75            context.push_str(&query_section);
76            remaining_length -= query_section.len();
77        }
78
79        // Add community summaries
80        if self.config.include_communities && !communities.is_empty() {
81            let community_section = self.format_communities(communities, remaining_length / 3);
82            if community_section.len() < remaining_length {
83                context.push_str(&community_section);
84                remaining_length -= community_section.len();
85            }
86        }
87
88        // Add triples
89        if self.config.include_triples && !triples.is_empty() {
90            let triples_section = self.format_triples(triples, remaining_length);
91            context.push_str(&triples_section);
92        }
93
94        Ok(context)
95    }
96
97    /// Format community summaries
98    fn format_communities(&self, communities: &[CommunitySummary], max_length: usize) -> String {
99        let mut result = String::from("## Knowledge Graph Communities\n\n");
100
101        for community in communities {
102            let entry = format!(
103                "### {}\n{}\n**Entities:** {}\n\n",
104                community.id,
105                community.summary,
106                community
107                    .entities
108                    .iter()
109                    .take(5)
110                    .cloned()
111                    .collect::<Vec<_>>()
112                    .join(", ")
113            );
114
115            if result.len() + entry.len() > max_length {
116                break;
117            }
118            result.push_str(&entry);
119        }
120
121        result
122    }
123
124    /// Format triples according to configured format
125    fn format_triples(&self, triples: &[Triple], max_length: usize) -> String {
126        let mut result = String::from("## Knowledge Graph Facts\n\n");
127
128        for triple in triples {
129            let entry = match self.config.triple_format {
130                TripleFormat::NaturalLanguage => self.triple_to_natural_language(triple),
131                TripleFormat::Structured => self.triple_to_structured(triple),
132                TripleFormat::Turtle => self.triple_to_turtle(triple),
133                TripleFormat::JsonLd => self.triple_to_jsonld(triple),
134            };
135
136            if result.len() + entry.len() > max_length {
137                break;
138            }
139            result.push_str(&entry);
140            result.push('\n');
141        }
142
143        result
144    }
145
146    /// Convert triple to natural language
147    fn triple_to_natural_language(&self, triple: &Triple) -> String {
148        let subject = self.extract_local_name(&triple.subject);
149        let predicate = self.predicate_to_phrase(&triple.predicate);
150        let object = self.extract_local_name(&triple.object);
151
152        format!("- {} {} {}", subject, predicate, object)
153    }
154
155    /// Convert triple to structured format
156    fn triple_to_structured(&self, triple: &Triple) -> String {
157        let subject = self.extract_local_name(&triple.subject);
158        let predicate = self.extract_local_name(&triple.predicate);
159        let object = self.extract_local_name(&triple.object);
160
161        format!("- {} → {} → {}", subject, predicate, object)
162    }
163
164    /// Convert triple to Turtle format
165    fn triple_to_turtle(&self, triple: &Triple) -> String {
166        format!(
167            "<{}> <{}> <{}> .",
168            triple.subject, triple.predicate, triple.object
169        )
170    }
171
172    /// Convert triple to JSON-LD style
173    fn triple_to_jsonld(&self, triple: &Triple) -> String {
174        let subject = self.extract_local_name(&triple.subject);
175        let predicate = self.extract_local_name(&triple.predicate);
176        let object = self.extract_local_name(&triple.object);
177
178        format!(
179            "{{ \"@id\": \"{}\", \"{}\": \"{}\" }}",
180            subject, predicate, object
181        )
182    }
183
184    /// Extract local name from URI
185    fn extract_local_name(&self, uri: &str) -> String {
186        // Try '#' first (for RDF namespace URIs), then '/'
187        uri.rsplit('#')
188            .next()
189            .filter(|s| s != &uri) // Only use if '#' was found
190            .or_else(|| uri.rsplit('/').next())
191            .unwrap_or(uri)
192            .to_string()
193    }
194
195    /// Convert predicate URI to natural language phrase
196    fn predicate_to_phrase(&self, predicate: &str) -> String {
197        let local = self.extract_local_name(predicate);
198
199        // Common predicate mappings
200        match local.as_str() {
201            "type" | "rdf:type" => "is a".to_string(),
202            "label" | "rdfs:label" => "is labeled".to_string(),
203            "subClassOf" => "is a subclass of".to_string(),
204            "partOf" => "is part of".to_string(),
205            "hasPart" => "has part".to_string(),
206            "relatedTo" => "is related to".to_string(),
207            "sameAs" => "is the same as".to_string(),
208            "knows" => "knows".to_string(),
209            "worksFor" => "works for".to_string(),
210            "locatedIn" => "is located in".to_string(),
211            _ => {
212                // Convert camelCase to spaces
213                let mut result = String::new();
214                for (i, c) in local.chars().enumerate() {
215                    if i > 0 && c.is_uppercase() {
216                        result.push(' ');
217                    }
218                    result.push(c.to_lowercase().next().unwrap_or(c));
219                }
220                result
221            }
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_context_building() {
232        let builder = ContextBuilder::default();
233
234        let triples = vec![
235            Triple::new(
236                "http://example.org/Battery1",
237                "http://example.org/hasStatus",
238                "http://example.org/Critical",
239            ),
240            Triple::new(
241                "http://example.org/Battery1",
242                "http://example.org/temperature",
243                "85",
244            ),
245        ];
246
247        let communities = vec![CommunitySummary {
248            id: "community_0".to_string(),
249            summary: "Battery monitoring entities".to_string(),
250            entities: vec!["Battery1".to_string(), "Sensor1".to_string()],
251            representative_triples: vec![],
252            level: 0,
253            modularity: 0.5,
254        }];
255
256        let context = builder
257            .build("What is the battery status?", &triples, &communities)
258            .unwrap();
259
260        assert!(context.contains("Query"));
261        assert!(context.contains("Battery1"));
262    }
263
264    #[test]
265    fn test_predicate_to_phrase() {
266        let builder = ContextBuilder::default();
267
268        assert_eq!(
269            builder.predicate_to_phrase("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
270            "is a"
271        );
272        assert_eq!(
273            builder.predicate_to_phrase("http://example.org/partOf"),
274            "is part of"
275        );
276        assert_eq!(
277            builder.predicate_to_phrase("http://example.org/hasTemperature"),
278            "has temperature"
279        );
280    }
281}