oxirs_graphrag/generation/
context_builder.rs1use crate::{CommunitySummary, GraphRAGResult, Triple};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ContextConfig {
9 pub max_length: usize,
11 pub include_communities: bool,
13 pub include_triples: bool,
15 pub triple_format: TripleFormat,
17 pub score_weighted: bool,
19}
20
21impl Default for ContextConfig {
22 fn default() -> Self {
23 Self {
24 max_length: 8000,
25 include_communities: true,
26 include_triples: true,
27 triple_format: TripleFormat::NaturalLanguage,
28 score_weighted: true,
29 }
30 }
31}
32
33#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
35pub enum TripleFormat {
36 NaturalLanguage,
38 Structured,
40 Turtle,
42 JsonLd,
44}
45
46pub struct ContextBuilder {
48 config: ContextConfig,
49}
50
51impl Default for ContextBuilder {
52 fn default() -> Self {
53 Self::new(ContextConfig::default())
54 }
55}
56
57impl ContextBuilder {
58 pub fn new(config: ContextConfig) -> Self {
59 Self { config }
60 }
61
62 pub fn build(
64 &self,
65 query: &str,
66 triples: &[Triple],
67 communities: &[CommunitySummary],
68 ) -> GraphRAGResult<String> {
69 let mut context = String::new();
70 let mut remaining_length = self.config.max_length;
71
72 let query_section = format!("## Query\n{}\n\n", query);
74 if query_section.len() < remaining_length {
75 context.push_str(&query_section);
76 remaining_length -= query_section.len();
77 }
78
79 if self.config.include_communities && !communities.is_empty() {
81 let community_section = self.format_communities(communities, remaining_length / 3);
82 if community_section.len() < remaining_length {
83 context.push_str(&community_section);
84 remaining_length -= community_section.len();
85 }
86 }
87
88 if self.config.include_triples && !triples.is_empty() {
90 let triples_section = self.format_triples(triples, remaining_length);
91 context.push_str(&triples_section);
92 }
93
94 Ok(context)
95 }
96
97 fn format_communities(&self, communities: &[CommunitySummary], max_length: usize) -> String {
99 let mut result = String::from("## Knowledge Graph Communities\n\n");
100
101 for community in communities {
102 let entry = format!(
103 "### {}\n{}\n**Entities:** {}\n\n",
104 community.id,
105 community.summary,
106 community
107 .entities
108 .iter()
109 .take(5)
110 .cloned()
111 .collect::<Vec<_>>()
112 .join(", ")
113 );
114
115 if result.len() + entry.len() > max_length {
116 break;
117 }
118 result.push_str(&entry);
119 }
120
121 result
122 }
123
124 fn format_triples(&self, triples: &[Triple], max_length: usize) -> String {
126 let mut result = String::from("## Knowledge Graph Facts\n\n");
127
128 for triple in triples {
129 let entry = match self.config.triple_format {
130 TripleFormat::NaturalLanguage => self.triple_to_natural_language(triple),
131 TripleFormat::Structured => self.triple_to_structured(triple),
132 TripleFormat::Turtle => self.triple_to_turtle(triple),
133 TripleFormat::JsonLd => self.triple_to_jsonld(triple),
134 };
135
136 if result.len() + entry.len() > max_length {
137 break;
138 }
139 result.push_str(&entry);
140 result.push('\n');
141 }
142
143 result
144 }
145
146 fn triple_to_natural_language(&self, triple: &Triple) -> String {
148 let subject = self.extract_local_name(&triple.subject);
149 let predicate = self.predicate_to_phrase(&triple.predicate);
150 let object = self.extract_local_name(&triple.object);
151
152 format!("- {} {} {}", subject, predicate, object)
153 }
154
155 fn triple_to_structured(&self, triple: &Triple) -> String {
157 let subject = self.extract_local_name(&triple.subject);
158 let predicate = self.extract_local_name(&triple.predicate);
159 let object = self.extract_local_name(&triple.object);
160
161 format!("- {} → {} → {}", subject, predicate, object)
162 }
163
164 fn triple_to_turtle(&self, triple: &Triple) -> String {
166 format!(
167 "<{}> <{}> <{}> .",
168 triple.subject, triple.predicate, triple.object
169 )
170 }
171
172 fn triple_to_jsonld(&self, triple: &Triple) -> String {
174 let subject = self.extract_local_name(&triple.subject);
175 let predicate = self.extract_local_name(&triple.predicate);
176 let object = self.extract_local_name(&triple.object);
177
178 format!(
179 "{{ \"@id\": \"{}\", \"{}\": \"{}\" }}",
180 subject, predicate, object
181 )
182 }
183
184 fn extract_local_name(&self, uri: &str) -> String {
186 uri.rsplit('#')
188 .next()
189 .filter(|s| s != &uri) .or_else(|| uri.rsplit('/').next())
191 .unwrap_or(uri)
192 .to_string()
193 }
194
195 fn predicate_to_phrase(&self, predicate: &str) -> String {
197 let local = self.extract_local_name(predicate);
198
199 match local.as_str() {
201 "type" | "rdf:type" => "is a".to_string(),
202 "label" | "rdfs:label" => "is labeled".to_string(),
203 "subClassOf" => "is a subclass of".to_string(),
204 "partOf" => "is part of".to_string(),
205 "hasPart" => "has part".to_string(),
206 "relatedTo" => "is related to".to_string(),
207 "sameAs" => "is the same as".to_string(),
208 "knows" => "knows".to_string(),
209 "worksFor" => "works for".to_string(),
210 "locatedIn" => "is located in".to_string(),
211 _ => {
212 let mut result = String::new();
214 for (i, c) in local.chars().enumerate() {
215 if i > 0 && c.is_uppercase() {
216 result.push(' ');
217 }
218 result.push(c.to_lowercase().next().unwrap_or(c));
219 }
220 result
221 }
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn test_context_building() {
232 let builder = ContextBuilder::default();
233
234 let triples = vec![
235 Triple::new(
236 "http://example.org/Battery1",
237 "http://example.org/hasStatus",
238 "http://example.org/Critical",
239 ),
240 Triple::new(
241 "http://example.org/Battery1",
242 "http://example.org/temperature",
243 "85",
244 ),
245 ];
246
247 let communities = vec![CommunitySummary {
248 id: "community_0".to_string(),
249 summary: "Battery monitoring entities".to_string(),
250 entities: vec!["Battery1".to_string(), "Sensor1".to_string()],
251 representative_triples: vec![],
252 level: 0,
253 modularity: 0.5,
254 }];
255
256 let context = builder
257 .build("What is the battery status?", &triples, &communities)
258 .unwrap();
259
260 assert!(context.contains("Query"));
261 assert!(context.contains("Battery1"));
262 }
263
264 #[test]
265 fn test_predicate_to_phrase() {
266 let builder = ContextBuilder::default();
267
268 assert_eq!(
269 builder.predicate_to_phrase("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
270 "is a"
271 );
272 assert_eq!(
273 builder.predicate_to_phrase("http://example.org/partOf"),
274 "is part of"
275 );
276 assert_eq!(
277 builder.predicate_to_phrase("http://example.org/hasTemperature"),
278 "has temperature"
279 );
280 }
281}