Skip to main content

dk_protocol/
server.rs

1use std::sync::Arc;
2
3use dk_engine::conflict::{ClaimTracker, LocalClaimTracker};
4use dk_engine::repo::Engine;
5use tonic::{Request, Response, Status};
6
7use crate::auth::AuthConfig;
8use crate::events::EventBus;
9use crate::session::{AgentSession, SessionManager};
10
11/// The gRPC server that implements the `AgentService` trait generated by
12/// tonic from the `agent.proto` definition.
13///
14/// Holds a shared reference to the [`Engine`] (internally concurrent via
15/// fine-grained locks) and a [`SessionManager`] for stateful agent sessions.
16pub struct ProtocolServer {
17    pub(crate) engine: Arc<Engine>,
18    pub(crate) session_mgr: Arc<SessionManager>,
19    pub(crate) auth_config: AuthConfig,
20    pub(crate) event_bus: Arc<EventBus>,
21    pub(crate) claim_tracker: Arc<dyn ClaimTracker>,
22}
23
24impl ProtocolServer {
25    /// Create a new `ProtocolServer`.
26    ///
27    /// `auth_config` controls how agents authenticate on every
28    /// `ConnectRequest`.  The session timeout is fixed at 30 minutes.
29    pub fn new(engine: Arc<Engine>, auth_config: AuthConfig) -> Self {
30        Self {
31            engine,
32            session_mgr: Arc::new(SessionManager::new(std::time::Duration::from_secs(
33                30 * 60,
34            ))),
35            auth_config,
36            event_bus: Arc::new(EventBus::new()),
37            claim_tracker: Arc::new(LocalClaimTracker::new()),
38        }
39    }
40
41    /// Borrow the engine.
42    pub fn engine(&self) -> &Engine {
43        &self.engine
44    }
45
46    /// Borrow the session manager.
47    pub fn session_mgr(&self) -> &SessionManager {
48        &self.session_mgr
49    }
50
51    /// Borrow the shared event bus.
52    pub fn event_bus(&self) -> &EventBus {
53        &self.event_bus
54    }
55
56    /// Return a cloned `Arc` handle to the shared event bus.
57    ///
58    /// Callers that need to publish to the engine event bus from a spawned
59    /// task (where a borrow of `&ProtocolServer` is not available) should
60    /// use this to obtain a cheaply-clonable, ownership-safe handle.
61    pub fn event_bus_arc(&self) -> Arc<EventBus> {
62        Arc::clone(&self.event_bus)
63    }
64
65    /// Borrow the shared symbol claim tracker.
66    pub fn claim_tracker(&self) -> &dyn ClaimTracker {
67        &*self.claim_tracker
68    }
69
70    /// Create a `ProtocolServer` with a custom claim tracker implementation.
71    pub fn with_claim_tracker(
72        engine: Arc<Engine>,
73        auth_config: AuthConfig,
74        claim_tracker: Arc<dyn ClaimTracker>,
75    ) -> Self {
76        Self {
77            engine,
78            session_mgr: Arc::new(SessionManager::new(std::time::Duration::from_secs(
79                30 * 60,
80            ))),
81            auth_config,
82            event_bus: Arc::new(EventBus::new()),
83            claim_tracker,
84        }
85    }
86
87    /// Validate an auth token against the configured secret.
88    pub(crate) fn validate_auth(&self, token: &str) -> Result<String, Status> {
89        self.auth_config.validate(token)
90    }
91
92    /// Look up a session by its string-encoded UUID.  Returns an error if the
93    /// ID is malformed or the session has expired / does not exist.
94    pub(crate) fn validate_session(&self, session_id_str: &str) -> Result<AgentSession, Status> {
95        let sid = session_id_str
96            .parse::<uuid::Uuid>()
97            .map_err(|_| Status::invalid_argument("Invalid session ID format"))?;
98        self.session_mgr
99            .get_session(&sid)
100            .ok_or_else(|| Status::not_found("Session not found or expired"))
101    }
102}
103
104// ── AgentService tonic trait implementation ──
105
106#[tonic::async_trait]
107impl crate::agent_service_server::AgentService for ProtocolServer {
108    async fn connect(
109        &self,
110        request: Request<crate::ConnectRequest>,
111    ) -> Result<Response<crate::ConnectResponse>, Status> {
112        crate::connect::handle_connect(self, request.into_inner()).await
113    }
114
115    async fn context(
116        &self,
117        request: Request<crate::ContextRequest>,
118    ) -> Result<Response<crate::ContextResponse>, Status> {
119        crate::context::handle_context(self, request.into_inner()).await
120    }
121
122    async fn submit(
123        &self,
124        request: Request<crate::SubmitRequest>,
125    ) -> Result<Response<crate::SubmitResponse>, Status> {
126        crate::submit::handle_submit(self, request.into_inner()).await
127    }
128
129    type VerifyStream = tokio_stream::wrappers::ReceiverStream<Result<crate::VerifyStepResult, Status>>;
130
131    async fn verify(
132        &self,
133        request: Request<crate::VerifyRequest>,
134    ) -> Result<Response<Self::VerifyStream>, Status> {
135        let req = request.into_inner();
136        let (tx, rx) = tokio::sync::mpsc::channel(32);
137
138        let server_clone = ProtocolServer {
139            engine: self.engine.clone(),
140            session_mgr: self.session_mgr.clone(),
141            auth_config: self.auth_config.clone(),
142            event_bus: self.event_bus.clone(),
143            claim_tracker: self.claim_tracker.clone(),
144        };
145
146        tokio::spawn(async move {
147            crate::verify::handle_verify(&server_clone, req, tx).await;
148        });
149
150        Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
151    }
152
153    async fn merge(
154        &self,
155        request: Request<crate::MergeRequest>,
156    ) -> Result<Response<crate::MergeResponse>, Status> {
157        let resp = crate::merge::handle_merge(self, request.into_inner()).await?;
158        Ok(Response::new(resp))
159    }
160
161    type WatchStream = tokio_stream::wrappers::ReceiverStream<Result<crate::WatchEvent, Status>>;
162
163    async fn watch(
164        &self,
165        request: Request<crate::WatchRequest>,
166    ) -> Result<Response<Self::WatchStream>, Status> {
167        let req = request.into_inner();
168        let (tx, rx) = tokio::sync::mpsc::channel(64);
169        let server_clone = ProtocolServer {
170            engine: self.engine.clone(),
171            session_mgr: self.session_mgr.clone(),
172            auth_config: self.auth_config.clone(),
173            event_bus: self.event_bus.clone(),
174            claim_tracker: self.claim_tracker.clone(),
175        };
176        tokio::spawn(async move {
177            crate::watch::handle_watch(&server_clone, req, tx).await;
178        });
179        Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
180    }
181
182    async fn file_read(
183        &self,
184        request: Request<crate::FileReadRequest>,
185    ) -> Result<Response<crate::FileReadResponse>, Status> {
186        crate::file_read::handle_file_read(self, request.into_inner()).await
187    }
188
189    async fn file_write(
190        &self,
191        request: Request<crate::FileWriteRequest>,
192    ) -> Result<Response<crate::FileWriteResponse>, Status> {
193        crate::file_write::handle_file_write(self, request.into_inner()).await
194    }
195
196    async fn file_list(
197        &self,
198        request: Request<crate::FileListRequest>,
199    ) -> Result<Response<crate::FileListResponse>, Status> {
200        crate::file_list::handle_file_list(self, request.into_inner()).await
201    }
202
203    async fn pre_submit_check(
204        &self,
205        request: Request<crate::PreSubmitCheckRequest>,
206    ) -> Result<Response<crate::PreSubmitCheckResponse>, Status> {
207        crate::pre_submit::handle_pre_submit_check(self, request.into_inner()).await
208    }
209
210    async fn get_session_status(
211        &self,
212        request: Request<crate::SessionStatusRequest>,
213    ) -> Result<Response<crate::SessionStatusResponse>, Status> {
214        crate::session_status::handle_get_session_status(self, request.into_inner()).await
215    }
216
217    async fn push(
218        &self,
219        request: Request<crate::PushRequest>,
220    ) -> Result<Response<crate::PushResponse>, Status> {
221        let resp = crate::push::handle_push(self, request.into_inner()).await?;
222        Ok(Response::new(resp))
223    }
224
225    async fn approve(
226        &self,
227        _request: Request<crate::ApproveRequest>,
228    ) -> Result<Response<crate::ApproveResponse>, Status> {
229        Err(Status::unimplemented(
230            "approve is a platform-level operation; use the managed server",
231        ))
232    }
233
234    async fn resolve(
235        &self,
236        _request: Request<crate::ResolveRequest>,
237    ) -> Result<Response<crate::ResolveResponse>, Status> {
238        Err(Status::unimplemented(
239            "resolve is a platform-level operation; use the managed server",
240        ))
241    }
242
243    async fn close(
244        &self,
245        _request: Request<crate::CloseRequest>,
246    ) -> Result<Response<crate::CloseResponse>, Status> {
247        Err(Status::unimplemented(
248            "close is a platform-level operation; use the managed server",
249        ))
250    }
251
252    async fn review(
253        &self,
254        _request: Request<crate::ReviewRequest>,
255    ) -> Result<Response<crate::ReviewResponse>, Status> {
256        Err(Status::unimplemented(
257            "review is a platform-level operation; use the managed server",
258        ))
259    }
260}