1use anyhow::Result;
16use std::io::{BufRead, BufReader, Write};
17use std::path::{Path, PathBuf};
18
19use crate::dispenser::{Dispenser, DispenserState, IdMode, SqliteDispenser};
20
21pub fn default_socket_path() -> PathBuf {
23 if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
24 PathBuf::from(runtime_dir).join("aida.sock")
25 } else if let Some(uid) = get_uid() {
26 PathBuf::from(format!("/run/user/{}/aida.sock", uid))
27 } else {
28 std::env::temp_dir().join("aida.sock")
30 }
31}
32
33#[cfg(unix)]
34fn get_uid() -> Option<u32> {
35 Some(unsafe { libc::getuid() })
36}
37
38#[cfg(not(unix))]
39fn get_uid() -> Option<u32> {
40 None
41}
42
43#[derive(Debug, serde::Deserialize)]
45struct Request {
46 method: String,
47 #[serde(default)]
48 r#type: String,
49}
50
51#[derive(Debug, serde::Serialize, serde::Deserialize)]
53struct Response {
54 #[serde(skip_serializing_if = "Option::is_none")]
55 seq: Option<u32>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 id: Option<String>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 error: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 state: Option<DispenserState>,
62}
63
64#[cfg(unix)]
69pub fn run_daemon(db_path: &Path, mode: IdMode, socket_path: Option<&Path>) -> Result<()> {
70 use std::os::unix::net::UnixListener;
71
72 let sock_path = socket_path
73 .map(PathBuf::from)
74 .unwrap_or_else(default_socket_path);
75
76 if sock_path.exists() {
78 std::fs::remove_file(&sock_path)?;
79 }
80
81 let dispenser = SqliteDispenser::open(db_path.to_path_buf(), mode)?;
83
84 let listener = UnixListener::bind(&sock_path)?;
86 eprintln!("aida-daemon listening on {}", sock_path.display());
87
88 for stream in listener.incoming() {
90 match stream {
91 Ok(stream) => {
92 let reader = BufReader::new(stream.try_clone()?);
93 let mut writer = stream;
94
95 for line in reader.lines() {
96 let line = match line {
97 Ok(l) => l,
98 Err(_) => break,
99 };
100
101 let response = handle_request(&line, &dispenser);
102 let json = serde_json::to_string(&response).unwrap_or_else(|_| {
103 r#"{"error":"serialization failed"}"#.to_string()
104 });
105
106 if writeln!(writer, "{}", json).is_err() {
107 break;
108 }
109 }
110 }
111 Err(e) => {
112 eprintln!("aida-daemon: connection error: {}", e);
113 }
114 }
115 }
116
117 let _ = std::fs::remove_file(&sock_path);
119 Ok(())
120}
121
122fn handle_request(line: &str, dispenser: &dyn Dispenser) -> Response {
123 let req: Request = match serde_json::from_str(line) {
124 Ok(r) => r,
125 Err(e) => {
126 return Response {
127 seq: None,
128 id: None,
129 error: Some(format!("Invalid request: {}", e)),
130 state: None,
131 };
132 }
133 };
134
135 match req.method.as_str() {
136 "next" => match dispenser.next(&req.r#type) {
137 Ok(seq) => Response {
138 seq: Some(seq),
139 id: dispenser.format_id(&req.r#type, seq).ok(),
140 error: None,
141 state: None,
142 },
143 Err(e) => Response {
144 seq: None,
145 id: None,
146 error: Some(format!("next failed: {}", e)),
147 state: None,
148 },
149 },
150 "peek" => match dispenser.peek(&req.r#type) {
151 Ok(seq) => Response {
152 seq: Some(seq),
153 id: None,
154 error: None,
155 state: None,
156 },
157 Err(e) => Response {
158 seq: None,
159 id: None,
160 error: Some(format!("peek failed: {}", e)),
161 state: None,
162 },
163 },
164 "state" => match dispenser.state() {
165 Ok(state) => Response {
166 seq: None,
167 id: None,
168 error: None,
169 state: Some(state),
170 },
171 Err(e) => Response {
172 seq: None,
173 id: None,
174 error: Some(format!("state failed: {}", e)),
175 state: None,
176 },
177 },
178 "ping" => Response {
179 seq: None,
180 id: Some("pong".into()),
181 error: None,
182 state: None,
183 },
184 other => Response {
185 seq: None,
186 id: None,
187 error: Some(format!("Unknown method: {}", other)),
188 state: None,
189 },
190 }
191}
192
193#[cfg(unix)]
195pub struct DaemonClient {
196 socket_path: PathBuf,
197}
198
199#[cfg(unix)]
200impl DaemonClient {
201 pub fn new() -> Self {
203 Self {
204 socket_path: default_socket_path(),
205 }
206 }
207
208 pub fn with_path(socket_path: PathBuf) -> Self {
210 Self { socket_path }
211 }
212
213 pub fn is_running(&self) -> bool {
215 self.send_raw(r#"{"method":"ping"}"#).is_ok()
216 }
217
218 fn send_raw(&self, request: &str) -> Result<Response> {
220 use std::os::unix::net::UnixStream;
221
222 let mut stream = UnixStream::connect(&self.socket_path)?;
223 stream.set_read_timeout(Some(std::time::Duration::from_secs(5)))?;
224
225 writeln!(stream, "{}", request)?;
226
227 let mut reader = BufReader::new(stream);
228 let mut line = String::new();
229 reader.read_line(&mut line)?;
230
231 let response: Response = serde_json::from_str(&line)?;
232 Ok(response)
233 }
234}
235
236#[cfg(unix)]
237impl Dispenser for DaemonClient {
238 fn next(&self, object_type: &str) -> Result<u32> {
239 let req = serde_json::json!({"method": "next", "type": object_type});
240 let resp = self.send_raw(&req.to_string())?;
241 resp.seq
242 .ok_or_else(|| anyhow::anyhow!(resp.error.unwrap_or_else(|| "no seq".into())))
243 }
244
245 fn peek(&self, object_type: &str) -> Result<u32> {
246 let req = serde_json::json!({"method": "peek", "type": object_type});
247 let resp = self.send_raw(&req.to_string())?;
248 resp.seq
249 .ok_or_else(|| anyhow::anyhow!(resp.error.unwrap_or_else(|| "no seq".into())))
250 }
251
252 fn state(&self) -> Result<DispenserState> {
253 let req = serde_json::json!({"method": "state"});
254 let resp = self.send_raw(&req.to_string())?;
255 resp.state
256 .ok_or_else(|| anyhow::anyhow!(resp.error.unwrap_or_else(|| "no state".into())))
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_handle_request_next() {
266 use crate::dispenser::MemoryDispenser;
267 let d = MemoryDispenser::new(IdMode::Distributed { node_id: 7 });
268
269 let resp = handle_request(r#"{"method":"next","type":"FR"}"#, &d);
270 assert_eq!(resp.seq, Some(1));
271 assert_eq!(resp.id, Some("FR-7-001".into()));
272 assert!(resp.error.is_none());
273
274 let resp2 = handle_request(r#"{"method":"next","type":"FR"}"#, &d);
275 assert_eq!(resp2.seq, Some(2));
276 }
277
278 #[test]
279 fn test_handle_request_peek() {
280 use crate::dispenser::MemoryDispenser;
281 let d = MemoryDispenser::new(IdMode::Centralized);
282
283 let resp = handle_request(r#"{"method":"peek","type":"FR"}"#, &d);
284 assert_eq!(resp.seq, Some(1));
285
286 let resp2 = handle_request(r#"{"method":"peek","type":"FR"}"#, &d);
288 assert_eq!(resp2.seq, Some(1));
289 }
290
291 #[test]
292 fn test_handle_request_state() {
293 use crate::dispenser::MemoryDispenser;
294 let d = MemoryDispenser::new(IdMode::Distributed { node_id: 42 });
295 d.next("FR").unwrap();
296
297 let resp = handle_request(r#"{"method":"state"}"#, &d);
298 assert!(resp.state.is_some());
299 let state = resp.state.unwrap();
300 assert_eq!(state.mode, IdMode::Distributed { node_id: 42 });
301 }
302
303 #[test]
304 fn test_handle_request_ping() {
305 use crate::dispenser::MemoryDispenser;
306 let d = MemoryDispenser::new(IdMode::Centralized);
307
308 let resp = handle_request(r#"{"method":"ping"}"#, &d);
309 assert_eq!(resp.id, Some("pong".into()));
310 }
311
312 #[test]
313 fn test_handle_request_invalid() {
314 use crate::dispenser::MemoryDispenser;
315 let d = MemoryDispenser::new(IdMode::Centralized);
316
317 let resp = handle_request("not json", &d);
318 assert!(resp.error.is_some());
319
320 let resp2 = handle_request(r#"{"method":"unknown"}"#, &d);
321 assert!(resp2.error.is_some());
322 }
323
324 #[cfg(unix)]
328 #[test]
329 #[ignore]
330 fn test_daemon_client_server() {
331 use std::thread;
332
333 let dir = tempfile::tempdir().unwrap();
334 let db_path = dir.path().join("dispenser.db");
335 let sock_path = dir.path().join("test.sock");
336
337 let sock_clone = sock_path.clone();
338 let db_clone = db_path.clone();
339
340 let _handle = thread::spawn(move || {
342 let _ = run_daemon(&db_clone, IdMode::Distributed { node_id: 7 }, Some(&sock_clone));
343 });
344
345 for _ in 0..50 {
347 if sock_path.exists() {
348 break;
349 }
350 thread::sleep(std::time::Duration::from_millis(50));
351 }
352
353 if sock_path.exists() {
354 let client = DaemonClient::with_path(sock_path.clone());
355
356 assert!(client.is_running());
357 assert_eq!(client.next("FR").unwrap(), 1);
358 assert_eq!(client.next("FR").unwrap(), 2);
359 assert_eq!(client.peek("FR").unwrap(), 3);
360
361 let state = client.state().unwrap();
362 assert_eq!(state.mode, IdMode::Distributed { node_id: 7 });
363
364 let _ = std::fs::remove_file(&sock_path);
366 }
367 }
368}