Skip to main content

claw_guard/
grpc.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use tonic::transport::Server;
5use tonic::{Request, Response, Status};
6use uuid::Uuid;
7
8use crate::audit::AuditFilter;
9use crate::error::{GuardError, GuardResult};
10use crate::guard::Guard;
11use crate::proto::guard::{
12    guard_service_server::{GuardService, GuardServiceServer},
13    AddPolicyRequest, AddPolicyResponse, AuditEntry as AuditEntryMessage, CheckAccessRequest,
14    CheckAccessResponse, CreateSessionRequest, CreateSessionResponse, ListPoliciesRequest,
15    ListPoliciesResponse, Policy as PolicyMessage, QueryAuditLogRequest, QueryAuditLogResponse,
16    RemovePolicyRequest, RemovePolicyResponse, RevokeSessionRequest, RevokeSessionResponse,
17    ValidateSessionRequest, ValidateSessionResponse,
18};
19use crate::types::PolicyDecision;
20
21/// gRPC wrapper over the guard engine.
22#[derive(Clone)]
23pub struct GrpcGuardService {
24    guard: Arc<Guard>,
25}
26
27impl GrpcGuardService {
28    /// Creates a new gRPC service.
29    pub fn new(guard: Arc<Guard>) -> Self {
30        Self { guard }
31    }
32
33    /// Returns the tonic service implementation.
34    pub fn service(self) -> GuardServiceServer<Self> {
35        GuardServiceServer::new(self)
36    }
37}
38
39#[tonic::async_trait]
40impl GuardService for GrpcGuardService {
41    async fn check_access(
42        &self,
43        request: Request<CheckAccessRequest>,
44    ) -> Result<Response<CheckAccessResponse>, Status> {
45        let payload = request.into_inner();
46        let session = self
47            .guard
48            .sessions()
49            .validate_session(&payload.session_token)
50            .await?;
51        let decision = self
52            .guard
53            .check_access_with_task(&session, &payload.action, &payload.resource, &payload.task)
54            .await?;
55
56        let response = match decision {
57            PolicyDecision::Allow => CheckAccessResponse {
58                decision: "allow".to_owned(),
59                reason: String::new(),
60                masked_fields: Vec::new(),
61            },
62            PolicyDecision::Deny { reason } => CheckAccessResponse {
63                decision: "deny".to_owned(),
64                reason,
65                masked_fields: Vec::new(),
66            },
67            PolicyDecision::Mask { fields } => CheckAccessResponse {
68                decision: "mask".to_owned(),
69                reason: String::new(),
70                masked_fields: fields,
71            },
72        };
73        Ok(Response::new(response))
74    }
75
76    async fn create_session(
77        &self,
78        request: Request<CreateSessionRequest>,
79    ) -> Result<Response<CreateSessionResponse>, Status> {
80        let payload = request.into_inner();
81        let session = self
82            .guard
83            .sessions()
84            .create_session(
85                Uuid::parse_str(&payload.agent_id).map_err(invalid_argument)?,
86                Uuid::parse_str(&payload.workspace_id).map_err(invalid_argument)?,
87                &payload.role,
88                payload.scopes,
89                payload.ttl_secs as u64,
90            )
91            .await?;
92
93        Ok(Response::new(CreateSessionResponse {
94            session_id: session.id.to_string(),
95            agent_id: session.agent_id.to_string(),
96            workspace_id: session.workspace_id.to_string(),
97            role: session.role,
98            scopes: session.scopes,
99            token: session.token,
100            expires_at: session.expires_at.timestamp(),
101        }))
102    }
103
104    async fn validate_session(
105        &self,
106        request: Request<ValidateSessionRequest>,
107    ) -> Result<Response<ValidateSessionResponse>, Status> {
108        let session = self
109            .guard
110            .sessions()
111            .validate_session(&request.into_inner().token)
112            .await;
113        let response = match session {
114            Ok(session) => ValidateSessionResponse {
115                valid: true,
116                session_id: session.id.to_string(),
117                agent_id: session.agent_id.to_string(),
118                workspace_id: session.workspace_id.to_string(),
119                role: session.role,
120                scopes: session.scopes,
121                expires_at: session.expires_at.timestamp(),
122            },
123            Err(_) => ValidateSessionResponse {
124                valid: false,
125                session_id: String::new(),
126                agent_id: String::new(),
127                workspace_id: String::new(),
128                role: String::new(),
129                scopes: Vec::new(),
130                expires_at: 0,
131            },
132        };
133        Ok(Response::new(response))
134    }
135
136    async fn revoke_session(
137        &self,
138        request: Request<RevokeSessionRequest>,
139    ) -> Result<Response<RevokeSessionResponse>, Status> {
140        let session_id =
141            Uuid::parse_str(&request.into_inner().session_id).map_err(invalid_argument)?;
142        self.guard.sessions().revoke_session(session_id).await?;
143        Ok(Response::new(RevokeSessionResponse { revoked: true }))
144    }
145
146    async fn add_policy(
147        &self,
148        request: Request<AddPolicyRequest>,
149    ) -> Result<Response<AddPolicyResponse>, Status> {
150        let payload = request.into_inner();
151        let policy = self
152            .guard
153            .policy_engine()
154            .add_policy_from_toml(&payload.policy_toml, &payload.name)
155            .await?;
156        Ok(Response::new(AddPolicyResponse {
157            policy: Some(policy_to_proto(policy)?),
158        }))
159    }
160
161    async fn list_policies(
162        &self,
163        _request: Request<ListPoliciesRequest>,
164    ) -> Result<Response<ListPoliciesResponse>, Status> {
165        let policies = self
166            .guard
167            .policy_engine()
168            .list_policies()
169            .await?
170            .into_iter()
171            .map(policy_to_proto)
172            .collect::<Result<Vec<_>, _>>()?;
173        Ok(Response::new(ListPoliciesResponse { policies }))
174    }
175
176    async fn remove_policy(
177        &self,
178        request: Request<RemovePolicyRequest>,
179    ) -> Result<Response<RemovePolicyResponse>, Status> {
180        let policy_id =
181            Uuid::parse_str(&request.into_inner().policy_id).map_err(invalid_argument)?;
182        self.guard.policy_engine().remove_policy(policy_id).await?;
183        Ok(Response::new(RemovePolicyResponse { removed: true }))
184    }
185
186    async fn query_audit_log(
187        &self,
188        request: Request<QueryAuditLogRequest>,
189    ) -> Result<Response<QueryAuditLogResponse>, Status> {
190        let payload = request.into_inner();
191        let entries = self
192            .guard
193            .query_audit(AuditFilter {
194                workspace_id: parse_optional_uuid(&payload.workspace_id)?,
195                session_id: parse_optional_uuid(&payload.session_id)?,
196                decision: (!payload.decision.is_empty()).then_some(payload.decision),
197                start_time: (payload.start_ts > 0).then_some(from_secs(payload.start_ts)?),
198                end_time: (payload.end_ts > 0).then_some(from_secs(payload.end_ts)?),
199                resource: (!payload.resource.is_empty()).then_some(payload.resource),
200                limit: (payload.limit > 0).then_some(payload.limit),
201            })
202            .await?;
203
204        Ok(Response::new(QueryAuditLogResponse {
205            entries: entries
206                .into_iter()
207                .map(|entry| AuditEntryMessage {
208                    id: entry.id.to_string(),
209                    session_id: entry
210                        .session_id
211                        .map(|value| value.to_string())
212                        .unwrap_or_default(),
213                    workspace_id: entry.workspace_id.to_string(),
214                    agent_id: entry
215                        .agent_id
216                        .map(|value| value.to_string())
217                        .unwrap_or_default(),
218                    action: entry.action,
219                    resource: entry.resource,
220                    resource_id: entry.resource_id.unwrap_or_default(),
221                    decision: entry.decision,
222                    reason: entry.reason.unwrap_or_default(),
223                    risk_score: entry.risk_score,
224                    metadata_json: serde_json::to_string(&entry.metadata)
225                        .unwrap_or_else(|_| "{}".to_owned()),
226                    ts: entry.ts.timestamp(),
227                })
228                .collect(),
229        }))
230    }
231}
232
233/// Starts the gRPC server.
234pub async fn serve(guard: Arc<Guard>, addr: SocketAddr) -> GuardResult<()> {
235    Server::builder()
236        .add_service(GrpcGuardService::new(guard).service())
237        .serve(addr)
238        .await
239        .map_err(|error| GuardError::ConfigError(format!("gRPC server failed: {error}")))
240}
241
242fn parse_optional_uuid(value: &str) -> Result<Option<Uuid>, Status> {
243    if value.is_empty() {
244        Ok(None)
245    } else {
246        Uuid::parse_str(value).map(Some).map_err(invalid_argument)
247    }
248}
249
250fn invalid_argument(error: impl std::fmt::Display) -> Status {
251    Status::invalid_argument(error.to_string())
252}
253
254fn from_secs(value: i64) -> Result<chrono::DateTime<chrono::Utc>, Status> {
255    chrono::DateTime::from_timestamp(value, 0)
256        .ok_or_else(|| Status::invalid_argument("invalid timestamp"))
257}
258
259fn policy_to_proto(policy: crate::policy::Policy) -> Result<PolicyMessage, Status> {
260    Ok(PolicyMessage {
261        id: policy.id.to_string(),
262        name: policy.name,
263        description: policy.description.unwrap_or_default(),
264        priority: policy.priority,
265        enabled: policy.enabled,
266        rules_json: serde_json::to_string(&policy.rules).map_err(|error| {
267            Status::internal(format!("failed to serialize policy rules: {error}"))
268        })?,
269    })
270}