Skip to main content

dk_protocol/
server.rs

1use std::sync::Arc;
2
3use dk_engine::conflict::SymbolClaimTracker;
4use dk_engine::repo::Engine;
5use tonic::{Request, Response, Status};
6
7use crate::auth::AuthConfig;
8use crate::events::EventBus;
9use crate::session::{AgentSession, SessionManager};
10
11/// The gRPC server that implements the `AgentService` trait generated by
12/// tonic from the `agent.proto` definition.
13///
14/// Holds a shared reference to the [`Engine`] (internally concurrent via
15/// fine-grained locks) and a [`SessionManager`] for stateful agent sessions.
16pub struct ProtocolServer {
17    pub(crate) engine: Arc<Engine>,
18    pub(crate) session_mgr: Arc<SessionManager>,
19    pub(crate) auth_config: AuthConfig,
20    pub(crate) event_bus: Arc<EventBus>,
21    pub(crate) claim_tracker: Arc<SymbolClaimTracker>,
22}
23
24impl ProtocolServer {
25    /// Create a new `ProtocolServer`.
26    ///
27    /// `auth_config` controls how agents authenticate on every
28    /// `ConnectRequest`.  The session timeout is fixed at 30 minutes.
29    pub fn new(engine: Arc<Engine>, auth_config: AuthConfig) -> Self {
30        Self {
31            engine,
32            session_mgr: Arc::new(SessionManager::new(std::time::Duration::from_secs(
33                30 * 60,
34            ))),
35            auth_config,
36            event_bus: Arc::new(EventBus::new()),
37            claim_tracker: Arc::new(SymbolClaimTracker::new()),
38        }
39    }
40
41    /// Borrow the engine.
42    pub fn engine(&self) -> &Engine {
43        &self.engine
44    }
45
46    /// Borrow the session manager.
47    pub fn session_mgr(&self) -> &SessionManager {
48        &self.session_mgr
49    }
50
51    /// Borrow the shared event bus.
52    pub fn event_bus(&self) -> &EventBus {
53        &self.event_bus
54    }
55
56    /// Borrow the shared symbol claim tracker.
57    pub fn claim_tracker(&self) -> &SymbolClaimTracker {
58        &self.claim_tracker
59    }
60
61    /// Validate an auth token against the configured secret.
62    pub(crate) fn validate_auth(&self, token: &str) -> Result<String, Status> {
63        self.auth_config.validate(token)
64    }
65
66    /// Look up a session by its string-encoded UUID.  Returns an error if the
67    /// ID is malformed or the session has expired / does not exist.
68    pub(crate) fn validate_session(&self, session_id_str: &str) -> Result<AgentSession, Status> {
69        let sid = session_id_str
70            .parse::<uuid::Uuid>()
71            .map_err(|_| Status::invalid_argument("Invalid session ID format"))?;
72        self.session_mgr
73            .get_session(&sid)
74            .ok_or_else(|| Status::not_found("Session not found or expired"))
75    }
76}
77
78// ── AgentService tonic trait implementation ──
79
80#[tonic::async_trait]
81impl crate::agent_service_server::AgentService for ProtocolServer {
82    async fn connect(
83        &self,
84        request: Request<crate::ConnectRequest>,
85    ) -> Result<Response<crate::ConnectResponse>, Status> {
86        crate::connect::handle_connect(self, request.into_inner()).await
87    }
88
89    async fn context(
90        &self,
91        request: Request<crate::ContextRequest>,
92    ) -> Result<Response<crate::ContextResponse>, Status> {
93        crate::context::handle_context(self, request.into_inner()).await
94    }
95
96    async fn submit(
97        &self,
98        request: Request<crate::SubmitRequest>,
99    ) -> Result<Response<crate::SubmitResponse>, Status> {
100        crate::submit::handle_submit(self, request.into_inner()).await
101    }
102
103    type VerifyStream = tokio_stream::wrappers::ReceiverStream<Result<crate::VerifyStepResult, Status>>;
104
105    async fn verify(
106        &self,
107        request: Request<crate::VerifyRequest>,
108    ) -> Result<Response<Self::VerifyStream>, Status> {
109        let req = request.into_inner();
110        let (tx, rx) = tokio::sync::mpsc::channel(32);
111
112        let server_clone = ProtocolServer {
113            engine: self.engine.clone(),
114            session_mgr: self.session_mgr.clone(),
115            auth_config: self.auth_config.clone(),
116            event_bus: self.event_bus.clone(),
117            claim_tracker: self.claim_tracker.clone(),
118        };
119
120        tokio::spawn(async move {
121            crate::verify::handle_verify(&server_clone, req, tx).await;
122        });
123
124        Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
125    }
126
127    async fn merge(
128        &self,
129        request: Request<crate::MergeRequest>,
130    ) -> Result<Response<crate::MergeResponse>, Status> {
131        let resp = crate::merge::handle_merge(self, request.into_inner()).await?;
132        Ok(Response::new(resp))
133    }
134
135    type WatchStream = tokio_stream::wrappers::ReceiverStream<Result<crate::WatchEvent, Status>>;
136
137    async fn watch(
138        &self,
139        request: Request<crate::WatchRequest>,
140    ) -> Result<Response<Self::WatchStream>, Status> {
141        let req = request.into_inner();
142        let (tx, rx) = tokio::sync::mpsc::channel(64);
143        let server_clone = ProtocolServer {
144            engine: self.engine.clone(),
145            session_mgr: self.session_mgr.clone(),
146            auth_config: self.auth_config.clone(),
147            event_bus: self.event_bus.clone(),
148            claim_tracker: self.claim_tracker.clone(),
149        };
150        tokio::spawn(async move {
151            crate::watch::handle_watch(&server_clone, req, tx).await;
152        });
153        Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
154    }
155
156    async fn file_read(
157        &self,
158        request: Request<crate::FileReadRequest>,
159    ) -> Result<Response<crate::FileReadResponse>, Status> {
160        crate::file_read::handle_file_read(self, request.into_inner()).await
161    }
162
163    async fn file_write(
164        &self,
165        request: Request<crate::FileWriteRequest>,
166    ) -> Result<Response<crate::FileWriteResponse>, Status> {
167        crate::file_write::handle_file_write(self, request.into_inner()).await
168    }
169
170    async fn file_list(
171        &self,
172        request: Request<crate::FileListRequest>,
173    ) -> Result<Response<crate::FileListResponse>, Status> {
174        crate::file_list::handle_file_list(self, request.into_inner()).await
175    }
176
177    async fn pre_submit_check(
178        &self,
179        request: Request<crate::PreSubmitCheckRequest>,
180    ) -> Result<Response<crate::PreSubmitCheckResponse>, Status> {
181        crate::pre_submit::handle_pre_submit_check(self, request.into_inner()).await
182    }
183
184    async fn get_session_status(
185        &self,
186        request: Request<crate::SessionStatusRequest>,
187    ) -> Result<Response<crate::SessionStatusResponse>, Status> {
188        crate::session_status::handle_get_session_status(self, request.into_inner()).await
189    }
190
191    async fn push(
192        &self,
193        request: Request<crate::PushRequest>,
194    ) -> Result<Response<crate::PushResponse>, Status> {
195        let resp = crate::push::handle_push(self, request.into_inner()).await?;
196        Ok(Response::new(resp))
197    }
198
199    async fn approve(
200        &self,
201        _request: Request<crate::ApproveRequest>,
202    ) -> Result<Response<crate::ApproveResponse>, Status> {
203        Err(Status::unimplemented(
204            "approve is a platform-level operation; use the managed server",
205        ))
206    }
207}