Skip to main content

dk_protocol/
server.rs

1use std::sync::Arc;
2
3use dk_engine::conflict::SymbolClaimTracker;
4use dk_engine::repo::Engine;
5use tonic::{Request, Response, Status};
6
7use crate::auth::AuthConfig;
8use crate::events::EventBus;
9use crate::session::{AgentSession, SessionManager};
10
11/// The gRPC server that implements the `AgentService` trait generated by
12/// tonic from the `agent.proto` definition.
13///
14/// Holds a shared reference to the [`Engine`] (internally concurrent via
15/// fine-grained locks) and a [`SessionManager`] for stateful agent sessions.
16pub struct ProtocolServer {
17    pub(crate) engine: Arc<Engine>,
18    pub(crate) session_mgr: Arc<SessionManager>,
19    pub(crate) auth_config: AuthConfig,
20    pub(crate) event_bus: Arc<EventBus>,
21    pub(crate) claim_tracker: Arc<SymbolClaimTracker>,
22}
23
24impl ProtocolServer {
25    /// Create a new `ProtocolServer`.
26    ///
27    /// `auth_config` controls how agents authenticate on every
28    /// `ConnectRequest`.  The session timeout is fixed at 30 minutes.
29    pub fn new(engine: Arc<Engine>, auth_config: AuthConfig) -> Self {
30        Self {
31            engine,
32            session_mgr: Arc::new(SessionManager::new(std::time::Duration::from_secs(
33                30 * 60,
34            ))),
35            auth_config,
36            event_bus: Arc::new(EventBus::new()),
37            claim_tracker: Arc::new(SymbolClaimTracker::new()),
38        }
39    }
40
41    /// Borrow the engine.
42    pub fn engine(&self) -> &Engine {
43        &self.engine
44    }
45
46    /// Borrow the session manager.
47    pub fn session_mgr(&self) -> &SessionManager {
48        &self.session_mgr
49    }
50
51    /// Borrow the shared event bus.
52    pub fn event_bus(&self) -> &EventBus {
53        &self.event_bus
54    }
55
56    /// Return a cloned `Arc` handle to the shared event bus.
57    ///
58    /// Callers that need to publish to the engine event bus from a spawned
59    /// task (where a borrow of `&ProtocolServer` is not available) should
60    /// use this to obtain a cheaply-clonable, ownership-safe handle.
61    pub fn event_bus_arc(&self) -> Arc<EventBus> {
62        Arc::clone(&self.event_bus)
63    }
64
65    /// Borrow the shared symbol claim tracker.
66    pub fn claim_tracker(&self) -> &SymbolClaimTracker {
67        &self.claim_tracker
68    }
69
70    /// Validate an auth token against the configured secret.
71    pub(crate) fn validate_auth(&self, token: &str) -> Result<String, Status> {
72        self.auth_config.validate(token)
73    }
74
75    /// Look up a session by its string-encoded UUID.  Returns an error if the
76    /// ID is malformed or the session has expired / does not exist.
77    pub(crate) fn validate_session(&self, session_id_str: &str) -> Result<AgentSession, Status> {
78        let sid = session_id_str
79            .parse::<uuid::Uuid>()
80            .map_err(|_| Status::invalid_argument("Invalid session ID format"))?;
81        self.session_mgr
82            .get_session(&sid)
83            .ok_or_else(|| Status::not_found("Session not found or expired"))
84    }
85}
86
87// ── AgentService tonic trait implementation ──
88
89#[tonic::async_trait]
90impl crate::agent_service_server::AgentService for ProtocolServer {
91    async fn connect(
92        &self,
93        request: Request<crate::ConnectRequest>,
94    ) -> Result<Response<crate::ConnectResponse>, Status> {
95        crate::connect::handle_connect(self, request.into_inner()).await
96    }
97
98    async fn context(
99        &self,
100        request: Request<crate::ContextRequest>,
101    ) -> Result<Response<crate::ContextResponse>, Status> {
102        crate::context::handle_context(self, request.into_inner()).await
103    }
104
105    async fn submit(
106        &self,
107        request: Request<crate::SubmitRequest>,
108    ) -> Result<Response<crate::SubmitResponse>, Status> {
109        crate::submit::handle_submit(self, request.into_inner()).await
110    }
111
112    type VerifyStream = tokio_stream::wrappers::ReceiverStream<Result<crate::VerifyStepResult, Status>>;
113
114    async fn verify(
115        &self,
116        request: Request<crate::VerifyRequest>,
117    ) -> Result<Response<Self::VerifyStream>, Status> {
118        let req = request.into_inner();
119        let (tx, rx) = tokio::sync::mpsc::channel(32);
120
121        let server_clone = ProtocolServer {
122            engine: self.engine.clone(),
123            session_mgr: self.session_mgr.clone(),
124            auth_config: self.auth_config.clone(),
125            event_bus: self.event_bus.clone(),
126            claim_tracker: self.claim_tracker.clone(),
127        };
128
129        tokio::spawn(async move {
130            crate::verify::handle_verify(&server_clone, req, tx).await;
131        });
132
133        Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
134    }
135
136    async fn merge(
137        &self,
138        request: Request<crate::MergeRequest>,
139    ) -> Result<Response<crate::MergeResponse>, Status> {
140        let resp = crate::merge::handle_merge(self, request.into_inner()).await?;
141        Ok(Response::new(resp))
142    }
143
144    type WatchStream = tokio_stream::wrappers::ReceiverStream<Result<crate::WatchEvent, Status>>;
145
146    async fn watch(
147        &self,
148        request: Request<crate::WatchRequest>,
149    ) -> Result<Response<Self::WatchStream>, Status> {
150        let req = request.into_inner();
151        let (tx, rx) = tokio::sync::mpsc::channel(64);
152        let server_clone = ProtocolServer {
153            engine: self.engine.clone(),
154            session_mgr: self.session_mgr.clone(),
155            auth_config: self.auth_config.clone(),
156            event_bus: self.event_bus.clone(),
157            claim_tracker: self.claim_tracker.clone(),
158        };
159        tokio::spawn(async move {
160            crate::watch::handle_watch(&server_clone, req, tx).await;
161        });
162        Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
163    }
164
165    async fn file_read(
166        &self,
167        request: Request<crate::FileReadRequest>,
168    ) -> Result<Response<crate::FileReadResponse>, Status> {
169        crate::file_read::handle_file_read(self, request.into_inner()).await
170    }
171
172    async fn file_write(
173        &self,
174        request: Request<crate::FileWriteRequest>,
175    ) -> Result<Response<crate::FileWriteResponse>, Status> {
176        crate::file_write::handle_file_write(self, request.into_inner()).await
177    }
178
179    async fn file_list(
180        &self,
181        request: Request<crate::FileListRequest>,
182    ) -> Result<Response<crate::FileListResponse>, Status> {
183        crate::file_list::handle_file_list(self, request.into_inner()).await
184    }
185
186    async fn pre_submit_check(
187        &self,
188        request: Request<crate::PreSubmitCheckRequest>,
189    ) -> Result<Response<crate::PreSubmitCheckResponse>, Status> {
190        crate::pre_submit::handle_pre_submit_check(self, request.into_inner()).await
191    }
192
193    async fn get_session_status(
194        &self,
195        request: Request<crate::SessionStatusRequest>,
196    ) -> Result<Response<crate::SessionStatusResponse>, Status> {
197        crate::session_status::handle_get_session_status(self, request.into_inner()).await
198    }
199
200    async fn push(
201        &self,
202        request: Request<crate::PushRequest>,
203    ) -> Result<Response<crate::PushResponse>, Status> {
204        let resp = crate::push::handle_push(self, request.into_inner()).await?;
205        Ok(Response::new(resp))
206    }
207
208    async fn approve(
209        &self,
210        _request: Request<crate::ApproveRequest>,
211    ) -> Result<Response<crate::ApproveResponse>, Status> {
212        Err(Status::unimplemented(
213            "approve is a platform-level operation; use the managed server",
214        ))
215    }
216
217    async fn resolve(
218        &self,
219        _request: Request<crate::ResolveRequest>,
220    ) -> Result<Response<crate::ResolveResponse>, Status> {
221        Err(Status::unimplemented(
222            "resolve is a platform-level operation; use the managed server",
223        ))
224    }
225}