1use std::sync::Arc;
2
3use dk_engine::conflict::SymbolClaimTracker;
4use dk_engine::repo::Engine;
5use tonic::{Request, Response, Status};
6
7use crate::auth::AuthConfig;
8use crate::events::EventBus;
9use crate::session::{AgentSession, SessionManager};
10
11pub struct ProtocolServer {
17 pub(crate) engine: Arc<Engine>,
18 pub(crate) session_mgr: Arc<SessionManager>,
19 pub(crate) auth_config: AuthConfig,
20 pub(crate) event_bus: Arc<EventBus>,
21 pub(crate) claim_tracker: Arc<SymbolClaimTracker>,
22}
23
24impl ProtocolServer {
25 pub fn new(engine: Arc<Engine>, auth_config: AuthConfig) -> Self {
30 Self {
31 engine,
32 session_mgr: Arc::new(SessionManager::new(std::time::Duration::from_secs(
33 30 * 60,
34 ))),
35 auth_config,
36 event_bus: Arc::new(EventBus::new()),
37 claim_tracker: Arc::new(SymbolClaimTracker::new()),
38 }
39 }
40
41 pub fn engine(&self) -> &Engine {
43 &self.engine
44 }
45
46 pub fn session_mgr(&self) -> &SessionManager {
48 &self.session_mgr
49 }
50
51 pub fn event_bus(&self) -> &EventBus {
53 &self.event_bus
54 }
55
56 pub fn claim_tracker(&self) -> &SymbolClaimTracker {
58 &self.claim_tracker
59 }
60
61 pub(crate) fn validate_auth(&self, token: &str) -> Result<String, Status> {
63 self.auth_config.validate(token)
64 }
65
66 pub(crate) fn validate_session(&self, session_id_str: &str) -> Result<AgentSession, Status> {
69 let sid = session_id_str
70 .parse::<uuid::Uuid>()
71 .map_err(|_| Status::invalid_argument("Invalid session ID format"))?;
72 self.session_mgr
73 .get_session(&sid)
74 .ok_or_else(|| Status::not_found("Session not found or expired"))
75 }
76}
77
78#[tonic::async_trait]
81impl crate::agent_service_server::AgentService for ProtocolServer {
82 async fn connect(
83 &self,
84 request: Request<crate::ConnectRequest>,
85 ) -> Result<Response<crate::ConnectResponse>, Status> {
86 crate::connect::handle_connect(self, request.into_inner()).await
87 }
88
89 async fn context(
90 &self,
91 request: Request<crate::ContextRequest>,
92 ) -> Result<Response<crate::ContextResponse>, Status> {
93 crate::context::handle_context(self, request.into_inner()).await
94 }
95
96 async fn submit(
97 &self,
98 request: Request<crate::SubmitRequest>,
99 ) -> Result<Response<crate::SubmitResponse>, Status> {
100 crate::submit::handle_submit(self, request.into_inner()).await
101 }
102
103 type VerifyStream = tokio_stream::wrappers::ReceiverStream<Result<crate::VerifyStepResult, Status>>;
104
105 async fn verify(
106 &self,
107 request: Request<crate::VerifyRequest>,
108 ) -> Result<Response<Self::VerifyStream>, Status> {
109 let req = request.into_inner();
110 let (tx, rx) = tokio::sync::mpsc::channel(32);
111
112 let server_clone = ProtocolServer {
113 engine: self.engine.clone(),
114 session_mgr: self.session_mgr.clone(),
115 auth_config: self.auth_config.clone(),
116 event_bus: self.event_bus.clone(),
117 claim_tracker: self.claim_tracker.clone(),
118 };
119
120 tokio::spawn(async move {
121 crate::verify::handle_verify(&server_clone, req, tx).await;
122 });
123
124 Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
125 }
126
127 async fn merge(
128 &self,
129 request: Request<crate::MergeRequest>,
130 ) -> Result<Response<crate::MergeResponse>, Status> {
131 let resp = crate::merge::handle_merge(self, request.into_inner()).await?;
132 Ok(Response::new(resp))
133 }
134
135 type WatchStream = tokio_stream::wrappers::ReceiverStream<Result<crate::WatchEvent, Status>>;
136
137 async fn watch(
138 &self,
139 request: Request<crate::WatchRequest>,
140 ) -> Result<Response<Self::WatchStream>, Status> {
141 let req = request.into_inner();
142 let (tx, rx) = tokio::sync::mpsc::channel(64);
143 let server_clone = ProtocolServer {
144 engine: self.engine.clone(),
145 session_mgr: self.session_mgr.clone(),
146 auth_config: self.auth_config.clone(),
147 event_bus: self.event_bus.clone(),
148 claim_tracker: self.claim_tracker.clone(),
149 };
150 tokio::spawn(async move {
151 crate::watch::handle_watch(&server_clone, req, tx).await;
152 });
153 Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(rx)))
154 }
155
156 async fn file_read(
157 &self,
158 request: Request<crate::FileReadRequest>,
159 ) -> Result<Response<crate::FileReadResponse>, Status> {
160 crate::file_read::handle_file_read(self, request.into_inner()).await
161 }
162
163 async fn file_write(
164 &self,
165 request: Request<crate::FileWriteRequest>,
166 ) -> Result<Response<crate::FileWriteResponse>, Status> {
167 crate::file_write::handle_file_write(self, request.into_inner()).await
168 }
169
170 async fn file_list(
171 &self,
172 request: Request<crate::FileListRequest>,
173 ) -> Result<Response<crate::FileListResponse>, Status> {
174 crate::file_list::handle_file_list(self, request.into_inner()).await
175 }
176
177 async fn pre_submit_check(
178 &self,
179 request: Request<crate::PreSubmitCheckRequest>,
180 ) -> Result<Response<crate::PreSubmitCheckResponse>, Status> {
181 crate::pre_submit::handle_pre_submit_check(self, request.into_inner()).await
182 }
183
184 async fn get_session_status(
185 &self,
186 request: Request<crate::SessionStatusRequest>,
187 ) -> Result<Response<crate::SessionStatusResponse>, Status> {
188 crate::session_status::handle_get_session_status(self, request.into_inner()).await
189 }
190
191 async fn push(
192 &self,
193 request: Request<crate::PushRequest>,
194 ) -> Result<Response<crate::PushResponse>, Status> {
195 let resp = crate::push::handle_push(self, request.into_inner()).await?;
196 Ok(Response::new(resp))
197 }
198
199 async fn approve(
200 &self,
201 _request: Request<crate::ApproveRequest>,
202 ) -> Result<Response<crate::ApproveResponse>, Status> {
203 Err(Status::unimplemented(
204 "approve is a platform-level operation; use the managed server",
205 ))
206 }
207}