1use std::sync::Arc;
2
3use dk_engine::conflict::SymbolClaimTracker;
4use dk_engine::repo::Engine;
5use tonic::{Request, Response, Status};
6
7use crate::auth::AuthConfig;
8use crate::events::EventBus;
9use crate::session::{AgentSession, SessionManager};
10
11pub struct ProtocolServer {
17 pub(crate) engine: Arc<Engine>,
18 pub(crate) session_mgr: Arc<SessionManager>,
19 pub(crate) auth_config: AuthConfig,
20 pub(crate) event_bus: Arc<EventBus>,
21 pub(crate) claim_tracker: Arc<SymbolClaimTracker>,
22}
23
24impl ProtocolServer {
25 pub fn new(engine: Arc<Engine>, auth_config: AuthConfig) -> Self {
30 Self {
31 engine,
32 session_mgr: Arc::new(SessionManager::new(std::time::Duration::from_secs(
33 30 * 60,
34 ))),
35 auth_config,
36 event_bus: Arc::new(EventBus::new()),
37 claim_tracker: Arc::new(SymbolClaimTracker::new()),
38 }
39 }
40
41 pub fn engine(&self) -> &Engine {
43 &self.engine
44 }
45
46 pub fn session_mgr(&self) -> &SessionManager {
48 &self.session_mgr
49 }
50
51 pub fn event_bus(&self) -> &EventBus {
53 &self.event_bus
54 }
55
56 pub fn event_bus_arc(&self) -> Arc<EventBus> {
62 Arc::clone(&self.event_bus)
63 }
64
65 pub fn claim_tracker(&self) -> &SymbolClaimTracker {
67 &self.claim_tracker
68 }
69
70 pub(crate) fn validate_auth(&self, token: &str) -> Result<String, Status> {
72 self.auth_config.validate(token)
73 }
74
75 pub(crate) fn validate_session(&self, session_id_str: &str) -> Result<AgentSession, Status> {
78 let sid = session_id_str
79 .parse::<uuid::Uuid>()
80 .map_err(|_| Status::invalid_argument("Invalid session ID format"))?;
81 self.session_mgr
82 .get_session(&sid)
83 .ok_or_else(|| Status::not_found("Session not found or expired"))
84 }
85}
86
87#[tonic::async_trait]
90impl crate::agent_service_server::AgentService for ProtocolServer {
91 async fn connect(
92 &self,
93 request: Request<crate::ConnectRequest>,
94 ) -> Result<Response<crate::ConnectResponse>, Status> {
95 crate::connect::handle_connect(self, request.into_inner()).await
96 }
97
98 async fn context(
99 &self,
100 request: Request<crate::ContextRequest>,
101 ) -> Result<Response<crate::ContextResponse>, Status> {
102 crate::context::handle_context(self, request.into_inner()).await
103 }
104
105 async fn submit(
106 &self,
107 request: Request<crate::SubmitRequest>,
108 ) -> Result<Response<crate::SubmitResponse>, Status> {
109 crate::submit::handle_submit(self, request.into_inner()).await
110 }
111
112 type VerifyStream = tokio_stream::wrappers::ReceiverStream<Result<crate::VerifyStepResult, Status>>;
113
114 async fn verify(
115 &self,
116 request: Request<crate::VerifyRequest>,
117 ) -> Result<Response<Self::VerifyStream>, Status> {
118 let req = request.into_inner();
119 let (tx, rx) = tokio::sync::mpsc::channel(32);
120
121 let server_clone = ProtocolServer {
122 engine: self.engine.clone(),
123 session_mgr: self.session_mgr.clone(),
124 auth_config: self.auth_config.clone(),
125 event_bus: self.event_bus.clone(),
126 claim_tracker: self.claim_tracker.clone(),
127 };
128
129 tokio::spawn(async move {
130 crate::verify::handle_verify(&server_clone, req, tx).await;
131 });
132
133 Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
134 }
135
136 async fn merge(
137 &self,
138 request: Request<crate::MergeRequest>,
139 ) -> Result<Response<crate::MergeResponse>, Status> {
140 let resp = crate::merge::handle_merge(self, request.into_inner()).await?;
141 Ok(Response::new(resp))
142 }
143
144 type WatchStream = tokio_stream::wrappers::ReceiverStream<Result<crate::WatchEvent, Status>>;
145
146 async fn watch(
147 &self,
148 request: Request<crate::WatchRequest>,
149 ) -> Result<Response<Self::WatchStream>, Status> {
150 let req = request.into_inner();
151 let (tx, rx) = tokio::sync::mpsc::channel(64);
152 let server_clone = ProtocolServer {
153 engine: self.engine.clone(),
154 session_mgr: self.session_mgr.clone(),
155 auth_config: self.auth_config.clone(),
156 event_bus: self.event_bus.clone(),
157 claim_tracker: self.claim_tracker.clone(),
158 };
159 tokio::spawn(async move {
160 crate::watch::handle_watch(&server_clone, req, tx).await;
161 });
162 Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
163 }
164
165 async fn file_read(
166 &self,
167 request: Request<crate::FileReadRequest>,
168 ) -> Result<Response<crate::FileReadResponse>, Status> {
169 crate::file_read::handle_file_read(self, request.into_inner()).await
170 }
171
172 async fn file_write(
173 &self,
174 request: Request<crate::FileWriteRequest>,
175 ) -> Result<Response<crate::FileWriteResponse>, Status> {
176 crate::file_write::handle_file_write(self, request.into_inner()).await
177 }
178
179 async fn file_list(
180 &self,
181 request: Request<crate::FileListRequest>,
182 ) -> Result<Response<crate::FileListResponse>, Status> {
183 crate::file_list::handle_file_list(self, request.into_inner()).await
184 }
185
186 async fn pre_submit_check(
187 &self,
188 request: Request<crate::PreSubmitCheckRequest>,
189 ) -> Result<Response<crate::PreSubmitCheckResponse>, Status> {
190 crate::pre_submit::handle_pre_submit_check(self, request.into_inner()).await
191 }
192
193 async fn get_session_status(
194 &self,
195 request: Request<crate::SessionStatusRequest>,
196 ) -> Result<Response<crate::SessionStatusResponse>, Status> {
197 crate::session_status::handle_get_session_status(self, request.into_inner()).await
198 }
199
200 async fn push(
201 &self,
202 request: Request<crate::PushRequest>,
203 ) -> Result<Response<crate::PushResponse>, Status> {
204 let resp = crate::push::handle_push(self, request.into_inner()).await?;
205 Ok(Response::new(resp))
206 }
207
208 async fn approve(
209 &self,
210 _request: Request<crate::ApproveRequest>,
211 ) -> Result<Response<crate::ApproveResponse>, Status> {
212 Err(Status::unimplemented(
213 "approve is a platform-level operation; use the managed server",
214 ))
215 }
216
217 async fn resolve(
218 &self,
219 _request: Request<crate::ResolveRequest>,
220 ) -> Result<Response<crate::ResolveResponse>, Status> {
221 Err(Status::unimplemented(
222 "resolve is a platform-level operation; use the managed server",
223 ))
224 }
225}