1use std::path::Path;
6use std::sync::Arc;
7
8use nexo_driver_permission::{PermissionDecider, PermissionRequest, PermissionResponse};
9use serde::{Deserialize, Serialize};
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::net::{UnixListener, UnixStream};
12use tokio_util::sync::CancellationToken;
13
14use crate::error::DriverError;
15
16#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(tag = "kind", rename_all = "snake_case")]
20pub enum SocketMessage {
21 Decide {
22 id: String,
23 request: PermissionRequest,
24 },
25 Decision {
26 id: String,
27 response: PermissionResponse,
28 },
29 Error {
30 id: String,
31 message: String,
32 },
33 Shutdown,
34}
35
36pub struct DriverSocketServer {
37 listener: UnixListener,
38 decider: Arc<dyn PermissionDecider>,
39 cancel: CancellationToken,
40}
41
42impl DriverSocketServer {
43 pub async fn bind(
47 path: &Path,
48 decider: Arc<dyn PermissionDecider>,
49 cancel: CancellationToken,
50 ) -> Result<Self, DriverError> {
51 if let Some(parent) = path.parent() {
52 tokio::fs::create_dir_all(parent).await?;
53 }
54 if path.exists() {
55 match tokio::fs::remove_file(path).await {
57 Ok(()) => {}
58 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
59 Err(e) => return Err(DriverError::Socket(e.to_string())),
60 }
61 }
62 let listener = UnixListener::bind(path).map_err(|e| DriverError::Socket(e.to_string()))?;
63 #[cfg(unix)]
64 {
65 use std::os::unix::fs::PermissionsExt;
66 let mut perms = std::fs::metadata(path)
67 .map_err(|e| DriverError::Socket(e.to_string()))?
68 .permissions();
69 perms.set_mode(0o600);
70 std::fs::set_permissions(path, perms)
71 .map_err(|e| DriverError::Socket(e.to_string()))?;
72 }
73 Ok(Self {
74 listener,
75 decider,
76 cancel,
77 })
78 }
79
80 pub async fn run(self) -> Result<(), DriverError> {
83 loop {
84 tokio::select! {
85 _ = self.cancel.cancelled() => {
86 tracing::info!(target: "driver-socket", "shutdown signalled, draining");
87 return Ok(());
88 }
89 accept = self.listener.accept() => {
90 match accept {
91 Ok((stream, _addr)) => {
92 let decider = Arc::clone(&self.decider);
93 let cancel = self.cancel.clone();
94 tokio::spawn(async move {
95 if let Err(e) = serve_connection(stream, decider, cancel).await {
96 tracing::warn!(
97 target: "driver-socket",
98 "connection error: {e}"
99 );
100 }
101 });
102 }
103 Err(e) => {
104 tracing::warn!(target: "driver-socket", "accept failed: {e}");
105 return Err(DriverError::Socket(e.to_string()));
106 }
107 }
108 }
109 }
110 }
111 }
112}
113
114async fn serve_connection(
115 stream: UnixStream,
116 decider: Arc<dyn PermissionDecider>,
117 cancel: CancellationToken,
118) -> Result<(), DriverError> {
119 let (read_half, mut write_half) = stream.into_split();
120 let mut reader = BufReader::new(read_half).lines();
121 loop {
122 let next = tokio::select! {
123 _ = cancel.cancelled() => return Ok(()),
124 line = reader.next_line() => line?,
125 };
126 let raw = match next {
127 None => return Ok(()),
128 Some(s) if s.trim().is_empty() => continue,
129 Some(s) => s,
130 };
131 let msg: SocketMessage = match serde_json::from_str(&raw) {
132 Ok(m) => m,
133 Err(e) => {
134 let err = SocketMessage::Error {
135 id: String::new(),
136 message: format!("parse error: {e}"),
137 };
138 write_message(&mut write_half, &err).await?;
139 continue;
140 }
141 };
142 match msg {
143 SocketMessage::Decide { id, request } => {
144 let response = match decider.decide(request).await {
145 Ok(r) => r,
146 Err(e) => {
147 let err = SocketMessage::Error {
148 id: id.clone(),
149 message: e.to_string(),
150 };
151 write_message(&mut write_half, &err).await?;
152 continue;
153 }
154 };
155 let out = SocketMessage::Decision { id, response };
156 write_message(&mut write_half, &out).await?;
157 }
158 SocketMessage::Shutdown => return Ok(()),
159 other => {
160 tracing::warn!(target: "driver-socket", "unexpected message kind: {other:?}");
161 }
162 }
163 }
164}
165
166async fn write_message(
167 w: &mut (impl AsyncWriteExt + Unpin),
168 msg: &SocketMessage,
169) -> Result<(), DriverError> {
170 let mut bytes = serde_json::to_vec(msg)?;
171 bytes.push(b'\n');
172 w.write_all(&bytes).await?;
173 w.flush().await?;
174 Ok(())
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use nexo_driver_permission::{AllowAllDecider, AllowScope, PermissionOutcome, ScriptedDecider};
181 use nexo_driver_types::GoalId;
182
183 async fn write_line(w: &mut (impl AsyncWriteExt + Unpin), msg: &SocketMessage) {
184 write_message(w, msg).await.unwrap();
185 }
186
187 #[tokio::test]
188 async fn server_round_trip_with_allow_all() {
189 let dir = tempfile::tempdir().unwrap();
190 let path = dir.path().join("driver.sock");
191 let cancel = CancellationToken::new();
192 let server = DriverSocketServer::bind(&path, Arc::new(AllowAllDecider), cancel.clone())
193 .await
194 .unwrap();
195 let handle = tokio::spawn(server.run());
196
197 let stream = UnixStream::connect(&path).await.unwrap();
198 let (read_half, mut write_half) = stream.into_split();
199 let mut reader = BufReader::new(read_half).lines();
200
201 let req = PermissionRequest {
202 goal_id: GoalId::new(),
203 tool_use_id: "tu_x".into(),
204 tool_name: "Edit".into(),
205 input: serde_json::json!({"file":"x"}),
206 metadata: serde_json::Map::new(),
207 };
208 write_line(
209 &mut write_half,
210 &SocketMessage::Decide {
211 id: "1".into(),
212 request: req,
213 },
214 )
215 .await;
216
217 let line = reader.next_line().await.unwrap().unwrap();
218 let resp: SocketMessage = serde_json::from_str(&line).unwrap();
219 match resp {
220 SocketMessage::Decision { id, response } => {
221 assert_eq!(id, "1");
222 assert!(matches!(
223 response.outcome,
224 PermissionOutcome::AllowOnce { .. }
225 ));
226 }
227 other => panic!("expected Decision, got {other:?}"),
228 }
229
230 cancel.cancel();
231 let _ = handle.await;
232 }
233
234 #[tokio::test]
235 async fn scripted_decider_iterates_across_socket() {
236 let dir = tempfile::tempdir().unwrap();
237 let path = dir.path().join("driver.sock");
238 let cancel = CancellationToken::new();
239 let decider = Arc::new(ScriptedDecider::new([
240 PermissionOutcome::AllowSession {
241 scope: AllowScope::Turn,
242 updated_input: None,
243 },
244 PermissionOutcome::Deny {
245 message: "denied".into(),
246 },
247 ]));
248 let server = DriverSocketServer::bind(&path, decider, cancel.clone())
249 .await
250 .unwrap();
251 let handle = tokio::spawn(server.run());
252
253 let stream = UnixStream::connect(&path).await.unwrap();
254 let (read_half, mut write_half) = stream.into_split();
255 let mut reader = BufReader::new(read_half).lines();
256
257 let req = || PermissionRequest {
258 goal_id: GoalId::new(),
259 tool_use_id: "tu".into(),
260 tool_name: "Edit".into(),
261 input: serde_json::json!({}),
262 metadata: serde_json::Map::new(),
263 };
264 write_line(
265 &mut write_half,
266 &SocketMessage::Decide {
267 id: "1".into(),
268 request: req(),
269 },
270 )
271 .await;
272 let l1 = reader.next_line().await.unwrap().unwrap();
273 write_line(
274 &mut write_half,
275 &SocketMessage::Decide {
276 id: "2".into(),
277 request: req(),
278 },
279 )
280 .await;
281 let l2 = reader.next_line().await.unwrap().unwrap();
282
283 let r1: SocketMessage = serde_json::from_str(&l1).unwrap();
284 let r2: SocketMessage = serde_json::from_str(&l2).unwrap();
285 assert!(matches!(
286 r1,
287 SocketMessage::Decision {
288 response: PermissionResponse {
289 outcome: PermissionOutcome::AllowSession { .. },
290 ..
291 },
292 ..
293 }
294 ));
295 assert!(matches!(
296 r2,
297 SocketMessage::Decision {
298 response: PermissionResponse {
299 outcome: PermissionOutcome::Deny { .. },
300 ..
301 },
302 ..
303 }
304 ));
305
306 cancel.cancel();
307 let _ = handle.await;
308 }
309}