1use lean_agentic::{Arena, Environment, SymbolTable};
15use std::sync::Arc;
16use std::time::Instant;
17
18pub mod policy;
19pub mod proof;
20pub mod router;
21pub mod audit;
22
23pub use policy::{Policy, PolicyEngine, PolicyViolation};
24pub use proof::{ProofCertificate, ProofKind};
25pub use router::{CostAwareRouter, Lane, RoutingDecision};
26pub use audit::{AuditLog, AuditEvent};
27
28#[derive(Debug, Clone)]
30pub struct RagQuery {
31 pub question: String,
33
34 pub sources: Vec<String>,
36
37 pub user_id: String,
39
40 pub latency_sla: Option<u64>,
42
43 pub cost_budget: Option<f64>,
45}
46
47#[derive(Debug, Clone)]
49pub struct RagResponse {
50 pub answer: String,
52
53 pub citations: Vec<Citation>,
55
56 pub proof: ProofCertificate,
58
59 pub metrics: ResponseMetrics,
61}
62
63#[derive(Debug, Clone)]
64pub struct Citation {
65 pub source: String,
66 pub excerpt: String,
67 pub relevance_score: f64,
68}
69
70#[derive(Debug, Clone)]
71pub struct ResponseMetrics {
72 pub latency_ms: u64,
73 pub cost_usd: f64,
74 pub tokens_used: usize,
75 pub lane_used: String,
76}
77
78pub struct RagGateway {
80 policy_engine: PolicyEngine,
82
83 router: CostAwareRouter,
85
86 audit_log: Arc<AuditLog>,
88
89 arena: Arena,
91 env: Environment,
92 symbols: SymbolTable,
93}
94
95impl RagGateway {
96 pub fn new(policies: Vec<Policy>) -> Self {
98 Self {
99 policy_engine: PolicyEngine::new(policies),
100 router: CostAwareRouter::new(),
101 audit_log: Arc::new(AuditLog::new()),
102 arena: Arena::new(),
103 env: Environment::new(),
104 symbols: SymbolTable::new(),
105 }
106 }
107
108 pub fn process(&mut self, query: RagQuery) -> Result<RagResponse, GatewayError> {
110 let start = Instant::now();
111
112 let access_check = self.policy_engine.check_access(&query)?;
114 if !access_check.allowed {
115 self.audit_log.log_blocked(&query, format!("{:?}", access_check.violation));
116 return Err(GatewayError::PolicyViolation(access_check.violation));
117 }
118
119 let routing = self.router.select_lane(
121 query.latency_sla.unwrap_or(150),
122 query.cost_budget.unwrap_or(0.01),
123 )?;
124
125 let (answer, citations, tokens) = self.retrieve_and_generate(
127 &query,
128 &routing.lane,
129 )?;
130
131 let masked_answer = self.policy_engine.mask_pii(&answer)?;
133
134 let proof = self.generate_proof(&query, &masked_answer, &access_check)?;
136
137 let latency = start.elapsed().as_millis() as u64;
138
139 self.audit_log.log_success(&query, latency, routing.estimated_cost, &routing.lane.name);
141
142 Ok(RagResponse {
143 answer: masked_answer,
144 citations,
145 proof,
146 metrics: ResponseMetrics {
147 latency_ms: latency,
148 cost_usd: routing.estimated_cost,
149 tokens_used: tokens,
150 lane_used: routing.lane.name.clone(),
151 },
152 })
153 }
154
155 fn retrieve_and_generate(
157 &self,
158 query: &RagQuery,
159 _lane: &Lane,
160 ) -> Result<(String, Vec<Citation>, usize), GatewayError> {
161 let answer = format!(
165 "Based on the sources, here is the answer to '{}': \
166 [Generated content respecting all policies]",
167 query.question
168 );
169
170 let citations = vec![
171 Citation {
172 source: query.sources.first().cloned().unwrap_or_default(),
173 excerpt: "Relevant excerpt from source...".to_string(),
174 relevance_score: 0.92,
175 }
176 ];
177
178 let tokens = 450; Ok((answer, citations, tokens))
181 }
182
183 fn generate_proof(
185 &self,
186 query: &RagQuery,
187 answer: &str,
188 _access_check: &AccessCheckResult,
189 ) -> Result<ProofCertificate, GatewayError> {
190 let claims = vec![
191 format!("access_granted(user={})", query.user_id),
192 format!("pii_masked(answer)"),
193 format!("sources_authorized({:?})", query.sources),
194 ];
195
196 Ok(ProofCertificate::new(
197 ProofKind::PolicyRespected,
198 claims,
199 answer,
200 ))
201 }
202
203 pub fn audit_log(&self) -> Arc<AuditLog> {
205 Arc::clone(&self.audit_log)
206 }
207}
208
209#[derive(Debug)]
210pub struct AccessCheckResult {
211 pub allowed: bool,
212 pub violation: PolicyViolation,
213}
214
215#[derive(Debug)]
216pub enum GatewayError {
217 PolicyViolation(PolicyViolation),
218 RoutingError(String),
219 RetrievalError(String),
220 ProofGenerationError(String),
221 Internal(String),
222}
223
224impl std::fmt::Display for GatewayError {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 match self {
227 GatewayError::PolicyViolation(v) => write!(f, "Policy violation: {:?}", v),
228 GatewayError::RoutingError(e) => write!(f, "Routing error: {}", e),
229 GatewayError::RetrievalError(e) => write!(f, "Retrieval error: {}", e),
230 GatewayError::ProofGenerationError(e) => write!(f, "Proof error: {}", e),
231 GatewayError::Internal(e) => write!(f, "Internal error: {}", e),
232 }
233 }
234}
235
236impl std::error::Error for GatewayError {}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_basic_query() {
244 let policies = vec![
245 Policy::allow_user("user123"),
246 Policy::mask_pii(),
247 ];
248
249 let mut gateway = RagGateway::new(policies);
250
251 let query = RagQuery {
252 question: "What is our refund policy?".to_string(),
253 sources: vec!["policies.txt".to_string()],
254 user_id: "user123".to_string(),
255 latency_sla: Some(150),
256 cost_budget: Some(0.01),
257 };
258
259 let response = gateway.process(query).unwrap();
260
261 assert!(response.metrics.latency_ms < 150);
262 assert!(response.proof.claims.len() > 0);
263 assert_eq!(response.metrics.lane_used, "local");
264 }
265
266 #[test]
267 fn test_policy_violation() {
268 let policies = vec![
269 Policy::deny_user("blocked_user"),
270 ];
271
272 let mut gateway = RagGateway::new(policies);
273
274 let query = RagQuery {
275 question: "What is our refund policy?".to_string(),
276 sources: vec!["policies.txt".to_string()],
277 user_id: "blocked_user".to_string(),
278 latency_sla: Some(150),
279 cost_budget: Some(0.01),
280 };
281
282 let result = gateway.process(query);
283 assert!(result.is_err());
284 assert!(matches!(result.unwrap_err(), GatewayError::PolicyViolation(_)));
285 }
286
287 #[test]
288 fn test_pii_masking() {
289 let policies = vec![Policy::mask_pii()];
290 let mut gateway = RagGateway::new(policies);
291
292 let query = RagQuery {
293 question: "My SSN is 123-45-6789".to_string(),
294 sources: vec![],
295 user_id: "user123".to_string(),
296 latency_sla: None,
297 cost_budget: None,
298 };
299
300 let response = gateway.process(query).unwrap();
301 assert!(response.answer.contains("[REDACTED]") || !response.answer.contains("123-45-6789"));
303 }
304}