1use std::path::PathBuf;
2use std::process::Stdio;
3use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::UnixStream;
8use tokio::process::Command;
9
10use crate::config::ScalePolicy;
11use crate::error::{SdkError, SdkResult};
12
13const PILOT_CONNECTING_ERROR: &str = "runtime session not attached";
14const DEFAULT_PILOT_ENSURE_TIMEOUT: Duration = Duration::from_secs(5);
15const DEFAULT_PILOT_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
16
17fn pilot_ensure_timeout() -> Duration {
18 std::env::var("RUNE_PILOT_ENSURE_TIMEOUT_SECS")
19 .ok()
20 .and_then(|v| v.parse::<u64>().ok())
21 .map(Duration::from_secs)
22 .unwrap_or(DEFAULT_PILOT_ENSURE_TIMEOUT)
23}
24
25fn pilot_request_timeout() -> Duration {
26 std::env::var("RUNE_PILOT_REQUEST_TIMEOUT_SECS")
27 .ok()
28 .and_then(|v| v.parse::<u64>().ok())
29 .map(Duration::from_secs)
30 .unwrap_or(DEFAULT_PILOT_REQUEST_TIMEOUT)
31}
32
33#[derive(Debug, Serialize)]
34#[serde(tag = "command", rename_all = "snake_case")]
35enum PilotRequest {
36 Register {
37 caster_id: String,
38 pid: u32,
39 group: String,
40 spawn_command: String,
41 shutdown_signal: String,
42 },
43 Deregister {
44 caster_id: String,
45 },
46 Status,
47 Stop,
48}
49
50#[derive(Debug, Deserialize)]
51struct PilotResponse {
52 #[serde(default)]
53 ok: bool,
54 #[serde(default)]
55 pilot_id: String,
56 #[serde(default)]
57 runtime: String,
58 error: Option<String>,
59}
60
61#[derive(Debug, Clone)]
62pub struct PilotClient {
63 pilot_id: String,
64}
65
66impl PilotClient {
67 pub async fn ensure(runtime: &str, key: Option<&str>) -> SdkResult<Self> {
68 let normalized = normalize_runtime(runtime);
69 let deadline = tokio::time::Instant::now() + pilot_ensure_timeout();
70 if let Ok(response) = send_request(&PilotRequest::Status).await {
71 match Self::classify_status(response, &normalized) {
72 EnsureStatus::Ready(client) => return Ok(client),
73 EnsureStatus::Retry => {
74 return Self::wait_until_ready(&normalized, deadline, Some(runtime), key).await;
75 }
76 EnsureStatus::Mismatch => {
77 let _ = send_request(&PilotRequest::Stop).await;
79 }
80 EnsureStatus::Failed(error) => return Err(SdkError::Other(error)),
81 }
82 }
83
84 start_pilot(runtime, key).await?;
85 Self::wait_until_ready(&normalized, deadline, Some(runtime), key).await
86 }
87
88 async fn wait_until_ready(
92 normalized: &str,
93 deadline: tokio::time::Instant,
94 start_runtime: Option<&str>,
95 start_key: Option<&str>,
96 ) -> SdkResult<Self> {
97 let mut last_start = tokio::time::Instant::now();
98 loop {
99 match send_request(&PilotRequest::Status).await {
100 Ok(response) => match Self::classify_status(response, normalized) {
101 EnsureStatus::Ready(client) => return Ok(client),
102 EnsureStatus::Retry | EnsureStatus::Mismatch => {}
103 EnsureStatus::Failed(error) => return Err(SdkError::Other(error)),
104 },
105 Err(_) => {
106 if let Some(rt) = start_runtime {
108 if last_start.elapsed() >= Duration::from_secs(1) {
109 let _ = start_pilot(rt, start_key).await;
110 last_start = tokio::time::Instant::now();
111 }
112 }
113 }
114 }
115 if tokio::time::Instant::now() >= deadline {
116 break;
117 }
118 tokio::time::sleep(Duration::from_millis(100)).await;
119 }
120
121 Err(SdkError::Other("pilot did not become ready".into()))
122 }
123
124 pub fn pilot_id(&self) -> &str {
125 &self.pilot_id
126 }
127
128 pub async fn register(&self, caster_id: &str, policy: &ScalePolicy) -> SdkResult<()> {
129 let response = send_request(&PilotRequest::Register {
130 caster_id: caster_id.to_string(),
131 pid: std::process::id(),
132 group: policy.group.clone(),
133 spawn_command: policy.spawn_command.clone(),
134 shutdown_signal: policy.shutdown_signal.clone(),
135 })
136 .await?;
137 Self::ensure_ok(response)
138 }
139
140 pub async fn deregister(&self, caster_id: &str) -> SdkResult<()> {
141 let response = send_request(&PilotRequest::Deregister {
142 caster_id: caster_id.to_string(),
143 })
144 .await?;
145 Self::ensure_ok(response)
146 }
147
148 fn ensure_ok(response: PilotResponse) -> SdkResult<()> {
149 if response.ok {
150 Ok(())
151 } else {
152 Err(SdkError::Other(
153 response
154 .error
155 .unwrap_or_else(|| "pilot request failed".into()),
156 ))
157 }
158 }
159
160 fn classify_status(response: PilotResponse, normalized: &str) -> EnsureStatus {
161 if response.runtime != normalized {
162 return EnsureStatus::Mismatch;
163 }
164 if response.ok {
165 return EnsureStatus::Ready(Self {
166 pilot_id: response.pilot_id,
167 });
168 }
169 match response.error {
170 Some(error) if error == PILOT_CONNECTING_ERROR || error.is_empty() => {
171 EnsureStatus::Retry
172 }
173 Some(error) => EnsureStatus::Failed(error),
174 None => EnsureStatus::Retry,
175 }
176 }
177}
178
179enum EnsureStatus {
180 Ready(PilotClient),
181 Retry,
182 Mismatch,
183 Failed(String),
184}
185
186async fn send_request(request: &PilotRequest) -> SdkResult<PilotResponse> {
187 tokio::time::timeout(pilot_request_timeout(), send_request_inner(request))
188 .await
189 .map_err(|_| SdkError::Other("pilot request timed out".into()))?
190}
191
192async fn send_request_inner(request: &PilotRequest) -> SdkResult<PilotResponse> {
193 let socket_path = socket_path()?;
194 let mut stream = UnixStream::connect(&socket_path)
195 .await
196 .map_err(|err| SdkError::Other(format!("failed to connect to pilot: {err}")))?;
197 let payload = serde_json::to_vec(request)
198 .map_err(|err| SdkError::Other(format!("failed to encode pilot request: {err}")))?;
199 stream
200 .write_all(&payload)
201 .await
202 .map_err(|err| SdkError::Other(format!("failed to write pilot request: {err}")))?;
203 stream
204 .shutdown()
205 .await
206 .map_err(|err| SdkError::Other(format!("failed to flush pilot request: {err}")))?;
207 const MAX_RESPONSE_SIZE: u64 = 256 * 1024; let mut response = Vec::new();
209 stream
210 .take(MAX_RESPONSE_SIZE)
211 .read_to_end(&mut response)
212 .await
213 .map_err(|err| SdkError::Other(format!("failed to read pilot response: {err}")))?;
214 serde_json::from_slice(&response)
215 .map_err(|err| SdkError::Other(format!("failed to decode pilot response: {err}")))
216}
217
218async fn start_pilot(runtime: &str, key: Option<&str>) -> SdkResult<()> {
219 let mut command = Command::new(find_rune_binary()?);
220 command
221 .arg("pilot")
222 .arg("daemon")
223 .arg("--runtime")
224 .arg(normalize_runtime(runtime))
225 .stdin(Stdio::null())
226 .stdout(Stdio::null())
227 .stderr(Stdio::null());
228 if let Some(key) = key {
229 command.env("RUNE_KEY", key);
230 }
231 #[cfg(unix)]
232 unsafe {
237 command.pre_exec(|| {
238 libc::setsid();
239 Ok(())
240 });
241 }
242 command
243 .spawn()
244 .map_err(|err| SdkError::Other(format!("failed to start pilot daemon: {err}")))?;
245 Ok(())
246}
247
248fn normalize_runtime(runtime: &str) -> String {
249 runtime.trim().trim_end_matches('/').to_string()
250}
251
252fn socket_path() -> SdkResult<PathBuf> {
253 Ok(home_dir()?.join(".rune").join("pilot.sock"))
254}
255
256fn find_rune_binary() -> SdkResult<PathBuf> {
257 if let Ok(path) = std::env::var("RUNE_BIN") {
258 return Ok(PathBuf::from(path));
259 }
260
261 if let Some(paths) = std::env::var_os("PATH") {
262 for dir in std::env::split_paths(&paths) {
263 let candidate = dir.join("rune");
264 if candidate.is_file() {
265 return Ok(candidate);
266 }
267 }
268 }
269
270 Err(SdkError::Other(
271 "failed to locate rune binary; set RUNE_BIN or add rune to PATH".into(),
272 ))
273}
274
275fn home_dir() -> SdkResult<PathBuf> {
276 std::env::var_os("HOME")
277 .map(PathBuf::from)
278 .ok_or_else(|| SdkError::Other("failed to determine HOME".into()))
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use serde_json::json;
285 use std::collections::VecDeque;
286 use std::ffi::OsString;
287 use std::fs;
288 use std::path::{Path, PathBuf};
289 use std::sync::{Mutex, OnceLock};
290 use std::time::{SystemTime, UNIX_EPOCH};
291 use tokio::net::UnixListener;
292
293 static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
294
295 struct HomeGuard {
296 previous: Option<OsString>,
297 root: PathBuf,
298 _lock: std::sync::MutexGuard<'static, ()>,
299 }
300
301 impl HomeGuard {
302 fn set() -> Self {
303 let lock = ENV_LOCK
304 .get_or_init(|| Mutex::new(()))
305 .lock()
306 .unwrap_or_else(|poisoned| poisoned.into_inner());
307 let unique = SystemTime::now()
308 .duration_since(UNIX_EPOCH)
309 .unwrap_or_default()
310 .as_millis();
311 let root = PathBuf::from(format!("/tmp/rpc-{}-{unique}", std::process::id()));
312 fs::create_dir_all(root.join(".rune")).unwrap();
313 let previous = std::env::var_os("HOME");
314 std::env::set_var("HOME", &root);
315 Self {
316 previous,
317 root,
318 _lock: lock,
319 }
320 }
321 }
322
323 impl Drop for HomeGuard {
324 fn drop(&mut self) {
325 if let Some(previous) = self.previous.take() {
326 std::env::set_var("HOME", previous);
327 } else {
328 std::env::remove_var("HOME");
329 }
330 let _ = fs::remove_dir_all(&self.root);
331 }
332 }
333
334 async fn spawn_status_server(
335 socket: &Path,
336 responses: Vec<serde_json::Value>,
337 ) -> tokio::task::JoinHandle<()> {
338 let _ = fs::remove_file(socket);
339 let listener = UnixListener::bind(socket).unwrap();
340 let responses = Mutex::new(VecDeque::from(responses));
341 tokio::spawn(async move {
342 loop {
343 let Some(response) = responses
344 .lock()
345 .unwrap_or_else(|poisoned| poisoned.into_inner())
346 .pop_front()
347 else {
348 break;
349 };
350 let (mut stream, _) = listener.accept().await.unwrap();
351 let mut request = Vec::new();
352 stream.read_to_end(&mut request).await.unwrap();
353 let payload: serde_json::Value = serde_json::from_slice(&request).unwrap();
354 assert_eq!(payload["command"], "status");
355 stream
356 .write_all(&serde_json::to_vec(&response).unwrap())
357 .await
358 .unwrap();
359 }
360 })
361 }
362
363 #[tokio::test]
364 async fn test_fix_ensure_waits_for_matching_runtime_to_become_ready() {
365 let _home = HomeGuard::set();
366 let socket = socket_path().unwrap();
367 let server = spawn_status_server(
368 &socket,
369 vec![
370 json!({
371 "ok": false,
372 "pilot_id": "pilot-1",
373 "runtime": "127.0.0.1:50051",
374 "error": "runtime session not attached"
375 }),
376 json!({
377 "ok": true,
378 "pilot_id": "pilot-1",
379 "runtime": "127.0.0.1:50051",
380 "error": null
381 }),
382 ],
383 )
384 .await;
385
386 let client = PilotClient::ensure("127.0.0.1:50051", None)
387 .await
388 .expect("ensure should keep polling while pilot is still connecting");
389 assert_eq!(client.pilot_id(), "pilot-1");
390
391 server.await.unwrap();
392 }
393
394 #[test]
397 fn test_fix_pilot_response_deserialize_without_pilot_id() {
398 let json = r#"{"ok": false, "error": "connection refused"}"#;
399 let resp: PilotResponse =
400 serde_json::from_str(json).expect("should deserialize error response missing pilot_id");
401 assert!(!resp.ok);
402 assert_eq!(resp.pilot_id, "");
403 assert_eq!(resp.error.as_deref(), Some("connection refused"));
404
405 let json_minimal = r#"{"error": "socket not found"}"#;
407 let resp2: PilotResponse = serde_json::from_str(json_minimal)
408 .expect("should deserialize minimal error-only response");
409 assert!(!resp2.ok);
410 assert_eq!(resp2.pilot_id, "");
411 assert_eq!(resp2.runtime, "");
412 assert_eq!(resp2.error.as_deref(), Some("socket not found"));
413 }
414
415 #[test]
419 fn test_fix_classify_status_empty_error_retries() {
420 let response = PilotResponse {
421 ok: false,
422 pilot_id: "pilot-1".into(),
423 runtime: "127.0.0.1:50051".into(),
424 error: Some("".into()),
425 };
426 let result = PilotClient::classify_status(response, "127.0.0.1:50051");
427 assert!(
428 matches!(result, EnsureStatus::Retry),
429 "empty error string should classify as Retry, not Failed"
430 );
431 }
432}