1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use tonic::transport::Server;
5use tonic::{Request, Response, Status};
6use uuid::Uuid;
7
8use crate::audit::AuditFilter;
9use crate::error::{GuardError, GuardResult};
10use crate::guard::Guard;
11use crate::proto::guard::{
12 guard_service_server::{GuardService, GuardServiceServer},
13 AddPolicyRequest, AddPolicyResponse, AuditEntry as AuditEntryMessage, CheckAccessRequest,
14 CheckAccessResponse, CreateSessionRequest, CreateSessionResponse, ListPoliciesRequest,
15 ListPoliciesResponse, Policy as PolicyMessage, QueryAuditLogRequest, QueryAuditLogResponse,
16 RemovePolicyRequest, RemovePolicyResponse, RevokeSessionRequest, RevokeSessionResponse,
17 ValidateSessionRequest, ValidateSessionResponse,
18};
19use crate::types::PolicyDecision;
20
21#[derive(Clone)]
23pub struct GrpcGuardService {
24 guard: Arc<Guard>,
25}
26
27impl GrpcGuardService {
28 pub fn new(guard: Arc<Guard>) -> Self {
30 Self { guard }
31 }
32
33 pub fn service(self) -> GuardServiceServer<Self> {
35 GuardServiceServer::new(self)
36 }
37}
38
39#[tonic::async_trait]
40impl GuardService for GrpcGuardService {
41 async fn check_access(
42 &self,
43 request: Request<CheckAccessRequest>,
44 ) -> Result<Response<CheckAccessResponse>, Status> {
45 let payload = request.into_inner();
46 let session = self
47 .guard
48 .sessions()
49 .validate_session(&payload.session_token)
50 .await?;
51 let decision = self
52 .guard
53 .check_access_with_task(&session, &payload.action, &payload.resource, &payload.task)
54 .await?;
55
56 let response = match decision {
57 PolicyDecision::Allow => CheckAccessResponse {
58 decision: "allow".to_owned(),
59 reason: String::new(),
60 masked_fields: Vec::new(),
61 },
62 PolicyDecision::Deny { reason } => CheckAccessResponse {
63 decision: "deny".to_owned(),
64 reason,
65 masked_fields: Vec::new(),
66 },
67 PolicyDecision::Mask { fields } => CheckAccessResponse {
68 decision: "mask".to_owned(),
69 reason: String::new(),
70 masked_fields: fields,
71 },
72 };
73 Ok(Response::new(response))
74 }
75
76 async fn create_session(
77 &self,
78 request: Request<CreateSessionRequest>,
79 ) -> Result<Response<CreateSessionResponse>, Status> {
80 let payload = request.into_inner();
81 let session = self
82 .guard
83 .sessions()
84 .create_session(
85 Uuid::parse_str(&payload.agent_id).map_err(invalid_argument)?,
86 Uuid::parse_str(&payload.workspace_id).map_err(invalid_argument)?,
87 &payload.role,
88 payload.scopes,
89 payload.ttl_secs as u64,
90 )
91 .await?;
92
93 Ok(Response::new(CreateSessionResponse {
94 session_id: session.id.to_string(),
95 agent_id: session.agent_id.to_string(),
96 workspace_id: session.workspace_id.to_string(),
97 role: session.role,
98 scopes: session.scopes,
99 token: session.token,
100 expires_at: session.expires_at.timestamp(),
101 }))
102 }
103
104 async fn validate_session(
105 &self,
106 request: Request<ValidateSessionRequest>,
107 ) -> Result<Response<ValidateSessionResponse>, Status> {
108 let session = self
109 .guard
110 .sessions()
111 .validate_session(&request.into_inner().token)
112 .await;
113 let response = match session {
114 Ok(session) => ValidateSessionResponse {
115 valid: true,
116 session_id: session.id.to_string(),
117 agent_id: session.agent_id.to_string(),
118 workspace_id: session.workspace_id.to_string(),
119 role: session.role,
120 scopes: session.scopes,
121 expires_at: session.expires_at.timestamp(),
122 },
123 Err(_) => ValidateSessionResponse {
124 valid: false,
125 session_id: String::new(),
126 agent_id: String::new(),
127 workspace_id: String::new(),
128 role: String::new(),
129 scopes: Vec::new(),
130 expires_at: 0,
131 },
132 };
133 Ok(Response::new(response))
134 }
135
136 async fn revoke_session(
137 &self,
138 request: Request<RevokeSessionRequest>,
139 ) -> Result<Response<RevokeSessionResponse>, Status> {
140 let session_id =
141 Uuid::parse_str(&request.into_inner().session_id).map_err(invalid_argument)?;
142 self.guard.sessions().revoke_session(session_id).await?;
143 Ok(Response::new(RevokeSessionResponse { revoked: true }))
144 }
145
146 async fn add_policy(
147 &self,
148 request: Request<AddPolicyRequest>,
149 ) -> Result<Response<AddPolicyResponse>, Status> {
150 let payload = request.into_inner();
151 let policy = self
152 .guard
153 .policy_engine()
154 .add_policy_from_toml(&payload.policy_toml, &payload.name)
155 .await?;
156 Ok(Response::new(AddPolicyResponse {
157 policy: Some(policy_to_proto(policy)?),
158 }))
159 }
160
161 async fn list_policies(
162 &self,
163 _request: Request<ListPoliciesRequest>,
164 ) -> Result<Response<ListPoliciesResponse>, Status> {
165 let policies = self
166 .guard
167 .policy_engine()
168 .list_policies()
169 .await?
170 .into_iter()
171 .map(policy_to_proto)
172 .collect::<Result<Vec<_>, _>>()?;
173 Ok(Response::new(ListPoliciesResponse { policies }))
174 }
175
176 async fn remove_policy(
177 &self,
178 request: Request<RemovePolicyRequest>,
179 ) -> Result<Response<RemovePolicyResponse>, Status> {
180 let policy_id =
181 Uuid::parse_str(&request.into_inner().policy_id).map_err(invalid_argument)?;
182 self.guard.policy_engine().remove_policy(policy_id).await?;
183 Ok(Response::new(RemovePolicyResponse { removed: true }))
184 }
185
186 async fn query_audit_log(
187 &self,
188 request: Request<QueryAuditLogRequest>,
189 ) -> Result<Response<QueryAuditLogResponse>, Status> {
190 let payload = request.into_inner();
191 let entries = self
192 .guard
193 .query_audit(AuditFilter {
194 workspace_id: parse_optional_uuid(&payload.workspace_id)?,
195 session_id: parse_optional_uuid(&payload.session_id)?,
196 decision: (!payload.decision.is_empty()).then_some(payload.decision),
197 start_time: (payload.start_ts > 0).then_some(from_secs(payload.start_ts)?),
198 end_time: (payload.end_ts > 0).then_some(from_secs(payload.end_ts)?),
199 resource: (!payload.resource.is_empty()).then_some(payload.resource),
200 limit: (payload.limit > 0).then_some(payload.limit),
201 })
202 .await?;
203
204 Ok(Response::new(QueryAuditLogResponse {
205 entries: entries
206 .into_iter()
207 .map(|entry| AuditEntryMessage {
208 id: entry.id.to_string(),
209 session_id: entry
210 .session_id
211 .map(|value| value.to_string())
212 .unwrap_or_default(),
213 workspace_id: entry.workspace_id.to_string(),
214 agent_id: entry
215 .agent_id
216 .map(|value| value.to_string())
217 .unwrap_or_default(),
218 action: entry.action,
219 resource: entry.resource,
220 resource_id: entry.resource_id.unwrap_or_default(),
221 decision: entry.decision,
222 reason: entry.reason.unwrap_or_default(),
223 risk_score: entry.risk_score,
224 metadata_json: serde_json::to_string(&entry.metadata)
225 .unwrap_or_else(|_| "{}".to_owned()),
226 ts: entry.ts.timestamp(),
227 })
228 .collect(),
229 }))
230 }
231}
232
233pub async fn serve(guard: Arc<Guard>, addr: SocketAddr) -> GuardResult<()> {
235 Server::builder()
236 .add_service(GrpcGuardService::new(guard).service())
237 .serve(addr)
238 .await
239 .map_err(|error| GuardError::ConfigError(format!("gRPC server failed: {error}")))
240}
241
242fn parse_optional_uuid(value: &str) -> Result<Option<Uuid>, Status> {
243 if value.is_empty() {
244 Ok(None)
245 } else {
246 Uuid::parse_str(value).map(Some).map_err(invalid_argument)
247 }
248}
249
250fn invalid_argument(error: impl std::fmt::Display) -> Status {
251 Status::invalid_argument(error.to_string())
252}
253
254fn from_secs(value: i64) -> Result<chrono::DateTime<chrono::Utc>, Status> {
255 chrono::DateTime::from_timestamp(value, 0)
256 .ok_or_else(|| Status::invalid_argument("invalid timestamp"))
257}
258
259fn policy_to_proto(policy: crate::policy::Policy) -> Result<PolicyMessage, Status> {
260 Ok(PolicyMessage {
261 id: policy.id.to_string(),
262 name: policy.name,
263 description: policy.description.unwrap_or_default(),
264 priority: policy.priority,
265 enabled: policy.enabled,
266 rules_json: serde_json::to_string(&policy.rules).map_err(|error| {
267 Status::internal(format!("failed to serialize policy rules: {error}"))
268 })?,
269 })
270}