1use anyhow::{Context, Result};
9use std::path::PathBuf;
10use std::process::{Command, Stdio};
11use std::time::Duration;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::UnixStream;
14
15use st_protocol::{Frame, Verb};
16
17pub fn socket_path() -> PathBuf {
19 std::env::var("XDG_RUNTIME_DIR")
20 .map(PathBuf::from)
21 .unwrap_or_else(|_| PathBuf::from("/tmp"))
22 .join("st.sock")
23}
24
25pub async fn is_daemon_running() -> bool {
27 let path = socket_path();
28 if !path.exists() {
29 return false;
30 }
31
32 match UnixStream::connect(&path).await {
34 Ok(mut stream) => {
35 let ping = Frame::ping();
36 if stream.write_all(&ping.encode()).await.is_err() {
37 return false;
38 }
39
40 let mut buf = [0u8; 256];
41 match tokio::time::timeout(Duration::from_millis(500), stream.read(&mut buf)).await {
42 Ok(Ok(n)) if n > 0 => {
43 true
45 }
46 _ => false,
47 }
48 }
49 Err(_) => false,
50 }
51}
52
53pub async fn start_daemon() -> Result<bool> {
55 if is_daemon_running().await {
56 return Ok(false); }
58
59 let exe_path = std::env::current_exe().ok();
61 let exe_dir = exe_path.as_ref().and_then(|p| p.parent());
62
63 let std_path = if let Some(dir) = exe_dir {
64 let candidate = dir.join("std");
65 if candidate.exists() {
66 candidate
67 } else {
68 PathBuf::from("std")
70 }
71 } else {
72 PathBuf::from("std")
73 };
74
75 #[cfg(unix)]
77 {
78 use std::os::unix::process::CommandExt;
79
80 let mut cmd = Command::new(&std_path);
82 cmd.arg("start")
83 .stdin(Stdio::null())
84 .stdout(Stdio::null())
85 .stderr(Stdio::null());
86
87 unsafe {
89 cmd.pre_exec(|| {
90 libc::setsid();
91 Ok(())
92 });
93 }
94
95 cmd.spawn().context("Failed to start std daemon")?;
96 }
97
98 #[cfg(windows)]
99 {
100 Command::new(&std_path)
101 .arg("start")
102 .creation_flags(0x00000008) .spawn()
104 .context("Failed to start std daemon")?;
105 }
106
107 for _ in 0..50 {
110 tokio::time::sleep(Duration::from_millis(100)).await;
111 if is_daemon_running().await {
112 return Ok(true);
113 }
114 }
115
116 Err(anyhow::anyhow!("Daemon started but not responding after 5 seconds"))
117}
118
119pub struct StdClient {
121 stream: Option<UnixStream>,
122}
123
124impl StdClient {
125 pub async fn connect() -> Option<Self> {
127 let path = socket_path();
128 match UnixStream::connect(&path).await {
129 Ok(stream) => Some(Self {
130 stream: Some(stream),
131 }),
132 Err(_) => None,
133 }
134 }
135
136 pub async fn connect_or_start() -> Result<Self> {
138 if let Some(client) = Self::connect().await {
139 return Ok(client);
140 }
141
142 start_daemon().await?;
144
145 Self::connect()
147 .await
148 .ok_or_else(|| anyhow::anyhow!("Failed to connect after starting daemon"))
149 }
150
151 pub async fn send(&mut self, frame: &Frame) -> Result<Vec<u8>> {
153 let stream = self
154 .stream
155 .as_mut()
156 .ok_or_else(|| anyhow::anyhow!("Not connected"))?;
157
158 stream
159 .write_all(&frame.encode())
160 .await
161 .context("Failed to send frame")?;
162
163 let mut buf = vec![0u8; 65536];
164 let n = stream.read(&mut buf).await.context("Failed to read response")?;
165 buf.truncate(n);
166 Ok(buf)
167 }
168
169 pub async fn ping(&mut self) -> Result<bool> {
171 let resp = self.send(&Frame::ping()).await?;
172 Ok(!resp.is_empty() && resp[0] == Verb::Ping as u8)
173 }
174
175 pub async fn scan(&mut self, path: &str, depth: u8) -> Result<String> {
177 let frame = Frame::scan(path, depth);
178 let resp = self.send(&frame).await?;
179
180 if resp.is_empty() {
182 return Ok(String::new());
183 }
184
185 if resp.len() > 2 {
187 let payload = &resp[1..resp.len() - 1];
188 String::from_utf8(payload.to_vec()).context("Invalid UTF-8 in scan response")
189 } else {
190 Ok(String::new())
191 }
192 }
193
194 pub async fn format(&mut self, path: &str, depth: u8, mode: &str) -> Result<String> {
196 let frame = Frame::format_path(mode, path, depth);
197 let resp = self.send(&frame).await?;
198
199 if resp.len() > 2 {
200 let payload = &resp[1..resp.len() - 1];
201 String::from_utf8(payload.to_vec()).context("Invalid UTF-8 in format response")
202 } else {
203 Ok(String::new())
204 }
205 }
206
207 pub async fn search(&mut self, path: &str, pattern: &str, max_results: u8) -> Result<String> {
209 let frame = Frame::search_path(path, pattern, max_results);
210 let resp = self.send(&frame).await?;
211
212 if resp.len() > 2 {
213 let payload = &resp[1..resp.len() - 1];
214 String::from_utf8(payload.to_vec()).context("Invalid UTF-8 in search response")
215 } else {
216 Ok(String::new())
217 }
218 }
219
220 pub async fn remember(
222 &mut self,
223 content: &str,
224 keywords: &str,
225 memory_type: &str,
226 ) -> Result<String> {
227 let frame = Frame::remember(content, keywords, memory_type);
228 let resp = self.send(&frame).await?;
229
230 if resp.len() > 2 {
231 let payload = &resp[1..resp.len() - 1];
232 String::from_utf8(payload.to_vec()).context("Invalid UTF-8 in remember response")
233 } else {
234 Ok(String::new())
235 }
236 }
237
238 pub async fn recall(&mut self, keywords: &str, max_results: u8) -> Result<String> {
240 let frame = Frame::recall(keywords, max_results);
241 let resp = self.send(&frame).await?;
242
243 if resp.len() > 2 {
244 let payload = &resp[1..resp.len() - 1];
245 String::from_utf8(payload.to_vec()).context("Invalid UTF-8 in recall response")
246 } else {
247 Ok(String::new())
248 }
249 }
250
251 pub async fn stats(&mut self) -> Result<serde_json::Value> {
253 let frame = Frame::stats();
254 let resp = self.send(&frame).await?;
255
256 if resp.len() > 2 {
257 let payload = &resp[1..resp.len() - 1];
258 let json_str = String::from_utf8(payload.to_vec())
259 .context("Invalid UTF-8 in stats response")?;
260 serde_json::from_str(&json_str).context("Invalid JSON in stats response")
261 } else {
262 Ok(serde_json::json!({}))
263 }
264 }
265
266 pub async fn m8_wave(&mut self) -> Result<String> {
268 let frame = Frame::m8_wave();
269 let resp = self.send(&frame).await?;
270
271 if resp.len() > 2 {
272 let payload = &resp[1..resp.len() - 1];
273 String::from_utf8(payload.to_vec()).context("Invalid UTF-8 in wave response")
274 } else {
275 Ok(String::new())
276 }
277 }
278
279 pub async fn audio(&mut self, acoustic_bytes: &[u8]) -> Result<String> {
283 let frame = Frame::audio(acoustic_bytes);
284 let resp = self.send(&frame).await?;
285
286 if resp.len() > 2 {
287 let payload = &resp[1..resp.len() - 1];
288 String::from_utf8(payload.to_vec()).context("Invalid UTF-8 in audio response")
289 } else {
290 Ok(String::new())
291 }
292 }
293
294 pub async fn audio_simple(&mut self, text: &str, valence: f32, arousal: f32) -> Result<String> {
296 let frame = Frame::audio_simple(text, valence, arousal);
297 let resp = self.send(&frame).await?;
298
299 if resp.len() > 2 {
300 let payload = &resp[1..resp.len() - 1];
301 String::from_utf8(payload.to_vec()).context("Invalid UTF-8 in audio response")
302 } else {
303 Ok(String::new())
304 }
305 }
306}
307
308pub async fn ensure_daemon(quiet: bool) -> Result<()> {
310 if is_daemon_running().await {
311 return Ok(());
312 }
313
314 if !quiet {
315 eprintln!("🌳 Starting Smart Tree daemon...");
316 }
317
318 start_daemon().await?;
319
320 if !quiet {
321 eprintln!("✓ Daemon ready");
322 }
323
324 Ok(())
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[tokio::test]
332 async fn test_socket_path() {
333 let path = socket_path();
334 assert!(path.to_string_lossy().contains("st.sock"));
335 }
336
337 #[tokio::test]
338 async fn test_daemon_check() {
339 let _ = is_daemon_running().await;
341 }
342}