1use crate::constants::SQLITE_GRAPHRAG_VERSION;
2use crate::errors::AppError;
3use crate::{embedder, shutdown_requested};
4use interprocess::local_socket::{
5 prelude::LocalSocketStream,
6 traits::{Listener as _, Stream as _},
7 GenericFilePath, GenericNamespaced, ListenerNonblockingMode, ListenerOptions, ToFsName,
8 ToNsName,
9};
10use serde::{Deserialize, Serialize};
11use std::io::{BufRead, BufReader, Write};
12use std::path::Path;
13use std::thread;
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Serialize, Deserialize)]
17#[serde(tag = "request", rename_all = "snake_case")]
18pub enum DaemonRequest {
19 Ping,
20 Shutdown,
21 EmbedPassage {
22 text: String,
23 },
24 EmbedQuery {
25 text: String,
26 },
27 EmbedPassages {
28 texts: Vec<String>,
29 token_counts: Vec<usize>,
30 },
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34#[serde(tag = "status", rename_all = "snake_case")]
35pub enum DaemonResponse {
36 Listening {
37 pid: u32,
38 socket: String,
39 idle_shutdown_secs: u64,
40 },
41 Ok {
42 pid: u32,
43 version: String,
44 handled_embed_requests: u64,
45 },
46 PassageEmbedding {
47 embedding: Vec<f32>,
48 handled_embed_requests: u64,
49 },
50 QueryEmbedding {
51 embedding: Vec<f32>,
52 handled_embed_requests: u64,
53 },
54 PassageEmbeddings {
55 embeddings: Vec<Vec<f32>>,
56 handled_embed_requests: u64,
57 },
58 ShuttingDown {
59 handled_embed_requests: u64,
60 },
61 Error {
62 message: String,
63 },
64}
65
66pub fn daemon_label(models_dir: &Path) -> String {
67 let hash = blake3::hash(models_dir.to_string_lossy().as_bytes())
68 .to_hex()
69 .to_string();
70 format!("sqlite-graphrag-daemon-{}", &hash[..16])
71}
72
73pub fn try_ping(models_dir: &Path) -> Result<Option<DaemonResponse>, AppError> {
74 request_if_available(models_dir, &DaemonRequest::Ping)
75}
76
77pub fn try_shutdown(models_dir: &Path) -> Result<Option<DaemonResponse>, AppError> {
78 request_if_available(models_dir, &DaemonRequest::Shutdown)
79}
80
81pub fn embed_passage_or_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
82 match request_if_available(
83 models_dir,
84 &DaemonRequest::EmbedPassage {
85 text: text.to_string(),
86 },
87 )? {
88 Some(DaemonResponse::PassageEmbedding { embedding, .. }) => Ok(embedding),
89 Some(DaemonResponse::Error { message }) => Err(AppError::Embedding(message)),
90 Some(other) => Err(AppError::Internal(anyhow::anyhow!(
91 "unexpected daemon response for passage embedding: {other:?}"
92 ))),
93 None => {
94 let embedder = embedder::get_embedder(models_dir)?;
95 embedder::embed_passage(embedder, text)
96 }
97 }
98}
99
100pub fn embed_query_or_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
101 match request_if_available(
102 models_dir,
103 &DaemonRequest::EmbedQuery {
104 text: text.to_string(),
105 },
106 )? {
107 Some(DaemonResponse::QueryEmbedding { embedding, .. }) => Ok(embedding),
108 Some(DaemonResponse::Error { message }) => Err(AppError::Embedding(message)),
109 Some(other) => Err(AppError::Internal(anyhow::anyhow!(
110 "unexpected daemon response for query embedding: {other:?}"
111 ))),
112 None => {
113 let embedder = embedder::get_embedder(models_dir)?;
114 embedder::embed_query(embedder, text)
115 }
116 }
117}
118
119pub fn embed_passages_controlled_or_local(
120 models_dir: &Path,
121 texts: &[&str],
122 token_counts: &[usize],
123) -> Result<Vec<Vec<f32>>, AppError> {
124 let request = DaemonRequest::EmbedPassages {
125 texts: texts.iter().map(|t| (*t).to_string()).collect(),
126 token_counts: token_counts.to_vec(),
127 };
128
129 match request_if_available(models_dir, &request)? {
130 Some(DaemonResponse::PassageEmbeddings { embeddings, .. }) => Ok(embeddings),
131 Some(DaemonResponse::Error { message }) => Err(AppError::Embedding(message)),
132 Some(other) => Err(AppError::Internal(anyhow::anyhow!(
133 "unexpected daemon response for batch passage embeddings: {other:?}"
134 ))),
135 None => {
136 let embedder = embedder::get_embedder(models_dir)?;
137 embedder::embed_passages_controlled(embedder, texts, token_counts)
138 }
139 }
140}
141
142pub fn run(models_dir: &Path, idle_shutdown_secs: u64) -> Result<(), AppError> {
143 let socket = daemon_label(models_dir);
144 let name = to_local_socket_name(&socket)?;
145 let listener = ListenerOptions::new()
146 .name(name)
147 .nonblocking(ListenerNonblockingMode::Accept)
148 .try_overwrite(true)
149 .create_sync()
150 .map_err(AppError::Io)?;
151
152 let _ = embedder::get_embedder(models_dir)?;
154
155 crate::output::emit_json(&DaemonResponse::Listening {
156 pid: std::process::id(),
157 socket,
158 idle_shutdown_secs,
159 })?;
160
161 let mut handled_embed_requests = 0_u64;
162 let mut last_activity = Instant::now();
163
164 loop {
165 if shutdown_requested() {
166 break;
167 }
168
169 match listener.accept() {
170 Ok(stream) => {
171 last_activity = Instant::now();
172 let should_exit = handle_client(stream, models_dir, &mut handled_embed_requests)?;
173 if should_exit {
174 break;
175 }
176 }
177 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
178 if last_activity.elapsed() >= Duration::from_secs(idle_shutdown_secs) {
179 tracing::info!(
180 idle_shutdown_secs,
181 handled_embed_requests,
182 "daemon idle timeout reached"
183 );
184 break;
185 }
186 thread::sleep(Duration::from_millis(50));
187 }
188 Err(err) => return Err(AppError::Io(err)),
189 }
190 }
191
192 Ok(())
193}
194
195fn handle_client(
196 stream: LocalSocketStream,
197 models_dir: &Path,
198 handled_embed_requests: &mut u64,
199) -> Result<bool, AppError> {
200 let mut reader = BufReader::new(stream);
201 let mut line = String::new();
202 reader.read_line(&mut line).map_err(AppError::Io)?;
203
204 if line.trim().is_empty() {
205 write_response(
206 reader.get_mut(),
207 &DaemonResponse::Error {
208 message: "empty daemon request".to_string(),
209 },
210 )?;
211 return Ok(false);
212 }
213
214 let request: DaemonRequest = serde_json::from_str(line.trim()).map_err(AppError::Json)?;
215 let (response, should_exit) = match request {
216 DaemonRequest::Ping => (
217 DaemonResponse::Ok {
218 pid: std::process::id(),
219 version: SQLITE_GRAPHRAG_VERSION.to_string(),
220 handled_embed_requests: *handled_embed_requests,
221 },
222 false,
223 ),
224 DaemonRequest::Shutdown => (
225 DaemonResponse::ShuttingDown {
226 handled_embed_requests: *handled_embed_requests,
227 },
228 true,
229 ),
230 DaemonRequest::EmbedPassage { text } => {
231 let embedder = embedder::get_embedder(models_dir)?;
232 let embedding = embedder::embed_passage(embedder, &text)?;
233 *handled_embed_requests += 1;
234 (
235 DaemonResponse::PassageEmbedding {
236 embedding,
237 handled_embed_requests: *handled_embed_requests,
238 },
239 false,
240 )
241 }
242 DaemonRequest::EmbedQuery { text } => {
243 let embedder = embedder::get_embedder(models_dir)?;
244 let embedding = embedder::embed_query(embedder, &text)?;
245 *handled_embed_requests += 1;
246 (
247 DaemonResponse::QueryEmbedding {
248 embedding,
249 handled_embed_requests: *handled_embed_requests,
250 },
251 false,
252 )
253 }
254 DaemonRequest::EmbedPassages {
255 texts,
256 token_counts,
257 } => {
258 let embedder = embedder::get_embedder(models_dir)?;
259 let text_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
260 let embeddings =
261 embedder::embed_passages_controlled(embedder, &text_refs, &token_counts)?;
262 *handled_embed_requests += 1;
263 (
264 DaemonResponse::PassageEmbeddings {
265 embeddings,
266 handled_embed_requests: *handled_embed_requests,
267 },
268 false,
269 )
270 }
271 };
272
273 write_response(reader.get_mut(), &response)?;
274 Ok(should_exit)
275}
276
277fn write_response(
278 stream: &mut LocalSocketStream,
279 response: &DaemonResponse,
280) -> Result<(), AppError> {
281 serde_json::to_writer(&mut *stream, response).map_err(AppError::Json)?;
282 stream.write_all(b"\n").map_err(AppError::Io)?;
283 stream.flush().map_err(AppError::Io)?;
284 Ok(())
285}
286
287fn request_if_available(
288 models_dir: &Path,
289 request: &DaemonRequest,
290) -> Result<Option<DaemonResponse>, AppError> {
291 let socket = daemon_label(models_dir);
292 let name = match to_local_socket_name(&socket) {
293 Ok(name) => name,
294 Err(err) => return Err(AppError::Io(err)),
295 };
296
297 let mut stream = match LocalSocketStream::connect(name) {
298 Ok(stream) => stream,
299 Err(err)
300 if matches!(
301 err.kind(),
302 std::io::ErrorKind::NotFound
303 | std::io::ErrorKind::ConnectionRefused
304 | std::io::ErrorKind::AddrNotAvailable
305 | std::io::ErrorKind::TimedOut
306 ) =>
307 {
308 return Ok(None);
309 }
310 Err(err) => return Err(AppError::Io(err)),
311 };
312
313 serde_json::to_writer(&mut stream, request).map_err(AppError::Json)?;
314 stream.write_all(b"\n").map_err(AppError::Io)?;
315 stream.flush().map_err(AppError::Io)?;
316
317 let mut reader = BufReader::new(stream);
318 let mut line = String::new();
319 reader.read_line(&mut line).map_err(AppError::Io)?;
320 if line.trim().is_empty() {
321 return Err(AppError::Embedding("daemon returned empty response".into()));
322 }
323
324 let response = serde_json::from_str(line.trim()).map_err(AppError::Json)?;
325 Ok(Some(response))
326}
327
328fn to_local_socket_name(name: &str) -> std::io::Result<interprocess::local_socket::Name<'static>> {
329 if let Ok(ns_name) = name.to_string().to_ns_name::<GenericNamespaced>() {
330 return Ok(ns_name);
331 }
332
333 let path = if cfg!(unix) {
334 format!("/tmp/{name}.sock")
335 } else {
336 format!(r"\\.\pipe\{name}")
337 };
338 path.to_fs_name::<GenericFilePath>()
339}