Skip to main content

nexo_driver_loop/
socket.rs

1//! Daemon-side Unix socket server. Bins spawned by Claude (`SocketDecider`
2//! in `nexo-driver-permission`) connect here and forward each
3//! `permission_prompt` request to the in-process decider.
4
5use std::path::Path;
6use std::sync::Arc;
7
8use nexo_driver_permission::{PermissionDecider, PermissionRequest, PermissionResponse};
9use serde::{Deserialize, Serialize};
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::net::{UnixListener, UnixStream};
12use tokio_util::sync::CancellationToken;
13
14use crate::error::DriverError;
15
16/// Wire frame on the line-delimited JSON socket protocol. Both sides
17/// send and receive these.
18#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(tag = "kind", rename_all = "snake_case")]
20pub enum SocketMessage {
21    Decide {
22        id: String,
23        request: PermissionRequest,
24    },
25    Decision {
26        id: String,
27        response: PermissionResponse,
28    },
29    Error {
30        id: String,
31        message: String,
32    },
33    Shutdown,
34}
35
36pub struct DriverSocketServer {
37    listener: UnixListener,
38    decider: Arc<dyn PermissionDecider>,
39    cancel: CancellationToken,
40}
41
42impl DriverSocketServer {
43    /// Bind a Unix listener at `path`. If the path exists, an `unlink`
44    /// is attempted first (operator may have a stale socket from a
45    /// crashed daemon). File mode is set to 0600 after bind.
46    pub async fn bind(
47        path: &Path,
48        decider: Arc<dyn PermissionDecider>,
49        cancel: CancellationToken,
50    ) -> Result<Self, DriverError> {
51        if let Some(parent) = path.parent() {
52            tokio::fs::create_dir_all(parent).await?;
53        }
54        if path.exists() {
55            // Best-effort cleanup of stale sock; ignore NotFound.
56            match tokio::fs::remove_file(path).await {
57                Ok(()) => {}
58                Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
59                Err(e) => return Err(DriverError::Socket(e.to_string())),
60            }
61        }
62        let listener = UnixListener::bind(path).map_err(|e| DriverError::Socket(e.to_string()))?;
63        #[cfg(unix)]
64        {
65            use std::os::unix::fs::PermissionsExt;
66            let mut perms = std::fs::metadata(path)
67                .map_err(|e| DriverError::Socket(e.to_string()))?
68                .permissions();
69            perms.set_mode(0o600);
70            std::fs::set_permissions(path, perms)
71                .map_err(|e| DriverError::Socket(e.to_string()))?;
72        }
73        Ok(Self {
74            listener,
75            decider,
76            cancel,
77        })
78    }
79
80    /// Accept loop. Returns when `cancel` fires or the listener
81    /// errors.
82    pub async fn run(self) -> Result<(), DriverError> {
83        loop {
84            tokio::select! {
85                _ = self.cancel.cancelled() => {
86                    tracing::info!(target: "driver-socket", "shutdown signalled, draining");
87                    return Ok(());
88                }
89                accept = self.listener.accept() => {
90                    match accept {
91                        Ok((stream, _addr)) => {
92                            let decider = Arc::clone(&self.decider);
93                            let cancel = self.cancel.clone();
94                            tokio::spawn(async move {
95                                if let Err(e) = serve_connection(stream, decider, cancel).await {
96                                    tracing::warn!(
97                                        target: "driver-socket",
98                                        "connection error: {e}"
99                                    );
100                                }
101                            });
102                        }
103                        Err(e) => {
104                            tracing::warn!(target: "driver-socket", "accept failed: {e}");
105                            return Err(DriverError::Socket(e.to_string()));
106                        }
107                    }
108                }
109            }
110        }
111    }
112}
113
114async fn serve_connection(
115    stream: UnixStream,
116    decider: Arc<dyn PermissionDecider>,
117    cancel: CancellationToken,
118) -> Result<(), DriverError> {
119    let (read_half, mut write_half) = stream.into_split();
120    let mut reader = BufReader::new(read_half).lines();
121    loop {
122        let next = tokio::select! {
123            _ = cancel.cancelled() => return Ok(()),
124            line = reader.next_line() => line?,
125        };
126        let raw = match next {
127            None => return Ok(()),
128            Some(s) if s.trim().is_empty() => continue,
129            Some(s) => s,
130        };
131        let msg: SocketMessage = match serde_json::from_str(&raw) {
132            Ok(m) => m,
133            Err(e) => {
134                let err = SocketMessage::Error {
135                    id: String::new(),
136                    message: format!("parse error: {e}"),
137                };
138                write_message(&mut write_half, &err).await?;
139                continue;
140            }
141        };
142        match msg {
143            SocketMessage::Decide { id, request } => {
144                let response = match decider.decide(request).await {
145                    Ok(r) => r,
146                    Err(e) => {
147                        let err = SocketMessage::Error {
148                            id: id.clone(),
149                            message: e.to_string(),
150                        };
151                        write_message(&mut write_half, &err).await?;
152                        continue;
153                    }
154                };
155                let out = SocketMessage::Decision { id, response };
156                write_message(&mut write_half, &out).await?;
157            }
158            SocketMessage::Shutdown => return Ok(()),
159            other => {
160                tracing::warn!(target: "driver-socket", "unexpected message kind: {other:?}");
161            }
162        }
163    }
164}
165
166async fn write_message(
167    w: &mut (impl AsyncWriteExt + Unpin),
168    msg: &SocketMessage,
169) -> Result<(), DriverError> {
170    let mut bytes = serde_json::to_vec(msg)?;
171    bytes.push(b'\n');
172    w.write_all(&bytes).await?;
173    w.flush().await?;
174    Ok(())
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use nexo_driver_permission::{AllowAllDecider, AllowScope, PermissionOutcome, ScriptedDecider};
181    use nexo_driver_types::GoalId;
182
183    async fn write_line(w: &mut (impl AsyncWriteExt + Unpin), msg: &SocketMessage) {
184        write_message(w, msg).await.unwrap();
185    }
186
187    #[tokio::test]
188    async fn server_round_trip_with_allow_all() {
189        let dir = tempfile::tempdir().unwrap();
190        let path = dir.path().join("driver.sock");
191        let cancel = CancellationToken::new();
192        let server = DriverSocketServer::bind(&path, Arc::new(AllowAllDecider), cancel.clone())
193            .await
194            .unwrap();
195        let handle = tokio::spawn(server.run());
196
197        let stream = UnixStream::connect(&path).await.unwrap();
198        let (read_half, mut write_half) = stream.into_split();
199        let mut reader = BufReader::new(read_half).lines();
200
201        let req = PermissionRequest {
202            goal_id: GoalId::new(),
203            tool_use_id: "tu_x".into(),
204            tool_name: "Edit".into(),
205            input: serde_json::json!({"file":"x"}),
206            metadata: serde_json::Map::new(),
207        };
208        write_line(
209            &mut write_half,
210            &SocketMessage::Decide {
211                id: "1".into(),
212                request: req,
213            },
214        )
215        .await;
216
217        let line = reader.next_line().await.unwrap().unwrap();
218        let resp: SocketMessage = serde_json::from_str(&line).unwrap();
219        match resp {
220            SocketMessage::Decision { id, response } => {
221                assert_eq!(id, "1");
222                assert!(matches!(
223                    response.outcome,
224                    PermissionOutcome::AllowOnce { .. }
225                ));
226            }
227            other => panic!("expected Decision, got {other:?}"),
228        }
229
230        cancel.cancel();
231        let _ = handle.await;
232    }
233
234    #[tokio::test]
235    async fn scripted_decider_iterates_across_socket() {
236        let dir = tempfile::tempdir().unwrap();
237        let path = dir.path().join("driver.sock");
238        let cancel = CancellationToken::new();
239        let decider = Arc::new(ScriptedDecider::new([
240            PermissionOutcome::AllowSession {
241                scope: AllowScope::Turn,
242                updated_input: None,
243            },
244            PermissionOutcome::Deny {
245                message: "denied".into(),
246            },
247        ]));
248        let server = DriverSocketServer::bind(&path, decider, cancel.clone())
249            .await
250            .unwrap();
251        let handle = tokio::spawn(server.run());
252
253        let stream = UnixStream::connect(&path).await.unwrap();
254        let (read_half, mut write_half) = stream.into_split();
255        let mut reader = BufReader::new(read_half).lines();
256
257        let req = || PermissionRequest {
258            goal_id: GoalId::new(),
259            tool_use_id: "tu".into(),
260            tool_name: "Edit".into(),
261            input: serde_json::json!({}),
262            metadata: serde_json::Map::new(),
263        };
264        write_line(
265            &mut write_half,
266            &SocketMessage::Decide {
267                id: "1".into(),
268                request: req(),
269            },
270        )
271        .await;
272        let l1 = reader.next_line().await.unwrap().unwrap();
273        write_line(
274            &mut write_half,
275            &SocketMessage::Decide {
276                id: "2".into(),
277                request: req(),
278            },
279        )
280        .await;
281        let l2 = reader.next_line().await.unwrap().unwrap();
282
283        let r1: SocketMessage = serde_json::from_str(&l1).unwrap();
284        let r2: SocketMessage = serde_json::from_str(&l2).unwrap();
285        assert!(matches!(
286            r1,
287            SocketMessage::Decision {
288                response: PermissionResponse {
289                    outcome: PermissionOutcome::AllowSession { .. },
290                    ..
291                },
292                ..
293            }
294        ));
295        assert!(matches!(
296            r2,
297            SocketMessage::Decision {
298                response: PermissionResponse {
299                    outcome: PermissionOutcome::Deny { .. },
300                    ..
301                },
302                ..
303            }
304        ));
305
306        cancel.cancel();
307        let _ = handle.await;
308    }
309}